From 8182a1cfb523fb1e8d328716111e98be3a1c5c35 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 7 Jun 2019 11:09:08 +0100 Subject: Refactor email tests --- tests/push/test_email.py | 64 +++++++++++++++++++++++++++++++----------------- 1 file changed, 42 insertions(+), 22 deletions(-) (limited to 'tests') diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 9cdde1a9bd..62b3c2a99d 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -15,6 +15,7 @@ import os +import attr import pkg_resources from twisted.internet.defer import Deferred @@ -30,6 +31,13 @@ except Exception: load_jinja2_templates = None +@attr.s +class _User(object): + "Helper wrapper for user ID and access token" + id = attr.ib() + token = attr.ib() + + class EmailPusherTests(HomeserverTestCase): skip = "No Jinja installed" if not load_jinja2_templates else None @@ -77,25 +85,32 @@ class EmailPusherTests(HomeserverTestCase): return hs - def test_sends_email(self): - + def prepare(self, reactor, clock, hs): # Register the user who gets notified - user_id = self.register_user("user", "pass") - access_token = self.login("user", "pass") - - # Register the user who sends the message - other_user_id = self.register_user("otheruser", "pass") - other_access_token = self.login("otheruser", "pass") + self.user_id = self.register_user("user", "pass") + self.access_token = self.login("user", "pass") + + # Register other users + self.others = [ + _User( + id=self.register_user("otheruser1", "pass"), + token=self.login("otheruser1", "pass"), + ), + _User( + id=self.register_user("otheruser2", "pass"), + token=self.login("otheruser2", "pass"), + ), + ] # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastore().get_user_by_access_token(self.access_token) ) token_id = user_tuple["token_id"] self.get_success( self.hs.get_pusherpool().add_pusher( - user_id=user_id, + user_id=self.user_id, access_token=token_id, kind="email", app_id="m.email", @@ -107,22 +122,27 @@ class EmailPusherTests(HomeserverTestCase): ) ) - # Create a room - room = self.helper.create_room_as(user_id, tok=access_token) + def test_simple_sends_email(self): + # Create a simple room with two users + room = self.helper.create_room_as(self.user_id, tok=self.access_token) + self.helper.invite( + room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id, + ) + self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token) - # Invite the other person - self.helper.invite(room=room, src=user_id, tok=access_token, targ=other_user_id) + # The other user sends some messages + self.helper.send(room, body="Hi!", tok=self.others[0].token) + self.helper.send(room, body="There!", tok=self.others[0].token) - # The other user joins - self.helper.join(room=room, user=other_user_id, tok=other_access_token) + # We should get emailed about that message + self._check_for_mail() - # The other user sends some messages - self.helper.send(room, body="Hi!", tok=other_access_token) - self.helper.send(room, body="There!", tok=other_access_token) + def _check_for_mail(self): + "Check that the user receives an email notification" # Get the stream ordering before it gets sent pushers = self.get_success( - self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) + self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id)) ) self.assertEqual(len(pushers), 1) last_stream_ordering = pushers[0]["last_stream_ordering"] @@ -132,7 +152,7 @@ class EmailPusherTests(HomeserverTestCase): # It hasn't succeeded yet, so the stream ordering shouldn't have moved pushers = self.get_success( - self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) + self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id)) ) self.assertEqual(len(pushers), 1) self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"]) @@ -149,7 +169,7 @@ class EmailPusherTests(HomeserverTestCase): # The stream ordering has increased pushers = self.get_success( - self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) + self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id)) ) self.assertEqual(len(pushers), 1) self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) -- cgit 1.5.1 From 2ebeda48b2e6ba522fe049ee7ef13450f6839e1b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 7 Jun 2019 12:10:23 +0100 Subject: Add test --- synapse/push/emailpusher.py | 19 +++++++++++++++++++ synapse/push/pusherpool.py | 30 +++++++++++++++++++++++------- tests/push/test_email.py | 29 ++++++++++++++++++++++++++++- 3 files changed, 70 insertions(+), 8 deletions(-) (limited to 'tests') diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index e8ee67401f..c89a8438a9 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -114,6 +114,21 @@ class EmailPusher(object): run_as_background_process("emailpush.process", self._process) + def _pause_processing(self): + """Used by tests to temporarily pause processing of events. + + Asserts that its not currently processing. + """ + assert not self._is_processing + self._is_processing = True + + def _resume_processing(self): + """Used by tests to resume processing of events after pausing. + """ + assert self._is_processing + self._is_processing = False + self._start_processing() + @defer.inlineCallbacks def _process(self): # we should never get here if we are already processing @@ -215,6 +230,10 @@ class EmailPusher(object): @defer.inlineCallbacks def save_last_stream_ordering_and_success(self, last_stream_ordering): + if last_stream_ordering is None: + # This happens if we haven't yet processed anything + return + self.last_stream_ordering = last_stream_ordering yield self.store.update_pusher_last_stream_ordering_and_success( self.app_id, self.email, self.user_id, diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 40a7709c09..63c583565f 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -60,6 +60,11 @@ class PusherPool: def add_pusher(self, user_id, access_token, kind, app_id, app_display_name, device_display_name, pushkey, lang, data, profile_tag=""): + """Creates a new pusher and adds it to the pool + + Returns: + Deferred[EmailPusher|HttpPusher] + """ time_now_msec = self.clock.time_msec() # we try to create the pusher just to validate the config: it @@ -103,7 +108,9 @@ class PusherPool: last_stream_ordering=last_stream_ordering, profile_tag=profile_tag, ) - yield self.start_pusher_by_id(app_id, pushkey, user_id) + pusher = yield self.start_pusher_by_id(app_id, pushkey, user_id) + + defer.returnValue(pusher) @defer.inlineCallbacks def remove_pushers_by_app_id_and_pushkey_not_user(self, app_id, pushkey, @@ -184,7 +191,11 @@ class PusherPool: @defer.inlineCallbacks def start_pusher_by_id(self, app_id, pushkey, user_id): - """Look up the details for the given pusher, and start it""" + """Look up the details for the given pusher, and start it + + Returns: + Deferred[EmailPusher|HttpPusher|None]: The pusher started, if any + """ if not self._should_start_pushers: return @@ -192,13 +203,16 @@ class PusherPool: app_id, pushkey ) - p = None + pusher_dict = None for r in resultlist: if r['user_name'] == user_id: - p = r + pusher_dict = r - if p: - yield self._start_pusher(p) + pusher = None + if pusher_dict: + pusher = yield self._start_pusher(pusher_dict) + + defer.returnValue(pusher) @defer.inlineCallbacks def _start_pushers(self): @@ -224,7 +238,7 @@ class PusherPool: pusherdict (dict): Returns: - None + Deferred[EmailPusher|HttpPusher] """ try: p = self.pusher_factory.create_pusher(pusherdict) @@ -270,6 +284,8 @@ class PusherPool: p.on_started(have_notifs) + defer.returnValue(p) + @defer.inlineCallbacks def remove_pusher(self, app_id, pushkey, user_id): appid_pushkey = "%s:%s" % (app_id, pushkey) diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 62b3c2a99d..c10b65d4b8 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -108,7 +108,7 @@ class EmailPusherTests(HomeserverTestCase): ) token_id = user_tuple["token_id"] - self.get_success( + self.pusher = self.get_success( self.hs.get_pusherpool().add_pusher( user_id=self.user_id, access_token=token_id, @@ -137,6 +137,33 @@ class EmailPusherTests(HomeserverTestCase): # We should get emailed about that message self._check_for_mail() + def test_multiple_members_email(self): + # We want to test multiple notifications, so we pause processing of push + # while we send messages. + self.pusher._pause_processing() + + # Create a simple room with multiple other users + room = self.helper.create_room_as(self.user_id, tok=self.access_token) + + for other in self.others: + self.helper.invite( + room=room, src=self.user_id, tok=self.access_token, targ=other.id, + ) + self.helper.join(room=room, user=other.id, tok=other.token) + + # The other users send some messages + self.helper.send(room, body="Hi!", tok=self.others[0].token) + self.helper.send(room, body="There!", tok=self.others[1].token) + self.helper.send(room, body="There!", tok=self.others[1].token) + + # Nothing should have happened yet, as we're paused. + assert not self.email_attempts + + self.pusher._resume_processing() + + # We should get emailed about those messages + self._check_for_mail() + def _check_for_mail(self): "Check that the user receives an email notification" -- cgit 1.5.1 From 6312d6cc7c5bc80984758a70e2c368d8b4fb3bfd Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Thu, 13 Jun 2019 22:40:52 +1000 Subject: Expose statistics on extrems to prometheus (#5384) --- changelog.d/5384.feature | 1 + scripts/generate_signing_key.py | 2 +- synapse/metrics/__init__.py | 112 +++++++++++++++++++++++------ synapse/storage/events.py | 44 ++++++++---- tests/storage/test_cleanup_extrems.py | 128 +++++++++++++--------------------- tests/storage/test_event_metrics.py | 97 ++++++++++++++++++++++++++ tests/unittest.py | 61 +++++++++++++++- 7 files changed, 331 insertions(+), 114 deletions(-) create mode 100644 changelog.d/5384.feature create mode 100644 tests/storage/test_event_metrics.py (limited to 'tests') diff --git a/changelog.d/5384.feature b/changelog.d/5384.feature new file mode 100644 index 0000000000..9497f521c8 --- /dev/null +++ b/changelog.d/5384.feature @@ -0,0 +1 @@ +Statistics on forward extremities per room are now exposed via Prometheus. diff --git a/scripts/generate_signing_key.py b/scripts/generate_signing_key.py index ba3ba97395..36e9140b50 100755 --- a/scripts/generate_signing_key.py +++ b/scripts/generate_signing_key.py @@ -16,7 +16,7 @@ import argparse import sys -from signedjson.key import write_signing_keys, generate_signing_key +from signedjson.key import generate_signing_key, write_signing_keys from synapse.util.stringutils import random_string diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index ef48984fdd..539c353528 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -25,7 +25,7 @@ import six import attr from prometheus_client import Counter, Gauge, Histogram -from prometheus_client.core import REGISTRY, GaugeMetricFamily +from prometheus_client.core import REGISTRY, GaugeMetricFamily, HistogramMetricFamily from twisted.internet import reactor @@ -40,7 +40,6 @@ HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat") class RegistryProxy(object): - @staticmethod def collect(): for metric in REGISTRY.collect(): @@ -63,10 +62,7 @@ class LaterGauge(object): try: calls = self.caller() except Exception: - logger.exception( - "Exception running callback for LaterGauge(%s)", - self.name, - ) + logger.exception("Exception running callback for LaterGauge(%s)", self.name) yield g return @@ -116,9 +112,7 @@ class InFlightGauge(object): # Create a class which have the sub_metrics values as attributes, which # default to 0 on initialization. Used to pass to registered callbacks. self._metrics_class = attr.make_class( - "_MetricsEntry", - attrs={x: attr.ib(0) for x in sub_metrics}, - slots=True, + "_MetricsEntry", attrs={x: attr.ib(0) for x in sub_metrics}, slots=True ) # Counts number of in flight blocks for a given set of label values @@ -157,7 +151,9 @@ class InFlightGauge(object): Note: may be called by a separate thread. """ - in_flight = GaugeMetricFamily(self.name + "_total", self.desc, labels=self.labels) + in_flight = GaugeMetricFamily( + self.name + "_total", self.desc, labels=self.labels + ) metrics_by_key = {} @@ -179,7 +175,9 @@ class InFlightGauge(object): yield in_flight for name in self.sub_metrics: - gauge = GaugeMetricFamily("_".join([self.name, name]), "", labels=self.labels) + gauge = GaugeMetricFamily( + "_".join([self.name, name]), "", labels=self.labels + ) for key, metrics in six.iteritems(metrics_by_key): gauge.add_metric(key, getattr(metrics, name)) yield gauge @@ -193,12 +191,75 @@ class InFlightGauge(object): all_gauges[self.name] = self +@attr.s(hash=True) +class BucketCollector(object): + """ + Like a Histogram, but allows buckets to be point-in-time instead of + incrementally added to. + + Args: + name (str): Base name of metric to be exported to Prometheus. + data_collector (callable -> dict): A synchronous callable that + returns a dict mapping bucket to number of items in the + bucket. If these buckets are not the same as the buckets + given to this class, they will be remapped into them. + buckets (list[float]): List of floats/ints of the buckets to + give to Prometheus. +Inf is ignored, if given. + + """ + + name = attr.ib() + data_collector = attr.ib() + buckets = attr.ib() + + def collect(self): + + # Fetch the data -- this must be synchronous! + data = self.data_collector() + + buckets = {} + + res = [] + for x in data.keys(): + for i, bound in enumerate(self.buckets): + if x <= bound: + buckets[bound] = buckets.get(bound, 0) + data[x] + break + + for i in self.buckets: + res.append([i, buckets.get(i, 0)]) + + res.append(["+Inf", sum(data.values())]) + + metric = HistogramMetricFamily( + self.name, + "", + buckets=res, + sum_value=sum([x * y for x, y in data.items()]), + ) + yield metric + + def __attrs_post_init__(self): + self.buckets = [float(x) for x in self.buckets if x != "+Inf"] + if self.buckets != sorted(self.buckets): + raise ValueError("Buckets not sorted") + + self.buckets = tuple(self.buckets) + + if self.name in all_gauges.keys(): + logger.warning("%s already registered, reregistering" % (self.name,)) + REGISTRY.unregister(all_gauges.pop(self.name)) + + REGISTRY.register(self) + all_gauges[self.name] = self + + # # Detailed CPU metrics # -class CPUMetrics(object): +class CPUMetrics(object): def __init__(self): ticks_per_sec = 100 try: @@ -237,13 +298,28 @@ gc_time = Histogram( "python_gc_time", "Time taken to GC (sec)", ["gen"], - buckets=[0.0025, 0.005, 0.01, 0.025, 0.05, 0.10, 0.25, 0.50, 1.00, 2.50, - 5.00, 7.50, 15.00, 30.00, 45.00, 60.00], + buckets=[ + 0.0025, + 0.005, + 0.01, + 0.025, + 0.05, + 0.10, + 0.25, + 0.50, + 1.00, + 2.50, + 5.00, + 7.50, + 15.00, + 30.00, + 45.00, + 60.00, + ], ) class GCCounts(object): - def collect(self): cm = GaugeMetricFamily("python_gc_counts", "GC object counts", labels=["gen"]) for n, m in enumerate(gc.get_count()): @@ -279,9 +355,7 @@ sent_transactions_counter = Counter("synapse_federation_client_sent_transactions events_processed_counter = Counter("synapse_federation_client_events_processed", "") event_processing_loop_counter = Counter( - "synapse_event_processing_loop_count", - "Event processing loop iterations", - ["name"], + "synapse_event_processing_loop_count", "Event processing loop iterations", ["name"] ) event_processing_loop_room_count = Counter( @@ -311,7 +385,6 @@ last_ticked = time.time() class ReactorLastSeenMetric(object): - def collect(self): cm = GaugeMetricFamily( "python_twisted_reactor_last_seen", @@ -325,7 +398,6 @@ REGISTRY.register(ReactorLastSeenMetric()) def runUntilCurrentTimer(func): - @functools.wraps(func) def f(*args, **kwargs): now = reactor.seconds() diff --git a/synapse/storage/events.py b/synapse/storage/events.py index f9162be9b9..1578403f79 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -17,7 +17,7 @@ import itertools import logging -from collections import OrderedDict, deque, namedtuple +from collections import Counter as c_counter, OrderedDict, deque, namedtuple from functools import wraps from six import iteritems, text_type @@ -33,6 +33,7 @@ from synapse.api.constants import EventTypes from synapse.api.errors import SynapseError from synapse.events import EventBase # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401 +from synapse.metrics import BucketCollector from synapse.metrics.background_process_metrics import run_as_background_process from synapse.state import StateResolutionStore from synapse.storage.background_updates import BackgroundUpdateStore @@ -220,13 +221,38 @@ class EventsStore( EventsWorkerStore, BackgroundUpdateStore, ): - def __init__(self, db_conn, hs): super(EventsStore, self).__init__(db_conn, hs) self._event_persist_queue = _EventPeristenceQueue() self._state_resolution_handler = hs.get_state_resolution_handler() + # Collect metrics on the number of forward extremities that exist. + self._current_forward_extremities_amount = {} + + BucketCollector( + "synapse_forward_extremities", + lambda: self._current_forward_extremities_amount, + buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"] + ) + + # Read the extrems every 60 minutes + hs.get_clock().looping_call(self._read_forward_extremities, 60 * 60 * 1000) + + @defer.inlineCallbacks + def _read_forward_extremities(self): + def fetch(txn): + txn.execute( + """ + select count(*) c from event_forward_extremities + group by room_id + """ + ) + return txn.fetchall() + + res = yield self.runInteraction("read_forward_extremities", fetch) + self._current_forward_extremities_amount = c_counter(list(x[0] for x in res)) + @defer.inlineCallbacks def persist_events(self, events_and_contexts, backfilled=False): """ @@ -568,17 +594,11 @@ class EventsStore( ) txn.execute(sql, batch) - results.extend( - r[0] - for r in txn - if not json.loads(r[1]).get("soft_failed") - ) + results.extend(r[0] for r in txn if not json.loads(r[1]).get("soft_failed")) for chunk in batch_iter(event_ids, 100): yield self.runInteraction( - "_get_events_which_are_prevs", - _get_events_which_are_prevs_txn, - chunk, + "_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk ) defer.returnValue(results) @@ -640,9 +660,7 @@ class EventsStore( for chunk in batch_iter(event_ids, 100): yield self.runInteraction( - "_get_prevs_before_rejected", - _get_prevs_before_rejected_txn, - chunk, + "_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk ) defer.returnValue(existing_prevs) diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index 6aa8b8b3c6..f4c81ef77d 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -15,7 +15,6 @@ import os.path -from synapse.api.constants import EventTypes from synapse.storage import prepare_database from synapse.types import Requester, UserID @@ -23,17 +22,12 @@ from tests.unittest import HomeserverTestCase class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): - """Test the background update to clean forward extremities table. """ - def make_homeserver(self, reactor, clock): - # Hack until we understand why test_forked_graph_cleanup fails with v4 - config = self.default_config() - config['default_room_version'] = '1' - return self.setup_test_homeserver(config=config) + Test the background update to clean forward extremities table. + """ def prepare(self, reactor, clock, homeserver): self.store = homeserver.get_datastore() - self.event_creator = homeserver.get_event_creation_handler() self.room_creator = homeserver.get_room_creation_handler() # Create a test user and room @@ -42,56 +36,6 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): info = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] - def create_and_send_event(self, soft_failed=False, prev_event_ids=None): - """Create and send an event. - - Args: - soft_failed (bool): Whether to create a soft failed event or not - prev_event_ids (list[str]|None): Explicitly set the prev events, - or if None just use the default - - Returns: - str: The new event's ID. - """ - prev_events_and_hashes = None - if prev_event_ids: - prev_events_and_hashes = [[p, {}, 0] for p in prev_event_ids] - - event, context = self.get_success( - self.event_creator.create_event( - self.requester, - { - "type": EventTypes.Message, - "room_id": self.room_id, - "sender": self.user.to_string(), - "content": {"body": "", "msgtype": "m.text"}, - }, - prev_events_and_hashes=prev_events_and_hashes, - ) - ) - - if soft_failed: - event.internal_metadata.soft_failed = True - - self.get_success( - self.event_creator.send_nonmember_event(self.requester, event, context) - ) - - return event.event_id - - def add_extremity(self, event_id): - """Add the given event as an extremity to the room. - """ - self.get_success( - self.store._simple_insert( - table="event_forward_extremities", - values={"room_id": self.room_id, "event_id": event_id}, - desc="test_add_extremity", - ) - ) - - self.store.get_latest_event_ids_in_room.invalidate((self.room_id,)) - def run_background_update(self): """Re run the background update to clean up the extremities. """ @@ -131,10 +75,16 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): """ # Create the room graph - event_id_1 = self.create_and_send_event() - event_id_2 = self.create_and_send_event(True, [event_id_1]) - event_id_3 = self.create_and_send_event(True, [event_id_2]) - event_id_4 = self.create_and_send_event(False, [event_id_3]) + event_id_1 = self.create_and_send_event(self.room_id, self.user) + event_id_2 = self.create_and_send_event( + self.room_id, self.user, True, [event_id_1] + ) + event_id_3 = self.create_and_send_event( + self.room_id, self.user, True, [event_id_2] + ) + event_id_4 = self.create_and_send_event( + self.room_id, self.user, False, [event_id_3] + ) # Check the latest events are as expected latest_event_ids = self.get_success( @@ -154,12 +104,16 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): Where SF* are soft failed, and with extremities of A and B """ # Create the room graph - event_id_a = self.create_and_send_event() - event_id_sf1 = self.create_and_send_event(True, [event_id_a]) - event_id_b = self.create_and_send_event(False, [event_id_sf1]) + event_id_a = self.create_and_send_event(self.room_id, self.user) + event_id_sf1 = self.create_and_send_event( + self.room_id, self.user, True, [event_id_a] + ) + event_id_b = self.create_and_send_event( + self.room_id, self.user, False, [event_id_sf1] + ) # Add the new extremity and check the latest events are as expected - self.add_extremity(event_id_a) + self.add_extremity(self.room_id, event_id_a) latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) @@ -185,13 +139,19 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): Where SF* are soft failed, and with extremities of A and B """ # Create the room graph - event_id_a = self.create_and_send_event() - event_id_sf1 = self.create_and_send_event(True, [event_id_a]) - event_id_sf2 = self.create_and_send_event(True, [event_id_sf1]) - event_id_b = self.create_and_send_event(False, [event_id_sf2]) + event_id_a = self.create_and_send_event(self.room_id, self.user) + event_id_sf1 = self.create_and_send_event( + self.room_id, self.user, True, [event_id_a] + ) + event_id_sf2 = self.create_and_send_event( + self.room_id, self.user, True, [event_id_sf1] + ) + event_id_b = self.create_and_send_event( + self.room_id, self.user, False, [event_id_sf2] + ) # Add the new extremity and check the latest events are as expected - self.add_extremity(event_id_a) + self.add_extremity(self.room_id, event_id_a) latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) @@ -227,16 +187,26 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): """ # Create the room graph - event_id_a = self.create_and_send_event() - event_id_b = self.create_and_send_event() - event_id_sf1 = self.create_and_send_event(True, [event_id_a]) - event_id_sf2 = self.create_and_send_event(True, [event_id_a, event_id_b]) - event_id_sf3 = self.create_and_send_event(True, [event_id_sf1]) - self.create_and_send_event(True, [event_id_sf2, event_id_sf3]) # SF4 - event_id_c = self.create_and_send_event(False, [event_id_sf3]) + event_id_a = self.create_and_send_event(self.room_id, self.user) + event_id_b = self.create_and_send_event(self.room_id, self.user) + event_id_sf1 = self.create_and_send_event( + self.room_id, self.user, True, [event_id_a] + ) + event_id_sf2 = self.create_and_send_event( + self.room_id, self.user, True, [event_id_a, event_id_b] + ) + event_id_sf3 = self.create_and_send_event( + self.room_id, self.user, True, [event_id_sf1] + ) + self.create_and_send_event( + self.room_id, self.user, True, [event_id_sf2, event_id_sf3] + ) # SF4 + event_id_c = self.create_and_send_event( + self.room_id, self.user, False, [event_id_sf3] + ) # Add the new extremity and check the latest events are as expected - self.add_extremity(event_id_a) + self.add_extremity(self.room_id, event_id_a) latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py new file mode 100644 index 0000000000..20a068f1fc --- /dev/null +++ b/tests/storage/test_event_metrics.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.metrics import REGISTRY +from synapse.types import Requester, UserID + +from tests.unittest import HomeserverTestCase + + +class ExtremStatisticsTestCase(HomeserverTestCase): + def test_exposed_to_prometheus(self): + """ + Forward extremity counts are exposed via Prometheus. + """ + room_creator = self.hs.get_room_creation_handler() + + user = UserID("alice", "test") + requester = Requester(user, None, False, None, None) + + # Real events, forward extremities + events = [(3, 2), (6, 2), (4, 6)] + + for event_count, extrems in events: + info = self.get_success(room_creator.create_room(requester, {})) + room_id = info["room_id"] + + last_event = None + + # Make a real event chain + for i in range(event_count): + ev = self.create_and_send_event(room_id, user, False, last_event) + last_event = [ev] + + # Sprinkle in some extremities + for i in range(extrems): + ev = self.create_and_send_event(room_id, user, False, last_event) + + # Let it run for a while, then pull out the statistics from the + # Prometheus client registry + self.reactor.advance(60 * 60 * 1000) + self.pump(1) + + items = list( + filter( + lambda x: x.name == "synapse_forward_extremities", + list(REGISTRY.collect()), + ) + ) + + # Check the values are what we want + buckets = {} + _count = 0 + _sum = 0 + + for i in items[0].samples: + if i[0].endswith("_bucket"): + buckets[i[1]['le']] = i[2] + elif i[0].endswith("_count"): + _count = i[2] + elif i[0].endswith("_sum"): + _sum = i[2] + + # 3 buckets, 2 with 2 extrems, 1 with 6 extrems (bucketed as 7), and + # +Inf which is all + self.assertEqual( + buckets, + { + 1.0: 0, + 2.0: 2, + 3.0: 0, + 5.0: 0, + 7.0: 1, + 10.0: 0, + 15.0: 0, + 20.0: 0, + 50.0: 0, + 100.0: 0, + 200.0: 0, + 500.0: 0, + "+Inf": 3, + }, + ) + # 3 rooms, with 10 total events + self.assertEqual(_count, 3) + self.assertEqual(_sum, 10) diff --git a/tests/unittest.py b/tests/unittest.py index 7dbb64af59..b6dc7932ce 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -27,11 +27,12 @@ import twisted.logger from twisted.internet.defer import Deferred from twisted.trial import unittest +from synapse.api.constants import EventTypes from synapse.config.homeserver import HomeServerConfig from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest from synapse.server import HomeServer -from synapse.types import UserID, create_requester +from synapse.types import Requester, UserID, create_requester from synapse.util.logcontext import LoggingContext from tests.server import get_clock, make_request, render, setup_test_homeserver @@ -442,6 +443,64 @@ class HomeserverTestCase(TestCase): access_token = channel.json_body["access_token"] return access_token + def create_and_send_event( + self, room_id, user, soft_failed=False, prev_event_ids=None + ): + """ + Create and send an event. + + Args: + soft_failed (bool): Whether to create a soft failed event or not + prev_event_ids (list[str]|None): Explicitly set the prev events, + or if None just use the default + + Returns: + str: The new event's ID. + """ + event_creator = self.hs.get_event_creation_handler() + secrets = self.hs.get_secrets() + requester = Requester(user, None, False, None, None) + + prev_events_and_hashes = None + if prev_event_ids: + prev_events_and_hashes = [[p, {}, 0] for p in prev_event_ids] + + event, context = self.get_success( + event_creator.create_event( + requester, + { + "type": EventTypes.Message, + "room_id": room_id, + "sender": user.to_string(), + "content": {"body": secrets.token_hex(), "msgtype": "m.text"}, + }, + prev_events_and_hashes=prev_events_and_hashes, + ) + ) + + if soft_failed: + event.internal_metadata.soft_failed = True + + self.get_success( + event_creator.send_nonmember_event(requester, event, context) + ) + + return event.event_id + + def add_extremity(self, room_id, event_id): + """ + Add the given event as an extremity to the room. + """ + self.get_success( + self.hs.get_datastore()._simple_insert( + table="event_forward_extremities", + values={"room_id": room_id, "event_id": event_id}, + desc="test_add_extremity", + ) + ) + + self.hs.get_datastore().get_latest_event_ids_in_room.invalidate((room_id,)) + def attempt_wrong_password_login(self, username, password): """Attempts to login as the user with the given password, asserting that the attempt *fails*. -- cgit 1.5.1 From a10c8dae85d3706afbab588e1004350aa5b49539 Mon Sep 17 00:00:00 2001 From: "Amber H. Brown" Date: Fri, 14 Jun 2019 21:09:33 +1000 Subject: fix prometheus rendering error --- synapse/metrics/__init__.py | 2 +- tests/storage/test_event_metrics.py | 61 ++++++++++++++----------------------- 2 files changed, 24 insertions(+), 39 deletions(-) (limited to 'tests') diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index 539c353528..0d3ae1a43d 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -227,7 +227,7 @@ class BucketCollector(object): break for i in self.buckets: - res.append([i, buckets.get(i, 0)]) + res.append([str(i), buckets.get(i, 0)]) res.append(["+Inf", sum(data.values())]) diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py index 20a068f1fc..1655fcdafc 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from prometheus_client.exposition import generate_latest + from synapse.metrics import REGISTRY from synapse.types import Requester, UserID @@ -52,46 +54,29 @@ class ExtremStatisticsTestCase(HomeserverTestCase): self.reactor.advance(60 * 60 * 1000) self.pump(1) - items = list( + items = set( filter( - lambda x: x.name == "synapse_forward_extremities", - list(REGISTRY.collect()), + lambda x: b"synapse_forward_extremities_" in x, + generate_latest(REGISTRY).split(b"\n"), ) ) - # Check the values are what we want - buckets = {} - _count = 0 - _sum = 0 - - for i in items[0].samples: - if i[0].endswith("_bucket"): - buckets[i[1]['le']] = i[2] - elif i[0].endswith("_count"): - _count = i[2] - elif i[0].endswith("_sum"): - _sum = i[2] + expected = set([ + b'synapse_forward_extremities_bucket{le="1.0"} 0.0', + b'synapse_forward_extremities_bucket{le="2.0"} 2.0', + b'synapse_forward_extremities_bucket{le="3.0"} 0.0', + b'synapse_forward_extremities_bucket{le="5.0"} 0.0', + b'synapse_forward_extremities_bucket{le="7.0"} 1.0', + b'synapse_forward_extremities_bucket{le="10.0"} 0.0', + b'synapse_forward_extremities_bucket{le="15.0"} 0.0', + b'synapse_forward_extremities_bucket{le="20.0"} 0.0', + b'synapse_forward_extremities_bucket{le="50.0"} 0.0', + b'synapse_forward_extremities_bucket{le="100.0"} 0.0', + b'synapse_forward_extremities_bucket{le="200.0"} 0.0', + b'synapse_forward_extremities_bucket{le="500.0"} 0.0', + b'synapse_forward_extremities_bucket{le="+Inf"} 3.0', + b'synapse_forward_extremities_count 3.0', + b'synapse_forward_extremities_sum 10.0', + ]) - # 3 buckets, 2 with 2 extrems, 1 with 6 extrems (bucketed as 7), and - # +Inf which is all - self.assertEqual( - buckets, - { - 1.0: 0, - 2.0: 2, - 3.0: 0, - 5.0: 0, - 7.0: 1, - 10.0: 0, - 15.0: 0, - 20.0: 0, - 50.0: 0, - 100.0: 0, - 200.0: 0, - 500.0: 0, - "+Inf": 3, - }, - ) - # 3 rooms, with 10 total events - self.assertEqual(_count, 3) - self.assertEqual(_sum, 10) + self.assertEqual(items, expected) -- cgit 1.5.1 From d0530382eeff053547304532167c0e4654af172c Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 14 Jun 2019 13:18:24 +0100 Subject: Track deactivated accounts in the database (#5378) --- changelog.d/5378.misc | 1 + synapse/handlers/deactivate_account.py | 4 + synapse/storage/registration.py | 114 +++++++++++++++++++++ .../schema/delta/55/users_alter_deactivated.sql | 19 ++++ tests/rest/client/v2_alpha/test_account.py | 45 ++++++++ 5 files changed, 183 insertions(+) create mode 100644 changelog.d/5378.misc create mode 100644 synapse/storage/schema/delta/55/users_alter_deactivated.sql (limited to 'tests') diff --git a/changelog.d/5378.misc b/changelog.d/5378.misc new file mode 100644 index 0000000000..365e49d634 --- /dev/null +++ b/changelog.d/5378.misc @@ -0,0 +1 @@ +Track deactivated accounts in the database. diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 6a91f7698e..b29089d82c 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2017, 2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -114,6 +115,9 @@ class DeactivateAccountHandler(BaseHandler): # parts users from rooms (if it isn't already running) self._start_user_parting() + # Mark the user as deactivated. + yield self.store.set_user_deactivated_status(user_id, True) + defer.returnValue(identity_server_supports_unbinding) def _start_user_parting(self): diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 1dd1182e82..4c5751b57f 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import re from six import iterkeys @@ -31,6 +32,8 @@ from synapse.util.caches.descriptors import cached, cachedInlineCallbacks THIRTY_MINUTES_IN_MS = 30 * 60 * 1000 +logger = logging.getLogger(__name__) + class RegistrationWorkerStore(SQLBaseStore): def __init__(self, db_conn, hs): @@ -598,11 +601,75 @@ class RegistrationStore( "user_threepids_grandfather", self._bg_user_threepids_grandfather, ) + self.register_background_update_handler( + "users_set_deactivated_flag", self._backgroud_update_set_deactivated_flag, + ) + # Create a background job for culling expired 3PID validity tokens hs.get_clock().looping_call( self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS, ) + @defer.inlineCallbacks + def _backgroud_update_set_deactivated_flag(self, progress, batch_size): + """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1 + for each of them. + """ + + last_user = progress.get("user_id", "") + + def _backgroud_update_set_deactivated_flag_txn(txn): + txn.execute( + """ + SELECT + users.name, + COUNT(access_tokens.token) AS count_tokens, + COUNT(user_threepids.address) AS count_threepids + FROM users + LEFT JOIN access_tokens ON (access_tokens.user_id = users.name) + LEFT JOIN user_threepids ON (user_threepids.user_id = users.name) + WHERE password_hash IS NULL OR password_hash = '' + AND users.name > ? + GROUP BY users.name + ORDER BY users.name ASC + LIMIT ?; + """, + (last_user, batch_size), + ) + + rows = self.cursor_to_dict(txn) + + if not rows: + return True + + rows_processed_nb = 0 + + for user in rows: + if not user["count_tokens"] and not user["count_threepids"]: + self.set_user_deactivated_status_txn(txn, user["user_id"], True) + rows_processed_nb += 1 + + logger.info("Marked %d rows as deactivated", rows_processed_nb) + + self._background_update_progress_txn( + txn, "users_set_deactivated_flag", {"user_id": rows[-1]["user_id"]} + ) + + if batch_size > len(rows): + return True + else: + return False + + end = yield self.runInteraction( + "users_set_deactivated_flag", + _backgroud_update_set_deactivated_flag_txn, + ) + + if end: + yield self._end_background_update("users_set_deactivated_flag") + + defer.returnValue(batch_size) + @defer.inlineCallbacks def add_access_token_to_user(self, user_id, token, device_id=None): """Adds an access token for the given user. @@ -1268,3 +1335,50 @@ class RegistrationStore( "delete_threepid_session", delete_threepid_session_txn, ) + + def set_user_deactivated_status_txn(self, txn, user_id, deactivated): + self._simple_update_one_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + updatevalues={"deactivated": 1 if deactivated else 0}, + ) + self._invalidate_cache_and_stream( + txn, self.get_user_deactivated_status, (user_id,), + ) + + @defer.inlineCallbacks + def set_user_deactivated_status(self, user_id, deactivated): + """Set the `deactivated` property for the provided user to the provided value. + + Args: + user_id (str): The ID of the user to set the status for. + deactivated (bool): The value to set for `deactivated`. + """ + + yield self.runInteraction( + "set_user_deactivated_status", + self.set_user_deactivated_status_txn, + user_id, deactivated, + ) + + @cachedInlineCallbacks() + def get_user_deactivated_status(self, user_id): + """Retrieve the value for the `deactivated` property for the provided user. + + Args: + user_id (str): The ID of the user to retrieve the status for. + + Returns: + defer.Deferred(bool): The requested value. + """ + + res = yield self._simple_select_one_onecol( + table="users", + keyvalues={"name": user_id}, + retcol="deactivated", + desc="get_user_deactivated_status", + ) + + # Convert the integer into a boolean. + defer.returnValue(res == 1) diff --git a/synapse/storage/schema/delta/55/users_alter_deactivated.sql b/synapse/storage/schema/delta/55/users_alter_deactivated.sql new file mode 100644 index 0000000000..dabdde489b --- /dev/null +++ b/synapse/storage/schema/delta/55/users_alter_deactivated.sql @@ -0,0 +1,19 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE users ADD deactivated SMALLINT DEFAULT 0 NOT NULL; + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('users_set_deactivated_flag', '{}'); diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 0d1c0868ce..a60a4a3b87 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os import re from email.parser import Parser @@ -239,3 +240,47 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): ) self.render(request) self.assertEquals(expected_code, channel.code, channel.result) + + +class DeactivateTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + account.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver() + return hs + + def test_deactivate_account(self): + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + + request_data = json.dumps({ + "auth": { + "type": "m.login.password", + "user": user_id, + "password": "test", + }, + "erase": False, + }) + request, channel = self.make_request( + "POST", + "account/deactivate", + request_data, + access_token=tok, + ) + self.render(request) + self.assertEqual(request.code, 200) + + store = self.hs.get_datastore() + + # Check that the user has been marked as deactivated. + self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id))) + + # Check that this access token has been invalidated. + request, channel = self.make_request("GET", "account/whoami") + self.render(request) + self.assertEqual(request.code, 401) -- cgit 1.5.1 From 3ed595e327aee6d45ed0371c98e828d724c26b2d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 14 Jun 2019 14:07:32 +0100 Subject: Prometheus histograms are cumalative --- synapse/metrics/__init__.py | 1 - synapse/storage/events.py | 3 ++- tests/storage/test_event_metrics.py | 20 ++++++++++---------- 3 files changed, 12 insertions(+), 12 deletions(-) (limited to 'tests') diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index 0d3ae1a43d..8aee14a8a8 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -224,7 +224,6 @@ class BucketCollector(object): for i, bound in enumerate(self.buckets): if x <= bound: buckets[bound] = buckets.get(bound, 0) + data[x] - break for i in self.buckets: res.append([str(i), buckets.get(i, 0)]) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 1578403f79..f631fb1733 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -228,7 +228,8 @@ class EventsStore( self._state_resolution_handler = hs.get_state_resolution_handler() # Collect metrics on the number of forward extremities that exist. - self._current_forward_extremities_amount = {} + # Counter of number of extremities to count + self._current_forward_extremities_amount = c_counter() BucketCollector( "synapse_forward_extremities", diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py index 1655fcdafc..19f9ccf5e0 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py @@ -64,16 +64,16 @@ class ExtremStatisticsTestCase(HomeserverTestCase): expected = set([ b'synapse_forward_extremities_bucket{le="1.0"} 0.0', b'synapse_forward_extremities_bucket{le="2.0"} 2.0', - b'synapse_forward_extremities_bucket{le="3.0"} 0.0', - b'synapse_forward_extremities_bucket{le="5.0"} 0.0', - b'synapse_forward_extremities_bucket{le="7.0"} 1.0', - b'synapse_forward_extremities_bucket{le="10.0"} 0.0', - b'synapse_forward_extremities_bucket{le="15.0"} 0.0', - b'synapse_forward_extremities_bucket{le="20.0"} 0.0', - b'synapse_forward_extremities_bucket{le="50.0"} 0.0', - b'synapse_forward_extremities_bucket{le="100.0"} 0.0', - b'synapse_forward_extremities_bucket{le="200.0"} 0.0', - b'synapse_forward_extremities_bucket{le="500.0"} 0.0', + b'synapse_forward_extremities_bucket{le="3.0"} 2.0', + b'synapse_forward_extremities_bucket{le="5.0"} 2.0', + b'synapse_forward_extremities_bucket{le="7.0"} 3.0', + b'synapse_forward_extremities_bucket{le="10.0"} 3.0', + b'synapse_forward_extremities_bucket{le="15.0"} 3.0', + b'synapse_forward_extremities_bucket{le="20.0"} 3.0', + b'synapse_forward_extremities_bucket{le="50.0"} 3.0', + b'synapse_forward_extremities_bucket{le="100.0"} 3.0', + b'synapse_forward_extremities_bucket{le="200.0"} 3.0', + b'synapse_forward_extremities_bucket{le="500.0"} 3.0', b'synapse_forward_extremities_bucket{le="+Inf"} 3.0', b'synapse_forward_extremities_count 3.0', b'synapse_forward_extremities_sum 10.0', -- cgit 1.5.1 From 6d56a694f4cbfaf9c57a56837d4170e6c6783f3c Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 7 Jun 2019 15:30:54 +0100 Subject: Don't send renewal emails to deactivated users --- changelog.d/5394.bugfix | 1 + synapse/handlers/account_validity.py | 3 ++ synapse/handlers/deactivate_account.py | 6 +++ synapse/storage/_base.py | 4 +- synapse/storage/registration.py | 14 ++++++ tests/rest/client/v2_alpha/test_register.py | 67 ++++++++++++++++++----------- 6 files changed, 68 insertions(+), 27 deletions(-) create mode 100644 changelog.d/5394.bugfix (limited to 'tests') diff --git a/changelog.d/5394.bugfix b/changelog.d/5394.bugfix new file mode 100644 index 0000000000..2ad9fbe82c --- /dev/null +++ b/changelog.d/5394.bugfix @@ -0,0 +1 @@ +Fix a bug where deactivated users could receive renewal emails if the account validity feature is on. diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 261446517d..5e0b92eb1c 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -110,6 +110,9 @@ class AccountValidityHandler(object): # Stop right here if the user doesn't have at least one email address. # In this case, they will have to ask their server admin to renew their # account manually. + # We don't need to do a specific check to make sure the account isn't + # deactivated, as a deactivated account isn't supposed to have any + # email address attached to it. if not addresses: return diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index b29089d82c..7378b56c1d 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -43,6 +43,8 @@ class DeactivateAccountHandler(BaseHandler): # it left off (if it has work left to do). hs.get_reactor().callWhenRunning(self._start_user_parting) + self._account_validity_enabled = hs.config.account_validity.enabled + @defer.inlineCallbacks def deactivate_account(self, user_id, erase_data, id_server=None): """Deactivate a user's account @@ -115,6 +117,10 @@ class DeactivateAccountHandler(BaseHandler): # parts users from rooms (if it isn't already running) self._start_user_parting() + # Remove all information on the user from the account_validity table. + if self._account_validity_enabled: + yield self.store.delete_account_validity_for_user(user_id) + # Mark the user as deactivated. yield self.store.set_user_deactivated_status(user_id, True) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index ae891aa332..941c07fce5 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -299,12 +299,12 @@ class SQLBaseStore(object): def select_users_with_no_expiration_date_txn(txn): """Retrieves the list of registered users with no expiration date from the - database. + database, filtering out deactivated users. """ sql = ( "SELECT users.name FROM users" " LEFT JOIN account_validity ON (users.name = account_validity.user_id)" - " WHERE account_validity.user_id is NULL;" + " WHERE account_validity.user_id is NULL AND users.deactivated = 0;" ) txn.execute(sql, []) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 4c5751b57f..9f910eac9c 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -251,6 +251,20 @@ class RegistrationWorkerStore(SQLBaseStore): desc="set_renewal_mail_status", ) + @defer.inlineCallbacks + def delete_account_validity_for_user(self, user_id): + """Deletes the entry for the given user in the account validity table, removing + their expiration date and renewal token. + + Args: + user_id (str): ID of the user to remove from the account validity table. + """ + yield self._simple_delete_one( + table="account_validity", + keyvalues={"user_id": user_id}, + desc="delete_account_validity_for_user", + ) + @defer.inlineCallbacks def is_server_admin(self, user): res = yield self._simple_select_one_onecol( diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 8536e6777a..b35b215446 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -26,7 +26,7 @@ from synapse.api.constants import LoginType from synapse.api.errors import Codes from synapse.appservice import ApplicationService from synapse.rest.client.v1 import login -from synapse.rest.client.v2_alpha import account_validity, register, sync +from synapse.rest.client.v2_alpha import account, account_validity, register, sync from tests import unittest @@ -308,6 +308,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): login.register_servlets, sync.register_servlets, account_validity.register_servlets, + account.register_servlets, ] def make_homeserver(self, reactor, clock): @@ -358,20 +359,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): def test_renewal_email(self): self.email_attempts = [] - user_id = self.register_user("kermit", "monkey") - tok = self.login("kermit", "monkey") - # We need to manually add an email address otherwise the handler will do - # nothing. - now = self.hs.clock.time_msec() - self.get_success( - self.store.user_add_threepid( - user_id=user_id, - medium="email", - address="kermit@example.com", - validated_at=now, - added_at=now, - ) - ) + (user_id, tok) = self.create_user() # Move 6 days forward. This should trigger a renewal email to be sent. self.reactor.advance(datetime.timedelta(days=6).total_seconds()) @@ -396,6 +384,44 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): def test_manual_email_send(self): self.email_attempts = [] + (user_id, tok) = self.create_user() + request, channel = self.make_request( + b"POST", + "/_matrix/client/unstable/account_validity/send_mail", + access_token=tok, + ) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.assertEqual(len(self.email_attempts), 1) + + def test_deactivated_user(self): + self.email_attempts = [] + + (user_id, tok) = self.create_user() + + request_data = json.dumps({ + "auth": { + "type": "m.login.password", + "user": user_id, + "password": "monkey", + }, + "erase": False, + }) + request, channel = self.make_request( + "POST", + "account/deactivate", + request_data, + access_token=tok, + ) + self.render(request) + self.assertEqual(request.code, 200) + + self.reactor.advance(datetime.timedelta(days=8).total_seconds()) + + self.assertEqual(len(self.email_attempts), 0) + + def create_user(self): user_id = self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") # We need to manually add an email address otherwise the handler will do @@ -410,16 +436,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): added_at=now, ) ) - - request, channel = self.make_request( - b"POST", - "/_matrix/client/unstable/account_validity/send_mail", - access_token=tok, - ) - self.render(request) - self.assertEquals(channel.result["code"], b"200", channel.result) - - self.assertEqual(len(self.email_attempts), 1) + return (user_id, tok) def test_manual_email_send_expired_account(self): user_id = self.register_user("kermit", "monkey") -- cgit 1.5.1 From b42f90470f00831bfd9b7ebca19111ed229599b0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 17 Jun 2019 18:04:42 +0100 Subject: Add experimental option to reduce extremities. Adds new config option `cleanup_extremities_with_dummy_events` which periodically sends dummy events to rooms with more than 10 extremities. THIS IS REALLY EXPERIMENTAL. --- synapse/config/server.py | 6 +++ synapse/events/__init__.py | 12 ++++++ synapse/federation/sender/__init__.py | 3 ++ synapse/handlers/message.py | 72 ++++++++++++++++++++++++++++++++++- synapse/storage/event_federation.py | 29 ++++++++++++++ tests/storage/test_cleanup_extrems.py | 41 ++++++++++++++++++++ 6 files changed, 162 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/synapse/config/server.py b/synapse/config/server.py index 7d56e2d141..6e5b46e6c3 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -317,6 +317,12 @@ class ServerConfig(Config): _check_resource_config(self.listeners) + # An experimental option to try and periodically clean up extremities + # by sending dummy events. + self.cleanup_extremities_with_dummy_events = config.get( + "cleanup_extremities_with_dummy_events", False, + ) + def has_tls_listener(self): return any(l["tls"] for l in self.listeners) diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 1edd19cc13..f1fbb3d14a 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -92,6 +92,18 @@ class _EventInternalMetadata(object): """ return getattr(self, "soft_failed", False) + def should_proactively_send(self): + """Whether the eventm, if ours, should be sent to other clients and + servers. + + This is used for sending dummy events internally. Servers and clients + can still explicitly fetch the event. + + Returns: + bool + """ + return getattr(self, "proactively_send", True) + def _event_dict_property(key): # We want to be able to use hasattr with the event dict properties. diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 4f0f939102..4224b29ecf 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -168,6 +168,9 @@ class FederationSender(object): if not is_mine and send_on_behalf_of is None: return + if not event.internal_metadata.should_proactively_send(): + return + try: # Get the state from before the event. # We need to make sure that this is the state from before diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 11650dc80c..3b5942b7ae 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -36,7 +36,7 @@ from synapse.api.urls import ConsentURIBuilder from synapse.events.validator import EventValidator from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.storage.state import StateFilter -from synapse.types import RoomAlias, UserID +from synapse.types import RoomAlias, UserID, create_requester from synapse.util.async_helpers import Linearizer from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.logcontext import run_in_background @@ -261,6 +261,16 @@ class EventCreationHandler(object): if self._block_events_without_consent_error: self._consent_uri_builder = ConsentURIBuilder(self.config) + if ( + not self.config.worker_app + and self.config.cleanup_extremities_with_dummy_events + ): + # XXX: Send dummy events. + self.clock.looping_call( + self._send_dummy_events_to_fill_extremities, + 5 * 60 * 1000, + ) + @defer.inlineCallbacks def create_event(self, requester, event_dict, token_id=None, txn_id=None, prev_events_and_hashes=None, require_consent=True): @@ -874,3 +884,63 @@ class EventCreationHandler(object): yield presence.bump_presence_active_time(user) except Exception: logger.exception("Error bumping presence active time") + + @defer.inlineCallbacks + def _send_dummy_events_to_fill_extremities(self): + """Background task to send dummy events into rooms that have a large + number of extremities + """ + + room_ids = yield self.store.get_rooms_with_many_extremities( + min_count=10, limit=5, + ) + + for room_id in room_ids: + # For each room we need to find a joined member we can use to send + # the dummy event with. + + prev_events_and_hashes = yield self.store.get_prev_events_for_room( + room_id, + ) + + latest_event_ids = ( + event_id for (event_id, _, _) in prev_events_and_hashes + ) + + members = yield self.state.get_current_users_in_room( + room_id, latest_event_ids=latest_event_ids, + ) + + user_id = None + for member in members: + if self.hs.is_mine_id(member): + user_id = member + break + + if not user_id: + # We don't have a joined user. + # TODO: We should do something here to stop the room from + # appearing next time. + continue + + requester = create_requester(user_id) + + event, context = yield self.create_event( + requester, + { + "type": "org.matrix.dummy_event", + "content": {}, + "room_id": room_id, + "sender": user_id, + }, + prev_events_and_hashes=prev_events_and_hashes, + ) + + event.internal_metadata.proactively_send = False + + yield self.send_nonmember_event( + requester, + event, + context, + ratelimit=False, + ) diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 09e39c2c28..e8d16edbc8 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -190,6 +190,35 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas room_id, ) + def get_rooms_with_many_extremities(self, min_count, limit): + """Get the top rooms with at least N extremities. + + Args: + min_count (int): The minimum number of extremities + limit (int): The maximum number of rooms to return. + + Returns: + Deferred[list]: At most `limit` room IDs that have at least + `min_count` extremities, sorted by extremity count. + """ + + def _get_rooms_with_many_extremities_txn(txn): + sql = """ + SELECT room_id FROM event_forward_extremities + GROUP BY room_id + HAVING count(*) > ? + ORDER BY count(*) DESC + LIMIT ? + """ + + txn.execute(sql, (min_count, limit)) + return [room_id for room_id, in txn] + + return self.runInteraction( + "get_rooms_with_many_extremities", + _get_rooms_with_many_extremities_txn, + ) + @cached(max_entries=5000, iterable=True) def get_latest_event_ids_in_room(self, room_id): return self._simple_select_onecol( diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index f4c81ef77d..ed5d58f58c 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -222,3 +222,44 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): self.store.get_latest_event_ids_in_room(self.room_id) ) self.assertEqual(set(latest_event_ids), set([event_id_b, event_id_c])) + + +class CleanupExtremDummyEventsTestCase(HomeserverTestCase): + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["cleanup_extremities_with_dummy_events"] = True + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor, clock, homeserver): + self.store = homeserver.get_datastore() + self.room_creator = homeserver.get_room_creation_handler() + + # Create a test user and room + self.user = UserID("alice", "test") + self.requester = Requester(self.user, None, False, None, None) + info = self.get_success(self.room_creator.create_room(self.requester, {})) + self.room_id = info["room_id"] + + def test_send_dummy_event(self): + # Create a bushy graph with 50 extremities. + + event_id_start = self.create_and_send_event(self.room_id, self.user) + + for _ in range(50): + self.create_and_send_event( + self.room_id, self.user, prev_event_ids=[event_id_start] + ) + + latest_event_ids = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) + ) + self.assertEqual(len(latest_event_ids), 50) + + # Bump the reacto repeatedly so that the background updates have a + # chance to run. + self.pump(10 * 60) + + latest_event_ids = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) + ) + self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids)) -- cgit 1.5.1 From 554609288b0fc5f36d9dd9c45a939e7c81698b12 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 19 Jun 2019 11:33:03 +0100 Subject: Run as background process and fix comments --- synapse/events/__init__.py | 2 +- synapse/handlers/message.py | 7 +++++-- tests/storage/test_cleanup_extrems.py | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index f1fbb3d14a..7154bcbea6 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -93,7 +93,7 @@ class _EventInternalMetadata(object): return getattr(self, "soft_failed", False) def should_proactively_send(self): - """Whether the eventm, if ours, should be sent to other clients and + """Whether the event, if ours, should be sent to other clients and servers. This is used for sending dummy events internally. Servers and clients diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 3b5942b7ae..7728ea230d 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -34,6 +34,7 @@ from synapse.api.errors import ( from synapse.api.room_versions import RoomVersions from synapse.api.urls import ConsentURIBuilder from synapse.events.validator import EventValidator +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.storage.state import StateFilter from synapse.types import RoomAlias, UserID, create_requester @@ -265,9 +266,11 @@ class EventCreationHandler(object): not self.config.worker_app and self.config.cleanup_extremities_with_dummy_events ): - # XXX: Send dummy events. self.clock.looping_call( - self._send_dummy_events_to_fill_extremities, + lambda: run_as_background_process( + "send_dummy_events_to_fill_extremities", + self._send_dummy_events_to_fill_extremities + ), 5 * 60 * 1000, ) diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index ed5d58f58c..e9e2d5337c 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -255,7 +255,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): ) self.assertEqual(len(latest_event_ids), 50) - # Bump the reacto repeatedly so that the background updates have a + # Pump the reactor repeatedly so that the background updates have a # chance to run. self.pump(10 * 60) -- cgit 1.5.1 From 32e7c9e7f20b57dd081023ac42d6931a8da9b3a3 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Thu, 20 Jun 2019 19:32:02 +1000 Subject: Run Black. (#5482) --- .buildkite/pipeline.yml | 4 +- changelog.d/5482.misc | 1 + contrib/cmdclient/console.py | 309 ++++--- contrib/cmdclient/http.py | 38 +- contrib/experiments/cursesio.py | 36 +- contrib/experiments/test_messaging.py | 81 +- contrib/graph/graph.py | 33 +- contrib/graph/graph2.py | 51 +- contrib/graph/graph3.py | 46 +- contrib/jitsimeetbridge/jitsimeetbridge.py | 257 +++--- contrib/scripts/kick_users.py | 39 +- demo/webserver.py | 25 +- docker/start.py | 58 +- docs/sphinx/conf.py | 158 ++-- pyproject.toml | 19 + scripts-dev/check_auth.py | 4 +- scripts-dev/check_event_hash.py | 2 +- scripts-dev/check_signature.py | 5 +- scripts-dev/convert_server_keys.py | 2 +- scripts-dev/definitions.py | 62 +- scripts-dev/federation_client.py | 32 +- scripts-dev/hash_history.py | 2 +- scripts-dev/list_url_patterns.py | 4 +- scripts-dev/tail-synapse.py | 2 +- scripts/generate_signing_key.py | 8 +- scripts/move_remote_media_to_new_store.py | 4 +- setup.cfg | 7 +- setup.py | 31 +- synapse/__init__.py | 1 + synapse/_scripts/register_new_matrix_user.py | 24 +- synapse/api/auth.py | 159 ++-- synapse/api/constants.py | 46 +- synapse/api/errors.py | 112 ++- synapse/api/filtering.py | 189 ++--- synapse/api/ratelimiting.py | 20 +- synapse/api/room_versions.py | 25 +- synapse/api/urls.py | 17 +- synapse/app/__init__.py | 4 +- synapse/app/_base.py | 77 +- synapse/app/appservice.py | 29 +- synapse/app/client_reader.py | 29 +- synapse/app/event_creator.py | 39 +- synapse/app/federation_reader.py | 38 +- synapse/app/federation_sender.py | 46 +- synapse/app/frontend_proxy.py | 69 +- synapse/app/homeserver.py | 163 ++-- synapse/app/media_repository.py | 37 +- synapse/app/pusher.py | 60 +- synapse/app/synchrotron.py | 90 ++- synapse/app/user_dir.py | 45 +- synapse/appservice/__init__.py | 39 +- synapse/appservice/api.py | 66 +- synapse/appservice/scheduler.py | 50 +- synapse/config/_base.py | 12 +- synapse/config/api.py | 22 +- synapse/config/appservice.py | 54 +- synapse/config/captcha.py | 1 - synapse/config/consent_config.py | 27 +- synapse/config/database.py | 31 +- synapse/config/emailconfig.py | 71 +- synapse/config/jwt_config.py | 5 +- synapse/config/key.py | 5 +- synapse/config/logger.py | 70 +- synapse/config/metrics.py | 8 +- synapse/config/password_auth_providers.py | 16 +- synapse/config/registration.py | 41 +- synapse/config/repository.py | 74 +- synapse/config/room_directory.py | 19 +- synapse/config/saml2_config.py | 18 +- synapse/config/server.py | 226 +++--- synapse/config/server_notices_config.py | 17 +- synapse/config/tls.py | 52 +- synapse/config/user_directory.py | 8 +- synapse/config/voip.py | 3 +- synapse/config/workers.py | 16 +- synapse/crypto/event_signing.py | 17 +- synapse/crypto/keyring.py | 33 +- synapse/event_auth.py | 177 ++--- synapse/events/__init__.py | 30 +- synapse/events/builder.py | 37 +- synapse/events/snapshot.py | 49 +- synapse/events/spamcheck.py | 4 +- synapse/events/third_party_rules.py | 5 +- synapse/events/utils.py | 58 +- synapse/events/validator.py | 21 +- synapse/federation/federation_base.py | 78 +- synapse/federation/federation_client.py | 217 +++-- synapse/federation/federation_server.py | 234 +++--- synapse/federation/persistence.py | 15 +- synapse/federation/send_queue.py | 145 ++-- synapse/federation/sender/__init__.py | 67 +- synapse/federation/sender/per_destination_queue.py | 7 +- synapse/federation/sender/transaction_manager.py | 38 +- synapse/federation/transport/client.py | 294 ++++--- synapse/federation/transport/server.py | 237 +++--- synapse/federation/units.py | 33 +- synapse/groups/attestations.py | 35 +- synapse/groups/groups_server.py | 317 +++----- synapse/handlers/_base.py | 10 +- synapse/handlers/account_data.py | 23 +- synapse/handlers/account_validity.py | 61 +- synapse/handlers/acme.py | 14 +- synapse/handlers/admin.py | 23 +- synapse/handlers/appservice.py | 74 +- synapse/handlers/auth.py | 266 +++---- synapse/handlers/deactivate_account.py | 12 +- synapse/handlers/device.py | 112 ++- synapse/handlers/devicemessage.py | 7 +- synapse/handlers/directory.py | 136 ++-- synapse/handlers/e2e_keys.py | 119 ++- synapse/handlers/e2e_room_keys.py | 38 +- synapse/handlers/events.py | 49 +- synapse/handlers/federation.py | 884 +++++++++------------ synapse/handlers/groups_local.py | 95 +-- synapse/handlers/identity.py | 132 ++- synapse/handlers/initial_sync.py | 170 ++-- synapse/handlers/message.py | 258 +++--- synapse/handlers/pagination.py | 68 +- synapse/handlers/presence.py | 311 ++++---- synapse/handlers/profile.py | 65 +- synapse/handlers/read_marker.py | 9 +- synapse/handlers/receipts.py | 26 +- synapse/handlers/register.py | 224 +++--- synapse/handlers/room.py | 295 +++---- synapse/handlers/room_list.py | 254 +++--- synapse/handlers/room_member.py | 304 +++---- synapse/handlers/room_member_worker.py | 8 +- synapse/handlers/search.py | 129 ++- synapse/handlers/set_password.py | 5 +- synapse/handlers/state_deltas.py | 1 - synapse/handlers/stats.py | 2 +- synapse/handlers/sync.py | 694 ++++++++-------- synapse/handlers/typing.py | 74 +- synapse/http/__init__.py | 9 +- synapse/http/additional_resource.py | 1 + synapse/http/client.py | 28 +- synapse/http/endpoint.py | 20 +- synapse/http/federation/matrix_federation_agent.py | 98 ++- synapse/http/federation/srv_resolver.py | 37 +- synapse/http/matrixfederationclient.py | 242 +++--- synapse/http/server.py | 80 +- synapse/http/servlet.py | 45 +- synapse/http/site.py | 48 +- synapse/metrics/__init__.py | 7 +- synapse/metrics/background_process_metrics.py | 33 +- synapse/module_api/__init__.py | 6 +- synapse/notifier.py | 70 +- synapse/push/action_generator.py | 4 +- synapse/push/baserules.py | 404 ++++------ synapse/push/bulk_push_rule_evaluator.py | 63 +- synapse/push/clientformat.py | 38 +- synapse/push/emailpusher.py | 63 +- synapse/push/httppusher.py | 226 +++--- synapse/push/mailer.py | 347 ++++---- synapse/push/presentable_names.py | 45 +- synapse/push/push_rule_evaluator.py | 57 +- synapse/push/push_tools.py | 8 +- synapse/push/pusher.py | 10 +- synapse/push/pusherpool.py | 97 ++- synapse/push/rulekinds.py | 10 +- synapse/python_dependencies.py | 24 +- synapse/replication/http/_base.py | 22 +- synapse/replication/http/federation.py | 51 +- synapse/replication/http/login.py | 7 +- synapse/replication/http/membership.py | 34 +- synapse/replication/http/register.py | 17 +- synapse/replication/http/send_event.py | 13 +- synapse/replication/slave/storage/_base.py | 2 +- synapse/replication/slave/storage/account_data.py | 17 +- synapse/replication/slave/storage/appservice.py | 5 +- synapse/replication/slave/storage/client_ips.py | 4 +- synapse/replication/slave/storage/deviceinbox.py | 6 +- synapse/replication/slave/storage/devices.py | 14 +- synapse/replication/slave/storage/events.py | 77 +- synapse/replication/slave/storage/groups.py | 9 +- synapse/replication/slave/storage/presence.py | 8 +- synapse/replication/slave/storage/push_rule.py | 6 +- synapse/replication/slave/storage/pushers.py | 4 +- synapse/replication/slave/storage/receipts.py | 1 - synapse/replication/slave/storage/room.py | 4 +- synapse/replication/tcp/client.py | 6 +- synapse/replication/tcp/commands.py | 71 +- synapse/replication/tcp/protocol.py | 76 +- synapse/replication/tcp/resource.py | 54 +- synapse/replication/tcp/streams/_base.py | 160 ++-- synapse/replication/tcp/streams/events.py | 32 +- synapse/replication/tcp/streams/federation.py | 12 +- synapse/rest/__init__.py | 1 + synapse/rest/admin/__init__.py | 153 ++-- synapse/rest/admin/server_notice_servlet.py | 13 +- synapse/rest/client/transactions.py | 3 +- synapse/rest/client/v1/directory.py | 31 +- synapse/rest/client/v1/events.py | 7 +- synapse/rest/client/v1/login.py | 121 ++- synapse/rest/client/v1/logout.py | 3 +- synapse/rest/client/v1/presence.py | 2 +- synapse/rest/client/v1/profile.py | 6 +- synapse/rest/client/v1/push_rule.py | 119 ++- synapse/rest/client/v1/pusher.py | 70 +- synapse/rest/client/v1/room.py | 185 ++--- synapse/rest/client/v1/voip.py | 24 +- synapse/rest/client/v2_alpha/_base.py | 12 +- synapse/rest/client/v2_alpha/account.py | 190 ++--- synapse/rest/client/v2_alpha/account_data.py | 16 +- synapse/rest/client/v2_alpha/account_validity.py | 16 +- synapse/rest/client/v2_alpha/auth.py | 62 +- synapse/rest/client/v2_alpha/devices.py | 19 +- synapse/rest/client/v2_alpha/filter.py | 11 +- synapse/rest/client/v2_alpha/groups.py | 157 ++-- synapse/rest/client/v2_alpha/keys.py | 27 +- synapse/rest/client/v2_alpha/notifications.py | 28 +- synapse/rest/client/v2_alpha/openid.py | 22 +- synapse/rest/client/v2_alpha/read_marker.py | 4 +- synapse/rest/client/v2_alpha/receipts.py | 5 +- synapse/rest/client/v2_alpha/register.py | 132 ++- synapse/rest/client/v2_alpha/relations.py | 5 +- synapse/rest/client/v2_alpha/report_event.py | 4 +- synapse/rest/client/v2_alpha/room_keys.py | 54 +- .../client/v2_alpha/room_upgrade_rest_servlet.py | 9 +- synapse/rest/client/v2_alpha/sendtodevice.py | 2 +- synapse/rest/client/v2_alpha/sync.py | 129 +-- synapse/rest/client/v2_alpha/tags.py | 14 +- synapse/rest/client/v2_alpha/thirdparty.py | 2 +- synapse/rest/client/v2_alpha/tokenrefresh.py | 1 + synapse/rest/client/v2_alpha/user_directory.py | 7 +- synapse/rest/client/versions.py | 43 +- synapse/rest/consent/consent_resource.py | 28 +- synapse/rest/key/v2/local_key_resource.py | 28 +- synapse/rest/key/v2/remote_key_resource.py | 56 +- synapse/rest/media/v0/content_repository.py | 15 +- synapse/rest/media/v1/_base.py | 57 +- synapse/rest/media/v1/config_resource.py | 4 +- synapse/rest/media/v1/download_resource.py | 8 +- synapse/rest/media/v1/filepath.py | 131 ++- synapse/rest/media/v1/media_repository.py | 215 ++--- synapse/rest/media/v1/media_storage.py | 17 +- synapse/rest/media/v1/preview_url_resource.py | 215 +++-- synapse/rest/media/v1/storage_provider.py | 6 +- synapse/rest/media/v1/thumbnail_resource.py | 120 ++- synapse/rest/media/v1/thumbnailer.py | 15 +- synapse/rest/media/v1/upload_resource.py | 30 +- synapse/rest/saml2/metadata_resource.py | 2 +- synapse/rest/saml2/response_resource.py | 13 +- synapse/rest/well_known.py | 11 +- synapse/secrets.py | 3 +- synapse/server.py | 156 ++-- synapse/server.pyi | 51 +- synapse/server_notices/consent_server_notices.py | 24 +- .../resource_limits_server_notices.py | 33 +- synapse/server_notices/server_notices_manager.py | 30 +- synapse/server_notices/server_notices_sender.py | 11 +- .../server_notices/worker_server_notices_sender.py | 1 + synapse/state/__init__.py | 107 +-- synapse/state/v1.py | 56 +- synapse/state/v2.py | 92 +-- synapse/storage/__init__.py | 10 +- synapse/storage/_base.py | 6 +- synapse/storage/background_updates.py | 2 +- synapse/storage/devices.py | 12 +- synapse/storage/e2e_room_keys.py | 28 +- synapse/storage/engines/sqlite.py | 2 +- synapse/storage/event_federation.py | 3 +- synapse/storage/event_push_actions.py | 4 +- synapse/storage/events.py | 2 +- synapse/storage/events_bg_updates.py | 16 +- synapse/storage/events_worker.py | 10 +- synapse/storage/group_server.py | 8 +- synapse/storage/keys.py | 2 +- synapse/storage/media_repository.py | 22 +- synapse/storage/monthly_active_users.py | 6 +- synapse/storage/prepare_database.py | 13 +- synapse/storage/profile.py | 2 +- synapse/storage/push_rule.py | 30 +- synapse/storage/pusher.py | 36 +- synapse/storage/receipts.py | 2 +- synapse/storage/registration.py | 107 ++- synapse/storage/relations.py | 6 +- synapse/storage/roommember.py | 8 +- synapse/storage/schema/delta/20/pushers.py | 22 +- synapse/storage/schema/delta/30/as_users.py | 17 +- synapse/storage/schema/delta/31/pushers.py | 24 +- synapse/storage/schema/delta/33/remote_media_ts.py | 2 +- synapse/storage/schema/delta/47/state_group_seq.py | 5 +- .../schema/delta/48/group_unique_indexes.py | 14 +- .../schema/delta/50/make_event_content_nullable.py | 12 +- synapse/storage/search.py | 6 +- synapse/storage/stats.py | 34 +- synapse/storage/stream.py | 32 +- synapse/streams/config.py | 27 +- synapse/streams/events.py | 43 +- synapse/types.py | 108 +-- synapse/util/__init__.py | 23 +- synapse/util/async_helpers.py | 41 +- synapse/util/caches/__init__.py | 6 +- synapse/util/caches/descriptors.py | 100 +-- synapse/util/caches/dictionary_cache.py | 15 +- synapse/util/caches/expiringcache.py | 18 +- synapse/util/caches/lrucache.py | 14 +- synapse/util/caches/response_cache.py | 22 +- synapse/util/caches/stream_change_cache.py | 13 +- synapse/util/caches/treecache.py | 1 + synapse/util/caches/ttlcache.py | 1 + synapse/util/distributor.py | 29 +- synapse/util/frozenutils.py | 9 +- synapse/util/httpresourcetree.py | 8 +- synapse/util/jsonobject.py | 6 +- synapse/util/logcontext.py | 83 +- synapse/util/logformatter.py | 3 +- synapse/util/logutils.py | 51 +- synapse/util/manhole.py | 18 +- synapse/util/metrics.py | 32 +- synapse/util/module_loader.py | 6 +- synapse/util/msisdn.py | 6 +- synapse/util/ratelimitutils.py | 35 +- synapse/util/stringutils.py | 16 +- synapse/util/threepids.py | 10 +- synapse/util/versionstring.py | 61 +- synapse/util/wheel_timer.py | 4 +- synapse/visibility.py | 69 +- tests/api/test_auth.py | 12 +- tests/config/test_server.py | 8 +- tests/config/test_tls.py | 2 +- tests/crypto/test_event_signing.py | 46 +- tests/events/test_utils.py | 94 +-- tests/federation/test_complexity.py | 2 +- tests/federation/test_federation_sender.py | 48 +- tests/handlers/test_auth.py | 10 +- tests/handlers/test_directory.py | 8 +- tests/handlers/test_e2e_room_keys.py | 18 +- tests/handlers/test_register.py | 31 +- tests/handlers/test_stats.py | 5 +- tests/handlers/test_typing.py | 16 +- tests/handlers/test_user_directory.py | 8 +- .../federation/test_matrix_federation_agent.py | 216 ++--- tests/http/federation/test_srv_resolver.py | 2 +- tests/http/test_endpoint.py | 12 +- tests/http/test_fedclient.py | 19 +- tests/push/test_email.py | 6 +- tests/rest/admin/test_admin.py | 100 ++- tests/rest/client/test_consent.py | 10 +- tests/rest/client/test_identity.py | 2 +- tests/rest/client/v1/test_profile.py | 7 +- tests/rest/client/v1/test_rooms.py | 46 +- tests/rest/client/v1/utils.py | 8 +- tests/rest/client/v2_alpha/test_account.py | 31 +- tests/rest/client/v2_alpha/test_capabilities.py | 14 +- tests/rest/client/v2_alpha/test_register.py | 35 +- tests/rest/client/v2_alpha/test_relations.py | 10 +- tests/rest/media/v1/test_base.py | 14 +- tests/rest/media/v1/test_media_storage.py | 6 +- tests/rest/media/v1/test_url_preview.py | 56 +- tests/server.py | 16 +- .../test_resource_limits_server_notices.py | 4 +- tests/state/test_v2.py | 20 +- tests/storage/test_appservice.py | 4 +- tests/storage/test_client_ips.py | 20 +- tests/storage/test_devices.py | 16 +- tests/storage/test_end_to_end_keys.py | 8 +- tests/storage/test_event_federation.py | 14 +- tests/storage/test_event_metrics.py | 36 +- tests/storage/test_monthly_active_users.py | 22 +- tests/storage/test_redaction.py | 4 +- tests/storage/test_registration.py | 2 +- tests/storage/test_room.py | 4 +- tests/storage/test_state.py | 16 +- tests/test_preview.py | 136 ++-- tests/test_server.py | 14 +- tests/test_state.py | 6 +- tests/test_types.py | 4 +- tests/test_utils/logging_setup.py | 2 +- tests/test_visibility.py | 2 +- tests/unittest.py | 18 +- tests/util/caches/test_descriptors.py | 60 +- tests/util/caches/test_ttlcache.py | 46 +- tests/utils.py | 18 +- tox.ini | 9 +- 376 files changed, 9153 insertions(+), 10399 deletions(-) create mode 100644 changelog.d/5482.misc (limited to 'tests') diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 20c7aab5a7..513eb3bde9 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -5,8 +5,8 @@ steps: - command: - "python -m pip install tox" - - "tox -e pep8" - label: "\U0001F9F9 PEP-8" + - "tox -e check_codestyle" + label: "\U0001F9F9 Check Style" plugins: - docker#v3.0.1: image: "python:3.6" diff --git a/changelog.d/5482.misc b/changelog.d/5482.misc new file mode 100644 index 0000000000..0332d1133b --- /dev/null +++ b/changelog.d/5482.misc @@ -0,0 +1 @@ +Synapse's codebase is now formatted by `black`. diff --git a/contrib/cmdclient/console.py b/contrib/cmdclient/console.py index 462f146113..af8f39c8c2 100755 --- a/contrib/cmdclient/console.py +++ b/contrib/cmdclient/console.py @@ -37,9 +37,8 @@ from signedjson.sign import verify_signed_json, SignatureVerifyException CONFIG_JSON = "cmdclient_config.json" -TRUSTED_ID_SERVERS = [ - 'localhost:8001' -] +TRUSTED_ID_SERVERS = ["localhost:8001"] + class SynapseCmd(cmd.Cmd): @@ -59,7 +58,7 @@ class SynapseCmd(cmd.Cmd): "token": token, "verbose": "on", "complete_usernames": "on", - "send_delivery_receipts": "on" + "send_delivery_receipts": "on", } self.path_prefix = "/_matrix/client/api/v1" self.event_stream_token = "END" @@ -120,12 +119,11 @@ class SynapseCmd(cmd.Cmd): config_rules = [ # key, valid_values ("verbose", ["on", "off"]), ("complete_usernames", ["on", "off"]), - ("send_delivery_receipts", ["on", "off"]) + ("send_delivery_receipts", ["on", "off"]), ] for key, valid_vals in config_rules: if key == args["key"] and args["val"] not in valid_vals: - print("%s value must be one of %s" % (args["key"], - valid_vals)) + print("%s value must be one of %s" % (args["key"], valid_vals)) return # toggle the http client verbosity @@ -159,16 +157,13 @@ class SynapseCmd(cmd.Cmd): else: password = pwd - body = { - "type": "m.login.password" - } + body = {"type": "m.login.password"} if "userid" in args: body["user"] = args["userid"] if password: body["password"] = password - reactor.callFromThread(self._do_register, body, - "noupdate" not in args) + reactor.callFromThread(self._do_register, body, "noupdate" not in args) @defer.inlineCallbacks def _do_register(self, data, update_config): @@ -179,7 +174,9 @@ class SynapseCmd(cmd.Cmd): passwordFlow = None for flow in json_res["flows"]: - if flow["type"] == "m.login.recaptcha" or ("stages" in flow and "m.login.recaptcha" in flow["stages"]): + if flow["type"] == "m.login.recaptcha" or ( + "stages" in flow and "m.login.recaptcha" in flow["stages"] + ): print("Unable to register: Home server requires captcha.") return if flow["type"] == "m.login.password" and "stages" not in flow: @@ -202,9 +199,7 @@ class SynapseCmd(cmd.Cmd): """ try: args = self._parse(line, ["user_id"], force_keys=True) - can_login = threads.blockingCallFromThread( - reactor, - self._check_can_login) + can_login = threads.blockingCallFromThread(reactor, self._check_can_login) if can_login: p = getpass.getpass("Enter your password: ") user = args["user_id"] @@ -212,20 +207,16 @@ class SynapseCmd(cmd.Cmd): domain = self._domain() if domain: user = "@" + user + ":" + domain - + reactor.callFromThread(self._do_login, user, p) - #print " got %s " % p + # print " got %s " % p except Exception as e: print(e) @defer.inlineCallbacks def _do_login(self, user, password): path = "/login" - data = { - "user": user, - "password": password, - "type": "m.login.password" - } + data = {"user": user, "password": password, "type": "m.login.password"} url = self._url() + path json_res = yield self.http_client.do_request("POST", url, data=data) print(json_res) @@ -249,12 +240,13 @@ class SynapseCmd(cmd.Cmd): print("Failed to find any login flows.") defer.returnValue(False) - flow = json_res["flows"][0] # assume first is the one we want. - if ("type" not in flow or "m.login.password" != flow["type"] or - "stages" in flow): + flow = json_res["flows"][0] # assume first is the one we want. + if "type" not in flow or "m.login.password" != flow["type"] or "stages" in flow: fallback_url = self._url() + "/login/fallback" - print ("Unable to login via the command line client. Please visit " - "%s to login." % fallback_url) + print( + "Unable to login via the command line client. Please visit " + "%s to login." % fallback_url + ) defer.returnValue(False) defer.returnValue(True) @@ -264,21 +256,33 @@ class SynapseCmd(cmd.Cmd): A string of characters generated when requesting an email that you'll supply in subsequent calls to identify yourself The number of times the user has requested an email. Leave this the same between requests to retry the request at the transport level. Increment it to request that the email be sent again. """ - args = self._parse(line, ['address', 'clientSecret', 'sendAttempt']) + args = self._parse(line, ["address", "clientSecret", "sendAttempt"]) - postArgs = {'email': args['address'], 'clientSecret': args['clientSecret'], 'sendAttempt': args['sendAttempt']} + postArgs = { + "email": args["address"], + "clientSecret": args["clientSecret"], + "sendAttempt": args["sendAttempt"], + } reactor.callFromThread(self._do_emailrequest, postArgs) @defer.inlineCallbacks def _do_emailrequest(self, args): - url = self._identityServerUrl()+"/_matrix/identity/api/v1/validate/email/requestToken" - - json_res = yield self.http_client.do_request("POST", url, data=urllib.urlencode(args), jsonreq=False, - headers={'Content-Type': ['application/x-www-form-urlencoded']}) + url = ( + self._identityServerUrl() + + "/_matrix/identity/api/v1/validate/email/requestToken" + ) + + json_res = yield self.http_client.do_request( + "POST", + url, + data=urllib.urlencode(args), + jsonreq=False, + headers={"Content-Type": ["application/x-www-form-urlencoded"]}, + ) print(json_res) - if 'sid' in json_res: - print("Token sent. Your session ID is %s" % (json_res['sid'])) + if "sid" in json_res: + print("Token sent. Your session ID is %s" % (json_res["sid"])) def do_emailvalidate(self, line): """Validate and associate a third party ID @@ -286,18 +290,30 @@ class SynapseCmd(cmd.Cmd): The token sent to your third party identifier address The same clientSecret you supplied in requestToken """ - args = self._parse(line, ['sid', 'token', 'clientSecret']) + args = self._parse(line, ["sid", "token", "clientSecret"]) - postArgs = { 'sid' : args['sid'], 'token' : args['token'], 'clientSecret': args['clientSecret'] } + postArgs = { + "sid": args["sid"], + "token": args["token"], + "clientSecret": args["clientSecret"], + } reactor.callFromThread(self._do_emailvalidate, postArgs) @defer.inlineCallbacks def _do_emailvalidate(self, args): - url = self._identityServerUrl()+"/_matrix/identity/api/v1/validate/email/submitToken" - - json_res = yield self.http_client.do_request("POST", url, data=urllib.urlencode(args), jsonreq=False, - headers={'Content-Type': ['application/x-www-form-urlencoded']}) + url = ( + self._identityServerUrl() + + "/_matrix/identity/api/v1/validate/email/submitToken" + ) + + json_res = yield self.http_client.do_request( + "POST", + url, + data=urllib.urlencode(args), + jsonreq=False, + headers={"Content-Type": ["application/x-www-form-urlencoded"]}, + ) print(json_res) def do_3pidbind(self, line): @@ -305,19 +321,24 @@ class SynapseCmd(cmd.Cmd): The session ID (sid) given to you in the response to requestToken The same clientSecret you supplied in requestToken """ - args = self._parse(line, ['sid', 'clientSecret']) + args = self._parse(line, ["sid", "clientSecret"]) - postArgs = { 'sid' : args['sid'], 'clientSecret': args['clientSecret'] } - postArgs['mxid'] = self.config["user"] + postArgs = {"sid": args["sid"], "clientSecret": args["clientSecret"]} + postArgs["mxid"] = self.config["user"] reactor.callFromThread(self._do_3pidbind, postArgs) @defer.inlineCallbacks def _do_3pidbind(self, args): - url = self._identityServerUrl()+"/_matrix/identity/api/v1/3pid/bind" - - json_res = yield self.http_client.do_request("POST", url, data=urllib.urlencode(args), jsonreq=False, - headers={'Content-Type': ['application/x-www-form-urlencoded']}) + url = self._identityServerUrl() + "/_matrix/identity/api/v1/3pid/bind" + + json_res = yield self.http_client.do_request( + "POST", + url, + data=urllib.urlencode(args), + jsonreq=False, + headers={"Content-Type": ["application/x-www-form-urlencoded"]}, + ) print(json_res) def do_join(self, line): @@ -356,9 +377,7 @@ class SynapseCmd(cmd.Cmd): if "topic" not in args: print("Must specify a new topic.") return - body = { - "topic": args["topic"] - } + body = {"topic": args["topic"]} reactor.callFromThread(self._run_and_pprint, "PUT", path, body) elif args["action"].lower() == "get": reactor.callFromThread(self._run_and_pprint, "GET", path) @@ -378,45 +397,60 @@ class SynapseCmd(cmd.Cmd): @defer.inlineCallbacks def _do_invite(self, roomid, userstring): - if (not userstring.startswith('@') and - self._is_on("complete_usernames")): - url = self._identityServerUrl()+"/_matrix/identity/api/v1/lookup" + if not userstring.startswith("@") and self._is_on("complete_usernames"): + url = self._identityServerUrl() + "/_matrix/identity/api/v1/lookup" - json_res = yield self.http_client.do_request("GET", url, qparams={'medium':'email','address':userstring}) + json_res = yield self.http_client.do_request( + "GET", url, qparams={"medium": "email", "address": userstring} + ) mxid = None - if 'mxid' in json_res and 'signatures' in json_res: - url = self._identityServerUrl()+"/_matrix/identity/api/v1/pubkey/ed25519" + if "mxid" in json_res and "signatures" in json_res: + url = ( + self._identityServerUrl() + + "/_matrix/identity/api/v1/pubkey/ed25519" + ) pubKey = None pubKeyObj = yield self.http_client.do_request("GET", url) - if 'public_key' in pubKeyObj: - pubKey = nacl.signing.VerifyKey(pubKeyObj['public_key'], encoder=nacl.encoding.HexEncoder) + if "public_key" in pubKeyObj: + pubKey = nacl.signing.VerifyKey( + pubKeyObj["public_key"], encoder=nacl.encoding.HexEncoder + ) else: print("No public key found in pubkey response!") sigValid = False if pubKey: - for signame in json_res['signatures']: + for signame in json_res["signatures"]: if signame not in TRUSTED_ID_SERVERS: - print("Ignoring signature from untrusted server %s" % (signame)) + print( + "Ignoring signature from untrusted server %s" + % (signame) + ) else: try: verify_signed_json(json_res, signame, pubKey) sigValid = True - print("Mapping %s -> %s correctly signed by %s" % (userstring, json_res['mxid'], signame)) + print( + "Mapping %s -> %s correctly signed by %s" + % (userstring, json_res["mxid"], signame) + ) break except SignatureVerifyException as e: print("Invalid signature from %s" % (signame)) print(e) if sigValid: - print("Resolved 3pid %s to %s" % (userstring, json_res['mxid'])) - mxid = json_res['mxid'] + print("Resolved 3pid %s to %s" % (userstring, json_res["mxid"])) + mxid = json_res["mxid"] else: - print("Got association for %s but couldn't verify signature" % (userstring)) + print( + "Got association for %s but couldn't verify signature" + % (userstring) + ) if not mxid: mxid = "@" + userstring + ":" + self._domain() @@ -435,12 +469,11 @@ class SynapseCmd(cmd.Cmd): """Sends a message. "send " """ args = self._parse(line, ["roomid", "body"]) txn_id = "txn%s" % int(time.time()) - path = "/rooms/%s/send/m.room.message/%s" % (urllib.quote(args["roomid"]), - txn_id) - body_json = { - "msgtype": "m.text", - "body": args["body"] - } + path = "/rooms/%s/send/m.room.message/%s" % ( + urllib.quote(args["roomid"]), + txn_id, + ) + body_json = {"msgtype": "m.text", "body": args["body"]} reactor.callFromThread(self._run_and_pprint, "PUT", path, body_json) def do_list(self, line): @@ -472,8 +505,7 @@ class SynapseCmd(cmd.Cmd): print("Bad query param: %s" % key_value) return - reactor.callFromThread(self._run_and_pprint, "GET", path, - query_params=qp) + reactor.callFromThread(self._run_and_pprint, "GET", path, query_params=qp) def do_create(self, line): """Creates a room. @@ -513,8 +545,16 @@ class SynapseCmd(cmd.Cmd): return args["method"] = args["method"].upper() - valid_methods = ["PUT", "GET", "POST", "DELETE", - "XPUT", "XGET", "XPOST", "XDELETE"] + valid_methods = [ + "PUT", + "GET", + "POST", + "DELETE", + "XPUT", + "XGET", + "XPOST", + "XDELETE", + ] if args["method"] not in valid_methods: print("Unsupported method: %s" % args["method"]) return @@ -541,10 +581,13 @@ class SynapseCmd(cmd.Cmd): except: pass - reactor.callFromThread(self._run_and_pprint, args["method"], - args["path"], - args["data"], - query_params=qp) + reactor.callFromThread( + self._run_and_pprint, + args["method"], + args["path"], + args["data"], + query_params=qp, + ) def do_stream(self, line): """Stream data from the server: "stream " """ @@ -561,19 +604,22 @@ class SynapseCmd(cmd.Cmd): @defer.inlineCallbacks def _do_event_stream(self, timeout): res = yield self.http_client.get_json( - self._url() + "/events", - { - "access_token": self._tok(), - "timeout": str(timeout), - "from": self.event_stream_token - }) + self._url() + "/events", + { + "access_token": self._tok(), + "timeout": str(timeout), + "from": self.event_stream_token, + }, + ) print(json.dumps(res, indent=4)) if "chunk" in res: for event in res["chunk"]: - if (event["type"] == "m.room.message" and - self._is_on("send_delivery_receipts") and - event["user_id"] != self._usr()): # not sent by us + if ( + event["type"] == "m.room.message" + and self._is_on("send_delivery_receipts") + and event["user_id"] != self._usr() + ): # not sent by us self._send_receipt(event, "d") # update the position in the stram @@ -581,18 +627,28 @@ class SynapseCmd(cmd.Cmd): self.event_stream_token = res["end"] def _send_receipt(self, event, feedback_type): - path = ("/rooms/%s/messages/%s/%s/feedback/%s/%s" % - (urllib.quote(event["room_id"]), event["user_id"], event["msg_id"], - self._usr(), feedback_type)) + path = "/rooms/%s/messages/%s/%s/feedback/%s/%s" % ( + urllib.quote(event["room_id"]), + event["user_id"], + event["msg_id"], + self._usr(), + feedback_type, + ) data = {} - reactor.callFromThread(self._run_and_pprint, "PUT", path, data=data, - alt_text="Sent receipt for %s" % event["msg_id"]) + reactor.callFromThread( + self._run_and_pprint, + "PUT", + path, + data=data, + alt_text="Sent receipt for %s" % event["msg_id"], + ) def _do_membership_change(self, roomid, membership, userid): - path = "/rooms/%s/state/m.room.member/%s" % (urllib.quote(roomid), urllib.quote(userid)) - data = { - "membership": membership - } + path = "/rooms/%s/state/m.room.member/%s" % ( + urllib.quote(roomid), + urllib.quote(userid), + ) + data = {"membership": membership} reactor.callFromThread(self._run_and_pprint, "PUT", path, data=data) def do_displayname(self, line): @@ -645,15 +701,20 @@ class SynapseCmd(cmd.Cmd): for i, arg in enumerate(line_args): for config_key in self.config: if ("$" + config_key) in arg: - arg = arg.replace("$" + config_key, - self.config[config_key]) + arg = arg.replace("$" + config_key, self.config[config_key]) line_args[i] = arg return dict(zip(keys, line_args)) @defer.inlineCallbacks - def _run_and_pprint(self, method, path, data=None, - query_params={"access_token": None}, alt_text=None): + def _run_and_pprint( + self, + method, + path, + data=None, + query_params={"access_token": None}, + alt_text=None, + ): """ Runs an HTTP request and pretty prints the output. Args: @@ -666,9 +727,9 @@ class SynapseCmd(cmd.Cmd): if "access_token" in query_params: query_params["access_token"] = self._tok() - json_res = yield self.http_client.do_request(method, url, - data=data, - qparams=query_params) + json_res = yield self.http_client.do_request( + method, url, data=data, qparams=query_params + ) if alt_text: print(alt_text) else: @@ -676,7 +737,7 @@ class SynapseCmd(cmd.Cmd): def save_config(config): - with open(CONFIG_JSON, 'w') as out: + with open(CONFIG_JSON, "w") as out: json.dump(config, out) @@ -700,7 +761,7 @@ def main(server_url, identity_server_url, username, token, config_path): global CONFIG_JSON CONFIG_JSON = config_path # bit cheeky, but just overwrite the global try: - with open(config_path, 'r') as config: + with open(config_path, "r") as config: syn_cmd.config = json.load(config) try: http_client.verbose = "on" == syn_cmd.config["verbose"] @@ -717,23 +778,33 @@ def main(server_url, identity_server_url, username, token, config_path): reactor.run() -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser("Starts a synapse client.") parser.add_argument( - "-s", "--server", dest="server", default="http://localhost:8008", - help="The URL of the home server to talk to.") - parser.add_argument( - "-i", "--identity-server", dest="identityserver", default="http://localhost:8090", - help="The URL of the identity server to talk to.") + "-s", + "--server", + dest="server", + default="http://localhost:8008", + help="The URL of the home server to talk to.", + ) parser.add_argument( - "-u", "--username", dest="username", - help="Your username on the server.") + "-i", + "--identity-server", + dest="identityserver", + default="http://localhost:8090", + help="The URL of the identity server to talk to.", + ) parser.add_argument( - "-t", "--token", dest="token", - help="Your access token.") + "-u", "--username", dest="username", help="Your username on the server." + ) + parser.add_argument("-t", "--token", dest="token", help="Your access token.") parser.add_argument( - "-c", "--config", dest="config", default=CONFIG_JSON, - help="The location of the config.json file to read from.") + "-c", + "--config", + dest="config", + default=CONFIG_JSON, + help="The location of the config.json file to read from.", + ) args = parser.parse_args() if not args.server: diff --git a/contrib/cmdclient/http.py b/contrib/cmdclient/http.py index 1bd600e148..0e101d2be5 100644 --- a/contrib/cmdclient/http.py +++ b/contrib/cmdclient/http.py @@ -73,9 +73,7 @@ class TwistedHttpClient(HttpClient): @defer.inlineCallbacks def put_json(self, url, data): response = yield self._create_put_request( - url, - data, - headers_dict={"Content-Type": ["application/json"]} + url, data, headers_dict={"Content-Type": ["application/json"]} ) body = yield readBody(response) defer.returnValue((response.code, body)) @@ -95,40 +93,34 @@ class TwistedHttpClient(HttpClient): """ if "Content-Type" not in headers_dict: - raise defer.error( - RuntimeError("Must include Content-Type header for PUTs")) + raise defer.error(RuntimeError("Must include Content-Type header for PUTs")) return self._create_request( - "PUT", - url, - producer=_JsonProducer(json_data), - headers_dict=headers_dict + "PUT", url, producer=_JsonProducer(json_data), headers_dict=headers_dict ) def _create_get_request(self, url, headers_dict={}): """ Wrapper of _create_request to issue a GET request """ - return self._create_request( - "GET", - url, - headers_dict=headers_dict - ) + return self._create_request("GET", url, headers_dict=headers_dict) @defer.inlineCallbacks - def do_request(self, method, url, data=None, qparams=None, jsonreq=True, headers={}): + def do_request( + self, method, url, data=None, qparams=None, jsonreq=True, headers={} + ): if qparams: url = "%s?%s" % (url, urllib.urlencode(qparams, True)) if jsonreq: prod = _JsonProducer(data) - headers['Content-Type'] = ["application/json"]; + headers["Content-Type"] = ["application/json"] else: prod = _RawProducer(data) if method in ["POST", "PUT"]: - response = yield self._create_request(method, url, - producer=prod, - headers_dict=headers) + response = yield self._create_request( + method, url, producer=prod, headers_dict=headers + ) else: response = yield self._create_request(method, url) @@ -155,10 +147,7 @@ class TwistedHttpClient(HttpClient): while True: try: response = yield self.agent.request( - method, - url.encode("UTF8"), - Headers(headers_dict), - producer + method, url.encode("UTF8"), Headers(headers_dict), producer ) break except Exception as e: @@ -179,6 +168,7 @@ class TwistedHttpClient(HttpClient): reactor.callLater(seconds, d.callback, seconds) return d + class _RawProducer(object): def __init__(self, data): self.data = data @@ -195,9 +185,11 @@ class _RawProducer(object): def stopProducing(self): pass + class _JsonProducer(object): """ Used by the twisted http client to create the HTTP body from json """ + def __init__(self, jsn): self.data = jsn self.body = json.dumps(jsn).encode("utf8") diff --git a/contrib/experiments/cursesio.py b/contrib/experiments/cursesio.py index 44afe81008..ffefe3bb39 100644 --- a/contrib/experiments/cursesio.py +++ b/contrib/experiments/cursesio.py @@ -19,13 +19,13 @@ from curses.ascii import isprint from twisted.internet import reactor -class CursesStdIO(): +class CursesStdIO: def __init__(self, stdscr, callback=None): self.statusText = "Synapse test app -" - self.searchText = '' + self.searchText = "" self.stdscr = stdscr - self.logLine = '' + self.logLine = "" self.callback = callback @@ -71,8 +71,7 @@ class CursesStdIO(): i = 0 index = len(self.lines) - 1 while i < (self.rows - 3) and index >= 0: - self.stdscr.addstr(self.rows - 3 - i, 0, self.lines[index], - curses.A_NORMAL) + self.stdscr.addstr(self.rows - 3 - i, 0, self.lines[index], curses.A_NORMAL) i = i + 1 index = index - 1 @@ -85,15 +84,13 @@ class CursesStdIO(): raise RuntimeError("TextTooLongError") self.stdscr.addstr( - self.rows - 2, 0, - text + ' ' * (self.cols - len(text)), - curses.A_STANDOUT) + self.rows - 2, 0, text + " " * (self.cols - len(text)), curses.A_STANDOUT + ) def printLogLine(self, text): self.stdscr.addstr( - 0, 0, - text + ' ' * (self.cols - len(text)), - curses.A_STANDOUT) + 0, 0, text + " " * (self.cols - len(text)), curses.A_STANDOUT + ) def doRead(self): """ Input is ready! """ @@ -105,7 +102,7 @@ class CursesStdIO(): elif c == curses.KEY_ENTER or c == 10: text = self.searchText - self.searchText = '' + self.searchText = "" self.print_line(">> %s" % text) @@ -122,11 +119,13 @@ class CursesStdIO(): return self.searchText = self.searchText + chr(c) - self.stdscr.addstr(self.rows - 1, 0, - self.searchText + (' ' * ( - self.cols - len(self.searchText) - 2))) + self.stdscr.addstr( + self.rows - 1, + 0, + self.searchText + (" " * (self.cols - len(self.searchText) - 2)), + ) - self.paintStatus(self.statusText + ' %d' % len(self.searchText)) + self.paintStatus(self.statusText + " %d" % len(self.searchText)) self.stdscr.move(self.rows - 1, len(self.searchText)) self.stdscr.refresh() @@ -143,7 +142,6 @@ class CursesStdIO(): class Callback(object): - def __init__(self, stdio): self.stdio = stdio @@ -152,7 +150,7 @@ class Callback(object): def main(stdscr): - screen = CursesStdIO(stdscr) # create Screen object + screen = CursesStdIO(stdscr) # create Screen object callback = Callback(screen) @@ -164,5 +162,5 @@ def main(stdscr): screen.close() -if __name__ == '__main__': +if __name__ == "__main__": curses.wrapper(main) diff --git a/contrib/experiments/test_messaging.py b/contrib/experiments/test_messaging.py index 85c9c11984..c7e55d8aa7 100644 --- a/contrib/experiments/test_messaging.py +++ b/contrib/experiments/test_messaging.py @@ -28,9 +28,7 @@ Currently assumes the local address is localhost: """ -from synapse.federation import ( - ReplicationHandler -) +from synapse.federation import ReplicationHandler from synapse.federation.units import Pdu @@ -38,7 +36,7 @@ from synapse.util import origin_from_ucid from synapse.app.homeserver import SynapseHomeServer -#from synapse.util.logutils import log_function +# from synapse.util.logutils import log_function from twisted.internet import reactor, defer from twisted.python import log @@ -83,7 +81,7 @@ class InputOutput(object): room_name, = m.groups() self.print_line("%s joining %s" % (self.user, room_name)) self.server.join_room(room_name, self.user, self.user) - #self.print_line("OK.") + # self.print_line("OK.") return m = re.match("^invite (\S+) (\S+)$", line) @@ -92,7 +90,7 @@ class InputOutput(object): room_name, invitee = m.groups() self.print_line("%s invited to %s" % (invitee, room_name)) self.server.invite_to_room(room_name, self.user, invitee) - #self.print_line("OK.") + # self.print_line("OK.") return m = re.match("^send (\S+) (.*)$", line) @@ -101,7 +99,7 @@ class InputOutput(object): room_name, body = m.groups() self.print_line("%s send to %s" % (self.user, room_name)) self.server.send_message(room_name, self.user, body) - #self.print_line("OK.") + # self.print_line("OK.") return m = re.match("^backfill (\S+)$", line) @@ -125,7 +123,6 @@ class InputOutput(object): class IOLoggerHandler(logging.Handler): - def __init__(self, io): logging.Handler.__init__(self) self.io = io @@ -142,6 +139,7 @@ class Room(object): """ Used to store (in memory) the current membership state of a room, and which home servers we should send PDUs associated with the room to. """ + def __init__(self, room_name): self.room_name = room_name self.invited = set() @@ -175,6 +173,7 @@ class HomeServer(ReplicationHandler): """ A very basic home server implentation that allows people to join a room and then invite other people. """ + def __init__(self, server_name, replication_layer, output): self.server_name = server_name self.replication_layer = replication_layer @@ -197,26 +196,27 @@ class HomeServer(ReplicationHandler): elif pdu.content["membership"] == "invite": self._on_invite(pdu.origin, pdu.context, pdu.state_key) else: - self.output.print_line("#%s (unrec) %s = %s" % - (pdu.context, pdu.pdu_type, json.dumps(pdu.content)) + self.output.print_line( + "#%s (unrec) %s = %s" + % (pdu.context, pdu.pdu_type, json.dumps(pdu.content)) ) - #def on_state_change(self, pdu): - ##self.output.print_line("#%s (state) %s *** %s" % - ##(pdu.context, pdu.state_key, pdu.pdu_type) - ##) + # def on_state_change(self, pdu): + ##self.output.print_line("#%s (state) %s *** %s" % + ##(pdu.context, pdu.state_key, pdu.pdu_type) + ##) - #if "joinee" in pdu.content: - #self._on_join(pdu.context, pdu.content["joinee"]) - #elif "invitee" in pdu.content: - #self._on_invite(pdu.origin, pdu.context, pdu.content["invitee"]) + # if "joinee" in pdu.content: + # self._on_join(pdu.context, pdu.content["joinee"]) + # elif "invitee" in pdu.content: + # self._on_invite(pdu.origin, pdu.context, pdu.content["invitee"]) def _on_message(self, pdu): """ We received a message """ - self.output.print_line("#%s %s %s" % - (pdu.context, pdu.content["sender"], pdu.content["body"]) - ) + self.output.print_line( + "#%s %s %s" % (pdu.context, pdu.content["sender"], pdu.content["body"]) + ) def _on_join(self, context, joinee): """ Someone has joined a room, either a remote user or a local user @@ -224,9 +224,7 @@ class HomeServer(ReplicationHandler): room = self._get_or_create_room(context) room.add_participant(joinee) - self.output.print_line("#%s %s %s" % - (context, joinee, "*** JOINED") - ) + self.output.print_line("#%s %s %s" % (context, joinee, "*** JOINED")) def _on_invite(self, origin, context, invitee): """ Someone has been invited @@ -234,9 +232,7 @@ class HomeServer(ReplicationHandler): room = self._get_or_create_room(context) room.add_invited(invitee) - self.output.print_line("#%s %s %s" % - (context, invitee, "*** INVITED") - ) + self.output.print_line("#%s %s %s" % (context, invitee, "*** INVITED")) if not room.have_got_metadata and origin is not self.server_name: logger.debug("Get room state") @@ -272,14 +268,14 @@ class HomeServer(ReplicationHandler): try: pdu = Pdu.create_new( - context=room_name, - pdu_type="sy.room.member", - is_state=True, - state_key=joinee, - content={"membership": "join"}, - origin=self.server_name, - destinations=destinations, - ) + context=room_name, + pdu_type="sy.room.member", + is_state=True, + state_key=joinee, + content={"membership": "join"}, + origin=self.server_name, + destinations=destinations, + ) yield self.replication_layer.send_pdu(pdu) except Exception as e: logger.exception(e) @@ -318,21 +314,21 @@ class HomeServer(ReplicationHandler): return self.replication_layer.backfill(dest, room_name, limit) def _get_room_remote_servers(self, room_name): - return [i for i in self.joined_rooms.setdefault(room_name,).servers] + return [i for i in self.joined_rooms.setdefault(room_name).servers] def _get_or_create_room(self, room_name): return self.joined_rooms.setdefault(room_name, Room(room_name)) def get_servers_for_context(self, context): return defer.succeed( - self.joined_rooms.setdefault(context, Room(context)).servers - ) + self.joined_rooms.setdefault(context, Room(context)).servers + ) def main(stdscr): parser = argparse.ArgumentParser() - parser.add_argument('user', type=str) - parser.add_argument('-v', '--verbose', action='count') + parser.add_argument("user", type=str) + parser.add_argument("-v", "--verbose", action="count") args = parser.parse_args() user = args.user @@ -342,8 +338,9 @@ def main(stdscr): root_logger = logging.getLogger() - formatter = logging.Formatter('%(asctime)s - %(name)s - %(lineno)d - ' - '%(levelname)s - %(message)s') + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(lineno)d - " "%(levelname)s - %(message)s" + ) if not os.path.exists("logs"): os.makedirs("logs") fh = logging.FileHandler("logs/%s" % user) diff --git a/contrib/graph/graph.py b/contrib/graph/graph.py index e174ff5026..92736480eb 100644 --- a/contrib/graph/graph.py +++ b/contrib/graph/graph.py @@ -1,4 +1,5 @@ from __future__ import print_function + # Copyright 2014-2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -58,9 +59,9 @@ def make_graph(pdus, room, filename_prefix): name = make_name(pdu.get("pdu_id"), pdu.get("origin")) pdu_map[name] = pdu - t = datetime.datetime.fromtimestamp( - float(pdu["ts"]) / 1000 - ).strftime('%Y-%m-%d %H:%M:%S,%f') + t = datetime.datetime.fromtimestamp(float(pdu["ts"]) / 1000).strftime( + "%Y-%m-%d %H:%M:%S,%f" + ) label = ( "<" @@ -80,11 +81,7 @@ def make_graph(pdus, room, filename_prefix): "depth": pdu.get("depth"), } - node = pydot.Node( - name=name, - label=label, - color=color_map[pdu.get("origin")] - ) + node = pydot.Node(name=name, label=label, color=color_map[pdu.get("origin")]) node_map[name] = node graph.add_node(node) @@ -108,14 +105,13 @@ def make_graph(pdus, room, filename_prefix): if prev_state_name in node_map: state_edge = pydot.Edge( - node_map[start_name], node_map[prev_state_name], - style='dotted' + node_map[start_name], node_map[prev_state_name], style="dotted" ) graph.add_edge(state_edge) - graph.write('%s.dot' % filename_prefix, format='raw', prog='dot') -# graph.write_png("%s.png" % filename_prefix, prog='dot') - graph.write_svg("%s.svg" % filename_prefix, prog='dot') + graph.write("%s.dot" % filename_prefix, format="raw", prog="dot") + # graph.write_png("%s.png" % filename_prefix, prog='dot') + graph.write_svg("%s.svg" % filename_prefix, prog="dot") def get_pdus(host, room): @@ -131,15 +127,14 @@ def get_pdus(host, room): if __name__ == "__main__": parser = argparse.ArgumentParser( description="Generate a PDU graph for a given room by talking " - "to the given homeserver to get the list of PDUs. \n" - "Requires pydot." + "to the given homeserver to get the list of PDUs. \n" + "Requires pydot." ) parser.add_argument( - "-p", "--prefix", dest="prefix", - help="String to prefix output files with" + "-p", "--prefix", dest="prefix", help="String to prefix output files with" ) - parser.add_argument('host') - parser.add_argument('room') + parser.add_argument("host") + parser.add_argument("room") args = parser.parse_args() diff --git a/contrib/graph/graph2.py b/contrib/graph/graph2.py index 1ccad65728..9db8725eee 100644 --- a/contrib/graph/graph2.py +++ b/contrib/graph/graph2.py @@ -36,10 +36,7 @@ def make_graph(db_name, room_id, file_prefix, limit): args = [room_id] if limit: - sql += ( - " ORDER BY topological_ordering DESC, stream_ordering DESC " - "LIMIT ?" - ) + sql += " ORDER BY topological_ordering DESC, stream_ordering DESC " "LIMIT ?" args.append(limit) @@ -56,9 +53,8 @@ def make_graph(db_name, room_id, file_prefix, limit): for event in events: c = conn.execute( - "SELECT state_group FROM event_to_state_groups " - "WHERE event_id = ?", - (event.event_id,) + "SELECT state_group FROM event_to_state_groups " "WHERE event_id = ?", + (event.event_id,), ) res = c.fetchone() @@ -69,7 +65,7 @@ def make_graph(db_name, room_id, file_prefix, limit): t = datetime.datetime.fromtimestamp( float(event.origin_server_ts) / 1000 - ).strftime('%Y-%m-%d %H:%M:%S,%f') + ).strftime("%Y-%m-%d %H:%M:%S,%f") content = json.dumps(unfreeze(event.get_dict()["content"])) @@ -93,10 +89,7 @@ def make_graph(db_name, room_id, file_prefix, limit): "state_group": state_group, } - node = pydot.Node( - name=event.event_id, - label=label, - ) + node = pydot.Node(name=event.event_id, label=label) node_map[event.event_id] = node graph.add_node(node) @@ -106,10 +99,7 @@ def make_graph(db_name, room_id, file_prefix, limit): try: end_node = node_map[prev_id] except: - end_node = pydot.Node( - name=prev_id, - label="<%s>" % (prev_id,), - ) + end_node = pydot.Node(name=prev_id, label="<%s>" % (prev_id,)) node_map[prev_id] = end_node graph.add_node(end_node) @@ -121,36 +111,33 @@ def make_graph(db_name, room_id, file_prefix, limit): if len(event_ids) <= 1: continue - cluster = pydot.Cluster( - str(group), - label="" % (str(group),) - ) + cluster = pydot.Cluster(str(group), label="" % (str(group),)) for event_id in event_ids: cluster.add_node(node_map[event_id]) graph.add_subgraph(cluster) - graph.write('%s.dot' % file_prefix, format='raw', prog='dot') - graph.write_svg("%s.svg" % file_prefix, prog='dot') + graph.write("%s.dot" % file_prefix, format="raw", prog="dot") + graph.write_svg("%s.svg" % file_prefix, prog="dot") + if __name__ == "__main__": parser = argparse.ArgumentParser( description="Generate a PDU graph for a given room by talking " - "to the given homeserver to get the list of PDUs. \n" - "Requires pydot." + "to the given homeserver to get the list of PDUs. \n" + "Requires pydot." ) parser.add_argument( - "-p", "--prefix", dest="prefix", + "-p", + "--prefix", + dest="prefix", help="String to prefix output files with", - default="graph_output" - ) - parser.add_argument( - "-l", "--limit", - help="Only retrieve the last N events.", + default="graph_output", ) - parser.add_argument('db') - parser.add_argument('room') + parser.add_argument("-l", "--limit", help="Only retrieve the last N events.") + parser.add_argument("db") + parser.add_argument("room") args = parser.parse_args() diff --git a/contrib/graph/graph3.py b/contrib/graph/graph3.py index fe1dc81e90..7f9e5374a6 100644 --- a/contrib/graph/graph3.py +++ b/contrib/graph/graph3.py @@ -1,4 +1,5 @@ from __future__ import print_function + # Copyright 2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -42,7 +43,7 @@ def make_graph(file_name, room_id, file_prefix, limit): print("Sorted events") if limit: - events = events[-int(limit):] + events = events[-int(limit) :] node_map = {} @@ -51,7 +52,7 @@ def make_graph(file_name, room_id, file_prefix, limit): for event in events: t = datetime.datetime.fromtimestamp( float(event.origin_server_ts) / 1000 - ).strftime('%Y-%m-%d %H:%M:%S,%f') + ).strftime("%Y-%m-%d %H:%M:%S,%f") content = json.dumps(unfreeze(event.get_dict()["content"]), indent=4) content = content.replace("\n", "
\n") @@ -67,9 +68,10 @@ def make_graph(file_name, room_id, file_prefix, limit): value = json.dumps(value) content.append( - "%s: %s," % ( - cgi.escape(key, quote=True).encode("ascii", 'xmlcharrefreplace'), - cgi.escape(value, quote=True).encode("ascii", 'xmlcharrefreplace'), + "%s: %s," + % ( + cgi.escape(key, quote=True).encode("ascii", "xmlcharrefreplace"), + cgi.escape(value, quote=True).encode("ascii", "xmlcharrefreplace"), ) ) @@ -95,10 +97,7 @@ def make_graph(file_name, room_id, file_prefix, limit): "depth": event.depth, } - node = pydot.Node( - name=event.event_id, - label=label, - ) + node = pydot.Node(name=event.event_id, label=label) node_map[event.event_id] = node graph.add_node(node) @@ -110,10 +109,7 @@ def make_graph(file_name, room_id, file_prefix, limit): try: end_node = node_map[prev_id] except: - end_node = pydot.Node( - name=prev_id, - label="<%s>" % (prev_id,), - ) + end_node = pydot.Node(name=prev_id, label="<%s>" % (prev_id,)) node_map[prev_id] = end_node graph.add_node(end_node) @@ -123,31 +119,31 @@ def make_graph(file_name, room_id, file_prefix, limit): print("Created edges") - graph.write('%s.dot' % file_prefix, format='raw', prog='dot') + graph.write("%s.dot" % file_prefix, format="raw", prog="dot") print("Created Dot") - graph.write_svg("%s.svg" % file_prefix, prog='dot') + graph.write_svg("%s.svg" % file_prefix, prog="dot") print("Created svg") + if __name__ == "__main__": parser = argparse.ArgumentParser( description="Generate a PDU graph for a given room by reading " - "from a file with line deliminated events. \n" - "Requires pydot." + "from a file with line deliminated events. \n" + "Requires pydot." ) parser.add_argument( - "-p", "--prefix", dest="prefix", + "-p", + "--prefix", + dest="prefix", help="String to prefix output files with", - default="graph_output" - ) - parser.add_argument( - "-l", "--limit", - help="Only retrieve the last N events.", + default="graph_output", ) - parser.add_argument('event_file') - parser.add_argument('room') + parser.add_argument("-l", "--limit", help="Only retrieve the last N events.") + parser.add_argument("event_file") + parser.add_argument("room") args = parser.parse_args() diff --git a/contrib/jitsimeetbridge/jitsimeetbridge.py b/contrib/jitsimeetbridge/jitsimeetbridge.py index e82d1be5d2..67fb2cd1a7 100644 --- a/contrib/jitsimeetbridge/jitsimeetbridge.py +++ b/contrib/jitsimeetbridge/jitsimeetbridge.py @@ -20,24 +20,25 @@ import urllib import subprocess import time -#ACCESS_TOKEN="" # +# ACCESS_TOKEN="" # -MATRIXBASE = 'https://matrix.org/_matrix/client/api/v1/' -MYUSERNAME = '@davetest:matrix.org' +MATRIXBASE = "https://matrix.org/_matrix/client/api/v1/" +MYUSERNAME = "@davetest:matrix.org" -HTTPBIND = 'https://meet.jit.si/http-bind' -#HTTPBIND = 'https://jitsi.vuc.me/http-bind' -#ROOMNAME = "matrix" +HTTPBIND = "https://meet.jit.si/http-bind" +# HTTPBIND = 'https://jitsi.vuc.me/http-bind' +# ROOMNAME = "matrix" ROOMNAME = "pibble" -HOST="guest.jit.si" -#HOST="jitsi.vuc.me" +HOST = "guest.jit.si" +# HOST="jitsi.vuc.me" -TURNSERVER="turn.guest.jit.si" -#TURNSERVER="turn.jitsi.vuc.me" +TURNSERVER = "turn.guest.jit.si" +# TURNSERVER="turn.jitsi.vuc.me" + +ROOMDOMAIN = "meet.jit.si" +# ROOMDOMAIN="conference.jitsi.vuc.me" -ROOMDOMAIN="meet.jit.si" -#ROOMDOMAIN="conference.jitsi.vuc.me" class TrivialMatrixClient: def __init__(self, access_token): @@ -46,38 +47,50 @@ class TrivialMatrixClient: def getEvent(self): while True: - url = MATRIXBASE+'events?access_token='+self.access_token+"&timeout=60000" + url = ( + MATRIXBASE + + "events?access_token=" + + self.access_token + + "&timeout=60000" + ) if self.token: - url += "&from="+self.token + url += "&from=" + self.token req = grequests.get(url) resps = grequests.map([req]) obj = json.loads(resps[0].content) - print("incoming from matrix",obj) - if 'end' not in obj: + print("incoming from matrix", obj) + if "end" not in obj: continue - self.token = obj['end'] - if len(obj['chunk']): - return obj['chunk'][0] + self.token = obj["end"] + if len(obj["chunk"]): + return obj["chunk"][0] def joinRoom(self, roomId): - url = MATRIXBASE+'rooms/'+roomId+'/join?access_token='+self.access_token + url = MATRIXBASE + "rooms/" + roomId + "/join?access_token=" + self.access_token print(url) - headers={ 'Content-Type': 'application/json' } - req = grequests.post(url, headers=headers, data='{}') + headers = {"Content-Type": "application/json"} + req = grequests.post(url, headers=headers, data="{}") resps = grequests.map([req]) obj = json.loads(resps[0].content) - print("response: ",obj) + print("response: ", obj) def sendEvent(self, roomId, evType, event): - url = MATRIXBASE+'rooms/'+roomId+'/send/'+evType+'?access_token='+self.access_token + url = ( + MATRIXBASE + + "rooms/" + + roomId + + "/send/" + + evType + + "?access_token=" + + self.access_token + ) print(url) print(json.dumps(event)) - headers={ 'Content-Type': 'application/json' } + headers = {"Content-Type": "application/json"} req = grequests.post(url, headers=headers, data=json.dumps(event)) resps = grequests.map([req]) obj = json.loads(resps[0].content) - print("response: ",obj) - + print("response: ", obj) xmppClients = {} @@ -87,38 +100,39 @@ def matrixLoop(): while True: ev = matrixCli.getEvent() print(ev) - if ev['type'] == 'm.room.member': - print('membership event') - if ev['membership'] == 'invite' and ev['state_key'] == MYUSERNAME: - roomId = ev['room_id'] + if ev["type"] == "m.room.member": + print("membership event") + if ev["membership"] == "invite" and ev["state_key"] == MYUSERNAME: + roomId = ev["room_id"] print("joining room %s" % (roomId)) matrixCli.joinRoom(roomId) - elif ev['type'] == 'm.room.message': - if ev['room_id'] in xmppClients: + elif ev["type"] == "m.room.message": + if ev["room_id"] in xmppClients: print("already have a bridge for that user, ignoring") continue print("got message, connecting") - xmppClients[ev['room_id']] = TrivialXmppClient(ev['room_id'], ev['user_id']) - gevent.spawn(xmppClients[ev['room_id']].xmppLoop) - elif ev['type'] == 'm.call.invite': + xmppClients[ev["room_id"]] = TrivialXmppClient(ev["room_id"], ev["user_id"]) + gevent.spawn(xmppClients[ev["room_id"]].xmppLoop) + elif ev["type"] == "m.call.invite": print("Incoming call") - #sdp = ev['content']['offer']['sdp'] - #print "sdp: %s" % (sdp) - #xmppClients[ev['room_id']] = TrivialXmppClient(ev['room_id'], ev['user_id']) - #gevent.spawn(xmppClients[ev['room_id']].xmppLoop) - elif ev['type'] == 'm.call.answer': + # sdp = ev['content']['offer']['sdp'] + # print "sdp: %s" % (sdp) + # xmppClients[ev['room_id']] = TrivialXmppClient(ev['room_id'], ev['user_id']) + # gevent.spawn(xmppClients[ev['room_id']].xmppLoop) + elif ev["type"] == "m.call.answer": print("Call answered") - sdp = ev['content']['answer']['sdp'] - if ev['room_id'] not in xmppClients: + sdp = ev["content"]["answer"]["sdp"] + if ev["room_id"] not in xmppClients: print("We didn't have a call for that room") continue # should probably check call ID too - xmppCli = xmppClients[ev['room_id']] + xmppCli = xmppClients[ev["room_id"]] xmppCli.sendAnswer(sdp) - elif ev['type'] == 'm.call.hangup': - if ev['room_id'] in xmppClients: - xmppClients[ev['room_id']].stop() - del xmppClients[ev['room_id']] + elif ev["type"] == "m.call.hangup": + if ev["room_id"] in xmppClients: + xmppClients[ev["room_id"]].stop() + del xmppClients[ev["room_id"]] + class TrivialXmppClient: def __init__(self, matrixRoom, userId): @@ -132,130 +146,155 @@ class TrivialXmppClient: def nextRid(self): self.rid += 1 - return '%d' % (self.rid) + return "%d" % (self.rid) def sendIq(self, xml): - fullXml = "%s" % (self.nextRid(), self.sid, xml) - #print "\t>>>%s" % (fullXml) + fullXml = ( + "%s" + % (self.nextRid(), self.sid, xml) + ) + # print "\t>>>%s" % (fullXml) return self.xmppPoke(fullXml) def xmppPoke(self, xml): - headers = {'Content-Type': 'application/xml'} + headers = {"Content-Type": "application/xml"} req = grequests.post(HTTPBIND, verify=False, headers=headers, data=xml) resps = grequests.map([req]) obj = BeautifulSoup(resps[0].content) return obj def sendAnswer(self, answer): - print("sdp from matrix client",answer) - p = subprocess.Popen(['node', 'unjingle/unjingle.js', '--sdp'], stdin=subprocess.PIPE, stdout=subprocess.PIPE) + print("sdp from matrix client", answer) + p = subprocess.Popen( + ["node", "unjingle/unjingle.js", "--sdp"], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + ) jingle, out_err = p.communicate(answer) jingle = jingle % { - 'tojid': self.callfrom, - 'action': 'session-accept', - 'initiator': self.callfrom, - 'responder': self.jid, - 'sid': self.callsid + "tojid": self.callfrom, + "action": "session-accept", + "initiator": self.callfrom, + "responder": self.jid, + "sid": self.callsid, } - print("answer jingle from sdp",jingle) + print("answer jingle from sdp", jingle) res = self.sendIq(jingle) - print("reply from answer: ",res) + print("reply from answer: ", res) self.ssrcs = {} jingleSoup = BeautifulSoup(jingle) - for cont in jingleSoup.iq.jingle.findAll('content'): + for cont in jingleSoup.iq.jingle.findAll("content"): if cont.description: - self.ssrcs[cont['name']] = cont.description['ssrc'] - print("my ssrcs:",self.ssrcs) + self.ssrcs[cont["name"]] = cont.description["ssrc"] + print("my ssrcs:", self.ssrcs) - gevent.joinall([ - gevent.spawn(self.advertiseSsrcs) - ]) + gevent.joinall([gevent.spawn(self.advertiseSsrcs)]) def advertiseSsrcs(self): time.sleep(7) print("SSRC spammer started") while self.running: - ssrcMsg = "%(nick)s" % { 'tojid': "%s@%s/%s" % (ROOMNAME, ROOMDOMAIN, self.shortJid), 'nick': self.userId, 'assrc': self.ssrcs['audio'], 'vssrc': self.ssrcs['video'] } + ssrcMsg = ( + "%(nick)s" + % { + "tojid": "%s@%s/%s" % (ROOMNAME, ROOMDOMAIN, self.shortJid), + "nick": self.userId, + "assrc": self.ssrcs["audio"], + "vssrc": self.ssrcs["video"], + } + ) res = self.sendIq(ssrcMsg) - print("reply from ssrc announce: ",res) + print("reply from ssrc announce: ", res) time.sleep(10) - - def xmppLoop(self): self.matrixCallId = time.time() - res = self.xmppPoke("" % (self.nextRid(), HOST)) + res = self.xmppPoke( + "" + % (self.nextRid(), HOST) + ) print(res) - self.sid = res.body['sid'] + self.sid = res.body["sid"] print("sid %s" % (self.sid)) - res = self.sendIq("") + res = self.sendIq( + "" + ) - res = self.xmppPoke("" % (self.nextRid(), self.sid, HOST)) + res = self.xmppPoke( + "" + % (self.nextRid(), self.sid, HOST) + ) - res = self.sendIq("") + res = self.sendIq( + "" + ) print(res) self.jid = res.body.iq.bind.jid.string print("jid: %s" % (self.jid)) - self.shortJid = self.jid.split('-')[0] + self.shortJid = self.jid.split("-")[0] - res = self.sendIq("") + res = self.sendIq( + "" + ) - #randomthing = res.body.iq['to'] - #whatsitpart = randomthing.split('-')[0] + # randomthing = res.body.iq['to'] + # whatsitpart = randomthing.split('-')[0] - #print "other random bind thing: %s" % (randomthing) + # print "other random bind thing: %s" % (randomthing) # advertise preence to the jitsi room, with our nick - res = self.sendIq("%s" % (HOST, TURNSERVER, ROOMNAME, ROOMDOMAIN, self.userId)) - self.muc = {'users': []} - for p in res.body.findAll('presence'): + res = self.sendIq( + "%s" + % (HOST, TURNSERVER, ROOMNAME, ROOMDOMAIN, self.userId) + ) + self.muc = {"users": []} + for p in res.body.findAll("presence"): u = {} - u['shortJid'] = p['from'].split('/')[1] + u["shortJid"] = p["from"].split("/")[1] if p.c and p.c.nick: - u['nick'] = p.c.nick.string - self.muc['users'].append(u) - print("muc: ",self.muc) + u["nick"] = p.c.nick.string + self.muc["users"].append(u) + print("muc: ", self.muc) # wait for stuff while True: print("waiting...") res = self.sendIq("") - print("got from stream: ",res) + print("got from stream: ", res) if res.body.iq: - jingles = res.body.iq.findAll('jingle') + jingles = res.body.iq.findAll("jingle") if len(jingles): - self.callfrom = res.body.iq['from'] + self.callfrom = res.body.iq["from"] self.handleInvite(jingles[0]) - elif 'type' in res.body and res.body['type'] == 'terminate': + elif "type" in res.body and res.body["type"] == "terminate": self.running = False del xmppClients[self.matrixRoom] return def handleInvite(self, jingle): - self.initiator = jingle['initiator'] - self.callsid = jingle['sid'] - p = subprocess.Popen(['node', 'unjingle/unjingle.js', '--jingle'], stdin=subprocess.PIPE, stdout=subprocess.PIPE) - print("raw jingle invite",str(jingle)) + self.initiator = jingle["initiator"] + self.callsid = jingle["sid"] + p = subprocess.Popen( + ["node", "unjingle/unjingle.js", "--jingle"], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + ) + print("raw jingle invite", str(jingle)) sdp, out_err = p.communicate(str(jingle)) - print("transformed remote offer sdp",sdp) + print("transformed remote offer sdp", sdp) inviteEvent = { - 'offer': { - 'type': 'offer', - 'sdp': sdp - }, - 'call_id': self.matrixCallId, - 'version': 0, - 'lifetime': 30000 + "offer": {"type": "offer", "sdp": sdp}, + "call_id": self.matrixCallId, + "version": 0, + "lifetime": 30000, } - matrixCli.sendEvent(self.matrixRoom, 'm.call.invite', inviteEvent) + matrixCli.sendEvent(self.matrixRoom, "m.call.invite", inviteEvent) -matrixCli = TrivialMatrixClient(ACCESS_TOKEN) # Undefined name -gevent.joinall([ - gevent.spawn(matrixLoop) -]) +matrixCli = TrivialMatrixClient(ACCESS_TOKEN) # Undefined name +gevent.joinall([gevent.spawn(matrixLoop)]) diff --git a/contrib/scripts/kick_users.py b/contrib/scripts/kick_users.py index b4a14385d0..f57e6e7d25 100755 --- a/contrib/scripts/kick_users.py +++ b/contrib/scripts/kick_users.py @@ -11,22 +11,22 @@ try: except NameError: # Python 3 raw_input = input + def _mkurl(template, kws): for key in kws: template = template.replace(key, kws[key]) return template + def main(hs, room_id, access_token, user_id_prefix, why): if not why: why = "Automated kick." - print("Kicking members on %s in room %s matching %s" % (hs, room_id, user_id_prefix)) + print( + "Kicking members on %s in room %s matching %s" % (hs, room_id, user_id_prefix) + ) room_state_url = _mkurl( "$HS/_matrix/client/api/v1/rooms/$ROOM/state?access_token=$TOKEN", - { - "$HS": hs, - "$ROOM": room_id, - "$TOKEN": access_token - } + {"$HS": hs, "$ROOM": room_id, "$TOKEN": access_token}, ) print("Getting room state => %s" % room_state_url) res = requests.get(room_state_url) @@ -57,24 +57,16 @@ def main(hs, room_id, access_token, user_id_prefix, why): for uid in kick_list: print(uid) doit = raw_input("Continue? [Y]es\n") - if len(doit) > 0 and doit.lower() == 'y': + if len(doit) > 0 and doit.lower() == "y": print("Kicking members...") # encode them all kick_list = [urllib.quote(uid) for uid in kick_list] for uid in kick_list: kick_url = _mkurl( "$HS/_matrix/client/api/v1/rooms/$ROOM/state/m.room.member/$UID?access_token=$TOKEN", - { - "$HS": hs, - "$UID": uid, - "$ROOM": room_id, - "$TOKEN": access_token - } + {"$HS": hs, "$UID": uid, "$ROOM": room_id, "$TOKEN": access_token}, ) - kick_body = { - "membership": "leave", - "reason": why - } + kick_body = {"membership": "leave", "reason": why} print("Kicking %s" % uid) res = requests.put(kick_url, data=json.dumps(kick_body)) if res.status_code != 200: @@ -83,14 +75,15 @@ def main(hs, room_id, access_token, user_id_prefix, why): print("ERROR: JSON %s" % res.json()) - if __name__ == "__main__": parser = ArgumentParser("Kick members in a room matching a certain user ID prefix.") - parser.add_argument("-u","--user-id",help="The user ID prefix e.g. '@irc_'") - parser.add_argument("-t","--token",help="Your access_token") - parser.add_argument("-r","--room",help="The room ID to kick members in") - parser.add_argument("-s","--homeserver",help="The base HS url e.g. http://matrix.org") - parser.add_argument("-w","--why",help="Reason for the kick. Optional.") + parser.add_argument("-u", "--user-id", help="The user ID prefix e.g. '@irc_'") + parser.add_argument("-t", "--token", help="Your access_token") + parser.add_argument("-r", "--room", help="The room ID to kick members in") + parser.add_argument( + "-s", "--homeserver", help="The base HS url e.g. http://matrix.org" + ) + parser.add_argument("-w", "--why", help="Reason for the kick. Optional.") args = parser.parse_args() if not args.room or not args.token or not args.user_id or not args.homeserver: parser.print_help() diff --git a/demo/webserver.py b/demo/webserver.py index 875095c877..ba176d3bd2 100644 --- a/demo/webserver.py +++ b/demo/webserver.py @@ -6,23 +6,25 @@ import cgi, logging from daemonize import Daemonize + class SimpleHTTPRequestHandlerWithPOST(SimpleHTTPServer.SimpleHTTPRequestHandler): UPLOAD_PATH = "upload" """ Accept all post request as file upload """ + def do_POST(self): path = os.path.join(self.UPLOAD_PATH, os.path.basename(self.path)) - length = self.headers['content-length'] + length = self.headers["content-length"] data = self.rfile.read(int(length)) - with open(path, 'wb') as fh: + with open(path, "wb") as fh: fh.write(data) self.send_response(200) - self.send_header('Content-Type', 'application/json') + self.send_header("Content-Type", "application/json") self.end_headers() # Return the absolute path of the uploaded file @@ -33,30 +35,25 @@ def setup(): parser = argparse.ArgumentParser() parser.add_argument("directory") parser.add_argument("-p", "--port", dest="port", type=int, default=8080) - parser.add_argument('-P', "--pid-file", dest="pid", default="web.pid") + parser.add_argument("-P", "--pid-file", dest="pid", default="web.pid") args = parser.parse_args() # Get absolute path to directory to serve, as daemonize changes to '/' os.chdir(args.directory) dr = os.getcwd() - httpd = BaseHTTPServer.HTTPServer( - ('', args.port), - SimpleHTTPRequestHandlerWithPOST - ) + httpd = BaseHTTPServer.HTTPServer(("", args.port), SimpleHTTPRequestHandlerWithPOST) def run(): os.chdir(dr) httpd.serve_forever() daemon = Daemonize( - app="synapse-webclient", - pid=args.pid, - action=run, - auto_close_fds=False, - ) + app="synapse-webclient", pid=args.pid, action=run, auto_close_fds=False + ) daemon.start() -if __name__ == '__main__': + +if __name__ == "__main__": setup() diff --git a/docker/start.py b/docker/start.py index 2da555272a..a7a54dacf7 100755 --- a/docker/start.py +++ b/docker/start.py @@ -8,7 +8,10 @@ import glob import codecs # Utility functions -convert = lambda src, dst, environ: open(dst, "w").write(jinja2.Template(open(src).read()).render(**environ)) +convert = lambda src, dst, environ: open(dst, "w").write( + jinja2.Template(open(src).read()).render(**environ) +) + def check_arguments(environ, args): for argument in args: @@ -16,18 +19,22 @@ def check_arguments(environ, args): print("Environment variable %s is mandatory, exiting." % argument) sys.exit(2) + def generate_secrets(environ, secrets): for name, secret in secrets.items(): if secret not in environ: filename = "/data/%s.%s.key" % (environ["SYNAPSE_SERVER_NAME"], name) if os.path.exists(filename): - with open(filename) as handle: value = handle.read() + with open(filename) as handle: + value = handle.read() else: print("Generating a random secret for {}".format(name)) value = codecs.encode(os.urandom(32), "hex").decode() - with open(filename, "w") as handle: handle.write(value) + with open(filename, "w") as handle: + handle.write(value) environ[secret] = value + # Prepare the configuration mode = sys.argv[1] if len(sys.argv) > 1 else None environ = os.environ.copy() @@ -36,12 +43,17 @@ args = ["python", "-m", "synapse.app.homeserver"] # In generate mode, generate a configuration, missing keys, then exit if mode == "generate": - check_arguments(environ, ("SYNAPSE_SERVER_NAME", "SYNAPSE_REPORT_STATS", "SYNAPSE_CONFIG_PATH")) + check_arguments( + environ, ("SYNAPSE_SERVER_NAME", "SYNAPSE_REPORT_STATS", "SYNAPSE_CONFIG_PATH") + ) args += [ - "--server-name", environ["SYNAPSE_SERVER_NAME"], - "--report-stats", environ["SYNAPSE_REPORT_STATS"], - "--config-path", environ["SYNAPSE_CONFIG_PATH"], - "--generate-config" + "--server-name", + environ["SYNAPSE_SERVER_NAME"], + "--report-stats", + environ["SYNAPSE_REPORT_STATS"], + "--config-path", + environ["SYNAPSE_CONFIG_PATH"], + "--generate-config", ] os.execv("/usr/local/bin/python", args) @@ -51,15 +63,19 @@ else: config_path = environ["SYNAPSE_CONFIG_PATH"] else: check_arguments(environ, ("SYNAPSE_SERVER_NAME", "SYNAPSE_REPORT_STATS")) - generate_secrets(environ, { - "registration": "SYNAPSE_REGISTRATION_SHARED_SECRET", - "macaroon": "SYNAPSE_MACAROON_SECRET_KEY" - }) + generate_secrets( + environ, + { + "registration": "SYNAPSE_REGISTRATION_SHARED_SECRET", + "macaroon": "SYNAPSE_MACAROON_SECRET_KEY", + }, + ) environ["SYNAPSE_APPSERVICES"] = glob.glob("/data/appservices/*.yaml") - if not os.path.exists("/compiled"): os.mkdir("/compiled") + if not os.path.exists("/compiled"): + os.mkdir("/compiled") config_path = "/compiled/homeserver.yaml" - + # Convert SYNAPSE_NO_TLS to boolean if exists if "SYNAPSE_NO_TLS" in environ: tlsanswerstring = str.lower(environ["SYNAPSE_NO_TLS"]) @@ -69,19 +85,23 @@ else: if tlsanswerstring in ("false", "off", "0", "no"): environ["SYNAPSE_NO_TLS"] = False else: - print("Environment variable \"SYNAPSE_NO_TLS\" found but value \"" + tlsanswerstring + "\" unrecognized; exiting.") + print( + 'Environment variable "SYNAPSE_NO_TLS" found but value "' + + tlsanswerstring + + '" unrecognized; exiting.' + ) sys.exit(2) convert("/conf/homeserver.yaml", config_path, environ) convert("/conf/log.config", "/compiled/log.config", environ) subprocess.check_output(["chown", "-R", ownership, "/data"]) - args += [ - "--config-path", config_path, - + "--config-path", + config_path, # tell synapse to put any generated keys in /data rather than /compiled - "--keys-directory", "/data", + "--keys-directory", + "/data", ] # Generate missing keys and start synapse diff --git a/docs/sphinx/conf.py b/docs/sphinx/conf.py index 0b15bd8912..ca4b879526 100644 --- a/docs/sphinx/conf.py +++ b/docs/sphinx/conf.py @@ -18,226 +18,220 @@ import os # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -sys.path.insert(0, os.path.abspath('..')) +sys.path.insert(0, os.path.abspath("..")) # -- General configuration ------------------------------------------------ # If your documentation needs a minimal Sphinx version, state it here. -#needs_sphinx = '1.0' +# needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.intersphinx', - 'sphinx.ext.coverage', - 'sphinx.ext.ifconfig', - 'sphinxcontrib.napoleon', + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "sphinx.ext.coverage", + "sphinx.ext.ifconfig", + "sphinxcontrib.napoleon", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix of source filenames. -source_suffix = '.rst' +source_suffix = ".rst" # The encoding of source files. -#source_encoding = 'utf-8-sig' +# source_encoding = 'utf-8-sig' # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = u'Synapse' -copyright = u'Copyright 2014-2017 OpenMarket Ltd, 2017 Vector Creations Ltd, 2017 New Vector Ltd' +project = "Synapse" +copyright = ( + "Copyright 2014-2017 OpenMarket Ltd, 2017 Vector Creations Ltd, 2017 New Vector Ltd" +) # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. -version = '1.0' +version = "1.0" # The full version, including alpha/beta/rc tags. -release = '1.0' +release = "1.0" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. -#language = None +# language = None # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: -#today = '' +# today = '' # Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' +# today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. -exclude_patterns = ['_build'] +exclude_patterns = ["_build"] # The reST default role (used for this markup: `text`) to use for all # documents. -#default_role = None +# default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True +# add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). -#add_module_names = True +# add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. -#show_authors = False +# show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] +# modindex_common_prefix = [] # If true, keep warnings as "system message" paragraphs in the built documents. -#keep_warnings = False +# keep_warnings = False # -- Options for HTML output ---------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = 'default' +html_theme = "default" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. -#html_theme_options = {} +# html_theme_options = {} # Add any paths that contain custom themes here, relative to this directory. -#html_theme_path = [] +# html_theme_path = [] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". -#html_title = None +# html_title = None # A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None +# html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. -#html_logo = None +# html_logo = None # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. -#html_favicon = None +# html_favicon = None # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied # directly to the root of the documentation. -#html_extra_path = [] +# html_extra_path = [] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' +# html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. -#html_use_smartypants = True +# html_use_smartypants = True # Custom sidebar templates, maps document names to template names. -#html_sidebars = {} +# html_sidebars = {} # Additional templates that should be rendered to pages, maps page names to # template names. -#html_additional_pages = {} +# html_additional_pages = {} # If false, no module index is generated. -#html_domain_indices = True +# html_domain_indices = True # If false, no index is generated. -#html_use_index = True +# html_use_index = True # If true, the index is split into individual pages for each letter. -#html_split_index = False +# html_split_index = False # If true, links to the reST sources are added to the pages. -#html_show_sourcelink = True +# html_show_sourcelink = True # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -#html_show_sphinx = True +# html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -#html_show_copyright = True +# html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. -#html_use_opensearch = '' +# html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = None +# html_file_suffix = None # Output file base name for HTML help builder. -htmlhelp_basename = 'Synapsedoc' +htmlhelp_basename = "Synapsedoc" # -- Options for LaTeX output --------------------------------------------- latex_elements = { -# The paper size ('letterpaper' or 'a4paper'). -#'papersize': 'letterpaper', - -# The font size ('10pt', '11pt' or '12pt'). -#'pointsize': '10pt', - -# Additional stuff for the LaTeX preamble. -#'preamble': '', + # The paper size ('letterpaper' or 'a4paper'). + #'papersize': 'letterpaper', + # The font size ('10pt', '11pt' or '12pt'). + #'pointsize': '10pt', + # Additional stuff for the LaTeX preamble. + #'preamble': '', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). -latex_documents = [ - ('index', 'Synapse.tex', u'Synapse Documentation', - u'TNG', 'manual'), -] +latex_documents = [("index", "Synapse.tex", "Synapse Documentation", "TNG", "manual")] # The name of an image file (relative to this directory) to place at the top of # the title page. -#latex_logo = None +# latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. -#latex_use_parts = False +# latex_use_parts = False # If true, show page references after internal links. -#latex_show_pagerefs = False +# latex_show_pagerefs = False # If true, show URL addresses after external links. -#latex_show_urls = False +# latex_show_urls = False # Documents to append as an appendix to all manuals. -#latex_appendices = [] +# latex_appendices = [] # If false, no module index is generated. -#latex_domain_indices = True +# latex_domain_indices = True # -- Options for manual page output --------------------------------------- # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - ('index', 'synapse', u'Synapse Documentation', - [u'TNG'], 1) -] +man_pages = [("index", "synapse", "Synapse Documentation", ["TNG"], 1)] # If true, show URL addresses after external links. -#man_show_urls = False +# man_show_urls = False # -- Options for Texinfo output ------------------------------------------- @@ -246,26 +240,32 @@ man_pages = [ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - ('index', 'Synapse', u'Synapse Documentation', - u'TNG', 'Synapse', 'One line description of project.', - 'Miscellaneous'), + ( + "index", + "Synapse", + "Synapse Documentation", + "TNG", + "Synapse", + "One line description of project.", + "Miscellaneous", + ) ] # Documents to append as an appendix to all manuals. -#texinfo_appendices = [] +# texinfo_appendices = [] # If false, no module index is generated. -#texinfo_domain_indices = True +# texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. -#texinfo_show_urls = 'footnote' +# texinfo_show_urls = 'footnote' # If true, do not generate a @detailmenu in the "Top" node's menu. -#texinfo_no_detailmenu = False +# texinfo_no_detailmenu = False # Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = {'http://docs.python.org/': None} +intersphinx_mapping = {"http://docs.python.org/": None} napoleon_include_special_with_doc = True napoleon_use_ivar = True diff --git a/pyproject.toml b/pyproject.toml index dd099dc9c8..ec23258da8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,3 +28,22 @@ directory = "misc" name = "Internal Changes" showcontent = true + +[tool.black] +target-version = ['py34'] +exclude = ''' + +( + /( + \.eggs # exclude a few common directories in the + | \.git # root of the project + | \.tox + | \.venv + | _build + | _trial_temp.* + | build + | dist + | debian + )/ +) +''' diff --git a/scripts-dev/check_auth.py b/scripts-dev/check_auth.py index b3d11f49ec..2a1c5f39d4 100644 --- a/scripts-dev/check_auth.py +++ b/scripts-dev/check_auth.py @@ -39,11 +39,11 @@ def check_auth(auth, auth_chain, events): print("Success:", e.event_id, e.type, e.state_key) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( - 'json', nargs='?', type=argparse.FileType('r'), default=sys.stdin + "json", nargs="?", type=argparse.FileType("r"), default=sys.stdin ) args = parser.parse_args() diff --git a/scripts-dev/check_event_hash.py b/scripts-dev/check_event_hash.py index 8535f99697..cd5599e9a1 100644 --- a/scripts-dev/check_event_hash.py +++ b/scripts-dev/check_event_hash.py @@ -30,7 +30,7 @@ class dictobj(dict): def main(): parser = argparse.ArgumentParser() parser.add_argument( - "input_json", nargs="?", type=argparse.FileType('r'), default=sys.stdin + "input_json", nargs="?", type=argparse.FileType("r"), default=sys.stdin ) args = parser.parse_args() logging.basicConfig() diff --git a/scripts-dev/check_signature.py b/scripts-dev/check_signature.py index 612f17ca7f..ecda103cf7 100644 --- a/scripts-dev/check_signature.py +++ b/scripts-dev/check_signature.py @@ -1,4 +1,3 @@ - import argparse import json import logging @@ -40,7 +39,7 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument("signature_name") parser.add_argument( - "input_json", nargs="?", type=argparse.FileType('r'), default=sys.stdin + "input_json", nargs="?", type=argparse.FileType("r"), default=sys.stdin ) args = parser.parse_args() @@ -69,5 +68,5 @@ def main(): print("FAIL %s" % (key_id,)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/scripts-dev/convert_server_keys.py b/scripts-dev/convert_server_keys.py index ac152b5c42..179be61c30 100644 --- a/scripts-dev/convert_server_keys.py +++ b/scripts-dev/convert_server_keys.py @@ -116,5 +116,5 @@ def main(): connection.commit() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/scripts-dev/definitions.py b/scripts-dev/definitions.py index 1deb0fe2b7..9eddb6d515 100755 --- a/scripts-dev/definitions.py +++ b/scripts-dev/definitions.py @@ -19,10 +19,10 @@ class DefinitionVisitor(ast.NodeVisitor): self.names = {} self.attrs = set() self.definitions = { - 'def': self.functions, - 'class': self.classes, - 'names': self.names, - 'attrs': self.attrs, + "def": self.functions, + "class": self.classes, + "names": self.names, + "attrs": self.attrs, } def visit_Name(self, node): @@ -47,23 +47,23 @@ class DefinitionVisitor(ast.NodeVisitor): def non_empty(defs): - functions = {name: non_empty(f) for name, f in defs['def'].items()} - classes = {name: non_empty(f) for name, f in defs['class'].items()} + functions = {name: non_empty(f) for name, f in defs["def"].items()} + classes = {name: non_empty(f) for name, f in defs["class"].items()} result = {} if functions: - result['def'] = functions + result["def"] = functions if classes: - result['class'] = classes - names = defs['names'] + result["class"] = classes + names = defs["names"] uses = [] - for name in names.get('Load', ()): - if name not in names.get('Param', ()) and name not in names.get('Store', ()): + for name in names.get("Load", ()): + if name not in names.get("Param", ()) and name not in names.get("Store", ()): uses.append(name) - uses.extend(defs['attrs']) + uses.extend(defs["attrs"]) if uses: - result['uses'] = uses - result['names'] = names - result['attrs'] = defs['attrs'] + result["uses"] = uses + result["names"] = names + result["attrs"] = defs["attrs"] return result @@ -81,33 +81,33 @@ def definitions_in_file(filepath): def defined_names(prefix, defs, names): - for name, funcs in defs.get('def', {}).items(): - names.setdefault(name, {'defined': []})['defined'].append(prefix + name) + for name, funcs in defs.get("def", {}).items(): + names.setdefault(name, {"defined": []})["defined"].append(prefix + name) defined_names(prefix + name + ".", funcs, names) - for name, funcs in defs.get('class', {}).items(): - names.setdefault(name, {'defined': []})['defined'].append(prefix + name) + for name, funcs in defs.get("class", {}).items(): + names.setdefault(name, {"defined": []})["defined"].append(prefix + name) defined_names(prefix + name + ".", funcs, names) def used_names(prefix, item, defs, names): - for name, funcs in defs.get('def', {}).items(): + for name, funcs in defs.get("def", {}).items(): used_names(prefix + name + ".", name, funcs, names) - for name, funcs in defs.get('class', {}).items(): + for name, funcs in defs.get("class", {}).items(): used_names(prefix + name + ".", name, funcs, names) - path = prefix.rstrip('.') - for used in defs.get('uses', ()): + path = prefix.rstrip(".") + for used in defs.get("uses", ()): if used in names: if item: - names[item].setdefault('uses', []).append(used) - names[used].setdefault('used', {}).setdefault(item, []).append(path) + names[item].setdefault("uses", []).append(used) + names[used].setdefault("used", {}).setdefault(item, []).append(path) -if __name__ == '__main__': +if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Find definitions.') + parser = argparse.ArgumentParser(description="Find definitions.") parser.add_argument( "--unused", action="store_true", help="Only list unused definitions" ) @@ -119,7 +119,7 @@ if __name__ == '__main__': ) parser.add_argument( "directories", - nargs='+', + nargs="+", metavar="DIR", help="Directories to search for definitions", ) @@ -164,7 +164,7 @@ if __name__ == '__main__': continue if ignore and any(pattern.match(name) for pattern in ignore): continue - if args.unused and definition.get('used'): + if args.unused and definition.get("used"): continue result[name] = definition @@ -196,9 +196,9 @@ if __name__ == '__main__': continue result[name] = definition - if args.format == 'yaml': + if args.format == "yaml": yaml.dump(result, sys.stdout, default_flow_style=False) - elif args.format == 'dot': + elif args.format == "dot": print("digraph {") for name, entry in result.items(): print(name) diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py index 41e7b24418..7c19e405d4 100755 --- a/scripts-dev/federation_client.py +++ b/scripts-dev/federation_client.py @@ -63,7 +63,7 @@ def encode_canonical_json(value): # Encode code-points outside of ASCII as UTF-8 rather than \u escapes ensure_ascii=False, # Remove unecessary white space. - separators=(',', ':'), + separators=(",", ":"), # Sort the keys of dictionaries. sort_keys=True, # Encode the resulting unicode as UTF-8 bytes. @@ -145,7 +145,7 @@ def request_json(method, origin_name, origin_key, destination, path, content): authorization_headers = [] for key, sig in signed_json["signatures"][origin_name].items(): - header = "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (origin_name, key, sig) + header = 'X-Matrix origin=%s,key="%s",sig="%s"' % (origin_name, key, sig) authorization_headers.append(header.encode("ascii")) print("Authorization: %s" % header, file=sys.stderr) @@ -161,11 +161,7 @@ def request_json(method, origin_name, origin_key, destination, path, content): headers["Content-Type"] = "application/json" result = s.request( - method=method, - url=dest, - headers=headers, - verify=False, - data=content, + method=method, url=dest, headers=headers, verify=False, data=content ) sys.stderr.write("Status Code: %d\n" % (result.status_code,)) return result.json() @@ -241,18 +237,18 @@ def main(): def read_args_from_config(args): - with open(args.config, 'r') as fh: + with open(args.config, "r") as fh: config = yaml.safe_load(fh) if not args.server_name: - args.server_name = config['server_name'] + args.server_name = config["server_name"] if not args.signing_key_path: - args.signing_key_path = config['signing_key_path'] + args.signing_key_path = config["signing_key_path"] class MatrixConnectionAdapter(HTTPAdapter): @staticmethod def lookup(s, skip_well_known=False): - if s[-1] == ']': + if s[-1] == "]": # ipv6 literal (with no port) return s, 8448 @@ -268,9 +264,7 @@ class MatrixConnectionAdapter(HTTPAdapter): if not skip_well_known: well_known = MatrixConnectionAdapter.get_well_known(s) if well_known: - return MatrixConnectionAdapter.lookup( - well_known, skip_well_known=True - ) + return MatrixConnectionAdapter.lookup(well_known, skip_well_known=True) try: srv = srvlookup.lookup("matrix", "tcp", s)[0] @@ -280,8 +274,8 @@ class MatrixConnectionAdapter(HTTPAdapter): @staticmethod def get_well_known(server_name): - uri = "https://%s/.well-known/matrix/server" % (server_name, ) - print("fetching %s" % (uri, ), file=sys.stderr) + uri = "https://%s/.well-known/matrix/server" % (server_name,) + print("fetching %s" % (uri,), file=sys.stderr) try: resp = requests.get(uri) @@ -294,12 +288,12 @@ class MatrixConnectionAdapter(HTTPAdapter): raise Exception("not a dict") if "m.server" not in parsed_well_known: raise Exception("Missing key 'm.server'") - new_name = parsed_well_known['m.server'] - print("well-known lookup gave %s" % (new_name, ), file=sys.stderr) + new_name = parsed_well_known["m.server"] + print("well-known lookup gave %s" % (new_name,), file=sys.stderr) return new_name except Exception as e: - print("Invalid response from %s: %s" % (uri, e, ), file=sys.stderr) + print("Invalid response from %s: %s" % (uri, e), file=sys.stderr) return None def get_connection(self, url, proxies=None): diff --git a/scripts-dev/hash_history.py b/scripts-dev/hash_history.py index 514d80fa60..d20f6db176 100644 --- a/scripts-dev/hash_history.py +++ b/scripts-dev/hash_history.py @@ -79,5 +79,5 @@ def main(): conn.commit() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/scripts-dev/list_url_patterns.py b/scripts-dev/list_url_patterns.py index 62e5a07472..26ad7c67f4 100755 --- a/scripts-dev/list_url_patterns.py +++ b/scripts-dev/list_url_patterns.py @@ -35,11 +35,11 @@ def find_patterns_in_file(filepath): find_patterns_in_code(f.read()) -parser = argparse.ArgumentParser(description='Find url patterns.') +parser = argparse.ArgumentParser(description="Find url patterns.") parser.add_argument( "directories", - nargs='+', + nargs="+", metavar="DIR", help="Directories to search for definitions", ) diff --git a/scripts-dev/tail-synapse.py b/scripts-dev/tail-synapse.py index 7c9985d9f0..44e3a6dbf1 100644 --- a/scripts-dev/tail-synapse.py +++ b/scripts-dev/tail-synapse.py @@ -63,5 +63,5 @@ def main(): streams[update.name] = update.position -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/scripts/generate_signing_key.py b/scripts/generate_signing_key.py index 36e9140b50..16d7c4f382 100755 --- a/scripts/generate_signing_key.py +++ b/scripts/generate_signing_key.py @@ -24,14 +24,14 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( - "-o", "--output_file", - - type=argparse.FileType('w'), + "-o", + "--output_file", + type=argparse.FileType("w"), default=sys.stdout, help="Where to write the output to", ) args = parser.parse_args() key_id = "a_" + random_string(4) - key = generate_signing_key(key_id), + key = (generate_signing_key(key_id),) write_signing_keys(args.output_file, key) diff --git a/scripts/move_remote_media_to_new_store.py b/scripts/move_remote_media_to_new_store.py index e630936f78..12747c6024 100755 --- a/scripts/move_remote_media_to_new_store.py +++ b/scripts/move_remote_media_to_new_store.py @@ -50,7 +50,7 @@ def main(src_repo, dest_repo): dest_paths = MediaFilePaths(dest_repo) for line in sys.stdin: line = line.strip() - parts = line.split('|') + parts = line.split("|") if len(parts) != 2: print("Unable to parse input line %s" % line, file=sys.stderr) exit(1) @@ -107,7 +107,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter ) - parser.add_argument("-v", action='store_true', help='enable debug logging') + parser.add_argument("-v", action="store_true", help="enable debug logging") parser.add_argument("src_repo", help="Path to source content repo") parser.add_argument("dest_repo", help="Path to source content repo") args = parser.parse_args() diff --git a/setup.cfg b/setup.cfg index b6b4aa740d..12a7849081 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,18 +15,17 @@ ignore = tox.ini [flake8] -max-line-length = 90 - # see https://pycodestyle.readthedocs.io/en/latest/intro.html#error-codes # for error codes. The ones we ignore are: # W503: line break before binary operator # W504: line break after binary operator # E203: whitespace before ':' (which is contrary to pep8?) # E731: do not assign a lambda expression, use a def -ignore=W503,W504,E203,E731 +# E501: Line too long (black enforces this for us) +ignore=W503,W504,E203,E731,E501 [isort] -line_length = 89 +line_length = 88 not_skip = __init__.py sections=FUTURE,STDLIB,COMPAT,THIRDPARTY,TWISTED,FIRSTPARTY,TESTS,LOCALFOLDER default_section=THIRDPARTY diff --git a/setup.py b/setup.py index 3492cdc5a0..5ce06c8987 100755 --- a/setup.py +++ b/setup.py @@ -60,9 +60,12 @@ class TestCommand(Command): pass def run(self): - print ("""Synapse's tests cannot be run via setup.py. To run them, try: + print( + """Synapse's tests cannot be run via setup.py. To run them, try: PYTHONPATH="." trial tests -""") +""" + ) + def read_file(path_segments): """Read a file from the package. Takes a list of strings to join to @@ -84,9 +87,9 @@ version = exec_file(("synapse", "__init__.py"))["__version__"] dependencies = exec_file(("synapse", "python_dependencies.py")) long_description = read_file(("README.rst",)) -REQUIREMENTS = dependencies['REQUIREMENTS'] -CONDITIONAL_REQUIREMENTS = dependencies['CONDITIONAL_REQUIREMENTS'] -ALL_OPTIONAL_REQUIREMENTS = dependencies['ALL_OPTIONAL_REQUIREMENTS'] +REQUIREMENTS = dependencies["REQUIREMENTS"] +CONDITIONAL_REQUIREMENTS = dependencies["CONDITIONAL_REQUIREMENTS"] +ALL_OPTIONAL_REQUIREMENTS = dependencies["ALL_OPTIONAL_REQUIREMENTS"] # Make `pip install matrix-synapse[all]` install all the optional dependencies. CONDITIONAL_REQUIREMENTS["all"] = list(ALL_OPTIONAL_REQUIREMENTS) @@ -102,16 +105,16 @@ setup( include_package_data=True, zip_safe=False, long_description=long_description, - python_requires='~=3.5', + python_requires="~=3.5", classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'Topic :: Communications :: Chat', - 'License :: OSI Approved :: Apache Software License', - 'Programming Language :: Python :: 3 :: Only', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', + "Development Status :: 5 - Production/Stable", + "Topic :: Communications :: Chat", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.5", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", ], scripts=["synctl"] + glob.glob("scripts/*"), - cmdclass={'test': TestCommand}, + cmdclass={"test": TestCommand}, ) diff --git a/synapse/__init__.py b/synapse/__init__.py index 0c01546789..119359be68 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -28,6 +28,7 @@ try: from twisted.internet import protocol from twisted.internet.protocol import Factory from twisted.names.dns import DNSDatagramProtocol + protocol.Factory.noisy = False Factory.noisy = False DNSDatagramProtocol.noisy = False diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py index 6e93f5a0c6..bdcd915bbe 100644 --- a/synapse/_scripts/register_new_matrix_user.py +++ b/synapse/_scripts/register_new_matrix_user.py @@ -57,18 +57,18 @@ def request_registration( nonce = r.json()["nonce"] - mac = hmac.new(key=shared_secret.encode('utf8'), digestmod=hashlib.sha1) + mac = hmac.new(key=shared_secret.encode("utf8"), digestmod=hashlib.sha1) - mac.update(nonce.encode('utf8')) + mac.update(nonce.encode("utf8")) mac.update(b"\x00") - mac.update(user.encode('utf8')) + mac.update(user.encode("utf8")) mac.update(b"\x00") - mac.update(password.encode('utf8')) + mac.update(password.encode("utf8")) mac.update(b"\x00") mac.update(b"admin" if admin else b"notadmin") if user_type: mac.update(b"\x00") - mac.update(user_type.encode('utf8')) + mac.update(user_type.encode("utf8")) mac = mac.hexdigest() @@ -134,8 +134,9 @@ def register_new_user(user, password, server_location, shared_secret, admin, use else: admin = False - request_registration(user, password, server_location, shared_secret, - bool(admin), user_type) + request_registration( + user, password, server_location, shared_secret, bool(admin), user_type + ) def main(): @@ -189,7 +190,7 @@ def main(): group.add_argument( "-c", "--config", - type=argparse.FileType('r'), + type=argparse.FileType("r"), help="Path to server config file. Used to read in shared secret.", ) @@ -200,7 +201,7 @@ def main(): parser.add_argument( "server_url", default="https://localhost:8448", - nargs='?', + nargs="?", help="URL to use to talk to the home server. Defaults to " " 'https://localhost:8448'.", ) @@ -220,8 +221,9 @@ def main(): if args.admin or args.no_admin: admin = args.admin - register_new_user(args.user, args.password, args.server_url, secret, - admin, args.user_type) + register_new_user( + args.user, args.password, args.server_url, secret, admin, args.user_type + ) if __name__ == "__main__": diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 79e2808dc5..86f145649c 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -36,8 +36,11 @@ logger = logging.getLogger(__name__) AuthEventTypes = ( - EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels, - EventTypes.JoinRules, EventTypes.RoomHistoryVisibility, + EventTypes.Create, + EventTypes.Member, + EventTypes.PowerLevels, + EventTypes.JoinRules, + EventTypes.RoomHistoryVisibility, EventTypes.ThirdPartyInvite, ) @@ -54,6 +57,7 @@ class Auth(object): FIXME: This class contains a mix of functions for authenticating users of our client-server API and authenticating events added to room graphs. """ + def __init__(self, hs): self.hs = hs self.clock = hs.get_clock() @@ -70,15 +74,12 @@ class Auth(object): def check_from_context(self, room_version, event, context, do_sig_check=True): prev_state_ids = yield context.get_prev_state_ids(self.store) auth_events_ids = yield self.compute_auth_events( - event, prev_state_ids, for_verification=True, + event, prev_state_ids, for_verification=True ) auth_events = yield self.store.get_events(auth_events_ids) - auth_events = { - (e.type, e.state_key): e for e in itervalues(auth_events) - } + auth_events = {(e.type, e.state_key): e for e in itervalues(auth_events)} self.check( - room_version, event, - auth_events=auth_events, do_sig_check=do_sig_check, + room_version, event, auth_events=auth_events, do_sig_check=do_sig_check ) def check(self, room_version, event, auth_events, do_sig_check=True): @@ -115,15 +116,10 @@ class Auth(object): the room. """ if current_state: - member = current_state.get( - (EventTypes.Member, user_id), - None - ) + member = current_state.get((EventTypes.Member, user_id), None) else: member = yield self.state.get_current_state( - room_id=room_id, - event_type=EventTypes.Member, - state_key=user_id + room_id=room_id, event_type=EventTypes.Member, state_key=user_id ) self._check_joined_room(member, user_id, room_id) @@ -143,23 +139,17 @@ class Auth(object): the room. This will be the leave event if they have left the room. """ member = yield self.state.get_current_state( - room_id=room_id, - event_type=EventTypes.Member, - state_key=user_id + room_id=room_id, event_type=EventTypes.Member, state_key=user_id ) membership = member.membership if member else None if membership not in (Membership.JOIN, Membership.LEAVE): - raise AuthError(403, "User %s not in room %s" % ( - user_id, room_id - )) + raise AuthError(403, "User %s not in room %s" % (user_id, room_id)) if membership == Membership.LEAVE: forgot = yield self.store.did_forget(user_id, room_id) if forgot: - raise AuthError(403, "User %s not in room %s" % ( - user_id, room_id - )) + raise AuthError(403, "User %s not in room %s" % (user_id, room_id)) defer.returnValue(member) @@ -171,9 +161,9 @@ class Auth(object): def _check_joined_room(self, member, user_id, room_id): if not member or member.membership != Membership.JOIN: - raise AuthError(403, "User %s not in room %s (%s)" % ( - user_id, room_id, repr(member) - )) + raise AuthError( + 403, "User %s not in room %s (%s)" % (user_id, room_id, repr(member)) + ) def can_federate(self, event, auth_events): creation_event = auth_events.get((EventTypes.Create, "")) @@ -185,11 +175,7 @@ class Auth(object): @defer.inlineCallbacks def get_user_by_req( - self, - request, - allow_guest=False, - rights="access", - allow_expired=False, + self, request, allow_guest=False, rights="access", allow_expired=False ): """ Get a registered user's ID. @@ -209,9 +195,8 @@ class Auth(object): try: ip_addr = self.hs.get_ip_from_request(request) user_agent = request.requestHeaders.getRawHeaders( - b"User-Agent", - default=[b""] - )[0].decode('ascii', 'surrogateescape') + b"User-Agent", default=[b""] + )[0].decode("ascii", "surrogateescape") access_token = self.get_access_token_from_request( request, self.TOKEN_NOT_FOUND_HTTP_STATUS @@ -243,11 +228,12 @@ class Auth(object): if self._account_validity.enabled and not allow_expired: user_id = user.to_string() expiration_ts = yield self.store.get_expiration_ts_for_user(user_id) - if expiration_ts is not None and self.clock.time_msec() >= expiration_ts: + if ( + expiration_ts is not None + and self.clock.time_msec() >= expiration_ts + ): raise AuthError( - 403, - "User account has expired", - errcode=Codes.EXPIRED_ACCOUNT, + 403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT ) # device_id may not be present if get_user_by_access_token has been @@ -265,18 +251,23 @@ class Auth(object): if is_guest and not allow_guest: raise AuthError( - 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN + 403, + "Guest access not allowed", + errcode=Codes.GUEST_ACCESS_FORBIDDEN, ) request.authenticated_entity = user.to_string() - defer.returnValue(synapse.types.create_requester( - user, token_id, is_guest, device_id, app_service=app_service) + defer.returnValue( + synapse.types.create_requester( + user, token_id, is_guest, device_id, app_service=app_service + ) ) except KeyError: raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", - errcode=Codes.MISSING_TOKEN + self.TOKEN_NOT_FOUND_HTTP_STATUS, + "Missing access token.", + errcode=Codes.MISSING_TOKEN, ) @defer.inlineCallbacks @@ -297,20 +288,14 @@ class Auth(object): if b"user_id" not in request.args: defer.returnValue((app_service.sender, app_service)) - user_id = request.args[b"user_id"][0].decode('utf8') + user_id = request.args[b"user_id"][0].decode("utf8") if app_service.sender == user_id: defer.returnValue((app_service.sender, app_service)) if not app_service.is_interested_in_user(user_id): - raise AuthError( - 403, - "Application service cannot masquerade as this user." - ) + raise AuthError(403, "Application service cannot masquerade as this user.") if not (yield self.store.get_user_by_id(user_id)): - raise AuthError( - 403, - "Application service has not registered this user" - ) + raise AuthError(403, "Application service has not registered this user") defer.returnValue((user_id, app_service)) @defer.inlineCallbacks @@ -368,13 +353,13 @@ class Auth(object): raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unknown user_id %s" % user_id, - errcode=Codes.UNKNOWN_TOKEN + errcode=Codes.UNKNOWN_TOKEN, ) if not stored_user["is_guest"]: raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Guest access token used for regular user", - errcode=Codes.UNKNOWN_TOKEN + errcode=Codes.UNKNOWN_TOKEN, ) ret = { "user": user, @@ -402,8 +387,9 @@ class Auth(object): ) as e: logger.warning("Invalid macaroon in auth: %s %s", type(e), e) raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.", - errcode=Codes.UNKNOWN_TOKEN + self.TOKEN_NOT_FOUND_HTTP_STATUS, + "Invalid macaroon passed.", + errcode=Codes.UNKNOWN_TOKEN, ) def _parse_and_validate_macaroon(self, token, rights="access"): @@ -441,13 +427,13 @@ class Auth(object): guest = True self.validate_macaroon( - macaroon, rights, self.hs.config.expire_access_token, - user_id=user_id, + macaroon, rights, self.hs.config.expire_access_token, user_id=user_id ) except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.", - errcode=Codes.UNKNOWN_TOKEN + self.TOKEN_NOT_FOUND_HTTP_STATUS, + "Invalid macaroon passed.", + errcode=Codes.UNKNOWN_TOKEN, ) if not has_expiry and rights == "access": @@ -472,10 +458,11 @@ class Auth(object): user_prefix = "user_id = " for caveat in macaroon.caveats: if caveat.caveat_id.startswith(user_prefix): - return caveat.caveat_id[len(user_prefix):] + return caveat.caveat_id[len(user_prefix) :] raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon", - errcode=Codes.UNKNOWN_TOKEN + self.TOKEN_NOT_FOUND_HTTP_STATUS, + "No user caveat in macaroon", + errcode=Codes.UNKNOWN_TOKEN, ) def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id): @@ -522,7 +509,7 @@ class Auth(object): prefix = "time < " if not caveat.startswith(prefix): return False - expiry = int(caveat[len(prefix):]) + expiry = int(caveat[len(prefix) :]) now = self.hs.get_clock().time_msec() return now < expiry @@ -554,14 +541,12 @@ class Auth(object): raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.", - errcode=Codes.UNKNOWN_TOKEN + errcode=Codes.UNKNOWN_TOKEN, ) request.authenticated_entity = service.sender return defer.succeed(service) except KeyError: - raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token." - ) + raise AuthError(self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.") def is_server_admin(self, user): """ Check if the given user is a local server admin. @@ -581,19 +566,19 @@ class Auth(object): auth_ids = [] - key = (EventTypes.PowerLevels, "", ) + key = (EventTypes.PowerLevels, "") power_level_event_id = current_state_ids.get(key) if power_level_event_id: auth_ids.append(power_level_event_id) - key = (EventTypes.JoinRules, "", ) + key = (EventTypes.JoinRules, "") join_rule_event_id = current_state_ids.get(key) - key = (EventTypes.Member, event.sender, ) + key = (EventTypes.Member, event.sender) member_event_id = current_state_ids.get(key) - key = (EventTypes.Create, "", ) + key = (EventTypes.Create, "") create_event_id = current_state_ids.get(key) if create_event_id: auth_ids.append(create_event_id) @@ -619,7 +604,7 @@ class Auth(object): auth_ids.append(member_event_id) if for_verification: - key = (EventTypes.Member, event.state_key, ) + key = (EventTypes.Member, event.state_key) existing_event_id = current_state_ids.get(key) if existing_event_id: auth_ids.append(existing_event_id) @@ -628,7 +613,7 @@ class Auth(object): if "third_party_invite" in event.content: key = ( EventTypes.ThirdPartyInvite, - event.content["third_party_invite"]["signed"]["token"] + event.content["third_party_invite"]["signed"]["token"], ) third_party_invite_id = current_state_ids.get(key) if third_party_invite_id: @@ -684,7 +669,7 @@ class Auth(object): auth_events[(EventTypes.PowerLevels, "")] = power_level_event send_level = event_auth.get_send_level( - EventTypes.Aliases, "", power_level_event, + EventTypes.Aliases, "", power_level_event ) user_level = event_auth.get_user_power_level(user_id, auth_events) @@ -692,7 +677,7 @@ class Auth(object): raise AuthError( 403, "This server requires you to be a moderator in the room to" - " edit its room list entry" + " edit its room list entry", ) @staticmethod @@ -742,7 +727,7 @@ class Auth(object): ) parts = auth_headers[0].split(b" ") if parts[0] == b"Bearer" and len(parts) == 2: - return parts[1].decode('ascii') + return parts[1].decode("ascii") else: raise AuthError( token_not_found_http_status, @@ -755,10 +740,10 @@ class Auth(object): raise AuthError( token_not_found_http_status, "Missing access token.", - errcode=Codes.MISSING_TOKEN + errcode=Codes.MISSING_TOKEN, ) - return query_params[0].decode('ascii') + return query_params[0].decode("ascii") @defer.inlineCallbacks def check_in_room_or_world_readable(self, room_id, user_id): @@ -785,8 +770,8 @@ class Auth(object): room_id, EventTypes.RoomHistoryVisibility, "" ) if ( - visibility and - visibility.content["history_visibility"] == "world_readable" + visibility + and visibility.content["history_visibility"] == "world_readable" ): defer.returnValue((Membership.JOIN, None)) return @@ -820,10 +805,11 @@ class Auth(object): if self.hs.config.hs_disabled: raise ResourceLimitError( - 403, self.hs.config.hs_disabled_message, + 403, + self.hs.config.hs_disabled_message, errcode=Codes.RESOURCE_LIMIT_EXCEEDED, admin_contact=self.hs.config.admin_contact, - limit_type=self.hs.config.hs_disabled_limit_type + limit_type=self.hs.config.hs_disabled_limit_type, ) if self.hs.config.limit_usage_by_mau is True: assert not (user_id and threepid) @@ -848,8 +834,9 @@ class Auth(object): current_mau = yield self.store.get_monthly_active_count() if current_mau >= self.hs.config.max_mau_value: raise ResourceLimitError( - 403, "Monthly Active User Limit Exceeded", + 403, + "Monthly Active User Limit Exceeded", admin_contact=self.hs.config.admin_contact, errcode=Codes.RESOURCE_LIMIT_EXCEEDED, - limit_type="monthly_active_user" + limit_type="monthly_active_user", ) diff --git a/synapse/api/constants.py b/synapse/api/constants.py index ee129c8689..3ffde0d7fc 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -18,7 +18,7 @@ """Contains constants from the specification.""" # the "depth" field on events is limited to 2**63 - 1 -MAX_DEPTH = 2**63 - 1 +MAX_DEPTH = 2 ** 63 - 1 # the maximum length for a room alias is 255 characters MAX_ALIAS_LENGTH = 255 @@ -30,39 +30,41 @@ MAX_USERID_LENGTH = 255 class Membership(object): """Represents the membership states of a user in a room.""" - INVITE = u"invite" - JOIN = u"join" - KNOCK = u"knock" - LEAVE = u"leave" - BAN = u"ban" + + INVITE = "invite" + JOIN = "join" + KNOCK = "knock" + LEAVE = "leave" + BAN = "ban" LIST = (INVITE, JOIN, KNOCK, LEAVE, BAN) class PresenceState(object): """Represents the presence state of a user.""" - OFFLINE = u"offline" - UNAVAILABLE = u"unavailable" - ONLINE = u"online" + + OFFLINE = "offline" + UNAVAILABLE = "unavailable" + ONLINE = "online" class JoinRules(object): - PUBLIC = u"public" - KNOCK = u"knock" - INVITE = u"invite" - PRIVATE = u"private" + PUBLIC = "public" + KNOCK = "knock" + INVITE = "invite" + PRIVATE = "private" class LoginType(object): - PASSWORD = u"m.login.password" - EMAIL_IDENTITY = u"m.login.email.identity" - MSISDN = u"m.login.msisdn" - RECAPTCHA = u"m.login.recaptcha" - TERMS = u"m.login.terms" - DUMMY = u"m.login.dummy" + PASSWORD = "m.login.password" + EMAIL_IDENTITY = "m.login.email.identity" + MSISDN = "m.login.msisdn" + RECAPTCHA = "m.login.recaptcha" + TERMS = "m.login.terms" + DUMMY = "m.login.dummy" # Only for C/S API v1 - APPLICATION_SERVICE = u"m.login.application_service" - SHARED_SECRET = u"org.matrix.login.shared_secret" + APPLICATION_SERVICE = "m.login.application_service" + SHARED_SECRET = "org.matrix.login.shared_secret" class EventTypes(object): @@ -118,6 +120,7 @@ class UserTypes(object): """Allows for user type specific behaviour. With the benefit of hindsight 'admin' and 'guest' users should also be UserTypes. Normal users are type None """ + SUPPORT = "support" ALL_USER_TYPES = (SUPPORT,) @@ -125,6 +128,7 @@ class UserTypes(object): class RelationTypes(object): """The types of relations known to this server. """ + ANNOTATION = "m.annotation" REPLACE = "m.replace" REFERENCE = "m.reference" diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 66201d6efe..28b5c2af9b 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -70,6 +70,7 @@ class CodeMessageException(RuntimeError): code (int): HTTP error code msg (str): string describing the error """ + def __init__(self, code, msg): super(CodeMessageException, self).__init__("%d: %s" % (code, msg)) self.code = code @@ -83,6 +84,7 @@ class SynapseError(CodeMessageException): Attributes: errcode (str): Matrix error code e.g 'M_FORBIDDEN' """ + def __init__(self, code, msg, errcode=Codes.UNKNOWN): """Constructs a synapse error. @@ -95,10 +97,7 @@ class SynapseError(CodeMessageException): self.errcode = errcode def error_dict(self): - return cs_error( - self.msg, - self.errcode, - ) + return cs_error(self.msg, self.errcode) class ProxiedRequestError(SynapseError): @@ -107,27 +106,23 @@ class ProxiedRequestError(SynapseError): Attributes: errcode (str): Matrix error code e.g 'M_FORBIDDEN' """ + def __init__(self, code, msg, errcode=Codes.UNKNOWN, additional_fields=None): - super(ProxiedRequestError, self).__init__( - code, msg, errcode - ) + super(ProxiedRequestError, self).__init__(code, msg, errcode) if additional_fields is None: self._additional_fields = {} else: self._additional_fields = dict(additional_fields) def error_dict(self): - return cs_error( - self.msg, - self.errcode, - **self._additional_fields - ) + return cs_error(self.msg, self.errcode, **self._additional_fields) class ConsentNotGivenError(SynapseError): """The error returned to the client when the user has not consented to the privacy policy. """ + def __init__(self, msg, consent_uri): """Constructs a ConsentNotGivenError @@ -136,22 +131,17 @@ class ConsentNotGivenError(SynapseError): consent_url (str): The URL where the user can give their consent """ super(ConsentNotGivenError, self).__init__( - code=http_client.FORBIDDEN, - msg=msg, - errcode=Codes.CONSENT_NOT_GIVEN + code=http_client.FORBIDDEN, msg=msg, errcode=Codes.CONSENT_NOT_GIVEN ) self._consent_uri = consent_uri def error_dict(self): - return cs_error( - self.msg, - self.errcode, - consent_uri=self._consent_uri - ) + return cs_error(self.msg, self.errcode, consent_uri=self._consent_uri) class RegistrationError(SynapseError): """An error raised when a registration event fails.""" + pass @@ -190,15 +180,17 @@ class InteractiveAuthIncompleteError(Exception): result (dict): the server response to the request, which should be passed back to the client """ + def __init__(self, result): super(InteractiveAuthIncompleteError, self).__init__( - "Interactive auth not yet complete", + "Interactive auth not yet complete" ) self.result = result class UnrecognizedRequestError(SynapseError): """An error indicating we don't understand the request you're trying to make""" + def __init__(self, *args, **kwargs): if "errcode" not in kwargs: kwargs["errcode"] = Codes.UNRECOGNIZED @@ -207,21 +199,14 @@ class UnrecognizedRequestError(SynapseError): message = "Unrecognized request" else: message = args[0] - super(UnrecognizedRequestError, self).__init__( - 400, - message, - **kwargs - ) + super(UnrecognizedRequestError, self).__init__(400, message, **kwargs) class NotFoundError(SynapseError): """An error indicating we can't find the thing you asked for""" + def __init__(self, msg="Not found", errcode=Codes.NOT_FOUND): - super(NotFoundError, self).__init__( - 404, - msg, - errcode=errcode - ) + super(NotFoundError, self).__init__(404, msg, errcode=errcode) class AuthError(SynapseError): @@ -238,8 +223,11 @@ class ResourceLimitError(SynapseError): Any error raised when there is a problem with resource usage. For instance, the monthly active user limit for the server has been exceeded """ + def __init__( - self, code, msg, + self, + code, + msg, errcode=Codes.RESOURCE_LIMIT_EXCEEDED, admin_contact=None, limit_type=None, @@ -253,7 +241,7 @@ class ResourceLimitError(SynapseError): self.msg, self.errcode, admin_contact=self.admin_contact, - limit_type=self.limit_type + limit_type=self.limit_type, ) @@ -268,6 +256,7 @@ class EventSizeError(SynapseError): class EventStreamError(SynapseError): """An error raised when there a problem with the event stream.""" + def __init__(self, *args, **kwargs): if "errcode" not in kwargs: kwargs["errcode"] = Codes.BAD_PAGINATION @@ -276,47 +265,53 @@ class EventStreamError(SynapseError): class LoginError(SynapseError): """An error raised when there was a problem logging in.""" + pass class StoreError(SynapseError): """An error raised when there was a problem storing some data.""" + pass class InvalidCaptchaError(SynapseError): - def __init__(self, code=400, msg="Invalid captcha.", error_url=None, - errcode=Codes.CAPTCHA_INVALID): + def __init__( + self, + code=400, + msg="Invalid captcha.", + error_url=None, + errcode=Codes.CAPTCHA_INVALID, + ): super(InvalidCaptchaError, self).__init__(code, msg, errcode) self.error_url = error_url def error_dict(self): - return cs_error( - self.msg, - self.errcode, - error_url=self.error_url, - ) + return cs_error(self.msg, self.errcode, error_url=self.error_url) class LimitExceededError(SynapseError): """A client has sent too many requests and is being throttled. """ - def __init__(self, code=429, msg="Too Many Requests", retry_after_ms=None, - errcode=Codes.LIMIT_EXCEEDED): + + def __init__( + self, + code=429, + msg="Too Many Requests", + retry_after_ms=None, + errcode=Codes.LIMIT_EXCEEDED, + ): super(LimitExceededError, self).__init__(code, msg, errcode) self.retry_after_ms = retry_after_ms def error_dict(self): - return cs_error( - self.msg, - self.errcode, - retry_after_ms=self.retry_after_ms, - ) + return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms) class RoomKeysVersionError(SynapseError): """A client has tried to upload to a non-current version of the room_keys store """ + def __init__(self, current_version): """ Args: @@ -331,6 +326,7 @@ class RoomKeysVersionError(SynapseError): class UnsupportedRoomVersionError(SynapseError): """The client's request to create a room used a room version that the server does not support.""" + def __init__(self): super(UnsupportedRoomVersionError, self).__init__( code=400, @@ -354,22 +350,19 @@ class IncompatibleRoomVersionError(SynapseError): Unlike UnsupportedRoomVersionError, it is specific to the case of the make_join failing. """ + def __init__(self, room_version): super(IncompatibleRoomVersionError, self).__init__( code=400, msg="Your homeserver does not support the features required to " - "join this room", + "join this room", errcode=Codes.INCOMPATIBLE_ROOM_VERSION, ) self._room_version = room_version def error_dict(self): - return cs_error( - self.msg, - self.errcode, - room_version=self._room_version, - ) + return cs_error(self.msg, self.errcode, room_version=self._room_version) class RequestSendFailed(RuntimeError): @@ -380,11 +373,11 @@ class RequestSendFailed(RuntimeError): networking (e.g. DNS failures, connection timeouts etc), versus unexpected errors (like programming errors). """ + def __init__(self, inner_exception, can_retry): super(RequestSendFailed, self).__init__( - "Failed to send request: %s: %s" % ( - type(inner_exception).__name__, inner_exception, - ) + "Failed to send request: %s: %s" + % (type(inner_exception).__name__, inner_exception) ) self.inner_exception = inner_exception self.can_retry = can_retry @@ -428,7 +421,7 @@ class FederationError(RuntimeError): self.affected = affected self.source = source - msg = "%s %s: %s" % (level, code, reason,) + msg = "%s %s: %s" % (level, code, reason) super(FederationError, self).__init__(msg) def get_dict(self): @@ -448,6 +441,7 @@ class HttpResponseException(CodeMessageException): Attributes: response (bytes): body of response """ + def __init__(self, code, msg, response): """ @@ -486,7 +480,7 @@ class HttpResponseException(CodeMessageException): if not isinstance(j, dict): j = {} - errcode = j.pop('errcode', Codes.UNKNOWN) - errmsg = j.pop('error', self.msg) + errcode = j.pop("errcode", Codes.UNKNOWN) + errmsg = j.pop("error", self.msg) return ProxiedRequestError(self.code, errmsg, errcode, j) diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 3906475403..9b3daca29b 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -28,117 +28,55 @@ FILTER_SCHEMA = { "additionalProperties": False, "type": "object", "properties": { - "limit": { - "type": "number" - }, - "senders": { - "$ref": "#/definitions/user_id_array" - }, - "not_senders": { - "$ref": "#/definitions/user_id_array" - }, + "limit": {"type": "number"}, + "senders": {"$ref": "#/definitions/user_id_array"}, + "not_senders": {"$ref": "#/definitions/user_id_array"}, # TODO: We don't limit event type values but we probably should... # check types are valid event types - "types": { - "type": "array", - "items": { - "type": "string" - } - }, - "not_types": { - "type": "array", - "items": { - "type": "string" - } - } - } + "types": {"type": "array", "items": {"type": "string"}}, + "not_types": {"type": "array", "items": {"type": "string"}}, + }, } ROOM_FILTER_SCHEMA = { "additionalProperties": False, "type": "object", "properties": { - "not_rooms": { - "$ref": "#/definitions/room_id_array" - }, - "rooms": { - "$ref": "#/definitions/room_id_array" - }, - "ephemeral": { - "$ref": "#/definitions/room_event_filter" - }, - "include_leave": { - "type": "boolean" - }, - "state": { - "$ref": "#/definitions/room_event_filter" - }, - "timeline": { - "$ref": "#/definitions/room_event_filter" - }, - "account_data": { - "$ref": "#/definitions/room_event_filter" - }, - } + "not_rooms": {"$ref": "#/definitions/room_id_array"}, + "rooms": {"$ref": "#/definitions/room_id_array"}, + "ephemeral": {"$ref": "#/definitions/room_event_filter"}, + "include_leave": {"type": "boolean"}, + "state": {"$ref": "#/definitions/room_event_filter"}, + "timeline": {"$ref": "#/definitions/room_event_filter"}, + "account_data": {"$ref": "#/definitions/room_event_filter"}, + }, } ROOM_EVENT_FILTER_SCHEMA = { "additionalProperties": False, "type": "object", "properties": { - "limit": { - "type": "number" - }, - "senders": { - "$ref": "#/definitions/user_id_array" - }, - "not_senders": { - "$ref": "#/definitions/user_id_array" - }, - "types": { - "type": "array", - "items": { - "type": "string" - } - }, - "not_types": { - "type": "array", - "items": { - "type": "string" - } - }, - "rooms": { - "$ref": "#/definitions/room_id_array" - }, - "not_rooms": { - "$ref": "#/definitions/room_id_array" - }, - "contains_url": { - "type": "boolean" - }, - "lazy_load_members": { - "type": "boolean" - }, - "include_redundant_members": { - "type": "boolean" - }, - } + "limit": {"type": "number"}, + "senders": {"$ref": "#/definitions/user_id_array"}, + "not_senders": {"$ref": "#/definitions/user_id_array"}, + "types": {"type": "array", "items": {"type": "string"}}, + "not_types": {"type": "array", "items": {"type": "string"}}, + "rooms": {"$ref": "#/definitions/room_id_array"}, + "not_rooms": {"$ref": "#/definitions/room_id_array"}, + "contains_url": {"type": "boolean"}, + "lazy_load_members": {"type": "boolean"}, + "include_redundant_members": {"type": "boolean"}, + }, } USER_ID_ARRAY_SCHEMA = { "type": "array", - "items": { - "type": "string", - "format": "matrix_user_id" - } + "items": {"type": "string", "format": "matrix_user_id"}, } ROOM_ID_ARRAY_SCHEMA = { "type": "array", - "items": { - "type": "string", - "format": "matrix_room_id" - } + "items": {"type": "string", "format": "matrix_room_id"}, } USER_FILTER_SCHEMA = { @@ -150,22 +88,13 @@ USER_FILTER_SCHEMA = { "user_id_array": USER_ID_ARRAY_SCHEMA, "filter": FILTER_SCHEMA, "room_filter": ROOM_FILTER_SCHEMA, - "room_event_filter": ROOM_EVENT_FILTER_SCHEMA + "room_event_filter": ROOM_EVENT_FILTER_SCHEMA, }, "properties": { - "presence": { - "$ref": "#/definitions/filter" - }, - "account_data": { - "$ref": "#/definitions/filter" - }, - "room": { - "$ref": "#/definitions/room_filter" - }, - "event_format": { - "type": "string", - "enum": ["client", "federation"] - }, + "presence": {"$ref": "#/definitions/filter"}, + "account_data": {"$ref": "#/definitions/filter"}, + "room": {"$ref": "#/definitions/room_filter"}, + "event_format": {"type": "string", "enum": ["client", "federation"]}, "event_fields": { "type": "array", "items": { @@ -177,26 +106,25 @@ USER_FILTER_SCHEMA = { # # Note that because this is a regular expression, we have to escape # each backslash in the pattern. - "pattern": r"^((?!\\\\).)*$" - } - } + "pattern": r"^((?!\\\\).)*$", + }, + }, }, - "additionalProperties": False + "additionalProperties": False, } -@FormatChecker.cls_checks('matrix_room_id') +@FormatChecker.cls_checks("matrix_room_id") def matrix_room_id_validator(room_id_str): return RoomID.from_string(room_id_str) -@FormatChecker.cls_checks('matrix_user_id') +@FormatChecker.cls_checks("matrix_user_id") def matrix_user_id_validator(user_id_str): return UserID.from_string(user_id_str) class Filtering(object): - def __init__(self, hs): super(Filtering, self).__init__() self.store = hs.get_datastore() @@ -228,8 +156,9 @@ class Filtering(object): # individual top-level key e.g. public_user_data. Filters are made of # many definitions. try: - jsonschema.validate(user_filter_json, USER_FILTER_SCHEMA, - format_checker=FormatChecker()) + jsonschema.validate( + user_filter_json, USER_FILTER_SCHEMA, format_checker=FormatChecker() + ) except jsonschema.ValidationError as e: raise SynapseError(400, str(e)) @@ -240,10 +169,9 @@ class FilterCollection(object): room_filter_json = self._filter_json.get("room", {}) - self._room_filter = Filter({ - k: v for k, v in room_filter_json.items() - if k in ("rooms", "not_rooms") - }) + self._room_filter = Filter( + {k: v for k, v in room_filter_json.items() if k in ("rooms", "not_rooms")} + ) self._room_timeline_filter = Filter(room_filter_json.get("timeline", {})) self._room_state_filter = Filter(room_filter_json.get("state", {})) @@ -252,9 +180,7 @@ class FilterCollection(object): self._presence_filter = Filter(filter_json.get("presence", {})) self._account_data = Filter(filter_json.get("account_data", {})) - self.include_leave = filter_json.get("room", {}).get( - "include_leave", False - ) + self.include_leave = filter_json.get("room", {}).get("include_leave", False) self.event_fields = filter_json.get("event_fields", []) self.event_format = filter_json.get("event_format", "client") @@ -299,22 +225,22 @@ class FilterCollection(object): def blocks_all_presence(self): return ( - self._presence_filter.filters_all_types() or - self._presence_filter.filters_all_senders() + self._presence_filter.filters_all_types() + or self._presence_filter.filters_all_senders() ) def blocks_all_room_ephemeral(self): return ( - self._room_ephemeral_filter.filters_all_types() or - self._room_ephemeral_filter.filters_all_senders() or - self._room_ephemeral_filter.filters_all_rooms() + self._room_ephemeral_filter.filters_all_types() + or self._room_ephemeral_filter.filters_all_senders() + or self._room_ephemeral_filter.filters_all_rooms() ) def blocks_all_room_timeline(self): return ( - self._room_timeline_filter.filters_all_types() or - self._room_timeline_filter.filters_all_senders() or - self._room_timeline_filter.filters_all_rooms() + self._room_timeline_filter.filters_all_types() + or self._room_timeline_filter.filters_all_senders() + or self._room_timeline_filter.filters_all_rooms() ) @@ -375,12 +301,7 @@ class Filter(object): # check if there is a string url field in the content for filtering purposes contains_url = isinstance(content.get("url"), text_type) - return self.check_fields( - room_id, - sender, - ev_type, - contains_url, - ) + return self.check_fields(room_id, sender, ev_type, contains_url) def check_fields(self, room_id, sender, event_type, contains_url): """Checks whether the filter matches the given event fields. @@ -391,7 +312,7 @@ class Filter(object): literal_keys = { "rooms": lambda v: room_id == v, "senders": lambda v: sender == v, - "types": lambda v: _matches_wildcard(event_type, v) + "types": lambda v: _matches_wildcard(event_type, v), } for name, match_func in literal_keys.items(): diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 296c4a1c17..172841f595 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -44,29 +44,25 @@ class Ratelimiter(object): """ self.prune_message_counts(time_now_s) message_count, time_start, _ignored = self.message_counts.get( - key, (0., time_now_s, None), + key, (0.0, time_now_s, None) ) time_delta = time_now_s - time_start sent_count = message_count - time_delta * rate_hz if sent_count < 0: allowed = True time_start = time_now_s - message_count = 1. - elif sent_count > burst_count - 1.: + message_count = 1.0 + elif sent_count > burst_count - 1.0: allowed = False else: allowed = True message_count += 1 if update: - self.message_counts[key] = ( - message_count, time_start, rate_hz - ) + self.message_counts[key] = (message_count, time_start, rate_hz) if rate_hz > 0: - time_allowed = ( - time_start + (message_count - burst_count + 1) / rate_hz - ) + time_allowed = time_start + (message_count - burst_count + 1) / rate_hz if time_allowed < time_now_s: time_allowed = time_now_s else: @@ -76,9 +72,7 @@ class Ratelimiter(object): def prune_message_counts(self, time_now_s): for key in list(self.message_counts.keys()): - message_count, time_start, rate_hz = ( - self.message_counts[key] - ) + message_count, time_start, rate_hz = self.message_counts[key] time_delta = time_now_s - time_start if message_count - time_delta * rate_hz > 0: break @@ -92,5 +86,5 @@ class Ratelimiter(object): if not allowed: raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now_s)), + retry_after_ms=int(1000 * (time_allowed - time_now_s)) ) diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index d644803d38..95292b7dec 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -19,9 +19,10 @@ class EventFormatVersions(object): """This is an internal enum for tracking the version of the event format, independently from the room version. """ - V1 = 1 # $id:server event id format - V2 = 2 # MSC1659-style $hash event id format: introduced for room v3 - V3 = 3 # MSC1884-style $hash format: introduced for room v4 + + V1 = 1 # $id:server event id format + V2 = 2 # MSC1659-style $hash event id format: introduced for room v3 + V3 = 3 # MSC1884-style $hash format: introduced for room v4 KNOWN_EVENT_FORMAT_VERSIONS = { @@ -33,8 +34,9 @@ KNOWN_EVENT_FORMAT_VERSIONS = { class StateResolutionVersions(object): """Enum to identify the state resolution algorithms""" - V1 = 1 # room v1 state res - V2 = 2 # MSC1442 state res: room v2 and later + + V1 = 1 # room v1 state res + V2 = 2 # MSC1442 state res: room v2 and later class RoomDisposition(object): @@ -46,10 +48,10 @@ class RoomDisposition(object): class RoomVersion(object): """An object which describes the unique attributes of a room version.""" - identifier = attr.ib() # str; the identifier for this version - disposition = attr.ib() # str; one of the RoomDispositions - event_format = attr.ib() # int; one of the EventFormatVersions - state_res = attr.ib() # int; one of the StateResolutionVersions + identifier = attr.ib() # str; the identifier for this version + disposition = attr.ib() # str; one of the RoomDispositions + event_format = attr.ib() # int; one of the EventFormatVersions + state_res = attr.ib() # int; one of the StateResolutionVersions enforce_key_validity = attr.ib() # bool @@ -92,11 +94,12 @@ class RoomVersions(object): KNOWN_ROOM_VERSIONS = { - v.identifier: v for v in ( + v.identifier: v + for v in ( RoomVersions.V1, RoomVersions.V2, RoomVersions.V3, RoomVersions.V4, RoomVersions.V5, ) -} # type: dict[str, RoomVersion] +} # type: dict[str, RoomVersion] diff --git a/synapse/api/urls.py b/synapse/api/urls.py index e16c386a14..ff1f39e86c 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -42,13 +42,9 @@ class ConsentURIBuilder(object): hs_config (synapse.config.homeserver.HomeServerConfig): """ if hs_config.form_secret is None: - raise ConfigError( - "form_secret not set in config", - ) + raise ConfigError("form_secret not set in config") if hs_config.public_baseurl is None: - raise ConfigError( - "public_baseurl not set in config", - ) + raise ConfigError("public_baseurl not set in config") self._hmac_secret = hs_config.form_secret.encode("utf-8") self._public_baseurl = hs_config.public_baseurl @@ -64,15 +60,10 @@ class ConsentURIBuilder(object): (str) the URI where the user can do consent """ mac = hmac.new( - key=self._hmac_secret, - msg=user_id.encode('ascii'), - digestmod=sha256, + key=self._hmac_secret, msg=user_id.encode("ascii"), digestmod=sha256 ).hexdigest() consent_uri = "%s_matrix/consent?%s" % ( self._public_baseurl, - urlencode({ - "u": user_id, - "h": mac - }), + urlencode({"u": user_id, "h": mac}), ) return consent_uri diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py index f56f5fcc13..d877c77834 100644 --- a/synapse/app/__init__.py +++ b/synapse/app/__init__.py @@ -43,7 +43,7 @@ def check_bind_error(e, address, bind_addresses): address (str): Address on which binding was attempted. bind_addresses (list): Addresses on which the service listens. """ - if address == '0.0.0.0' and '::' in bind_addresses: - logger.warn('Failed to listen on 0.0.0.0, continuing because listening on [::]') + if address == "0.0.0.0" and "::" in bind_addresses: + logger.warn("Failed to listen on 0.0.0.0, continuing because listening on [::]") else: raise e diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 8cc990399f..df4c2d4c97 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -75,14 +75,14 @@ def start_worker_reactor(appname, config): def start_reactor( - appname, - soft_file_limit, - gc_thresholds, - pid_file, - daemonize, - cpu_affinity, - print_pidfile, - logger, + appname, + soft_file_limit, + gc_thresholds, + pid_file, + daemonize, + cpu_affinity, + print_pidfile, + logger, ): """ Run the reactor in the main process @@ -149,10 +149,10 @@ def start_reactor( def quit_with_error(error_string): message_lines = error_string.split("\n") line_length = max([len(l) for l in message_lines if len(l) < 80]) + 2 - sys.stderr.write("*" * line_length + '\n') + sys.stderr.write("*" * line_length + "\n") for line in message_lines: sys.stderr.write(" %s\n" % (line.rstrip(),)) - sys.stderr.write("*" * line_length + '\n') + sys.stderr.write("*" * line_length + "\n") sys.exit(1) @@ -178,14 +178,7 @@ def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50): r = [] for address in bind_addresses: try: - r.append( - reactor.listenTCP( - port, - factory, - backlog, - address - ) - ) + r.append(reactor.listenTCP(port, factory, backlog, address)) except error.CannotListenError as e: check_bind_error(e, address, bind_addresses) @@ -205,13 +198,7 @@ def listen_ssl( for address in bind_addresses: try: r.append( - reactor.listenSSL( - port, - factory, - context_factory, - backlog, - address - ) + reactor.listenSSL(port, factory, context_factory, backlog, address) ) except error.CannotListenError as e: check_bind_error(e, address, bind_addresses) @@ -243,15 +230,13 @@ def refresh_certificate(hs): if isinstance(i.factory, TLSMemoryBIOFactory): addr = i.getHost() logger.info( - "Replacing TLS context factory on [%s]:%i", addr.host, addr.port, + "Replacing TLS context factory on [%s]:%i", addr.host, addr.port ) # We want to replace TLS factories with a new one, with the new # TLS configuration. We do this by reaching in and pulling out # the wrappedFactory, and then re-wrapping it. i.factory = TLSMemoryBIOFactory( - hs.tls_server_context_factory, - False, - i.factory.wrappedFactory + hs.tls_server_context_factory, False, i.factory.wrappedFactory ) logger.info("Context factories updated.") @@ -267,6 +252,7 @@ def start(hs, listeners=None): try: # Set up the SIGHUP machinery. if hasattr(signal, "SIGHUP"): + def handle_sighup(*args, **kwargs): for i in _sighup_callbacks: i(hs) @@ -302,10 +288,8 @@ def setup_sentry(hs): return import sentry_sdk - sentry_sdk.init( - dsn=hs.config.sentry_dsn, - release=get_version_string(synapse), - ) + + sentry_sdk.init(dsn=hs.config.sentry_dsn, release=get_version_string(synapse)) # We set some default tags that give some context to this instance with sentry_sdk.configure_scope() as scope: @@ -326,7 +310,7 @@ def install_dns_limiter(reactor, max_dns_requests_in_flight=100): many DNS queries at once """ new_resolver = _LimitedHostnameResolver( - reactor.nameResolver, max_dns_requests_in_flight, + reactor.nameResolver, max_dns_requests_in_flight ) reactor.installNameResolver(new_resolver) @@ -339,11 +323,17 @@ class _LimitedHostnameResolver(object): def __init__(self, resolver, max_dns_requests_in_flight): self._resolver = resolver self._limiter = Linearizer( - name="dns_client_limiter", max_count=max_dns_requests_in_flight, + name="dns_client_limiter", max_count=max_dns_requests_in_flight ) - def resolveHostName(self, resolutionReceiver, hostName, portNumber=0, - addressTypes=None, transportSemantics='TCP'): + def resolveHostName( + self, + resolutionReceiver, + hostName, + portNumber=0, + addressTypes=None, + transportSemantics="TCP", + ): # We need this function to return `resolutionReceiver` so we do all the # actual logic involving deferreds in a separate function. @@ -363,8 +353,14 @@ class _LimitedHostnameResolver(object): return resolutionReceiver @defer.inlineCallbacks - def _resolve(self, resolutionReceiver, hostName, portNumber=0, - addressTypes=None, transportSemantics='TCP'): + def _resolve( + self, + resolutionReceiver, + hostName, + portNumber=0, + addressTypes=None, + transportSemantics="TCP", + ): with (yield self._limiter.queue(())): # resolveHostName doesn't return a Deferred, so we need to hook into @@ -374,8 +370,7 @@ class _LimitedHostnameResolver(object): receiver = _DeferredResolutionReceiver(resolutionReceiver, deferred) self._resolver.resolveHostName( - receiver, hostName, portNumber, - addressTypes, transportSemantics, + receiver, hostName, portNumber, addressTypes, transportSemantics ) yield deferred diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index 33107f56d1..9120bdb143 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -44,7 +44,9 @@ logger = logging.getLogger("synapse.app.appservice") class AppserviceSlaveStore( - DirectoryStore, SlavedEventStore, SlavedApplicationServiceStore, + DirectoryStore, + SlavedEventStore, + SlavedApplicationServiceStore, SlavedRegistrationStore, ): pass @@ -74,7 +76,7 @@ class AppserviceServer(HomeServer): listener_config, root_resource, self.version_string, - ) + ), ) logger.info("Synapse appservice now listening on port %d", port) @@ -88,18 +90,19 @@ class AppserviceServer(HomeServer): listener["bind_addresses"], listener["port"], manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ) + username="matrix", password="rabbithole", globals={"hs": self} + ), ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn(("Metrics listener configured, but " - "enable_metrics is not True!")) + logger.warn( + ( + "Metrics listener configured, but " + "enable_metrics is not True!" + ) + ) else: - _base.listen_metrics(listener["bind_addresses"], - listener["port"]) + _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -132,9 +135,7 @@ class ASReplicationHandler(ReplicationClientHandler): def start(config_options): try: - config = HomeServerConfig.load_config( - "Synapse appservice", config_options - ) + config = HomeServerConfig.load_config("Synapse appservice", config_options) except ConfigError as e: sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) @@ -173,6 +174,6 @@ def start(config_options): _base.start_worker_reactor("synapse-appservice", config) -if __name__ == '__main__': +if __name__ == "__main__": with LoggingContext("main"): start(sys.argv[1:]) diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index a16e037f32..310cdab2e4 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -118,9 +118,7 @@ class ClientReaderServer(HomeServer): PushRuleRestServlet(self).register(resource) VersionsRestServlet().register(resource) - resources.update({ - "/_matrix/client": resource, - }) + resources.update({"/_matrix/client": resource}) root_resource = create_resource_tree(resources, NoResource()) @@ -133,7 +131,7 @@ class ClientReaderServer(HomeServer): listener_config, root_resource, self.version_string, - ) + ), ) logger.info("Synapse client reader now listening on port %d", port) @@ -147,18 +145,19 @@ class ClientReaderServer(HomeServer): listener["bind_addresses"], listener["port"], manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ) + username="matrix", password="rabbithole", globals={"hs": self} + ), ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn(("Metrics listener configured, but " - "enable_metrics is not True!")) + logger.warn( + ( + "Metrics listener configured, but " + "enable_metrics is not True!" + ) + ) else: - _base.listen_metrics(listener["bind_addresses"], - listener["port"]) + _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -170,9 +169,7 @@ class ClientReaderServer(HomeServer): def start(config_options): try: - config = HomeServerConfig.load_config( - "Synapse client reader", config_options - ) + config = HomeServerConfig.load_config("Synapse client reader", config_options) except ConfigError as e: sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) @@ -199,6 +196,6 @@ def start(config_options): _base.start_worker_reactor("synapse-client-reader", config) -if __name__ == '__main__': +if __name__ == "__main__": with LoggingContext("main"): start(sys.argv[1:]) diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py index b8e5196152..ff522e4499 100644 --- a/synapse/app/event_creator.py +++ b/synapse/app/event_creator.py @@ -109,12 +109,14 @@ class EventCreatorServer(HomeServer): ProfileAvatarURLRestServlet(self).register(resource) ProfileDisplaynameRestServlet(self).register(resource) ProfileRestServlet(self).register(resource) - resources.update({ - "/_matrix/client/r0": resource, - "/_matrix/client/unstable": resource, - "/_matrix/client/v2_alpha": resource, - "/_matrix/client/api/v1": resource, - }) + resources.update( + { + "/_matrix/client/r0": resource, + "/_matrix/client/unstable": resource, + "/_matrix/client/v2_alpha": resource, + "/_matrix/client/api/v1": resource, + } + ) root_resource = create_resource_tree(resources, NoResource()) @@ -127,7 +129,7 @@ class EventCreatorServer(HomeServer): listener_config, root_resource, self.version_string, - ) + ), ) logger.info("Synapse event creator now listening on port %d", port) @@ -141,18 +143,19 @@ class EventCreatorServer(HomeServer): listener["bind_addresses"], listener["port"], manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ) + username="matrix", password="rabbithole", globals={"hs": self} + ), ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn(("Metrics listener configured, but " - "enable_metrics is not True!")) + logger.warn( + ( + "Metrics listener configured, but " + "enable_metrics is not True!" + ) + ) else: - _base.listen_metrics(listener["bind_addresses"], - listener["port"]) + _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -164,9 +167,7 @@ class EventCreatorServer(HomeServer): def start(config_options): try: - config = HomeServerConfig.load_config( - "Synapse event creator", config_options - ) + config = HomeServerConfig.load_config("Synapse event creator", config_options) except ConfigError as e: sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) @@ -198,6 +199,6 @@ def start(config_options): _base.start_worker_reactor("synapse-event-creator", config) -if __name__ == '__main__': +if __name__ == "__main__": with LoggingContext("main"): start(sys.argv[1:]) diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index 7da79dc827..9421420930 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -86,19 +86,18 @@ class FederationReaderServer(HomeServer): if name == "metrics": resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) elif name == "federation": - resources.update({ - FEDERATION_PREFIX: TransportLayerServer(self), - }) + resources.update({FEDERATION_PREFIX: TransportLayerServer(self)}) if name == "openid" and "federation" not in res["names"]: # Only load the openid resource separately if federation resource # is not specified since federation resource includes openid # resource. - resources.update({ - FEDERATION_PREFIX: TransportLayerServer( - self, - servlet_groups=["openid"], - ), - }) + resources.update( + { + FEDERATION_PREFIX: TransportLayerServer( + self, servlet_groups=["openid"] + ) + } + ) if name in ["keys", "federation"]: resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self) @@ -115,7 +114,7 @@ class FederationReaderServer(HomeServer): root_resource, self.version_string, ), - reactor=self.get_reactor() + reactor=self.get_reactor(), ) logger.info("Synapse federation reader now listening on port %d", port) @@ -129,18 +128,19 @@ class FederationReaderServer(HomeServer): listener["bind_addresses"], listener["port"], manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ) + username="matrix", password="rabbithole", globals={"hs": self} + ), ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn(("Metrics listener configured, but " - "enable_metrics is not True!")) + logger.warn( + ( + "Metrics listener configured, but " + "enable_metrics is not True!" + ) + ) else: - _base.listen_metrics(listener["bind_addresses"], - listener["port"]) + _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -181,6 +181,6 @@ def start(config_options): _base.start_worker_reactor("synapse-federation-reader", config) -if __name__ == '__main__': +if __name__ == "__main__": with LoggingContext("main"): start(sys.argv[1:]) diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 1d43f2b075..969be58d0b 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -52,8 +52,13 @@ logger = logging.getLogger("synapse.app.federation_sender") class FederationSenderSlaveStore( - SlavedDeviceInboxStore, SlavedTransactionStore, SlavedReceiptsStore, SlavedEventStore, - SlavedRegistrationStore, SlavedDeviceStore, SlavedPresenceStore, + SlavedDeviceInboxStore, + SlavedTransactionStore, + SlavedReceiptsStore, + SlavedEventStore, + SlavedRegistrationStore, + SlavedDeviceStore, + SlavedPresenceStore, ): def __init__(self, db_conn, hs): super(FederationSenderSlaveStore, self).__init__(db_conn, hs) @@ -65,10 +70,7 @@ class FederationSenderSlaveStore( self.federation_out_pos_startup = self._get_federation_out_pos(db_conn) def _get_federation_out_pos(self, db_conn): - sql = ( - "SELECT stream_id FROM federation_stream_position" - " WHERE type = ?" - ) + sql = "SELECT stream_id FROM federation_stream_position" " WHERE type = ?" sql = self.database_engine.convert_param_style(sql) txn = db_conn.cursor() @@ -103,7 +105,7 @@ class FederationSenderServer(HomeServer): listener_config, root_resource, self.version_string, - ) + ), ) logger.info("Synapse federation_sender now listening on port %d", port) @@ -117,18 +119,19 @@ class FederationSenderServer(HomeServer): listener["bind_addresses"], listener["port"], manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ) + username="matrix", password="rabbithole", globals={"hs": self} + ), ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn(("Metrics listener configured, but " - "enable_metrics is not True!")) + logger.warn( + ( + "Metrics listener configured, but " + "enable_metrics is not True!" + ) + ) else: - _base.listen_metrics(listener["bind_addresses"], - listener["port"]) + _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -151,7 +154,9 @@ class FederationSenderReplicationHandler(ReplicationClientHandler): self.send_handler.process_replication_rows(stream_name, token, rows) def get_streams_to_replicate(self): - args = super(FederationSenderReplicationHandler, self).get_streams_to_replicate() + args = super( + FederationSenderReplicationHandler, self + ).get_streams_to_replicate() args.update(self.send_handler.stream_positions()) return args @@ -203,6 +208,7 @@ class FederationSenderHandler(object): """Processes the replication stream and forwards the appropriate entries to the federation sender. """ + def __init__(self, hs, replication_client): self.store = hs.get_datastore() self._is_mine_id = hs.is_mine_id @@ -241,7 +247,7 @@ class FederationSenderHandler(object): # ... and when new receipts happen elif stream_name == ReceiptsStream.NAME: run_as_background_process( - "process_receipts_for_federation", self._on_new_receipts, rows, + "process_receipts_for_federation", self._on_new_receipts, rows ) @defer.inlineCallbacks @@ -278,12 +284,14 @@ class FederationSenderHandler(object): # We ACK this token over replication so that the master can drop # its in memory queues - self.replication_client.send_federation_ack(self.federation_position) + self.replication_client.send_federation_ack( + self.federation_position + ) self._last_ack = self.federation_position except Exception: logger.exception("Error updating federation stream position") -if __name__ == '__main__': +if __name__ == "__main__": with LoggingContext("main"): start(sys.argv[1:]) diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index 6504da5278..2fd7d57ebf 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -62,14 +62,11 @@ class PresenceStatusStubServlet(RestServlet): # Pass through the auth headers, if any, in case the access token # is there. auth_headers = request.requestHeaders.getRawHeaders("Authorization", []) - headers = { - "Authorization": auth_headers, - } + headers = {"Authorization": auth_headers} try: result = yield self.http_client.get_json( - self.main_uri + request.uri.decode('ascii'), - headers=headers, + self.main_uri + request.uri.decode("ascii"), headers=headers ) except HttpResponseException as e: raise e.to_synapse_error() @@ -105,18 +102,19 @@ class KeyUploadServlet(RestServlet): if device_id is not None: # passing the device_id here is deprecated; however, we allow it # for now for compatibility with older clients. - if (requester.device_id is not None and - device_id != requester.device_id): - logger.warning("Client uploading keys for a different device " - "(logged in as %s, uploading for %s)", - requester.device_id, device_id) + if requester.device_id is not None and device_id != requester.device_id: + logger.warning( + "Client uploading keys for a different device " + "(logged in as %s, uploading for %s)", + requester.device_id, + device_id, + ) else: device_id = requester.device_id if device_id is None: raise SynapseError( - 400, - "To upload keys, you must pass device_id when authenticating" + 400, "To upload keys, you must pass device_id when authenticating" ) if body: @@ -124,13 +122,9 @@ class KeyUploadServlet(RestServlet): # Pass through the auth headers, if any, in case the access token # is there. auth_headers = request.requestHeaders.getRawHeaders(b"Authorization", []) - headers = { - "Authorization": auth_headers, - } + headers = {"Authorization": auth_headers} result = yield self.http_client.post_json_get_json( - self.main_uri + request.uri.decode('ascii'), - body, - headers=headers, + self.main_uri + request.uri.decode("ascii"), body, headers=headers ) defer.returnValue((200, result)) @@ -171,12 +165,14 @@ class FrontendProxyServer(HomeServer): if not self.config.use_presence: PresenceStatusStubServlet(self).register(resource) - resources.update({ - "/_matrix/client/r0": resource, - "/_matrix/client/unstable": resource, - "/_matrix/client/v2_alpha": resource, - "/_matrix/client/api/v1": resource, - }) + resources.update( + { + "/_matrix/client/r0": resource, + "/_matrix/client/unstable": resource, + "/_matrix/client/v2_alpha": resource, + "/_matrix/client/api/v1": resource, + } + ) root_resource = create_resource_tree(resources, NoResource()) @@ -190,7 +186,7 @@ class FrontendProxyServer(HomeServer): root_resource, self.version_string, ), - reactor=self.get_reactor() + reactor=self.get_reactor(), ) logger.info("Synapse client reader now listening on port %d", port) @@ -204,18 +200,19 @@ class FrontendProxyServer(HomeServer): listener["bind_addresses"], listener["port"], manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ) + username="matrix", password="rabbithole", globals={"hs": self} + ), ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn(("Metrics listener configured, but " - "enable_metrics is not True!")) + logger.warn( + ( + "Metrics listener configured, but " + "enable_metrics is not True!" + ) + ) else: - _base.listen_metrics(listener["bind_addresses"], - listener["port"]) + _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -227,9 +224,7 @@ class FrontendProxyServer(HomeServer): def start(config_options): try: - config = HomeServerConfig.load_config( - "Synapse frontend proxy", config_options - ) + config = HomeServerConfig.load_config("Synapse frontend proxy", config_options) except ConfigError as e: sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) @@ -258,6 +253,6 @@ def start(config_options): _base.start_worker_reactor("synapse-frontend-proxy", config) -if __name__ == '__main__': +if __name__ == "__main__": with LoggingContext("main"): start(sys.argv[1:]) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index b27b12e73d..d19c7c7d71 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -101,13 +101,12 @@ class SynapseHomeServer(HomeServer): # Skip loading openid resource if federation is defined # since federation resource will include openid continue - resources.update(self._configure_named_resource( - name, res.get("compress", False), - )) + resources.update( + self._configure_named_resource(name, res.get("compress", False)) + ) additional_resources = listener_config.get("additional_resources", {}) - logger.debug("Configuring additional resources: %r", - additional_resources) + logger.debug("Configuring additional resources: %r", additional_resources) module_api = ModuleApi(self, self.get_auth_handler()) for path, resmodule in additional_resources.items(): handler_cls, config = load_module(resmodule) @@ -174,59 +173,67 @@ class SynapseHomeServer(HomeServer): if compress: client_resource = gz_wrap(client_resource) - resources.update({ - "/_matrix/client/api/v1": client_resource, - "/_matrix/client/r0": client_resource, - "/_matrix/client/unstable": client_resource, - "/_matrix/client/v2_alpha": client_resource, - "/_matrix/client/versions": client_resource, - "/.well-known/matrix/client": WellKnownResource(self), - "/_synapse/admin": AdminRestResource(self), - }) + resources.update( + { + "/_matrix/client/api/v1": client_resource, + "/_matrix/client/r0": client_resource, + "/_matrix/client/unstable": client_resource, + "/_matrix/client/v2_alpha": client_resource, + "/_matrix/client/versions": client_resource, + "/.well-known/matrix/client": WellKnownResource(self), + "/_synapse/admin": AdminRestResource(self), + } + ) if self.get_config().saml2_enabled: from synapse.rest.saml2 import SAML2Resource + resources["/_matrix/saml2"] = SAML2Resource(self) if name == "consent": from synapse.rest.consent.consent_resource import ConsentResource + consent_resource = ConsentResource(self) if compress: consent_resource = gz_wrap(consent_resource) - resources.update({ - "/_matrix/consent": consent_resource, - }) + resources.update({"/_matrix/consent": consent_resource}) if name == "federation": - resources.update({ - FEDERATION_PREFIX: TransportLayerServer(self), - }) + resources.update({FEDERATION_PREFIX: TransportLayerServer(self)}) if name == "openid": - resources.update({ - FEDERATION_PREFIX: TransportLayerServer(self, servlet_groups=["openid"]), - }) + resources.update( + { + FEDERATION_PREFIX: TransportLayerServer( + self, servlet_groups=["openid"] + ) + } + ) if name in ["static", "client"]: - resources.update({ - STATIC_PREFIX: File( - os.path.join(os.path.dirname(synapse.__file__), "static") - ), - }) + resources.update( + { + STATIC_PREFIX: File( + os.path.join(os.path.dirname(synapse.__file__), "static") + ) + } + ) if name in ["media", "federation", "client"]: if self.get_config().enable_media_repo: media_repo = self.get_media_repository_resource() - resources.update({ - MEDIA_PREFIX: media_repo, - LEGACY_MEDIA_PREFIX: media_repo, - CONTENT_REPO_PREFIX: ContentRepoResource( - self, self.config.uploads_path - ), - }) + resources.update( + { + MEDIA_PREFIX: media_repo, + LEGACY_MEDIA_PREFIX: media_repo, + CONTENT_REPO_PREFIX: ContentRepoResource( + self, self.config.uploads_path + ), + } + ) elif name == "media": raise ConfigError( - "'media' resource conflicts with enable_media_repo=False", + "'media' resource conflicts with enable_media_repo=False" ) if name in ["keys", "federation"]: @@ -257,18 +264,14 @@ class SynapseHomeServer(HomeServer): for listener in listeners: if listener["type"] == "http": - self._listening_services.extend( - self._listener_http(config, listener) - ) + self._listening_services.extend(self._listener_http(config, listener)) elif listener["type"] == "manhole": listen_tcp( listener["bind_addresses"], listener["port"], manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ) + username="matrix", password="rabbithole", globals={"hs": self} + ), ) elif listener["type"] == "replication": services = listen_tcp( @@ -277,16 +280,17 @@ class SynapseHomeServer(HomeServer): ReplicationStreamProtocolFactory(self), ) for s in services: - reactor.addSystemEventTrigger( - "before", "shutdown", s.stopListening, - ) + reactor.addSystemEventTrigger("before", "shutdown", s.stopListening) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn(("Metrics listener configured, but " - "enable_metrics is not True!")) + logger.warn( + ( + "Metrics listener configured, but " + "enable_metrics is not True!" + ) + ) else: - _base.listen_metrics(listener["bind_addresses"], - listener["port"]) + _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -312,7 +316,7 @@ current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU") max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit") registered_reserved_users_mau_gauge = Gauge( "synapse_admin_mau:registered_reserved_users", - "Registered users with reserved threepids" + "Registered users with reserved threepids", ) @@ -327,8 +331,7 @@ def setup(config_options): """ try: config = HomeServerConfig.load_or_generate_config( - "Synapse Homeserver", - config_options, + "Synapse Homeserver", config_options ) except ConfigError as e: sys.stderr.write("\n" + str(e) + "\n") @@ -339,10 +342,7 @@ def setup(config_options): # generating config files and shouldn't try to continue. sys.exit(0) - synapse.config.logger.setup_logging( - config, - use_worker_options=False - ) + synapse.config.logger.setup_logging(config, use_worker_options=False) events.USE_FROZEN_DICTS = config.use_frozen_dicts @@ -357,7 +357,7 @@ def setup(config_options): database_engine=database_engine, ) - logger.info("Preparing database: %s...", config.database_config['name']) + logger.info("Preparing database: %s...", config.database_config["name"]) try: with hs.get_db_conn(run_new_connection=False) as db_conn: @@ -375,7 +375,7 @@ def setup(config_options): ) sys.exit(1) - logger.info("Database prepared in %s.", config.database_config['name']) + logger.info("Database prepared in %s.", config.database_config["name"]) hs.setup() hs.setup_master() @@ -391,9 +391,7 @@ def setup(config_options): acme = hs.get_acme_handler() # Check how long the certificate is active for. - cert_days_remaining = hs.config.is_disk_cert_valid( - allow_self_signed=False - ) + cert_days_remaining = hs.config.is_disk_cert_valid(allow_self_signed=False) # We want to reprovision if cert_days_remaining is None (meaning no # certificate exists), or the days remaining number it returns @@ -401,8 +399,8 @@ def setup(config_options): provision = False if ( - cert_days_remaining is None or - cert_days_remaining < hs.config.acme_reprovision_threshold + cert_days_remaining is None + or cert_days_remaining < hs.config.acme_reprovision_threshold ): provision = True @@ -433,10 +431,7 @@ def setup(config_options): yield do_acme() # Check if it needs to be reprovisioned every day. - hs.get_clock().looping_call( - reprovision_acme, - 24 * 60 * 60 * 1000 - ) + hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000) _base.start(hs, config.listeners) @@ -463,6 +458,7 @@ class SynapseService(service.Service): A twisted Service class that will start synapse. Used to run synapse via twistd and a .tac. """ + def __init__(self, config): self.config = config @@ -479,6 +475,7 @@ class SynapseService(service.Service): def run(hs): PROFILE_SYNAPSE = False if PROFILE_SYNAPSE: + def profile(func): from cProfile import Profile from threading import current_thread @@ -489,13 +486,14 @@ def run(hs): func(*args, **kargs) profile.disable() ident = current_thread().ident - profile.dump_stats("/tmp/%s.%s.%i.pstat" % ( - hs.hostname, func.__name__, ident - )) + profile.dump_stats( + "/tmp/%s.%s.%i.pstat" % (hs.hostname, func.__name__, ident) + ) return profiled from twisted.python.threadpool import ThreadPool + ThreadPool._worker = profile(ThreadPool._worker) reactor.run = profile(reactor.run) @@ -541,7 +539,9 @@ def run(hs): stats["daily_active_users"] = yield hs.get_datastore().count_daily_users() stats["monthly_active_users"] = yield hs.get_datastore().count_monthly_users() - stats["daily_active_rooms"] = yield hs.get_datastore().count_daily_active_rooms() + stats[ + "daily_active_rooms" + ] = yield hs.get_datastore().count_daily_active_rooms() stats["daily_messages"] = yield hs.get_datastore().count_daily_messages() r30_results = yield hs.get_datastore().count_r30_users() @@ -565,8 +565,7 @@ def run(hs): logger.info("Reporting stats to matrix.org: %s" % (stats,)) try: yield hs.get_simple_http_client().put_json( - "https://matrix.org/report-usage-stats/push", - stats + "https://matrix.org/report-usage-stats/push", stats ) except Exception as e: logger.warn("Error reporting stats: %s", e) @@ -581,14 +580,11 @@ def run(hs): logger.info("report_stats can use psutil") stats_process.append(process) except (AttributeError): - logger.warning( - "Unable to read memory/cpu stats. Disabling reporting." - ) + logger.warning("Unable to read memory/cpu stats. Disabling reporting.") def generate_user_daily_visit_stats(): return run_as_background_process( - "generate_user_daily_visits", - hs.get_datastore().generate_user_daily_visits, + "generate_user_daily_visits", hs.get_datastore().generate_user_daily_visits ) # Rather than update on per session basis, batch up the requests. @@ -599,9 +595,9 @@ def run(hs): # monthly active user limiting functionality def reap_monthly_active_users(): return run_as_background_process( - "reap_monthly_active_users", - hs.get_datastore().reap_monthly_active_users, + "reap_monthly_active_users", hs.get_datastore().reap_monthly_active_users ) + clock.looping_call(reap_monthly_active_users, 1000 * 60 * 60) reap_monthly_active_users() @@ -619,8 +615,7 @@ def run(hs): def start_generate_monthly_active_users(): return run_as_background_process( - "generate_monthly_active_users", - generate_monthly_active_users, + "generate_monthly_active_users", generate_monthly_active_users ) start_generate_monthly_active_users() @@ -660,5 +655,5 @@ def main(): run(hs) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index d4cc4e9443..cf0e2036c3 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -72,13 +72,15 @@ class MediaRepositoryServer(HomeServer): resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) elif name == "media": media_repo = self.get_media_repository_resource() - resources.update({ - MEDIA_PREFIX: media_repo, - LEGACY_MEDIA_PREFIX: media_repo, - CONTENT_REPO_PREFIX: ContentRepoResource( - self, self.config.uploads_path - ), - }) + resources.update( + { + MEDIA_PREFIX: media_repo, + LEGACY_MEDIA_PREFIX: media_repo, + CONTENT_REPO_PREFIX: ContentRepoResource( + self, self.config.uploads_path + ), + } + ) root_resource = create_resource_tree(resources, NoResource()) @@ -91,7 +93,7 @@ class MediaRepositoryServer(HomeServer): listener_config, root_resource, self.version_string, - ) + ), ) logger.info("Synapse media repository now listening on port %d", port) @@ -105,18 +107,19 @@ class MediaRepositoryServer(HomeServer): listener["bind_addresses"], listener["port"], manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ) + username="matrix", password="rabbithole", globals={"hs": self} + ), ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn(("Metrics listener configured, but " - "enable_metrics is not True!")) + logger.warn( + ( + "Metrics listener configured, but " + "enable_metrics is not True!" + ) + ) else: - _base.listen_metrics(listener["bind_addresses"], - listener["port"]) + _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -164,6 +167,6 @@ def start(config_options): _base.start_worker_reactor("synapse-media-repository", config) -if __name__ == '__main__': +if __name__ == "__main__": with LoggingContext("main"): start(sys.argv[1:]) diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index cbf0d67f51..df29ea5ecb 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -46,36 +46,27 @@ logger = logging.getLogger("synapse.app.pusher") class PusherSlaveStore( - SlavedEventStore, SlavedPusherStore, SlavedReceiptsStore, - SlavedAccountDataStore + SlavedEventStore, SlavedPusherStore, SlavedReceiptsStore, SlavedAccountDataStore ): - update_pusher_last_stream_ordering_and_success = ( - __func__(DataStore.update_pusher_last_stream_ordering_and_success) + update_pusher_last_stream_ordering_and_success = __func__( + DataStore.update_pusher_last_stream_ordering_and_success ) - update_pusher_failing_since = ( - __func__(DataStore.update_pusher_failing_since) - ) + update_pusher_failing_since = __func__(DataStore.update_pusher_failing_since) - update_pusher_last_stream_ordering = ( - __func__(DataStore.update_pusher_last_stream_ordering) + update_pusher_last_stream_ordering = __func__( + DataStore.update_pusher_last_stream_ordering ) - get_throttle_params_by_room = ( - __func__(DataStore.get_throttle_params_by_room) - ) + get_throttle_params_by_room = __func__(DataStore.get_throttle_params_by_room) - set_throttle_params = ( - __func__(DataStore.set_throttle_params) - ) + set_throttle_params = __func__(DataStore.set_throttle_params) - get_time_of_last_push_action_before = ( - __func__(DataStore.get_time_of_last_push_action_before) + get_time_of_last_push_action_before = __func__( + DataStore.get_time_of_last_push_action_before ) - get_profile_displayname = ( - __func__(DataStore.get_profile_displayname) - ) + get_profile_displayname = __func__(DataStore.get_profile_displayname) class PusherServer(HomeServer): @@ -105,7 +96,7 @@ class PusherServer(HomeServer): listener_config, root_resource, self.version_string, - ) + ), ) logger.info("Synapse pusher now listening on port %d", port) @@ -119,18 +110,19 @@ class PusherServer(HomeServer): listener["bind_addresses"], listener["port"], manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ) + username="matrix", password="rabbithole", globals={"hs": self} + ), ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn(("Metrics listener configured, but " - "enable_metrics is not True!")) + logger.warn( + ( + "Metrics listener configured, but " + "enable_metrics is not True!" + ) + ) else: - _base.listen_metrics(listener["bind_addresses"], - listener["port"]) + _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -161,9 +153,7 @@ class PusherReplicationHandler(ReplicationClientHandler): else: yield self.start_pusher(row.user_id, row.app_id, row.pushkey) elif stream_name == "events": - yield self.pusher_pool.on_new_notifications( - token, token, - ) + yield self.pusher_pool.on_new_notifications(token, token) elif stream_name == "receipts": yield self.pusher_pool.on_new_receipts( token, token, set(row.room_id for row in rows) @@ -188,9 +178,7 @@ class PusherReplicationHandler(ReplicationClientHandler): def start(config_options): try: - config = HomeServerConfig.load_config( - "Synapse pusher", config_options - ) + config = HomeServerConfig.load_config("Synapse pusher", config_options) except ConfigError as e: sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) @@ -234,6 +222,6 @@ def start(config_options): _base.start_worker_reactor("synapse-pusher", config) -if __name__ == '__main__': +if __name__ == "__main__": with LoggingContext("main"): ps = start(sys.argv[1:]) diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 5388def28a..858949910d 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -98,10 +98,7 @@ class SynchrotronPresence(object): self.notifier = hs.get_notifier() active_presence = self.store.take_presence_startup_info() - self.user_to_current_state = { - state.user_id: state - for state in active_presence - } + self.user_to_current_state = {state.user_id: state for state in active_presence} # user_id -> last_sync_ms. Lists the users that have stopped syncing # but we haven't notified the master of that yet @@ -196,17 +193,26 @@ class SynchrotronPresence(object): room_ids_to_states, users_to_states = parties self.notifier.on_new_event( - "presence_key", stream_id, rooms=room_ids_to_states.keys(), - users=users_to_states.keys() + "presence_key", + stream_id, + rooms=room_ids_to_states.keys(), + users=users_to_states.keys(), ) @defer.inlineCallbacks def process_replication_rows(self, token, rows): - states = [UserPresenceState( - row.user_id, row.state, row.last_active_ts, - row.last_federation_update_ts, row.last_user_sync_ts, row.status_msg, - row.currently_active - ) for row in rows] + states = [ + UserPresenceState( + row.user_id, + row.state, + row.last_active_ts, + row.last_federation_update_ts, + row.last_user_sync_ts, + row.status_msg, + row.currently_active, + ) + for row in rows + ] for state in states: self.user_to_current_state[state.user_id] = state @@ -217,7 +223,8 @@ class SynchrotronPresence(object): def get_currently_syncing_users(self): if self.hs.config.use_presence: return [ - user_id for user_id, count in iteritems(self.user_to_num_current_syncs) + user_id + for user_id, count in iteritems(self.user_to_num_current_syncs) if count > 0 ] else: @@ -281,12 +288,14 @@ class SynchrotronServer(HomeServer): events.register_servlets(self, resource) InitialSyncRestServlet(self).register(resource) RoomInitialSyncRestServlet(self).register(resource) - resources.update({ - "/_matrix/client/r0": resource, - "/_matrix/client/unstable": resource, - "/_matrix/client/v2_alpha": resource, - "/_matrix/client/api/v1": resource, - }) + resources.update( + { + "/_matrix/client/r0": resource, + "/_matrix/client/unstable": resource, + "/_matrix/client/v2_alpha": resource, + "/_matrix/client/api/v1": resource, + } + ) root_resource = create_resource_tree(resources, NoResource()) @@ -299,7 +308,7 @@ class SynchrotronServer(HomeServer): listener_config, root_resource, self.version_string, - ) + ), ) logger.info("Synapse synchrotron now listening on port %d", port) @@ -313,18 +322,19 @@ class SynchrotronServer(HomeServer): listener["bind_addresses"], listener["port"], manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ) + username="matrix", password="rabbithole", globals={"hs": self} + ), ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn(("Metrics listener configured, but " - "enable_metrics is not True!")) + logger.warn( + ( + "Metrics listener configured, but " + "enable_metrics is not True!" + ) + ) else: - _base.listen_metrics(listener["bind_addresses"], - listener["port"]) + _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -382,40 +392,36 @@ class SyncReplicationHandler(ReplicationClientHandler): ) elif stream_name == "push_rules": self.notifier.on_new_event( - "push_rules_key", token, users=[row.user_id for row in rows], + "push_rules_key", token, users=[row.user_id for row in rows] ) - elif stream_name in ("account_data", "tag_account_data",): + elif stream_name in ("account_data", "tag_account_data"): self.notifier.on_new_event( - "account_data_key", token, users=[row.user_id for row in rows], + "account_data_key", token, users=[row.user_id for row in rows] ) elif stream_name == "receipts": self.notifier.on_new_event( - "receipt_key", token, rooms=[row.room_id for row in rows], + "receipt_key", token, rooms=[row.room_id for row in rows] ) elif stream_name == "typing": self.typing_handler.process_replication_rows(token, rows) self.notifier.on_new_event( - "typing_key", token, rooms=[row.room_id for row in rows], + "typing_key", token, rooms=[row.room_id for row in rows] ) elif stream_name == "to_device": entities = [row.entity for row in rows if row.entity.startswith("@")] if entities: - self.notifier.on_new_event( - "to_device_key", token, users=entities, - ) + self.notifier.on_new_event("to_device_key", token, users=entities) elif stream_name == "device_lists": all_room_ids = set() for row in rows: room_ids = yield self.store.get_rooms_for_user(row.user_id) all_room_ids.update(room_ids) - self.notifier.on_new_event( - "device_list_key", token, rooms=all_room_ids, - ) + self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids) elif stream_name == "presence": yield self.presence_handler.process_replication_rows(token, rows) elif stream_name == "receipts": self.notifier.on_new_event( - "groups_key", token, users=[row.user_id for row in rows], + "groups_key", token, users=[row.user_id for row in rows] ) except Exception: logger.exception("Error processing replication") @@ -423,9 +429,7 @@ class SyncReplicationHandler(ReplicationClientHandler): def start(config_options): try: - config = HomeServerConfig.load_config( - "Synapse synchrotron", config_options - ) + config = HomeServerConfig.load_config("Synapse synchrotron", config_options) except ConfigError as e: sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) @@ -453,6 +457,6 @@ def start(config_options): _base.start_worker_reactor("synapse-synchrotron", config) -if __name__ == '__main__': +if __name__ == "__main__": with LoggingContext("main"): start(sys.argv[1:]) diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index 355f5aa71d..2d9d2e1bbc 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -66,14 +66,16 @@ class UserDirectorySlaveStore( events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict( - db_conn, "current_state_delta_stream", + db_conn, + "current_state_delta_stream", entity_column="room_id", stream_column="stream_id", max_value=events_max, # As we share the stream id with events token limit=1000, ) self._curr_state_delta_stream_cache = StreamChangeCache( - "_curr_state_delta_stream_cache", min_curr_state_delta_id, + "_curr_state_delta_stream_cache", + min_curr_state_delta_id, prefilled_cache=curr_state_delta_prefill, ) @@ -110,12 +112,14 @@ class UserDirectoryServer(HomeServer): elif name == "client": resource = JsonResource(self, canonical_json=False) user_directory.register_servlets(self, resource) - resources.update({ - "/_matrix/client/r0": resource, - "/_matrix/client/unstable": resource, - "/_matrix/client/v2_alpha": resource, - "/_matrix/client/api/v1": resource, - }) + resources.update( + { + "/_matrix/client/r0": resource, + "/_matrix/client/unstable": resource, + "/_matrix/client/v2_alpha": resource, + "/_matrix/client/api/v1": resource, + } + ) root_resource = create_resource_tree(resources, NoResource()) @@ -128,7 +132,7 @@ class UserDirectoryServer(HomeServer): listener_config, root_resource, self.version_string, - ) + ), ) logger.info("Synapse user_dir now listening on port %d", port) @@ -142,18 +146,19 @@ class UserDirectoryServer(HomeServer): listener["bind_addresses"], listener["port"], manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ) + username="matrix", password="rabbithole", globals={"hs": self} + ), ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn(("Metrics listener configured, but " - "enable_metrics is not True!")) + logger.warn( + ( + "Metrics listener configured, but " + "enable_metrics is not True!" + ) + ) else: - _base.listen_metrics(listener["bind_addresses"], - listener["port"]) + _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -186,9 +191,7 @@ class UserDirectoryReplicationHandler(ReplicationClientHandler): def start(config_options): try: - config = HomeServerConfig.load_config( - "Synapse user directory", config_options - ) + config = HomeServerConfig.load_config("Synapse user directory", config_options) except ConfigError as e: sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) @@ -227,6 +230,6 @@ def start(config_options): _base.start_worker_reactor("synapse-user-dir", config) -if __name__ == '__main__': +if __name__ == "__main__": with LoggingContext("main"): start(sys.argv[1:]) diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 57ed8a3ca2..b26a31dd54 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -48,9 +48,7 @@ class AppServiceTransaction(object): A Deferred which resolves to True if the transaction was sent. """ return as_api.push_bulk( - service=self.service, - events=self.events, - txn_id=self.id + service=self.service, events=self.events, txn_id=self.id ) def complete(self, store): @@ -64,10 +62,7 @@ class AppServiceTransaction(object): Returns: A Deferred which resolves to True if the transaction was completed. """ - return store.complete_appservice_txn( - service=self.service, - txn_id=self.id - ) + return store.complete_appservice_txn(service=self.service, txn_id=self.id) class ApplicationService(object): @@ -76,6 +71,7 @@ class ApplicationService(object): Provides methods to check if this service is "interested" in events. """ + NS_USERS = "users" NS_ALIASES = "aliases" NS_ROOMS = "rooms" @@ -84,9 +80,19 @@ class ApplicationService(object): # values. NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS] - def __init__(self, token, hostname, url=None, namespaces=None, hs_token=None, - sender=None, id=None, protocols=None, rate_limited=True, - ip_range_whitelist=None): + def __init__( + self, + token, + hostname, + url=None, + namespaces=None, + hs_token=None, + sender=None, + id=None, + protocols=None, + rate_limited=True, + ip_range_whitelist=None, + ): self.token = token self.url = url self.hs_token = hs_token @@ -128,9 +134,7 @@ class ApplicationService(object): if not isinstance(regex_obj, dict): raise ValueError("Expected dict regex for ns '%s'" % ns) if not isinstance(regex_obj.get("exclusive"), bool): - raise ValueError( - "Expected bool for 'exclusive' in ns '%s'" % ns - ) + raise ValueError("Expected bool for 'exclusive' in ns '%s'" % ns) group_id = regex_obj.get("group_id") if group_id: if not isinstance(group_id, str): @@ -153,9 +157,7 @@ class ApplicationService(object): if isinstance(regex, string_types): regex_obj["regex"] = re.compile(regex) # Pre-compile regex else: - raise ValueError( - "Expected string for 'regex' in ns '%s'" % ns - ) + raise ValueError("Expected string for 'regex' in ns '%s'" % ns) return namespaces def _matches_regex(self, test_string, namespace_key): @@ -178,8 +180,9 @@ class ApplicationService(object): if self.is_interested_in_user(event.sender): defer.returnValue(True) # also check m.room.member state key - if (event.type == EventTypes.Member and - self.is_interested_in_user(event.state_key)): + if event.type == EventTypes.Member and self.is_interested_in_user( + event.state_key + ): defer.returnValue(True) if not store: diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 9ccc5a80fc..571881775b 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -32,19 +32,17 @@ logger = logging.getLogger(__name__) sent_transactions_counter = Counter( "synapse_appservice_api_sent_transactions", "Number of /transactions/ requests sent", - ["service"] + ["service"], ) failed_transactions_counter = Counter( "synapse_appservice_api_failed_transactions", "Number of /transactions/ requests that failed to send", - ["service"] + ["service"], ) sent_events_counter = Counter( - "synapse_appservice_api_sent_events", - "Number of events sent to the AS", - ["service"] + "synapse_appservice_api_sent_events", "Number of events sent to the AS", ["service"] ) HOUR_IN_MS = 60 * 60 * 1000 @@ -92,8 +90,9 @@ class ApplicationServiceApi(SimpleHttpClient): super(ApplicationServiceApi, self).__init__(hs) self.clock = hs.get_clock() - self.protocol_meta_cache = ResponseCache(hs, "as_protocol_meta", - timeout_ms=HOUR_IN_MS) + self.protocol_meta_cache = ResponseCache( + hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS + ) @defer.inlineCallbacks def query_user(self, service, user_id): @@ -102,9 +101,7 @@ class ApplicationServiceApi(SimpleHttpClient): uri = service.url + ("/users/%s" % urllib.parse.quote(user_id)) response = None try: - response = yield self.get_json(uri, { - "access_token": service.hs_token - }) + response = yield self.get_json(uri, {"access_token": service.hs_token}) if response is not None: # just an empty json object defer.returnValue(True) except CodeMessageException as e: @@ -123,9 +120,7 @@ class ApplicationServiceApi(SimpleHttpClient): uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias)) response = None try: - response = yield self.get_json(uri, { - "access_token": service.hs_token - }) + response = yield self.get_json(uri, {"access_token": service.hs_token}) if response is not None: # just an empty json object defer.returnValue(True) except CodeMessageException as e: @@ -144,9 +139,7 @@ class ApplicationServiceApi(SimpleHttpClient): elif kind == ThirdPartyEntityKind.LOCATION: required_field = "alias" else: - raise ValueError( - "Unrecognised 'kind' argument %r to query_3pe()", kind - ) + raise ValueError("Unrecognised 'kind' argument %r to query_3pe()", kind) if service.url is None: defer.returnValue([]) @@ -154,14 +147,13 @@ class ApplicationServiceApi(SimpleHttpClient): service.url, APP_SERVICE_PREFIX, kind, - urllib.parse.quote(protocol) + urllib.parse.quote(protocol), ) try: response = yield self.get_json(uri, fields) if not isinstance(response, list): logger.warning( - "query_3pe to %s returned an invalid response %r", - uri, response + "query_3pe to %s returned an invalid response %r", uri, response ) defer.returnValue([]) @@ -171,8 +163,7 @@ class ApplicationServiceApi(SimpleHttpClient): ret.append(r) else: logger.warning( - "query_3pe to %s returned an invalid result %r", - uri, r + "query_3pe to %s returned an invalid result %r", uri, r ) defer.returnValue(ret) @@ -189,27 +180,27 @@ class ApplicationServiceApi(SimpleHttpClient): uri = "%s%s/thirdparty/protocol/%s" % ( service.url, APP_SERVICE_PREFIX, - urllib.parse.quote(protocol) + urllib.parse.quote(protocol), ) try: info = yield self.get_json(uri, {}) if not _is_valid_3pe_metadata(info): - logger.warning("query_3pe_protocol to %s did not return a" - " valid result", uri) + logger.warning( + "query_3pe_protocol to %s did not return a" " valid result", uri + ) defer.returnValue(None) for instance in info.get("instances", []): network_id = instance.get("network_id", None) if network_id is not None: instance["instance_id"] = ThirdPartyInstanceID( - service.id, network_id, + service.id, network_id ).to_string() defer.returnValue(info) except Exception as ex: - logger.warning("query_3pe_protocol to %s threw exception %s", - uri, ex) + logger.warning("query_3pe_protocol to %s threw exception %s", uri, ex) defer.returnValue(None) key = (service.id, protocol) @@ -223,22 +214,19 @@ class ApplicationServiceApi(SimpleHttpClient): events = self._serialize(events) if txn_id is None: - logger.warning("push_bulk: Missing txn ID sending events to %s", - service.url) + logger.warning( + "push_bulk: Missing txn ID sending events to %s", service.url + ) txn_id = str(0) txn_id = str(txn_id) - uri = service.url + ("/transactions/%s" % - urllib.parse.quote(txn_id)) + uri = service.url + ("/transactions/%s" % urllib.parse.quote(txn_id)) try: yield self.put_json( uri=uri, - json_body={ - "events": events - }, - args={ - "access_token": service.hs_token - }) + json_body={"events": events}, + args={"access_token": service.hs_token}, + ) sent_transactions_counter.labels(service.id).inc() sent_events_counter.labels(service.id).inc(len(events)) defer.returnValue(True) @@ -252,6 +240,4 @@ class ApplicationServiceApi(SimpleHttpClient): def _serialize(self, events): time_now = self.clock.time_msec() - return [ - serialize_event(e, time_now, as_client_event=True) for e in events - ] + return [serialize_event(e, time_now, as_client_event=True) for e in events] diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 685f15c061..b54bf5411f 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -112,15 +112,14 @@ class _ServiceQueuer(object): return run_as_background_process( - "as-sender-%s" % (service.id, ), - self._send_request, service, + "as-sender-%s" % (service.id,), self._send_request, service ) @defer.inlineCallbacks def _send_request(self, service): # sanity-check: we shouldn't get here if this service already has a sender # running. - assert(service.id not in self.requests_in_flight) + assert service.id not in self.requests_in_flight self.requests_in_flight.add(service.id) try: @@ -137,7 +136,6 @@ class _ServiceQueuer(object): class _TransactionController(object): - def __init__(self, clock, store, as_api, recoverer_fn): self.clock = clock self.store = store @@ -149,10 +147,7 @@ class _TransactionController(object): @defer.inlineCallbacks def send(self, service, events): try: - txn = yield self.store.create_appservice_txn( - service=service, - events=events - ) + txn = yield self.store.create_appservice_txn(service=service, events=events) service_is_up = yield self._is_service_up(service) if service_is_up: sent = yield txn.send(self.as_api) @@ -167,12 +162,12 @@ class _TransactionController(object): @defer.inlineCallbacks def on_recovered(self, recoverer): self.recoverers.remove(recoverer) - logger.info("Successfully recovered application service AS ID %s", - recoverer.service.id) + logger.info( + "Successfully recovered application service AS ID %s", recoverer.service.id + ) logger.info("Remaining active recoverers: %s", len(self.recoverers)) yield self.store.set_appservice_state( - recoverer.service, - ApplicationServiceState.UP + recoverer.service, ApplicationServiceState.UP ) def add_recoverers(self, recoverers): @@ -184,13 +179,10 @@ class _TransactionController(object): @defer.inlineCallbacks def _start_recoverer(self, service): try: - yield self.store.set_appservice_state( - service, - ApplicationServiceState.DOWN - ) + yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN) logger.info( "Application service falling behind. Starting recoverer. AS ID %s", - service.id + service.id, ) recoverer = self.recoverer_fn(service, self.on_recovered) self.add_recoverers([recoverer]) @@ -205,19 +197,16 @@ class _TransactionController(object): class _Recoverer(object): - @staticmethod @defer.inlineCallbacks def start(clock, store, as_api, callback): - services = yield store.get_appservices_by_state( - ApplicationServiceState.DOWN - ) - recoverers = [ - _Recoverer(clock, store, as_api, s, callback) for s in services - ] + services = yield store.get_appservices_by_state(ApplicationServiceState.DOWN) + recoverers = [_Recoverer(clock, store, as_api, s, callback) for s in services] for r in recoverers: - logger.info("Starting recoverer for AS ID %s which was marked as " - "DOWN", r.service.id) + logger.info( + "Starting recoverer for AS ID %s which was marked as " "DOWN", + r.service.id, + ) r.recover() defer.returnValue(recoverers) @@ -232,9 +221,9 @@ class _Recoverer(object): def recover(self): def _retry(): run_as_background_process( - "as-recoverer-%s" % (self.service.id,), - self.retry, + "as-recoverer-%s" % (self.service.id,), self.retry ) + self.clock.call_later((2 ** self.backoff_counter), _retry) def _backoff(self): @@ -248,8 +237,9 @@ class _Recoverer(object): try: txn = yield self.store.get_oldest_unsent_txn(self.service) if txn: - logger.info("Retrying transaction %s for AS ID %s", - txn.id, txn.service.id) + logger.info( + "Retrying transaction %s for AS ID %s", txn.id, txn.service.id + ) sent = yield txn.send(self.as_api) if sent: yield txn.complete(self.store) diff --git a/synapse/config/_base.py b/synapse/config/_base.py index f7d7f153bb..8284aa4c6d 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -284,8 +284,8 @@ class Config(object): if not config_files: config_parser.error( "Must supply a config file.\nA config file can be automatically" - " generated using \"--generate-config -H SERVER_NAME" - " -c CONFIG-FILE\"" + ' generated using "--generate-config -H SERVER_NAME' + ' -c CONFIG-FILE"' ) (config_path,) = config_files if not cls.path_exists(config_path): @@ -313,9 +313,7 @@ class Config(object): if not cls.path_exists(config_dir_path): os.makedirs(config_dir_path) with open(config_path, "w") as config_file: - config_file.write( - "# vim:ft=yaml\n\n" - ) + config_file.write("# vim:ft=yaml\n\n") config_file.write(config_str) config = yaml.safe_load(config_str) @@ -352,8 +350,8 @@ class Config(object): if not config_files: config_parser.error( "Must supply a config file.\nA config file can be automatically" - " generated using \"--generate-config -H SERVER_NAME" - " -c CONFIG-FILE\"" + ' generated using "--generate-config -H SERVER_NAME' + ' -c CONFIG-FILE"' ) obj.read_config_files( diff --git a/synapse/config/api.py b/synapse/config/api.py index 5eb4f86fa2..23b0ea6962 100644 --- a/synapse/config/api.py +++ b/synapse/config/api.py @@ -18,15 +18,17 @@ from ._base import Config class ApiConfig(Config): - def read_config(self, config): - self.room_invite_state_types = config.get("room_invite_state_types", [ - EventTypes.JoinRules, - EventTypes.CanonicalAlias, - EventTypes.RoomAvatar, - EventTypes.RoomEncryption, - EventTypes.Name, - ]) + self.room_invite_state_types = config.get( + "room_invite_state_types", + [ + EventTypes.JoinRules, + EventTypes.CanonicalAlias, + EventTypes.RoomAvatar, + EventTypes.RoomEncryption, + EventTypes.Name, + ], + ) def default_config(cls, **kwargs): return """\ @@ -40,4 +42,6 @@ class ApiConfig(Config): # - "{RoomAvatar}" # - "{RoomEncryption}" # - "{Name}" - """.format(**vars(EventTypes)) + """.format( + **vars(EventTypes) + ) diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 7e89d345d8..679ee62480 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -29,7 +29,6 @@ logger = logging.getLogger(__name__) class AppServiceConfig(Config): - def read_config(self, config): self.app_service_config_files = config.get("app_service_config_files", []) self.notify_appservices = config.get("notify_appservices", True) @@ -53,9 +52,7 @@ class AppServiceConfig(Config): def load_appservices(hostname, config_files): """Returns a list of Application Services from the config files.""" if not isinstance(config_files, list): - logger.warning( - "Expected %s to be a list of AS config files.", config_files - ) + logger.warning("Expected %s to be a list of AS config files.", config_files) return [] # Dicts of value -> filename @@ -66,22 +63,20 @@ def load_appservices(hostname, config_files): for config_file in config_files: try: - with open(config_file, 'r') as f: - appservice = _load_appservice( - hostname, yaml.safe_load(f), config_file - ) + with open(config_file, "r") as f: + appservice = _load_appservice(hostname, yaml.safe_load(f), config_file) if appservice.id in seen_ids: raise ConfigError( "Cannot reuse ID across application services: " - "%s (files: %s, %s)" % ( - appservice.id, config_file, seen_ids[appservice.id], - ) + "%s (files: %s, %s)" + % (appservice.id, config_file, seen_ids[appservice.id]) ) seen_ids[appservice.id] = config_file if appservice.token in seen_as_tokens: raise ConfigError( "Cannot reuse as_token across application services: " - "%s (files: %s, %s)" % ( + "%s (files: %s, %s)" + % ( appservice.token, config_file, seen_as_tokens[appservice.token], @@ -98,28 +93,26 @@ def load_appservices(hostname, config_files): def _load_appservice(hostname, as_info, config_filename): - required_string_fields = [ - "id", "as_token", "hs_token", "sender_localpart" - ] + required_string_fields = ["id", "as_token", "hs_token", "sender_localpart"] for field in required_string_fields: if not isinstance(as_info.get(field), string_types): - raise KeyError("Required string field: '%s' (%s)" % ( - field, config_filename, - )) + raise KeyError( + "Required string field: '%s' (%s)" % (field, config_filename) + ) # 'url' must either be a string or explicitly null, not missing # to avoid accidentally turning off push for ASes. - if (not isinstance(as_info.get("url"), string_types) and - as_info.get("url", "") is not None): + if ( + not isinstance(as_info.get("url"), string_types) + and as_info.get("url", "") is not None + ): raise KeyError( "Required string field or explicit null: 'url' (%s)" % (config_filename,) ) localpart = as_info["sender_localpart"] if urlparse.quote(localpart) != localpart: - raise ValueError( - "sender_localpart needs characters which are not URL encoded." - ) + raise ValueError("sender_localpart needs characters which are not URL encoded.") user = UserID(localpart, hostname) user_id = user.to_string() @@ -138,13 +131,12 @@ def _load_appservice(hostname, as_info, config_filename): for regex_obj in as_info["namespaces"][ns]: if not isinstance(regex_obj, dict): raise ValueError( - "Expected namespace entry in %s to be an object," - " but got %s", ns, regex_obj + "Expected namespace entry in %s to be an object," " but got %s", + ns, + regex_obj, ) if not isinstance(regex_obj.get("regex"), string_types): - raise ValueError( - "Missing/bad type 'regex' key in %s", regex_obj - ) + raise ValueError("Missing/bad type 'regex' key in %s", regex_obj) if not isinstance(regex_obj.get("exclusive"), bool): raise ValueError( "Missing/bad type 'exclusive' key in %s", regex_obj @@ -167,10 +159,8 @@ def _load_appservice(hostname, as_info, config_filename): ) ip_range_whitelist = None - if as_info.get('ip_range_whitelist'): - ip_range_whitelist = IPSet( - as_info.get('ip_range_whitelist') - ) + if as_info.get("ip_range_whitelist"): + ip_range_whitelist = IPSet(as_info.get("ip_range_whitelist")) return ApplicationService( token=as_info["as_token"], diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py index f7eebf26d2..e2eb473a92 100644 --- a/synapse/config/captcha.py +++ b/synapse/config/captcha.py @@ -16,7 +16,6 @@ from ._base import Config class CaptchaConfig(Config): - def read_config(self, config): self.recaptcha_private_key = config.get("recaptcha_private_key") self.recaptcha_public_key = config.get("recaptcha_public_key") diff --git a/synapse/config/consent_config.py b/synapse/config/consent_config.py index abeb0180d3..5b0bf919c7 100644 --- a/synapse/config/consent_config.py +++ b/synapse/config/consent_config.py @@ -89,29 +89,26 @@ class ConsentConfig(Config): if consent_config is None: return self.user_consent_version = str(consent_config["version"]) - self.user_consent_template_dir = self.abspath( - consent_config["template_dir"] - ) + self.user_consent_template_dir = self.abspath(consent_config["template_dir"]) if not path.isdir(self.user_consent_template_dir): raise ConfigError( - "Could not find template directory '%s'" % ( - self.user_consent_template_dir, - ), + "Could not find template directory '%s'" + % (self.user_consent_template_dir,) ) self.user_consent_server_notice_content = consent_config.get( - "server_notice_content", + "server_notice_content" ) self.block_events_without_consent_error = consent_config.get( - "block_events_error", + "block_events_error" + ) + self.user_consent_server_notice_to_guests = bool( + consent_config.get("send_server_notice_to_guests", False) + ) + self.user_consent_at_registration = bool( + consent_config.get("require_at_registration", False) ) - self.user_consent_server_notice_to_guests = bool(consent_config.get( - "send_server_notice_to_guests", False, - )) - self.user_consent_at_registration = bool(consent_config.get( - "require_at_registration", False, - )) self.user_consent_policy_name = consent_config.get( - "policy_name", "Privacy Policy", + "policy_name", "Privacy Policy" ) def default_config(self, **kwargs): diff --git a/synapse/config/database.py b/synapse/config/database.py index 3c27ed6b4a..adc0a47ddf 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -18,29 +18,21 @@ from ._base import Config class DatabaseConfig(Config): - def read_config(self, config): - self.event_cache_size = self.parse_size( - config.get("event_cache_size", "10K") - ) + self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K")) self.database_config = config.get("database") if self.database_config is None: - self.database_config = { - "name": "sqlite3", - "args": {}, - } + self.database_config = {"name": "sqlite3", "args": {}} name = self.database_config.get("name", None) if name == "psycopg2": pass elif name == "sqlite3": - self.database_config.setdefault("args", {}).update({ - "cp_min": 1, - "cp_max": 1, - "check_same_thread": False, - }) + self.database_config.setdefault("args", {}).update( + {"cp_min": 1, "cp_max": 1, "check_same_thread": False} + ) else: raise RuntimeError("Unsupported database type '%s'" % (name,)) @@ -48,7 +40,8 @@ class DatabaseConfig(Config): def default_config(self, data_dir_path, **kwargs): database_path = os.path.join(data_dir_path, "homeserver.db") - return """\ + return ( + """\ ## Database ## database: @@ -62,7 +55,9 @@ class DatabaseConfig(Config): # Number of events to cache in memory. # #event_cache_size: 10K - """ % locals() + """ + % locals() + ) def read_arguments(self, args): self.set_databasepath(args.database_path) @@ -77,6 +72,8 @@ class DatabaseConfig(Config): def add_arguments(self, parser): db_group = parser.add_argument_group("database") db_group.add_argument( - "-d", "--database-path", metavar="SQLITE_DATABASE_PATH", - help="The path to a sqlite database to use." + "-d", + "--database-path", + metavar="SQLITE_DATABASE_PATH", + help="The path to a sqlite database to use.", ) diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 86018dfcce..3a6cb07206 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -56,7 +56,7 @@ class EmailConfig(Config): if self.email_notif_from is not None: # make sure it's valid parsed = email.utils.parseaddr(self.email_notif_from) - if parsed[1] == '': + if parsed[1] == "": raise RuntimeError("Invalid notif_from address") template_dir = email_config.get("template_dir") @@ -65,19 +65,17 @@ class EmailConfig(Config): # (Note that loading as package_resources with jinja.PackageLoader doesn't # work for the same reason.) if not template_dir: - template_dir = pkg_resources.resource_filename( - 'synapse', 'res/templates' - ) + template_dir = pkg_resources.resource_filename("synapse", "res/templates") self.email_template_dir = os.path.abspath(template_dir) self.email_enable_notifs = email_config.get("enable_notifs", False) - account_validity_renewal_enabled = config.get( - "account_validity", {}, - ).get("renew_at") + account_validity_renewal_enabled = config.get("account_validity", {}).get( + "renew_at" + ) email_trust_identity_server_for_password_resets = email_config.get( - "trust_identity_server_for_password_resets", False, + "trust_identity_server_for_password_resets", False ) self.email_password_reset_behaviour = ( "remote" if email_trust_identity_server_for_password_resets else "local" @@ -103,62 +101,59 @@ class EmailConfig(Config): # make sure we can import the required deps import jinja2 import bleach + # prevent unused warnings jinja2 bleach if self.email_password_reset_behaviour == "local": - required = [ - "smtp_host", - "smtp_port", - "notif_from", - ] + required = ["smtp_host", "smtp_port", "notif_from"] missing = [] for k in required: if k not in email_config: missing.append(k) - if (len(missing) > 0): + if len(missing) > 0: raise RuntimeError( "email.password_reset_behaviour is set to 'local' " - "but required keys are missing: %s" % - (", ".join(["email." + k for k in missing]),) + "but required keys are missing: %s" + % (", ".join(["email." + k for k in missing]),) ) # Templates for password reset emails self.email_password_reset_template_html = email_config.get( - "password_reset_template_html", "password_reset.html", + "password_reset_template_html", "password_reset.html" ) self.email_password_reset_template_text = email_config.get( - "password_reset_template_text", "password_reset.txt", + "password_reset_template_text", "password_reset.txt" ) self.email_password_reset_failure_template = email_config.get( - "password_reset_failure_template", "password_reset_failure.html", + "password_reset_failure_template", "password_reset_failure.html" ) # This template does not support any replaceable variables, so we will # read it from the disk once during setup email_password_reset_success_template = email_config.get( - "password_reset_success_template", "password_reset_success.html", + "password_reset_success_template", "password_reset_success.html" ) # Check templates exist - for f in [self.email_password_reset_template_html, - self.email_password_reset_template_text, - self.email_password_reset_failure_template, - email_password_reset_success_template]: + for f in [ + self.email_password_reset_template_html, + self.email_password_reset_template_text, + self.email_password_reset_failure_template, + email_password_reset_success_template, + ]: p = os.path.join(self.email_template_dir, f) if not os.path.isfile(p): - raise ConfigError("Unable to find template file %s" % (p, )) + raise ConfigError("Unable to find template file %s" % (p,)) # Retrieve content of web templates filepath = os.path.join( - self.email_template_dir, - email_password_reset_success_template, + self.email_template_dir, email_password_reset_success_template ) self.email_password_reset_success_html_content = self.read_file( - filepath, - "email.password_reset_template_success_html", + filepath, "email.password_reset_template_success_html" ) if config.get("public_baseurl") is None: @@ -182,10 +177,10 @@ class EmailConfig(Config): if k not in email_config: missing.append(k) - if (len(missing) > 0): + if len(missing) > 0: raise RuntimeError( - "email.enable_notifs is True but required keys are missing: %s" % - (", ".join(["email." + k for k in missing]),) + "email.enable_notifs is True but required keys are missing: %s" + % (", ".join(["email." + k for k in missing]),) ) if config.get("public_baseurl") is None: @@ -199,27 +194,25 @@ class EmailConfig(Config): for f in self.email_notif_template_text, self.email_notif_template_html: p = os.path.join(self.email_template_dir, f) if not os.path.isfile(p): - raise ConfigError("Unable to find email template file %s" % (p, )) + raise ConfigError("Unable to find email template file %s" % (p,)) self.email_notif_for_new_users = email_config.get( "notif_for_new_users", True ) - self.email_riot_base_url = email_config.get( - "riot_base_url", None - ) + self.email_riot_base_url = email_config.get("riot_base_url", None) if account_validity_renewal_enabled: self.email_expiry_template_html = email_config.get( - "expiry_template_html", "notice_expiry.html", + "expiry_template_html", "notice_expiry.html" ) self.email_expiry_template_text = email_config.get( - "expiry_template_text", "notice_expiry.txt", + "expiry_template_text", "notice_expiry.txt" ) for f in self.email_expiry_template_text, self.email_expiry_template_html: p = os.path.join(self.email_template_dir, f) if not os.path.isfile(p): - raise ConfigError("Unable to find email template file %s" % (p, )) + raise ConfigError("Unable to find email template file %s" % (p,)) def default_config(self, config_dir_path, server_name, **kwargs): return """ diff --git a/synapse/config/jwt_config.py b/synapse/config/jwt_config.py index ecb4124096..b190dcbe38 100644 --- a/synapse/config/jwt_config.py +++ b/synapse/config/jwt_config.py @@ -15,13 +15,11 @@ from ._base import Config, ConfigError -MISSING_JWT = ( - """Missing jwt library. This is required for jwt login. +MISSING_JWT = """Missing jwt library. This is required for jwt login. Install by running: pip install pyjwt """ -) class JWTConfig(Config): @@ -34,6 +32,7 @@ class JWTConfig(Config): try: import jwt + jwt # To stop unused lint. except ImportError: raise ConfigError(MISSING_JWT) diff --git a/synapse/config/key.py b/synapse/config/key.py index 424875feae..94a0f47ea4 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -348,9 +348,8 @@ def _parse_key_servers(key_servers, federation_verify_certificates): result.verify_keys[key_id] = verify_key - if ( - not federation_verify_certificates and - not server.get("accept_keys_insecurely") + if not federation_verify_certificates and not server.get( + "accept_keys_insecurely" ): _assert_keyserver_has_verify_keys(result) diff --git a/synapse/config/logger.py b/synapse/config/logger.py index c1febbe9d3..a22655b125 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -29,7 +29,8 @@ from synapse.util.versionstring import get_version_string from ._base import Config -DEFAULT_LOG_CONFIG = Template(""" +DEFAULT_LOG_CONFIG = Template( + """ version: 1 formatters: @@ -68,11 +69,11 @@ loggers: root: level: INFO handlers: [file, console] -""") +""" +) class LoggingConfig(Config): - def read_config(self, config): self.verbosity = config.get("verbose", 0) self.no_redirect_stdio = config.get("no_redirect_stdio", False) @@ -81,13 +82,16 @@ class LoggingConfig(Config): def default_config(self, config_dir_path, server_name, **kwargs): log_config = os.path.join(config_dir_path, server_name + ".log.config") - return """\ + return ( + """\ ## Logging ## # A yaml python logging config file # log_config: "%(log_config)s" - """ % locals() + """ + % locals() + ) def read_arguments(self, args): if args.verbose is not None: @@ -102,22 +106,31 @@ class LoggingConfig(Config): def add_arguments(cls, parser): logging_group = parser.add_argument_group("logging") logging_group.add_argument( - '-v', '--verbose', dest="verbose", action='count', + "-v", + "--verbose", + dest="verbose", + action="count", help="The verbosity level. Specify multiple times to increase " - "verbosity. (Ignored if --log-config is specified.)" + "verbosity. (Ignored if --log-config is specified.)", ) logging_group.add_argument( - '-f', '--log-file', dest="log_file", - help="File to log to. (Ignored if --log-config is specified.)" + "-f", + "--log-file", + dest="log_file", + help="File to log to. (Ignored if --log-config is specified.)", ) logging_group.add_argument( - '--log-config', dest="log_config", default=None, - help="Python logging config file" + "--log-config", + dest="log_config", + default=None, + help="Python logging config file", ) logging_group.add_argument( - '-n', '--no-redirect-stdio', - action='store_true', default=None, - help="Do not redirect stdout/stderr to the log" + "-n", + "--no-redirect-stdio", + action="store_true", + default=None, + help="Do not redirect stdout/stderr to the log", ) def generate_files(self, config): @@ -125,9 +138,7 @@ class LoggingConfig(Config): if log_config and not os.path.exists(log_config): log_file = self.abspath("homeserver.log") with open(log_config, "w") as log_config_file: - log_config_file.write( - DEFAULT_LOG_CONFIG.substitute(log_file=log_file) - ) + log_config_file.write(DEFAULT_LOG_CONFIG.substitute(log_file=log_file)) def setup_logging(config, use_worker_options=False): @@ -143,10 +154,8 @@ def setup_logging(config, use_worker_options=False): register_sighup (func | None): Function to call to register a sighup handler. """ - log_config = (config.worker_log_config if use_worker_options - else config.log_config) - log_file = (config.worker_log_file if use_worker_options - else config.log_file) + log_config = config.worker_log_config if use_worker_options else config.log_config + log_file = config.worker_log_file if use_worker_options else config.log_file log_format = ( "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s" @@ -164,23 +173,23 @@ def setup_logging(config, use_worker_options=False): if config.verbosity > 1: level_for_storage = logging.DEBUG - logger = logging.getLogger('') + logger = logging.getLogger("") logger.setLevel(level) - logging.getLogger('synapse.storage.SQL').setLevel(level_for_storage) + logging.getLogger("synapse.storage.SQL").setLevel(level_for_storage) formatter = logging.Formatter(log_format) if log_file: # TODO: Customisable file size / backup count handler = logging.handlers.RotatingFileHandler( - log_file, maxBytes=(1000 * 1000 * 100), backupCount=3, - encoding='utf8' + log_file, maxBytes=(1000 * 1000 * 100), backupCount=3, encoding="utf8" ) def sighup(signum, stack): logger.info("Closing log file due to SIGHUP") handler.doRollover() logger.info("Opened new log file due to SIGHUP") + else: handler = logging.StreamHandler() @@ -193,8 +202,9 @@ def setup_logging(config, use_worker_options=False): logger.addHandler(handler) else: + def load_log_config(): - with open(log_config, 'r') as f: + with open(log_config, "r") as f: logging.config.dictConfig(yaml.safe_load(f)) def sighup(*args): @@ -209,10 +219,7 @@ def setup_logging(config, use_worker_options=False): # make sure that the first thing we log is a thing we can grep backwards # for logging.warn("***** STARTING SERVER *****") - logging.warn( - "Server %s version %s", - sys.argv[0], get_version_string(synapse), - ) + logging.warn("Server %s version %s", sys.argv[0], get_version_string(synapse)) logging.info("Server hostname: %s", config.server_name) # It's critical to point twisted's internal logging somewhere, otherwise it @@ -242,8 +249,7 @@ def setup_logging(config, use_worker_options=False): return observer(event) globalLogBeginner.beginLoggingTo( - [_log], - redirectStandardIO=not config.no_redirect_stdio, + [_log], redirectStandardIO=not config.no_redirect_stdio ) if not config.no_redirect_stdio: print("Redirected stdout/stderr to logs") diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index 2de51979d8..c85e234d22 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -15,11 +15,9 @@ from ._base import Config, ConfigError -MISSING_SENTRY = ( - """Missing sentry-sdk library. This is required to enable sentry +MISSING_SENTRY = """Missing sentry-sdk library. This is required to enable sentry integration. """ -) class MetricsConfig(Config): @@ -39,7 +37,7 @@ class MetricsConfig(Config): self.sentry_dsn = config["sentry"].get("dsn") if not self.sentry_dsn: raise ConfigError( - "sentry.dsn field is required when sentry integration is enabled", + "sentry.dsn field is required when sentry integration is enabled" ) def default_config(self, report_stats=None, **kwargs): @@ -66,6 +64,6 @@ class MetricsConfig(Config): if report_stats is None: res += "# report_stats: true|false\n" else: - res += "report_stats: %s\n" % ('true' if report_stats else 'false') + res += "report_stats: %s\n" % ("true" if report_stats else "false") return res diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py index f0a6be0679..fcf279e8e1 100644 --- a/synapse/config/password_auth_providers.py +++ b/synapse/config/password_auth_providers.py @@ -17,7 +17,7 @@ from synapse.util.module_loader import load_module from ._base import Config -LDAP_PROVIDER = 'ldap_auth_provider.LdapAuthProvider' +LDAP_PROVIDER = "ldap_auth_provider.LdapAuthProvider" class PasswordAuthProviderConfig(Config): @@ -29,24 +29,20 @@ class PasswordAuthProviderConfig(Config): # param. ldap_config = config.get("ldap_config", {}) if ldap_config.get("enabled", False): - providers.append({ - 'module': LDAP_PROVIDER, - 'config': ldap_config, - }) + providers.append({"module": LDAP_PROVIDER, "config": ldap_config}) providers.extend(config.get("password_providers", [])) for provider in providers: - mod_name = provider['module'] + mod_name = provider["module"] # This is for backwards compat when the ldap auth provider resided # in this package. if mod_name == "synapse.util.ldap_auth_provider.LdapAuthProvider": mod_name = LDAP_PROVIDER - (provider_class, provider_config) = load_module({ - "module": mod_name, - "config": provider['config'], - }) + (provider_class, provider_config) = load_module( + {"module": mod_name, "config": provider["config"]} + ) self.password_providers.append((provider_class, provider_config)) diff --git a/synapse/config/registration.py b/synapse/config/registration.py index aad3400819..a1e27ba66c 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -23,7 +23,7 @@ from synapse.util.stringutils import random_string_with_symbols class AccountValidityConfig(Config): def __init__(self, config, synapse_config): self.enabled = config.get("enabled", False) - self.renew_by_email_enabled = ("renew_at" in config) + self.renew_by_email_enabled = "renew_at" in config if self.enabled: if "period" in config: @@ -39,14 +39,13 @@ class AccountValidityConfig(Config): else: self.renew_email_subject = "Renew your %(app)s account" - self.startup_job_max_delta = self.period * 10. / 100. + self.startup_job_max_delta = self.period * 10.0 / 100.0 if self.renew_by_email_enabled and "public_baseurl" not in synapse_config: raise ConfigError("Can't send renewal emails without 'public_baseurl'") class RegistrationConfig(Config): - def read_config(self, config): self.enable_registration = bool( strtobool(str(config.get("enable_registration", False))) @@ -57,7 +56,7 @@ class RegistrationConfig(Config): ) self.account_validity = AccountValidityConfig( - config.get("account_validity", {}), config, + config.get("account_validity", {}), config ) self.registrations_require_3pid = config.get("registrations_require_3pid", []) @@ -67,24 +66,23 @@ class RegistrationConfig(Config): self.bcrypt_rounds = config.get("bcrypt_rounds", 12) self.trusted_third_party_id_servers = config.get( - "trusted_third_party_id_servers", - ["matrix.org", "vector.im"], + "trusted_third_party_id_servers", ["matrix.org", "vector.im"] ) self.default_identity_server = config.get("default_identity_server") self.allow_guest_access = config.get("allow_guest_access", False) - self.invite_3pid_guest = ( - self.allow_guest_access and config.get("invite_3pid_guest", False) + self.invite_3pid_guest = self.allow_guest_access and config.get( + "invite_3pid_guest", False ) self.auto_join_rooms = config.get("auto_join_rooms", []) for room_alias in self.auto_join_rooms: if not RoomAlias.is_valid(room_alias): - raise ConfigError('Invalid auto_join_rooms entry %s' % (room_alias,)) + raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,)) self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True) - self.disable_msisdn_registration = ( - config.get("disable_msisdn_registration", False) + self.disable_msisdn_registration = config.get( + "disable_msisdn_registration", False ) def default_config(self, generate_secrets=False, **kwargs): @@ -93,9 +91,12 @@ class RegistrationConfig(Config): random_string_with_symbols(50), ) else: - registration_shared_secret = '# registration_shared_secret: ' + registration_shared_secret = ( + "# registration_shared_secret: " + ) - return """\ + return ( + """\ ## Registration ## # # Registration can be rate-limited using the parameters in the "Ratelimiting" @@ -217,17 +218,19 @@ class RegistrationConfig(Config): # users cannot be auto-joined since they do not exist. # #autocreate_auto_join_rooms: true - """ % locals() + """ + % locals() + ) def add_arguments(self, parser): reg_group = parser.add_argument_group("registration") reg_group.add_argument( - "--enable-registration", action="store_true", default=None, - help="Enable registration for new users." + "--enable-registration", + action="store_true", + default=None, + help="Enable registration for new users.", ) def read_arguments(self, args): if args.enable_registration is not None: - self.enable_registration = bool( - strtobool(str(args.enable_registration)) - ) + self.enable_registration = bool(strtobool(str(args.enable_registration))) diff --git a/synapse/config/repository.py b/synapse/config/repository.py index fbfcecc240..9f9669ebb1 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -20,27 +20,11 @@ from synapse.util.module_loader import load_module from ._base import Config, ConfigError DEFAULT_THUMBNAIL_SIZES = [ - { - "width": 32, - "height": 32, - "method": "crop", - }, { - "width": 96, - "height": 96, - "method": "crop", - }, { - "width": 320, - "height": 240, - "method": "scale", - }, { - "width": 640, - "height": 480, - "method": "scale", - }, { - "width": 800, - "height": 600, - "method": "scale" - }, + {"width": 32, "height": 32, "method": "crop"}, + {"width": 96, "height": 96, "method": "crop"}, + {"width": 320, "height": 240, "method": "scale"}, + {"width": 640, "height": 480, "method": "scale"}, + {"width": 800, "height": 600, "method": "scale"}, ] THUMBNAIL_SIZE_YAML = """\ @@ -49,19 +33,15 @@ THUMBNAIL_SIZE_YAML = """\ # method: %(method)s """ -MISSING_NETADDR = ( - "Missing netaddr library. This is required for URL preview API." -) +MISSING_NETADDR = "Missing netaddr library. This is required for URL preview API." -MISSING_LXML = ( - """Missing lxml library. This is required for URL preview API. +MISSING_LXML = """Missing lxml library. This is required for URL preview API. Install by running: pip install lxml Requires libxslt1-dev system package. """ -) ThumbnailRequirement = namedtuple( @@ -69,7 +49,8 @@ ThumbnailRequirement = namedtuple( ) MediaStorageProviderConfig = namedtuple( - "MediaStorageProviderConfig", ( + "MediaStorageProviderConfig", + ( "store_local", # Whether to store newly uploaded local files "store_remote", # Whether to store newly downloaded remote files "store_synchronous", # Whether to wait for successful storage for local uploads @@ -100,8 +81,7 @@ def parse_thumbnail_requirements(thumbnail_sizes): requirements.setdefault("image/gif", []).append(png_thumbnail) requirements.setdefault("image/png", []).append(png_thumbnail) return { - media_type: tuple(thumbnails) - for media_type, thumbnails in requirements.items() + media_type: tuple(thumbnails) for media_type, thumbnails in requirements.items() } @@ -127,15 +107,15 @@ class ContentRepositoryConfig(Config): "Cannot use both 'backup_media_store_path' and 'storage_providers'" ) - storage_providers = [{ - "module": "file_system", - "store_local": True, - "store_synchronous": synchronous_backup_media_store, - "store_remote": True, - "config": { - "directory": backup_media_store_path, + storage_providers = [ + { + "module": "file_system", + "store_local": True, + "store_synchronous": synchronous_backup_media_store, + "store_remote": True, + "config": {"directory": backup_media_store_path}, } - }] + ] # This is a list of config that can be used to create the storage # providers. The entries are tuples of (Class, class_config, @@ -165,18 +145,19 @@ class ContentRepositoryConfig(Config): ) self.media_storage_providers.append( - (provider_class, parsed_config, wrapper_config,) + (provider_class, parsed_config, wrapper_config) ) self.uploads_path = self.ensure_directory(config["uploads_path"]) self.dynamic_thumbnails = config.get("dynamic_thumbnails", False) self.thumbnail_requirements = parse_thumbnail_requirements( - config.get("thumbnail_sizes", DEFAULT_THUMBNAIL_SIZES), + config.get("thumbnail_sizes", DEFAULT_THUMBNAIL_SIZES) ) self.url_preview_enabled = config.get("url_preview_enabled", False) if self.url_preview_enabled: try: import lxml + lxml # To stop unused lint. except ImportError: raise ConfigError(MISSING_LXML) @@ -199,15 +180,13 @@ class ContentRepositoryConfig(Config): # we always blacklist '0.0.0.0' and '::', which are supposed to be # unroutable addresses. - self.url_preview_ip_range_blacklist.update(['0.0.0.0', '::']) + self.url_preview_ip_range_blacklist.update(["0.0.0.0", "::"]) self.url_preview_ip_range_whitelist = IPSet( config.get("url_preview_ip_range_whitelist", ()) ) - self.url_preview_url_blacklist = config.get( - "url_preview_url_blacklist", () - ) + self.url_preview_url_blacklist = config.get("url_preview_url_blacklist", ()) def default_config(self, data_dir_path, **kwargs): media_store = os.path.join(data_dir_path, "media_store") @@ -219,7 +198,8 @@ class ContentRepositoryConfig(Config): # strip final NL formatted_thumbnail_sizes = formatted_thumbnail_sizes[:-1] - return r""" + return ( + r""" # Directory where uploaded images and attachments are stored. # media_store_path: "%(media_store)s" @@ -342,4 +322,6 @@ class ContentRepositoryConfig(Config): # The largest allowed URL preview spidering size in bytes # #max_spider_size: 10M - """ % locals() + """ + % locals() + ) diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py index 8a9fded4c5..c1da0e20e0 100644 --- a/synapse/config/room_directory.py +++ b/synapse/config/room_directory.py @@ -20,9 +20,7 @@ from ._base import Config, ConfigError class RoomDirectoryConfig(Config): def read_config(self, config): - self.enable_room_list_search = config.get( - "enable_room_list_search", True, - ) + self.enable_room_list_search = config.get("enable_room_list_search", True) alias_creation_rules = config.get("alias_creation_rules") @@ -33,11 +31,7 @@ class RoomDirectoryConfig(Config): ] else: self._alias_creation_rules = [ - _RoomDirectoryRule( - "alias_creation_rules", { - "action": "allow", - } - ) + _RoomDirectoryRule("alias_creation_rules", {"action": "allow"}) ] room_list_publication_rules = config.get("room_list_publication_rules") @@ -49,11 +43,7 @@ class RoomDirectoryConfig(Config): ] else: self._room_list_publication_rules = [ - _RoomDirectoryRule( - "room_list_publication_rules", { - "action": "allow", - } - ) + _RoomDirectoryRule("room_list_publication_rules", {"action": "allow"}) ] def default_config(self, config_dir_path, server_name, **kwargs): @@ -178,8 +168,7 @@ class _RoomDirectoryRule(object): self.action = action else: raise ConfigError( - "%s rules can only have action of 'allow'" - " or 'deny'" % (option_name,) + "%s rules can only have action of 'allow'" " or 'deny'" % (option_name,) ) self._alias_matches_all = alias == "*" diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index aa6eac271f..2ec38e48e9 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -28,6 +28,7 @@ class SAML2Config(Config): self.saml2_enabled = True import saml2.config + self.saml2_sp_config = saml2.config.SPConfig() self.saml2_sp_config.load(self._default_saml_config_dict()) self.saml2_sp_config.load(saml2_config.get("sp_config", {})) @@ -41,26 +42,23 @@ class SAML2Config(Config): public_baseurl = self.public_baseurl if public_baseurl is None: - raise ConfigError( - "saml2_config requires a public_baseurl to be set" - ) + raise ConfigError("saml2_config requires a public_baseurl to be set") metadata_url = public_baseurl + "_matrix/saml2/metadata.xml" response_url = public_baseurl + "_matrix/saml2/authn_response" return { "entityid": metadata_url, - "service": { "sp": { "endpoints": { "assertion_consumer_service": [ - (response_url, saml2.BINDING_HTTP_POST), - ], + (response_url, saml2.BINDING_HTTP_POST) + ] }, "required_attributes": ["uid"], "optional_attributes": ["mail", "surname", "givenname"], - }, - } + } + }, } def default_config(self, config_dir_path, server_name, **kwargs): @@ -106,4 +104,6 @@ class SAML2Config(Config): # # separate pysaml2 configuration file: # # # config_path: "%(config_dir_path)s/sp_conf.py" - """ % {"config_dir_path": config_dir_path} + """ % { + "config_dir_path": config_dir_path + } diff --git a/synapse/config/server.py b/synapse/config/server.py index 6e5b46e6c3..6d3f1da96c 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -34,13 +34,12 @@ logger = logging.Logger(__name__) # # We later check for errors when binding to 0.0.0.0 and ignore them if :: is also in # in the list. -DEFAULT_BIND_ADDRESSES = ['::', '0.0.0.0'] +DEFAULT_BIND_ADDRESSES = ["::", "0.0.0.0"] DEFAULT_ROOM_VERSION = "4" class ServerConfig(Config): - def read_config(self, config): self.server_name = config["server_name"] self.server_context = config.get("server_context", None) @@ -81,27 +80,25 @@ class ServerConfig(Config): # Whether to require authentication to retrieve profile data (avatars, # display names) of other users through the client API. self.require_auth_for_profile_requests = config.get( - "require_auth_for_profile_requests", False, + "require_auth_for_profile_requests", False ) # If set to 'True', requires authentication to access the server's # public rooms directory through the client API, and forbids any other # homeserver to fetch it via federation. self.restrict_public_rooms_to_local_users = config.get( - "restrict_public_rooms_to_local_users", False, + "restrict_public_rooms_to_local_users", False ) - default_room_version = config.get( - "default_room_version", DEFAULT_ROOM_VERSION, - ) + default_room_version = config.get("default_room_version", DEFAULT_ROOM_VERSION) # Ensure room version is a str default_room_version = str(default_room_version) if default_room_version not in KNOWN_ROOM_VERSIONS: raise ConfigError( - "Unknown default_room_version: %s, known room versions: %s" % - (default_room_version, list(KNOWN_ROOM_VERSIONS.keys())) + "Unknown default_room_version: %s, known room versions: %s" + % (default_room_version, list(KNOWN_ROOM_VERSIONS.keys())) ) # Get the actual room version object rather than just the identifier @@ -116,31 +113,25 @@ class ServerConfig(Config): # Whether we should block invites sent to users on this server # (other than those sent by local server admins) - self.block_non_admin_invites = config.get( - "block_non_admin_invites", False, - ) + self.block_non_admin_invites = config.get("block_non_admin_invites", False) # Whether to enable experimental MSC1849 (aka relations) support self.experimental_msc1849_support_enabled = config.get( - "experimental_msc1849_support_enabled", False, + "experimental_msc1849_support_enabled", False ) # Options to control access by tracking MAU self.limit_usage_by_mau = config.get("limit_usage_by_mau", False) self.max_mau_value = 0 if self.limit_usage_by_mau: - self.max_mau_value = config.get( - "max_mau_value", 0, - ) + self.max_mau_value = config.get("max_mau_value", 0) self.mau_stats_only = config.get("mau_stats_only", False) self.mau_limits_reserved_threepids = config.get( "mau_limit_reserved_threepids", [] ) - self.mau_trial_days = config.get( - "mau_trial_days", 0, - ) + self.mau_trial_days = config.get("mau_trial_days", 0) # Options to disable HS self.hs_disabled = config.get("hs_disabled", False) @@ -153,9 +144,7 @@ class ServerConfig(Config): # FIXME: federation_domain_whitelist needs sytests self.federation_domain_whitelist = None - federation_domain_whitelist = config.get( - "federation_domain_whitelist", None, - ) + federation_domain_whitelist = config.get("federation_domain_whitelist", None) if federation_domain_whitelist is not None: # turn the whitelist into a hash for speed of lookup @@ -165,7 +154,7 @@ class ServerConfig(Config): self.federation_domain_whitelist[domain] = True self.federation_ip_range_blacklist = config.get( - "federation_ip_range_blacklist", [], + "federation_ip_range_blacklist", [] ) # Attempt to create an IPSet from the given ranges @@ -178,13 +167,12 @@ class ServerConfig(Config): self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) except Exception as e: raise ConfigError( - "Invalid range(s) provided in " - "federation_ip_range_blacklist: %s" % e + "Invalid range(s) provided in " "federation_ip_range_blacklist: %s" % e ) if self.public_baseurl is not None: - if self.public_baseurl[-1] != '/': - self.public_baseurl += '/' + if self.public_baseurl[-1] != "/": + self.public_baseurl += "/" self.start_pushers = config.get("start_pushers", True) # (undocumented) option for torturing the worker-mode replication a bit, @@ -195,7 +183,7 @@ class ServerConfig(Config): # Whether to require a user to be in the room to add an alias to it. # Defaults to True. self.require_membership_for_aliases = config.get( - "require_membership_for_aliases", True, + "require_membership_for_aliases", True ) # Whether to allow per-room membership profiles through the send of membership @@ -227,9 +215,9 @@ class ServerConfig(Config): # if we still have an empty list of addresses, use the default list if not bind_addresses: - if listener['type'] == 'metrics': + if listener["type"] == "metrics": # the metrics listener doesn't support IPv6 - bind_addresses.append('0.0.0.0') + bind_addresses.append("0.0.0.0") else: bind_addresses.extend(DEFAULT_BIND_ADDRESSES) @@ -249,78 +237,72 @@ class ServerConfig(Config): bind_host = config.get("bind_host", "") gzip_responses = config.get("gzip_responses", True) - self.listeners.append({ - "port": bind_port, - "bind_addresses": [bind_host], - "tls": True, - "type": "http", - "resources": [ - { - "names": ["client"], - "compress": gzip_responses, - }, - { - "names": ["federation"], - "compress": False, - } - ] - }) - - unsecure_port = config.get("unsecure_port", bind_port - 400) - if unsecure_port: - self.listeners.append({ - "port": unsecure_port, + self.listeners.append( + { + "port": bind_port, "bind_addresses": [bind_host], - "tls": False, + "tls": True, "type": "http", "resources": [ - { - "names": ["client"], - "compress": gzip_responses, - }, - { - "names": ["federation"], - "compress": False, - } - ] - }) + {"names": ["client"], "compress": gzip_responses}, + {"names": ["federation"], "compress": False}, + ], + } + ) + + unsecure_port = config.get("unsecure_port", bind_port - 400) + if unsecure_port: + self.listeners.append( + { + "port": unsecure_port, + "bind_addresses": [bind_host], + "tls": False, + "type": "http", + "resources": [ + {"names": ["client"], "compress": gzip_responses}, + {"names": ["federation"], "compress": False}, + ], + } + ) manhole = config.get("manhole") if manhole: - self.listeners.append({ - "port": manhole, - "bind_addresses": ["127.0.0.1"], - "type": "manhole", - "tls": False, - }) + self.listeners.append( + { + "port": manhole, + "bind_addresses": ["127.0.0.1"], + "type": "manhole", + "tls": False, + } + ) metrics_port = config.get("metrics_port") if metrics_port: logger.warn( - ("The metrics_port configuration option is deprecated in Synapse 0.31 " - "in favour of a listener. Please see " - "http://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.rst" - " on how to configure the new listener.")) - - self.listeners.append({ - "port": metrics_port, - "bind_addresses": [config.get("metrics_bind_host", "127.0.0.1")], - "tls": False, - "type": "http", - "resources": [ - { - "names": ["metrics"], - "compress": False, - }, - ] - }) + ( + "The metrics_port configuration option is deprecated in Synapse 0.31 " + "in favour of a listener. Please see " + "http://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.rst" + " on how to configure the new listener." + ) + ) + + self.listeners.append( + { + "port": metrics_port, + "bind_addresses": [config.get("metrics_bind_host", "127.0.0.1")], + "tls": False, + "type": "http", + "resources": [{"names": ["metrics"], "compress": False}], + } + ) _check_resource_config(self.listeners) # An experimental option to try and periodically clean up extremities # by sending dummy events. self.cleanup_extremities_with_dummy_events = config.get( - "cleanup_extremities_with_dummy_events", False, + "cleanup_extremities_with_dummy_events", False ) def has_tls_listener(self): @@ -339,7 +321,8 @@ class ServerConfig(Config): # Bring DEFAULT_ROOM_VERSION into the local-scope for use in the # default config string default_room_version = DEFAULT_ROOM_VERSION - return """\ + return ( + """\ ## Server ## # The domain name of the server, with optional explicit port. @@ -637,7 +620,9 @@ class ServerConfig(Config): # Defaults to 'true'. # #allow_per_room_profiles: false - """ % locals() + """ + % locals() + ) def read_arguments(self, args): if args.manhole is not None: @@ -649,17 +634,26 @@ class ServerConfig(Config): def add_arguments(self, parser): server_group = parser.add_argument_group("server") - server_group.add_argument("-D", "--daemonize", action='store_true', - default=None, - help="Daemonize the home server") - server_group.add_argument("--print-pidfile", action='store_true', - default=None, - help="Print the path to the pidfile just" - " before daemonizing") - server_group.add_argument("--manhole", metavar="PORT", dest="manhole", - type=int, - help="Turn on the twisted telnet manhole" - " service on the given port.") + server_group.add_argument( + "-D", + "--daemonize", + action="store_true", + default=None, + help="Daemonize the home server", + ) + server_group.add_argument( + "--print-pidfile", + action="store_true", + default=None, + help="Print the path to the pidfile just" " before daemonizing", + ) + server_group.add_argument( + "--manhole", + metavar="PORT", + dest="manhole", + type=int, + help="Turn on the twisted telnet manhole" " service on the given port.", + ) def is_threepid_reserved(reserved_threepids, threepid): @@ -673,7 +667,7 @@ def is_threepid_reserved(reserved_threepids, threepid): """ for tp in reserved_threepids: - if (threepid['medium'] == tp['medium'] and threepid['address'] == tp['address']): + if threepid["medium"] == tp["medium"] and threepid["address"] == tp["address"]: return True return False @@ -686,9 +680,7 @@ def read_gc_thresholds(thresholds): return None try: assert len(thresholds) == 3 - return ( - int(thresholds[0]), int(thresholds[1]), int(thresholds[2]), - ) + return (int(thresholds[0]), int(thresholds[1]), int(thresholds[2])) except Exception: raise ConfigError( "Value of `gc_threshold` must be a list of three integers if set" @@ -706,22 +698,22 @@ def _warn_if_webclient_configured(listeners): for listener in listeners: for res in listener.get("resources", []): for name in res.get("names", []): - if name == 'webclient': + if name == "webclient": logger.warning(NO_MORE_WEB_CLIENT_WARNING) return KNOWN_RESOURCES = ( - 'client', - 'consent', - 'federation', - 'keys', - 'media', - 'metrics', - 'openid', - 'replication', - 'static', - 'webclient', + "client", + "consent", + "federation", + "keys", + "media", + "metrics", + "openid", + "replication", + "static", + "webclient", ) @@ -735,11 +727,9 @@ def _check_resource_config(listeners): for resource in resource_names: if resource not in KNOWN_RESOURCES: - raise ConfigError( - "Unknown listener resource '%s'" % (resource, ) - ) + raise ConfigError("Unknown listener resource '%s'" % (resource,)) if resource == "consent": try: - check_requirements('resources.consent') + check_requirements("resources.consent") except DependencyException as e: raise ConfigError(e.message) diff --git a/synapse/config/server_notices_config.py b/synapse/config/server_notices_config.py index 529dc0a617..d930eb33b5 100644 --- a/synapse/config/server_notices_config.py +++ b/synapse/config/server_notices_config.py @@ -58,6 +58,7 @@ class ServerNoticesConfig(Config): The name to use for the server notices room. None if server notices are not enabled. """ + def __init__(self): super(ServerNoticesConfig, self).__init__() self.server_notices_mxid = None @@ -70,18 +71,12 @@ class ServerNoticesConfig(Config): if c is None: return - mxid_localpart = c['system_mxid_localpart'] - self.server_notices_mxid = UserID( - mxid_localpart, self.server_name, - ).to_string() - self.server_notices_mxid_display_name = c.get( - 'system_mxid_display_name', None, - ) - self.server_notices_mxid_avatar_url = c.get( - 'system_mxid_avatar_url', None, - ) + mxid_localpart = c["system_mxid_localpart"] + self.server_notices_mxid = UserID(mxid_localpart, self.server_name).to_string() + self.server_notices_mxid_display_name = c.get("system_mxid_display_name", None) + self.server_notices_mxid_avatar_url = c.get("system_mxid_avatar_url", None) # todo: i18n - self.server_notices_room_name = c.get('room_name', "Server Notices") + self.server_notices_room_name = c.get("room_name", "Server Notices") def default_config(self, **kwargs): return DEFAULT_CONFIG diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 658f9dd361..7951bf21fa 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -42,11 +42,11 @@ class TlsConfig(Config): self.acme_enabled = acme_config.get("enabled", False) # hyperlink complains on py2 if this is not a Unicode - self.acme_url = six.text_type(acme_config.get( - "url", u"https://acme-v01.api.letsencrypt.org/directory" - )) + self.acme_url = six.text_type( + acme_config.get("url", "https://acme-v01.api.letsencrypt.org/directory") + ) self.acme_port = acme_config.get("port", 80) - self.acme_bind_addresses = acme_config.get("bind_addresses", ['::', '0.0.0.0']) + self.acme_bind_addresses = acme_config.get("bind_addresses", ["::", "0.0.0.0"]) self.acme_reprovision_threshold = acme_config.get("reprovision_threshold", 30) self.acme_domain = acme_config.get("domain", config.get("server_name")) @@ -74,12 +74,12 @@ class TlsConfig(Config): # Whether to verify certificates on outbound federation traffic self.federation_verify_certificates = config.get( - "federation_verify_certificates", True, + "federation_verify_certificates", True ) # Whitelist of domains to not verify certificates for fed_whitelist_entries = config.get( - "federation_certificate_verification_whitelist", [], + "federation_certificate_verification_whitelist", [] ) # Support globs (*) in whitelist values @@ -90,9 +90,7 @@ class TlsConfig(Config): self.federation_certificate_verification_whitelist.append(entry_regex) # List of custom certificate authorities for federation traffic validation - custom_ca_list = config.get( - "federation_custom_ca_list", None, - ) + custom_ca_list = config.get("federation_custom_ca_list", None) # Read in and parse custom CA certificates self.federation_ca_trust_root = None @@ -101,8 +99,10 @@ class TlsConfig(Config): # A trustroot cannot be generated without any CA certificates. # Raise an error if this option has been specified without any # corresponding certificates. - raise ConfigError("federation_custom_ca_list specified without " - "any certificate files") + raise ConfigError( + "federation_custom_ca_list specified without " + "any certificate files" + ) certs = [] for ca_file in custom_ca_list: @@ -114,8 +114,9 @@ class TlsConfig(Config): cert_base = Certificate.loadPEM(content) certs.append(cert_base) except Exception as e: - raise ConfigError("Error parsing custom CA certificate file %s: %s" - % (ca_file, e)) + raise ConfigError( + "Error parsing custom CA certificate file %s: %s" % (ca_file, e) + ) self.federation_ca_trust_root = trustRootFromCertificates(certs) @@ -146,17 +147,21 @@ class TlsConfig(Config): return None try: - with open(self.tls_certificate_file, 'rb') as f: + with open(self.tls_certificate_file, "rb") as f: cert_pem = f.read() except Exception as e: - raise ConfigError("Failed to read existing certificate file %s: %s" - % (self.tls_certificate_file, e)) + raise ConfigError( + "Failed to read existing certificate file %s: %s" + % (self.tls_certificate_file, e) + ) try: tls_certificate = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem) except Exception as e: - raise ConfigError("Failed to parse existing certificate file %s: %s" - % (self.tls_certificate_file, e)) + raise ConfigError( + "Failed to parse existing certificate file %s: %s" + % (self.tls_certificate_file, e) + ) if not allow_self_signed: if tls_certificate.get_subject() == tls_certificate.get_issuer(): @@ -166,7 +171,7 @@ class TlsConfig(Config): # YYYYMMDDhhmmssZ -- in UTC expires_on = datetime.strptime( - tls_certificate.get_notAfter().decode('ascii'), "%Y%m%d%H%M%SZ" + tls_certificate.get_notAfter().decode("ascii"), "%Y%m%d%H%M%SZ" ) now = datetime.utcnow() days_remaining = (expires_on - now).days @@ -191,7 +196,8 @@ class TlsConfig(Config): except Exception as e: logger.info( "Unable to read TLS certificate (%s). Ignoring as no " - "tls listeners enabled.", e, + "tls listeners enabled.", + e, ) self.tls_fingerprints = list(self._original_tls_fingerprints) @@ -205,7 +211,7 @@ class TlsConfig(Config): sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest()) sha256_fingerprints = set(f["sha256"] for f in self.tls_fingerprints) if sha256_fingerprint not in sha256_fingerprints: - self.tls_fingerprints.append({u"sha256": sha256_fingerprint}) + self.tls_fingerprints.append({"sha256": sha256_fingerprint}) def default_config(self, config_dir_path, server_name, **kwargs): base_key_name = os.path.join(config_dir_path, server_name) @@ -215,8 +221,8 @@ class TlsConfig(Config): # this is to avoid the max line length. Sorrynotsorry proxypassline = ( - 'ProxyPass /.well-known/acme-challenge ' - 'http://localhost:8009/.well-known/acme-challenge' + "ProxyPass /.well-known/acme-challenge " + "http://localhost:8009/.well-known/acme-challenge" ) return ( diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py index 023997ccde..e031b11599 100644 --- a/synapse/config/user_directory.py +++ b/synapse/config/user_directory.py @@ -26,11 +26,11 @@ class UserDirectoryConfig(Config): self.user_directory_search_all_users = False user_directory_config = config.get("user_directory", None) if user_directory_config: - self.user_directory_search_enabled = ( - user_directory_config.get("enabled", True) + self.user_directory_search_enabled = user_directory_config.get( + "enabled", True ) - self.user_directory_search_all_users = ( - user_directory_config.get("search_all_users", False) + self.user_directory_search_all_users = user_directory_config.get( + "search_all_users", False ) def default_config(self, config_dir_path, server_name, **kwargs): diff --git a/synapse/config/voip.py b/synapse/config/voip.py index 2a1f005a37..82cf8c53a8 100644 --- a/synapse/config/voip.py +++ b/synapse/config/voip.py @@ -16,14 +16,13 @@ from ._base import Config class VoipConfig(Config): - def read_config(self, config): self.turn_uris = config.get("turn_uris", []) self.turn_shared_secret = config.get("turn_shared_secret") self.turn_username = config.get("turn_username") self.turn_password = config.get("turn_password") self.turn_user_lifetime = self.parse_duration( - config.get("turn_user_lifetime", "1h"), + config.get("turn_user_lifetime", "1h") ) self.turn_allow_guests = config.get("turn_allow_guests", True) diff --git a/synapse/config/workers.py b/synapse/config/workers.py index bfbd8b6c91..75993abf35 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -52,12 +52,14 @@ class WorkerConfig(Config): # argument. manhole = config.get("worker_manhole") if manhole: - self.worker_listeners.append({ - "port": manhole, - "bind_addresses": ["127.0.0.1"], - "type": "manhole", - "tls": False, - }) + self.worker_listeners.append( + { + "port": manhole, + "bind_addresses": ["127.0.0.1"], + "type": "manhole", + "tls": False, + } + ) if self.worker_listeners: for listener in self.worker_listeners: @@ -67,7 +69,7 @@ class WorkerConfig(Config): if bind_address: bind_addresses.append(bind_address) elif not bind_addresses: - bind_addresses.append('') + bind_addresses.append("") def read_arguments(self, args): # We support a bunch of command line arguments that override options in diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py index 99a586655b..41eabbe717 100644 --- a/synapse/crypto/event_signing.py +++ b/synapse/crypto/event_signing.py @@ -46,9 +46,7 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256): if name not in hashes: raise SynapseError( 400, - "Algorithm %s not in hashes %s" % ( - name, list(hashes), - ), + "Algorithm %s not in hashes %s" % (name, list(hashes)), Codes.UNAUTHORIZED, ) message_hash_base64 = hashes[name] @@ -56,9 +54,7 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256): message_hash_bytes = decode_base64(message_hash_base64) except Exception: raise SynapseError( - 400, - "Invalid base64: %s" % (message_hash_base64,), - Codes.UNAUTHORIZED, + 400, "Invalid base64: %s" % (message_hash_base64,), Codes.UNAUTHORIZED ) return message_hash_bytes == expected_hash @@ -135,8 +131,9 @@ def compute_event_signature(event_dict, signature_name, signing_key): return redact_json["signatures"] -def add_hashes_and_signatures(event_dict, signature_name, signing_key, - hash_algorithm=hashlib.sha256): +def add_hashes_and_signatures( + event_dict, signature_name, signing_key, hash_algorithm=hashlib.sha256 +): """Add content hash and sign the event Args: @@ -153,7 +150,5 @@ def add_hashes_and_signatures(event_dict, signature_name, signing_key, event_dict.setdefault("hashes", {})[name] = encode_base64(digest) event_dict["signatures"] = compute_event_signature( - event_dict, - signature_name=signature_name, - signing_key=signing_key, + event_dict, signature_name=signature_name, signing_key=signing_key ) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 6f603f1961..10c2eb7f0f 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -505,7 +505,7 @@ class BaseV2KeyFetcher(object): Returns: Deferred[dict[str, FetchKeyResult]]: map from key_id to result object """ - ts_valid_until_ms = response_json[u"valid_until_ts"] + ts_valid_until_ms = response_json["valid_until_ts"] # start by extracting the keys from the response, since they may be required # to validate the signature on the response. @@ -614,10 +614,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): results = yield logcontext.make_deferred_yieldable( defer.gatherResults( - [ - run_in_background(get_key, server) - for server in self.key_servers - ], + [run_in_background(get_key, server) for server in self.key_servers], consumeErrors=True, ).addErrback(unwrapFirstError) ) @@ -630,9 +627,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): defer.returnValue(union_of_keys) @defer.inlineCallbacks - def get_server_verify_key_v2_indirect( - self, keys_to_fetch, key_server - ): + def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server): """ Args: keys_to_fetch (dict[str, dict[str, int]]): @@ -661,9 +656,9 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): destination=perspective_name, path="/_matrix/key/v2/query", data={ - u"server_keys": { + "server_keys": { server_name: { - key_id: {u"minimum_valid_until_ts": min_valid_ts} + key_id: {"minimum_valid_until_ts": min_valid_ts} for key_id, min_valid_ts in server_keys.items() } for server_name, server_keys in keys_to_fetch.items() @@ -690,10 +685,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): ) try: - self._validate_perspectives_response( - key_server, - response, - ) + self._validate_perspectives_response(key_server, response) processed_response = yield self.process_v2_response( perspective_name, response, time_added_ms=time_now_ms @@ -720,9 +712,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): defer.returnValue(keys) - def _validate_perspectives_response( - self, key_server, response, - ): + def _validate_perspectives_response(self, key_server, response): """Optionally check the signature on the result of a /key/query request Args: @@ -739,13 +729,13 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): return if ( - u"signatures" not in response - or perspective_name not in response[u"signatures"] + "signatures" not in response + or perspective_name not in response["signatures"] ): raise KeyLookupError("Response not signed by the notary server") verified = False - for key_id in response[u"signatures"][perspective_name]: + for key_id in response["signatures"][perspective_name]: if key_id in perspective_keys: verify_signed_json(response, perspective_name, perspective_keys[key_id]) verified = True @@ -754,7 +744,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): raise KeyLookupError( "Response not signed with a known key: signed with: %r, known keys: %r" % ( - list(response[u"signatures"][perspective_name].keys()), + list(response["signatures"][perspective_name].keys()), list(perspective_keys.keys()), ) ) @@ -826,7 +816,6 @@ class ServerKeyFetcher(BaseV2KeyFetcher): path="/_matrix/key/v2/server/" + urllib.parse.quote(requested_key_id), ignore_backoff=True, - # we only give the remote server 10s to respond. It should be an # easy request to handle, so if it doesn't reply within 10s, it's # probably not going to. diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 203490fc36..cd52e3f867 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -85,17 +85,14 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru room_id_domain = get_domain_from_id(event.room_id) if room_id_domain != sender_domain: raise AuthError( - 403, - "Creation event's room_id domain does not match sender's" + 403, "Creation event's room_id domain does not match sender's" ) room_version = event.content.get("room_version", "1") if room_version not in KNOWN_ROOM_VERSIONS: raise AuthError( - 403, - "room appears to have unsupported version %s" % ( - room_version, - )) + 403, "room appears to have unsupported version %s" % (room_version,) + ) # FIXME logger.debug("Allowing! %s", event) return @@ -103,46 +100,30 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru creation_event = auth_events.get((EventTypes.Create, ""), None) if not creation_event: - raise AuthError( - 403, - "No create event in auth events", - ) + raise AuthError(403, "No create event in auth events") creating_domain = get_domain_from_id(event.room_id) originating_domain = get_domain_from_id(event.sender) if creating_domain != originating_domain: if not _can_federate(event, auth_events): - raise AuthError( - 403, - "This room has been marked as unfederatable." - ) + raise AuthError(403, "This room has been marked as unfederatable.") # FIXME: Temp hack if event.type == EventTypes.Aliases: if not event.is_state(): - raise AuthError( - 403, - "Alias event must be a state event", - ) + raise AuthError(403, "Alias event must be a state event") if not event.state_key: - raise AuthError( - 403, - "Alias event must have non-empty state_key" - ) + raise AuthError(403, "Alias event must have non-empty state_key") sender_domain = get_domain_from_id(event.sender) if event.state_key != sender_domain: raise AuthError( - 403, - "Alias event's state_key does not match sender's domain" + 403, "Alias event's state_key does not match sender's domain" ) logger.debug("Allowing! %s", event) return if logger.isEnabledFor(logging.DEBUG): - logger.debug( - "Auth events: %s", - [a.event_id for a in auth_events.values()] - ) + logger.debug("Auth events: %s", [a.event_id for a in auth_events.values()]) if event.type == EventTypes.Member: _is_membership_change_allowed(event, auth_events) @@ -159,9 +140,7 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru invite_level = _get_named_level(auth_events, "invite", 0) if user_level < invite_level: - raise AuthError( - 403, "You don't have permission to invite users", - ) + raise AuthError(403, "You don't have permission to invite users") else: logger.debug("Allowing! %s", event) return @@ -207,7 +186,7 @@ def _is_membership_change_allowed(event, auth_events): # Check if this is the room creator joining: if len(event.prev_event_ids()) == 1 and Membership.JOIN == membership: # Get room creation event: - key = (EventTypes.Create, "", ) + key = (EventTypes.Create, "") create = auth_events.get(key) if create and event.prev_event_ids()[0] == create.event_id: if create.content["creator"] == event.state_key: @@ -219,38 +198,31 @@ def _is_membership_change_allowed(event, auth_events): target_domain = get_domain_from_id(target_user_id) if creating_domain != target_domain: if not _can_federate(event, auth_events): - raise AuthError( - 403, - "This room has been marked as unfederatable." - ) + raise AuthError(403, "This room has been marked as unfederatable.") # get info about the caller - key = (EventTypes.Member, event.user_id, ) + key = (EventTypes.Member, event.user_id) caller = auth_events.get(key) caller_in_room = caller and caller.membership == Membership.JOIN caller_invited = caller and caller.membership == Membership.INVITE # get info about the target - key = (EventTypes.Member, target_user_id, ) + key = (EventTypes.Member, target_user_id) target = auth_events.get(key) target_in_room = target and target.membership == Membership.JOIN target_banned = target and target.membership == Membership.BAN - key = (EventTypes.JoinRules, "", ) + key = (EventTypes.JoinRules, "") join_rule_event = auth_events.get(key) if join_rule_event: - join_rule = join_rule_event.content.get( - "join_rule", JoinRules.INVITE - ) + join_rule = join_rule_event.content.get("join_rule", JoinRules.INVITE) else: join_rule = JoinRules.INVITE user_level = get_user_power_level(event.user_id, auth_events) - target_level = get_user_power_level( - target_user_id, auth_events - ) + target_level = get_user_power_level(target_user_id, auth_events) # FIXME (erikj): What should we do here as the default? ban_level = _get_named_level(auth_events, "ban", 50) @@ -266,29 +238,26 @@ def _is_membership_change_allowed(event, auth_events): "join_rule": join_rule, "target_user_id": target_user_id, "event.user_id": event.user_id, - } + }, ) if Membership.INVITE == membership and "third_party_invite" in event.content: if not _verify_third_party_invite(event, auth_events): raise AuthError(403, "You are not invited to this room.") if target_banned: - raise AuthError( - 403, "%s is banned from the room" % (target_user_id,) - ) + raise AuthError(403, "%s is banned from the room" % (target_user_id,)) return if Membership.JOIN != membership: - if (caller_invited - and Membership.LEAVE == membership - and target_user_id == event.user_id): + if ( + caller_invited + and Membership.LEAVE == membership + and target_user_id == event.user_id + ): return if not caller_in_room: # caller isn't joined - raise AuthError( - 403, - "%s not in room %s." % (event.user_id, event.room_id,) - ) + raise AuthError(403, "%s not in room %s." % (event.user_id, event.room_id)) if Membership.INVITE == membership: # TODO (erikj): We should probably handle this more intelligently @@ -296,19 +265,14 @@ def _is_membership_change_allowed(event, auth_events): # Invites are valid iff caller is in the room and target isn't. if target_banned: - raise AuthError( - 403, "%s is banned from the room" % (target_user_id,) - ) + raise AuthError(403, "%s is banned from the room" % (target_user_id,)) elif target_in_room: # the target is already in the room. - raise AuthError(403, "%s is already in the room." % - target_user_id) + raise AuthError(403, "%s is already in the room." % target_user_id) else: invite_level = _get_named_level(auth_events, "invite", 0) if user_level < invite_level: - raise AuthError( - 403, "You don't have permission to invite users", - ) + raise AuthError(403, "You don't have permission to invite users") elif Membership.JOIN == membership: # Joins are valid iff caller == target and they were: # invited: They are accepting the invitation @@ -329,16 +293,12 @@ def _is_membership_change_allowed(event, auth_events): elif Membership.LEAVE == membership: # TODO (erikj): Implement kicks. if target_banned and user_level < ban_level: - raise AuthError( - 403, "You cannot unban user %s." % (target_user_id,) - ) + raise AuthError(403, "You cannot unban user %s." % (target_user_id,)) elif target_user_id != event.user_id: kick_level = _get_named_level(auth_events, "kick", 50) if user_level < kick_level or user_level <= target_level: - raise AuthError( - 403, "You cannot kick user %s." % target_user_id - ) + raise AuthError(403, "You cannot kick user %s." % target_user_id) elif Membership.BAN == membership: if user_level < ban_level or user_level <= target_level: raise AuthError(403, "You don't have permission to ban") @@ -347,21 +307,17 @@ def _is_membership_change_allowed(event, auth_events): def _check_event_sender_in_room(event, auth_events): - key = (EventTypes.Member, event.user_id, ) + key = (EventTypes.Member, event.user_id) member_event = auth_events.get(key) - return _check_joined_room( - member_event, - event.user_id, - event.room_id - ) + return _check_joined_room(member_event, event.user_id, event.room_id) def _check_joined_room(member, user_id, room_id): if not member or member.membership != Membership.JOIN: - raise AuthError(403, "User %s not in room %s (%s)" % ( - user_id, room_id, repr(member) - )) + raise AuthError( + 403, "User %s not in room %s (%s)" % (user_id, room_id, repr(member)) + ) def get_send_level(etype, state_key, power_levels_event): @@ -402,26 +358,21 @@ def get_send_level(etype, state_key, power_levels_event): def _can_send_event(event, auth_events): power_levels_event = _get_power_level_event(auth_events) - send_level = get_send_level( - event.type, event.get("state_key"), power_levels_event, - ) + send_level = get_send_level(event.type, event.get("state_key"), power_levels_event) user_level = get_user_power_level(event.user_id, auth_events) if user_level < send_level: raise AuthError( 403, - "You don't have permission to post that to the room. " + - "user_level (%d) < send_level (%d)" % (user_level, send_level) + "You don't have permission to post that to the room. " + + "user_level (%d) < send_level (%d)" % (user_level, send_level), ) # Check state_key if hasattr(event, "state_key"): if event.state_key.startswith("@"): if event.state_key != event.user_id: - raise AuthError( - 403, - "You are not allowed to set others state" - ) + raise AuthError(403, "You are not allowed to set others state") return True @@ -459,10 +410,7 @@ def check_redaction(room_version, event, auth_events): event.internal_metadata.recheck_redaction = True return True - raise AuthError( - 403, - "You don't have permission to redact events" - ) + raise AuthError(403, "You don't have permission to redact events") def _check_power_levels(event, auth_events): @@ -479,7 +427,7 @@ def _check_power_levels(event, auth_events): except Exception: raise SynapseError(400, "Not a valid power level: %s" % (v,)) - key = (event.type, event.state_key, ) + key = (event.type, event.state_key) current_state = auth_events.get(key) if not current_state: @@ -500,16 +448,12 @@ def _check_power_levels(event, auth_events): old_list = current_state.content.get("users", {}) for user in set(list(old_list) + list(user_list)): - levels_to_check.append( - (user, "users") - ) + levels_to_check.append((user, "users")) old_list = current_state.content.get("events", {}) new_list = event.content.get("events", {}) for ev_id in set(list(old_list) + list(new_list)): - levels_to_check.append( - (ev_id, "events") - ) + levels_to_check.append((ev_id, "events")) old_state = current_state.content new_state = event.content @@ -540,7 +484,7 @@ def _check_power_levels(event, auth_events): raise AuthError( 403, "You don't have permission to remove ops level equal " - "to your own" + "to your own", ) # Check if the old and new levels are greater than the user level @@ -550,8 +494,7 @@ def _check_power_levels(event, auth_events): if old_level_too_big or new_level_too_big: raise AuthError( 403, - "You don't have permission to add ops level greater " - "than your own" + "You don't have permission to add ops level greater " "than your own", ) @@ -587,10 +530,9 @@ def get_user_power_level(user_id, auth_events): # some things which call this don't pass the create event: hack around # that. - key = (EventTypes.Create, "", ) + key = (EventTypes.Create, "") create_event = auth_events.get(key) - if (create_event is not None and - create_event.content["creator"] == user_id): + if create_event is not None and create_event.content["creator"] == user_id: return 100 else: return 0 @@ -636,9 +578,7 @@ def _verify_third_party_invite(event, auth_events): token = signed["token"] - invite_event = auth_events.get( - (EventTypes.ThirdPartyInvite, token,) - ) + invite_event = auth_events.get((EventTypes.ThirdPartyInvite, token)) if not invite_event: return False @@ -661,8 +601,7 @@ def _verify_third_party_invite(event, auth_events): if not key_name.startswith("ed25519:"): continue verify_key = decode_verify_key_bytes( - key_name, - decode_base64(public_key) + key_name, decode_base64(public_key) ) verify_signed_json(signed, server, verify_key) @@ -671,7 +610,7 @@ def _verify_third_party_invite(event, auth_events): # The caller is responsible for checking that the signing # server has not revoked that public key. return True - except (KeyError, SignatureVerifyException,): + except (KeyError, SignatureVerifyException): continue return False @@ -679,9 +618,7 @@ def _verify_third_party_invite(event, auth_events): def get_public_keys(invite_event): public_keys = [] if "public_key" in invite_event.content: - o = { - "public_key": invite_event.content["public_key"], - } + o = {"public_key": invite_event.content["public_key"]} if "key_validity_url" in invite_event.content: o["key_validity_url"] = invite_event.content["key_validity_url"] public_keys.append(o) @@ -702,22 +639,22 @@ def auth_types_for_event(event): auth_types = [] - auth_types.append((EventTypes.PowerLevels, "", )) - auth_types.append((EventTypes.Member, event.sender, )) - auth_types.append((EventTypes.Create, "", )) + auth_types.append((EventTypes.PowerLevels, "")) + auth_types.append((EventTypes.Member, event.sender)) + auth_types.append((EventTypes.Create, "")) if event.type == EventTypes.Member: membership = event.content["membership"] if membership in [Membership.JOIN, Membership.INVITE]: - auth_types.append((EventTypes.JoinRules, "", )) + auth_types.append((EventTypes.JoinRules, "")) - auth_types.append((EventTypes.Member, event.state_key, )) + auth_types.append((EventTypes.Member, event.state_key)) if membership == Membership.INVITE: if "third_party_invite" in event.content: key = ( EventTypes.ThirdPartyInvite, - event.content["third_party_invite"]["signed"]["token"] + event.content["third_party_invite"]["signed"]["token"], ) auth_types.append(key) diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 7154bcbea6..d3de70e671 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -127,25 +127,25 @@ def _event_dict_property(key): except KeyError: raise AttributeError(key) - return property( - getter, - setter, - delete, - ) + return property(getter, setter, delete) class EventBase(object): - def __init__(self, event_dict, signatures={}, unsigned={}, - internal_metadata_dict={}, rejected_reason=None): + def __init__( + self, + event_dict, + signatures={}, + unsigned={}, + internal_metadata_dict={}, + rejected_reason=None, + ): self.signatures = signatures self.unsigned = unsigned self.rejected_reason = rejected_reason self._event_dict = event_dict - self.internal_metadata = _EventInternalMetadata( - internal_metadata_dict - ) + self.internal_metadata = _EventInternalMetadata(internal_metadata_dict) auth_events = _event_dict_property("auth_events") depth = _event_dict_property("depth") @@ -168,10 +168,7 @@ class EventBase(object): def get_dict(self): d = dict(self._event_dict) - d.update({ - "signatures": self.signatures, - "unsigned": dict(self.unsigned), - }) + d.update({"signatures": self.signatures, "unsigned": dict(self.unsigned)}) return d @@ -358,6 +355,7 @@ class FrozenEventV2(EventBase): class FrozenEventV3(FrozenEventV2): """FrozenEventV3, which differs from FrozenEventV2 only in the event_id format""" + format_version = EventFormatVersions.V3 # All events of this type are V3 @property @@ -414,6 +412,4 @@ def event_type_from_format_version(format_version): elif format_version == EventFormatVersions.V3: return FrozenEventV3 else: - raise Exception( - "No event format %r" % (format_version,) - ) + raise Exception("No event format %r" % (format_version,)) diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 546b6f4982..db011e0407 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -78,7 +78,9 @@ class EventBuilder(object): _redacts = attr.ib(default=None) _origin_server_ts = attr.ib(default=None) - internal_metadata = attr.ib(default=attr.Factory(lambda: _EventInternalMetadata({}))) + internal_metadata = attr.ib( + default=attr.Factory(lambda: _EventInternalMetadata({})) + ) @property def state_key(self): @@ -102,11 +104,9 @@ class EventBuilder(object): """ state_ids = yield self._state.get_current_state_ids( - self.room_id, prev_event_ids, - ) - auth_ids = yield self._auth.compute_auth_events( - self, state_ids, + self.room_id, prev_event_ids ) + auth_ids = yield self._auth.compute_auth_events(self, state_ids) if self.format_version == EventFormatVersions.V1: auth_events = yield self._store.add_event_hashes(auth_ids) @@ -115,9 +115,7 @@ class EventBuilder(object): auth_events = auth_ids prev_events = prev_event_ids - old_depth = yield self._store.get_max_depth_of( - prev_event_ids, - ) + old_depth = yield self._store.get_max_depth_of(prev_event_ids) depth = old_depth + 1 # we cap depth of generated events, to ensure that they are not @@ -217,9 +215,14 @@ class EventBuilderFactory(object): ) -def create_local_event_from_event_dict(clock, hostname, signing_key, - format_version, event_dict, - internal_metadata_dict=None): +def create_local_event_from_event_dict( + clock, + hostname, + signing_key, + format_version, + event_dict, + internal_metadata_dict=None, +): """Takes a fully formed event dict, ensuring that fields like `origin` and `origin_server_ts` have correct values for a locally produced event, then signs and hashes it. @@ -237,9 +240,7 @@ def create_local_event_from_event_dict(clock, hostname, signing_key, """ if format_version not in KNOWN_EVENT_FORMAT_VERSIONS: - raise Exception( - "No event format defined for version %r" % (format_version,) - ) + raise Exception("No event format defined for version %r" % (format_version,)) if internal_metadata_dict is None: internal_metadata_dict = {} @@ -258,13 +259,9 @@ def create_local_event_from_event_dict(clock, hostname, signing_key, event_dict.setdefault("signatures", {}) - add_hashes_and_signatures( - event_dict, - hostname, - signing_key, - ) + add_hashes_and_signatures(event_dict, hostname, signing_key) return event_type_from_format_version(format_version)( - event_dict, internal_metadata_dict=internal_metadata_dict, + event_dict, internal_metadata_dict=internal_metadata_dict ) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index fa09c132a0..a96cdada3d 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -88,8 +88,9 @@ class EventContext(object): self.app_service = None @staticmethod - def with_state(state_group, current_state_ids, prev_state_ids, - prev_group=None, delta_ids=None): + def with_state( + state_group, current_state_ids, prev_state_ids, prev_group=None, delta_ids=None + ): context = EventContext() # The current state including the current event @@ -132,17 +133,19 @@ class EventContext(object): else: prev_state_id = None - defer.returnValue({ - "prev_state_id": prev_state_id, - "event_type": event.type, - "event_state_key": event.state_key if event.is_state() else None, - "state_group": self.state_group, - "rejected": self.rejected, - "prev_group": self.prev_group, - "delta_ids": _encode_state_dict(self.delta_ids), - "prev_state_events": self.prev_state_events, - "app_service_id": self.app_service.id if self.app_service else None - }) + defer.returnValue( + { + "prev_state_id": prev_state_id, + "event_type": event.type, + "event_state_key": event.state_key if event.is_state() else None, + "state_group": self.state_group, + "rejected": self.rejected, + "prev_group": self.prev_group, + "delta_ids": _encode_state_dict(self.delta_ids), + "prev_state_events": self.prev_state_events, + "app_service_id": self.app_service.id if self.app_service else None, + } + ) @staticmethod def deserialize(store, input): @@ -194,7 +197,7 @@ class EventContext(object): if not self._fetching_state_deferred: self._fetching_state_deferred = run_in_background( - self._fill_out_state, store, + self._fill_out_state, store ) yield make_deferred_yieldable(self._fetching_state_deferred) @@ -214,7 +217,7 @@ class EventContext(object): if not self._fetching_state_deferred: self._fetching_state_deferred = run_in_background( - self._fill_out_state, store, + self._fill_out_state, store ) yield make_deferred_yieldable(self._fetching_state_deferred) @@ -240,9 +243,7 @@ class EventContext(object): if self.state_group is None: return - self._current_state_ids = yield store.get_state_ids_for_group( - self.state_group, - ) + self._current_state_ids = yield store.get_state_ids_for_group(self.state_group) if self._prev_state_id and self._event_state_key is not None: self._prev_state_ids = dict(self._current_state_ids) @@ -252,8 +253,9 @@ class EventContext(object): self._prev_state_ids = self._current_state_ids @defer.inlineCallbacks - def update_state(self, state_group, prev_state_ids, current_state_ids, - prev_group, delta_ids): + def update_state( + self, state_group, prev_state_ids, current_state_ids, prev_group, delta_ids + ): """Replace the state in the context """ @@ -279,10 +281,7 @@ def _encode_state_dict(state_dict): if state_dict is None: return None - return [ - (etype, state_key, v) - for (etype, state_key), v in iteritems(state_dict) - ] + return [(etype, state_key, v) for (etype, state_key), v in iteritems(state_dict)] def _decode_state_dict(input): @@ -291,4 +290,4 @@ def _decode_state_dict(input): if input is None: return None - return frozendict({(etype, state_key,): v for etype, state_key, v in input}) + return frozendict({(etype, state_key): v for etype, state_key, v in input}) diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 6058077f75..129771f183 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -60,7 +60,9 @@ class SpamChecker(object): if self.spam_checker is None: return True - return self.spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id) + return self.spam_checker.user_may_invite( + inviter_userid, invitee_userid, room_id + ) def user_may_create_room(self, userid): """Checks if a given user may create a room diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 50ceeb1e8e..8f5d95696b 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -36,8 +36,7 @@ class ThirdPartyEventRules(object): if module is not None: self.third_party_rules = module( - config=config, - http_client=hs.get_simple_http_client(), + config=config, http_client=hs.get_simple_http_client() ) @defer.inlineCallbacks @@ -109,6 +108,6 @@ class ThirdPartyEventRules(object): state_events[key] = room_state_events[event_id] ret = yield self.third_party_rules.check_threepid_can_be_invited( - medium, address, state_events, + medium, address, state_events ) defer.returnValue(ret) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index e2d4384de1..f24f0c16f0 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -31,7 +31,7 @@ from . import EventBase # by a match for 'stuff'. # TODO: This is fast, but fails to handle "foo\\.bar" which should be treated as # the literal fields "foo\" and "bar" but will instead be treated as "foo\\.bar" -SPLIT_FIELD_REGEX = re.compile(r'(? MAX_ALIAS_LENGTH: raise SynapseError( 400, - ("Can't create aliases longer than" - " %d characters" % (MAX_ALIAS_LENGTH,)), + ( + "Can't create aliases longer than" + " %d characters" % (MAX_ALIAS_LENGTH,) + ), Codes.INVALID_PARAM, ) @@ -76,11 +76,7 @@ class EventValidator(object): event (EventBuilder|FrozenEvent) """ - strings = [ - "room_id", - "sender", - "type", - ] + strings = ["room_id", "sender", "type"] if hasattr(event, "state_key"): strings.append("state_key") @@ -93,10 +89,7 @@ class EventValidator(object): UserID.from_string(event.sender) if event.type == EventTypes.Message: - strings = [ - "body", - "msgtype", - ] + strings = ["body", "msgtype"] self._ensure_strings(event.content, strings) diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index fc5cfb7d83..58b929363f 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -44,8 +44,9 @@ class FederationBase(object): self._clock = hs.get_clock() @defer.inlineCallbacks - def _check_sigs_and_hash_and_fetch(self, origin, pdus, room_version, - outlier=False, include_none=False): + def _check_sigs_and_hash_and_fetch( + self, origin, pdus, room_version, outlier=False, include_none=False + ): """Takes a list of PDUs and checks the signatures and hashs of each one. If a PDU fails its signature check then we check if we have it in the database and if not then request if from the originating server of @@ -79,9 +80,7 @@ class FederationBase(object): if not res: # Check local db. res = yield self.store.get_event( - pdu.event_id, - allow_rejected=True, - allow_none=True, + pdu.event_id, allow_rejected=True, allow_none=True ) if not res and pdu.origin != origin: @@ -98,23 +97,16 @@ class FederationBase(object): if not res: logger.warn( - "Failed to find copy of %s with valid signature", - pdu.event_id, + "Failed to find copy of %s with valid signature", pdu.event_id ) defer.returnValue(res) handle = logcontext.preserve_fn(handle_check_result) - deferreds2 = [ - handle(pdu, deferred) - for pdu, deferred in zip(pdus, deferreds) - ] + deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)] valid_pdus = yield logcontext.make_deferred_yieldable( - defer.gatherResults( - deferreds2, - consumeErrors=True, - ) + defer.gatherResults(deferreds2, consumeErrors=True) ).addErrback(unwrapFirstError) if include_none: @@ -124,7 +116,7 @@ class FederationBase(object): def _check_sigs_and_hash(self, room_version, pdu): return logcontext.make_deferred_yieldable( - self._check_sigs_and_hashes(room_version, [pdu])[0], + self._check_sigs_and_hashes(room_version, [pdu])[0] ) def _check_sigs_and_hashes(self, room_version, pdus): @@ -159,11 +151,9 @@ class FederationBase(object): # received event was probably a redacted copy (but we then use our # *actual* redacted copy to be on the safe side.) redacted_event = prune_event(pdu) - if ( - set(redacted_event.keys()) == set(pdu.keys()) and - set(six.iterkeys(redacted_event.content)) - == set(six.iterkeys(pdu.content)) - ): + if set(redacted_event.keys()) == set(pdu.keys()) and set( + six.iterkeys(redacted_event.content) + ) == set(six.iterkeys(pdu.content)): logger.info( "Event %s seems to have been redacted; using our redacted " "copy", @@ -172,14 +162,16 @@ class FederationBase(object): else: logger.warning( "Event %s content has been tampered, redacting", - pdu.event_id, pdu.get_pdu_json(), + pdu.event_id, + pdu.get_pdu_json(), ) return redacted_event if self.spam_checker.check_event_for_spam(pdu): logger.warn( "Event contains spam, redacting %s: %s", - pdu.event_id, pdu.get_pdu_json() + pdu.event_id, + pdu.get_pdu_json(), ) return prune_event(pdu) @@ -190,23 +182,24 @@ class FederationBase(object): with logcontext.PreserveLoggingContext(ctx): logger.warn( "Signature check failed for %s: %s", - pdu.event_id, failure.getErrorMessage(), + pdu.event_id, + failure.getErrorMessage(), ) return failure for deferred, pdu in zip(deferreds, pdus): deferred.addCallbacks( - callback, errback, - callbackArgs=[pdu], - errbackArgs=[pdu], + callback, errback, callbackArgs=[pdu], errbackArgs=[pdu] ) return deferreds -class PduToCheckSig(namedtuple("PduToCheckSig", [ - "pdu", "redacted_pdu_json", "sender_domain", "deferreds", -])): +class PduToCheckSig( + namedtuple( + "PduToCheckSig", ["pdu", "redacted_pdu_json", "sender_domain", "deferreds"] + ) +): pass @@ -260,10 +253,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus): # First we check that the sender event is signed by the sender's domain # (except if its a 3pid invite, in which case it may be sent by any server) - pdus_to_check_sender = [ - p for p in pdus_to_check - if not _is_invite_via_3pid(p.pdu) - ] + pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)] more_deferreds = keyring.verify_json_objects_for_server( [ @@ -297,7 +287,8 @@ def _check_sigs_on_pdus(keyring, room_version, pdus): # (ie, the room version uses old-style non-hash event IDs). if v.event_format == EventFormatVersions.V1: pdus_to_check_event_id = [ - p for p in pdus_to_check + p + for p in pdus_to_check if p.sender_domain != get_domain_from_id(p.pdu.event_id) ] @@ -315,10 +306,8 @@ def _check_sigs_on_pdus(keyring, room_version, pdus): def event_err(e, pdu_to_check): errmsg = ( - "event id %s: unable to verify signature for event id domain: %s" % ( - pdu_to_check.pdu.event_id, - e.getErrorMessage(), - ) + "event id %s: unable to verify signature for event id domain: %s" + % (pdu_to_check.pdu.event_id, e.getErrorMessage()) ) # XX as above: not really sure if these are the right codes raise SynapseError(400, errmsg, Codes.UNAUTHORIZED) @@ -368,21 +357,18 @@ def event_from_pdu_json(pdu_json, event_format_version, outlier=False): """ # we could probably enforce a bunch of other fields here (room_id, sender, # origin, etc etc) - assert_params_in_dict(pdu_json, ('type', 'depth')) + assert_params_in_dict(pdu_json, ("type", "depth")) - depth = pdu_json['depth'] + depth = pdu_json["depth"] if not isinstance(depth, six.integer_types): - raise SynapseError(400, "Depth %r not an intger" % (depth, ), - Codes.BAD_JSON) + raise SynapseError(400, "Depth %r not an intger" % (depth,), Codes.BAD_JSON) if depth < 0: raise SynapseError(400, "Depth too small", Codes.BAD_JSON) elif depth > MAX_DEPTH: raise SynapseError(400, "Depth too large", Codes.BAD_JSON) - event = event_type_from_format_version(event_format_version)( - pdu_json, - ) + event = event_type_from_format_version(event_format_version)(pdu_json) event.internal_metadata.outlier = outlier diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 70573746d6..3883eb525e 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -57,6 +57,7 @@ class InvalidResponseError(RuntimeError): """Helper for _try_destination_list: indicates that the server returned a response we couldn't parse """ + pass @@ -65,9 +66,7 @@ class FederationClient(FederationBase): super(FederationClient, self).__init__(hs) self.pdu_destination_tried = {} - self._clock.looping_call( - self._clear_tried_cache, 60 * 1000, - ) + self._clock.looping_call(self._clear_tried_cache, 60 * 1000) self.state = hs.get_state_handler() self.transport_layer = hs.get_federation_transport_client() @@ -99,8 +98,14 @@ class FederationClient(FederationBase): self.pdu_destination_tried[event_id] = destination_dict @log_function - def make_query(self, destination, query_type, args, - retry_on_dns_fail=False, ignore_backoff=False): + def make_query( + self, + destination, + query_type, + args, + retry_on_dns_fail=False, + ignore_backoff=False, + ): """Sends a federation Query to a remote homeserver of the given type and arguments. @@ -120,7 +125,10 @@ class FederationClient(FederationBase): sent_queries_counter.labels(query_type).inc() return self.transport_layer.make_query( - destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail, + destination, + query_type, + args, + retry_on_dns_fail=retry_on_dns_fail, ignore_backoff=ignore_backoff, ) @@ -137,9 +145,7 @@ class FederationClient(FederationBase): response """ sent_queries_counter.labels("client_device_keys").inc() - return self.transport_layer.query_client_keys( - destination, content, timeout - ) + return self.transport_layer.query_client_keys(destination, content, timeout) @log_function def query_user_devices(self, destination, user_id, timeout=30000): @@ -147,9 +153,7 @@ class FederationClient(FederationBase): server. """ sent_queries_counter.labels("user_devices").inc() - return self.transport_layer.query_user_devices( - destination, user_id, timeout - ) + return self.transport_layer.query_user_devices(destination, user_id, timeout) @log_function def claim_client_keys(self, destination, content, timeout): @@ -164,9 +168,7 @@ class FederationClient(FederationBase): response """ sent_queries_counter.labels("client_one_time_keys").inc() - return self.transport_layer.claim_client_keys( - destination, content, timeout - ) + return self.transport_layer.claim_client_keys(destination, content, timeout) @defer.inlineCallbacks @log_function @@ -191,7 +193,8 @@ class FederationClient(FederationBase): return transaction_data = yield self.transport_layer.backfill( - dest, room_id, extremities, limit) + dest, room_id, extremities, limit + ) logger.debug("backfill transaction_data=%s", repr(transaction_data)) @@ -204,17 +207,19 @@ class FederationClient(FederationBase): ] # FIXME: We should handle signature failures more gracefully. - pdus[:] = yield logcontext.make_deferred_yieldable(defer.gatherResults( - self._check_sigs_and_hashes(room_version, pdus), - consumeErrors=True, - ).addErrback(unwrapFirstError)) + pdus[:] = yield logcontext.make_deferred_yieldable( + defer.gatherResults( + self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True + ).addErrback(unwrapFirstError) + ) defer.returnValue(pdus) @defer.inlineCallbacks @log_function - def get_pdu(self, destinations, event_id, room_version, outlier=False, - timeout=None): + def get_pdu( + self, destinations, event_id, room_version, outlier=False, timeout=None + ): """Requests the PDU with given origin and ID from the remote home servers. @@ -255,7 +260,7 @@ class FederationClient(FederationBase): try: transaction_data = yield self.transport_layer.get_event( - destination, event_id, timeout=timeout, + destination, event_id, timeout=timeout ) logger.debug( @@ -282,8 +287,7 @@ class FederationClient(FederationBase): except SynapseError as e: logger.info( - "Failed to get PDU %s from %s because %s", - event_id, destination, e, + "Failed to get PDU %s from %s because %s", event_id, destination, e ) continue except NotRetryingDestination as e: @@ -296,8 +300,7 @@ class FederationClient(FederationBase): pdu_attempts[destination] = now logger.info( - "Failed to get PDU %s from %s because %s", - event_id, destination, e, + "Failed to get PDU %s from %s because %s", event_id, destination, e ) continue @@ -326,7 +329,7 @@ class FederationClient(FederationBase): # we have most of the state and auth_chain already. # However, this may 404 if the other side has an old synapse. result = yield self.transport_layer.get_room_state_ids( - destination, room_id, event_id=event_id, + destination, room_id, event_id=event_id ) state_event_ids = result["pdu_ids"] @@ -340,12 +343,10 @@ class FederationClient(FederationBase): logger.warning( "Failed to fetch missing state/auth events for %s: %s", room_id, - failed_to_fetch + failed_to_fetch, ) - event_map = { - ev.event_id: ev for ev in fetched_events - } + event_map = {ev.event_id: ev for ev in fetched_events} pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map] auth_chain = [ @@ -362,15 +363,14 @@ class FederationClient(FederationBase): raise e result = yield self.transport_layer.get_room_state( - destination, room_id, event_id=event_id, + destination, room_id, event_id=event_id ) room_version = yield self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) pdus = [ - event_from_pdu_json(p, format_ver, outlier=True) - for p in result["pdus"] + event_from_pdu_json(p, format_ver, outlier=True) for p in result["pdus"] ] auth_chain = [ @@ -378,9 +378,9 @@ class FederationClient(FederationBase): for p in result.get("auth_chain", []) ] - seen_events = yield self.store.get_events([ - ev.event_id for ev in itertools.chain(pdus, auth_chain) - ]) + seen_events = yield self.store.get_events( + [ev.event_id for ev in itertools.chain(pdus, auth_chain)] + ) signed_pdus = yield self._check_sigs_and_hash_and_fetch( destination, @@ -442,7 +442,7 @@ class FederationClient(FederationBase): batch_size = 20 missing_events = list(missing_events) for i in range(0, len(missing_events), batch_size): - batch = set(missing_events[i:i + batch_size]) + batch = set(missing_events[i : i + batch_size]) deferreds = [ run_in_background( @@ -470,21 +470,17 @@ class FederationClient(FederationBase): @defer.inlineCallbacks @log_function def get_event_auth(self, destination, room_id, event_id): - res = yield self.transport_layer.get_event_auth( - destination, room_id, event_id, - ) + res = yield self.transport_layer.get_event_auth(destination, room_id, event_id) room_version = yield self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) auth_chain = [ - event_from_pdu_json(p, format_ver, outlier=True) - for p in res["auth_chain"] + event_from_pdu_json(p, format_ver, outlier=True) for p in res["auth_chain"] ] signed_auth = yield self._check_sigs_and_hash_and_fetch( - destination, auth_chain, - outlier=True, room_version=room_version, + destination, auth_chain, outlier=True, room_version=room_version ) signed_auth.sort(key=lambda e: e.depth) @@ -527,28 +523,26 @@ class FederationClient(FederationBase): res = yield callback(destination) defer.returnValue(res) except InvalidResponseError as e: - logger.warn( - "Failed to %s via %s: %s", - description, destination, e, - ) + logger.warn("Failed to %s via %s: %s", description, destination, e) except HttpResponseException as e: if not 500 <= e.code < 600: raise e.to_synapse_error() else: logger.warn( "Failed to %s via %s: %i %s", - description, destination, e.code, e.args[0], + description, + destination, + e.code, + e.args[0], ) except Exception: - logger.warn( - "Failed to %s via %s", - description, destination, exc_info=1, - ) + logger.warn("Failed to %s via %s", description, destination, exc_info=1) - raise RuntimeError("Failed to %s via any server" % (description, )) + raise RuntimeError("Failed to %s via any server" % (description,)) - def make_membership_event(self, destinations, room_id, user_id, membership, - content, params): + def make_membership_event( + self, destinations, room_id, user_id, membership, content, params + ): """ Creates an m.room.member event, with context, without participating in the room. @@ -584,14 +578,14 @@ class FederationClient(FederationBase): valid_memberships = {Membership.JOIN, Membership.LEAVE} if membership not in valid_memberships: raise RuntimeError( - "make_membership_event called with membership='%s', must be one of %s" % - (membership, ",".join(valid_memberships)) + "make_membership_event called with membership='%s', must be one of %s" + % (membership, ",".join(valid_memberships)) ) @defer.inlineCallbacks def send_request(destination): ret = yield self.transport_layer.make_membership_event( - destination, room_id, user_id, membership, params, + destination, room_id, user_id, membership, params ) # Note: If not supplied, the room version may be either v1 or v2, @@ -614,16 +608,17 @@ class FederationClient(FederationBase): pdu_dict["prev_state"] = [] ev = builder.create_local_event_from_event_dict( - self._clock, self.hostname, self.signing_key, - format_version=event_format, event_dict=pdu_dict, + self._clock, + self.hostname, + self.signing_key, + format_version=event_format, + event_dict=pdu_dict, ) - defer.returnValue( - (destination, ev, event_format) - ) + defer.returnValue((destination, ev, event_format)) return self._try_destination_list( - "make_" + membership, destinations, send_request, + "make_" + membership, destinations, send_request ) def send_join(self, destinations, pdu, event_format_version): @@ -655,9 +650,7 @@ class FederationClient(FederationBase): create_event = e break else: - raise InvalidResponseError( - "no %s in auth chain" % (EventTypes.Create,), - ) + raise InvalidResponseError("no %s in auth chain" % (EventTypes.Create,)) # the room version should be sane. room_version = create_event.content.get("room_version", "1") @@ -665,9 +658,8 @@ class FederationClient(FederationBase): # This shouldn't be possible, because the remote server should have # rejected the join attempt during make_join. raise InvalidResponseError( - "room appears to have unsupported version %s" % ( - room_version, - )) + "room appears to have unsupported version %s" % (room_version,) + ) @defer.inlineCallbacks def send_request(destination): @@ -691,10 +683,7 @@ class FederationClient(FederationBase): for p in content.get("auth_chain", []) ] - pdus = { - p.event_id: p - for p in itertools.chain(state, auth_chain) - } + pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)} room_version = None for e in state: @@ -710,15 +699,13 @@ class FederationClient(FederationBase): raise SynapseError(400, "No create event in state") valid_pdus = yield self._check_sigs_and_hash_and_fetch( - destination, list(pdus.values()), + destination, + list(pdus.values()), outlier=True, room_version=room_version, ) - valid_pdus_map = { - p.event_id: p - for p in valid_pdus - } + valid_pdus_map = {p.event_id: p for p in valid_pdus} # NB: We *need* to copy to ensure that we don't have multiple # references being passed on, as that causes... issues. @@ -741,11 +728,14 @@ class FederationClient(FederationBase): check_authchain_validity(signed_auth) - defer.returnValue({ - "state": signed_state, - "auth_chain": signed_auth, - "origin": destination, - }) + defer.returnValue( + { + "state": signed_state, + "auth_chain": signed_auth, + "origin": destination, + } + ) + return self._try_destination_list("send_join", destinations, send_request) @defer.inlineCallbacks @@ -854,6 +844,7 @@ class FederationClient(FederationBase): Fails with a ``RuntimeError`` if no servers were reachable. """ + @defer.inlineCallbacks def send_request(destination): time_now = self._clock.time_msec() @@ -869,14 +860,23 @@ class FederationClient(FederationBase): return self._try_destination_list("send_leave", destinations, send_request) - def get_public_rooms(self, destination, limit=None, since_token=None, - search_filter=None, include_all_networks=False, - third_party_instance_id=None): + def get_public_rooms( + self, + destination, + limit=None, + since_token=None, + search_filter=None, + include_all_networks=False, + third_party_instance_id=None, + ): if destination == self.server_name: return return self.transport_layer.get_public_rooms( - destination, limit, since_token, search_filter, + destination, + limit, + since_token, + search_filter, include_all_networks=include_all_networks, third_party_instance_id=third_party_instance_id, ) @@ -891,9 +891,7 @@ class FederationClient(FederationBase): """ time_now = self._clock.time_msec() - send_content = { - "auth_chain": [e.get_pdu_json(time_now) for e in local_auth], - } + send_content = {"auth_chain": [e.get_pdu_json(time_now) for e in local_auth]} code, content = yield self.transport_layer.send_query_auth( destination=destination, @@ -905,13 +903,10 @@ class FederationClient(FederationBase): room_version = yield self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) - auth_chain = [ - event_from_pdu_json(e, format_ver) - for e in content["auth_chain"] - ] + auth_chain = [event_from_pdu_json(e, format_ver) for e in content["auth_chain"]] signed_auth = yield self._check_sigs_and_hash_and_fetch( - destination, auth_chain, outlier=True, room_version=room_version, + destination, auth_chain, outlier=True, room_version=room_version ) signed_auth.sort(key=lambda e: e.depth) @@ -925,8 +920,16 @@ class FederationClient(FederationBase): defer.returnValue(ret) @defer.inlineCallbacks - def get_missing_events(self, destination, room_id, earliest_events_ids, - latest_events, limit, min_depth, timeout): + def get_missing_events( + self, + destination, + room_id, + earliest_events_ids, + latest_events, + limit, + min_depth, + timeout, + ): """Tries to fetch events we are missing. This is called when we receive an event without having received all of its ancestors. @@ -957,12 +960,11 @@ class FederationClient(FederationBase): format_ver = room_version_to_event_format(room_version) events = [ - event_from_pdu_json(e, format_ver) - for e in content.get("events", []) + event_from_pdu_json(e, format_ver) for e in content.get("events", []) ] signed_events = yield self._check_sigs_and_hash_and_fetch( - destination, events, outlier=False, room_version=room_version, + destination, events, outlier=False, room_version=room_version ) except HttpResponseException as e: if not e.code == 400: @@ -982,17 +984,14 @@ class FederationClient(FederationBase): try: yield self.transport_layer.exchange_third_party_invite( - destination=destination, - room_id=room_id, - event_dict=event_dict, + destination=destination, room_id=room_id, event_dict=event_dict ) defer.returnValue(None) except CodeMessageException: raise except Exception as e: logger.exception( - "Failed to send_third_party_invite via %s: %s", - destination, str(e) + "Failed to send_third_party_invite via %s: %s", destination, str(e) ) raise RuntimeError("Failed to send to any server.") diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 4c28c1dc3c..2e0cebb638 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -69,7 +69,6 @@ received_queries_counter = Counter( class FederationServer(FederationBase): - def __init__(self, hs): super(FederationServer, self).__init__(hs) @@ -118,11 +117,13 @@ class FederationServer(FederationBase): # use a linearizer to ensure that we don't process the same transaction # multiple times in parallel. - with (yield self._transaction_linearizer.queue( - (origin, transaction.transaction_id), - )): + with ( + yield self._transaction_linearizer.queue( + (origin, transaction.transaction_id) + ) + ): result = yield self._handle_incoming_transaction( - origin, transaction, request_time, + origin, transaction, request_time ) defer.returnValue(result) @@ -144,7 +145,7 @@ class FederationServer(FederationBase): if response: logger.debug( "[%s] We've already responded to this request", - transaction.transaction_id + transaction.transaction_id, ) defer.returnValue(response) return @@ -152,18 +153,15 @@ class FederationServer(FederationBase): logger.debug("[%s] Transaction is new", transaction.transaction_id) # Reject if PDU count > 50 and EDU count > 100 - if (len(transaction.pdus) > 50 - or (hasattr(transaction, "edus") and len(transaction.edus) > 100)): + if len(transaction.pdus) > 50 or ( + hasattr(transaction, "edus") and len(transaction.edus) > 100 + ): - logger.info( - "Transaction PDU or EDU count too large. Returning 400", - ) + logger.info("Transaction PDU or EDU count too large. Returning 400") response = {} yield self.transaction_actions.set_response( - origin, - transaction, - 400, response + origin, transaction, 400, response ) defer.returnValue((400, response)) @@ -230,9 +228,7 @@ class FederationServer(FederationBase): try: yield self.check_server_matches_acl(origin_host, room_id) except AuthError as e: - logger.warn( - "Ignoring PDUs for room %s from banned server", room_id, - ) + logger.warn("Ignoring PDUs for room %s from banned server", room_id) for pdu in pdus_by_room[room_id]: event_id = pdu.event_id pdu_results[event_id] = e.error_dict() @@ -242,9 +238,7 @@ class FederationServer(FederationBase): event_id = pdu.event_id with nested_logging_context(event_id): try: - yield self._handle_received_pdu( - origin, pdu - ) + yield self._handle_received_pdu(origin, pdu) pdu_results[event_id] = {} except FederationError as e: logger.warn("Error handling PDU %s: %s", event_id, e) @@ -259,29 +253,18 @@ class FederationServer(FederationBase): ) yield concurrently_execute( - process_pdus_for_room, pdus_by_room.keys(), - TRANSACTION_CONCURRENCY_LIMIT, + process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT ) if hasattr(transaction, "edus"): for edu in (Edu(**x) for x in transaction.edus): - yield self.received_edu( - origin, - edu.edu_type, - edu.content - ) + yield self.received_edu(origin, edu.edu_type, edu.content) - response = { - "pdus": pdu_results, - } + response = {"pdus": pdu_results} logger.debug("Returning: %s", str(response)) - yield self.transaction_actions.set_response( - origin, - transaction, - 200, response - ) + yield self.transaction_actions.set_response(origin, transaction, 200, response) defer.returnValue((200, response)) @defer.inlineCallbacks @@ -311,7 +294,8 @@ class FederationServer(FederationBase): resp = yield self._state_resp_cache.wrap( (room_id, event_id), self._on_context_state_request_compute, - room_id, event_id, + room_id, + event_id, ) defer.returnValue((200, resp)) @@ -328,24 +312,17 @@ class FederationServer(FederationBase): if not in_room: raise AuthError(403, "Host not in room.") - state_ids = yield self.handler.get_state_ids_for_pdu( - room_id, event_id, - ) + state_ids = yield self.handler.get_state_ids_for_pdu(room_id, event_id) auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids) - defer.returnValue((200, { - "pdu_ids": state_ids, - "auth_chain_ids": auth_chain_ids, - })) + defer.returnValue( + (200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}) + ) @defer.inlineCallbacks def _on_context_state_request_compute(self, room_id, event_id): - pdus = yield self.handler.get_state_for_pdu( - room_id, event_id, - ) - auth_chain = yield self.store.get_auth_chain( - [pdu.event_id for pdu in pdus] - ) + pdus = yield self.handler.get_state_for_pdu(room_id, event_id) + auth_chain = yield self.store.get_auth_chain([pdu.event_id for pdu in pdus]) for event in auth_chain: # We sign these again because there was a bug where we @@ -355,14 +332,16 @@ class FederationServer(FederationBase): compute_event_signature( event.get_pdu_json(), self.hs.hostname, - self.hs.config.signing_key[0] + self.hs.config.signing_key[0], ) ) - defer.returnValue({ - "pdus": [pdu.get_pdu_json() for pdu in pdus], - "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], - }) + defer.returnValue( + { + "pdus": [pdu.get_pdu_json() for pdu in pdus], + "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], + } + ) @defer.inlineCallbacks @log_function @@ -370,9 +349,7 @@ class FederationServer(FederationBase): pdu = yield self.handler.get_persisted_pdu(origin, event_id) if pdu: - defer.returnValue( - (200, self._transaction_from_pdus([pdu]).get_dict()) - ) + defer.returnValue((200, self._transaction_from_pdus([pdu]).get_dict())) else: defer.returnValue((404, "")) @@ -394,10 +371,9 @@ class FederationServer(FederationBase): pdu = yield self.handler.on_make_join_request(room_id, user_id) time_now = self._clock.time_msec() - defer.returnValue({ - "event": pdu.get_pdu_json(time_now), - "room_version": room_version, - }) + defer.returnValue( + {"event": pdu.get_pdu_json(time_now), "room_version": room_version} + ) @defer.inlineCallbacks def on_invite_request(self, origin, content, room_version): @@ -431,12 +407,17 @@ class FederationServer(FederationBase): logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) res_pdus = yield self.handler.on_send_join_request(origin, pdu) time_now = self._clock.time_msec() - defer.returnValue((200, { - "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]], - "auth_chain": [ - p.get_pdu_json(time_now) for p in res_pdus["auth_chain"] - ], - })) + defer.returnValue( + ( + 200, + { + "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]], + "auth_chain": [ + p.get_pdu_json(time_now) for p in res_pdus["auth_chain"] + ], + }, + ) + ) @defer.inlineCallbacks def on_make_leave_request(self, origin, room_id, user_id): @@ -447,10 +428,9 @@ class FederationServer(FederationBase): room_version = yield self.store.get_room_version(room_id) time_now = self._clock.time_msec() - defer.returnValue({ - "event": pdu.get_pdu_json(time_now), - "room_version": room_version, - }) + defer.returnValue( + {"event": pdu.get_pdu_json(time_now), "room_version": room_version} + ) @defer.inlineCallbacks def on_send_leave_request(self, origin, content, room_id): @@ -475,9 +455,7 @@ class FederationServer(FederationBase): time_now = self._clock.time_msec() auth_pdus = yield self.handler.on_event_auth(event_id) - res = { - "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus], - } + res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]} defer.returnValue((200, res)) @defer.inlineCallbacks @@ -508,12 +486,11 @@ class FederationServer(FederationBase): format_ver = room_version_to_event_format(room_version) auth_chain = [ - event_from_pdu_json(e, format_ver) - for e in content["auth_chain"] + event_from_pdu_json(e, format_ver) for e in content["auth_chain"] ] signed_auth = yield self._check_sigs_and_hash_and_fetch( - origin, auth_chain, outlier=True, room_version=room_version, + origin, auth_chain, outlier=True, room_version=room_version ) ret = yield self.handler.on_query_auth( @@ -527,17 +504,12 @@ class FederationServer(FederationBase): time_now = self._clock.time_msec() send_content = { - "auth_chain": [ - e.get_pdu_json(time_now) - for e in ret["auth_chain"] - ], + "auth_chain": [e.get_pdu_json(time_now) for e in ret["auth_chain"]], "rejects": ret.get("rejects", []), "missing": ret.get("missing", []), } - defer.returnValue( - (200, send_content) - ) + defer.returnValue((200, send_content)) @log_function def on_query_client_keys(self, origin, content): @@ -566,20 +538,23 @@ class FederationServer(FederationBase): logger.info( "Claimed one-time-keys: %s", - ",".join(( - "%s for %s:%s" % (key_id, user_id, device_id) - for user_id, user_keys in iteritems(json_result) - for device_id, device_keys in iteritems(user_keys) - for key_id, _ in iteritems(device_keys) - )), + ",".join( + ( + "%s for %s:%s" % (key_id, user_id, device_id) + for user_id, user_keys in iteritems(json_result) + for device_id, device_keys in iteritems(user_keys) + for key_id, _ in iteritems(device_keys) + ) + ), ) defer.returnValue({"one_time_keys": json_result}) @defer.inlineCallbacks @log_function - def on_get_missing_events(self, origin, room_id, earliest_events, - latest_events, limit): + def on_get_missing_events( + self, origin, room_id, earliest_events, latest_events, limit + ): with (yield self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) yield self.check_server_matches_acl(origin_host, room_id) @@ -587,11 +562,13 @@ class FederationServer(FederationBase): logger.info( "on_get_missing_events: earliest_events: %r, latest_events: %r," " limit: %d", - earliest_events, latest_events, limit, + earliest_events, + latest_events, + limit, ) missing_events = yield self.handler.on_get_missing_events( - origin, room_id, earliest_events, latest_events, limit, + origin, room_id, earliest_events, latest_events, limit ) if len(missing_events) < 5: @@ -603,9 +580,9 @@ class FederationServer(FederationBase): time_now = self._clock.time_msec() - defer.returnValue({ - "events": [ev.get_pdu_json(time_now) for ev in missing_events], - }) + defer.returnValue( + {"events": [ev.get_pdu_json(time_now) for ev in missing_events]} + ) @log_function def on_openid_userinfo(self, token): @@ -666,22 +643,17 @@ class FederationServer(FederationBase): # origin. See bug #1893. This is also true for some third party # invites). if not ( - pdu.type == 'm.room.member' and - pdu.content and - pdu.content.get("membership", None) in ( - Membership.JOIN, Membership.INVITE, - ) + pdu.type == "m.room.member" + and pdu.content + and pdu.content.get("membership", None) + in (Membership.JOIN, Membership.INVITE) ): logger.info( - "Discarding PDU %s from invalid origin %s", - pdu.event_id, origin + "Discarding PDU %s from invalid origin %s", pdu.event_id, origin ) return else: - logger.info( - "Accepting join PDU %s from %s", - pdu.event_id, origin - ) + logger.info("Accepting join PDU %s from %s", pdu.event_id, origin) # We've already checked that we know the room version by this point room_version = yield self.store.get_room_version(pdu.room_id) @@ -690,33 +662,19 @@ class FederationServer(FederationBase): try: pdu = yield self._check_sigs_and_hash(room_version, pdu) except SynapseError as e: - raise FederationError( - "ERROR", - e.code, - e.msg, - affected=pdu.event_id, - ) + raise FederationError("ERROR", e.code, e.msg, affected=pdu.event_id) - yield self.handler.on_receive_pdu( - origin, pdu, sent_to_us_directly=True, - ) + yield self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True) def __str__(self): return "" % self.server_name @defer.inlineCallbacks def exchange_third_party_invite( - self, - sender_user_id, - target_user_id, - room_id, - signed, + self, sender_user_id, target_user_id, room_id, signed ): ret = yield self.handler.exchange_third_party_invite( - sender_user_id, - target_user_id, - room_id, - signed, + sender_user_id, target_user_id, room_id, signed ) defer.returnValue(ret) @@ -771,7 +729,7 @@ def server_matches_acl_event(server_name, acl_event): allow_ip_literals = True if not allow_ip_literals: # check for ipv6 literals. These start with '['. - if server_name[0] == '[': + if server_name[0] == "[": return False # check for ipv4 literals. We can just lift the routine from twisted. @@ -805,7 +763,9 @@ def server_matches_acl_event(server_name, acl_event): def _acl_entry_matches(server_name, acl_entry): if not isinstance(acl_entry, six.string_types): - logger.warn("Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)) + logger.warn( + "Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry) + ) return False regex = glob_to_regex(acl_entry) return regex.match(server_name) @@ -815,6 +775,7 @@ class FederationHandlerRegistry(object): """Allows classes to register themselves as handlers for a given EDU or query type for incoming federation traffic. """ + def __init__(self): self.edu_handlers = {} self.query_handlers = {} @@ -848,9 +809,7 @@ class FederationHandlerRegistry(object): on and the result used as the response to the query request. """ if query_type in self.query_handlers: - raise KeyError( - "Already have a Query handler for %s" % (query_type,) - ) + raise KeyError("Already have a Query handler for %s" % (query_type,)) logger.info("Registering federation query handler for %r", query_type) @@ -905,14 +864,10 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry): handler = self.edu_handlers.get(edu_type) if handler: return super(ReplicationFederationHandlerRegistry, self).on_edu( - edu_type, origin, content, + edu_type, origin, content ) - return self._send_edu( - edu_type=edu_type, - origin=origin, - content=content, - ) + return self._send_edu(edu_type=edu_type, origin=origin, content=content) def on_query(self, query_type, args): """Overrides FederationHandlerRegistry @@ -921,7 +876,4 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry): if handler: return handler(args) - return self._get_query_client( - query_type=query_type, - args=args, - ) + return self._get_query_client(query_type=query_type, args=args) diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index 74ffd13b4f..7535f79203 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -46,12 +46,9 @@ class TransactionActions(object): response code and response body. """ if not transaction.transaction_id: - raise RuntimeError("Cannot persist a transaction with no " - "transaction_id") + raise RuntimeError("Cannot persist a transaction with no " "transaction_id") - return self.store.get_received_txn_response( - transaction.transaction_id, origin - ) + return self.store.get_received_txn_response(transaction.transaction_id, origin) @log_function def set_response(self, origin, transaction, code, response): @@ -61,14 +58,10 @@ class TransactionActions(object): Deferred """ if not transaction.transaction_id: - raise RuntimeError("Cannot persist a transaction with no " - "transaction_id") + raise RuntimeError("Cannot persist a transaction with no " "transaction_id") return self.store.set_received_txn_response( - transaction.transaction_id, - origin, - code, - response, + transaction.transaction_id, origin, code, response ) @defer.inlineCallbacks diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 0240b339b0..454456a52d 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -77,12 +77,22 @@ class FederationRemoteSendQueue(object): # lambda binds to the queue rather than to the name of the queue which # changes. ARGH. def register(name, queue): - LaterGauge("synapse_federation_send_queue_%s_size" % (queue_name,), - "", [], lambda: len(queue)) + LaterGauge( + "synapse_federation_send_queue_%s_size" % (queue_name,), + "", + [], + lambda: len(queue), + ) for queue_name in [ - "presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed", - "edus", "device_messages", "pos_time", "presence_destinations", + "presence_map", + "presence_changed", + "keyed_edu", + "keyed_edu_changed", + "edus", + "device_messages", + "pos_time", + "presence_destinations", ]: register(queue_name, getattr(self, queue_name)) @@ -121,9 +131,7 @@ class FederationRemoteSendQueue(object): del self.presence_changed[key] user_ids = set( - user_id - for uids in self.presence_changed.values() - for user_id in uids + user_id for uids in self.presence_changed.values() for user_id in uids ) keys = self.presence_destinations.keys() @@ -285,19 +293,21 @@ class FederationRemoteSendQueue(object): ] for (key, user_id) in dest_user_ids: - rows.append((key, PresenceRow( - state=self.presence_map[user_id], - ))) + rows.append((key, PresenceRow(state=self.presence_map[user_id]))) # Fetch presence to send to destinations i = self.presence_destinations.bisect_right(from_token) j = self.presence_destinations.bisect_right(to_token) + 1 for pos, (user_id, dests) in self.presence_destinations.items()[i:j]: - rows.append((pos, PresenceDestinationsRow( - state=self.presence_map[user_id], - destinations=list(dests), - ))) + rows.append( + ( + pos, + PresenceDestinationsRow( + state=self.presence_map[user_id], destinations=list(dests) + ), + ) + ) # Fetch changes keyed edus i = self.keyed_edu_changed.bisect_right(from_token) @@ -308,10 +318,14 @@ class FederationRemoteSendQueue(object): keyed_edus = {v: k for k, v in self.keyed_edu_changed.items()[i:j]} for ((destination, edu_key), pos) in iteritems(keyed_edus): - rows.append((pos, KeyedEduRow( - key=edu_key, - edu=self.keyed_edu[(destination, edu_key)], - ))) + rows.append( + ( + pos, + KeyedEduRow( + key=edu_key, edu=self.keyed_edu[(destination, edu_key)] + ), + ) + ) # Fetch changed edus i = self.edus.bisect_right(from_token) @@ -327,9 +341,7 @@ class FederationRemoteSendQueue(object): device_messages = {v: k for k, v in self.device_messages.items()[i:j]} for (destination, pos) in iteritems(device_messages): - rows.append((pos, DeviceRow( - destination=destination, - ))) + rows.append((pos, DeviceRow(destination=destination))) # Sort rows based on pos rows.sort() @@ -377,16 +389,14 @@ class BaseFederationRow(object): raise NotImplementedError() -class PresenceRow(BaseFederationRow, namedtuple("PresenceRow", ( - "state", # UserPresenceState -))): +class PresenceRow( + BaseFederationRow, namedtuple("PresenceRow", ("state",)) # UserPresenceState +): TypeId = "p" @staticmethod def from_data(data): - return PresenceRow( - state=UserPresenceState.from_dict(data) - ) + return PresenceRow(state=UserPresenceState.from_dict(data)) def to_data(self): return self.state.as_dict() @@ -395,33 +405,35 @@ class PresenceRow(BaseFederationRow, namedtuple("PresenceRow", ( buff.presence.append(self.state) -class PresenceDestinationsRow(BaseFederationRow, namedtuple("PresenceDestinationsRow", ( - "state", # UserPresenceState - "destinations", # list[str] -))): +class PresenceDestinationsRow( + BaseFederationRow, + namedtuple( + "PresenceDestinationsRow", + ("state", "destinations"), # UserPresenceState # list[str] + ), +): TypeId = "pd" @staticmethod def from_data(data): return PresenceDestinationsRow( - state=UserPresenceState.from_dict(data["state"]), - destinations=data["dests"], + state=UserPresenceState.from_dict(data["state"]), destinations=data["dests"] ) def to_data(self): - return { - "state": self.state.as_dict(), - "dests": self.destinations, - } + return {"state": self.state.as_dict(), "dests": self.destinations} def add_to_buffer(self, buff): buff.presence_destinations.append((self.state, self.destinations)) -class KeyedEduRow(BaseFederationRow, namedtuple("KeyedEduRow", ( - "key", # tuple(str) - the edu key passed to send_edu - "edu", # Edu -))): +class KeyedEduRow( + BaseFederationRow, + namedtuple( + "KeyedEduRow", + ("key", "edu"), # tuple(str) - the edu key passed to send_edu # Edu + ), +): """Streams EDUs that have an associated key that is ued to clobber. For example, typing EDUs clobber based on room_id. """ @@ -430,28 +442,19 @@ class KeyedEduRow(BaseFederationRow, namedtuple("KeyedEduRow", ( @staticmethod def from_data(data): - return KeyedEduRow( - key=tuple(data["key"]), - edu=Edu(**data["edu"]), - ) + return KeyedEduRow(key=tuple(data["key"]), edu=Edu(**data["edu"])) def to_data(self): - return { - "key": self.key, - "edu": self.edu.get_internal_dict(), - } + return {"key": self.key, "edu": self.edu.get_internal_dict()} def add_to_buffer(self, buff): - buff.keyed_edus.setdefault( - self.edu.destination, {} - )[self.key] = self.edu + buff.keyed_edus.setdefault(self.edu.destination, {})[self.key] = self.edu -class EduRow(BaseFederationRow, namedtuple("EduRow", ( - "edu", # Edu -))): +class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu """Streams EDUs that don't have keys. See KeyedEduRow """ + TypeId = "e" @staticmethod @@ -465,13 +468,12 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ( buff.edus.setdefault(self.edu.destination, []).append(self.edu) -class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", ( - "destination", # str -))): +class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", ("destination",))): # str """Streams the fact that either a) there is pending to device messages for users on the remote, or b) a local users device has changed and needs to be sent to the remote. """ + TypeId = "d" @staticmethod @@ -487,23 +489,20 @@ class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", ( TypeToRow = { Row.TypeId: Row - for Row in ( - PresenceRow, - PresenceDestinationsRow, - KeyedEduRow, - EduRow, - DeviceRow, - ) + for Row in (PresenceRow, PresenceDestinationsRow, KeyedEduRow, EduRow, DeviceRow) } -ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", ( - "presence", # list(UserPresenceState) - "presence_destinations", # list of tuples of UserPresenceState and destinations - "keyed_edus", # dict of destination -> { key -> Edu } - "edus", # dict of destination -> [Edu] - "device_destinations", # set of destinations -)) +ParsedFederationStreamData = namedtuple( + "ParsedFederationStreamData", + ( + "presence", # list(UserPresenceState) + "presence_destinations", # list of tuples of UserPresenceState and destinations + "keyed_edus", # dict of destination -> { key -> Edu } + "edus", # dict of destination -> [Edu] + "device_destinations", # set of destinations + ), +) def process_rows_for_federation(transaction_queue, rows): @@ -542,7 +541,7 @@ def process_rows_for_federation(transaction_queue, rows): for state, destinations in buff.presence_destinations: transaction_queue.send_presence_to_destinations( - states=[state], destinations=destinations, + states=[state], destinations=destinations ) for destination, edu_map in iteritems(buff.keyed_edus): diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 4224b29ecf..766c5a37cd 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -44,8 +44,8 @@ sent_pdus_destination_dist_count = Counter( ) sent_pdus_destination_dist_total = Counter( - "synapse_federation_client_sent_pdu_destinations:total", "" - "Total number of PDUs queued for sending across all destinations", + "synapse_federation_client_sent_pdu_destinations:total", + "" "Total number of PDUs queued for sending across all destinations", ) @@ -63,14 +63,15 @@ class FederationSender(object): self._transaction_manager = TransactionManager(hs) # map from destination to PerDestinationQueue - self._per_destination_queues = {} # type: dict[str, PerDestinationQueue] + self._per_destination_queues = {} # type: dict[str, PerDestinationQueue] LaterGauge( "synapse_federation_transaction_queue_pending_destinations", "", [], lambda: sum( - 1 for d in self._per_destination_queues.values() + 1 + for d in self._per_destination_queues.values() if d.transmission_loop_running ), ) @@ -108,8 +109,9 @@ class FederationSender(object): # awaiting a call to flush_read_receipts_for_room. The presence of an entry # here for a given room means that we are rate-limiting RR flushes to that room, # and that there is a pending call to _flush_rrs_for_room in the system. - self._queues_awaiting_rr_flush_by_room = { - } # type: dict[str, set[PerDestinationQueue]] + self._queues_awaiting_rr_flush_by_room = ( + {} + ) # type: dict[str, set[PerDestinationQueue]] self._rr_txn_interval_per_room_ms = ( 1000.0 / hs.get_config().federation_rr_transactions_per_room_per_second @@ -141,8 +143,7 @@ class FederationSender(object): # fire off a processing loop in the background run_as_background_process( - "process_event_queue_for_federation", - self._process_event_queue_loop, + "process_event_queue_for_federation", self._process_event_queue_loop ) @defer.inlineCallbacks @@ -152,7 +153,7 @@ class FederationSender(object): while True: last_token = yield self.store.get_federation_out_pos("events") next_token, events = yield self.store.get_all_new_events_stream( - last_token, self._last_poked_id, limit=100, + last_token, self._last_poked_id, limit=100 ) logger.debug("Handling %s -> %s", last_token, next_token) @@ -179,7 +180,7 @@ class FederationSender(object): # banned then it won't receive the event because it won't # be in the room after the ban. destinations = yield self.state.get_current_hosts_in_room( - event.room_id, latest_event_ids=event.prev_event_ids(), + event.room_id, latest_event_ids=event.prev_event_ids() ) except Exception: logger.exception( @@ -209,37 +210,40 @@ class FederationSender(object): for event in events: events_by_room.setdefault(event.room_id, []).append(event) - yield logcontext.make_deferred_yieldable(defer.gatherResults( - [ - logcontext.run_in_background(handle_room_events, evs) - for evs in itervalues(events_by_room) - ], - consumeErrors=True - )) - - yield self.store.update_federation_out_pos( - "events", next_token + yield logcontext.make_deferred_yieldable( + defer.gatherResults( + [ + logcontext.run_in_background(handle_room_events, evs) + for evs in itervalues(events_by_room) + ], + consumeErrors=True, + ) ) + yield self.store.update_federation_out_pos("events", next_token) + if events: now = self.clock.time_msec() ts = yield self.store.get_received_ts(events[-1].event_id) synapse.metrics.event_processing_lag.labels( - "federation_sender").set(now - ts) + "federation_sender" + ).set(now - ts) synapse.metrics.event_processing_last_ts.labels( - "federation_sender").set(ts) + "federation_sender" + ).set(ts) events_processed_counter.inc(len(events)) - event_processing_loop_room_count.labels( - "federation_sender" - ).inc(len(events_by_room)) + event_processing_loop_room_count.labels("federation_sender").inc( + len(events_by_room) + ) event_processing_loop_counter.labels("federation_sender").inc() synapse.metrics.event_processing_positions.labels( - "federation_sender").set(next_token) + "federation_sender" + ).set(next_token) finally: self._is_processing = False @@ -312,9 +316,7 @@ class FederationSender(object): if not domains: return - queues_pending_flush = self._queues_awaiting_rr_flush_by_room.get( - room_id - ) + queues_pending_flush = self._queues_awaiting_rr_flush_by_room.get(room_id) # if there is no flush yet scheduled, we will send out these receipts with # immediate flushes, and schedule the next flush for this room. @@ -377,10 +379,9 @@ class FederationSender(object): # updates in quick succession are correctly handled. # We only want to send presence for our own users, so lets always just # filter here just in case. - self.pending_presence.update({ - state.user_id: state for state in states - if self.is_mine_id(state.user_id) - }) + self.pending_presence.update( + {state.user_id: state for state in states if self.is_mine_id(state.user_id)} + ) # We then handle the new pending presence in batches, first figuring # out the destinations we need to send each state to and then poking it diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 22a2735405..9aab12c0d3 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -360,7 +360,7 @@ class PerDestinationQueue(object): # Retrieve list of new device updates to send to the destination now_stream_id, results = yield self._store.get_devices_by_remote( - self._destination, last_device_list, limit=limit, + self._destination, last_device_list, limit=limit ) edus = [ Edu( @@ -381,10 +381,7 @@ class PerDestinationQueue(object): last_device_stream_id = self._last_device_stream_id to_device_stream_id = self._store.get_to_device_stream_token() contents, stream_id = yield self._store.get_new_device_msgs_for_remote( - self._destination, - last_device_stream_id, - to_device_stream_id, - limit, + self._destination, last_device_stream_id, to_device_stream_id, limit ) edus = [ Edu( diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 35e6b8ff5b..c987bb9a0d 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -29,9 +29,10 @@ class TransactionManager(object): shared between PerDestinationQueue objects """ + def __init__(self, hs): self._server_name = hs.hostname - self.clock = hs.get_clock() # nb must be called this for @measure_func + self.clock = hs.get_clock() # nb must be called this for @measure_func self._store = hs.get_datastore() self._transaction_actions = TransactionActions(self._store) self._transport_layer = hs.get_federation_transport_client() @@ -55,9 +56,9 @@ class TransactionManager(object): txn_id = str(self._next_txn_id) logger.debug( - "TX [%s] {%s} Attempting new transaction" - " (pdus: %d, edus: %d)", - destination, txn_id, + "TX [%s] {%s} Attempting new transaction" " (pdus: %d, edus: %d)", + destination, + txn_id, len(pdus), len(edus), ) @@ -79,9 +80,9 @@ class TransactionManager(object): logger.debug("TX [%s] Persisted transaction", destination) logger.info( - "TX [%s] {%s} Sending transaction [%s]," - " (PDUs: %d, EDUs: %d)", - destination, txn_id, + "TX [%s] {%s} Sending transaction [%s]," " (PDUs: %d, EDUs: %d)", + destination, + txn_id, transaction.transaction_id, len(pdus), len(edus), @@ -112,20 +113,12 @@ class TransactionManager(object): response = e.response if e.code in (401, 404, 429) or 500 <= e.code: - logger.info( - "TX [%s] {%s} got %d response", - destination, txn_id, code - ) + logger.info("TX [%s] {%s} got %d response", destination, txn_id, code) raise e - logger.info( - "TX [%s] {%s} got %d response", - destination, txn_id, code - ) + logger.info("TX [%s] {%s} got %d response", destination, txn_id, code) - yield self._transaction_actions.delivered( - transaction, code, response - ) + yield self._transaction_actions.delivered(transaction, code, response) logger.debug("TX [%s] {%s} Marked as delivered", destination, txn_id) @@ -134,13 +127,18 @@ class TransactionManager(object): if "error" in r: logger.warn( "TX [%s] {%s} Remote returned error for %s: %s", - destination, txn_id, e_id, r, + destination, + txn_id, + e_id, + r, ) else: for p in pdus: logger.warn( "TX [%s] {%s} Failed to send event %s", - destination, txn_id, p.event_id, + destination, + txn_id, + p.event_id, ) success = False diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index e424c40fdf..aecd142309 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -48,12 +48,13 @@ class TransportLayerClient(object): Returns: Deferred: Results in a dict received from the remote homeserver. """ - logger.debug("get_room_state dest=%s, room=%s", - destination, room_id) + logger.debug("get_room_state dest=%s, room=%s", destination, room_id) path = _create_v1_path("/state/%s", room_id) return self.client.get_json( - destination, path=path, args={"event_id": event_id}, + destination, + path=path, + args={"event_id": event_id}, try_trailing_slash_on_400=True, ) @@ -71,12 +72,13 @@ class TransportLayerClient(object): Returns: Deferred: Results in a dict received from the remote homeserver. """ - logger.debug("get_room_state_ids dest=%s, room=%s", - destination, room_id) + logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id) path = _create_v1_path("/state_ids/%s", room_id) return self.client.get_json( - destination, path=path, args={"event_id": event_id}, + destination, + path=path, + args={"event_id": event_id}, try_trailing_slash_on_400=True, ) @@ -94,13 +96,11 @@ class TransportLayerClient(object): Returns: Deferred: Results in a dict received from the remote homeserver. """ - logger.debug("get_pdu dest=%s, event_id=%s", - destination, event_id) + logger.debug("get_pdu dest=%s, event_id=%s", destination, event_id) path = _create_v1_path("/event/%s", event_id) return self.client.get_json( - destination, path=path, timeout=timeout, - try_trailing_slash_on_400=True, + destination, path=path, timeout=timeout, try_trailing_slash_on_400=True ) @log_function @@ -119,7 +119,10 @@ class TransportLayerClient(object): """ logger.debug( "backfill dest=%s, room_id=%s, event_tuples=%s, limit=%s", - destination, room_id, repr(event_tuples), str(limit) + destination, + room_id, + repr(event_tuples), + str(limit), ) if not event_tuples: @@ -128,16 +131,10 @@ class TransportLayerClient(object): path = _create_v1_path("/backfill/%s", room_id) - args = { - "v": event_tuples, - "limit": [str(limit)], - } + args = {"v": event_tuples, "limit": [str(limit)]} return self.client.get_json( - destination, - path=path, - args=args, - try_trailing_slash_on_400=True, + destination, path=path, args=args, try_trailing_slash_on_400=True ) @defer.inlineCallbacks @@ -163,7 +160,8 @@ class TransportLayerClient(object): """ logger.debug( "send_data dest=%s, txid=%s", - transaction.destination, transaction.transaction_id + transaction.destination, + transaction.transaction_id, ) if transaction.destination == self.server_name: @@ -189,8 +187,9 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function - def make_query(self, destination, query_type, args, retry_on_dns_fail, - ignore_backoff=False): + def make_query( + self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False + ): path = _create_v1_path("/query/%s", query_type) content = yield self.client.get_json( @@ -235,8 +234,8 @@ class TransportLayerClient(object): valid_memberships = {Membership.JOIN, Membership.LEAVE} if membership not in valid_memberships: raise RuntimeError( - "make_membership_event called with membership='%s', must be one of %s" % - (membership, ",".join(valid_memberships)) + "make_membership_event called with membership='%s', must be one of %s" + % (membership, ",".join(valid_memberships)) ) path = _create_v1_path("/make_%s/%s/%s", membership, room_id, user_id) @@ -268,9 +267,7 @@ class TransportLayerClient(object): path = _create_v1_path("/send_join/%s/%s", room_id, event_id) response = yield self.client.put_json( - destination=destination, - path=path, - data=content, + destination=destination, path=path, data=content ) defer.returnValue(response) @@ -284,7 +281,6 @@ class TransportLayerClient(object): destination=destination, path=path, data=content, - # we want to do our best to send this through. The problem is # that if it fails, we won't retry it later, so if the remote # server was just having a momentary blip, the room will be out of @@ -300,10 +296,7 @@ class TransportLayerClient(object): path = _create_v1_path("/invite/%s/%s", room_id, event_id) response = yield self.client.put_json( - destination=destination, - path=path, - data=content, - ignore_backoff=True, + destination=destination, path=path, data=content, ignore_backoff=True ) defer.returnValue(response) @@ -314,26 +307,27 @@ class TransportLayerClient(object): path = _create_v2_path("/invite/%s/%s", room_id, event_id) response = yield self.client.put_json( - destination=destination, - path=path, - data=content, - ignore_backoff=True, + destination=destination, path=path, data=content, ignore_backoff=True ) defer.returnValue(response) @defer.inlineCallbacks @log_function - def get_public_rooms(self, remote_server, limit, since_token, - search_filter=None, include_all_networks=False, - third_party_instance_id=None): + def get_public_rooms( + self, + remote_server, + limit, + since_token, + search_filter=None, + include_all_networks=False, + third_party_instance_id=None, + ): path = _create_v1_path("/publicRooms") - args = { - "include_all_networks": "true" if include_all_networks else "false", - } + args = {"include_all_networks": "true" if include_all_networks else "false"} if third_party_instance_id: - args["third_party_instance_id"] = third_party_instance_id, + args["third_party_instance_id"] = (third_party_instance_id,) if limit: args["limit"] = [str(limit)] if since_token: @@ -342,10 +336,7 @@ class TransportLayerClient(object): # TODO(erikj): Actually send the search_filter across federation. response = yield self.client.get_json( - destination=remote_server, - path=path, - args=args, - ignore_backoff=True, + destination=remote_server, path=path, args=args, ignore_backoff=True ) defer.returnValue(response) @@ -353,12 +344,10 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def exchange_third_party_invite(self, destination, room_id, event_dict): - path = _create_v1_path("/exchange_third_party_invite/%s", room_id,) + path = _create_v1_path("/exchange_third_party_invite/%s", room_id) response = yield self.client.put_json( - destination=destination, - path=path, - data=event_dict, + destination=destination, path=path, data=event_dict ) defer.returnValue(response) @@ -368,10 +357,7 @@ class TransportLayerClient(object): def get_event_auth(self, destination, room_id, event_id): path = _create_v1_path("/event_auth/%s/%s", room_id, event_id) - content = yield self.client.get_json( - destination=destination, - path=path, - ) + content = yield self.client.get_json(destination=destination, path=path) defer.returnValue(content) @@ -381,9 +367,7 @@ class TransportLayerClient(object): path = _create_v1_path("/query_auth/%s/%s", room_id, event_id) content = yield self.client.post_json( - destination=destination, - path=path, - data=content, + destination=destination, path=path, data=content ) defer.returnValue(content) @@ -416,10 +400,7 @@ class TransportLayerClient(object): path = _create_v1_path("/user/keys/query") content = yield self.client.post_json( - destination=destination, - path=path, - data=query_content, - timeout=timeout, + destination=destination, path=path, data=query_content, timeout=timeout ) defer.returnValue(content) @@ -443,9 +424,7 @@ class TransportLayerClient(object): path = _create_v1_path("/user/devices/%s", user_id) content = yield self.client.get_json( - destination=destination, - path=path, - timeout=timeout, + destination=destination, path=path, timeout=timeout ) defer.returnValue(content) @@ -479,18 +458,23 @@ class TransportLayerClient(object): path = _create_v1_path("/user/keys/claim") content = yield self.client.post_json( - destination=destination, - path=path, - data=query_content, - timeout=timeout, + destination=destination, path=path, data=query_content, timeout=timeout ) defer.returnValue(content) @defer.inlineCallbacks @log_function - def get_missing_events(self, destination, room_id, earliest_events, - latest_events, limit, min_depth, timeout): - path = _create_v1_path("/get_missing_events/%s", room_id,) + def get_missing_events( + self, + destination, + room_id, + earliest_events, + latest_events, + limit, + min_depth, + timeout, + ): + path = _create_v1_path("/get_missing_events/%s", room_id) content = yield self.client.post_json( destination=destination, @@ -510,7 +494,7 @@ class TransportLayerClient(object): def get_group_profile(self, destination, group_id, requester_user_id): """Get a group profile """ - path = _create_v1_path("/groups/%s/profile", group_id,) + path = _create_v1_path("/groups/%s/profile", group_id) return self.client.get_json( destination=destination, @@ -529,7 +513,7 @@ class TransportLayerClient(object): requester_user_id (str) content (dict): The new profile of the group """ - path = _create_v1_path("/groups/%s/profile", group_id,) + path = _create_v1_path("/groups/%s/profile", group_id) return self.client.post_json( destination=destination, @@ -543,7 +527,7 @@ class TransportLayerClient(object): def get_group_summary(self, destination, group_id, requester_user_id): """Get a group summary """ - path = _create_v1_path("/groups/%s/summary", group_id,) + path = _create_v1_path("/groups/%s/summary", group_id) return self.client.get_json( destination=destination, @@ -556,7 +540,7 @@ class TransportLayerClient(object): def get_rooms_in_group(self, destination, group_id, requester_user_id): """Get all rooms in a group """ - path = _create_v1_path("/groups/%s/rooms", group_id,) + path = _create_v1_path("/groups/%s/rooms", group_id) return self.client.get_json( destination=destination, @@ -565,11 +549,12 @@ class TransportLayerClient(object): ignore_backoff=True, ) - def add_room_to_group(self, destination, group_id, requester_user_id, room_id, - content): + def add_room_to_group( + self, destination, group_id, requester_user_id, room_id, content + ): """Add a room to a group """ - path = _create_v1_path("/groups/%s/room/%s", group_id, room_id,) + path = _create_v1_path("/groups/%s/room/%s", group_id, room_id) return self.client.post_json( destination=destination, @@ -579,13 +564,13 @@ class TransportLayerClient(object): ignore_backoff=True, ) - def update_room_in_group(self, destination, group_id, requester_user_id, room_id, - config_key, content): + def update_room_in_group( + self, destination, group_id, requester_user_id, room_id, config_key, content + ): """Update room in group """ path = _create_v1_path( - "/groups/%s/room/%s/config/%s", - group_id, room_id, config_key, + "/groups/%s/room/%s/config/%s", group_id, room_id, config_key ) return self.client.post_json( @@ -599,7 +584,7 @@ class TransportLayerClient(object): def remove_room_from_group(self, destination, group_id, requester_user_id, room_id): """Remove a room from a group """ - path = _create_v1_path("/groups/%s/room/%s", group_id, room_id,) + path = _create_v1_path("/groups/%s/room/%s", group_id, room_id) return self.client.delete_json( destination=destination, @@ -612,7 +597,7 @@ class TransportLayerClient(object): def get_users_in_group(self, destination, group_id, requester_user_id): """Get users in a group """ - path = _create_v1_path("/groups/%s/users", group_id,) + path = _create_v1_path("/groups/%s/users", group_id) return self.client.get_json( destination=destination, @@ -625,7 +610,7 @@ class TransportLayerClient(object): def get_invited_users_in_group(self, destination, group_id, requester_user_id): """Get users that have been invited to a group """ - path = _create_v1_path("/groups/%s/invited_users", group_id,) + path = _create_v1_path("/groups/%s/invited_users", group_id) return self.client.get_json( destination=destination, @@ -638,16 +623,10 @@ class TransportLayerClient(object): def accept_group_invite(self, destination, group_id, user_id, content): """Accept a group invite """ - path = _create_v1_path( - "/groups/%s/users/%s/accept_invite", - group_id, user_id, - ) + path = _create_v1_path("/groups/%s/users/%s/accept_invite", group_id, user_id) return self.client.post_json( - destination=destination, - path=path, - data=content, - ignore_backoff=True, + destination=destination, path=path, data=content, ignore_backoff=True ) @log_function @@ -657,14 +636,13 @@ class TransportLayerClient(object): path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id) return self.client.post_json( - destination=destination, - path=path, - data=content, - ignore_backoff=True, + destination=destination, path=path, data=content, ignore_backoff=True ) @log_function - def invite_to_group(self, destination, group_id, user_id, requester_user_id, content): + def invite_to_group( + self, destination, group_id, user_id, requester_user_id, content + ): """Invite a user to a group """ path = _create_v1_path("/groups/%s/users/%s/invite", group_id, user_id) @@ -686,15 +664,13 @@ class TransportLayerClient(object): path = _create_v1_path("/groups/local/%s/users/%s/invite", group_id, user_id) return self.client.post_json( - destination=destination, - path=path, - data=content, - ignore_backoff=True, + destination=destination, path=path, data=content, ignore_backoff=True ) @log_function - def remove_user_from_group(self, destination, group_id, requester_user_id, - user_id, content): + def remove_user_from_group( + self, destination, group_id, requester_user_id, user_id, content + ): """Remove a user fron a group """ path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id) @@ -708,8 +684,9 @@ class TransportLayerClient(object): ) @log_function - def remove_user_from_group_notification(self, destination, group_id, user_id, - content): + def remove_user_from_group_notification( + self, destination, group_id, user_id, content + ): """Sent by group server to inform a user's server that they have been kicked from the group. """ @@ -717,10 +694,7 @@ class TransportLayerClient(object): path = _create_v1_path("/groups/local/%s/users/%s/remove", group_id, user_id) return self.client.post_json( - destination=destination, - path=path, - data=content, - ignore_backoff=True, + destination=destination, path=path, data=content, ignore_backoff=True ) @log_function @@ -732,24 +706,24 @@ class TransportLayerClient(object): path = _create_v1_path("/groups/%s/renew_attestation/%s", group_id, user_id) return self.client.post_json( - destination=destination, - path=path, - data=content, - ignore_backoff=True, + destination=destination, path=path, data=content, ignore_backoff=True ) @log_function - def update_group_summary_room(self, destination, group_id, user_id, room_id, - category_id, content): + def update_group_summary_room( + self, destination, group_id, user_id, room_id, category_id, content + ): """Update a room entry in a group summary """ if category_id: path = _create_v1_path( "/groups/%s/summary/categories/%s/rooms/%s", - group_id, category_id, room_id, + group_id, + category_id, + room_id, ) else: - path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id,) + path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id) return self.client.post_json( destination=destination, @@ -760,17 +734,20 @@ class TransportLayerClient(object): ) @log_function - def delete_group_summary_room(self, destination, group_id, user_id, room_id, - category_id): + def delete_group_summary_room( + self, destination, group_id, user_id, room_id, category_id + ): """Delete a room entry in a group summary """ if category_id: path = _create_v1_path( "/groups/%s/summary/categories/%s/rooms/%s", - group_id, category_id, room_id, + group_id, + category_id, + room_id, ) else: - path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id,) + path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id) return self.client.delete_json( destination=destination, @@ -783,7 +760,7 @@ class TransportLayerClient(object): def get_group_categories(self, destination, group_id, requester_user_id): """Get all categories in a group """ - path = _create_v1_path("/groups/%s/categories", group_id,) + path = _create_v1_path("/groups/%s/categories", group_id) return self.client.get_json( destination=destination, @@ -796,7 +773,7 @@ class TransportLayerClient(object): def get_group_category(self, destination, group_id, requester_user_id, category_id): """Get category info in a group """ - path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,) + path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) return self.client.get_json( destination=destination, @@ -806,11 +783,12 @@ class TransportLayerClient(object): ) @log_function - def update_group_category(self, destination, group_id, requester_user_id, category_id, - content): + def update_group_category( + self, destination, group_id, requester_user_id, category_id, content + ): """Update a category in a group """ - path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,) + path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) return self.client.post_json( destination=destination, @@ -821,11 +799,12 @@ class TransportLayerClient(object): ) @log_function - def delete_group_category(self, destination, group_id, requester_user_id, - category_id): + def delete_group_category( + self, destination, group_id, requester_user_id, category_id + ): """Delete a category in a group """ - path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,) + path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) return self.client.delete_json( destination=destination, @@ -838,7 +817,7 @@ class TransportLayerClient(object): def get_group_roles(self, destination, group_id, requester_user_id): """Get all roles in a group """ - path = _create_v1_path("/groups/%s/roles", group_id,) + path = _create_v1_path("/groups/%s/roles", group_id) return self.client.get_json( destination=destination, @@ -851,7 +830,7 @@ class TransportLayerClient(object): def get_group_role(self, destination, group_id, requester_user_id, role_id): """Get a roles info """ - path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,) + path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) return self.client.get_json( destination=destination, @@ -861,11 +840,12 @@ class TransportLayerClient(object): ) @log_function - def update_group_role(self, destination, group_id, requester_user_id, role_id, - content): + def update_group_role( + self, destination, group_id, requester_user_id, role_id, content + ): """Update a role in a group """ - path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,) + path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) return self.client.post_json( destination=destination, @@ -879,7 +859,7 @@ class TransportLayerClient(object): def delete_group_role(self, destination, group_id, requester_user_id, role_id): """Delete a role in a group """ - path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,) + path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) return self.client.delete_json( destination=destination, @@ -889,17 +869,17 @@ class TransportLayerClient(object): ) @log_function - def update_group_summary_user(self, destination, group_id, requester_user_id, - user_id, role_id, content): + def update_group_summary_user( + self, destination, group_id, requester_user_id, user_id, role_id, content + ): """Update a users entry in a group """ if role_id: path = _create_v1_path( - "/groups/%s/summary/roles/%s/users/%s", - group_id, role_id, user_id, + "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id ) else: - path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id,) + path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id) return self.client.post_json( destination=destination, @@ -910,11 +890,10 @@ class TransportLayerClient(object): ) @log_function - def set_group_join_policy(self, destination, group_id, requester_user_id, - content): + def set_group_join_policy(self, destination, group_id, requester_user_id, content): """Sets the join policy for a group """ - path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id,) + path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id) return self.client.put_json( destination=destination, @@ -925,17 +904,17 @@ class TransportLayerClient(object): ) @log_function - def delete_group_summary_user(self, destination, group_id, requester_user_id, - user_id, role_id): + def delete_group_summary_user( + self, destination, group_id, requester_user_id, user_id, role_id + ): """Delete a users entry in a group """ if role_id: path = _create_v1_path( - "/groups/%s/summary/roles/%s/users/%s", - group_id, role_id, user_id, + "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id ) else: - path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id,) + path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id) return self.client.delete_json( destination=destination, @@ -953,10 +932,7 @@ class TransportLayerClient(object): content = {"user_ids": user_ids} return self.client.post_json( - destination=destination, - path=path, - data=content, - ignore_backoff=True, + destination=destination, path=path, data=content, ignore_backoff=True ) @@ -975,9 +951,8 @@ def _create_v1_path(path, *args): Returns: str """ - return ( - FEDERATION_V1_PREFIX - + path % tuple(urllib.parse.quote(arg, "") for arg in args) + return FEDERATION_V1_PREFIX + path % tuple( + urllib.parse.quote(arg, "") for arg in args ) @@ -996,7 +971,6 @@ def _create_v2_path(path, *args): Returns: str """ - return ( - FEDERATION_V2_PREFIX - + path % tuple(urllib.parse.quote(arg, "") for arg in args) + return FEDERATION_V2_PREFIX + path % tuple( + urllib.parse.quote(arg, "") for arg in args ) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 949a5fb2aa..b4854e82f6 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -66,8 +66,7 @@ class TransportLayerServer(JsonResource): self.authenticator = Authenticator(hs) self.ratelimiter = FederationRateLimiter( - self.clock, - config=hs.config.rc_federation, + self.clock, config=hs.config.rc_federation ) self.register_servlets() @@ -84,11 +83,13 @@ class TransportLayerServer(JsonResource): class AuthenticationError(SynapseError): """There was a problem authenticating the request""" + pass class NoAuthenticationError(AuthenticationError): """The request had no authentication information""" + pass @@ -105,8 +106,8 @@ class Authenticator(object): def authenticate_request(self, request, content): now = self._clock.time_msec() json_request = { - "method": request.method.decode('ascii'), - "uri": request.uri.decode('ascii'), + "method": request.method.decode("ascii"), + "uri": request.uri.decode("ascii"), "destination": self.server_name, "signatures": {}, } @@ -120,7 +121,7 @@ class Authenticator(object): if not auth_headers: raise NoAuthenticationError( - 401, "Missing Authorization headers", Codes.UNAUTHORIZED, + 401, "Missing Authorization headers", Codes.UNAUTHORIZED ) for auth in auth_headers: @@ -130,14 +131,14 @@ class Authenticator(object): json_request["signatures"].setdefault(origin, {})[key] = sig if ( - self.federation_domain_whitelist is not None and - origin not in self.federation_domain_whitelist + self.federation_domain_whitelist is not None + and origin not in self.federation_domain_whitelist ): raise FederationDeniedError(origin) if not json_request["signatures"]: raise NoAuthenticationError( - 401, "Missing Authorization headers", Codes.UNAUTHORIZED, + 401, "Missing Authorization headers", Codes.UNAUTHORIZED ) yield self.keyring.verify_json_for_server( @@ -177,12 +178,12 @@ def _parse_auth_header(header_bytes): AuthenticationError if the header could not be parsed """ try: - header_str = header_bytes.decode('utf-8') + header_str = header_bytes.decode("utf-8") params = header_str.split(" ")[1].split(",") param_dict = dict(kv.split("=") for kv in params) def strip_quotes(value): - if value.startswith("\""): + if value.startswith('"'): return value[1:-1] else: return value @@ -198,11 +199,11 @@ def _parse_auth_header(header_bytes): except Exception as e: logger.warn( "Error parsing auth header '%s': %s", - header_bytes.decode('ascii', 'replace'), + header_bytes.decode("ascii", "replace"), e, ) raise AuthenticationError( - 400, "Malformed Authorization header", Codes.UNAUTHORIZED, + 400, "Malformed Authorization header", Codes.UNAUTHORIZED ) @@ -242,6 +243,7 @@ class BaseFederationServlet(object): Exception: other exceptions will be caught, logged, and a 500 will be returned. """ + REQUIRE_AUTH = True PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version @@ -293,9 +295,7 @@ class BaseFederationServlet(object): origin, content, request.args, *args, **kwargs ) else: - response = yield func( - origin, content, request.args, *args, **kwargs - ) + response = yield func(origin, content, request.args, *args, **kwargs) defer.returnValue(response) @@ -343,14 +343,12 @@ class FederationSendServlet(BaseFederationServlet): try: transaction_data = content - logger.debug( - "Decoded %s: %s", - transaction_id, str(transaction_data) - ) + logger.debug("Decoded %s: %s", transaction_id, str(transaction_data)) logger.info( "Received txn %s from %s. (PDUs: %d, EDUs: %d)", - transaction_id, origin, + transaction_id, + origin, len(transaction_data.get("pdus", [])), len(transaction_data.get("edus", [])), ) @@ -361,8 +359,7 @@ class FederationSendServlet(BaseFederationServlet): # Add some extra data to the transaction dict that isn't included # in the request body. transaction_data.update( - transaction_id=transaction_id, - destination=self.server_name + transaction_id=transaction_id, destination=self.server_name ) except Exception as e: @@ -372,7 +369,7 @@ class FederationSendServlet(BaseFederationServlet): try: code, response = yield self.handler.on_incoming_transaction( - origin, transaction_data, + origin, transaction_data ) except Exception: logger.exception("on_incoming_transaction failed") @@ -416,7 +413,7 @@ class FederationBackfillServlet(BaseFederationServlet): PATH = "/backfill/(?P[^/]*)/?" def on_GET(self, origin, content, query, context): - versions = [x.decode('ascii') for x in query[b"v"]] + versions = [x.decode("ascii") for x in query[b"v"]] limit = parse_integer_from_args(query, "limit", None) if not limit: @@ -432,7 +429,7 @@ class FederationQueryServlet(BaseFederationServlet): def on_GET(self, origin, content, query, query_type): return self.handler.on_query_request( query_type, - {k.decode('utf8'): v[0].decode("utf-8") for k, v in query.items()} + {k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()}, ) @@ -456,15 +453,14 @@ class FederationMakeJoinServlet(BaseFederationServlet): Deferred[(int, object)|None]: either (response code, response object) to return a JSON response, or None if the request has already been handled. """ - versions = query.get(b'ver') + versions = query.get(b"ver") if versions is not None: supported_versions = [v.decode("utf-8") for v in versions] else: supported_versions = ["1"] content = yield self.handler.on_make_join_request( - origin, context, user_id, - supported_versions=supported_versions, + origin, context, user_id, supported_versions=supported_versions ) defer.returnValue((200, content)) @@ -474,9 +470,7 @@ class FederationMakeLeaveServlet(BaseFederationServlet): @defer.inlineCallbacks def on_GET(self, origin, content, query, context, user_id): - content = yield self.handler.on_make_leave_request( - origin, context, user_id, - ) + content = yield self.handler.on_make_leave_request(origin, context, user_id) defer.returnValue((200, content)) @@ -517,7 +511,7 @@ class FederationV1InviteServlet(BaseFederationServlet): # state resolution algorithm, and we don't use that for processing # invites content = yield self.handler.on_invite_request( - origin, content, room_version=RoomVersions.V1.identifier, + origin, content, room_version=RoomVersions.V1.identifier ) # V1 federation API is defined to return a content of `[200, {...}]` @@ -545,7 +539,7 @@ class FederationV2InviteServlet(BaseFederationServlet): event.setdefault("unsigned", {})["invite_room_state"] = invite_room_state content = yield self.handler.on_invite_request( - origin, event, room_version=room_version, + origin, event, room_version=room_version ) defer.returnValue((200, content)) @@ -629,8 +623,10 @@ class On3pidBindServlet(BaseFederationServlet): for invite in content["invites"]: try: if "signed" not in invite or "token" not in invite["signed"]: - message = ("Rejecting received notification of third-" - "party invite without signed: %s" % (invite,)) + message = ( + "Rejecting received notification of third-" + "party invite without signed: %s" % (invite,) + ) logger.info(message) raise SynapseError(400, message) yield self.handler.exchange_third_party_invite( @@ -671,18 +667,23 @@ class OpenIdUserInfo(BaseFederationServlet): def on_GET(self, origin, content, query): token = query.get(b"access_token", [None])[0] if token is None: - defer.returnValue((401, { - "errcode": "M_MISSING_TOKEN", "error": "Access Token required" - })) + defer.returnValue( + (401, {"errcode": "M_MISSING_TOKEN", "error": "Access Token required"}) + ) return - user_id = yield self.handler.on_openid_userinfo(token.decode('ascii')) + user_id = yield self.handler.on_openid_userinfo(token.decode("ascii")) if user_id is None: - defer.returnValue((401, { - "errcode": "M_UNKNOWN_TOKEN", - "error": "Access Token unknown or expired" - })) + defer.returnValue( + ( + 401, + { + "errcode": "M_UNKNOWN_TOKEN", + "error": "Access Token unknown or expired", + }, + ) + ) defer.returnValue((200, {"sub": user_id})) @@ -722,7 +723,7 @@ class PublicRoomList(BaseFederationServlet): def __init__(self, handler, authenticator, ratelimiter, server_name, deny_access): super(PublicRoomList, self).__init__( - handler, authenticator, ratelimiter, server_name, + handler, authenticator, ratelimiter, server_name ) self.deny_access = deny_access @@ -748,9 +749,7 @@ class PublicRoomList(BaseFederationServlet): network_tuple = ThirdPartyInstanceID(None, None) data = yield self.handler.get_local_public_room_list( - limit, since_token, - network_tuple=network_tuple, - from_federation=True, + limit, since_token, network_tuple=network_tuple, from_federation=True ) defer.returnValue((200, data)) @@ -761,17 +760,18 @@ class FederationVersionServlet(BaseFederationServlet): REQUIRE_AUTH = False def on_GET(self, origin, content, query): - return defer.succeed((200, { - "server": { - "name": "Synapse", - "version": get_version_string(synapse) - }, - })) + return defer.succeed( + ( + 200, + {"server": {"name": "Synapse", "version": get_version_string(synapse)}}, + ) + ) class FederationGroupsProfileServlet(BaseFederationServlet): """Get/set the basic profile of a group on behalf of a user """ + PATH = "/groups/(?P[^/]*)/profile" @defer.inlineCallbacks @@ -780,9 +780,7 @@ class FederationGroupsProfileServlet(BaseFederationServlet): if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - new_content = yield self.handler.get_group_profile( - group_id, requester_user_id - ) + new_content = yield self.handler.get_group_profile(group_id, requester_user_id) defer.returnValue((200, new_content)) @@ -808,9 +806,7 @@ class FederationGroupsSummaryServlet(BaseFederationServlet): if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - new_content = yield self.handler.get_group_summary( - group_id, requester_user_id - ) + new_content = yield self.handler.get_group_summary(group_id, requester_user_id) defer.returnValue((200, new_content)) @@ -818,6 +814,7 @@ class FederationGroupsSummaryServlet(BaseFederationServlet): class FederationGroupsRoomsServlet(BaseFederationServlet): """Get the rooms in a group on behalf of a user """ + PATH = "/groups/(?P[^/]*)/rooms" @defer.inlineCallbacks @@ -826,9 +823,7 @@ class FederationGroupsRoomsServlet(BaseFederationServlet): if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - new_content = yield self.handler.get_rooms_in_group( - group_id, requester_user_id - ) + new_content = yield self.handler.get_rooms_in_group(group_id, requester_user_id) defer.returnValue((200, new_content)) @@ -836,6 +831,7 @@ class FederationGroupsRoomsServlet(BaseFederationServlet): class FederationGroupsAddRoomsServlet(BaseFederationServlet): """Add/remove room from group """ + PATH = "/groups/(?P[^/]*)/room/(?P[^/]*)" @defer.inlineCallbacks @@ -857,7 +853,7 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet): raise SynapseError(403, "requester_user_id doesn't match origin") new_content = yield self.handler.remove_room_from_group( - group_id, requester_user_id, room_id, + group_id, requester_user_id, room_id ) defer.returnValue((200, new_content)) @@ -866,6 +862,7 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet): class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet): """Update room config in group """ + PATH = ( "/groups/(?P[^/]*)/room/(?P[^/]*)" "/config/(?P[^/]*)" @@ -878,7 +875,7 @@ class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet): raise SynapseError(403, "requester_user_id doesn't match origin") result = yield self.groups_handler.update_room_in_group( - group_id, requester_user_id, room_id, config_key, content, + group_id, requester_user_id, room_id, config_key, content ) defer.returnValue((200, result)) @@ -887,6 +884,7 @@ class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet): class FederationGroupsUsersServlet(BaseFederationServlet): """Get the users in a group on behalf of a user """ + PATH = "/groups/(?P[^/]*)/users" @defer.inlineCallbacks @@ -895,9 +893,7 @@ class FederationGroupsUsersServlet(BaseFederationServlet): if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - new_content = yield self.handler.get_users_in_group( - group_id, requester_user_id - ) + new_content = yield self.handler.get_users_in_group(group_id, requester_user_id) defer.returnValue((200, new_content)) @@ -905,6 +901,7 @@ class FederationGroupsUsersServlet(BaseFederationServlet): class FederationGroupsInvitedUsersServlet(BaseFederationServlet): """Get the users that have been invited to a group """ + PATH = "/groups/(?P[^/]*)/invited_users" @defer.inlineCallbacks @@ -923,6 +920,7 @@ class FederationGroupsInvitedUsersServlet(BaseFederationServlet): class FederationGroupsInviteServlet(BaseFederationServlet): """Ask a group server to invite someone to the group """ + PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/invite" @defer.inlineCallbacks @@ -932,7 +930,7 @@ class FederationGroupsInviteServlet(BaseFederationServlet): raise SynapseError(403, "requester_user_id doesn't match origin") new_content = yield self.handler.invite_to_group( - group_id, user_id, requester_user_id, content, + group_id, user_id, requester_user_id, content ) defer.returnValue((200, new_content)) @@ -941,6 +939,7 @@ class FederationGroupsInviteServlet(BaseFederationServlet): class FederationGroupsAcceptInviteServlet(BaseFederationServlet): """Accept an invitation from the group server """ + PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/accept_invite" @defer.inlineCallbacks @@ -948,9 +947,7 @@ class FederationGroupsAcceptInviteServlet(BaseFederationServlet): if get_domain_from_id(user_id) != origin: raise SynapseError(403, "user_id doesn't match origin") - new_content = yield self.handler.accept_invite( - group_id, user_id, content, - ) + new_content = yield self.handler.accept_invite(group_id, user_id, content) defer.returnValue((200, new_content)) @@ -958,6 +955,7 @@ class FederationGroupsAcceptInviteServlet(BaseFederationServlet): class FederationGroupsJoinServlet(BaseFederationServlet): """Attempt to join a group """ + PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/join" @defer.inlineCallbacks @@ -965,9 +963,7 @@ class FederationGroupsJoinServlet(BaseFederationServlet): if get_domain_from_id(user_id) != origin: raise SynapseError(403, "user_id doesn't match origin") - new_content = yield self.handler.join_group( - group_id, user_id, content, - ) + new_content = yield self.handler.join_group(group_id, user_id, content) defer.returnValue((200, new_content)) @@ -975,6 +971,7 @@ class FederationGroupsJoinServlet(BaseFederationServlet): class FederationGroupsRemoveUserServlet(BaseFederationServlet): """Leave or kick a user from the group """ + PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/remove" @defer.inlineCallbacks @@ -984,7 +981,7 @@ class FederationGroupsRemoveUserServlet(BaseFederationServlet): raise SynapseError(403, "requester_user_id doesn't match origin") new_content = yield self.handler.remove_user_from_group( - group_id, user_id, requester_user_id, content, + group_id, user_id, requester_user_id, content ) defer.returnValue((200, new_content)) @@ -993,6 +990,7 @@ class FederationGroupsRemoveUserServlet(BaseFederationServlet): class FederationGroupsLocalInviteServlet(BaseFederationServlet): """A group server has invited a local user """ + PATH = "/groups/local/(?P[^/]*)/users/(?P[^/]*)/invite" @defer.inlineCallbacks @@ -1000,9 +998,7 @@ class FederationGroupsLocalInviteServlet(BaseFederationServlet): if get_domain_from_id(group_id) != origin: raise SynapseError(403, "group_id doesn't match origin") - new_content = yield self.handler.on_invite( - group_id, user_id, content, - ) + new_content = yield self.handler.on_invite(group_id, user_id, content) defer.returnValue((200, new_content)) @@ -1010,6 +1006,7 @@ class FederationGroupsLocalInviteServlet(BaseFederationServlet): class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet): """A group server has removed a local user """ + PATH = "/groups/local/(?P[^/]*)/users/(?P[^/]*)/remove" @defer.inlineCallbacks @@ -1018,7 +1015,7 @@ class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet): raise SynapseError(403, "user_id doesn't match origin") new_content = yield self.handler.user_removed_from_group( - group_id, user_id, content, + group_id, user_id, content ) defer.returnValue((200, new_content)) @@ -1027,6 +1024,7 @@ class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet): class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): """A group or user's server renews their attestation """ + PATH = "/groups/(?P[^/]*)/renew_attestation/(?P[^/]*)" @defer.inlineCallbacks @@ -1047,6 +1045,7 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet): - /groups/:group/summary/rooms/:room_id - /groups/:group/summary/categories/:category/rooms/:room_id """ + PATH = ( "/groups/(?P[^/]*)/summary" "(/categories/(?P[^/]+))?" @@ -1063,7 +1062,8 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet): raise SynapseError(400, "category_id cannot be empty string") resp = yield self.handler.update_group_summary_room( - group_id, requester_user_id, + group_id, + requester_user_id, room_id=room_id, category_id=category_id, content=content, @@ -1081,9 +1081,7 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet): raise SynapseError(400, "category_id cannot be empty string") resp = yield self.handler.delete_group_summary_room( - group_id, requester_user_id, - room_id=room_id, - category_id=category_id, + group_id, requester_user_id, room_id=room_id, category_id=category_id ) defer.returnValue((200, resp)) @@ -1092,9 +1090,8 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet): class FederationGroupsCategoriesServlet(BaseFederationServlet): """Get all categories for a group """ - PATH = ( - "/groups/(?P[^/]*)/categories/?" - ) + + PATH = "/groups/(?P[^/]*)/categories/?" @defer.inlineCallbacks def on_GET(self, origin, content, query, group_id): @@ -1102,9 +1099,7 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet): if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - resp = yield self.handler.get_group_categories( - group_id, requester_user_id, - ) + resp = yield self.handler.get_group_categories(group_id, requester_user_id) defer.returnValue((200, resp)) @@ -1112,9 +1107,8 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet): class FederationGroupsCategoryServlet(BaseFederationServlet): """Add/remove/get a category in a group """ - PATH = ( - "/groups/(?P[^/]*)/categories/(?P[^/]+)" - ) + + PATH = "/groups/(?P[^/]*)/categories/(?P[^/]+)" @defer.inlineCallbacks def on_GET(self, origin, content, query, group_id, category_id): @@ -1138,7 +1132,7 @@ class FederationGroupsCategoryServlet(BaseFederationServlet): raise SynapseError(400, "category_id cannot be empty string") resp = yield self.handler.upsert_group_category( - group_id, requester_user_id, category_id, content, + group_id, requester_user_id, category_id, content ) defer.returnValue((200, resp)) @@ -1153,7 +1147,7 @@ class FederationGroupsCategoryServlet(BaseFederationServlet): raise SynapseError(400, "category_id cannot be empty string") resp = yield self.handler.delete_group_category( - group_id, requester_user_id, category_id, + group_id, requester_user_id, category_id ) defer.returnValue((200, resp)) @@ -1162,9 +1156,8 @@ class FederationGroupsCategoryServlet(BaseFederationServlet): class FederationGroupsRolesServlet(BaseFederationServlet): """Get roles in a group """ - PATH = ( - "/groups/(?P[^/]*)/roles/?" - ) + + PATH = "/groups/(?P[^/]*)/roles/?" @defer.inlineCallbacks def on_GET(self, origin, content, query, group_id): @@ -1172,9 +1165,7 @@ class FederationGroupsRolesServlet(BaseFederationServlet): if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - resp = yield self.handler.get_group_roles( - group_id, requester_user_id, - ) + resp = yield self.handler.get_group_roles(group_id, requester_user_id) defer.returnValue((200, resp)) @@ -1182,9 +1173,8 @@ class FederationGroupsRolesServlet(BaseFederationServlet): class FederationGroupsRoleServlet(BaseFederationServlet): """Add/remove/get a role in a group """ - PATH = ( - "/groups/(?P[^/]*)/roles/(?P[^/]+)" - ) + + PATH = "/groups/(?P[^/]*)/roles/(?P[^/]+)" @defer.inlineCallbacks def on_GET(self, origin, content, query, group_id, role_id): @@ -1192,9 +1182,7 @@ class FederationGroupsRoleServlet(BaseFederationServlet): if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - resp = yield self.handler.get_group_role( - group_id, requester_user_id, role_id - ) + resp = yield self.handler.get_group_role(group_id, requester_user_id, role_id) defer.returnValue((200, resp)) @@ -1208,7 +1196,7 @@ class FederationGroupsRoleServlet(BaseFederationServlet): raise SynapseError(400, "role_id cannot be empty string") resp = yield self.handler.update_group_role( - group_id, requester_user_id, role_id, content, + group_id, requester_user_id, role_id, content ) defer.returnValue((200, resp)) @@ -1223,7 +1211,7 @@ class FederationGroupsRoleServlet(BaseFederationServlet): raise SynapseError(400, "role_id cannot be empty string") resp = yield self.handler.delete_group_role( - group_id, requester_user_id, role_id, + group_id, requester_user_id, role_id ) defer.returnValue((200, resp)) @@ -1236,6 +1224,7 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet): - /groups/:group/summary/users/:user_id - /groups/:group/summary/roles/:role/users/:user_id """ + PATH = ( "/groups/(?P[^/]*)/summary" "(/roles/(?P[^/]+))?" @@ -1252,7 +1241,8 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet): raise SynapseError(400, "role_id cannot be empty string") resp = yield self.handler.update_group_summary_user( - group_id, requester_user_id, + group_id, + requester_user_id, user_id=user_id, role_id=role_id, content=content, @@ -1270,9 +1260,7 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet): raise SynapseError(400, "role_id cannot be empty string") resp = yield self.handler.delete_group_summary_user( - group_id, requester_user_id, - user_id=user_id, - role_id=role_id, + group_id, requester_user_id, user_id=user_id, role_id=role_id ) defer.returnValue((200, resp)) @@ -1281,14 +1269,13 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet): class FederationGroupsBulkPublicisedServlet(BaseFederationServlet): """Get roles in a group """ - PATH = ( - "/get_groups_publicised" - ) + + PATH = "/get_groups_publicised" @defer.inlineCallbacks def on_POST(self, origin, content, query): resp = yield self.handler.bulk_get_publicised_groups( - content["user_ids"], proxy=False, + content["user_ids"], proxy=False ) defer.returnValue((200, resp)) @@ -1297,6 +1284,7 @@ class FederationGroupsBulkPublicisedServlet(BaseFederationServlet): class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet): """Sets whether a group is joinable without an invite or knock """ + PATH = "/groups/(?P[^/]*)/settings/m.join_policy" @defer.inlineCallbacks @@ -1317,6 +1305,7 @@ class RoomComplexityServlet(BaseFederationServlet): Indicates to other servers how complex (and therefore likely resource-intensive) a public room this server knows about is. """ + PATH = "/rooms/(?P[^/]*)/complexity" PREFIX = FEDERATION_UNSTABLE_PREFIX @@ -1325,9 +1314,7 @@ class RoomComplexityServlet(BaseFederationServlet): store = self.handler.hs.get_datastore() - is_public = yield store.is_room_world_readable_or_publicly_joinable( - room_id - ) + is_public = yield store.is_room_world_readable_or_publicly_joinable(room_id) if not is_public: raise SynapseError(404, "Room not found", errcode=Codes.INVALID_PARAM) @@ -1362,13 +1349,9 @@ FEDERATION_SERVLET_CLASSES = ( RoomComplexityServlet, ) -OPENID_SERVLET_CLASSES = ( - OpenIdUserInfo, -) +OPENID_SERVLET_CLASSES = (OpenIdUserInfo,) -ROOM_LIST_CLASSES = ( - PublicRoomList, -) +ROOM_LIST_CLASSES = (PublicRoomList,) GROUP_SERVER_SERVLET_CLASSES = ( FederationGroupsProfileServlet, @@ -1399,9 +1382,7 @@ GROUP_LOCAL_SERVLET_CLASSES = ( ) -GROUP_ATTESTATION_SERVLET_CLASSES = ( - FederationGroupsRenewAttestaionServlet, -) +GROUP_ATTESTATION_SERVLET_CLASSES = (FederationGroupsRenewAttestaionServlet,) DEFAULT_SERVLET_GROUPS = ( "federation", diff --git a/synapse/federation/units.py b/synapse/federation/units.py index 025a79c022..14aad8f09d 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -32,21 +32,11 @@ class Edu(JsonEncodedObject): internal ID or previous references graph. """ - valid_keys = [ - "origin", - "destination", - "edu_type", - "content", - ] + valid_keys = ["origin", "destination", "edu_type", "content"] - required_keys = [ - "edu_type", - ] + required_keys = ["edu_type"] - internal_keys = [ - "origin", - "destination", - ] + internal_keys = ["origin", "destination"] class Transaction(JsonEncodedObject): @@ -75,10 +65,7 @@ class Transaction(JsonEncodedObject): "edus", ] - internal_keys = [ - "transaction_id", - "destination", - ] + internal_keys = ["transaction_id", "destination"] required_keys = [ "transaction_id", @@ -98,9 +85,7 @@ class Transaction(JsonEncodedObject): del kwargs["edus"] super(Transaction, self).__init__( - transaction_id=transaction_id, - pdus=pdus, - **kwargs + transaction_id=transaction_id, pdus=pdus, **kwargs ) @staticmethod @@ -109,13 +94,9 @@ class Transaction(JsonEncodedObject): transaction_id and origin_server_ts keys. """ if "origin_server_ts" not in kwargs: - raise KeyError( - "Require 'origin_server_ts' to construct a Transaction" - ) + raise KeyError("Require 'origin_server_ts' to construct a Transaction") if "transaction_id" not in kwargs: - raise KeyError( - "Require 'transaction_id' to construct a Transaction" - ) + raise KeyError("Require 'transaction_id' to construct a Transaction") kwargs["pdus"] = [p.get_pdu_json() for p in pdus] diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index 3ba74001d8..e73757570c 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -65,6 +65,7 @@ UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000 class GroupAttestationSigning(object): """Creates and verifies group attestations. """ + def __init__(self, hs): self.keyring = hs.get_keyring() self.clock = hs.get_clock() @@ -113,11 +114,15 @@ class GroupAttestationSigning(object): validity_period *= random.uniform(*DEFAULT_ATTESTATION_JITTER) valid_until_ms = int(self.clock.time_msec() + validity_period) - return sign_json({ - "group_id": group_id, - "user_id": user_id, - "valid_until_ms": valid_until_ms, - }, self.server_name, self.signing_key) + return sign_json( + { + "group_id": group_id, + "user_id": user_id, + "valid_until_ms": valid_until_ms, + }, + self.server_name, + self.signing_key, + ) class GroupAttestionRenewer(object): @@ -134,7 +139,7 @@ class GroupAttestionRenewer(object): if not hs.config.worker_app: self._renew_attestations_loop = self.clock.looping_call( - self._start_renew_attestations, 30 * 60 * 1000, + self._start_renew_attestations, 30 * 60 * 1000 ) @defer.inlineCallbacks @@ -147,9 +152,7 @@ class GroupAttestionRenewer(object): raise SynapseError(400, "Neither user not group are on this server") yield self.attestations.verify_attestation( - attestation, - user_id=user_id, - group_id=group_id, + attestation, user_id=user_id, group_id=group_id ) yield self.store.update_remote_attestion(group_id, user_id, attestation) @@ -180,7 +183,8 @@ class GroupAttestionRenewer(object): else: logger.warn( "Incorrectly trying to do attestations for user: %r in %r", - user_id, group_id, + user_id, + group_id, ) yield self.store.remove_attestation_renewal(group_id, user_id) return @@ -188,8 +192,7 @@ class GroupAttestionRenewer(object): attestation = self.attestations.create_attestation(group_id, user_id) yield self.transport_client.renew_group_attestation( - destination, group_id, user_id, - content={"attestation": attestation}, + destination, group_id, user_id, content={"attestation": attestation} ) yield self.store.update_attestation_renewal( @@ -197,12 +200,12 @@ class GroupAttestionRenewer(object): ) except (RequestSendFailed, HttpResponseException) as e: logger.warning( - "Failed to renew attestation of %r in %r: %s", - user_id, group_id, e, + "Failed to renew attestation of %r in %r: %s", user_id, group_id, e ) except Exception: - logger.exception("Error renewing attestation of %r in %r", - user_id, group_id) + logger.exception( + "Error renewing attestation of %r in %r", user_id, group_id + ) for row in rows: group_id = row["group_id"] diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 817be40360..168c9e3f84 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -54,8 +54,9 @@ class GroupsServerHandler(object): hs.get_groups_attestation_renewer() @defer.inlineCallbacks - def check_group_is_ours(self, group_id, requester_user_id, - and_exists=False, and_is_admin=None): + def check_group_is_ours( + self, group_id, requester_user_id, and_exists=False, and_is_admin=None + ): """Check that the group is ours, and optionally if it exists. If group does exist then return group. @@ -73,7 +74,9 @@ class GroupsServerHandler(object): if and_exists and not group: raise SynapseError(404, "Unknown group") - is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) + is_user_in_group = yield self.store.is_user_in_group( + requester_user_id, group_id + ) if group and not is_user_in_group and not group["is_public"]: raise SynapseError(404, "Unknown group") @@ -96,25 +99,27 @@ class GroupsServerHandler(object): """ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) + is_user_in_group = yield self.store.is_user_in_group( + requester_user_id, group_id + ) profile = yield self.get_group_profile(group_id, requester_user_id) users, roles = yield self.store.get_users_for_summary_by_role( - group_id, include_private=is_user_in_group, + group_id, include_private=is_user_in_group ) # TODO: Add profiles to users rooms, categories = yield self.store.get_rooms_for_summary_by_category( - group_id, include_private=is_user_in_group, + group_id, include_private=is_user_in_group ) for room_entry in rooms: room_id = room_entry["room_id"] joined_users = yield self.store.get_users_in_room(room_id) entry = yield self.room_list_handler.generate_room_entry( - room_id, len(joined_users), with_alias=False, allow_private=True, + room_id, len(joined_users), with_alias=False, allow_private=True ) entry = dict(entry) # so we don't change whats cached entry.pop("room_id", None) @@ -134,7 +139,7 @@ class GroupsServerHandler(object): entry["attestation"] = attestation else: entry["attestation"] = self.attestations.create_attestation( - group_id, user_id, + group_id, user_id ) user_profile = yield self.profile_handler.get_profile_from_cache(user_id) @@ -143,34 +148,34 @@ class GroupsServerHandler(object): users.sort(key=lambda e: e.get("order", 0)) membership_info = yield self.store.get_users_membership_info_in_group( - group_id, requester_user_id, + group_id, requester_user_id ) - defer.returnValue({ - "profile": profile, - "users_section": { - "users": users, - "roles": roles, - "total_user_count_estimate": 0, # TODO - }, - "rooms_section": { - "rooms": rooms, - "categories": categories, - "total_room_count_estimate": 0, # TODO - }, - "user": membership_info, - }) + defer.returnValue( + { + "profile": profile, + "users_section": { + "users": users, + "roles": roles, + "total_user_count_estimate": 0, # TODO + }, + "rooms_section": { + "rooms": rooms, + "categories": categories, + "total_room_count_estimate": 0, # TODO + }, + "user": membership_info, + } + ) @defer.inlineCallbacks - def update_group_summary_room(self, group_id, requester_user_id, - room_id, category_id, content): + def update_group_summary_room( + self, group_id, requester_user_id, room_id, category_id, content + ): """Add/update a room to the group summary """ yield self.check_group_is_ours( - group_id, - requester_user_id, - and_exists=True, - and_is_admin=requester_user_id, + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) RoomID.from_string(room_id) # Ensure valid room id @@ -190,21 +195,17 @@ class GroupsServerHandler(object): defer.returnValue({}) @defer.inlineCallbacks - def delete_group_summary_room(self, group_id, requester_user_id, - room_id, category_id): + def delete_group_summary_room( + self, group_id, requester_user_id, room_id, category_id + ): """Remove a room from the summary """ yield self.check_group_is_ours( - group_id, - requester_user_id, - and_exists=True, - and_is_admin=requester_user_id, + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) yield self.store.remove_room_from_summary( - group_id=group_id, - room_id=room_id, - category_id=category_id, + group_id=group_id, room_id=room_id, category_id=category_id ) defer.returnValue({}) @@ -223,9 +224,7 @@ class GroupsServerHandler(object): join_policy = _parse_join_policy_from_contents(content) if join_policy is None: - raise SynapseError( - 400, "No value specified for 'm.join_policy'" - ) + raise SynapseError(400, "No value specified for 'm.join_policy'") yield self.store.set_group_join_policy(group_id, join_policy=join_policy) @@ -237,9 +236,7 @@ class GroupsServerHandler(object): """ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - categories = yield self.store.get_group_categories( - group_id=group_id, - ) + categories = yield self.store.get_group_categories(group_id=group_id) defer.returnValue({"categories": categories}) @defer.inlineCallbacks @@ -249,8 +246,7 @@ class GroupsServerHandler(object): yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) res = yield self.store.get_group_category( - group_id=group_id, - category_id=category_id, + group_id=group_id, category_id=category_id ) defer.returnValue(res) @@ -260,10 +256,7 @@ class GroupsServerHandler(object): """Add/Update a group category """ yield self.check_group_is_ours( - group_id, - requester_user_id, - and_exists=True, - and_is_admin=requester_user_id, + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) is_public = _parse_visibility_from_contents(content) @@ -283,15 +276,11 @@ class GroupsServerHandler(object): """Delete a group category """ yield self.check_group_is_ours( - group_id, - requester_user_id, - and_exists=True, - and_is_admin=requester_user_id + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) yield self.store.remove_group_category( - group_id=group_id, - category_id=category_id, + group_id=group_id, category_id=category_id ) defer.returnValue({}) @@ -302,9 +291,7 @@ class GroupsServerHandler(object): """ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - roles = yield self.store.get_group_roles( - group_id=group_id, - ) + roles = yield self.store.get_group_roles(group_id=group_id) defer.returnValue({"roles": roles}) @defer.inlineCallbacks @@ -313,10 +300,7 @@ class GroupsServerHandler(object): """ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - res = yield self.store.get_group_role( - group_id=group_id, - role_id=role_id, - ) + res = yield self.store.get_group_role(group_id=group_id, role_id=role_id) defer.returnValue(res) @defer.inlineCallbacks @@ -324,10 +308,7 @@ class GroupsServerHandler(object): """Add/update a role in a group """ yield self.check_group_is_ours( - group_id, - requester_user_id, - and_exists=True, - and_is_admin=requester_user_id, + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) is_public = _parse_visibility_from_contents(content) @@ -335,10 +316,7 @@ class GroupsServerHandler(object): profile = content.get("profile") yield self.store.upsert_group_role( - group_id=group_id, - role_id=role_id, - is_public=is_public, - profile=profile, + group_id=group_id, role_id=role_id, is_public=is_public, profile=profile ) defer.returnValue({}) @@ -348,26 +326,21 @@ class GroupsServerHandler(object): """Remove role from group """ yield self.check_group_is_ours( - group_id, - requester_user_id, - and_exists=True, - and_is_admin=requester_user_id, + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) - yield self.store.remove_group_role( - group_id=group_id, - role_id=role_id, - ) + yield self.store.remove_group_role(group_id=group_id, role_id=role_id) defer.returnValue({}) @defer.inlineCallbacks - def update_group_summary_user(self, group_id, requester_user_id, user_id, role_id, - content): + def update_group_summary_user( + self, group_id, requester_user_id, user_id, role_id, content + ): """Add/update a users entry in the group summary """ yield self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id, + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) order = content.get("order", None) @@ -389,13 +362,11 @@ class GroupsServerHandler(object): """Remove a user from the group summary """ yield self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id, + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) yield self.store.remove_user_from_summary( - group_id=group_id, - user_id=user_id, - role_id=role_id, + group_id=group_id, user_id=user_id, role_id=role_id ) defer.returnValue({}) @@ -411,8 +382,11 @@ class GroupsServerHandler(object): if group: cols = [ - "name", "short_description", "long_description", - "avatar_url", "is_public", + "name", + "short_description", + "long_description", + "avatar_url", + "is_public", ] group_description = {key: group[key] for key in cols} group_description["is_openly_joinable"] = group["join_policy"] == "open" @@ -426,12 +400,11 @@ class GroupsServerHandler(object): """Update the group profile """ yield self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id, + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) profile = {} - for keyname in ("name", "avatar_url", "short_description", - "long_description"): + for keyname in ("name", "avatar_url", "short_description", "long_description"): if keyname in content: value = content[keyname] if not isinstance(value, string_types): @@ -449,10 +422,12 @@ class GroupsServerHandler(object): yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) + is_user_in_group = yield self.store.is_user_in_group( + requester_user_id, group_id + ) user_results = yield self.store.get_users_in_group( - group_id, include_private=is_user_in_group, + group_id, include_private=is_user_in_group ) chunk = [] @@ -470,24 +445,25 @@ class GroupsServerHandler(object): entry["is_privileged"] = bool(is_privileged) if not self.is_mine_id(g_user_id): - attestation = yield self.store.get_remote_attestation(group_id, g_user_id) + attestation = yield self.store.get_remote_attestation( + group_id, g_user_id + ) if not attestation: continue entry["attestation"] = attestation else: entry["attestation"] = self.attestations.create_attestation( - group_id, g_user_id, + group_id, g_user_id ) chunk.append(entry) # TODO: If admin add lists of users whose attestations have timed out - defer.returnValue({ - "chunk": chunk, - "total_user_count_estimate": len(user_results), - }) + defer.returnValue( + {"chunk": chunk, "total_user_count_estimate": len(user_results)} + ) @defer.inlineCallbacks def get_invited_users_in_group(self, group_id, requester_user_id): @@ -498,7 +474,9 @@ class GroupsServerHandler(object): yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) + is_user_in_group = yield self.store.is_user_in_group( + requester_user_id, group_id + ) if not is_user_in_group: raise SynapseError(403, "User not in group") @@ -508,9 +486,7 @@ class GroupsServerHandler(object): user_profiles = [] for user_id in invited_users: - user_profile = { - "user_id": user_id - } + user_profile = {"user_id": user_id} try: profile = yield self.profile_handler.get_profile_from_cache(user_id) user_profile.update(profile) @@ -518,10 +494,9 @@ class GroupsServerHandler(object): logger.warn("Error getting profile for %s: %s", user_id, e) user_profiles.append(user_profile) - defer.returnValue({ - "chunk": user_profiles, - "total_user_count_estimate": len(invited_users), - }) + defer.returnValue( + {"chunk": user_profiles, "total_user_count_estimate": len(invited_users)} + ) @defer.inlineCallbacks def get_rooms_in_group(self, group_id, requester_user_id): @@ -532,10 +507,12 @@ class GroupsServerHandler(object): yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) + is_user_in_group = yield self.store.is_user_in_group( + requester_user_id, group_id + ) room_results = yield self.store.get_rooms_in_group( - group_id, include_private=is_user_in_group, + group_id, include_private=is_user_in_group ) chunk = [] @@ -544,7 +521,7 @@ class GroupsServerHandler(object): joined_users = yield self.store.get_users_in_room(room_id) entry = yield self.room_list_handler.generate_room_entry( - room_id, len(joined_users), with_alias=False, allow_private=True, + room_id, len(joined_users), with_alias=False, allow_private=True ) if not entry: @@ -556,10 +533,9 @@ class GroupsServerHandler(object): chunk.sort(key=lambda e: -e["num_joined_members"]) - defer.returnValue({ - "chunk": chunk, - "total_room_count_estimate": len(room_results), - }) + defer.returnValue( + {"chunk": chunk, "total_room_count_estimate": len(room_results)} + ) @defer.inlineCallbacks def add_room_to_group(self, group_id, requester_user_id, room_id, content): @@ -578,8 +554,9 @@ class GroupsServerHandler(object): defer.returnValue({}) @defer.inlineCallbacks - def update_room_in_group(self, group_id, requester_user_id, room_id, config_key, - content): + def update_room_in_group( + self, group_id, requester_user_id, room_id, config_key, content + ): """Update room in group """ RoomID.from_string(room_id) # Ensure valid room id @@ -592,8 +569,7 @@ class GroupsServerHandler(object): is_public = _parse_visibility_dict(content) yield self.store.update_room_in_group_visibility( - group_id, room_id, - is_public=is_public, + group_id, room_id, is_public=is_public ) else: raise SynapseError(400, "Uknown config option") @@ -625,10 +601,7 @@ class GroupsServerHandler(object): # TODO: Check if user is already invited content = { - "profile": { - "name": group["name"], - "avatar_url": group["avatar_url"], - }, + "profile": {"name": group["name"], "avatar_url": group["avatar_url"]}, "inviter": requester_user_id, } @@ -638,9 +611,7 @@ class GroupsServerHandler(object): local_attestation = None else: local_attestation = self.attestations.create_attestation(group_id, user_id) - content.update({ - "attestation": local_attestation, - }) + content.update({"attestation": local_attestation}) res = yield self.transport_client.invite_to_group_notification( get_domain_from_id(user_id), group_id, user_id, content @@ -658,31 +629,24 @@ class GroupsServerHandler(object): remote_attestation = res["attestation"] yield self.attestations.verify_attestation( - remote_attestation, - user_id=user_id, - group_id=group_id, + remote_attestation, user_id=user_id, group_id=group_id ) else: remote_attestation = None yield self.store.add_user_to_group( - group_id, user_id, + group_id, + user_id, is_admin=False, is_public=False, # TODO local_attestation=local_attestation, remote_attestation=remote_attestation, ) elif res["state"] == "invite": - yield self.store.add_group_invite( - group_id, user_id, - ) - defer.returnValue({ - "state": "invite" - }) + yield self.store.add_group_invite(group_id, user_id) + defer.returnValue({"state": "invite"}) elif res["state"] == "reject": - defer.returnValue({ - "state": "reject" - }) + defer.returnValue({"state": "reject"}) else: raise SynapseError(502, "Unknown state returned by HS") @@ -693,16 +657,12 @@ class GroupsServerHandler(object): See accept_invite, join_group. """ if not self.hs.is_mine_id(user_id): - local_attestation = self.attestations.create_attestation( - group_id, user_id, - ) + local_attestation = self.attestations.create_attestation(group_id, user_id) remote_attestation = content["attestation"] yield self.attestations.verify_attestation( - remote_attestation, - user_id=user_id, - group_id=group_id, + remote_attestation, user_id=user_id, group_id=group_id ) else: local_attestation = None @@ -711,7 +671,8 @@ class GroupsServerHandler(object): is_public = _parse_visibility_from_contents(content) yield self.store.add_user_to_group( - group_id, user_id, + group_id, + user_id, is_admin=False, is_public=is_public, local_attestation=local_attestation, @@ -731,17 +692,14 @@ class GroupsServerHandler(object): yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) is_invited = yield self.store.is_user_invited_to_local_group( - group_id, requester_user_id, + group_id, requester_user_id ) if not is_invited: raise SynapseError(403, "User not invited to group") local_attestation = yield self._add_user(group_id, requester_user_id, content) - defer.returnValue({ - "state": "join", - "attestation": local_attestation, - }) + defer.returnValue({"state": "join", "attestation": local_attestation}) @defer.inlineCallbacks def join_group(self, group_id, requester_user_id, content): @@ -753,15 +711,12 @@ class GroupsServerHandler(object): group_info = yield self.check_group_is_ours( group_id, requester_user_id, and_exists=True ) - if group_info['join_policy'] != "open": + if group_info["join_policy"] != "open": raise SynapseError(403, "Group is not publicly joinable") local_attestation = yield self._add_user(group_id, requester_user_id, content) - defer.returnValue({ - "state": "join", - "attestation": local_attestation, - }) + defer.returnValue({"state": "join", "attestation": local_attestation}) @defer.inlineCallbacks def knock(self, group_id, requester_user_id, content): @@ -800,9 +755,7 @@ class GroupsServerHandler(object): is_kick = True - yield self.store.remove_user_from_group( - group_id, user_id, - ) + yield self.store.remove_user_from_group(group_id, user_id) if is_kick: if self.hs.is_mine_id(user_id): @@ -830,19 +783,20 @@ class GroupsServerHandler(object): if group: raise SynapseError(400, "Group already exists") - is_admin = yield self.auth.is_server_admin(UserID.from_string(requester_user_id)) + is_admin = yield self.auth.is_server_admin( + UserID.from_string(requester_user_id) + ) if not is_admin: if not self.hs.config.enable_group_creation: raise SynapseError( - 403, "Only a server admin can create groups on this server", + 403, "Only a server admin can create groups on this server" ) localpart = group_id_obj.localpart if not localpart.startswith(self.hs.config.group_creation_prefix): raise SynapseError( 400, - "Can only create groups with prefix %r on this server" % ( - self.hs.config.group_creation_prefix, - ), + "Can only create groups with prefix %r on this server" + % (self.hs.config.group_creation_prefix,), ) profile = content.get("profile", {}) @@ -865,21 +819,19 @@ class GroupsServerHandler(object): remote_attestation = content["attestation"] yield self.attestations.verify_attestation( - remote_attestation, - user_id=requester_user_id, - group_id=group_id, + remote_attestation, user_id=requester_user_id, group_id=group_id ) local_attestation = self.attestations.create_attestation( - group_id, - requester_user_id, + group_id, requester_user_id ) else: local_attestation = None remote_attestation = None yield self.store.add_user_to_group( - group_id, requester_user_id, + group_id, + requester_user_id, is_admin=True, is_public=True, # TODO local_attestation=local_attestation, @@ -893,9 +845,7 @@ class GroupsServerHandler(object): avatar_url=user_profile.get("avatar_url"), ) - defer.returnValue({ - "group_id": group_id, - }) + defer.returnValue({"group_id": group_id}) @defer.inlineCallbacks def delete_group(self, group_id, requester_user_id): @@ -911,29 +861,22 @@ class GroupsServerHandler(object): Deferred """ - yield self.check_group_is_ours( - group_id, requester_user_id, - and_exists=True, - ) + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) # Only server admins or group admins can delete groups. - is_admin = yield self.store.is_user_admin_in_group( - group_id, requester_user_id - ) + is_admin = yield self.store.is_user_admin_in_group(group_id, requester_user_id) if not is_admin: is_admin = yield self.auth.is_server_admin( - UserID.from_string(requester_user_id), + UserID.from_string(requester_user_id) ) if not is_admin: raise SynapseError(403, "User is not an admin") # Before deleting the group lets kick everyone out of it - users = yield self.store.get_users_in_group( - group_id, include_private=True, - ) + users = yield self.store.get_users_in_group(group_id, include_private=True) @defer.inlineCallbacks def _kick_user_from_group(user_id): @@ -989,9 +932,7 @@ def _parse_join_policy_dict(join_policy_dict): return "invite" if join_policy_type not in ("invite", "open"): - raise SynapseError( - 400, "Synapse only supports 'invite'/'open' join rule" - ) + raise SynapseError(400, "Synapse only supports 'invite'/'open' join rule") return join_policy_type @@ -1018,7 +959,5 @@ def _parse_visibility_dict(visibility): return True if vis_type not in ("public", "private"): - raise SynapseError( - 400, "Synapse only supports 'public'/'private' visibility" - ) + raise SynapseError(400, "Synapse only supports 'public'/'private' visibility") return vis_type == "public" diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index dca337ec61..c29c78bd65 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -94,14 +94,15 @@ class BaseHandler(object): burst_count = self.hs.config.rc_message.burst_count allowed, time_allowed = self.ratelimiter.can_do_action( - user_id, time_now, + user_id, + time_now, rate_hz=messages_per_second, burst_count=burst_count, update=update, ) if not allowed: raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now)), + retry_after_ms=int(1000 * (time_allowed - time_now)) ) @defer.inlineCallbacks @@ -139,7 +140,7 @@ class BaseHandler(object): if member_event.content["membership"] not in { Membership.JOIN, - Membership.INVITE + Membership.INVITE, }: continue @@ -156,8 +157,7 @@ class BaseHandler(object): # and having homeservers have their own users leave keeps more # of that decision-making and control local to the guest-having # homeserver. - requester = synapse.types.create_requester( - target_user, is_guest=True) + requester = synapse.types.create_requester(target_user, is_guest=True) handler = self.hs.get_room_member_handler() yield handler.update_membership( requester, diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 7fa5d44d29..e62e6cab77 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -20,7 +20,7 @@ class AccountDataEventSource(object): def __init__(self, hs): self.store = hs.get_datastore() - def get_current_key(self, direction='f'): + def get_current_key(self, direction="f"): return self.store.get_max_account_data_stream_id() @defer.inlineCallbacks @@ -34,29 +34,22 @@ class AccountDataEventSource(object): tags = yield self.store.get_updated_tags(user_id, last_stream_id) for room_id, room_tags in tags.items(): - results.append({ - "type": "m.tag", - "content": {"tags": room_tags}, - "room_id": room_id, - }) + results.append( + {"type": "m.tag", "content": {"tags": room_tags}, "room_id": room_id} + ) account_data, room_account_data = ( yield self.store.get_updated_account_data_for_user(user_id, last_stream_id) ) for account_data_type, content in account_data.items(): - results.append({ - "type": account_data_type, - "content": content, - }) + results.append({"type": account_data_type, "content": content}) for room_id, account_data in room_account_data.items(): for account_data_type, content in account_data.items(): - results.append({ - "type": account_data_type, - "content": content, - "room_id": room_id, - }) + results.append( + {"type": account_data_type, "content": content, "room_id": room_id} + ) defer.returnValue((results, current_stream_id)) diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 5e0b92eb1c..0719da3ab7 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -49,12 +49,10 @@ class AccountValidityHandler(object): app_name = self.hs.config.email_app_name self._subject = self._account_validity.renew_email_subject % { - "app": app_name, + "app": app_name } - self._from_string = self.hs.config.email_notif_from % { - "app": app_name, - } + self._from_string = self.hs.config.email_notif_from % {"app": app_name} except Exception: # If substitution failed, fall back to the bare strings. self._subject = self._account_validity.renew_email_subject @@ -69,10 +67,7 @@ class AccountValidityHandler(object): ) # Check the renewal emails to send and send them every 30min. - self.clock.looping_call( - self.send_renewal_emails, - 30 * 60 * 1000, - ) + self.clock.looping_call(self.send_renewal_emails, 30 * 60 * 1000) @defer.inlineCallbacks def send_renewal_emails(self): @@ -86,8 +81,7 @@ class AccountValidityHandler(object): if expiring_users: for user in expiring_users: yield self._send_renewal_email( - user_id=user["user_id"], - expiration_ts=user["expiration_ts_ms"], + user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"] ) @defer.inlineCallbacks @@ -146,32 +140,33 @@ class AccountValidityHandler(object): for address in addresses: raw_to = email.utils.parseaddr(address)[1] - multipart_msg = MIMEMultipart('alternative') - multipart_msg['Subject'] = self._subject - multipart_msg['From'] = self._from_string - multipart_msg['To'] = address - multipart_msg['Date'] = email.utils.formatdate() - multipart_msg['Message-ID'] = email.utils.make_msgid() + multipart_msg = MIMEMultipart("alternative") + multipart_msg["Subject"] = self._subject + multipart_msg["From"] = self._from_string + multipart_msg["To"] = address + multipart_msg["Date"] = email.utils.formatdate() + multipart_msg["Message-ID"] = email.utils.make_msgid() multipart_msg.attach(text_part) multipart_msg.attach(html_part) logger.info("Sending renewal email to %s", address) - yield make_deferred_yieldable(self.sendmail( - self.hs.config.email_smtp_host, - self._raw_from, raw_to, multipart_msg.as_string().encode('utf8'), - reactor=self.hs.get_reactor(), - port=self.hs.config.email_smtp_port, - requireAuthentication=self.hs.config.email_smtp_user is not None, - username=self.hs.config.email_smtp_user, - password=self.hs.config.email_smtp_pass, - requireTransportSecurity=self.hs.config.require_transport_security - )) - - yield self.store.set_renewal_mail_status( - user_id=user_id, - email_sent=True, - ) + yield make_deferred_yieldable( + self.sendmail( + self.hs.config.email_smtp_host, + self._raw_from, + raw_to, + multipart_msg.as_string().encode("utf8"), + reactor=self.hs.get_reactor(), + port=self.hs.config.email_smtp_port, + requireAuthentication=self.hs.config.email_smtp_user is not None, + username=self.hs.config.email_smtp_user, + password=self.hs.config.email_smtp_pass, + requireTransportSecurity=self.hs.config.require_transport_security, + ) + ) + + yield self.store.set_renewal_mail_status(user_id=user_id, email_sent=True) @defer.inlineCallbacks def _get_email_addresses_for_user(self, user_id): @@ -248,9 +243,7 @@ class AccountValidityHandler(object): expiration_ts = self.clock.time_msec() + self._account_validity.period yield self.store.set_account_validity_for_user( - user_id=user_id, - expiration_ts=expiration_ts, - email_sent=email_sent, + user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent ) defer.returnValue(expiration_ts) diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py index 813777bf18..01e0ef408d 100644 --- a/synapse/handlers/acme.py +++ b/synapse/handlers/acme.py @@ -93,24 +93,20 @@ class AcmeHandler(object): ) well_known = Resource() - well_known.putChild(b'acme-challenge', responder.resource) + well_known.putChild(b"acme-challenge", responder.resource) responder_resource = Resource() - responder_resource.putChild(b'.well-known', well_known) - responder_resource.putChild(b'check', static.Data(b'OK', b'text/plain')) + responder_resource.putChild(b".well-known", well_known) + responder_resource.putChild(b"check", static.Data(b"OK", b"text/plain")) srv = server.Site(responder_resource) bind_addresses = self.hs.config.acme_bind_addresses for host in bind_addresses: logger.info( - "Listening for ACME requests on %s:%i", host, self.hs.config.acme_port, + "Listening for ACME requests on %s:%i", host, self.hs.config.acme_port ) try: - self.reactor.listenTCP( - self.hs.config.acme_port, - srv, - interface=host, - ) + self.reactor.listenTCP(self.hs.config.acme_port, srv, interface=host) except twisted.internet.error.CannotListenError as e: check_bind_error(e, host, bind_addresses) diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 5d629126fc..941ebfa107 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -23,7 +23,6 @@ logger = logging.getLogger(__name__) class AdminHandler(BaseHandler): - def __init__(self, hs): super(AdminHandler, self).__init__(hs) @@ -33,23 +32,17 @@ class AdminHandler(BaseHandler): sessions = yield self.store.get_user_ip_and_agents(user) for session in sessions: - connections.append({ - "ip": session["ip"], - "last_seen": session["last_seen"], - "user_agent": session["user_agent"], - }) + connections.append( + { + "ip": session["ip"], + "last_seen": session["last_seen"], + "user_agent": session["user_agent"], + } + ) ret = { "user_id": user.to_string(), - "devices": { - "": { - "sessions": [ - { - "connections": connections, - } - ] - }, - }, + "devices": {"": {"sessions": [{"connections": connections}]}}, } defer.returnValue(ret) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 17eedf4dbf..5cc89d43f6 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -38,7 +38,6 @@ events_processed_counter = Counter("synapse_handlers_appservice_events_processed class ApplicationServicesHandler(object): - def __init__(self, hs): self.store = hs.get_datastore() self.is_mine_id = hs.is_mine_id @@ -101,9 +100,10 @@ class ApplicationServicesHandler(object): yield self._check_user_exists(event.state_key) if not self.started_scheduler: + def start_scheduler(): return self.scheduler.start().addErrback( - log_failure, "Application Services Failure", + log_failure, "Application Services Failure" ) run_as_background_process("as_scheduler", start_scheduler) @@ -118,10 +118,15 @@ class ApplicationServicesHandler(object): for event in events: yield handle_event(event) - yield make_deferred_yieldable(defer.gatherResults([ - run_in_background(handle_room_events, evs) - for evs in itervalues(events_by_room) - ], consumeErrors=True)) + yield make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background(handle_room_events, evs) + for evs in itervalues(events_by_room) + ], + consumeErrors=True, + ) + ) yield self.store.set_appservice_last_pos(upper_bound) @@ -129,20 +134,23 @@ class ApplicationServicesHandler(object): ts = yield self.store.get_received_ts(events[-1].event_id) synapse.metrics.event_processing_positions.labels( - "appservice_sender").set(upper_bound) + "appservice_sender" + ).set(upper_bound) events_processed_counter.inc(len(events)) - event_processing_loop_room_count.labels( - "appservice_sender" - ).inc(len(events_by_room)) + event_processing_loop_room_count.labels("appservice_sender").inc( + len(events_by_room) + ) event_processing_loop_counter.labels("appservice_sender").inc() synapse.metrics.event_processing_lag.labels( - "appservice_sender").set(now - ts) + "appservice_sender" + ).set(now - ts) synapse.metrics.event_processing_last_ts.labels( - "appservice_sender").set(ts) + "appservice_sender" + ).set(ts) finally: self.is_processing = False @@ -155,13 +163,9 @@ class ApplicationServicesHandler(object): Returns: True if this user exists on at least one application service. """ - user_query_services = yield self._get_services_for_user( - user_id=user_id - ) + user_query_services = yield self._get_services_for_user(user_id=user_id) for user_service in user_query_services: - is_known_user = yield self.appservice_api.query_user( - user_service, user_id - ) + is_known_user = yield self.appservice_api.query_user(user_service, user_id) if is_known_user: defer.returnValue(True) defer.returnValue(False) @@ -179,9 +183,7 @@ class ApplicationServicesHandler(object): room_alias_str = room_alias.to_string() services = self.store.get_app_services() alias_query_services = [ - s for s in services if ( - s.is_interested_in_alias(room_alias_str) - ) + s for s in services if (s.is_interested_in_alias(room_alias_str)) ] for alias_service in alias_query_services: is_known_alias = yield self.appservice_api.query_alias( @@ -189,22 +191,24 @@ class ApplicationServicesHandler(object): ) if is_known_alias: # the alias exists now so don't query more ASes. - result = yield self.store.get_association_from_room_alias( - room_alias - ) + result = yield self.store.get_association_from_room_alias(room_alias) defer.returnValue(result) @defer.inlineCallbacks def query_3pe(self, kind, protocol, fields): services = yield self._get_services_for_3pn(protocol) - results = yield make_deferred_yieldable(defer.DeferredList([ - run_in_background( - self.appservice_api.query_3pe, - service, kind, protocol, fields, + results = yield make_deferred_yieldable( + defer.DeferredList( + [ + run_in_background( + self.appservice_api.query_3pe, service, kind, protocol, fields + ) + for service in services + ], + consumeErrors=True, ) - for service in services - ], consumeErrors=True)) + ) ret = [] for (success, result) in results: @@ -276,18 +280,12 @@ class ApplicationServicesHandler(object): def _get_services_for_user(self, user_id): services = self.store.get_app_services() - interested_list = [ - s for s in services if ( - s.is_interested_in_user(user_id) - ) - ] + interested_list = [s for s in services if (s.is_interested_in_user(user_id))] return defer.succeed(interested_list) def _get_services_for_3pn(self, protocol): services = self.store.get_app_services() - interested_list = [ - s for s in services if s.is_interested_in_protocol(protocol) - ] + interested_list = [s for s in services if s.is_interested_in_protocol(protocol)] return defer.succeed(interested_list) @defer.inlineCallbacks diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index a0cf37a9f9..97b21c4093 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -134,13 +134,9 @@ class AuthHandler(BaseHandler): """ # build a list of supported flows - flows = [ - [login_type] for login_type in self._supported_login_types - ] + flows = [[login_type] for login_type in self._supported_login_types] - result, params, _ = yield self.check_auth( - flows, request_body, clientip, - ) + result, params, _ = yield self.check_auth(flows, request_body, clientip) # find the completed login type for login_type in self._supported_login_types: @@ -151,9 +147,7 @@ class AuthHandler(BaseHandler): break else: # this can't happen - raise Exception( - "check_auth returned True but no successful login type", - ) + raise Exception("check_auth returned True but no successful login type") # check that the UI auth matched the access token if user_id != requester.user.to_string(): @@ -215,11 +209,11 @@ class AuthHandler(BaseHandler): authdict = None sid = None - if clientdict and 'auth' in clientdict: - authdict = clientdict['auth'] - del clientdict['auth'] - if 'session' in authdict: - sid = authdict['session'] + if clientdict and "auth" in clientdict: + authdict = clientdict["auth"] + del clientdict["auth"] + if "session" in authdict: + sid = authdict["session"] session = self._get_session_info(sid) if len(clientdict) > 0: @@ -232,27 +226,27 @@ class AuthHandler(BaseHandler): # on a home server. # Revisit: Assumimg the REST APIs do sensible validation, the data # isn't arbintrary. - session['clientdict'] = clientdict + session["clientdict"] = clientdict self._save_session(session) - elif 'clientdict' in session: - clientdict = session['clientdict'] + elif "clientdict" in session: + clientdict = session["clientdict"] if not authdict: raise InteractiveAuthIncompleteError( - self._auth_dict_for_flows(flows, session), + self._auth_dict_for_flows(flows, session) ) - if 'creds' not in session: - session['creds'] = {} - creds = session['creds'] + if "creds" not in session: + session["creds"] = {} + creds = session["creds"] # check auth type currently being presented errordict = {} - if 'type' in authdict: - login_type = authdict['type'] + if "type" in authdict: + login_type = authdict["type"] try: result = yield self._check_auth_dict( - authdict, clientip, password_servlet=password_servlet, + authdict, clientip, password_servlet=password_servlet ) if result: creds[login_type] = result @@ -281,16 +275,15 @@ class AuthHandler(BaseHandler): # and is not sensitive). logger.info( "Auth completed with creds: %r. Client dict has keys: %r", - creds, list(clientdict) + creds, + list(clientdict), ) - defer.returnValue((creds, clientdict, session['id'])) + defer.returnValue((creds, clientdict, session["id"])) ret = self._auth_dict_for_flows(flows, session) - ret['completed'] = list(creds) + ret["completed"] = list(creds) ret.update(errordict) - raise InteractiveAuthIncompleteError( - ret, - ) + raise InteractiveAuthIncompleteError(ret) @defer.inlineCallbacks def add_oob_auth(self, stagetype, authdict, clientip): @@ -300,15 +293,13 @@ class AuthHandler(BaseHandler): """ if stagetype not in self.checkers: raise LoginError(400, "", Codes.MISSING_PARAM) - if 'session' not in authdict: + if "session" not in authdict: raise LoginError(400, "", Codes.MISSING_PARAM) - sess = self._get_session_info( - authdict['session'] - ) - if 'creds' not in sess: - sess['creds'] = {} - creds = sess['creds'] + sess = self._get_session_info(authdict["session"]) + if "creds" not in sess: + sess["creds"] = {} + creds = sess["creds"] result = yield self.checkers[stagetype](authdict, clientip) if result: @@ -329,10 +320,10 @@ class AuthHandler(BaseHandler): not send a session ID, returns None. """ sid = None - if clientdict and 'auth' in clientdict: - authdict = clientdict['auth'] - if 'session' in authdict: - sid = authdict['session'] + if clientdict and "auth" in clientdict: + authdict = clientdict["auth"] + if "session" in authdict: + sid = authdict["session"] return sid def set_session_data(self, session_id, key, value): @@ -347,7 +338,7 @@ class AuthHandler(BaseHandler): value (any): The data to store """ sess = self._get_session_info(session_id) - sess.setdefault('serverdict', {})[key] = value + sess.setdefault("serverdict", {})[key] = value self._save_session(sess) def get_session_data(self, session_id, key, default=None): @@ -360,7 +351,7 @@ class AuthHandler(BaseHandler): default (any): Value to return if the key has not been set """ sess = self._get_session_info(session_id) - return sess.setdefault('serverdict', {}).get(key, default) + return sess.setdefault("serverdict", {}).get(key, default) @defer.inlineCallbacks def _check_auth_dict(self, authdict, clientip, password_servlet=False): @@ -378,15 +369,13 @@ class AuthHandler(BaseHandler): SynapseError if there was a problem with the request LoginError if there was an authentication problem. """ - login_type = authdict['type'] + login_type = authdict["type"] checker = self.checkers.get(login_type) if checker is not None: # XXX: Temporary workaround for having Synapse handle password resets # See AuthHandler.check_auth for further details res = yield checker( - authdict, - clientip=clientip, - password_servlet=password_servlet, + authdict, clientip=clientip, password_servlet=password_servlet ) defer.returnValue(res) @@ -408,13 +397,11 @@ class AuthHandler(BaseHandler): # Client tried to provide captcha but didn't give the parameter: # bad request. raise LoginError( - 400, "Captcha response is required", - errcode=Codes.CAPTCHA_NEEDED + 400, "Captcha response is required", errcode=Codes.CAPTCHA_NEEDED ) logger.info( - "Submitting recaptcha response %s with remoteip %s", - user_response, clientip + "Submitting recaptcha response %s with remoteip %s", user_response, clientip ) # TODO: get this from the homeserver rather than creating a new one for @@ -424,34 +411,34 @@ class AuthHandler(BaseHandler): resp_body = yield client.post_urlencoded_get_json( self.hs.config.recaptcha_siteverify_api, args={ - 'secret': self.hs.config.recaptcha_private_key, - 'response': user_response, - 'remoteip': clientip, - } + "secret": self.hs.config.recaptcha_private_key, + "response": user_response, + "remoteip": clientip, + }, ) except PartialDownloadError as pde: # Twisted is silly data = pde.response resp_body = json.loads(data) - if 'success' in resp_body: + if "success" in resp_body: # Note that we do NOT check the hostname here: we explicitly # intend the CAPTCHA to be presented by whatever client the # user is using, we just care that they have completed a CAPTCHA. logger.info( "%s reCAPTCHA from hostname %s", - "Successful" if resp_body['success'] else "Failed", - resp_body.get('hostname') + "Successful" if resp_body["success"] else "Failed", + resp_body.get("hostname"), ) - if resp_body['success']: + if resp_body["success"]: defer.returnValue(True) raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) def _check_email_identity(self, authdict, **kwargs): - return self._check_threepid('email', authdict, **kwargs) + return self._check_threepid("email", authdict, **kwargs) def _check_msisdn(self, authdict, **kwargs): - return self._check_threepid('msisdn', authdict) + return self._check_threepid("msisdn", authdict) def _check_dummy_auth(self, authdict, **kwargs): return defer.succeed(True) @@ -461,10 +448,10 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def _check_threepid(self, medium, authdict, password_servlet=False, **kwargs): - if 'threepid_creds' not in authdict: + if "threepid_creds" not in authdict: raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM) - threepid_creds = authdict['threepid_creds'] + threepid_creds = authdict["threepid_creds"] identity_handler = self.hs.get_handlers().identity_handler @@ -482,31 +469,36 @@ class AuthHandler(BaseHandler): validated=True, ) - threepid = { - "medium": row["medium"], - "address": row["address"], - "validated_at": row["validated_at"], - } if row else None + threepid = ( + { + "medium": row["medium"], + "address": row["address"], + "validated_at": row["validated_at"], + } + if row + else None + ) if row: # Valid threepid returned, delete from the db yield self.store.delete_threepid_session(threepid_creds["sid"]) else: - raise SynapseError(400, "Password resets are not enabled on this homeserver") + raise SynapseError( + 400, "Password resets are not enabled on this homeserver" + ) if not threepid: raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) - if threepid['medium'] != medium: + if threepid["medium"] != medium: raise LoginError( 401, - "Expecting threepid of type '%s', got '%s'" % ( - medium, threepid['medium'], - ), - errcode=Codes.UNAUTHORIZED + "Expecting threepid of type '%s', got '%s'" + % (medium, threepid["medium"]), + errcode=Codes.UNAUTHORIZED, ) - threepid['threepid_creds'] = authdict['threepid_creds'] + threepid["threepid_creds"] = authdict["threepid_creds"] defer.returnValue(threepid) @@ -520,13 +512,14 @@ class AuthHandler(BaseHandler): "version": self.hs.config.user_consent_version, "en": { "name": self.hs.config.user_consent_policy_name, - "url": "%s_matrix/consent?v=%s" % ( + "url": "%s_matrix/consent?v=%s" + % ( self.hs.config.public_baseurl, self.hs.config.user_consent_version, ), }, - }, - }, + } + } } def _auth_dict_for_flows(self, flows, session): @@ -547,9 +540,9 @@ class AuthHandler(BaseHandler): params[stage] = get_params[stage]() return { - "session": session['id'], + "session": session["id"], "flows": [{"stages": f} for f in public_flows], - "params": params + "params": params, } def _get_session_info(self, session_id): @@ -560,9 +553,7 @@ class AuthHandler(BaseHandler): # create a new session while session_id is None or session_id in self.sessions: session_id = stringutils.random_string(24) - self.sessions[session_id] = { - "id": session_id, - } + self.sessions[session_id] = {"id": session_id} return self.sessions[session_id] @@ -652,7 +643,8 @@ class AuthHandler(BaseHandler): logger.warn( "Attempted to login as %s but it matches more than one user " "inexactly: %r", - user_id, user_infos.keys() + user_id, + user_infos.keys(), ) defer.returnValue(result) @@ -690,12 +682,10 @@ class AuthHandler(BaseHandler): user is too high too proceed. """ - if username.startswith('@'): + if username.startswith("@"): qualified_user_id = username else: - qualified_user_id = UserID( - username, self.hs.hostname - ).to_string() + qualified_user_id = UserID(username, self.hs.hostname).to_string() self.ratelimit_login_per_account(qualified_user_id) @@ -713,17 +703,15 @@ class AuthHandler(BaseHandler): raise SynapseError(400, "Missing parameter: password") for provider in self.password_providers: - if (hasattr(provider, "check_password") - and login_type == LoginType.PASSWORD): + if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD: known_login_type = True - is_valid = yield provider.check_password( - qualified_user_id, password, - ) + is_valid = yield provider.check_password(qualified_user_id, password) if is_valid: defer.returnValue((qualified_user_id, None)) - if (not hasattr(provider, "get_supported_login_types") - or not hasattr(provider, "check_auth")): + if not hasattr(provider, "get_supported_login_types") or not hasattr( + provider, "check_auth" + ): # this password provider doesn't understand custom login types continue @@ -744,15 +732,12 @@ class AuthHandler(BaseHandler): login_dict[f] = login_submission[f] if missing_fields: raise SynapseError( - 400, "Missing parameters for login type %s: %s" % ( - login_type, - missing_fields, - ), + 400, + "Missing parameters for login type %s: %s" + % (login_type, missing_fields), ) - result = yield provider.check_auth( - username, login_type, login_dict, - ) + result = yield provider.check_auth(username, login_type, login_dict) if result: if isinstance(result, str): result = (result, None) @@ -762,7 +747,7 @@ class AuthHandler(BaseHandler): known_login_type = True canonical_user_id = yield self._check_local_password( - qualified_user_id, password, + qualified_user_id, password ) if canonical_user_id: @@ -773,7 +758,8 @@ class AuthHandler(BaseHandler): # unknown username or invalid password. self._failed_attempts_ratelimiter.ratelimit( - qualified_user_id.lower(), time_now_s=self._clock.time(), + qualified_user_id.lower(), + time_now_s=self._clock.time(), rate_hz=self.hs.config.rc_login_failed_attempts.per_second, burst_count=self.hs.config.rc_login_failed_attempts.burst_count, update=True, @@ -781,10 +767,7 @@ class AuthHandler(BaseHandler): # We raise a 403 here, but note that if we're doing user-interactive # login, it turns all LoginErrors into a 401 anyway. - raise LoginError( - 403, "Invalid password", - errcode=Codes.FORBIDDEN - ) + raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN) @defer.inlineCallbacks def check_password_provider_3pid(self, medium, address, password): @@ -810,9 +793,7 @@ class AuthHandler(BaseHandler): # success, to a str (which is the user_id) or a tuple of # (user_id, callback_func), where callback_func should be run # after we've finished everything else - result = yield provider.check_3pid_auth( - medium, address, password, - ) + result = yield provider.check_3pid_auth(medium, address, password) if result: # Check if the return value is a str or a tuple if isinstance(result, str): @@ -853,8 +834,7 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def issue_access_token(self, user_id, device_id=None): access_token = self.macaroon_gen.generate_access_token(user_id) - yield self.store.add_access_token_to_user(user_id, access_token, - device_id) + yield self.store.add_access_token_to_user(user_id, access_token, device_id) defer.returnValue(access_token) @defer.inlineCallbacks @@ -896,12 +876,13 @@ class AuthHandler(BaseHandler): # delete pushers associated with this access token if user_info["token_id"] is not None: yield self.hs.get_pusherpool().remove_pushers_by_access_token( - str(user_info["user"]), (user_info["token_id"], ) + str(user_info["user"]), (user_info["token_id"],) ) @defer.inlineCallbacks - def delete_access_tokens_for_user(self, user_id, except_token_id=None, - device_id=None): + def delete_access_tokens_for_user( + self, user_id, except_token_id=None, device_id=None + ): """Invalidate access tokens belonging to a user Args: @@ -915,7 +896,7 @@ class AuthHandler(BaseHandler): Deferred """ tokens_and_devices = yield self.store.user_delete_access_tokens( - user_id, except_token_id=except_token_id, device_id=device_id, + user_id, except_token_id=except_token_id, device_id=device_id ) # see if any of our auth providers want to know about this @@ -923,14 +904,12 @@ class AuthHandler(BaseHandler): if hasattr(provider, "on_logged_out"): for token, token_id, device_id in tokens_and_devices: yield provider.on_logged_out( - user_id=user_id, - device_id=device_id, - access_token=token, + user_id=user_id, device_id=device_id, access_token=token ) # delete pushers associated with the access tokens yield self.hs.get_pusherpool().remove_pushers_by_access_token( - user_id, (token_id for _, token_id, _ in tokens_and_devices), + user_id, (token_id for _, token_id, _ in tokens_and_devices) ) @defer.inlineCallbacks @@ -944,12 +923,11 @@ class AuthHandler(BaseHandler): # of specific types of threepid (and fixes the fact that checking # for the presence of an email address during password reset was # case sensitive). - if medium == 'email': + if medium == "email": address = address.lower() yield self.store.user_add_threepid( - user_id, medium, address, validated_at, - self.hs.get_clock().time_msec() + user_id, medium, address, validated_at, self.hs.get_clock().time_msec() ) @defer.inlineCallbacks @@ -973,22 +951,15 @@ class AuthHandler(BaseHandler): """ # 'Canonicalise' email addresses as per above - if medium == 'email': + if medium == "email": address = address.lower() identity_handler = self.hs.get_handlers().identity_handler result = yield identity_handler.try_unbind_threepid( - user_id, - { - 'medium': medium, - 'address': address, - 'id_server': id_server, - }, + user_id, {"medium": medium, "address": address, "id_server": id_server} ) - yield self.store.user_delete_threepid( - user_id, medium, address, - ) + yield self.store.user_delete_threepid(user_id, medium, address) defer.returnValue(result) def _save_session(self, session): @@ -1006,14 +977,15 @@ class AuthHandler(BaseHandler): Returns: Deferred(unicode): Hashed password. """ + def _do_hash(): # Normalise the Unicode in the password pw = unicodedata.normalize("NFKC", password) return bcrypt.hashpw( - pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"), + pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"), bcrypt.gensalt(self.bcrypt_rounds), - ).decode('ascii') + ).decode("ascii") return logcontext.defer_to_thread(self.hs.get_reactor(), _do_hash) @@ -1027,18 +999,19 @@ class AuthHandler(BaseHandler): Returns: Deferred(bool): Whether self.hash(password) == stored_hash. """ + def _do_validate_hash(): # Normalise the Unicode in the password pw = unicodedata.normalize("NFKC", password) return bcrypt.checkpw( - pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"), - stored_hash + pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"), + stored_hash, ) if stored_hash: if not isinstance(stored_hash, bytes): - stored_hash = stored_hash.encode('ascii') + stored_hash = stored_hash.encode("ascii") return logcontext.defer_to_thread(self.hs.get_reactor(), _do_validate_hash) else: @@ -1058,14 +1031,16 @@ class AuthHandler(BaseHandler): for this user is too high too proceed. """ self._failed_attempts_ratelimiter.ratelimit( - user_id.lower(), time_now_s=self._clock.time(), + user_id.lower(), + time_now_s=self._clock.time(), rate_hz=self.hs.config.rc_login_failed_attempts.per_second, burst_count=self.hs.config.rc_login_failed_attempts.burst_count, update=False, ) self._account_ratelimiter.ratelimit( - user_id.lower(), time_now_s=self._clock.time(), + user_id.lower(), + time_now_s=self._clock.time(), rate_hz=self.hs.config.rc_login_account.per_second, burst_count=self.hs.config.rc_login_account.burst_count, update=True, @@ -1083,9 +1058,9 @@ class MacaroonGenerator(object): macaroon.add_first_party_caveat("type = access") # Include a nonce, to make sure that each login gets a different # access token. - macaroon.add_first_party_caveat("nonce = %s" % ( - stringutils.random_string_with_symbols(16), - )) + macaroon.add_first_party_caveat( + "nonce = %s" % (stringutils.random_string_with_symbols(16),) + ) for caveat in extra_caveats: macaroon.add_first_party_caveat(caveat) return macaroon.serialize() @@ -1116,7 +1091,8 @@ class MacaroonGenerator(object): macaroon = pymacaroons.Macaroon( location=self.hs.config.server_name, identifier="key", - key=self.hs.config.macaroon_secret_key) + key=self.hs.config.macaroon_secret_key, + ) macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) return macaroon diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 7378b56c1d..e8f9da6098 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -28,6 +28,7 @@ logger = logging.getLogger(__name__) class DeactivateAccountHandler(BaseHandler): """Handler which deals with deactivating user accounts.""" + def __init__(self, hs): super(DeactivateAccountHandler, self).__init__(hs) self._auth_handler = hs.get_auth_handler() @@ -78,9 +79,9 @@ class DeactivateAccountHandler(BaseHandler): result = yield self._identity_handler.try_unbind_threepid( user_id, { - 'medium': threepid['medium'], - 'address': threepid['address'], - 'id_server': id_server, + "medium": threepid["medium"], + "address": threepid["address"], + "id_server": id_server, }, ) identity_server_supports_unbinding &= result @@ -89,7 +90,7 @@ class DeactivateAccountHandler(BaseHandler): logger.exception("Failed to remove threepid from ID server") raise SynapseError(400, "Failed to remove threepid from ID server") yield self.store.user_delete_threepid( - user_id, threepid['medium'], threepid['address'], + user_id, threepid["medium"], threepid["address"] ) # delete any devices belonging to the user, which will also @@ -183,5 +184,6 @@ class DeactivateAccountHandler(BaseHandler): except Exception: logger.exception( "Failed to part user %r from room %r: ignoring and continuing", - user_id, room_id, + user_id, + room_id, ) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index b398848079..f59d0479b5 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -58,9 +58,7 @@ class DeviceWorkerHandler(BaseHandler): device_map = yield self.store.get_devices_by_user(user_id) - ips = yield self.store.get_last_client_ip_by_device( - user_id, device_id=None - ) + ips = yield self.store.get_last_client_ip_by_device(user_id, device_id=None) devices = list(device_map.values()) for device in devices: @@ -85,9 +83,7 @@ class DeviceWorkerHandler(BaseHandler): device = yield self.store.get_device(user_id, device_id) except errors.StoreError: raise errors.NotFoundError - ips = yield self.store.get_last_client_ip_by_device( - user_id, device_id, - ) + ips = yield self.store.get_last_client_ip_by_device(user_id, device_id) _update_device_from_client_ips(device, ips) defer.returnValue(device) @@ -114,13 +110,11 @@ class DeviceWorkerHandler(BaseHandler): rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key) member_events = yield self.store.get_membership_changes_for_user( - user_id, from_token.room_key, now_room_key, + user_id, from_token.room_key, now_room_key ) rooms_changed.update(event.room_id for event in member_events) - stream_ordering = RoomStreamToken.parse_stream_token( - from_token.room_key - ).stream + stream_ordering = RoomStreamToken.parse_stream_token(from_token.room_key).stream possibly_changed = set(changed) possibly_left = set() @@ -206,10 +200,9 @@ class DeviceWorkerHandler(BaseHandler): possibly_joined = [] possibly_left = [] - defer.returnValue({ - "changed": list(possibly_joined), - "left": list(possibly_left), - }) + defer.returnValue( + {"changed": list(possibly_joined), "left": list(possibly_left)} + ) class DeviceHandler(DeviceWorkerHandler): @@ -223,17 +216,18 @@ class DeviceHandler(DeviceWorkerHandler): federation_registry = hs.get_federation_registry() federation_registry.register_edu_handler( - "m.device_list_update", self._edu_updater.incoming_device_list_update, + "m.device_list_update", self._edu_updater.incoming_device_list_update ) federation_registry.register_query_handler( - "user_devices", self.on_federation_query_user_devices, + "user_devices", self.on_federation_query_user_devices ) hs.get_distributor().observe("user_left_room", self.user_left_room) @defer.inlineCallbacks - def check_device_registered(self, user_id, device_id, - initial_device_display_name=None): + def check_device_registered( + self, user_id, device_id, initial_device_display_name=None + ): """ If the given device has not been registered, register it with the supplied display name. @@ -297,12 +291,10 @@ class DeviceHandler(DeviceWorkerHandler): raise yield self._auth_handler.delete_access_tokens_for_user( - user_id, device_id=device_id, + user_id, device_id=device_id ) - yield self.store.delete_e2e_keys_by_device( - user_id=user_id, device_id=device_id - ) + yield self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id) yield self.notify_device_update(user_id, [device_id]) @@ -349,7 +341,7 @@ class DeviceHandler(DeviceWorkerHandler): # considered as part of a critical path. for device_id in device_ids: yield self._auth_handler.delete_access_tokens_for_user( - user_id, device_id=device_id, + user_id, device_id=device_id ) yield self.store.delete_e2e_keys_by_device( user_id=user_id, device_id=device_id @@ -372,9 +364,7 @@ class DeviceHandler(DeviceWorkerHandler): try: yield self.store.update_device( - user_id, - device_id, - new_display_name=content.get("display_name") + user_id, device_id, new_display_name=content.get("display_name") ) yield self.notify_device_update(user_id, [device_id]) except errors.StoreError as e: @@ -404,29 +394,26 @@ class DeviceHandler(DeviceWorkerHandler): for device_id in device_ids: logger.debug( - "Notifying about update %r/%r, ID: %r", user_id, device_id, - position, + "Notifying about update %r/%r, ID: %r", user_id, device_id, position ) room_ids = yield self.store.get_rooms_for_user(user_id) - yield self.notifier.on_new_event( - "device_list_key", position, rooms=room_ids, - ) + yield self.notifier.on_new_event("device_list_key", position, rooms=room_ids) if hosts: - logger.info("Sending device list update notif for %r to: %r", user_id, hosts) + logger.info( + "Sending device list update notif for %r to: %r", user_id, hosts + ) for host in hosts: self.federation_sender.send_device_messages(host) @defer.inlineCallbacks def on_federation_query_user_devices(self, user_id): stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id) - defer.returnValue({ - "user_id": user_id, - "stream_id": stream_id, - "devices": devices, - }) + defer.returnValue( + {"user_id": user_id, "stream_id": stream_id, "devices": devices} + ) @defer.inlineCallbacks def user_left_room(self, user, room_id): @@ -440,10 +427,7 @@ class DeviceHandler(DeviceWorkerHandler): def _update_device_from_client_ips(device, client_ips): ip = client_ips.get((device["user_id"], device["device_id"]), {}) - device.update({ - "last_seen_ts": ip.get("last_seen"), - "last_seen_ip": ip.get("ip"), - }) + device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")}) class DeviceListEduUpdater(object): @@ -481,13 +465,15 @@ class DeviceListEduUpdater(object): device_id = edu_content.pop("device_id") stream_id = str(edu_content.pop("stream_id")) # They may come as ints prev_ids = edu_content.pop("prev_id", []) - prev_ids = [str(p) for p in prev_ids] # They may come as ints + prev_ids = [str(p) for p in prev_ids] # They may come as ints if get_domain_from_id(user_id) != origin: # TODO: Raise? logger.warning( "Got device list update edu for %r/%r from %r", - user_id, device_id, origin, + user_id, + device_id, + origin, ) return @@ -497,13 +483,12 @@ class DeviceListEduUpdater(object): # probably won't get any further updates. logger.warning( "Got device list update edu for %r/%r, but don't share a room", - user_id, device_id, + user_id, + device_id, ) return - logger.debug( - "Received device list update for %r/%r", user_id, device_id, - ) + logger.debug("Received device list update for %r/%r", user_id, device_id) self._pending_updates.setdefault(user_id, []).append( (device_id, stream_id, prev_ids, edu_content) @@ -525,7 +510,10 @@ class DeviceListEduUpdater(object): for device_id, stream_id, prev_ids, content in pending_updates: logger.debug( "Handling update %r/%r, ID: %r, prev: %r ", - user_id, device_id, stream_id, prev_ids, + user_id, + device_id, + stream_id, + prev_ids, ) # Given a list of updates we check if we need to resync. This @@ -540,13 +528,13 @@ class DeviceListEduUpdater(object): try: result = yield self.federation.query_user_devices(origin, user_id) except ( - NotRetryingDestination, RequestSendFailed, HttpResponseException, + NotRetryingDestination, + RequestSendFailed, + HttpResponseException, ): # TODO: Remember that we are now out of sync and try again # later - logger.warn( - "Failed to handle device list update for %s", user_id, - ) + logger.warn("Failed to handle device list update for %s", user_id) # We abort on exceptions rather than accepting the update # as otherwise synapse will 'forget' that its device list # is out of date. If we bail then we will retry the resync @@ -582,18 +570,21 @@ class DeviceListEduUpdater(object): if len(devices) > 1000: logger.warn( "Ignoring device list snapshot for %s as it has >1K devs (%d)", - user_id, len(devices) + user_id, + len(devices), ) devices = [] for device in devices: logger.debug( "Handling resync update %r/%r, ID: %r", - user_id, device["device_id"], stream_id, + user_id, + device["device_id"], + stream_id, ) yield self.store.update_remote_device_list_cache( - user_id, devices, stream_id, + user_id, devices, stream_id ) device_ids = [device["device_id"] for device in devices] yield self.device_handler.notify_device_update(user_id, device_ids) @@ -606,7 +597,7 @@ class DeviceListEduUpdater(object): # change (because of the single prev_id matching the current cache) for device_id, stream_id, prev_ids, content in pending_updates: yield self.store.update_remote_device_list_cache_entry( - user_id, device_id, content, stream_id, + user_id, device_id, content, stream_id ) yield self.device_handler.notify_device_update( @@ -624,14 +615,9 @@ class DeviceListEduUpdater(object): """ seen_updates = self._seen_updates.get(user_id, set()) - extremity = yield self.store.get_device_list_last_stream_id_for_remote( - user_id - ) + extremity = yield self.store.get_device_list_last_stream_id_for_remote(user_id) - logger.debug( - "Current extremity for %r: %r", - user_id, extremity, - ) + logger.debug("Current extremity for %r: %r", user_id, extremity) stream_id_in_updates = set() # stream_ids in updates list for _, stream_id, prev_ids, _ in updates: diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 2e2e5261de..e1ebb6346c 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -25,7 +25,6 @@ logger = logging.getLogger(__name__) class DeviceMessageHandler(object): - def __init__(self, hs): """ Args: @@ -47,15 +46,15 @@ class DeviceMessageHandler(object): if origin != get_domain_from_id(sender_user_id): logger.warn( "Dropping device message from %r with spoofed sender %r", - origin, sender_user_id + origin, + sender_user_id, ) message_type = content["type"] message_id = content["message_id"] for user_id, by_device in content["messages"].items(): # we use UserID.from_string to catch invalid user ids if not self.is_mine(UserID.from_string(user_id)): - logger.warning("Request for keys for non-local user %s", - user_id) + logger.warning("Request for keys for non-local user %s", user_id) raise SynapseError(400, "Not a user here") messages_by_device = { diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index a12f9508d8..42d5b3db30 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -36,7 +36,6 @@ logger = logging.getLogger(__name__) class DirectoryHandler(BaseHandler): - def __init__(self, hs): super(DirectoryHandler, self).__init__(hs) @@ -77,15 +76,19 @@ class DirectoryHandler(BaseHandler): raise SynapseError(400, "Failed to get server list") yield self.store.create_room_alias_association( - room_alias, - room_id, - servers, - creator=creator, + room_alias, room_id, servers, creator=creator ) @defer.inlineCallbacks - def create_association(self, requester, room_alias, room_id, servers=None, - send_event=True, check_membership=True): + def create_association( + self, + requester, + room_alias, + room_id, + servers=None, + send_event=True, + check_membership=True, + ): """Attempt to create a new alias Args: @@ -115,49 +118,40 @@ class DirectoryHandler(BaseHandler): if service: if not service.is_interested_in_alias(room_alias.to_string()): raise SynapseError( - 400, "This application service has not reserved" - " this kind of alias.", errcode=Codes.EXCLUSIVE + 400, + "This application service has not reserved" " this kind of alias.", + errcode=Codes.EXCLUSIVE, ) else: if self.require_membership and check_membership: rooms_for_user = yield self.store.get_rooms_for_user(user_id) if room_id not in rooms_for_user: raise AuthError( - 403, - "You must be in the room to create an alias for it", + 403, "You must be in the room to create an alias for it" ) if not self.spam_checker.user_may_create_room_alias(user_id, room_alias): - raise AuthError( - 403, "This user is not permitted to create this alias", - ) + raise AuthError(403, "This user is not permitted to create this alias") if not self.config.is_alias_creation_allowed( - user_id, room_id, room_alias.to_string(), + user_id, room_id, room_alias.to_string() ): # Lets just return a generic message, as there may be all sorts of # reasons why we said no. TODO: Allow configurable error messages # per alias creation rule? - raise SynapseError( - 403, "Not allowed to create alias", - ) + raise SynapseError(403, "Not allowed to create alias") - can_create = yield self.can_modify_alias( - room_alias, - user_id=user_id - ) + can_create = yield self.can_modify_alias(room_alias, user_id=user_id) if not can_create: raise AuthError( - 400, "This alias is reserved by an application service.", - errcode=Codes.EXCLUSIVE + 400, + "This alias is reserved by an application service.", + errcode=Codes.EXCLUSIVE, ) yield self._create_association(room_alias, room_id, servers, creator=user_id) if send_event: - yield self.send_room_alias_update_event( - requester, - room_id - ) + yield self.send_room_alias_update_event(requester, room_id) @defer.inlineCallbacks def delete_association(self, requester, room_alias, send_event=True): @@ -194,34 +188,24 @@ class DirectoryHandler(BaseHandler): raise if not can_delete: - raise AuthError( - 403, "You don't have permission to delete the alias.", - ) + raise AuthError(403, "You don't have permission to delete the alias.") - can_delete = yield self.can_modify_alias( - room_alias, - user_id=user_id - ) + can_delete = yield self.can_modify_alias(room_alias, user_id=user_id) if not can_delete: raise SynapseError( - 400, "This alias is reserved by an application service.", - errcode=Codes.EXCLUSIVE + 400, + "This alias is reserved by an application service.", + errcode=Codes.EXCLUSIVE, ) room_id = yield self._delete_association(room_alias) try: if send_event: - yield self.send_room_alias_update_event( - requester, - room_id - ) + yield self.send_room_alias_update_event(requester, room_id) yield self._update_canonical_alias( - requester, - requester.user.to_string(), - room_id, - room_alias, + requester, requester.user.to_string(), room_id, room_alias ) except AuthError as e: logger.info("Failed to update alias events: %s", e) @@ -234,7 +218,7 @@ class DirectoryHandler(BaseHandler): raise SynapseError( 400, "This application service has not reserved this kind of alias", - errcode=Codes.EXCLUSIVE + errcode=Codes.EXCLUSIVE, ) yield self._delete_association(room_alias) @@ -251,9 +235,7 @@ class DirectoryHandler(BaseHandler): def get_association(self, room_alias): room_id = None if self.hs.is_mine(room_alias): - result = yield self.get_association_from_room_alias( - room_alias - ) + result = yield self.get_association_from_room_alias(room_alias) if result: room_id = result.room_id @@ -263,9 +245,7 @@ class DirectoryHandler(BaseHandler): result = yield self.federation.make_query( destination=room_alias.domain, query_type="directory", - args={ - "room_alias": room_alias.to_string(), - }, + args={"room_alias": room_alias.to_string()}, retry_on_dns_fail=False, ignore_backoff=True, ) @@ -284,7 +264,7 @@ class DirectoryHandler(BaseHandler): raise SynapseError( 404, "Room alias %s not found" % (room_alias.to_string(),), - Codes.NOT_FOUND + Codes.NOT_FOUND, ) users = yield self.state.get_current_users_in_room(room_id) @@ -293,41 +273,28 @@ class DirectoryHandler(BaseHandler): # If this server is in the list of servers, return it first. if self.server_name in servers: - servers = ( - [self.server_name] + - [s for s in servers if s != self.server_name] - ) + servers = [self.server_name] + [s for s in servers if s != self.server_name] else: servers = list(servers) - defer.returnValue({ - "room_id": room_id, - "servers": servers, - }) + defer.returnValue({"room_id": room_id, "servers": servers}) return @defer.inlineCallbacks def on_directory_query(self, args): room_alias = RoomAlias.from_string(args["room_alias"]) if not self.hs.is_mine(room_alias): - raise SynapseError( - 400, "Room Alias is not hosted on this Home Server" - ) + raise SynapseError(400, "Room Alias is not hosted on this Home Server") - result = yield self.get_association_from_room_alias( - room_alias - ) + result = yield self.get_association_from_room_alias(room_alias) if result is not None: - defer.returnValue({ - "room_id": result.room_id, - "servers": result.servers, - }) + defer.returnValue({"room_id": result.room_id, "servers": result.servers}) else: raise SynapseError( 404, "Room alias %r not found" % (room_alias.to_string(),), - Codes.NOT_FOUND + Codes.NOT_FOUND, ) @defer.inlineCallbacks @@ -343,7 +310,7 @@ class DirectoryHandler(BaseHandler): "sender": requester.user.to_string(), "content": {"aliases": aliases}, }, - ratelimit=False + ratelimit=False, ) @defer.inlineCallbacks @@ -365,14 +332,12 @@ class DirectoryHandler(BaseHandler): "sender": user_id, "content": {}, }, - ratelimit=False + ratelimit=False, ) @defer.inlineCallbacks def get_association_from_room_alias(self, room_alias): - result = yield self.store.get_association_from_room_alias( - room_alias - ) + result = yield self.store.get_association_from_room_alias(room_alias) if not result: # Query AS to see if it exists as_handler = self.appservice_handler @@ -421,8 +386,7 @@ class DirectoryHandler(BaseHandler): if not self.spam_checker.user_may_publish_room(user_id, room_id): raise AuthError( - 403, - "This user is not permitted to publish rooms to the room list" + 403, "This user is not permitted to publish rooms to the room list" ) if requester.is_guest: @@ -434,8 +398,7 @@ class DirectoryHandler(BaseHandler): if visibility == "public" and not self.enable_room_list_search: # The room list has been disabled. raise AuthError( - 403, - "This user is not permitted to publish rooms to the room list" + 403, "This user is not permitted to publish rooms to the room list" ) room = yield self.store.get_room(room_id) @@ -452,20 +415,19 @@ class DirectoryHandler(BaseHandler): room_aliases.append(canonical_alias) if not self.config.is_publishing_room_allowed( - user_id, room_id, room_aliases, + user_id, room_id, room_aliases ): # Lets just return a generic message, as there may be all sorts of # reasons why we said no. TODO: Allow configurable error messages # per alias creation rule? - raise SynapseError( - 403, "Not allowed to publish room", - ) + raise SynapseError(403, "Not allowed to publish room") yield self.store.set_room_is_public(room_id, making_public) @defer.inlineCallbacks - def edit_published_appservice_room_list(self, appservice_id, network_id, - room_id, visibility): + def edit_published_appservice_room_list( + self, appservice_id, network_id, room_id, visibility + ): """Add or remove a room from the appservice/network specific public room list. diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 9dc46aa15f..807900fe52 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -99,9 +99,7 @@ class E2eKeysHandler(object): query_list.append((user_id, None)) user_ids_not_in_cache, remote_results = ( - yield self.store.get_user_devices_from_cache( - query_list - ) + yield self.store.get_user_devices_from_cache(query_list) ) for user_id, devices in iteritems(remote_results): user_devices = results.setdefault(user_id, {}) @@ -126,9 +124,7 @@ class E2eKeysHandler(object): destination_query = remote_queries_not_in_cache[destination] try: remote_result = yield self.federation.query_client_keys( - destination, - {"device_keys": destination_query}, - timeout=timeout + destination, {"device_keys": destination_query}, timeout=timeout ) for user_id, keys in remote_result["device_keys"].items(): @@ -138,14 +134,17 @@ class E2eKeysHandler(object): except Exception as e: failures[destination] = _exception_to_failure(e) - yield make_deferred_yieldable(defer.gatherResults([ - run_in_background(do_remote_query, destination) - for destination in remote_queries_not_in_cache - ], consumeErrors=True)) + yield make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background(do_remote_query, destination) + for destination in remote_queries_not_in_cache + ], + consumeErrors=True, + ) + ) - defer.returnValue({ - "device_keys": results, "failures": failures, - }) + defer.returnValue({"device_keys": results, "failures": failures}) @defer.inlineCallbacks def query_local_devices(self, query): @@ -165,8 +164,7 @@ class E2eKeysHandler(object): for user_id, device_ids in query.items(): # we use UserID.from_string to catch invalid user ids if not self.is_mine(UserID.from_string(user_id)): - logger.warning("Request for keys for non-local user %s", - user_id) + logger.warning("Request for keys for non-local user %s", user_id) raise SynapseError(400, "Not a user here") if not device_ids: @@ -231,9 +229,7 @@ class E2eKeysHandler(object): device_keys = remote_queries[destination] try: remote_result = yield self.federation.claim_client_keys( - destination, - {"one_time_keys": device_keys}, - timeout=timeout + destination, {"one_time_keys": device_keys}, timeout=timeout ) for user_id, keys in remote_result["one_time_keys"].items(): if user_id in device_keys: @@ -241,25 +237,29 @@ class E2eKeysHandler(object): except Exception as e: failures[destination] = _exception_to_failure(e) - yield make_deferred_yieldable(defer.gatherResults([ - run_in_background(claim_client_keys, destination) - for destination in remote_queries - ], consumeErrors=True)) + yield make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background(claim_client_keys, destination) + for destination in remote_queries + ], + consumeErrors=True, + ) + ) logger.info( "Claimed one-time-keys: %s", - ",".join(( - "%s for %s:%s" % (key_id, user_id, device_id) - for user_id, user_keys in iteritems(json_result) - for device_id, device_keys in iteritems(user_keys) - for key_id, _ in iteritems(device_keys) - )), + ",".join( + ( + "%s for %s:%s" % (key_id, user_id, device_id) + for user_id, user_keys in iteritems(json_result) + for device_id, device_keys in iteritems(user_keys) + for key_id, _ in iteritems(device_keys) + ) + ), ) - defer.returnValue({ - "one_time_keys": json_result, - "failures": failures - }) + defer.returnValue({"one_time_keys": json_result, "failures": failures}) @defer.inlineCallbacks def upload_keys_for_user(self, user_id, device_id, keys): @@ -270,11 +270,13 @@ class E2eKeysHandler(object): if device_keys: logger.info( "Updating device_keys for device %r for user %s at %d", - device_id, user_id, time_now + device_id, + user_id, + time_now, ) # TODO: Sign the JSON with the server key changed = yield self.store.set_e2e_device_keys( - user_id, device_id, time_now, device_keys, + user_id, device_id, time_now, device_keys ) if changed: # Only notify about device updates *if* the keys actually changed @@ -283,7 +285,7 @@ class E2eKeysHandler(object): one_time_keys = keys.get("one_time_keys", None) if one_time_keys: yield self._upload_one_time_keys_for_user( - user_id, device_id, time_now, one_time_keys, + user_id, device_id, time_now, one_time_keys ) # the device should have been registered already, but it may have been @@ -298,20 +300,22 @@ class E2eKeysHandler(object): defer.returnValue({"one_time_key_counts": result}) @defer.inlineCallbacks - def _upload_one_time_keys_for_user(self, user_id, device_id, time_now, - one_time_keys): + def _upload_one_time_keys_for_user( + self, user_id, device_id, time_now, one_time_keys + ): logger.info( "Adding one_time_keys %r for device %r for user %r at %d", - one_time_keys.keys(), device_id, user_id, time_now, + one_time_keys.keys(), + device_id, + user_id, + time_now, ) # make a list of (alg, id, key) tuples key_list = [] for key_id, key_obj in one_time_keys.items(): algorithm, key_id = key_id.split(":") - key_list.append(( - algorithm, key_id, key_obj - )) + key_list.append((algorithm, key_id, key_obj)) # First we check if we have already persisted any of the keys. existing_key_map = yield self.store.get_e2e_one_time_keys( @@ -325,42 +329,35 @@ class E2eKeysHandler(object): if not _one_time_keys_match(ex_json, key): raise SynapseError( 400, - ("One time key %s:%s already exists. " - "Old key: %s; new key: %r") % - (algorithm, key_id, ex_json, key) + ( + "One time key %s:%s already exists. " + "Old key: %s; new key: %r" + ) + % (algorithm, key_id, ex_json, key), ) else: - new_keys.append(( - algorithm, key_id, encode_canonical_json(key).decode('ascii'))) + new_keys.append( + (algorithm, key_id, encode_canonical_json(key).decode("ascii")) + ) - yield self.store.add_e2e_one_time_keys( - user_id, device_id, time_now, new_keys - ) + yield self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys) def _exception_to_failure(e): if isinstance(e, CodeMessageException): - return { - "status": e.code, "message": str(e), - } + return {"status": e.code, "message": str(e)} if isinstance(e, NotRetryingDestination): - return { - "status": 503, "message": "Not ready for retry", - } + return {"status": 503, "message": "Not ready for retry"} if isinstance(e, FederationDeniedError): - return { - "status": 403, "message": "Federation Denied", - } + return {"status": 403, "message": "Federation Denied"} # include ConnectionRefused and other errors # # Note that some Exceptions (notably twisted's ResponseFailed etc) don't # give a string for e.message, which json then fails to serialize. - return { - "status": 503, "message": str(e), - } + return {"status": 503, "message": str(e)} def _one_time_keys_match(old_key_json, new_key): diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 7bc174070e..ebd807bca6 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -152,14 +152,14 @@ class E2eRoomKeysHandler(object): else: raise - if version_info['version'] != version: + if version_info["version"] != version: # Check that the version we're trying to upload actually exists try: version_info = yield self.store.get_e2e_room_keys_version_info( - user_id, version, + user_id, version ) # if we get this far, the version must exist - raise RoomKeysVersionError(current_version=version_info['version']) + raise RoomKeysVersionError(current_version=version_info["version"]) except StoreError as e: if e.code == 404: raise NotFoundError("Version '%s' not found" % (version,)) @@ -168,8 +168,8 @@ class E2eRoomKeysHandler(object): # go through the room_keys. # XXX: this should/could be done concurrently, given we're in a lock. - for room_id, room in iteritems(room_keys['rooms']): - for session_id, session in iteritems(room['sessions']): + for room_id, room in iteritems(room_keys["rooms"]): + for session_id, session in iteritems(room["sessions"]): yield self._upload_room_key( user_id, version, room_id, session_id, session ) @@ -223,14 +223,14 @@ class E2eRoomKeysHandler(object): # spelt out with if/elifs rather than nested boolean expressions # purely for legibility. - if room_key['is_verified'] and not current_room_key['is_verified']: + if room_key["is_verified"] and not current_room_key["is_verified"]: return True elif ( - room_key['first_message_index'] < - current_room_key['first_message_index'] + room_key["first_message_index"] + < current_room_key["first_message_index"] ): return True - elif room_key['forwarded_count'] < current_room_key['forwarded_count']: + elif room_key["forwarded_count"] < current_room_key["forwarded_count"]: return True else: return False @@ -328,16 +328,10 @@ class E2eRoomKeysHandler(object): A deferred of an empty dict. """ if "version" not in version_info: - raise SynapseError( - 400, - "Missing version in body", - Codes.MISSING_PARAM - ) + raise SynapseError(400, "Missing version in body", Codes.MISSING_PARAM) if version_info["version"] != version: raise SynapseError( - 400, - "Version in body does not match", - Codes.INVALID_PARAM + 400, "Version in body does not match", Codes.INVALID_PARAM ) with (yield self._upload_linearizer.queue(user_id)): try: @@ -350,12 +344,10 @@ class E2eRoomKeysHandler(object): else: raise if old_info["algorithm"] != version_info["algorithm"]: - raise SynapseError( - 400, - "Algorithm does not match", - Codes.INVALID_PARAM - ) + raise SynapseError(400, "Algorithm does not match", Codes.INVALID_PARAM) - yield self.store.update_e2e_room_keys_version(user_id, version, version_info) + yield self.store.update_e2e_room_keys_version( + user_id, version, version_info + ) defer.returnValue({}) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index eb525070cf..5836d3c639 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -31,7 +31,6 @@ logger = logging.getLogger(__name__) class EventStreamHandler(BaseHandler): - def __init__(self, hs): super(EventStreamHandler, self).__init__(hs) @@ -53,9 +52,17 @@ class EventStreamHandler(BaseHandler): @defer.inlineCallbacks @log_function - def get_stream(self, auth_user_id, pagin_config, timeout=0, - as_client_event=True, affect_presence=True, - only_keys=None, room_id=None, is_guest=False): + def get_stream( + self, + auth_user_id, + pagin_config, + timeout=0, + as_client_event=True, + affect_presence=True, + only_keys=None, + room_id=None, + is_guest=False, + ): """Fetches the events stream for a given user. If `only_keys` is not None, events from keys will be sent down. @@ -73,7 +80,7 @@ class EventStreamHandler(BaseHandler): presence_handler = self.hs.get_presence_handler() context = yield presence_handler.user_syncing( - auth_user_id, affect_presence=affect_presence, + auth_user_id, affect_presence=affect_presence ) with context: if timeout: @@ -85,9 +92,12 @@ class EventStreamHandler(BaseHandler): timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1)) events, tokens = yield self.notifier.get_events_for( - auth_user, pagin_config, timeout, + auth_user, + pagin_config, + timeout, only_keys=only_keys, - is_guest=is_guest, explicit_room_id=room_id + is_guest=is_guest, + explicit_room_id=room_id, ) # When the user joins a new room, or another user joins a currently @@ -102,17 +112,15 @@ class EventStreamHandler(BaseHandler): # Send down presence. if event.state_key == auth_user_id: # Send down presence for everyone in the room. - users = yield self.state.get_current_users_in_room(event.room_id) - states = yield presence_handler.get_states( - users, - as_event=True, + users = yield self.state.get_current_users_in_room( + event.room_id ) + states = yield presence_handler.get_states(users, as_event=True) to_add.extend(states) else: ev = yield presence_handler.get_state( - UserID.from_string(event.state_key), - as_event=True, + UserID.from_string(event.state_key), as_event=True ) to_add.append(ev) @@ -121,7 +129,9 @@ class EventStreamHandler(BaseHandler): time_now = self.clock.time_msec() chunks = yield self._event_serializer.serialize_events( - events, time_now, as_client_event=as_client_event, + events, + time_now, + as_client_event=as_client_event, # We don't bundle "live" events, as otherwise clients # will end up double counting annotations. bundle_aggregations=False, @@ -137,7 +147,6 @@ class EventStreamHandler(BaseHandler): class EventHandler(BaseHandler): - @defer.inlineCallbacks def get_event(self, user, room_id, event_id): """Retrieve a single specified event. @@ -164,16 +173,10 @@ class EventHandler(BaseHandler): is_peeking = user.to_string() not in users filtered = yield filter_events_for_client( - self.store, - user.to_string(), - [event], - is_peeking=is_peeking + self.store, user.to_string(), [event], is_peeking=is_peeking ) if not filtered: - raise AuthError( - 403, - "You don't have permission to access that event." - ) + raise AuthError(403, "You don't have permission to access that event.") defer.returnValue(event) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index d5a605d3bd..02d397c498 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -82,7 +82,7 @@ def shortstr(iterable, maxitems=5): items = list(itertools.islice(iterable, maxitems + 1)) if len(items) <= maxitems: return str(items) - return u"[" + u", ".join(repr(r) for r in items[:maxitems]) + u", ...]" + return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]" class FederationHandler(BaseHandler): @@ -115,14 +115,14 @@ class FederationHandler(BaseHandler): self.config = hs.config self.http_client = hs.get_simple_http_client() - self._send_events_to_master = ( - ReplicationFederationSendEventsRestServlet.make_client(hs) + self._send_events_to_master = ReplicationFederationSendEventsRestServlet.make_client( + hs ) - self._notify_user_membership_change = ( - ReplicationUserJoinedLeftRoomRestServlet.make_client(hs) + self._notify_user_membership_change = ReplicationUserJoinedLeftRoomRestServlet.make_client( + hs ) - self._clean_room_for_join_client = ( - ReplicationCleanRoomRestServlet.make_client(hs) + self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client( + hs ) # When joining a room we need to queue any events for that room up @@ -132,9 +132,7 @@ class FederationHandler(BaseHandler): self.third_party_event_rules = hs.get_third_party_event_rules() @defer.inlineCallbacks - def on_receive_pdu( - self, origin, pdu, sent_to_us_directly=False, - ): + def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False): """ Process a PDU received via a federation /send/ transaction, or via backfill of missing prev_events @@ -151,26 +149,19 @@ class FederationHandler(BaseHandler): room_id = pdu.room_id event_id = pdu.event_id - logger.info( - "[%s %s] handling received PDU: %s", - room_id, event_id, pdu, - ) + logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu) # We reprocess pdus when we have seen them only as outliers existing = yield self.store.get_event( - event_id, - allow_none=True, - allow_rejected=True, + event_id, allow_none=True, allow_rejected=True ) # FIXME: Currently we fetch an event again when we already have it # if it has been marked as an outlier. - already_seen = ( - existing and ( - not existing.internal_metadata.is_outlier() - or pdu.internal_metadata.is_outlier() - ) + already_seen = existing and ( + not existing.internal_metadata.is_outlier() + or pdu.internal_metadata.is_outlier() ) if already_seen: logger.debug("[%s %s]: Already seen pdu", room_id, event_id) @@ -182,20 +173,19 @@ class FederationHandler(BaseHandler): try: self._sanity_check_event(pdu) except SynapseError as err: - logger.warn("[%s %s] Received event failed sanity checks", room_id, event_id) - raise FederationError( - "ERROR", - err.code, - err.msg, - affected=pdu.event_id, + logger.warn( + "[%s %s] Received event failed sanity checks", room_id, event_id ) + raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id) # If we are currently in the process of joining this room, then we # queue up events for later processing. if room_id in self.room_queues: logger.info( "[%s %s] Queuing PDU from %s for now: join in progress", - room_id, event_id, origin, + room_id, + event_id, + origin, ) self.room_queues[room_id].append((pdu, origin)) return @@ -206,14 +196,13 @@ class FederationHandler(BaseHandler): # # Note that if we were never in the room then we would have already # dropped the event, since we wouldn't know the room version. - is_in_room = yield self.auth.check_host_in_room( - room_id, - self.server_name - ) + is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name) if not is_in_room: logger.info( "[%s %s] Ignoring PDU from %s as we're not in the room", - room_id, event_id, origin, + room_id, + event_id, + origin, ) defer.returnValue(None) @@ -223,14 +212,9 @@ class FederationHandler(BaseHandler): # Get missing pdus if necessary. if not pdu.internal_metadata.is_outlier(): # We only backfill backwards to the min depth. - min_depth = yield self.get_min_depth_for_context( - pdu.room_id - ) + min_depth = yield self.get_min_depth_for_context(pdu.room_id) - logger.debug( - "[%s %s] min_depth: %d", - room_id, event_id, min_depth, - ) + logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth) prevs = set(pdu.prev_event_ids()) seen = yield self.store.have_seen_events(prevs) @@ -248,12 +232,17 @@ class FederationHandler(BaseHandler): # at a time. logger.info( "[%s %s] Acquiring room lock to fetch %d missing prev_events: %s", - room_id, event_id, len(missing_prevs), shortstr(missing_prevs), + room_id, + event_id, + len(missing_prevs), + shortstr(missing_prevs), ) with (yield self._room_pdu_linearizer.queue(pdu.room_id)): logger.info( "[%s %s] Acquired room lock to fetch %d missing prev_events", - room_id, event_id, len(missing_prevs), + room_id, + event_id, + len(missing_prevs), ) yield self._get_missing_events_for_pdu( @@ -267,12 +256,16 @@ class FederationHandler(BaseHandler): if not prevs - seen: logger.info( "[%s %s] Found all missing prev_events", - room_id, event_id, + room_id, + event_id, ) elif missing_prevs: logger.info( "[%s %s] Not recursively fetching %d missing prev_events: %s", - room_id, event_id, len(missing_prevs), shortstr(missing_prevs), + room_id, + event_id, + len(missing_prevs), + shortstr(missing_prevs), ) if prevs - seen: @@ -303,7 +296,10 @@ class FederationHandler(BaseHandler): if sent_to_us_directly: logger.warn( "[%s %s] Rejecting: failed to fetch %d prev events: %s", - room_id, event_id, len(prevs - seen), shortstr(prevs - seen) + room_id, + event_id, + len(prevs - seen), + shortstr(prevs - seen), ) raise FederationError( "ERROR", @@ -318,9 +314,7 @@ class FederationHandler(BaseHandler): # Calculate the state after each of the previous events, and # resolve them to find the correct state at the current event. auth_chains = set() - event_map = { - event_id: pdu, - } + event_map = {event_id: pdu} try: # Get the state of the events we know about ours = yield self.store.get_state_groups_ids(room_id, seen) @@ -337,7 +331,9 @@ class FederationHandler(BaseHandler): for p in prevs - seen: logger.info( "[%s %s] Requesting state at missing prev_event %s", - room_id, event_id, p, + room_id, + event_id, + p, ) room_version = yield self.store.get_room_version(room_id) @@ -348,19 +344,19 @@ class FederationHandler(BaseHandler): # by the get_pdu_cache in federation_client. remote_state, got_auth_chain = ( yield self.federation_client.get_state_for_room( - origin, room_id, p, + origin, room_id, p ) ) # we want the state *after* p; get_state_for_room returns the # state *before* p. remote_event = yield self.federation_client.get_pdu( - [origin], p, room_version, outlier=True, + [origin], p, room_version, outlier=True ) if remote_event is None: raise Exception( - "Unable to get missing prev_event %s" % (p, ) + "Unable to get missing prev_event %s" % (p,) ) if remote_event.is_state(): @@ -380,7 +376,9 @@ class FederationHandler(BaseHandler): event_map[x.event_id] = x state_map = yield resolve_events_with_store( - room_version, state_maps, event_map, + room_version, + state_maps, + event_map, state_res_store=StateResolutionStore(self.store), ) @@ -396,15 +394,15 @@ class FederationHandler(BaseHandler): ) event_map.update(evs) - state = [ - event_map[e] for e in six.itervalues(state_map) - ] + state = [event_map[e] for e in six.itervalues(state_map)] auth_chain = list(auth_chains) except Exception: logger.warn( "[%s %s] Error attempting to resolve state at missing " "prev_events", - room_id, event_id, exc_info=True, + room_id, + event_id, + exc_info=True, ) raise FederationError( "ERROR", @@ -414,10 +412,7 @@ class FederationHandler(BaseHandler): ) yield self._process_received_pdu( - origin, - pdu, - state=state, - auth_chain=auth_chain, + origin, pdu, state=state, auth_chain=auth_chain ) @defer.inlineCallbacks @@ -447,7 +442,10 @@ class FederationHandler(BaseHandler): logger.info( "[%s %s]: Requesting missing events between %s and %s", - room_id, event_id, shortstr(latest), event_id, + room_id, + event_id, + shortstr(latest), + event_id, ) # XXX: we set timeout to 10s to help workaround @@ -512,15 +510,15 @@ class FederationHandler(BaseHandler): # We failed to get the missing events, but since we need to handle # the case of `get_missing_events` not returning the necessary # events anyway, it is safe to simply log the error and continue. - logger.warn( - "[%s %s]: Failed to get prev_events: %s", - room_id, event_id, e, - ) + logger.warn("[%s %s]: Failed to get prev_events: %s", room_id, event_id, e) return logger.info( "[%s %s]: Got %d prev_events: %s", - room_id, event_id, len(missing_events), shortstr(missing_events), + room_id, + event_id, + len(missing_events), + shortstr(missing_events), ) # We want to sort these by depth so we process them and @@ -530,20 +528,20 @@ class FederationHandler(BaseHandler): for ev in missing_events: logger.info( "[%s %s] Handling received prev_event %s", - room_id, event_id, ev.event_id, + room_id, + event_id, + ev.event_id, ) with logcontext.nested_logging_context(ev.event_id): try: - yield self.on_receive_pdu( - origin, - ev, - sent_to_us_directly=False, - ) + yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False) except FederationError as e: if e.code == 403: logger.warn( "[%s %s] Received prev_event %s failed history check.", - room_id, event_id, ev.event_id, + room_id, + event_id, + ev.event_id, ) else: raise @@ -556,10 +554,7 @@ class FederationHandler(BaseHandler): room_id = event.room_id event_id = event.event_id - logger.debug( - "[%s %s] Processing event: %s", - room_id, event_id, event, - ) + logger.debug("[%s %s] Processing event: %s", room_id, event_id, event) event_ids = set() if state: @@ -581,43 +576,32 @@ class FederationHandler(BaseHandler): e.internal_metadata.outlier = True auth_ids = e.auth_event_ids() auth = { - (e.type, e.state_key): e for e in auth_chain + (e.type, e.state_key): e + for e in auth_chain if e.event_id in auth_ids or e.type == EventTypes.Create } - event_infos.append({ - "event": e, - "auth_events": auth, - }) + event_infos.append({"event": e, "auth_events": auth}) seen_ids.add(e.event_id) logger.info( "[%s %s] persisting newly-received auth/state events %s", - room_id, event_id, [e["event"].event_id for e in event_infos] + room_id, + event_id, + [e["event"].event_id for e in event_infos], ) yield self._handle_new_events(origin, event_infos) try: - context = yield self._handle_new_event( - origin, - event, - state=state, - ) + context = yield self._handle_new_event(origin, event, state=state) except AuthError as e: - raise FederationError( - "ERROR", - e.code, - e.msg, - affected=event.event_id, - ) + raise FederationError("ERROR", e.code, e.msg, affected=event.event_id) room = yield self.store.get_room(room_id) if not room: try: yield self.store.store_room( - room_id=room_id, - room_creator_user_id="", - is_public=False, + room_id=room_id, room_creator_user_id="", is_public=False ) except StoreError: logger.exception("Failed to store room.") @@ -631,12 +615,10 @@ class FederationHandler(BaseHandler): prev_state_ids = yield context.get_prev_state_ids(self.store) - prev_state_id = prev_state_ids.get( - (event.type, event.state_key) - ) + prev_state_id = prev_state_ids.get((event.type, event.state_key)) if prev_state_id: prev_state = yield self.store.get_event( - prev_state_id, allow_none=True, + prev_state_id, allow_none=True ) if prev_state and prev_state.membership == Membership.JOIN: newly_joined = False @@ -667,10 +649,7 @@ class FederationHandler(BaseHandler): room_version = yield self.store.get_room_version(room_id) events = yield self.federation_client.backfill( - dest, - room_id, - limit=limit, - extremities=extremities, + dest, room_id, limit=limit, extremities=extremities ) # ideally we'd sanity check the events here for excess prev_events etc, @@ -697,16 +676,9 @@ class FederationHandler(BaseHandler): event_ids = set(e.event_id for e in events) - edges = [ - ev.event_id - for ev in events - if set(ev.prev_event_ids()) - event_ids - ] + edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids] - logger.info( - "backfill: Got %d events with %d edges", - len(events), len(edges), - ) + logger.info("backfill: Got %d events with %d edges", len(events), len(edges)) # For each edge get the current state. @@ -715,9 +687,7 @@ class FederationHandler(BaseHandler): events_to_state = {} for e_id in edges: state, auth = yield self.federation_client.get_state_for_room( - destination=dest, - room_id=room_id, - event_id=e_id + destination=dest, room_id=room_id, event_id=e_id ) auth_events.update({a.event_id: a for a in auth}) auth_events.update({s.event_id: s for s in state}) @@ -726,12 +696,14 @@ class FederationHandler(BaseHandler): required_auth = set( a_id - for event in events + list(state_events.values()) + list(auth_events.values()) + for event in events + + list(state_events.values()) + + list(auth_events.values()) for a_id in event.auth_event_ids() ) - auth_events.update({ - e_id: event_map[e_id] for e_id in required_auth if e_id in event_map - }) + auth_events.update( + {e_id: event_map[e_id] for e_id in required_auth if e_id in event_map} + ) missing_auth = required_auth - set(auth_events) failed_to_fetch = set() @@ -750,27 +722,30 @@ class FederationHandler(BaseHandler): if missing_auth - failed_to_fetch: logger.info( "Fetching missing auth for backfill: %r", - missing_auth - failed_to_fetch + missing_auth - failed_to_fetch, ) - results = yield logcontext.make_deferred_yieldable(defer.gatherResults( - [ - logcontext.run_in_background( - self.federation_client.get_pdu, - [dest], - event_id, - room_version=room_version, - outlier=True, - timeout=10000, - ) - for event_id in missing_auth - failed_to_fetch - ], - consumeErrors=True - )).addErrback(unwrapFirstError) + results = yield logcontext.make_deferred_yieldable( + defer.gatherResults( + [ + logcontext.run_in_background( + self.federation_client.get_pdu, + [dest], + event_id, + room_version=room_version, + outlier=True, + timeout=10000, + ) + for event_id in missing_auth - failed_to_fetch + ], + consumeErrors=True, + ) + ).addErrback(unwrapFirstError) auth_events.update({a.event_id: a for a in results if a}) required_auth.update( a_id - for event in results if event + for event in results + if event for a_id in event.auth_event_ids() ) missing_auth = required_auth - set(auth_events) @@ -802,15 +777,19 @@ class FederationHandler(BaseHandler): continue a.internal_metadata.outlier = True - ev_infos.append({ - "event": a, - "auth_events": { - (auth_events[a_id].type, auth_events[a_id].state_key): - auth_events[a_id] - for a_id in a.auth_event_ids() - if a_id in auth_events + ev_infos.append( + { + "event": a, + "auth_events": { + ( + auth_events[a_id].type, + auth_events[a_id].state_key, + ): auth_events[a_id] + for a_id in a.auth_event_ids() + if a_id in auth_events + }, } - }) + ) # Step 1b: persist the events in the chunk we fetched state for (i.e. # the backwards extremities) as non-outliers. @@ -818,23 +797,24 @@ class FederationHandler(BaseHandler): # For paranoia we ensure that these events are marked as # non-outliers ev = event_map[e_id] - assert(not ev.internal_metadata.is_outlier()) - - ev_infos.append({ - "event": ev, - "state": events_to_state[e_id], - "auth_events": { - (auth_events[a_id].type, auth_events[a_id].state_key): - auth_events[a_id] - for a_id in ev.auth_event_ids() - if a_id in auth_events + assert not ev.internal_metadata.is_outlier() + + ev_infos.append( + { + "event": ev, + "state": events_to_state[e_id], + "auth_events": { + ( + auth_events[a_id].type, + auth_events[a_id].state_key, + ): auth_events[a_id] + for a_id in ev.auth_event_ids() + if a_id in auth_events + }, } - }) + ) - yield self._handle_new_events( - dest, ev_infos, - backfilled=True, - ) + yield self._handle_new_events(dest, ev_infos, backfilled=True) # Step 2: Persist the rest of the events in the chunk one by one events.sort(key=lambda e: e.depth) @@ -845,14 +825,12 @@ class FederationHandler(BaseHandler): # For paranoia we ensure that these events are marked as # non-outliers - assert(not event.internal_metadata.is_outlier()) + assert not event.internal_metadata.is_outlier() # We store these one at a time since each event depends on the # previous to work out the state. # TODO: We can probably do something more clever here. - yield self._handle_new_event( - dest, event, backfilled=True, - ) + yield self._handle_new_event(dest, event, backfilled=True) defer.returnValue(events) @@ -861,9 +839,7 @@ class FederationHandler(BaseHandler): """Checks the database to see if we should backfill before paginating, and if so do. """ - extremities = yield self.store.get_oldest_events_with_depth_in_room( - room_id - ) + extremities = yield self.store.get_oldest_events_with_depth_in_room(room_id) if not extremities: logger.debug("Not backfilling as no extremeties found.") @@ -895,31 +871,27 @@ class FederationHandler(BaseHandler): # state *before* the event, ignoring the special casing certain event # types have. - forward_events = yield self.store.get_successor_events( - list(extremities), - ) + forward_events = yield self.store.get_successor_events(list(extremities)) extremities_events = yield self.store.get_events( - forward_events, - check_redacted=False, - get_prev_content=False, + forward_events, check_redacted=False, get_prev_content=False ) # We set `check_history_visibility_only` as we might otherwise get false # positives from users having been erased. filtered_extremities = yield filter_events_for_server( - self.store, self.server_name, list(extremities_events.values()), - redact=False, check_history_visibility_only=True, + self.store, + self.server_name, + list(extremities_events.values()), + redact=False, + check_history_visibility_only=True, ) if not filtered_extremities: defer.returnValue(False) # Check if we reached a point where we should start backfilling. - sorted_extremeties_tuple = sorted( - extremities.items(), - key=lambda e: -int(e[1]) - ) + sorted_extremeties_tuple = sorted(extremities.items(), key=lambda e: -int(e[1])) max_depth = sorted_extremeties_tuple[0][1] # We don't want to specify too many extremities as it causes the backfill @@ -928,8 +900,7 @@ class FederationHandler(BaseHandler): if current_depth > max_depth: logger.debug( - "Not backfilling as we don't need to. %d < %d", - max_depth, current_depth, + "Not backfilling as we don't need to. %d < %d", max_depth, current_depth ) return @@ -954,8 +925,7 @@ class FederationHandler(BaseHandler): joined_users = [ (state_key, int(event.depth)) for (e_type, state_key), event in iteritems(state) - if e_type == EventTypes.Member - and event.membership == Membership.JOIN + if e_type == EventTypes.Member and event.membership == Membership.JOIN ] joined_domains = {} @@ -975,8 +945,7 @@ class FederationHandler(BaseHandler): curr_domains = get_domains_from_state(curr_state) likely_domains = [ - domain for domain, depth in curr_domains - if domain != self.server_name + domain for domain, depth in curr_domains if domain != self.server_name ] @defer.inlineCallbacks @@ -985,28 +954,20 @@ class FederationHandler(BaseHandler): for dom in domains: try: yield self.backfill( - dom, room_id, - limit=100, - extremities=extremities, + dom, room_id, limit=100, extremities=extremities ) # If this succeeded then we probably already have the # appropriate stuff. # TODO: We can probably do something more intelligent here. defer.returnValue(True) except SynapseError as e: - logger.info( - "Failed to backfill from %s because %s", - dom, e, - ) + logger.info("Failed to backfill from %s because %s", dom, e) continue except CodeMessageException as e: if 400 <= e.code < 500: raise - logger.info( - "Failed to backfill from %s because %s", - dom, e, - ) + logger.info("Failed to backfill from %s because %s", dom, e) continue except NotRetryingDestination as e: logger.info(str(e)) @@ -1015,10 +976,7 @@ class FederationHandler(BaseHandler): logger.info(e) continue except Exception as e: - logger.exception( - "Failed to backfill from %s because %s", - dom, e, - ) + logger.exception("Failed to backfill from %s because %s", dom, e) continue defer.returnValue(False) @@ -1039,10 +997,11 @@ class FederationHandler(BaseHandler): resolve = logcontext.preserve_fn( self.state_handler.resolve_state_groups_for_events ) - states = yield logcontext.make_deferred_yieldable(defer.gatherResults( - [resolve(room_id, [e]) for e in event_ids], - consumeErrors=True, - )) + states = yield logcontext.make_deferred_yieldable( + defer.gatherResults( + [resolve(room_id, [e]) for e in event_ids], consumeErrors=True + ) + ) # dict[str, dict[tuple, str]], a map from event_id to state map of # event_ids. @@ -1050,23 +1009,23 @@ class FederationHandler(BaseHandler): state_map = yield self.store.get_events( [e_id for ids in itervalues(states) for e_id in itervalues(ids)], - get_prev_content=False + get_prev_content=False, ) states = { key: { k: state_map[e_id] for k, e_id in iteritems(state_dict) if e_id in state_map - } for key, state_dict in iteritems(states) + } + for key, state_dict in iteritems(states) } for e_id, _ in sorted_extremeties_tuple: likely_domains = get_domains_from_state(states[e_id]) - success = yield try_backfill([ - dom for dom, _ in likely_domains - if dom not in tried_domains - ]) + success = yield try_backfill( + [dom for dom, _ in likely_domains if dom not in tried_domains] + ) if success: defer.returnValue(True) @@ -1091,20 +1050,20 @@ class FederationHandler(BaseHandler): SynapseError if the event does not pass muster """ if len(ev.prev_event_ids()) > 20: - logger.warn("Rejecting event %s which has %i prev_events", - ev.event_id, len(ev.prev_event_ids())) - raise SynapseError( - http_client.BAD_REQUEST, - "Too many prev_events", + logger.warn( + "Rejecting event %s which has %i prev_events", + ev.event_id, + len(ev.prev_event_ids()), ) + raise SynapseError(http_client.BAD_REQUEST, "Too many prev_events") if len(ev.auth_event_ids()) > 10: - logger.warn("Rejecting event %s which has %i auth_events", - ev.event_id, len(ev.auth_event_ids())) - raise SynapseError( - http_client.BAD_REQUEST, - "Too many auth_events", + logger.warn( + "Rejecting event %s which has %i auth_events", + ev.event_id, + len(ev.auth_event_ids()), ) + raise SynapseError(http_client.BAD_REQUEST, "Too many auth_events") @defer.inlineCallbacks def send_invite(self, target_host, event): @@ -1116,7 +1075,7 @@ class FederationHandler(BaseHandler): destination=target_host, room_id=event.room_id, event_id=event.event_id, - pdu=event + pdu=event, ) defer.returnValue(pdu) @@ -1125,8 +1084,7 @@ class FederationHandler(BaseHandler): def on_event_auth(self, event_id): event = yield self.store.get_event(event_id) auth = yield self.store.get_auth_chain( - [auth_id for auth_id in event.auth_event_ids()], - include_given=True + [auth_id for auth_id in event.auth_event_ids()], include_given=True ) defer.returnValue([e for e in auth]) @@ -1152,15 +1110,13 @@ class FederationHandler(BaseHandler): joinee, "join", content, - params={ - "ver": KNOWN_ROOM_VERSIONS, - }, + params={"ver": KNOWN_ROOM_VERSIONS}, ) # This shouldn't happen, because the RoomMemberHandler has a # linearizer lock which only allows one operation per user per room # at a time - so this is just paranoia. - assert (room_id not in self.room_queues) + assert room_id not in self.room_queues self.room_queues[room_id] = [] @@ -1177,7 +1133,7 @@ class FederationHandler(BaseHandler): except ValueError: pass ret = yield self.federation_client.send_join( - target_hosts, event, event_format_version, + target_hosts, event, event_format_version ) origin = ret["origin"] @@ -1196,17 +1152,13 @@ class FederationHandler(BaseHandler): try: yield self.store.store_room( - room_id=room_id, - room_creator_user_id="", - is_public=False + room_id=room_id, room_creator_user_id="", is_public=False ) except Exception: # FIXME pass - yield self._persist_auth_tree( - origin, auth_chain, state, event - ) + yield self._persist_auth_tree(origin, auth_chain, state, event) logger.debug("Finished joining %s to %s", joinee, room_id) finally: @@ -1233,14 +1185,18 @@ class FederationHandler(BaseHandler): """ for p, origin in room_queue: try: - logger.info("Processing queued PDU %s which was received " - "while we were joining %s", p.event_id, p.room_id) + logger.info( + "Processing queued PDU %s which was received " + "while we were joining %s", + p.event_id, + p.room_id, + ) with logcontext.nested_logging_context(p.event_id): yield self.on_receive_pdu(origin, p, sent_to_us_directly=True) except Exception as e: logger.warn( - "Error handling queued PDU %s from %s: %s", - p.event_id, origin, e) + "Error handling queued PDU %s from %s: %s", p.event_id, origin, e + ) @defer.inlineCallbacks @log_function @@ -1261,30 +1217,30 @@ class FederationHandler(BaseHandler): "room_id": room_id, "sender": user_id, "state_key": user_id, - } + }, ) try: event, context = yield self.event_creation_handler.create_new_client_event( - builder=builder, + builder=builder ) except AuthError as e: logger.warn("Failed to create join %r because %s", event, e) raise e event_allowed = yield self.third_party_event_rules.check_event_allowed( - event, context, + event, context ) if not event_allowed: logger.info("Creation of join %s forbidden by third-party rules", event) raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN, + 403, "This event is not allowed in this context", Codes.FORBIDDEN ) # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_join_request` yield self.auth.check_from_context( - room_version, event, context, do_sig_check=False, + room_version, event, context, do_sig_check=False ) defer.returnValue(event) @@ -1319,17 +1275,15 @@ class FederationHandler(BaseHandler): # would introduce the danger of backwards-compatibility problems. event.internal_metadata.send_on_behalf_of = origin - context = yield self._handle_new_event( - origin, event - ) + context = yield self._handle_new_event(origin, event) event_allowed = yield self.third_party_event_rules.check_event_allowed( - event, context, + event, context ) if not event_allowed: logger.info("Sending of join %s forbidden by third-party rules", event) raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN, + 403, "This event is not allowed in this context", Codes.FORBIDDEN ) logger.debug( @@ -1350,10 +1304,7 @@ class FederationHandler(BaseHandler): state = yield self.store.get_events(list(prev_state_ids.values())) - defer.returnValue({ - "state": list(state.values()), - "auth_chain": auth_chain, - }) + defer.returnValue({"state": list(state.values()), "auth_chain": auth_chain}) @defer.inlineCallbacks def on_invite_request(self, origin, pdu): @@ -1374,7 +1325,7 @@ class FederationHandler(BaseHandler): raise SynapseError(403, "This server does not accept room invites") if not self.spam_checker.user_may_invite( - event.sender, event.state_key, event.room_id, + event.sender, event.state_key, event.room_id ): raise SynapseError( 403, "This user is not permitted to send invites to this server/user" @@ -1386,26 +1337,23 @@ class FederationHandler(BaseHandler): sender_domain = get_domain_from_id(event.sender) if sender_domain != origin: - raise SynapseError(400, "The invite event was not from the server sending it") + raise SynapseError( + 400, "The invite event was not from the server sending it" + ) if not self.is_mine_id(event.state_key): raise SynapseError(400, "The invite event must be for this server") # block any attempts to invite the server notices mxid if event.state_key == self._server_notices_mxid: - raise SynapseError( - http_client.FORBIDDEN, - "Cannot invite this user", - ) + raise SynapseError(http_client.FORBIDDEN, "Cannot invite this user") event.internal_metadata.outlier = True event.internal_metadata.out_of_band_membership = True event.signatures.update( compute_event_signature( - event.get_pdu_json(), - self.hs.hostname, - self.hs.config.signing_key[0] + event.get_pdu_json(), self.hs.hostname, self.hs.config.signing_key[0] ) ) @@ -1417,10 +1365,7 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def do_remotely_reject_invite(self, target_hosts, room_id, user_id): origin, event, event_format_version = yield self._make_and_verify_event( - target_hosts, - room_id, - user_id, - "leave" + target_hosts, room_id, user_id, "leave" ) # Mark as outlier as we don't have any state for this event; we're not # even in the room. @@ -1435,10 +1380,7 @@ class FederationHandler(BaseHandler): except ValueError: pass - yield self.federation_client.send_leave( - target_hosts, - event - ) + yield self.federation_client.send_leave(target_hosts, event) context = yield self.state_handler.compute_event_context(event) yield self.persist_events_and_notify([(event, context)]) @@ -1446,25 +1388,21 @@ class FederationHandler(BaseHandler): defer.returnValue(event) @defer.inlineCallbacks - def _make_and_verify_event(self, target_hosts, room_id, user_id, membership, - content={}, params=None): + def _make_and_verify_event( + self, target_hosts, room_id, user_id, membership, content={}, params=None + ): origin, event, format_ver = yield self.federation_client.make_membership_event( - target_hosts, - room_id, - user_id, - membership, - content, - params=params, + target_hosts, room_id, user_id, membership, content, params=params ) logger.debug("Got response to make_%s: %s", membership, event) # We should assert some things. # FIXME: Do this in a nicer way - assert(event.type == EventTypes.Member) - assert(event.user_id == user_id) - assert(event.state_key == user_id) - assert(event.room_id == room_id) + assert event.type == EventTypes.Member + assert event.user_id == user_id + assert event.state_key == user_id + assert event.room_id == room_id defer.returnValue((origin, event, format_ver)) @defer.inlineCallbacks @@ -1483,27 +1421,27 @@ class FederationHandler(BaseHandler): "room_id": room_id, "sender": user_id, "state_key": user_id, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( - builder=builder, + builder=builder ) event_allowed = yield self.third_party_event_rules.check_event_allowed( - event, context, + event, context ) if not event_allowed: logger.warning("Creation of leave %s forbidden by third-party rules", event) raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN, + 403, "This event is not allowed in this context", Codes.FORBIDDEN ) try: # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_leave_request` yield self.auth.check_from_context( - room_version, event, context, do_sig_check=False, + room_version, event, context, do_sig_check=False ) except AuthError as e: logger.warn("Failed to create new leave %r because %s", event, e) @@ -1525,17 +1463,15 @@ class FederationHandler(BaseHandler): event.internal_metadata.outlier = False - context = yield self._handle_new_event( - origin, event - ) + context = yield self._handle_new_event(origin, event) event_allowed = yield self.third_party_event_rules.check_event_allowed( - event, context, + event, context ) if not event_allowed: logger.info("Sending of leave %s forbidden by third-party rules", event) raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN, + 403, "This event is not allowed in this context", Codes.FORBIDDEN ) logger.debug( @@ -1552,18 +1488,14 @@ class FederationHandler(BaseHandler): """ event = yield self.store.get_event( - event_id, allow_none=False, check_room_id=room_id, + event_id, allow_none=False, check_room_id=room_id ) - state_groups = yield self.store.get_state_groups( - room_id, [event_id] - ) + state_groups = yield self.store.get_state_groups(room_id, [event_id]) if state_groups: _, state = list(iteritems(state_groups)).pop() - results = { - (e.type, e.state_key): e for e in state - } + results = {(e.type, e.state_key): e for e in state} if event.is_state(): # Get previous state @@ -1585,12 +1517,10 @@ class FederationHandler(BaseHandler): """Returns the state at the event. i.e. not including said event. """ event = yield self.store.get_event( - event_id, allow_none=False, check_room_id=room_id, + event_id, allow_none=False, check_room_id=room_id ) - state_groups = yield self.store.get_state_groups_ids( - room_id, [event_id] - ) + state_groups = yield self.store.get_state_groups_ids(room_id, [event_id]) if state_groups: _, state = list(state_groups.items()).pop() @@ -1616,11 +1546,7 @@ class FederationHandler(BaseHandler): if not in_room: raise AuthError(403, "Host not in room.") - events = yield self.store.get_backfill_events( - room_id, - pdu_list, - limit - ) + events = yield self.store.get_backfill_events(room_id, pdu_list, limit) events = yield filter_events_for_server(self.store, origin, events) @@ -1644,22 +1570,15 @@ class FederationHandler(BaseHandler): AuthError if the server is not currently in the room """ event = yield self.store.get_event( - event_id, - allow_none=True, - allow_rejected=True, + event_id, allow_none=True, allow_rejected=True ) if event: - in_room = yield self.auth.check_host_in_room( - event.room_id, - origin - ) + in_room = yield self.auth.check_host_in_room(event.room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") - events = yield filter_events_for_server( - self.store, origin, [event], - ) + events = yield filter_events_for_server(self.store, origin, [event]) event = events[0] defer.returnValue(event) else: @@ -1669,13 +1588,11 @@ class FederationHandler(BaseHandler): return self.store.get_min_depth(context) @defer.inlineCallbacks - def _handle_new_event(self, origin, event, state=None, auth_events=None, - backfilled=False): + def _handle_new_event( + self, origin, event, state=None, auth_events=None, backfilled=False + ): context = yield self._prep_event( - origin, event, - state=state, - auth_events=auth_events, - backfilled=backfilled, + origin, event, state=state, auth_events=auth_events, backfilled=backfilled ) # reraise does not allow inlineCallbacks to preserve the stacktrace, so we @@ -1688,15 +1605,13 @@ class FederationHandler(BaseHandler): ) yield self.persist_events_and_notify( - [(event, context)], - backfilled=backfilled, + [(event, context)], backfilled=backfilled ) success = True finally: if not success: logcontext.run_in_background( - self.store.remove_push_actions_from_staging, - event.event_id, + self.store.remove_push_actions_from_staging, event.event_id ) defer.returnValue(context) @@ -1724,12 +1639,15 @@ class FederationHandler(BaseHandler): ) defer.returnValue(res) - contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults( - [ - logcontext.run_in_background(prep, ev_info) - for ev_info in event_infos - ], consumeErrors=True, - )) + contexts = yield logcontext.make_deferred_yieldable( + defer.gatherResults( + [ + logcontext.run_in_background(prep, ev_info) + for ev_info in event_infos + ], + consumeErrors=True, + ) + ) yield self.persist_events_and_notify( [ @@ -1764,8 +1682,7 @@ class FederationHandler(BaseHandler): events_to_context[e.event_id] = ctx event_map = { - e.event_id: e - for e in itertools.chain(auth_events, state, [event]) + e.event_id: e for e in itertools.chain(auth_events, state, [event]) } create_event = None @@ -1780,7 +1697,7 @@ class FederationHandler(BaseHandler): raise SynapseError(400, "No create event in state") room_version = create_event.content.get( - "room_version", RoomVersions.V1.identifier, + "room_version", RoomVersions.V1.identifier ) missing_auth_events = set() @@ -1791,11 +1708,7 @@ class FederationHandler(BaseHandler): for e_id in missing_auth_events: m_ev = yield self.federation_client.get_pdu( - [origin], - e_id, - room_version=room_version, - outlier=True, - timeout=10000, + [origin], e_id, room_version=room_version, outlier=True, timeout=10000 ) if m_ev and m_ev.event_id == e_id: event_map[e_id] = m_ev @@ -1820,10 +1733,7 @@ class FederationHandler(BaseHandler): # cause SynapseErrors in auth.check. We don't want to give up # the attempt to federate altogether in such cases. - logger.warn( - "Rejecting %s because %s", - e.event_id, err.msg - ) + logger.warn("Rejecting %s because %s", e.event_id, err.msg) if e == event: raise @@ -1833,16 +1743,14 @@ class FederationHandler(BaseHandler): [ (e, events_to_context[e.event_id]) for e in itertools.chain(auth_events, state) - ], + ] ) new_event_context = yield self.state_handler.compute_event_context( event, old_state=state ) - yield self.persist_events_and_notify( - [(event, new_event_context)], - ) + yield self.persist_events_and_notify([(event, new_event_context)]) @defer.inlineCallbacks def _prep_event(self, origin, event, state, auth_events, backfilled): @@ -1858,40 +1766,30 @@ class FederationHandler(BaseHandler): Returns: Deferred, which resolves to synapse.events.snapshot.EventContext """ - context = yield self.state_handler.compute_event_context( - event, old_state=state, - ) + context = yield self.state_handler.compute_event_context(event, old_state=state) if not auth_events: prev_state_ids = yield context.get_prev_state_ids(self.store) auth_events_ids = yield self.auth.compute_auth_events( - event, prev_state_ids, for_verification=True, + event, prev_state_ids, for_verification=True ) auth_events = yield self.store.get_events(auth_events_ids) - auth_events = { - (e.type, e.state_key): e for e in auth_events.values() - } + auth_events = {(e.type, e.state_key): e for e in auth_events.values()} # This is a hack to fix some old rooms where the initial join event # didn't reference the create event in its auth events. if event.type == EventTypes.Member and not event.auth_event_ids(): if len(event.prev_event_ids()) == 1 and event.depth < 5: c = yield self.store.get_event( - event.prev_event_ids()[0], - allow_none=True, + event.prev_event_ids()[0], allow_none=True ) if c and c.type == EventTypes.Create: auth_events[(c.type, c.state_key)] = c try: - yield self.do_auth( - origin, event, context, auth_events=auth_events - ) + yield self.do_auth(origin, event, context, auth_events=auth_events) except AuthError as e: - logger.warn( - "[%s %s] Rejecting: %s", - event.room_id, event.event_id, e.msg - ) + logger.warn("[%s %s] Rejecting: %s", event.room_id, event.event_id, e.msg) context.rejected = RejectedReason.AUTH_ERROR @@ -1922,9 +1820,7 @@ class FederationHandler(BaseHandler): # "soft-fail" the event. do_soft_fail_check = not backfilled and not event.internal_metadata.is_outlier() if do_soft_fail_check: - extrem_ids = yield self.store.get_latest_event_ids_in_room( - event.room_id, - ) + extrem_ids = yield self.store.get_latest_event_ids_in_room(event.room_id) extrem_ids = set(extrem_ids) prev_event_ids = set(event.prev_event_ids()) @@ -1952,31 +1848,31 @@ class FederationHandler(BaseHandler): # like bans, especially with state res v2. state_sets = yield self.store.get_state_groups( - event.room_id, extrem_ids, + event.room_id, extrem_ids ) state_sets = list(state_sets.values()) state_sets.append(state) current_state_ids = yield self.state_handler.resolve_events( - room_version, state_sets, event, + room_version, state_sets, event ) current_state_ids = { k: e.event_id for k, e in iteritems(current_state_ids) } else: current_state_ids = yield self.state_handler.get_current_state_ids( - event.room_id, latest_event_ids=extrem_ids, + event.room_id, latest_event_ids=extrem_ids ) logger.debug( "Doing soft-fail check for %s: state %s", - event.event_id, current_state_ids, + event.event_id, + current_state_ids, ) # Now check if event pass auth against said current state auth_types = auth_types_for_event(event) current_state_ids = [ - e for k, e in iteritems(current_state_ids) - if k in auth_types + e for k, e in iteritems(current_state_ids) if k in auth_types ] current_auth_events = yield self.store.get_events(current_state_ids) @@ -1987,19 +1883,14 @@ class FederationHandler(BaseHandler): try: self.auth.check(room_version, event, auth_events=current_auth_events) except AuthError as e: - logger.warn( - "Soft-failing %r because %s", - event, e, - ) + logger.warn("Soft-failing %r because %s", event, e) event.internal_metadata.soft_failed = True @defer.inlineCallbacks - def on_query_auth(self, origin, event_id, room_id, remote_auth_chain, rejects, - missing): - in_room = yield self.auth.check_host_in_room( - room_id, - origin - ) + def on_query_auth( + self, origin, event_id, room_id, remote_auth_chain, rejects, missing + ): + in_room = yield self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") @@ -2017,28 +1908,23 @@ class FederationHandler(BaseHandler): # Now get the current auth_chain for the event. local_auth_chain = yield self.store.get_auth_chain( - [auth_id for auth_id in event.auth_event_ids()], - include_given=True + [auth_id for auth_id in event.auth_event_ids()], include_given=True ) # TODO: Check if we would now reject event_id. If so we need to tell # everyone. - ret = yield self.construct_auth_difference( - local_auth_chain, remote_auth_chain - ) + ret = yield self.construct_auth_difference(local_auth_chain, remote_auth_chain) logger.debug("on_query_auth returning: %s", ret) defer.returnValue(ret) @defer.inlineCallbacks - def on_get_missing_events(self, origin, room_id, earliest_events, - latest_events, limit): - in_room = yield self.auth.check_host_in_room( - room_id, - origin - ) + def on_get_missing_events( + self, origin, room_id, earliest_events, latest_events, limit + ): + in_room = yield self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") @@ -2052,7 +1938,7 @@ class FederationHandler(BaseHandler): ) missing_events = yield filter_events_for_server( - self.store, origin, missing_events, + self.store, origin, missing_events ) defer.returnValue(missing_events) @@ -2140,25 +2026,17 @@ class FederationHandler(BaseHandler): if missing_auth: # TODO: can we use store.have_seen_events here instead? - have_events = yield self.store.get_seen_events_with_rejections( - missing_auth - ) + have_events = yield self.store.get_seen_events_with_rejections(missing_auth) logger.debug("Got events %s from store", have_events) missing_auth.difference_update(have_events.keys()) else: have_events = {} - have_events.update({ - e.event_id: "" - for e in auth_events.values() - }) + have_events.update({e.event_id: "" for e in auth_events.values()}) if missing_auth: # If we don't have all the auth events, we need to get them. - logger.info( - "auth_events contains unknown events: %s", - missing_auth, - ) + logger.info("auth_events contains unknown events: %s", missing_auth) try: try: remote_auth_chain = yield self.federation_client.get_event_auth( @@ -2184,18 +2062,16 @@ class FederationHandler(BaseHandler): try: auth_ids = e.auth_event_ids() auth = { - (e.type, e.state_key): e for e in remote_auth_chain + (e.type, e.state_key): e + for e in remote_auth_chain if e.event_id in auth_ids or e.type == EventTypes.Create } e.internal_metadata.outlier = True logger.debug( - "do_auth %s missing_auth: %s", - event.event_id, e.event_id - ) - yield self._handle_new_event( - origin, e, auth_events=auth + "do_auth %s missing_auth: %s", event.event_id, e.event_id ) + yield self._handle_new_event(origin, e, auth_events=auth) if e.event_id in event_auth_events: auth_events[(e.type, e.state_key)] = e @@ -2231,35 +2107,36 @@ class FederationHandler(BaseHandler): room_version = yield self.store.get_room_version(event.room_id) different_events = yield logcontext.make_deferred_yieldable( - defer.gatherResults([ - logcontext.run_in_background( - self.store.get_event, - d, - allow_none=True, - allow_rejected=False, - ) - for d in different_auth - if d in have_events and not have_events[d] - ], consumeErrors=True) + defer.gatherResults( + [ + logcontext.run_in_background( + self.store.get_event, d, allow_none=True, allow_rejected=False + ) + for d in different_auth + if d in have_events and not have_events[d] + ], + consumeErrors=True, + ) ).addErrback(unwrapFirstError) if different_events: local_view = dict(auth_events) remote_view = dict(auth_events) - remote_view.update({ - (d.type, d.state_key): d for d in different_events if d - }) + remote_view.update( + {(d.type, d.state_key): d for d in different_events if d} + ) new_state = yield self.state_handler.resolve_events( room_version, [list(local_view.values()), list(remote_view.values())], - event + event, ) logger.info( "After state res: updating auth_events with new state %s", { - (d.type, d.state_key): d.event_id for d in new_state.values() + (d.type, d.state_key): d.event_id + for d in new_state.values() if auth_events.get((d.type, d.state_key)) != d }, ) @@ -2271,7 +2148,7 @@ class FederationHandler(BaseHandler): ) yield self._update_context_for_auth_events( - event, context, auth_events, event_key, + event, context, auth_events, event_key ) if not different_auth: @@ -2305,21 +2182,14 @@ class FederationHandler(BaseHandler): prev_state_ids = yield context.get_prev_state_ids(self.store) # 1. Get what we think is the auth chain. - auth_ids = yield self.auth.compute_auth_events( - event, prev_state_ids - ) - local_auth_chain = yield self.store.get_auth_chain( - auth_ids, include_given=True - ) + auth_ids = yield self.auth.compute_auth_events(event, prev_state_ids) + local_auth_chain = yield self.store.get_auth_chain(auth_ids, include_given=True) try: # 2. Get remote difference. try: result = yield self.federation_client.query_auth( - origin, - event.room_id, - event.event_id, - local_auth_chain, + origin, event.room_id, event.event_id, local_auth_chain ) except RequestSendFailed as e: # The other side isn't around or doesn't implement the @@ -2344,19 +2214,15 @@ class FederationHandler(BaseHandler): auth = { (e.type, e.state_key): e for e in result["auth_chain"] - if e.event_id in auth_ids - or event.type == EventTypes.Create + if e.event_id in auth_ids or event.type == EventTypes.Create } ev.internal_metadata.outlier = True logger.debug( - "do_auth %s different_auth: %s", - event.event_id, e.event_id + "do_auth %s different_auth: %s", event.event_id, e.event_id ) - yield self._handle_new_event( - origin, ev, auth_events=auth - ) + yield self._handle_new_event(origin, ev, auth_events=auth) if ev.event_id in event_auth_events: auth_events[(ev.type, ev.state_key)] = ev @@ -2371,12 +2237,11 @@ class FederationHandler(BaseHandler): # TODO. yield self._update_context_for_auth_events( - event, context, auth_events, event_key, + event, context, auth_events, event_key ) @defer.inlineCallbacks - def _update_context_for_auth_events(self, event, context, auth_events, - event_key): + def _update_context_for_auth_events(self, event, context, auth_events, event_key): """Update the state_ids in an event context after auth event resolution, storing the changes as a new state group. @@ -2393,8 +2258,7 @@ class FederationHandler(BaseHandler): this will not be included in the current_state in the context. """ state_updates = { - k: a.event_id for k, a in iteritems(auth_events) - if k != event_key + k: a.event_id for k, a in iteritems(auth_events) if k != event_key } current_state_ids = yield context.get_current_state_ids(self.store) current_state_ids = dict(current_state_ids) @@ -2404,9 +2268,7 @@ class FederationHandler(BaseHandler): prev_state_ids = yield context.get_prev_state_ids(self.store) prev_state_ids = dict(prev_state_ids) - prev_state_ids.update({ - k: a.event_id for k, a in iteritems(auth_events) - }) + prev_state_ids.update({k: a.event_id for k, a in iteritems(auth_events)}) # create a new state group as a delta from the existing one. prev_group = context.state_group @@ -2555,30 +2417,23 @@ class FederationHandler(BaseHandler): logger.debug("construct_auth_difference returning") - defer.returnValue({ - "auth_chain": local_auth, - "rejects": { - e.event_id: { - "reason": reason_map[e.event_id], - "proof": None, - } - for e in base_remote_rejected - }, - "missing": [e.event_id for e in missing_locals], - }) + defer.returnValue( + { + "auth_chain": local_auth, + "rejects": { + e.event_id: {"reason": reason_map[e.event_id], "proof": None} + for e in base_remote_rejected + }, + "missing": [e.event_id for e in missing_locals], + } + ) @defer.inlineCallbacks @log_function def exchange_third_party_invite( - self, - sender_user_id, - target_user_id, - room_id, - signed, + self, sender_user_id, target_user_id, room_id, signed ): - third_party_invite = { - "signed": signed, - } + third_party_invite = {"signed": signed} event_dict = { "type": EventTypes.Member, @@ -2601,7 +2456,7 @@ class FederationHandler(BaseHandler): ) event_allowed = yield self.third_party_event_rules.check_event_allowed( - event, context, + event, context ) if not event_allowed: logger.info( @@ -2609,7 +2464,7 @@ class FederationHandler(BaseHandler): event, ) raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN, + 403, "This event is not allowed in this context", Codes.FORBIDDEN ) event, context = yield self.add_display_name_to_third_party_invite( @@ -2634,9 +2489,7 @@ class FederationHandler(BaseHandler): else: destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id)) yield self.federation_client.forward_third_party_invite( - destinations, - room_id, - event_dict, + destinations, room_id, event_dict ) @defer.inlineCallbacks @@ -2657,19 +2510,18 @@ class FederationHandler(BaseHandler): builder = self.event_builder_factory.new(room_version, event_dict) event, context = yield self.event_creation_handler.create_new_client_event( - builder=builder, + builder=builder ) event_allowed = yield self.third_party_event_rules.check_event_allowed( - event, context, + event, context ) if not event_allowed: logger.warning( - "Exchange of threepid invite %s forbidden by third-party rules", - event, + "Exchange of threepid invite %s forbidden by third-party rules", event ) raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN, + 403, "This event is not allowed in this context", Codes.FORBIDDEN ) event, context = yield self.add_display_name_to_third_party_invite( @@ -2691,11 +2543,12 @@ class FederationHandler(BaseHandler): yield member_handler.send_membership_event(None, event, context) @defer.inlineCallbacks - def add_display_name_to_third_party_invite(self, room_version, event_dict, - event, context): + def add_display_name_to_third_party_invite( + self, room_version, event_dict, event, context + ): key = ( EventTypes.ThirdPartyInvite, - event.content["third_party_invite"]["signed"]["token"] + event.content["third_party_invite"]["signed"]["token"], ) original_invite = None prev_state_ids = yield context.get_prev_state_ids(self.store) @@ -2709,8 +2562,7 @@ class FederationHandler(BaseHandler): event_dict["content"]["third_party_invite"]["display_name"] = display_name else: logger.info( - "Could not find invite event for third_party_invite: %r", - event_dict + "Could not find invite event for third_party_invite: %r", event_dict ) # We don't discard here as this is not the appropriate place to do # auth checks. If we need the invite and don't have it then the @@ -2719,7 +2571,7 @@ class FederationHandler(BaseHandler): builder = self.event_builder_factory.new(room_version, event_dict) EventValidator().validate_builder(builder) event, context = yield self.event_creation_handler.create_new_client_event( - builder=builder, + builder=builder ) EventValidator().validate_new(event) defer.returnValue((event, context)) @@ -2743,9 +2595,7 @@ class FederationHandler(BaseHandler): token = signed["token"] prev_state_ids = yield context.get_prev_state_ids(self.store) - invite_event_id = prev_state_ids.get( - (EventTypes.ThirdPartyInvite, token,) - ) + invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token)) invite_event = None if invite_event_id: @@ -2769,38 +2619,42 @@ class FederationHandler(BaseHandler): logger.debug( "Attempting to verify sig with key %s from %r " "against pubkey %r", - key_name, server, public_key_object, + key_name, + server, + public_key_object, ) try: public_key = public_key_object["public_key"] verify_key = decode_verify_key_bytes( - key_name, - decode_base64(public_key) + key_name, decode_base64(public_key) ) verify_signed_json(signed, server, verify_key) logger.debug( "Successfully verified sig with key %s from %r " "against pubkey %r", - key_name, server, public_key_object, + key_name, + server, + public_key_object, ) except Exception: logger.info( "Failed to verify sig with key %s from %r " "against pubkey %r", - key_name, server, public_key_object, + key_name, + server, + public_key_object, ) raise try: if "key_validity_url" in public_key_object: yield self._check_key_revocation( - public_key, - public_key_object["key_validity_url"] + public_key, public_key_object["key_validity_url"] ) except Exception: logger.info( "Failed to query key_validity_url %s", - public_key_object["key_validity_url"] + public_key_object["key_validity_url"], ) raise return @@ -2823,15 +2677,9 @@ class FederationHandler(BaseHandler): for revocation. """ try: - response = yield self.http_client.get_json( - url, - {"public_key": public_key} - ) + response = yield self.http_client.get_json(url, {"public_key": public_key}) except Exception: - raise SynapseError( - 502, - "Third party certificate could not be checked" - ) + raise SynapseError(502, "Third party certificate could not be checked") if "valid" not in response or not response["valid"]: raise AuthError(403, "Third party certificate was invalid") @@ -2852,12 +2700,11 @@ class FederationHandler(BaseHandler): yield self._send_events_to_master( store=self.store, event_and_contexts=event_and_contexts, - backfilled=backfilled + backfilled=backfilled, ) else: max_stream_id = yield self.store.persist_events( - event_and_contexts, - backfilled=backfilled, + event_and_contexts, backfilled=backfilled ) if not backfilled: # Never notify for backfilled events @@ -2891,13 +2738,10 @@ class FederationHandler(BaseHandler): event_stream_id = event.internal_metadata.stream_ordering self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, - extra_users=extra_users + event, event_stream_id, max_stream_id, extra_users=extra_users ) - return self.pusher_pool.on_new_notifications( - event_stream_id, max_stream_id, - ) + return self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) def _clean_room_for_join(self, room_id): """Called to clean up any data in DB for a given room, ready for the @@ -2916,9 +2760,7 @@ class FederationHandler(BaseHandler): """ if self.config.worker_app: return self._notify_user_membership_change( - room_id=room_id, - user_id=user.to_string(), - change="joined", + room_id=room_id, user_id=user.to_string(), change="joined" ) else: return user_joined_room(self.distributor, user, room_id) diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index f60ace02e8..7da63bb643 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -30,6 +30,7 @@ def _create_rerouter(func_name): """Returns a function that looks at the group id and calls the function on federation or the local group server if the group is local """ + def f(self, group_id, *args, **kwargs): if self.is_mine_id(group_id): return getattr(self.groups_server_handler, func_name)( @@ -58,6 +59,7 @@ def _create_rerouter(func_name): d.addErrback(http_response_errback) d.addErrback(request_failed_errback) return d + return f @@ -125,7 +127,7 @@ class GroupsLocalHandler(object): ) else: res = yield self.transport_client.get_group_summary( - get_domain_from_id(group_id), group_id, requester_user_id, + get_domain_from_id(group_id), group_id, requester_user_id ) group_server_name = get_domain_from_id(group_id) @@ -182,7 +184,7 @@ class GroupsLocalHandler(object): content["user_profile"] = yield self.profile_handler.get_profile(user_id) res = yield self.transport_client.create_group( - get_domain_from_id(group_id), group_id, user_id, content, + get_domain_from_id(group_id), group_id, user_id, content ) remote_attestation = res["attestation"] @@ -195,16 +197,15 @@ class GroupsLocalHandler(object): is_publicised = content.get("publicise", False) token = yield self.store.register_user_group_membership( - group_id, user_id, + group_id, + user_id, membership="join", is_admin=True, local_attestation=local_attestation, remote_attestation=remote_attestation, is_publicised=is_publicised, ) - self.notifier.on_new_event( - "groups_key", token, users=[user_id], - ) + self.notifier.on_new_event("groups_key", token, users=[user_id]) defer.returnValue(res) @@ -221,7 +222,7 @@ class GroupsLocalHandler(object): group_server_name = get_domain_from_id(group_id) res = yield self.transport_client.get_users_in_group( - get_domain_from_id(group_id), group_id, requester_user_id, + get_domain_from_id(group_id), group_id, requester_user_id ) chunk = res["chunk"] @@ -250,9 +251,7 @@ class GroupsLocalHandler(object): """Request to join a group """ if self.is_mine_id(group_id): - yield self.groups_server_handler.join_group( - group_id, user_id, content - ) + yield self.groups_server_handler.join_group(group_id, user_id, content) local_attestation = None remote_attestation = None else: @@ -260,7 +259,7 @@ class GroupsLocalHandler(object): content["attestation"] = local_attestation res = yield self.transport_client.join_group( - get_domain_from_id(group_id), group_id, user_id, content, + get_domain_from_id(group_id), group_id, user_id, content ) remote_attestation = res["attestation"] @@ -276,16 +275,15 @@ class GroupsLocalHandler(object): is_publicised = content.get("publicise", False) token = yield self.store.register_user_group_membership( - group_id, user_id, + group_id, + user_id, membership="join", is_admin=False, local_attestation=local_attestation, remote_attestation=remote_attestation, is_publicised=is_publicised, ) - self.notifier.on_new_event( - "groups_key", token, users=[user_id], - ) + self.notifier.on_new_event("groups_key", token, users=[user_id]) defer.returnValue({}) @@ -294,9 +292,7 @@ class GroupsLocalHandler(object): """Accept an invite to a group """ if self.is_mine_id(group_id): - yield self.groups_server_handler.accept_invite( - group_id, user_id, content - ) + yield self.groups_server_handler.accept_invite(group_id, user_id, content) local_attestation = None remote_attestation = None else: @@ -304,7 +300,7 @@ class GroupsLocalHandler(object): content["attestation"] = local_attestation res = yield self.transport_client.accept_group_invite( - get_domain_from_id(group_id), group_id, user_id, content, + get_domain_from_id(group_id), group_id, user_id, content ) remote_attestation = res["attestation"] @@ -320,16 +316,15 @@ class GroupsLocalHandler(object): is_publicised = content.get("publicise", False) token = yield self.store.register_user_group_membership( - group_id, user_id, + group_id, + user_id, membership="join", is_admin=False, local_attestation=local_attestation, remote_attestation=remote_attestation, is_publicised=is_publicised, ) - self.notifier.on_new_event( - "groups_key", token, users=[user_id], - ) + self.notifier.on_new_event("groups_key", token, users=[user_id]) defer.returnValue({}) @@ -337,17 +332,17 @@ class GroupsLocalHandler(object): def invite(self, group_id, user_id, requester_user_id, config): """Invite a user to a group """ - content = { - "requester_user_id": requester_user_id, - "config": config, - } + content = {"requester_user_id": requester_user_id, "config": config} if self.is_mine_id(group_id): res = yield self.groups_server_handler.invite_to_group( - group_id, user_id, requester_user_id, content, + group_id, user_id, requester_user_id, content ) else: res = yield self.transport_client.invite_to_group( - get_domain_from_id(group_id), group_id, user_id, requester_user_id, + get_domain_from_id(group_id), + group_id, + user_id, + requester_user_id, content, ) @@ -370,13 +365,12 @@ class GroupsLocalHandler(object): local_profile["avatar_url"] = content["profile"]["avatar_url"] token = yield self.store.register_user_group_membership( - group_id, user_id, + group_id, + user_id, membership="invite", content={"profile": local_profile, "inviter": content["inviter"]}, ) - self.notifier.on_new_event( - "groups_key", token, users=[user_id], - ) + self.notifier.on_new_event("groups_key", token, users=[user_id]) try: user_profile = yield self.profile_handler.get_profile(user_id) except Exception as e: @@ -391,25 +385,25 @@ class GroupsLocalHandler(object): """ if user_id == requester_user_id: token = yield self.store.register_user_group_membership( - group_id, user_id, - membership="leave", - ) - self.notifier.on_new_event( - "groups_key", token, users=[user_id], + group_id, user_id, membership="leave" ) + self.notifier.on_new_event("groups_key", token, users=[user_id]) # TODO: Should probably remember that we tried to leave so that we can # retry if the group server is currently down. if self.is_mine_id(group_id): res = yield self.groups_server_handler.remove_user_from_group( - group_id, user_id, requester_user_id, content, + group_id, user_id, requester_user_id, content ) else: content["requester_user_id"] = requester_user_id res = yield self.transport_client.remove_user_from_group( - get_domain_from_id(group_id), group_id, requester_user_id, - user_id, content, + get_domain_from_id(group_id), + group_id, + requester_user_id, + user_id, + content, ) defer.returnValue(res) @@ -420,12 +414,9 @@ class GroupsLocalHandler(object): """ # TODO: Check if user in group token = yield self.store.register_user_group_membership( - group_id, user_id, - membership="leave", - ) - self.notifier.on_new_event( - "groups_key", token, users=[user_id], + group_id, user_id, membership="leave" ) + self.notifier.on_new_event("groups_key", token, users=[user_id]) @defer.inlineCallbacks def get_joined_groups(self, user_id): @@ -445,7 +436,7 @@ class GroupsLocalHandler(object): defer.returnValue({"groups": result}) else: bulk_result = yield self.transport_client.bulk_get_publicised_groups( - get_domain_from_id(user_id), [user_id], + get_domain_from_id(user_id), [user_id] ) result = bulk_result.get("users", {}).get(user_id) # TODO: Verify attestations @@ -460,9 +451,7 @@ class GroupsLocalHandler(object): if self.hs.is_mine_id(user_id): local_users.add(user_id) else: - destinations.setdefault( - get_domain_from_id(user_id), set() - ).add(user_id) + destinations.setdefault(get_domain_from_id(user_id), set()).add(user_id) if not proxy and destinations: raise SynapseError(400, "Some user_ids are not local") @@ -472,16 +461,14 @@ class GroupsLocalHandler(object): for destination, dest_user_ids in iteritems(destinations): try: r = yield self.transport_client.bulk_get_publicised_groups( - destination, list(dest_user_ids), + destination, list(dest_user_ids) ) results.update(r["users"]) except Exception: failed_results.extend(dest_user_ids) for uid in local_users: - results[uid] = yield self.store.get_publicised_groups_for_user( - uid - ) + results[uid] = yield self.store.get_publicised_groups_for_user(uid) # Check AS associated groups for this user - this depends on the # RegExps in the AS registration file (under `users`) diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 04caf65793..c82b1933f2 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -36,7 +36,6 @@ logger = logging.getLogger(__name__) class IdentityHandler(BaseHandler): - def __init__(self, hs): super(IdentityHandler, self).__init__(hs) @@ -64,40 +63,38 @@ class IdentityHandler(BaseHandler): @defer.inlineCallbacks def threepid_from_creds(self, creds): - if 'id_server' in creds: - id_server = creds['id_server'] - elif 'idServer' in creds: - id_server = creds['idServer'] + if "id_server" in creds: + id_server = creds["id_server"] + elif "idServer" in creds: + id_server = creds["idServer"] else: raise SynapseError(400, "No id_server in creds") - if 'client_secret' in creds: - client_secret = creds['client_secret'] - elif 'clientSecret' in creds: - client_secret = creds['clientSecret'] + if "client_secret" in creds: + client_secret = creds["client_secret"] + elif "clientSecret" in creds: + client_secret = creds["clientSecret"] else: raise SynapseError(400, "No client_secret in creds") if not self._should_trust_id_server(id_server): logger.warn( - '%s is not a trusted ID server: rejecting 3pid ' + - 'credentials', id_server + "%s is not a trusted ID server: rejecting 3pid " + "credentials", + id_server, ) defer.returnValue(None) try: data = yield self.http_client.get_json( - "https://%s%s" % ( - id_server, - "/_matrix/identity/api/v1/3pid/getValidated3pid" - ), - {'sid': creds['sid'], 'client_secret': client_secret} + "https://%s%s" + % (id_server, "/_matrix/identity/api/v1/3pid/getValidated3pid"), + {"sid": creds["sid"], "client_secret": client_secret}, ) except HttpResponseException as e: logger.info("getValidated3pid failed with Matrix error: %r", e) raise e.to_synapse_error() - if 'medium' in data: + if "medium" in data: defer.returnValue(data) defer.returnValue(None) @@ -106,30 +103,24 @@ class IdentityHandler(BaseHandler): logger.debug("binding threepid %r to %s", creds, mxid) data = None - if 'id_server' in creds: - id_server = creds['id_server'] - elif 'idServer' in creds: - id_server = creds['idServer'] + if "id_server" in creds: + id_server = creds["id_server"] + elif "idServer" in creds: + id_server = creds["idServer"] else: raise SynapseError(400, "No id_server in creds") - if 'client_secret' in creds: - client_secret = creds['client_secret'] - elif 'clientSecret' in creds: - client_secret = creds['clientSecret'] + if "client_secret" in creds: + client_secret = creds["client_secret"] + elif "clientSecret" in creds: + client_secret = creds["clientSecret"] else: raise SynapseError(400, "No client_secret in creds") try: data = yield self.http_client.post_urlencoded_get_json( - "https://%s%s" % ( - id_server, "/_matrix/identity/api/v1/3pid/bind" - ), - { - 'sid': creds['sid'], - 'client_secret': client_secret, - 'mxid': mxid, - } + "https://%s%s" % (id_server, "/_matrix/identity/api/v1/3pid/bind"), + {"sid": creds["sid"], "client_secret": client_secret, "mxid": mxid}, ) logger.debug("bound threepid %r to %s", creds, mxid) @@ -165,9 +156,7 @@ class IdentityHandler(BaseHandler): id_servers = [threepid["id_server"]] else: id_servers = yield self.store.get_id_servers_user_bound( - user_id=mxid, - medium=threepid["medium"], - address=threepid["address"], + user_id=mxid, medium=threepid["medium"], address=threepid["address"] ) # We don't know where to unbind, so we don't have a choice but to return @@ -177,7 +166,7 @@ class IdentityHandler(BaseHandler): changed = True for id_server in id_servers: changed &= yield self.try_unbind_threepid_with_id_server( - mxid, threepid, id_server, + mxid, threepid, id_server ) defer.returnValue(changed) @@ -201,10 +190,7 @@ class IdentityHandler(BaseHandler): url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,) content = { "mxid": mxid, - "threepid": { - "medium": threepid["medium"], - "address": threepid["address"], - }, + "threepid": {"medium": threepid["medium"], "address": threepid["address"]}, } # we abuse the federation http client to sign the request, but we have to send it @@ -212,25 +198,19 @@ class IdentityHandler(BaseHandler): # 'browser-like' HTTPS. auth_headers = self.federation_http_client.build_auth_headers( destination=None, - method='POST', - url_bytes='/_matrix/identity/api/v1/3pid/unbind'.encode('ascii'), + method="POST", + url_bytes="/_matrix/identity/api/v1/3pid/unbind".encode("ascii"), content=content, destination_is=id_server, ) - headers = { - b"Authorization": auth_headers, - } + headers = {b"Authorization": auth_headers} try: - yield self.http_client.post_json_get_json( - url, - content, - headers, - ) + yield self.http_client.post_json_get_json(url, content, headers) changed = True except HttpResponseException as e: changed = False - if e.code in (400, 404, 501,): + if e.code in (400, 404, 501): # The remote server probably doesn't support unbinding (yet) logger.warn("Received %d response while unbinding threepid", e.code) else: @@ -248,35 +228,27 @@ class IdentityHandler(BaseHandler): @defer.inlineCallbacks def requestEmailToken( - self, - id_server, - email, - client_secret, - send_attempt, - next_link=None, + self, id_server, email, client_secret, send_attempt, next_link=None ): if not self._should_trust_id_server(id_server): raise SynapseError( - 400, "Untrusted ID server '%s'" % id_server, - Codes.SERVER_NOT_TRUSTED + 400, "Untrusted ID server '%s'" % id_server, Codes.SERVER_NOT_TRUSTED ) params = { - 'email': email, - 'client_secret': client_secret, - 'send_attempt': send_attempt, + "email": email, + "client_secret": client_secret, + "send_attempt": send_attempt, } if next_link: - params.update({'next_link': next_link}) + params.update({"next_link": next_link}) try: data = yield self.http_client.post_json_get_json( - "https://%s%s" % ( - id_server, - "/_matrix/identity/api/v1/validate/email/requestToken" - ), - params + "https://%s%s" + % (id_server, "/_matrix/identity/api/v1/validate/email/requestToken"), + params, ) defer.returnValue(data) except HttpResponseException as e: @@ -285,30 +257,26 @@ class IdentityHandler(BaseHandler): @defer.inlineCallbacks def requestMsisdnToken( - self, id_server, country, phone_number, - client_secret, send_attempt, **kwargs + self, id_server, country, phone_number, client_secret, send_attempt, **kwargs ): if not self._should_trust_id_server(id_server): raise SynapseError( - 400, "Untrusted ID server '%s'" % id_server, - Codes.SERVER_NOT_TRUSTED + 400, "Untrusted ID server '%s'" % id_server, Codes.SERVER_NOT_TRUSTED ) params = { - 'country': country, - 'phone_number': phone_number, - 'client_secret': client_secret, - 'send_attempt': send_attempt, + "country": country, + "phone_number": phone_number, + "client_secret": client_secret, + "send_attempt": send_attempt, } params.update(kwargs) try: data = yield self.http_client.post_json_get_json( - "https://%s%s" % ( - id_server, - "/_matrix/identity/api/v1/validate/msisdn/requestToken" - ), - params + "https://%s%s" + % (id_server, "/_matrix/identity/api/v1/validate/msisdn/requestToken"), + params, ) defer.returnValue(data) except HttpResponseException as e: diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index aaee5db0b7..a1fe9d116f 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -44,8 +44,13 @@ class InitialSyncHandler(BaseHandler): self.snapshot_cache = SnapshotCache() self._event_serializer = hs.get_event_client_serializer() - def snapshot_all_rooms(self, user_id=None, pagin_config=None, - as_client_event=True, include_archived=False): + def snapshot_all_rooms( + self, + user_id=None, + pagin_config=None, + as_client_event=True, + include_archived=False, + ): """Retrieve a snapshot of all rooms the user is invited or has joined. This snapshot may include messages for all rooms where the user is @@ -77,13 +82,22 @@ class InitialSyncHandler(BaseHandler): if result is not None: return result - return self.snapshot_cache.set(now_ms, key, self._snapshot_all_rooms( - user_id, pagin_config, as_client_event, include_archived - )) + return self.snapshot_cache.set( + now_ms, + key, + self._snapshot_all_rooms( + user_id, pagin_config, as_client_event, include_archived + ), + ) @defer.inlineCallbacks - def _snapshot_all_rooms(self, user_id=None, pagin_config=None, - as_client_event=True, include_archived=False): + def _snapshot_all_rooms( + self, + user_id=None, + pagin_config=None, + as_client_event=True, + include_archived=False, + ): memberships = [Membership.INVITE, Membership.JOIN] if include_archived: @@ -128,8 +142,7 @@ class InitialSyncHandler(BaseHandler): "room_id": event.room_id, "membership": event.membership, "visibility": ( - "public" if event.room_id in public_room_ids - else "private" + "public" if event.room_id in public_room_ids else "private" ), } @@ -139,7 +152,7 @@ class InitialSyncHandler(BaseHandler): invite_event = yield self.store.get_event(event.event_id) d["invite"] = yield self._event_serializer.serialize_event( - invite_event, time_now, as_client_event, + invite_event, time_now, as_client_event ) rooms_ret.append(d) @@ -151,14 +164,12 @@ class InitialSyncHandler(BaseHandler): if event.membership == Membership.JOIN: room_end_token = now_token.room_key deferred_room_state = run_in_background( - self.state_handler.get_current_state, - event.room_id, + self.state_handler.get_current_state, event.room_id ) elif event.membership == Membership.LEAVE: room_end_token = "s%d" % (event.stream_ordering,) deferred_room_state = run_in_background( - self.store.get_state_for_events, - [event.event_id], + self.store.get_state_for_events, [event.event_id] ) deferred_room_state.addCallback( lambda states: states[event.event_id] @@ -178,9 +189,7 @@ class InitialSyncHandler(BaseHandler): ) ).addErrback(unwrapFirstError) - messages = yield filter_events_for_client( - self.store, user_id, messages - ) + messages = yield filter_events_for_client(self.store, user_id, messages) start_token = now_token.copy_and_replace("room_key", token) end_token = now_token.copy_and_replace("room_key", room_end_token) @@ -189,8 +198,7 @@ class InitialSyncHandler(BaseHandler): d["messages"] = { "chunk": ( yield self._event_serializer.serialize_events( - messages, time_now=time_now, - as_client_event=as_client_event, + messages, time_now=time_now, as_client_event=as_client_event ) ), "start": start_token.to_string(), @@ -200,23 +208,21 @@ class InitialSyncHandler(BaseHandler): d["state"] = yield self._event_serializer.serialize_events( current_state.values(), time_now=time_now, - as_client_event=as_client_event + as_client_event=as_client_event, ) account_data_events = [] tags = tags_by_room.get(event.room_id) if tags: - account_data_events.append({ - "type": "m.tag", - "content": {"tags": tags}, - }) + account_data_events.append( + {"type": "m.tag", "content": {"tags": tags}} + ) account_data = account_data_by_room.get(event.room_id, {}) for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) + account_data_events.append( + {"type": account_data_type, "content": content} + ) d["account_data"] = account_data_events except Exception: @@ -226,10 +232,7 @@ class InitialSyncHandler(BaseHandler): account_data_events = [] for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) + account_data_events.append({"type": account_data_type, "content": content}) now = self.clock.time_msec() @@ -274,7 +277,7 @@ class InitialSyncHandler(BaseHandler): user_id = requester.user.to_string() membership, member_event_id = yield self._check_in_room_or_world_readable( - room_id, user_id, + room_id, user_id ) is_peeking = member_event_id is None @@ -290,28 +293,21 @@ class InitialSyncHandler(BaseHandler): account_data_events = [] tags = yield self.store.get_tags_for_room(user_id, room_id) if tags: - account_data_events.append({ - "type": "m.tag", - "content": {"tags": tags}, - }) + account_data_events.append({"type": "m.tag", "content": {"tags": tags}}) account_data = yield self.store.get_account_data_for_room(user_id, room_id) for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) + account_data_events.append({"type": account_data_type, "content": content}) result["account_data"] = account_data_events defer.returnValue(result) @defer.inlineCallbacks - def _room_initial_sync_parted(self, user_id, room_id, pagin_config, - membership, member_event_id, is_peeking): - room_state = yield self.store.get_state_for_events( - [member_event_id], - ) + def _room_initial_sync_parted( + self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking + ): + room_state = yield self.store.get_state_for_events([member_event_id]) room_state = room_state[member_event_id] @@ -319,14 +315,10 @@ class InitialSyncHandler(BaseHandler): if limit is None: limit = 10 - stream_token = yield self.store.get_stream_token_for_event( - member_event_id - ) + stream_token = yield self.store.get_stream_token_for_event(member_event_id) messages, token = yield self.store.get_recent_events_for_room( - room_id, - limit=limit, - end_token=stream_token + room_id, limit=limit, end_token=stream_token ) messages = yield filter_events_for_client( @@ -338,34 +330,39 @@ class InitialSyncHandler(BaseHandler): time_now = self.clock.time_msec() - defer.returnValue({ - "membership": membership, - "room_id": room_id, - "messages": { - "chunk": (yield self._event_serializer.serialize_events( - messages, time_now, - )), - "start": start_token.to_string(), - "end": end_token.to_string(), - }, - "state": (yield self._event_serializer.serialize_events( - room_state.values(), time_now, - )), - "presence": [], - "receipts": [], - }) + defer.returnValue( + { + "membership": membership, + "room_id": room_id, + "messages": { + "chunk": ( + yield self._event_serializer.serialize_events( + messages, time_now + ) + ), + "start": start_token.to_string(), + "end": end_token.to_string(), + }, + "state": ( + yield self._event_serializer.serialize_events( + room_state.values(), time_now + ) + ), + "presence": [], + "receipts": [], + } + ) @defer.inlineCallbacks - def _room_initial_sync_joined(self, user_id, room_id, pagin_config, - membership, is_peeking): - current_state = yield self.state.get_current_state( - room_id=room_id, - ) + def _room_initial_sync_joined( + self, user_id, room_id, pagin_config, membership, is_peeking + ): + current_state = yield self.state.get_current_state(room_id=room_id) # TODO: These concurrently time_now = self.clock.time_msec() state = yield self._event_serializer.serialize_events( - current_state.values(), time_now, + current_state.values(), time_now ) now_token = yield self.hs.get_event_sources().get_current_token() @@ -375,7 +372,8 @@ class InitialSyncHandler(BaseHandler): limit = 10 room_members = [ - m for m in current_state.values() + m + for m in current_state.values() if m.type == EventTypes.Member and m.content["membership"] == Membership.JOIN ] @@ -389,8 +387,7 @@ class InitialSyncHandler(BaseHandler): defer.returnValue([]) states = yield presence_handler.get_states( - [m.user_id for m in room_members], - as_event=True, + [m.user_id for m in room_members], as_event=True ) defer.returnValue(states) @@ -398,8 +395,7 @@ class InitialSyncHandler(BaseHandler): @defer.inlineCallbacks def get_receipts(): receipts = yield self.store.get_linearized_receipts_for_room( - room_id, - to_key=now_token.receipt_key, + room_id, to_key=now_token.receipt_key ) if not receipts: receipts = [] @@ -415,14 +411,14 @@ class InitialSyncHandler(BaseHandler): room_id, limit=limit, end_token=now_token.room_key, - ) + ), ], consumeErrors=True, - ).addErrback(unwrapFirstError), + ).addErrback(unwrapFirstError) ) messages = yield filter_events_for_client( - self.store, user_id, messages, is_peeking=is_peeking, + self.store, user_id, messages, is_peeking=is_peeking ) start_token = now_token.copy_and_replace("room_key", token) @@ -433,9 +429,9 @@ class InitialSyncHandler(BaseHandler): ret = { "room_id": room_id, "messages": { - "chunk": (yield self._event_serializer.serialize_events( - messages, time_now, - )), + "chunk": ( + yield self._event_serializer.serialize_events(messages, time_now) + ), "start": start_token.to_string(), "end": end_token.to_string(), }, @@ -464,8 +460,8 @@ class InitialSyncHandler(BaseHandler): room_id, EventTypes.RoomHistoryVisibility, "" ) if ( - visibility and - visibility.content["history_visibility"] == "world_readable" + visibility + and visibility.content["history_visibility"] == "world_readable" ): defer.returnValue((Membership.JOIN, None)) return diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 7728ea230d..683da6bf32 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -61,8 +61,9 @@ class MessageHandler(object): self._event_serializer = hs.get_event_client_serializer() @defer.inlineCallbacks - def get_room_data(self, user_id=None, room_id=None, - event_type=None, state_key="", is_guest=False): + def get_room_data( + self, user_id=None, room_id=None, event_type=None, state_key="", is_guest=False + ): """ Get data from a room. Args: @@ -77,9 +78,7 @@ class MessageHandler(object): ) if membership == Membership.JOIN: - data = yield self.state.get_current_state( - room_id, event_type, state_key - ) + data = yield self.state.get_current_state(room_id, event_type, state_key) elif membership == Membership.LEAVE: key = (event_type, state_key) room_state = yield self.store.get_state_for_events( @@ -91,8 +90,12 @@ class MessageHandler(object): @defer.inlineCallbacks def get_state_events( - self, user_id, room_id, state_filter=StateFilter.all(), - at_token=None, is_guest=False, + self, + user_id, + room_id, + state_filter=StateFilter.all(), + at_token=None, + is_guest=False, ): """Retrieve all state events for a given room. If the user is joined to the room then return the current state. If the user has @@ -124,50 +127,48 @@ class MessageHandler(object): # does not reliably give you the state at the given stream position. # (https://github.com/matrix-org/synapse/issues/3305) last_events, _ = yield self.store.get_recent_events_for_room( - room_id, end_token=at_token.room_key, limit=1, + room_id, end_token=at_token.room_key, limit=1 ) if not last_events: - raise NotFoundError("Can't find event for token %s" % (at_token, )) + raise NotFoundError("Can't find event for token %s" % (at_token,)) visible_events = yield filter_events_for_client( - self.store, user_id, last_events, + self.store, user_id, last_events ) event = last_events[0] if visible_events: room_state = yield self.store.get_state_for_events( - [event.event_id], state_filter=state_filter, + [event.event_id], state_filter=state_filter ) room_state = room_state[event.event_id] else: raise AuthError( 403, - "User %s not allowed to view events in room %s at token %s" % ( - user_id, room_id, at_token, - ) + "User %s not allowed to view events in room %s at token %s" + % (user_id, room_id, at_token), ) else: membership, membership_event_id = ( - yield self.auth.check_in_room_or_world_readable( - room_id, user_id, - ) + yield self.auth.check_in_room_or_world_readable(room_id, user_id) ) if membership == Membership.JOIN: state_ids = yield self.store.get_filtered_current_state_ids( - room_id, state_filter=state_filter, + room_id, state_filter=state_filter ) room_state = yield self.store.get_events(state_ids.values()) elif membership == Membership.LEAVE: room_state = yield self.store.get_state_for_events( - [membership_event_id], state_filter=state_filter, + [membership_event_id], state_filter=state_filter ) room_state = room_state[membership_event_id] now = self.clock.time_msec() events = yield self._event_serializer.serialize_events( - room_state.values(), now, + room_state.values(), + now, # We don't bother bundling aggregations in when asked for state # events, as clients won't use them. bundle_aggregations=False, @@ -211,13 +212,15 @@ class MessageHandler(object): # Loop fell through, AS has no interested users in room raise AuthError(403, "Appservice not in room") - defer.returnValue({ - user_id: { - "avatar_url": profile.avatar_url, - "display_name": profile.display_name, + defer.returnValue( + { + user_id: { + "avatar_url": profile.avatar_url, + "display_name": profile.display_name, + } + for user_id, profile in iteritems(users_with_profile) } - for user_id, profile in iteritems(users_with_profile) - }) + ) class EventCreationHandler(object): @@ -269,14 +272,21 @@ class EventCreationHandler(object): self.clock.looping_call( lambda: run_as_background_process( "send_dummy_events_to_fill_extremities", - self._send_dummy_events_to_fill_extremities + self._send_dummy_events_to_fill_extremities, ), 5 * 60 * 1000, ) @defer.inlineCallbacks - def create_event(self, requester, event_dict, token_id=None, txn_id=None, - prev_events_and_hashes=None, require_consent=True): + def create_event( + self, + requester, + event_dict, + token_id=None, + txn_id=None, + prev_events_and_hashes=None, + require_consent=True, + ): """ Given a dict from a client, create a new event. @@ -336,8 +346,7 @@ class EventCreationHandler(object): content["avatar_url"] = yield profile.get_avatar_url(target) except Exception as e: logger.info( - "Failed to get profile information for %r: %s", - target, e + "Failed to get profile information for %r: %s", target, e ) is_exempt = yield self._is_exempt_from_privacy_policy(builder, requester) @@ -373,16 +382,17 @@ class EventCreationHandler(object): prev_event = yield self.store.get_event(prev_event_id, allow_none=True) if not prev_event or prev_event.membership != Membership.JOIN: logger.warning( - ("Attempt to send `m.room.aliases` in room %s by user %s but" - " membership is %s"), + ( + "Attempt to send `m.room.aliases` in room %s by user %s but" + " membership is %s" + ), event.room_id, event.sender, prev_event.membership if prev_event else None, ) raise AuthError( - 403, - "You must be in the room to create an alias for it", + 403, "You must be in the room to create an alias for it" ) self.validator.validate_new(event) @@ -449,8 +459,8 @@ class EventCreationHandler(object): # exempt the system notices user if ( - self.config.server_notices_mxid is not None and - user_id == self.config.server_notices_mxid + self.config.server_notices_mxid is not None + and user_id == self.config.server_notices_mxid ): return @@ -463,15 +473,10 @@ class EventCreationHandler(object): return consent_uri = self._consent_uri_builder.build_user_consent_uri( - requester.user.localpart, - ) - msg = self._block_events_without_consent_error % { - 'consent_uri': consent_uri, - } - raise ConsentNotGivenError( - msg=msg, - consent_uri=consent_uri, + requester.user.localpart ) + msg = self._block_events_without_consent_error % {"consent_uri": consent_uri} + raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri) @defer.inlineCallbacks def send_nonmember_event(self, requester, event, context, ratelimit=True): @@ -486,8 +491,7 @@ class EventCreationHandler(object): """ if event.type == EventTypes.Member: raise SynapseError( - 500, - "Tried to send member event through non-member codepath" + 500, "Tried to send member event through non-member codepath" ) user = UserID.from_string(event.sender) @@ -499,15 +503,13 @@ class EventCreationHandler(object): if prev_state is not None: logger.info( "Not bothering to persist state event %s duplicated by %s", - event.event_id, prev_state.event_id, + event.event_id, + prev_state.event_id, ) defer.returnValue(prev_state) yield self.handle_new_client_event( - requester=requester, - event=event, - context=context, - ratelimit=ratelimit, + requester=requester, event=event, context=context, ratelimit=ratelimit ) @defer.inlineCallbacks @@ -533,11 +535,7 @@ class EventCreationHandler(object): @defer.inlineCallbacks def create_and_send_nonmember_event( - self, - requester, - event_dict, - ratelimit=True, - txn_id=None + self, requester, event_dict, ratelimit=True, txn_id=None ): """ Creates an event, then sends it. @@ -552,32 +550,25 @@ class EventCreationHandler(object): # taking longer. with (yield self.limiter.queue(event_dict["room_id"])): event, context = yield self.create_event( - requester, - event_dict, - token_id=requester.access_token_id, - txn_id=txn_id + requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id ) spam_error = self.spam_checker.check_event_for_spam(event) if spam_error: if not isinstance(spam_error, string_types): spam_error = "Spam is not permitted here" - raise SynapseError( - 403, spam_error, Codes.FORBIDDEN - ) + raise SynapseError(403, spam_error, Codes.FORBIDDEN) yield self.send_nonmember_event( - requester, - event, - context, - ratelimit=ratelimit, + requester, event, context, ratelimit=ratelimit ) defer.returnValue(event) @measure_func("create_new_client_event") @defer.inlineCallbacks - def create_new_client_event(self, builder, requester=None, - prev_events_and_hashes=None): + def create_new_client_event( + self, builder, requester=None, prev_events_and_hashes=None + ): """Create a new event for a local client Args: @@ -597,22 +588,21 @@ class EventCreationHandler(object): """ if prev_events_and_hashes is not None: - assert len(prev_events_and_hashes) <= 10, \ - "Attempting to create an event with %i prev_events" % ( - len(prev_events_and_hashes), + assert len(prev_events_and_hashes) <= 10, ( + "Attempting to create an event with %i prev_events" + % (len(prev_events_and_hashes),) ) else: - prev_events_and_hashes = \ - yield self.store.get_prev_events_for_room(builder.room_id) + prev_events_and_hashes = yield self.store.get_prev_events_for_room( + builder.room_id + ) prev_events = [ (event_id, prev_hashes) for event_id, prev_hashes, _ in prev_events_and_hashes ] - event = yield builder.build( - prev_event_ids=[p for p, _ in prev_events], - ) + event = yield builder.build(prev_event_ids=[p for p, _ in prev_events]) context = yield self.state.compute_event_context(event) if requester: context.app_service = requester.app_service @@ -628,29 +618,19 @@ class EventCreationHandler(object): aggregation_key = relation["key"] already_exists = yield self.store.has_user_annotated_event( - relates_to, event.type, aggregation_key, event.sender, + relates_to, event.type, aggregation_key, event.sender ) if already_exists: raise SynapseError(400, "Can't send same reaction twice") - logger.debug( - "Created event %s", - event.event_id, - ) + logger.debug("Created event %s", event.event_id) - defer.returnValue( - (event, context,) - ) + defer.returnValue((event, context)) @measure_func("handle_new_client_event") @defer.inlineCallbacks def handle_new_client_event( - self, - requester, - event, - context, - ratelimit=True, - extra_users=[], + self, requester, event, context, ratelimit=True, extra_users=[] ): """Processes a new event. This includes checking auth, persisting it, notifying users, sending to remote servers, etc. @@ -666,19 +646,20 @@ class EventCreationHandler(object): extra_users (list(UserID)): Any extra users to notify about event """ - if event.is_state() and (event.type, event.state_key) == (EventTypes.Create, ""): - room_version = event.content.get( - "room_version", RoomVersions.V1.identifier - ) + if event.is_state() and (event.type, event.state_key) == ( + EventTypes.Create, + "", + ): + room_version = event.content.get("room_version", RoomVersions.V1.identifier) else: room_version = yield self.store.get_room_version(event.room_id) event_allowed = yield self.third_party_event_rules.check_event_allowed( - event, context, + event, context ) if not event_allowed: raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN, + 403, "This event is not allowed in this context", Codes.FORBIDDEN ) try: @@ -695,9 +676,7 @@ class EventCreationHandler(object): logger.exception("Failed to encode content: %r", event.content) raise - yield self.action_generator.handle_push_actions_for_event( - event, context - ) + yield self.action_generator.handle_push_actions_for_event(event, context) # reraise does not allow inlineCallbacks to preserve the stacktrace, so we # hack around with a try/finally instead. @@ -718,11 +697,7 @@ class EventCreationHandler(object): return yield self.persist_and_notify_client_event( - requester, - event, - context, - ratelimit=ratelimit, - extra_users=extra_users, + requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) success = True @@ -731,18 +706,12 @@ class EventCreationHandler(object): # Ensure that we actually remove the entries in the push actions # staging area, if we calculated them. run_in_background( - self.store.remove_push_actions_from_staging, - event.event_id, + self.store.remove_push_actions_from_staging, event.event_id ) @defer.inlineCallbacks def persist_and_notify_client_event( - self, - requester, - event, - context, - ratelimit=True, - extra_users=[], + self, requester, event, context, ratelimit=True, extra_users=[] ): """Called when we have fully built the event, have already calculated the push actions for the event, and checked auth. @@ -767,20 +736,16 @@ class EventCreationHandler(object): if mapping["room_id"] != event.room_id: raise SynapseError( 400, - "Room alias %s does not point to the room" % ( - room_alias_str, - ) + "Room alias %s does not point to the room" % (room_alias_str,), ) federation_handler = self.hs.get_handlers().federation_handler if event.type == EventTypes.Member: if event.content["membership"] == Membership.INVITE: + def is_inviter_member_event(e): - return ( - e.type == EventTypes.Member and - e.sender == event.sender - ) + return e.type == EventTypes.Member and e.sender == event.sender current_state_ids = yield context.get_current_state_ids(self.store) @@ -810,26 +775,21 @@ class EventCreationHandler(object): # to get them to sign the event. returned_invite = yield federation_handler.send_invite( - invitee.domain, - event, + invitee.domain, event ) event.unsigned.pop("room_state", None) # TODO: Make sure the signatures actually are correct. - event.signatures.update( - returned_invite.signatures - ) + event.signatures.update(returned_invite.signatures) if event.type == EventTypes.Redaction: prev_state_ids = yield context.get_prev_state_ids(self.store) auth_events_ids = yield self.auth.compute_auth_events( - event, prev_state_ids, for_verification=True, + event, prev_state_ids, for_verification=True ) auth_events = yield self.store.get_events(auth_events_ids) - auth_events = { - (e.type, e.state_key): e for e in auth_events.values() - } + auth_events = {(e.type, e.state_key): e for e in auth_events.values()} room_version = yield self.store.get_room_version(event.room_id) if self.auth.check_redaction(room_version, event, auth_events=auth_events): original_event = yield self.store.get_event( @@ -837,13 +797,10 @@ class EventCreationHandler(object): check_redacted=False, get_prev_content=False, allow_rejected=False, - allow_none=False + allow_none=False, ) if event.user_id != original_event.user_id: - raise AuthError( - 403, - "You don't have permission to redact events" - ) + raise AuthError(403, "You don't have permission to redact events") # We've already checked. event.internal_metadata.recheck_redaction = False @@ -851,24 +808,18 @@ class EventCreationHandler(object): if event.type == EventTypes.Create: prev_state_ids = yield context.get_prev_state_ids(self.store) if prev_state_ids: - raise AuthError( - 403, - "Changing the room create event is forbidden", - ) + raise AuthError(403, "Changing the room create event is forbidden") (event_stream_id, max_stream_id) = yield self.store.persist_event( event, context=context ) - yield self.pusher_pool.on_new_notifications( - event_stream_id, max_stream_id, - ) + yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) def _notify(): try: self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, - extra_users=extra_users + event, event_stream_id, max_stream_id, extra_users=extra_users ) except Exception: logger.exception("Error notifying about new room event") @@ -895,23 +846,19 @@ class EventCreationHandler(object): """ room_ids = yield self.store.get_rooms_with_many_extremities( - min_count=10, limit=5, + min_count=10, limit=5 ) for room_id in room_ids: # For each room we need to find a joined member we can use to send # the dummy event with. - prev_events_and_hashes = yield self.store.get_prev_events_for_room( - room_id, - ) + prev_events_and_hashes = yield self.store.get_prev_events_for_room(room_id) - latest_event_ids = ( - event_id for (event_id, _, _) in prev_events_and_hashes - ) + latest_event_ids = (event_id for (event_id, _, _) in prev_events_and_hashes) members = yield self.state.get_current_users_in_room( - room_id, latest_event_ids=latest_event_ids, + room_id, latest_event_ids=latest_event_ids ) user_id = None @@ -941,9 +888,4 @@ class EventCreationHandler(object): event.internal_metadata.proactively_send = False - yield self.send_nonmember_event( - requester, - event, - context, - ratelimit=False, - ) + yield self.send_nonmember_event(requester, event, context, ratelimit=False) diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 8f811e24fe..062e026e5f 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -55,9 +55,7 @@ class PurgeStatus(object): self.status = PurgeStatus.STATUS_ACTIVE def asdict(self): - return { - "status": PurgeStatus.STATUS_TEXT[self.status] - } + return {"status": PurgeStatus.STATUS_TEXT[self.status]} class PaginationHandler(object): @@ -79,8 +77,7 @@ class PaginationHandler(object): self._purges_by_id = {} self._event_serializer = hs.get_event_client_serializer() - def start_purge_history(self, room_id, token, - delete_local_events=False): + def start_purge_history(self, room_id, token, delete_local_events=False): """Start off a history purge on a room. Args: @@ -95,8 +92,7 @@ class PaginationHandler(object): """ if room_id in self._purges_in_progress_by_room: raise SynapseError( - 400, - "History purge already in progress for %s" % (room_id, ), + 400, "History purge already in progress for %s" % (room_id,) ) purge_id = random_string(16) @@ -107,14 +103,12 @@ class PaginationHandler(object): self._purges_by_id[purge_id] = PurgeStatus() run_in_background( - self._purge_history, - purge_id, room_id, token, delete_local_events, + self._purge_history, purge_id, room_id, token, delete_local_events ) return purge_id @defer.inlineCallbacks - def _purge_history(self, purge_id, room_id, token, - delete_local_events): + def _purge_history(self, purge_id, room_id, token, delete_local_events): """Carry out a history purge on a room. Args: @@ -130,16 +124,13 @@ class PaginationHandler(object): self._purges_in_progress_by_room.add(room_id) try: with (yield self.pagination_lock.write(room_id)): - yield self.store.purge_history( - room_id, token, delete_local_events, - ) + yield self.store.purge_history(room_id, token, delete_local_events) logger.info("[purge] complete") self._purges_by_id[purge_id].status = PurgeStatus.STATUS_COMPLETE except Exception: f = Failure() logger.error( - "[purge] failed", - exc_info=(f.type, f.value, f.getTracebackObject()), + "[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject()) ) self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED finally: @@ -148,6 +139,7 @@ class PaginationHandler(object): # remove the purge from the list 24 hours after it completes def clear_purge(): del self._purges_by_id[purge_id] + self.hs.get_reactor().callLater(24 * 3600, clear_purge) def get_purge_status(self, purge_id): @@ -162,8 +154,14 @@ class PaginationHandler(object): return self._purges_by_id.get(purge_id) @defer.inlineCallbacks - def get_messages(self, requester, room_id=None, pagin_config=None, - as_client_event=True, event_filter=None): + def get_messages( + self, + requester, + room_id=None, + pagin_config=None, + as_client_event=True, + event_filter=None, + ): """Get messages in a room. Args: @@ -201,7 +199,7 @@ class PaginationHandler(object): room_id, user_id ) - if source_config.direction == 'b': + if source_config.direction == "b": # if we're going backwards, we might need to backfill. This # requires that we have a topo token. if room_token.topological: @@ -235,27 +233,24 @@ class PaginationHandler(object): event_filter=event_filter, ) - next_token = pagin_config.from_token.copy_and_replace( - "room_key", next_key - ) + next_token = pagin_config.from_token.copy_and_replace("room_key", next_key) if events: if event_filter: events = event_filter.filter(events) events = yield filter_events_for_client( - self.store, - user_id, - events, - is_peeking=(member_event_id is None), + self.store, user_id, events, is_peeking=(member_event_id is None) ) if not events: - defer.returnValue({ - "chunk": [], - "start": pagin_config.from_token.to_string(), - "end": next_token.to_string(), - }) + defer.returnValue( + { + "chunk": [], + "start": pagin_config.from_token.to_string(), + "end": next_token.to_string(), + } + ) state = None if event_filter and event_filter.lazy_load_members() and len(events) > 0: @@ -263,12 +258,11 @@ class PaginationHandler(object): # FIXME: we also care about invite targets etc. state_filter = StateFilter.from_types( - (EventTypes.Member, event.sender) - for event in events + (EventTypes.Member, event.sender) for event in events ) state_ids = yield self.store.get_state_ids_for_event( - events[0].event_id, state_filter=state_filter, + events[0].event_id, state_filter=state_filter ) if state_ids: @@ -280,8 +274,7 @@ class PaginationHandler(object): chunk = { "chunk": ( yield self._event_serializer.serialize_events( - events, time_now, - as_client_event=as_client_event, + events, time_now, as_client_event=as_client_event ) ), "start": pagin_config.from_token.to_string(), @@ -291,8 +284,7 @@ class PaginationHandler(object): if state: chunk["state"] = ( yield self._event_serializer.serialize_events( - state, time_now, - as_client_event=as_client_event, + state, time_now, as_client_event=as_client_event ) ) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 557fb5f83d..5204073a38 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -50,16 +50,20 @@ logger = logging.getLogger(__name__) notified_presence_counter = Counter("synapse_handler_presence_notified_presence", "") federation_presence_out_counter = Counter( - "synapse_handler_presence_federation_presence_out", "") + "synapse_handler_presence_federation_presence_out", "" +) presence_updates_counter = Counter("synapse_handler_presence_presence_updates", "") timers_fired_counter = Counter("synapse_handler_presence_timers_fired", "") -federation_presence_counter = Counter("synapse_handler_presence_federation_presence", "") +federation_presence_counter = Counter( + "synapse_handler_presence_federation_presence", "" +) bump_active_time_counter = Counter("synapse_handler_presence_bump_active_time", "") get_updates_counter = Counter("synapse_handler_presence_get_updates", "", ["type"]) notify_reason_counter = Counter( - "synapse_handler_presence_notify_reason", "", ["reason"]) + "synapse_handler_presence_notify_reason", "", ["reason"] +) state_transition_counter = Counter( "synapse_handler_presence_state_transition", "", ["from", "to"] ) @@ -90,7 +94,6 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER class PresenceHandler(object): - def __init__(self, hs): """ @@ -110,31 +113,26 @@ class PresenceHandler(object): federation_registry = hs.get_federation_registry() - federation_registry.register_edu_handler( - "m.presence", self.incoming_presence - ) + federation_registry.register_edu_handler("m.presence", self.incoming_presence) active_presence = self.store.take_presence_startup_info() # A dictionary of the current state of users. This is prefilled with # non-offline presence from the DB. We should fetch from the DB if # we can't find a users presence in here. - self.user_to_current_state = { - state.user_id: state - for state in active_presence - } + self.user_to_current_state = {state.user_id: state for state in active_presence} LaterGauge( - "synapse_handlers_presence_user_to_current_state_size", "", [], - lambda: len(self.user_to_current_state) + "synapse_handlers_presence_user_to_current_state_size", + "", + [], + lambda: len(self.user_to_current_state), ) now = self.clock.time_msec() for state in active_presence: self.wheel_timer.insert( - now=now, - obj=state.user_id, - then=state.last_active_ts + IDLE_TIMER, + now=now, obj=state.user_id, then=state.last_active_ts + IDLE_TIMER ) self.wheel_timer.insert( now=now, @@ -193,27 +191,21 @@ class PresenceHandler(object): "handle_presence_timeouts", self._handle_timeouts ) - self.clock.call_later( - 30, - self.clock.looping_call, - run_timeout_handler, - 5000, - ) + self.clock.call_later(30, self.clock.looping_call, run_timeout_handler, 5000) def run_persister(): return run_as_background_process( "persist_presence_changes", self._persist_unpersisted_changes ) - self.clock.call_later( - 60, - self.clock.looping_call, - run_persister, - 60 * 1000, - ) + self.clock.call_later(60, self.clock.looping_call, run_persister, 60 * 1000) - LaterGauge("synapse_handlers_presence_wheel_timer_size", "", [], - lambda: len(self.wheel_timer)) + LaterGauge( + "synapse_handlers_presence_wheel_timer_size", + "", + [], + lambda: len(self.wheel_timer), + ) # Used to handle sending of presence to newly joined users/servers if hs.config.use_presence: @@ -241,15 +233,17 @@ class PresenceHandler(object): logger.info( "Performing _on_shutdown. Persisting %d unpersisted changes", - len(self.user_to_current_state) + len(self.user_to_current_state), ) if self.unpersisted_users_changes: - yield self.store.update_presence([ - self.user_to_current_state[user_id] - for user_id in self.unpersisted_users_changes - ]) + yield self.store.update_presence( + [ + self.user_to_current_state[user_id] + for user_id in self.unpersisted_users_changes + ] + ) logger.info("Finished _on_shutdown") @defer.inlineCallbacks @@ -261,13 +255,10 @@ class PresenceHandler(object): self.unpersisted_users_changes = set() if unpersisted: - logger.info( - "Persisting %d upersisted presence updates", len(unpersisted) + logger.info("Persisting %d upersisted presence updates", len(unpersisted)) + yield self.store.update_presence( + [self.user_to_current_state[user_id] for user_id in unpersisted] ) - yield self.store.update_presence([ - self.user_to_current_state[user_id] - for user_id in unpersisted - ]) @defer.inlineCallbacks def _update_states(self, new_states): @@ -303,10 +294,11 @@ class PresenceHandler(object): ) new_state, should_notify, should_ping = handle_update( - prev_state, new_state, + prev_state, + new_state, is_mine=self.is_mine_id(user_id), wheel_timer=self.wheel_timer, - now=now + now=now, ) self.user_to_current_state[user_id] = new_state @@ -328,7 +320,8 @@ class PresenceHandler(object): self.unpersisted_users_changes -= set(to_notify.keys()) to_federation_ping = { - user_id: state for user_id, state in to_federation_ping.items() + user_id: state + for user_id, state in to_federation_ping.items() if user_id not in to_notify } if to_federation_ping: @@ -351,8 +344,8 @@ class PresenceHandler(object): # Check whether the lists of syncing processes from an external # process have expired. expired_process_ids = [ - process_id for process_id, last_update - in self.external_process_last_updated_ms.items() + process_id + for process_id, last_update in self.external_process_last_updated_ms.items() if now - last_update > EXTERNAL_PROCESS_EXPIRY ] for process_id in expired_process_ids: @@ -362,9 +355,7 @@ class PresenceHandler(object): self.external_process_last_update.pop(process_id) states = [ - self.user_to_current_state.get( - user_id, UserPresenceState.default(user_id) - ) + self.user_to_current_state.get(user_id, UserPresenceState.default(user_id)) for user_id in users_to_check ] @@ -394,9 +385,7 @@ class PresenceHandler(object): prev_state = yield self.current_state_for_user(user_id) - new_fields = { - "last_active_ts": self.clock.time_msec(), - } + new_fields = {"last_active_ts": self.clock.time_msec()} if prev_state.state == PresenceState.UNAVAILABLE: new_fields["state"] = PresenceState.ONLINE @@ -430,15 +419,23 @@ class PresenceHandler(object): if prev_state.state == PresenceState.OFFLINE: # If they're currently offline then bring them online, otherwise # just update the last sync times. - yield self._update_states([prev_state.copy_and_replace( - state=PresenceState.ONLINE, - last_active_ts=self.clock.time_msec(), - last_user_sync_ts=self.clock.time_msec(), - )]) + yield self._update_states( + [ + prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=self.clock.time_msec(), + last_user_sync_ts=self.clock.time_msec(), + ) + ] + ) else: - yield self._update_states([prev_state.copy_and_replace( - last_user_sync_ts=self.clock.time_msec(), - )]) + yield self._update_states( + [ + prev_state.copy_and_replace( + last_user_sync_ts=self.clock.time_msec() + ) + ] + ) @defer.inlineCallbacks def _end(): @@ -446,9 +443,13 @@ class PresenceHandler(object): self.user_to_num_current_syncs[user_id] -= 1 prev_state = yield self.current_state_for_user(user_id) - yield self._update_states([prev_state.copy_and_replace( - last_user_sync_ts=self.clock.time_msec(), - )]) + yield self._update_states( + [ + prev_state.copy_and_replace( + last_user_sync_ts=self.clock.time_msec() + ) + ] + ) except Exception: logger.exception("Error updating presence after sync") @@ -469,7 +470,8 @@ class PresenceHandler(object): """ if self.hs.config.use_presence: syncing_user_ids = { - user_id for user_id, count in self.user_to_num_current_syncs.items() + user_id + for user_id, count in self.user_to_num_current_syncs.items() if count } for user_ids in self.external_process_to_current_syncs.values(): @@ -479,7 +481,9 @@ class PresenceHandler(object): return set() @defer.inlineCallbacks - def update_external_syncs_row(self, process_id, user_id, is_syncing, sync_time_msec): + def update_external_syncs_row( + self, process_id, user_id, is_syncing, sync_time_msec + ): """Update the syncing users for an external process as a delta. Args: @@ -500,20 +504,22 @@ class PresenceHandler(object): updates = [] if is_syncing and user_id not in process_presence: if prev_state.state == PresenceState.OFFLINE: - updates.append(prev_state.copy_and_replace( - state=PresenceState.ONLINE, - last_active_ts=sync_time_msec, - last_user_sync_ts=sync_time_msec, - )) + updates.append( + prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=sync_time_msec, + last_user_sync_ts=sync_time_msec, + ) + ) else: - updates.append(prev_state.copy_and_replace( - last_user_sync_ts=sync_time_msec, - )) + updates.append( + prev_state.copy_and_replace(last_user_sync_ts=sync_time_msec) + ) process_presence.add(user_id) elif user_id in process_presence: - updates.append(prev_state.copy_and_replace( - last_user_sync_ts=sync_time_msec, - )) + updates.append( + prev_state.copy_and_replace(last_user_sync_ts=sync_time_msec) + ) if not is_syncing: process_presence.discard(user_id) @@ -537,12 +543,12 @@ class PresenceHandler(object): prev_states = yield self.current_state_for_users(process_presence) time_now_ms = self.clock.time_msec() - yield self._update_states([ - prev_state.copy_and_replace( - last_user_sync_ts=time_now_ms, - ) - for prev_state in itervalues(prev_states) - ]) + yield self._update_states( + [ + prev_state.copy_and_replace(last_user_sync_ts=time_now_ms) + for prev_state in itervalues(prev_states) + ] + ) self.external_process_last_updated_ms.pop(process_id, None) @defer.inlineCallbacks @@ -574,8 +580,7 @@ class PresenceHandler(object): missing = [user_id for user_id, state in iteritems(states) if not state] if missing: new = { - user_id: UserPresenceState.default(user_id) - for user_id in missing + user_id: UserPresenceState.default(user_id) for user_id in missing } states.update(new) self.user_to_current_state.update(new) @@ -593,8 +598,10 @@ class PresenceHandler(object): room_ids_to_states, users_to_states = parties self.notifier.on_new_event( - "presence_key", stream_id, rooms=room_ids_to_states.keys(), - users=[UserID.from_string(u) for u in users_to_states] + "presence_key", + stream_id, + rooms=room_ids_to_states.keys(), + users=[UserID.from_string(u) for u in users_to_states], ) self._push_to_remotes(states) @@ -605,8 +612,10 @@ class PresenceHandler(object): room_ids_to_states, users_to_states = parties self.notifier.on_new_event( - "presence_key", stream_id, rooms=room_ids_to_states.keys(), - users=[UserID.from_string(u) for u in users_to_states] + "presence_key", + stream_id, + rooms=room_ids_to_states.keys(), + users=[UserID.from_string(u) for u in users_to_states], ) def _push_to_remotes(self, states): @@ -631,15 +640,15 @@ class PresenceHandler(object): user_id = push.get("user_id", None) if not user_id: logger.info( - "Got presence update from %r with no 'user_id': %r", - origin, push, + "Got presence update from %r with no 'user_id': %r", origin, push ) continue if get_domain_from_id(user_id) != origin: logger.info( "Got presence update from %r with bad 'user_id': %r", - origin, user_id, + origin, + user_id, ) continue @@ -647,14 +656,12 @@ class PresenceHandler(object): if not presence_state: logger.info( "Got presence update from %r with no 'presence_state': %r", - origin, push, + origin, + push, ) continue - new_fields = { - "state": presence_state, - "last_federation_update_ts": now, - } + new_fields = {"state": presence_state, "last_federation_update_ts": now} last_active_ago = push.get("last_active_ago", None) if last_active_ago is not None: @@ -672,10 +679,7 @@ class PresenceHandler(object): @defer.inlineCallbacks def get_state(self, target_user, as_event=False): - results = yield self.get_states( - [target_user.to_string()], - as_event=as_event, - ) + results = yield self.get_states([target_user.to_string()], as_event=as_event) defer.returnValue(results[0]) @@ -699,13 +703,15 @@ class PresenceHandler(object): now = self.clock.time_msec() if as_event: - defer.returnValue([ - { - "type": "m.presence", - "content": format_user_presence_state(state, now), - } - for state in updates - ]) + defer.returnValue( + [ + { + "type": "m.presence", + "content": format_user_presence_state(state, now), + } + for state in updates + ] + ) else: defer.returnValue(updates) @@ -717,7 +723,9 @@ class PresenceHandler(object): presence = state["presence"] valid_presence = ( - PresenceState.ONLINE, PresenceState.UNAVAILABLE, PresenceState.OFFLINE + PresenceState.ONLINE, + PresenceState.UNAVAILABLE, + PresenceState.OFFLINE, ) if presence not in valid_presence: raise SynapseError(400, "Invalid presence state") @@ -726,9 +734,7 @@ class PresenceHandler(object): prev_state = yield self.current_state_for_user(user_id) - new_fields = { - "state": presence - } + new_fields = {"state": presence} if not ignore_status_msg: msg = status_msg if presence != PresenceState.OFFLINE else None @@ -877,8 +883,7 @@ class PresenceHandler(object): hosts = set(host for host in hosts if host != self.server_name) self.federation.send_presence_to_destinations( - states=[state], - destinations=hosts, + states=[state], destinations=hosts ) else: # A remote user has joined the room, so we need to: @@ -904,7 +909,8 @@ class PresenceHandler(object): # default state. now = self.clock.time_msec() states = [ - state for state in states.values() + state + for state in states.values() if state.state != PresenceState.OFFLINE or now - state.last_active_ts < 7 * 24 * 60 * 60 * 1000 or state.status_msg is not None @@ -912,8 +918,7 @@ class PresenceHandler(object): if states: self.federation.send_presence_to_destinations( - states=states, - destinations=[get_domain_from_id(user_id)], + states=states, destinations=[get_domain_from_id(user_id)] ) @@ -937,7 +942,10 @@ def should_notify(old_state, new_state): notify_reason_counter.labels("current_active_change").inc() return True - if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: + if ( + new_state.last_active_ts - old_state.last_active_ts + > LAST_ACTIVE_GRANULARITY + ): # Only notify about last active bumps if we're not currently acive if not new_state.currently_active: notify_reason_counter.labels("last_active_change_online").inc() @@ -958,9 +966,7 @@ def format_user_presence_state(state, now, include_user_id=True): The "user_id" is optional so that this function can be used to format presence updates for client /sync responses and for federation /send requests. """ - content = { - "presence": state.state, - } + content = {"presence": state.state} if include_user_id: content["user_id"] = state.user_id if state.last_active_ts: @@ -986,8 +992,15 @@ class PresenceEventSource(object): @defer.inlineCallbacks @log_function - def get_new_events(self, user, from_key, room_ids=None, include_offline=True, - explicit_room_id=None, **kwargs): + def get_new_events( + self, + user, + from_key, + room_ids=None, + include_offline=True, + explicit_room_id=None, + **kwargs + ): # The process for getting presence events are: # 1. Get the rooms the user is in. # 2. Get the list of user in the rooms. @@ -1030,7 +1043,7 @@ class PresenceEventSource(object): if from_key: user_ids_changed = stream_change_cache.get_entities_changed( - users_interested_in, from_key, + users_interested_in, from_key ) else: user_ids_changed = users_interested_in @@ -1040,10 +1053,16 @@ class PresenceEventSource(object): if include_offline: defer.returnValue((list(updates.values()), max_token)) else: - defer.returnValue(([ - s for s in itervalues(updates) - if s.state != PresenceState.OFFLINE - ], max_token)) + defer.returnValue( + ( + [ + s + for s in itervalues(updates) + if s.state != PresenceState.OFFLINE + ], + max_token, + ) + ) def get_current_key(self): return self.store.get_current_presence_token() @@ -1061,13 +1080,13 @@ class PresenceEventSource(object): users_interested_in.add(user_id) # So that we receive our own presence users_who_share_room = yield self.store.get_users_who_share_room_with_user( - user_id, on_invalidate=cache_context.invalidate, + user_id, on_invalidate=cache_context.invalidate ) users_interested_in.update(users_who_share_room) if explicit_room_id: user_ids = yield self.store.get_users_in_room( - explicit_room_id, on_invalidate=cache_context.invalidate, + explicit_room_id, on_invalidate=cache_context.invalidate ) users_interested_in.update(user_ids) @@ -1123,9 +1142,7 @@ def handle_timeout(state, is_mine, syncing_user_ids, now): if now - state.last_active_ts > IDLE_TIMER: # Currently online, but last activity ages ago so auto # idle - state = state.copy_and_replace( - state=PresenceState.UNAVAILABLE, - ) + state = state.copy_and_replace(state=PresenceState.UNAVAILABLE) changed = True elif now - state.last_active_ts > LAST_ACTIVE_GRANULARITY: # So that we send down a notification that we've @@ -1145,8 +1162,7 @@ def handle_timeout(state, is_mine, syncing_user_ids, now): sync_or_active = max(state.last_user_sync_ts, state.last_active_ts) if now - sync_or_active > SYNC_ONLINE_TIMEOUT: state = state.copy_and_replace( - state=PresenceState.OFFLINE, - status_msg=None, + state=PresenceState.OFFLINE, status_msg=None ) changed = True else: @@ -1155,10 +1171,7 @@ def handle_timeout(state, is_mine, syncing_user_ids, now): # no one gets stuck online forever. if now - state.last_federation_update_ts > FEDERATION_TIMEOUT: # The other side seems to have disappeared. - state = state.copy_and_replace( - state=PresenceState.OFFLINE, - status_msg=None, - ) + state = state.copy_and_replace(state=PresenceState.OFFLINE, status_msg=None) changed = True return state if changed else None @@ -1193,21 +1206,17 @@ def handle_update(prev_state, new_state, is_mine, wheel_timer, now): if new_state.state == PresenceState.ONLINE: # Idle timer wheel_timer.insert( - now=now, - obj=user_id, - then=new_state.last_active_ts + IDLE_TIMER + now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER ) active = now - new_state.last_active_ts < LAST_ACTIVE_GRANULARITY - new_state = new_state.copy_and_replace( - currently_active=active, - ) + new_state = new_state.copy_and_replace(currently_active=active) if active: wheel_timer.insert( now=now, obj=user_id, - then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY + then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY, ) if new_state.state != PresenceState.OFFLINE: @@ -1215,29 +1224,25 @@ def handle_update(prev_state, new_state, is_mine, wheel_timer, now): wheel_timer.insert( now=now, obj=user_id, - then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT + then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT, ) last_federate = new_state.last_federation_update_ts if now - last_federate > FEDERATION_PING_INTERVAL: # Been a while since we've poked remote servers - new_state = new_state.copy_and_replace( - last_federation_update_ts=now, - ) + new_state = new_state.copy_and_replace(last_federation_update_ts=now) federation_ping = True else: wheel_timer.insert( now=now, obj=user_id, - then=new_state.last_federation_update_ts + FEDERATION_TIMEOUT + then=new_state.last_federation_update_ts + FEDERATION_TIMEOUT, ) # Check whether the change was something worth notifying about if should_notify(prev_state, new_state): - new_state = new_state.copy_and_replace( - last_federation_update_ts=now, - ) + new_state = new_state.copy_and_replace(last_federation_update_ts=now) persist_and_notify = True return new_state, persist_and_notify, federation_ping diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 3e04233394..d8462b75ec 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -73,18 +73,13 @@ class BaseProfileHandler(BaseHandler): raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) raise - defer.returnValue({ - "displayname": displayname, - "avatar_url": avatar_url, - }) + defer.returnValue({"displayname": displayname, "avatar_url": avatar_url}) else: try: result = yield self.federation.make_query( destination=target_user.domain, query_type="profile", - args={ - "user_id": user_id, - }, + args={"user_id": user_id}, ignore_backoff=True, ) defer.returnValue(result) @@ -113,10 +108,7 @@ class BaseProfileHandler(BaseHandler): raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) raise - defer.returnValue({ - "displayname": displayname, - "avatar_url": avatar_url, - }) + defer.returnValue({"displayname": displayname, "avatar_url": avatar_url}) else: profile = yield self.store.get_from_remote_profile_cache(user_id) defer.returnValue(profile or {}) @@ -139,10 +131,7 @@ class BaseProfileHandler(BaseHandler): result = yield self.federation.make_query( destination=target_user.domain, query_type="profile", - args={ - "user_id": target_user.to_string(), - "field": "displayname", - }, + args={"user_id": target_user.to_string(), "field": "displayname"}, ignore_backoff=True, ) except RequestSendFailed as e: @@ -170,15 +159,13 @@ class BaseProfileHandler(BaseHandler): if len(new_displayname) > MAX_DISPLAYNAME_LEN: raise SynapseError( - 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN, ), + 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,) ) - if new_displayname == '': + if new_displayname == "": new_displayname = None - yield self.store.set_profile_displayname( - target_user.localpart, new_displayname - ) + yield self.store.set_profile_displayname(target_user.localpart, new_displayname) if self.hs.config.user_directory_search_all_users: profile = yield self.store.get_profileinfo(target_user.localpart) @@ -205,10 +192,7 @@ class BaseProfileHandler(BaseHandler): result = yield self.federation.make_query( destination=target_user.domain, query_type="profile", - args={ - "user_id": target_user.to_string(), - "field": "avatar_url", - }, + args={"user_id": target_user.to_string(), "field": "avatar_url"}, ignore_backoff=True, ) except RequestSendFailed as e: @@ -230,12 +214,10 @@ class BaseProfileHandler(BaseHandler): if len(new_avatar_url) > MAX_AVATAR_URL_LEN: raise SynapseError( - 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN, ), + 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,) ) - yield self.store.set_profile_avatar_url( - target_user.localpart, new_avatar_url - ) + yield self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url) if self.hs.config.user_directory_search_all_users: profile = yield self.store.get_profileinfo(target_user.localpart) @@ -278,9 +260,7 @@ class BaseProfileHandler(BaseHandler): yield self.ratelimit(requester) - room_ids = yield self.store.get_rooms_for_user( - target_user.to_string(), - ) + room_ids = yield self.store.get_rooms_for_user(target_user.to_string()) for room_id in room_ids: handler = self.hs.get_room_member_handler() @@ -296,8 +276,7 @@ class BaseProfileHandler(BaseHandler): ) except Exception as e: logger.warn( - "Failed to update join event for room %s - %s", - room_id, str(e) + "Failed to update join event for room %s - %s", room_id, str(e) ) @defer.inlineCallbacks @@ -325,11 +304,9 @@ class BaseProfileHandler(BaseHandler): return try: - requester_rooms = yield self.store.get_rooms_for_user( - requester.to_string() - ) + requester_rooms = yield self.store.get_rooms_for_user(requester.to_string()) target_user_rooms = yield self.store.get_rooms_for_user( - target_user.to_string(), + target_user.to_string() ) # Check if the room lists have no elements in common. @@ -353,12 +330,12 @@ class MasterProfileHandler(BaseProfileHandler): assert hs.config.worker_app is None self.clock.looping_call( - self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS, + self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS ) def _start_update_remote_profile_cache(self): return run_as_background_process( - "Update remote profile", self._update_remote_profile_cache, + "Update remote profile", self._update_remote_profile_cache ) @defer.inlineCallbacks @@ -372,7 +349,7 @@ class MasterProfileHandler(BaseProfileHandler): for user_id, displayname, avatar_url in entries: is_subscribed = yield self.store.is_subscribed_remote_profile_for_user( - user_id, + user_id ) if not is_subscribed: yield self.store.maybe_delete_remote_profile_cache(user_id) @@ -382,9 +359,7 @@ class MasterProfileHandler(BaseProfileHandler): profile = yield self.federation.make_query( destination=get_domain_from_id(user_id), query_type="profile", - args={ - "user_id": user_id, - }, + args={"user_id": user_id}, ignore_backoff=True, ) except Exception: @@ -399,6 +374,4 @@ class MasterProfileHandler(BaseProfileHandler): new_avatar = profile.get("avatar_url") # We always hit update to update the last_check timestamp - yield self.store.update_remote_profile_cache( - user_id, new_name, new_avatar - ) + yield self.store.update_remote_profile_cache(user_id, new_name, new_avatar) diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py index 32108568c6..3e4d8c93a4 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py @@ -43,7 +43,7 @@ class ReadMarkerHandler(BaseHandler): with (yield self.read_marker_linearizer.queue((room_id, user_id))): existing_read_marker = yield self.store.get_account_data_for_room_and_type( - user_id, room_id, "m.fully_read", + user_id, room_id, "m.fully_read" ) should_update = True @@ -51,14 +51,11 @@ class ReadMarkerHandler(BaseHandler): if existing_read_marker: # Only update if the new marker is ahead in the stream should_update = yield self.store.is_event_after( - event_id, - existing_read_marker['event_id'] + event_id, existing_read_marker["event_id"] ) if should_update: - content = { - "event_id": event_id - } + content = {"event_id": event_id} max_id = yield self.store.add_account_data_to_room( user_id, room_id, "m.fully_read", content ) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 274d2946ad..a85dd8cdee 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -88,19 +88,16 @@ class ReceiptsHandler(BaseHandler): affected_room_ids = list(set([r.room_id for r in receipts])) - self.notifier.on_new_event( - "receipt_key", max_batch_id, rooms=affected_room_ids - ) + self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids) # Note that the min here shouldn't be relied upon to be accurate. yield self.hs.get_pusherpool().on_new_receipts( - min_batch_id, max_batch_id, affected_room_ids, + min_batch_id, max_batch_id, affected_room_ids ) defer.returnValue(True) @defer.inlineCallbacks - def received_client_receipt(self, room_id, receipt_type, user_id, - event_id): + def received_client_receipt(self, room_id, receipt_type, user_id, event_id): """Called when a client tells us a local user has read up to the given event_id in the room. """ @@ -109,9 +106,7 @@ class ReceiptsHandler(BaseHandler): receipt_type=receipt_type, user_id=user_id, event_ids=[event_id], - data={ - "ts": int(self.clock.time_msec()), - }, + data={"ts": int(self.clock.time_msec())}, ) is_new = yield self._handle_new_receipts([receipt]) @@ -125,8 +120,7 @@ class ReceiptsHandler(BaseHandler): """Gets all receipts for a room, upto the given key. """ result = yield self.store.get_linearized_receipts_for_room( - room_id, - to_key=to_key, + room_id, to_key=to_key ) if not result: @@ -148,14 +142,12 @@ class ReceiptEventSource(object): defer.returnValue(([], to_key)) events = yield self.store.get_linearized_receipts_for_rooms( - room_ids, - from_key=from_key, - to_key=to_key, + room_ids, from_key=from_key, to_key=to_key ) defer.returnValue((events, to_key)) - def get_current_key(self, direction='f'): + def get_current_key(self, direction="f"): return self.store.get_max_receipt_stream_id() @defer.inlineCallbacks @@ -169,9 +161,7 @@ class ReceiptEventSource(object): room_ids = yield self.store.get_rooms_for_user(user.to_string()) events = yield self.store.get_linearized_receipts_for_rooms( - room_ids, - from_key=from_key, - to_key=to_key, + room_ids, from_key=from_key, to_key=to_key ) defer.returnValue((events, to_key)) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 9a388ea013..e487b90c08 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -47,7 +47,6 @@ logger = logging.getLogger(__name__) class RegistrationHandler(BaseHandler): - def __init__(self, hs): """ @@ -69,44 +68,37 @@ class RegistrationHandler(BaseHandler): self.macaroon_gen = hs.get_macaroon_generator() self._generate_user_id_linearizer = Linearizer( - name="_generate_user_id_linearizer", + name="_generate_user_id_linearizer" ) self._server_notices_mxid = hs.config.server_notices_mxid if hs.config.worker_app: self._register_client = ReplicationRegisterServlet.make_client(hs) - self._register_device_client = ( - RegisterDeviceReplicationServlet.make_client(hs) + self._register_device_client = RegisterDeviceReplicationServlet.make_client( + hs ) - self._post_registration_client = ( - ReplicationPostRegisterActionsServlet.make_client(hs) + self._post_registration_client = ReplicationPostRegisterActionsServlet.make_client( + hs ) else: self.device_handler = hs.get_device_handler() self.pusher_pool = hs.get_pusherpool() @defer.inlineCallbacks - def check_username(self, localpart, guest_access_token=None, - assigned_user_id=None): + def check_username(self, localpart, guest_access_token=None, assigned_user_id=None): if types.contains_invalid_mxid_characters(localpart): raise SynapseError( 400, "User ID can only contain characters a-z, 0-9, or '=_-./'", - Codes.INVALID_USERNAME + Codes.INVALID_USERNAME, ) if not localpart: - raise SynapseError( - 400, - "User ID cannot be empty", - Codes.INVALID_USERNAME - ) + raise SynapseError(400, "User ID cannot be empty", Codes.INVALID_USERNAME) - if localpart[0] == '_': + if localpart[0] == "_": raise SynapseError( - 400, - "User ID may not begin with _", - Codes.INVALID_USERNAME + 400, "User ID may not begin with _", Codes.INVALID_USERNAME ) user = UserID(localpart, self.hs.hostname) @@ -126,19 +118,15 @@ class RegistrationHandler(BaseHandler): if len(user_id) > MAX_USERID_LENGTH: raise SynapseError( 400, - "User ID may not be longer than %s characters" % ( - MAX_USERID_LENGTH, - ), - Codes.INVALID_USERNAME + "User ID may not be longer than %s characters" % (MAX_USERID_LENGTH,), + Codes.INVALID_USERNAME, ) users = yield self.store.get_users_by_id_case_insensitive(user_id) if users: if not guest_access_token: raise SynapseError( - 400, - "User ID already taken.", - errcode=Codes.USER_IN_USE, + 400, "User ID already taken.", errcode=Codes.USER_IN_USE ) user_data = yield self.auth.get_user_by_access_token(guest_access_token) if not user_data["is_guest"] or user_data["user"].localpart != localpart: @@ -203,8 +191,7 @@ class RegistrationHandler(BaseHandler): try: int(localpart) raise RegistrationError( - 400, - "Numeric user IDs are reserved for guest users." + 400, "Numeric user IDs are reserved for guest users." ) except ValueError: pass @@ -283,9 +270,7 @@ class RegistrationHandler(BaseHandler): } # Bind email to new account - yield self._register_email_threepid( - user_id, threepid_dict, None, False, - ) + yield self._register_email_threepid(user_id, threepid_dict, None, False) defer.returnValue((user_id, token)) @@ -318,8 +303,8 @@ class RegistrationHandler(BaseHandler): room_alias = RoomAlias.from_string(r) if self.hs.hostname != room_alias.domain: logger.warning( - 'Cannot create room alias %s, ' - 'it does not match server domain', + "Cannot create room alias %s, " + "it does not match server domain", r, ) else: @@ -332,7 +317,7 @@ class RegistrationHandler(BaseHandler): fake_requester, config={ "preset": "public_chat", - "room_alias_name": room_alias_localpart + "room_alias_name": room_alias_localpart, }, ratelimit=False, ) @@ -364,8 +349,9 @@ class RegistrationHandler(BaseHandler): raise AuthError(403, "Invalid application service token.") if not service.is_interested_in_user(user_id): raise SynapseError( - 400, "Invalid user localpart for this application service.", - errcode=Codes.EXCLUSIVE + 400, + "Invalid user localpart for this application service.", + errcode=Codes.EXCLUSIVE, ) service_id = service.id if service.is_exclusive_user(user_id) else None @@ -391,17 +377,15 @@ class RegistrationHandler(BaseHandler): """ captcha_response = yield self._validate_captcha( - ip, - private_key, - challenge, - response + ip, private_key, challenge, response ) if not captcha_response["valid"]: - logger.info("Invalid captcha entered from %s. Error: %s", - ip, captcha_response["error_url"]) - raise InvalidCaptchaError( - error_url=captcha_response["error_url"] + logger.info( + "Invalid captcha entered from %s. Error: %s", + ip, + captcha_response["error_url"], ) + raise InvalidCaptchaError(error_url=captcha_response["error_url"]) else: logger.info("Valid captcha entered from %s", ip) @@ -414,8 +398,11 @@ class RegistrationHandler(BaseHandler): """ for c in threepidCreds: - logger.info("validating threepidcred sid %s on id server %s", - c['sid'], c['idServer']) + logger.info( + "validating threepidcred sid %s on id server %s", + c["sid"], + c["idServer"], + ) try: threepid = yield self.identity_handler.threepid_from_creds(c) except Exception: @@ -424,13 +411,14 @@ class RegistrationHandler(BaseHandler): if not threepid: raise RegistrationError(400, "Couldn't validate 3pid") - logger.info("got threepid with medium '%s' and address '%s'", - threepid['medium'], threepid['address']) + logger.info( + "got threepid with medium '%s' and address '%s'", + threepid["medium"], + threepid["address"], + ) - if not check_3pid_allowed(self.hs, threepid['medium'], threepid['address']): - raise RegistrationError( - 403, "Third party identifier is not allowed" - ) + if not check_3pid_allowed(self.hs, threepid["medium"], threepid["address"]): + raise RegistrationError(403, "Third party identifier is not allowed") @defer.inlineCallbacks def bind_emails(self, user_id, threepidCreds): @@ -449,23 +437,23 @@ class RegistrationHandler(BaseHandler): if self._server_notices_mxid is not None: if user_id == self._server_notices_mxid: raise SynapseError( - 400, "This user ID is reserved.", - errcode=Codes.EXCLUSIVE + 400, "This user ID is reserved.", errcode=Codes.EXCLUSIVE ) # valid user IDs must not clash with any user ID namespaces claimed by # application services. services = self.store.get_app_services() interested_services = [ - s for s in services - if s.is_interested_in_user(user_id) - and s != allowed_appservice + s + for s in services + if s.is_interested_in_user(user_id) and s != allowed_appservice ] for service in interested_services: if service.is_exclusive_user(user_id): raise SynapseError( - 400, "This user ID is reserved by an application service.", - errcode=Codes.EXCLUSIVE + 400, + "This user ID is reserved by an application service.", + errcode=Codes.EXCLUSIVE, ) @defer.inlineCallbacks @@ -491,14 +479,13 @@ class RegistrationHandler(BaseHandler): dict: Containing 'valid'(bool) and 'error_url'(str) if invalid. """ - response = yield self._submit_captcha(ip_addr, private_key, challenge, - response) + response = yield self._submit_captcha(ip_addr, private_key, challenge, response) # parse Google's response. Lovely format.. - lines = response.split('\n') + lines = response.split("\n") json = { - "valid": lines[0] == 'true', - "error_url": "http://www.recaptcha.net/recaptcha/api/challenge?" + - "error=%s" % lines[1] + "valid": lines[0] == "true", + "error_url": "http://www.recaptcha.net/recaptcha/api/challenge?" + + "error=%s" % lines[1], } defer.returnValue(json) @@ -510,17 +497,16 @@ class RegistrationHandler(BaseHandler): data = yield self.captcha_client.post_urlencoded_get_raw( "http://www.recaptcha.net:80/recaptcha/api/verify", args={ - 'privatekey': private_key, - 'remoteip': ip_addr, - 'challenge': challenge, - 'response': response - } + "privatekey": private_key, + "remoteip": ip_addr, + "challenge": challenge, + "response": response, + }, ) defer.returnValue(data) @defer.inlineCallbacks - def get_or_create_user(self, requester, localpart, displayname, - password_hash=None): + def get_or_create_user(self, requester, localpart, displayname, password_hash=None): """Creates a new user if the user does not exist, else revokes all previous access tokens and generates a new one. @@ -565,7 +551,7 @@ class RegistrationHandler(BaseHandler): if displayname is not None: logger.info("setting user display name: %s -> %s", user_id, displayname) yield self.profile_handler.set_displayname( - user, requester, displayname, by_admin=True, + user, requester, displayname, by_admin=True ) defer.returnValue((user_id, token)) @@ -587,15 +573,12 @@ class RegistrationHandler(BaseHandler): """ access_token = yield self.store.get_3pid_guest_access_token(medium, address) if access_token: - user_info = yield self.auth.get_user_by_access_token( - access_token - ) + user_info = yield self.auth.get_user_by_access_token(access_token) defer.returnValue((user_info["user"].to_string(), access_token)) user_id, access_token = yield self.register( - generate_token=True, - make_guest=True + generate_token=True, make_guest=True ) access_token = yield self.store.save_or_get_3pid_guest_access_token( medium, address, access_token, inviter_user_id @@ -616,9 +599,9 @@ class RegistrationHandler(BaseHandler): ) room_id = room_id.to_string() else: - raise SynapseError(400, "%s was not legal room ID or room alias" % ( - room_identifier, - )) + raise SynapseError( + 400, "%s was not legal room ID or room alias" % (room_identifier,) + ) yield room_member_handler.update_membership( requester=requester, @@ -629,10 +612,19 @@ class RegistrationHandler(BaseHandler): ratelimit=False, ) - def register_with_store(self, user_id, token=None, password_hash=None, - was_guest=False, make_guest=False, appservice_id=None, - create_profile_with_displayname=None, admin=False, - user_type=None, address=None): + def register_with_store( + self, + user_id, + token=None, + password_hash=None, + was_guest=False, + make_guest=False, + appservice_id=None, + create_profile_with_displayname=None, + admin=False, + user_type=None, + address=None, + ): """Register user in the datastore. Args: @@ -661,14 +653,15 @@ class RegistrationHandler(BaseHandler): time_now = self.clock.time() allowed, time_allowed = self.ratelimiter.can_do_action( - address, time_now_s=time_now, + address, + time_now_s=time_now, rate_hz=self.hs.config.rc_registration.per_second, burst_count=self.hs.config.rc_registration.burst_count, ) if not allowed: raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now)), + retry_after_ms=int(1000 * (time_allowed - time_now)) ) if self.hs.config.worker_app: @@ -698,8 +691,7 @@ class RegistrationHandler(BaseHandler): ) @defer.inlineCallbacks - def register_device(self, user_id, device_id, initial_display_name, - is_guest=False): + def register_device(self, user_id, device_id, initial_display_name, is_guest=False): """Register a device for a user and generate an access token. Args: @@ -732,14 +724,15 @@ class RegistrationHandler(BaseHandler): ) else: access_token = yield self._auth_handler.get_access_token_for_user_id( - user_id, device_id=device_id, + user_id, device_id=device_id ) defer.returnValue((device_id, access_token)) @defer.inlineCallbacks - def post_registration_actions(self, user_id, auth_result, access_token, - bind_email, bind_msisdn): + def post_registration_actions( + self, user_id, auth_result, access_token, bind_email, bind_msisdn + ): """A user has completed registration Args: @@ -773,20 +766,15 @@ class RegistrationHandler(BaseHandler): yield self.store.upsert_monthly_active_user(user_id) yield self._register_email_threepid( - user_id, threepid, access_token, - bind_email, + user_id, threepid, access_token, bind_email ) if auth_result and LoginType.MSISDN in auth_result: threepid = auth_result[LoginType.MSISDN] - yield self._register_msisdn_threepid( - user_id, threepid, bind_msisdn, - ) + yield self._register_msisdn_threepid(user_id, threepid, bind_msisdn) if auth_result and LoginType.TERMS in auth_result: - yield self._on_user_consented( - user_id, self.hs.config.user_consent_version, - ) + yield self._on_user_consented(user_id, self.hs.config.user_consent_version) @defer.inlineCallbacks def _on_user_consented(self, user_id, consent_version): @@ -798,9 +786,7 @@ class RegistrationHandler(BaseHandler): consented to. """ logger.info("%s has consented to the privacy policy", user_id) - yield self.store.user_set_consent_version( - user_id, consent_version, - ) + yield self.store.user_set_consent_version(user_id, consent_version) yield self.post_consent_actions(user_id) @defer.inlineCallbacks @@ -824,33 +810,30 @@ class RegistrationHandler(BaseHandler): Returns: defer.Deferred: """ - reqd = ('medium', 'address', 'validated_at') + reqd = ("medium", "address", "validated_at") if any(x not in threepid for x in reqd): # This will only happen if the ID server returns a malformed response logger.info("Can't add incomplete 3pid") return yield self._auth_handler.add_threepid( - user_id, - threepid['medium'], - threepid['address'], - threepid['validated_at'], + user_id, threepid["medium"], threepid["address"], threepid["validated_at"] ) # And we add an email pusher for them by default, but only # if email notifications are enabled (so people don't start # getting mail spam where they weren't before if email # notifs are set up on a home server) - if (self.hs.config.email_enable_notifs and - self.hs.config.email_notif_for_new_users - and token): + if ( + self.hs.config.email_enable_notifs + and self.hs.config.email_notif_for_new_users + and token + ): # Pull the ID of the access token back out of the db # It would really make more sense for this to be passed # up when the access token is saved, but that's quite an # invasive change I'd rather do separately. - user_tuple = yield self.store.get_user_by_access_token( - token - ) + user_tuple = yield self.store.get_user_by_access_token(token) token_id = user_tuple["token_id"] yield self.pusher_pool.add_pusher( @@ -867,11 +850,9 @@ class RegistrationHandler(BaseHandler): if bind_email: logger.info("bind_email specified: binding") - logger.debug("Binding emails %s to %s" % ( - threepid, user_id - )) + logger.debug("Binding emails %s to %s" % (threepid, user_id)) yield self.identity_handler.bind_threepid( - threepid['threepid_creds'], user_id + threepid["threepid_creds"], user_id ) else: logger.info("bind_email not specified: not binding email") @@ -894,7 +875,7 @@ class RegistrationHandler(BaseHandler): defer.Deferred: """ try: - assert_params_in_dict(threepid, ['medium', 'address', 'validated_at']) + assert_params_in_dict(threepid, ["medium", "address", "validated_at"]) except SynapseError as ex: if ex.errcode == Codes.MISSING_PARAM: # This will only happen if the ID server returns a malformed response @@ -903,17 +884,14 @@ class RegistrationHandler(BaseHandler): raise yield self._auth_handler.add_threepid( - user_id, - threepid['medium'], - threepid['address'], - threepid['validated_at'], + user_id, threepid["medium"], threepid["address"], threepid["validated_at"] ) if bind_msisdn: logger.info("bind_msisdn specified: binding") logger.debug("Binding msisdn %s to %s", threepid, user_id) yield self.identity_handler.bind_threepid( - threepid['threepid_creds'], user_id + threepid["threepid_creds"], user_id ) else: logger.info("bind_msisdn not specified: not binding msisdn") diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 74793bab33..89d89fc27c 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -101,7 +101,7 @@ class RoomCreationHandler(BaseHandler): if r is None: raise NotFoundError("Unknown room id %s" % (old_room_id,)) new_room_id = yield self._generate_room_id( - creator_id=user_id, is_public=r["is_public"], + creator_id=user_id, is_public=r["is_public"] ) logger.info("Creating new room %s to replace %s", new_room_id, old_room_id) @@ -110,7 +110,8 @@ class RoomCreationHandler(BaseHandler): # room, to check our user has perms in the old room. tombstone_event, tombstone_context = ( yield self.event_creation_handler.create_event( - requester, { + requester, + { "type": EventTypes.Tombstone, "state_key": "", "room_id": old_room_id, @@ -118,14 +119,14 @@ class RoomCreationHandler(BaseHandler): "content": { "body": "This room has been replaced", "replacement_room": new_room_id, - } + }, }, token_id=requester.access_token_id, ) ) old_room_version = yield self.store.get_room_version(old_room_id) yield self.auth.check_from_context( - old_room_version, tombstone_event, tombstone_context, + old_room_version, tombstone_event, tombstone_context ) yield self.clone_existing_room( @@ -138,27 +139,27 @@ class RoomCreationHandler(BaseHandler): # now send the tombstone yield self.event_creation_handler.send_nonmember_event( - requester, tombstone_event, tombstone_context, + requester, tombstone_event, tombstone_context ) old_room_state = yield tombstone_context.get_current_state_ids(self.store) # update any aliases yield self._move_aliases_to_new_room( - requester, old_room_id, new_room_id, old_room_state, + requester, old_room_id, new_room_id, old_room_state ) # and finally, shut down the PLs in the old room, and update them in the new # room. yield self._update_upgraded_room_pls( - requester, old_room_id, new_room_id, old_room_state, + requester, old_room_id, new_room_id, old_room_state ) defer.returnValue(new_room_id) @defer.inlineCallbacks def _update_upgraded_room_pls( - self, requester, old_room_id, new_room_id, old_room_state, + self, requester, old_room_id, new_room_id, old_room_state ): """Send updated power levels in both rooms after an upgrade @@ -176,7 +177,7 @@ class RoomCreationHandler(BaseHandler): if old_room_pl_event_id is None: logger.warning( "Not supported: upgrading a room with no PL event. Not setting PLs " - "in old room.", + "in old room." ) return @@ -197,45 +198,48 @@ class RoomCreationHandler(BaseHandler): if current < restricted_level: logger.info( "Setting level for %s in %s to %i (was %i)", - v, old_room_id, restricted_level, current, + v, + old_room_id, + restricted_level, + current, ) pl_content[v] = restricted_level updated = True else: - logger.info( - "Not setting level for %s (already %i)", - v, current, - ) + logger.info("Not setting level for %s (already %i)", v, current) if updated: try: yield self.event_creation_handler.create_and_send_nonmember_event( - requester, { + requester, + { "type": EventTypes.PowerLevels, - "state_key": '', + "state_key": "", "room_id": old_room_id, "sender": requester.user.to_string(), "content": pl_content, - }, ratelimit=False, + }, + ratelimit=False, ) except AuthError as e: logger.warning("Unable to update PLs in old room: %s", e) logger.info("Setting correct PLs in new room") yield self.event_creation_handler.create_and_send_nonmember_event( - requester, { + requester, + { "type": EventTypes.PowerLevels, - "state_key": '', + "state_key": "", "room_id": new_room_id, "sender": requester.user.to_string(), "content": old_room_pl_state.content, - }, ratelimit=False, + }, + ratelimit=False, ) @defer.inlineCallbacks def clone_existing_room( - self, requester, old_room_id, new_room_id, new_room_version, - tombstone_event_id, + self, requester, old_room_id, new_room_id, new_room_version, tombstone_event_id ): """Populate a new room based on an old room @@ -257,10 +261,7 @@ class RoomCreationHandler(BaseHandler): creation_content = { "room_version": new_room_version, - "predecessor": { - "room_id": old_room_id, - "event_id": tombstone_event_id, - } + "predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id}, } # Check if old room was non-federatable @@ -289,7 +290,7 @@ class RoomCreationHandler(BaseHandler): ) old_room_state_ids = yield self.store.get_filtered_current_state_ids( - old_room_id, StateFilter.from_types(types_to_copy), + old_room_id, StateFilter.from_types(types_to_copy) ) # map from event_id to BaseEvent old_room_state_events = yield self.store.get_events(old_room_state_ids.values()) @@ -302,11 +303,9 @@ class RoomCreationHandler(BaseHandler): yield self._send_events_for_new_room( requester, new_room_id, - # we expect to override all the presets with initial_state, so this is # somewhat arbitrary. preset_config=RoomCreationPreset.PRIVATE_CHAT, - invite_list=[], initial_state=initial_state, creation_content=creation_content, @@ -314,20 +313,22 @@ class RoomCreationHandler(BaseHandler): # Transfer membership events old_room_member_state_ids = yield self.store.get_filtered_current_state_ids( - old_room_id, StateFilter.from_types([(EventTypes.Member, None)]), + old_room_id, StateFilter.from_types([(EventTypes.Member, None)]) ) # map from event_id to BaseEvent old_room_member_state_events = yield self.store.get_events( - old_room_member_state_ids.values(), + old_room_member_state_ids.values() ) for k, old_event in iteritems(old_room_member_state_events): # Only transfer ban events - if ("membership" in old_event.content and - old_event.content["membership"] == "ban"): + if ( + "membership" in old_event.content + and old_event.content["membership"] == "ban" + ): yield self.room_member_handler.update_membership( requester, - UserID.from_string(old_event['state_key']), + UserID.from_string(old_event["state_key"]), new_room_id, "ban", ratelimit=False, @@ -339,7 +340,7 @@ class RoomCreationHandler(BaseHandler): @defer.inlineCallbacks def _move_aliases_to_new_room( - self, requester, old_room_id, new_room_id, old_room_state, + self, requester, old_room_id, new_room_id, old_room_state ): directory_handler = self.hs.get_handlers().directory_handler @@ -370,14 +371,11 @@ class RoomCreationHandler(BaseHandler): alias = RoomAlias.from_string(alias_str) try: yield directory_handler.delete_association( - requester, alias, send_event=False, + requester, alias, send_event=False ) removed_aliases.append(alias_str) except SynapseError as e: - logger.warning( - "Unable to remove alias %s from old room: %s", - alias, e, - ) + logger.warning("Unable to remove alias %s from old room: %s", alias, e) # if we didn't find any aliases, or couldn't remove anyway, we can skip the rest # of this. @@ -393,30 +391,26 @@ class RoomCreationHandler(BaseHandler): # as when you remove an alias from the directory normally - it just means that # the aliases event gets out of sync with the directory # (cf https://github.com/vector-im/riot-web/issues/2369) - yield directory_handler.send_room_alias_update_event( - requester, old_room_id, - ) + yield directory_handler.send_room_alias_update_event(requester, old_room_id) except AuthError as e: - logger.warning( - "Failed to send updated alias event on old room: %s", e, - ) + logger.warning("Failed to send updated alias event on old room: %s", e) # we can now add any aliases we successfully removed to the new room. for alias in removed_aliases: try: yield directory_handler.create_association( - requester, RoomAlias.from_string(alias), - new_room_id, servers=(self.hs.hostname, ), - send_event=False, check_membership=False, + requester, + RoomAlias.from_string(alias), + new_room_id, + servers=(self.hs.hostname,), + send_event=False, + check_membership=False, ) logger.info("Moved alias %s to new room", alias) except SynapseError as e: # I'm not really expecting this to happen, but it could if the spam # checking module decides it shouldn't, or similar. - logger.error( - "Error adding alias %s to new room: %s", - alias, e, - ) + logger.error("Error adding alias %s to new room: %s", alias, e) try: if canonical_alias and (canonical_alias in removed_aliases): @@ -427,24 +421,19 @@ class RoomCreationHandler(BaseHandler): "state_key": "", "room_id": new_room_id, "sender": requester.user.to_string(), - "content": {"alias": canonical_alias, }, + "content": {"alias": canonical_alias}, }, - ratelimit=False + ratelimit=False, ) - yield directory_handler.send_room_alias_update_event( - requester, new_room_id, - ) + yield directory_handler.send_room_alias_update_event(requester, new_room_id) except SynapseError as e: # again I'm not really expecting this to fail, but if it does, I'd rather # we returned the new room to the client at this point. - logger.error( - "Unable to send updated alias events in new room: %s", e, - ) + logger.error("Unable to send updated alias events in new room: %s", e) @defer.inlineCallbacks - def create_room(self, requester, config, ratelimit=True, - creator_join_profile=None): + def create_room(self, requester, config, ratelimit=True, creator_join_profile=None): """ Creates a new room. Args: @@ -474,25 +463,23 @@ class RoomCreationHandler(BaseHandler): yield self.auth.check_auth_blocking(user_id) - if (self._server_notices_mxid is not None and - requester.user.to_string() == self._server_notices_mxid): + if ( + self._server_notices_mxid is not None + and requester.user.to_string() == self._server_notices_mxid + ): # allow the server notices mxid to create rooms is_requester_admin = True else: - is_requester_admin = yield self.auth.is_server_admin( - requester.user, - ) + is_requester_admin = yield self.auth.is_server_admin(requester.user) # Check whether the third party rules allows/changes the room create # request. yield self.third_party_event_rules.on_create_room( - requester, - config, - is_requester_admin=is_requester_admin, + requester, config, is_requester_admin=is_requester_admin ) if not is_requester_admin and not self.spam_checker.user_may_create_room( - user_id, + user_id ): raise SynapseError(403, "You are not permitted to create rooms") @@ -500,16 +487,11 @@ class RoomCreationHandler(BaseHandler): yield self.ratelimit(requester) room_version = config.get( - "room_version", - self.config.default_room_version.identifier, + "room_version", self.config.default_room_version.identifier ) if not isinstance(room_version, string_types): - raise SynapseError( - 400, - "room_version must be a string", - Codes.BAD_JSON, - ) + raise SynapseError(400, "room_version must be a string", Codes.BAD_JSON) if room_version not in KNOWN_ROOM_VERSIONS: raise SynapseError( @@ -523,20 +505,11 @@ class RoomCreationHandler(BaseHandler): if wchar in config["room_alias_name"]: raise SynapseError(400, "Invalid characters in room alias") - room_alias = RoomAlias( - config["room_alias_name"], - self.hs.hostname, - ) - mapping = yield self.store.get_association_from_room_alias( - room_alias - ) + room_alias = RoomAlias(config["room_alias_name"], self.hs.hostname) + mapping = yield self.store.get_association_from_room_alias(room_alias) if mapping: - raise SynapseError( - 400, - "Room alias already taken", - Codes.ROOM_IN_USE - ) + raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE) else: room_alias = None @@ -547,9 +520,7 @@ class RoomCreationHandler(BaseHandler): except Exception: raise SynapseError(400, "Invalid user_id: %s" % (i,)) - yield self.event_creation_handler.assert_accepted_privacy_policy( - requester, - ) + yield self.event_creation_handler.assert_accepted_privacy_policy(requester) invite_3pid_list = config.get("invite_3pid", []) @@ -573,7 +544,7 @@ class RoomCreationHandler(BaseHandler): "preset", RoomCreationPreset.PRIVATE_CHAT if visibility == "private" - else RoomCreationPreset.PUBLIC_CHAT + else RoomCreationPreset.PUBLIC_CHAT, ) raw_initial_state = config.get("initial_state", []) @@ -610,7 +581,8 @@ class RoomCreationHandler(BaseHandler): "state_key": "", "content": {"name": name}, }, - ratelimit=False) + ratelimit=False, + ) if "topic" in config: topic = config["topic"] @@ -623,7 +595,8 @@ class RoomCreationHandler(BaseHandler): "state_key": "", "content": {"topic": topic}, }, - ratelimit=False) + ratelimit=False, + ) for invitee in invite_list: content = {} @@ -658,30 +631,25 @@ class RoomCreationHandler(BaseHandler): if room_alias: result["room_alias"] = room_alias.to_string() - yield directory_handler.send_room_alias_update_event( - requester, room_id - ) + yield directory_handler.send_room_alias_update_event(requester, room_id) defer.returnValue(result) @defer.inlineCallbacks def _send_events_for_new_room( - self, - creator, # A Requester object. - room_id, - preset_config, - invite_list, - initial_state, - creation_content, - room_alias=None, - power_level_content_override=None, - creator_join_profile=None, + self, + creator, # A Requester object. + room_id, + preset_config, + invite_list, + initial_state, + creation_content, + room_alias=None, + power_level_content_override=None, + creator_join_profile=None, ): def create(etype, content, **kwargs): - e = { - "type": etype, - "content": content, - } + e = {"type": etype, "content": content} e.update(event_keys) e.update(kwargs) @@ -693,26 +661,17 @@ class RoomCreationHandler(BaseHandler): event = create(etype, content, **kwargs) logger.info("Sending %s in new room", etype) yield self.event_creation_handler.create_and_send_nonmember_event( - creator, - event, - ratelimit=False + creator, event, ratelimit=False ) config = RoomCreationHandler.PRESETS_DICT[preset_config] creator_id = creator.user.to_string() - event_keys = { - "room_id": room_id, - "sender": creator_id, - "state_key": "", - } + event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""} creation_content.update({"creator": creator_id}) - yield send( - etype=EventTypes.Create, - content=creation_content, - ) + yield send(etype=EventTypes.Create, content=creation_content) logger.info("Sending %s in new room", EventTypes.Member) yield self.room_member_handler.update_membership( @@ -726,17 +685,12 @@ class RoomCreationHandler(BaseHandler): # We treat the power levels override specially as this needs to be one # of the first events that get sent into a room. - pl_content = initial_state.pop((EventTypes.PowerLevels, ''), None) + pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None) if pl_content is not None: - yield send( - etype=EventTypes.PowerLevels, - content=pl_content, - ) + yield send(etype=EventTypes.PowerLevels, content=pl_content) else: power_level_content = { - "users": { - creator_id: 100, - }, + "users": {creator_id: 100}, "users_default": 0, "events": { EventTypes.Name: 50, @@ -760,42 +714,33 @@ class RoomCreationHandler(BaseHandler): if power_level_content_override: power_level_content.update(power_level_content_override) - yield send( - etype=EventTypes.PowerLevels, - content=power_level_content, - ) + yield send(etype=EventTypes.PowerLevels, content=power_level_content) - if room_alias and (EventTypes.CanonicalAlias, '') not in initial_state: + if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: yield send( etype=EventTypes.CanonicalAlias, content={"alias": room_alias.to_string()}, ) - if (EventTypes.JoinRules, '') not in initial_state: + if (EventTypes.JoinRules, "") not in initial_state: yield send( - etype=EventTypes.JoinRules, - content={"join_rule": config["join_rules"]}, + etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]} ) - if (EventTypes.RoomHistoryVisibility, '') not in initial_state: + if (EventTypes.RoomHistoryVisibility, "") not in initial_state: yield send( etype=EventTypes.RoomHistoryVisibility, - content={"history_visibility": config["history_visibility"]} + content={"history_visibility": config["history_visibility"]}, ) if config["guest_can_join"]: - if (EventTypes.GuestAccess, '') not in initial_state: + if (EventTypes.GuestAccess, "") not in initial_state: yield send( - etype=EventTypes.GuestAccess, - content={"guest_access": "can_join"} + etype=EventTypes.GuestAccess, content={"guest_access": "can_join"} ) for (etype, state_key), content in initial_state.items(): - yield send( - etype=etype, - state_key=state_key, - content=content, - ) + yield send(etype=etype, state_key=state_key, content=content) @defer.inlineCallbacks def _generate_room_id(self, creator_id, is_public): @@ -805,12 +750,9 @@ class RoomCreationHandler(BaseHandler): while attempts < 5: try: random_string = stringutils.random_string(18) - gen_room_id = RoomID( - random_string, - self.hs.hostname, - ).to_string() + gen_room_id = RoomID(random_string, self.hs.hostname).to_string() if isinstance(gen_room_id, bytes): - gen_room_id = gen_room_id.decode('utf-8') + gen_room_id = gen_room_id.decode("utf-8") yield self.store.store_room( room_id=gen_room_id, room_creator_user_id=creator_id, @@ -844,7 +786,7 @@ class RoomContextHandler(object): Returns: dict, or None if the event isn't found """ - before_limit = math.floor(limit / 2.) + before_limit = math.floor(limit / 2.0) after_limit = limit - before_limit users = yield self.store.get_users_in_room(room_id) @@ -852,24 +794,19 @@ class RoomContextHandler(object): def filter_evts(events): return filter_events_for_client( - self.store, - user.to_string(), - events, - is_peeking=is_peeking + self.store, user.to_string(), events, is_peeking=is_peeking ) - event = yield self.store.get_event(event_id, get_prev_content=True, - allow_none=True) + event = yield self.store.get_event( + event_id, get_prev_content=True, allow_none=True + ) if not event: defer.returnValue(None) return - filtered = yield(filter_evts([event])) + filtered = yield (filter_evts([event])) if not filtered: - raise AuthError( - 403, - "You don't have permission to access that event." - ) + raise AuthError(403, "You don't have permission to access that event.") results = yield self.store.get_events_around( room_id, event_id, before_limit, after_limit, event_filter @@ -901,7 +838,7 @@ class RoomContextHandler(object): # https://github.com/matrix-org/matrix-doc/issues/687 state = yield self.store.get_state_for_events( - [last_event_id], state_filter=state_filter, + [last_event_id], state_filter=state_filter ) results["state"] = list(state[last_event_id].values()) @@ -913,9 +850,7 @@ class RoomContextHandler(object): "room_key", results["start"] ).to_string() - results["end"] = token.copy_and_replace( - "room_key", results["end"] - ).to_string() + results["end"] = token.copy_and_replace("room_key", results["end"]).to_string() defer.returnValue(results) @@ -926,13 +861,7 @@ class RoomEventSource(object): @defer.inlineCallbacks def get_new_events( - self, - user, - from_key, - limit, - room_ids, - is_guest, - explicit_room_id=None, + self, user, from_key, limit, room_ids, is_guest, explicit_room_id=None ): # We just ignore the key for now. @@ -943,9 +872,7 @@ class RoomEventSource(object): logger.warn("Stream has topological part!!!! %r", from_key) from_key = "s%s" % (from_token.stream,) - app_service = self.store.get_app_service_by_user_id( - user.to_string() - ) + app_service = self.store.get_app_service_by_user_id(user.to_string()) if app_service: # We no longer support AS users using /sync directly. # See https://github.com/matrix-org/matrix-doc/issues/1144 @@ -960,7 +887,7 @@ class RoomEventSource(object): from_key=from_key, to_key=to_key, limit=limit or 10, - order='ASC', + order="ASC", ) events = list(room_events) diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 617d1c9ef8..aae696a7e8 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -46,13 +46,18 @@ class RoomListHandler(BaseHandler): super(RoomListHandler, self).__init__(hs) self.enable_room_list_search = hs.config.enable_room_list_search self.response_cache = ResponseCache(hs, "room_list") - self.remote_response_cache = ResponseCache(hs, "remote_room_list", - timeout_ms=30 * 1000) + self.remote_response_cache = ResponseCache( + hs, "remote_room_list", timeout_ms=30 * 1000 + ) - def get_local_public_room_list(self, limit=None, since_token=None, - search_filter=None, - network_tuple=EMPTY_THIRD_PARTY_ID, - from_federation=False): + def get_local_public_room_list( + self, + limit=None, + since_token=None, + search_filter=None, + network_tuple=EMPTY_THIRD_PARTY_ID, + from_federation=False, + ): """Generate a local public room list. There are multiple different lists: the main one plus one per third @@ -68,14 +73,14 @@ class RoomListHandler(BaseHandler): Setting to None returns all public rooms across all lists. """ if not self.enable_room_list_search: - return defer.succeed({ - "chunk": [], - "total_room_count_estimate": 0, - }) + return defer.succeed({"chunk": [], "total_room_count_estimate": 0}) logger.info( "Getting public room list: limit=%r, since=%r, search=%r, network=%r", - limit, since_token, bool(search_filter), network_tuple, + limit, + since_token, + bool(search_filter), + network_tuple, ) if search_filter: @@ -88,24 +93,33 @@ class RoomListHandler(BaseHandler): # solution at some point timeout = self.clock.time() + 60 return self._get_public_room_list( - limit, since_token, search_filter, - network_tuple=network_tuple, timeout=timeout, + limit, + since_token, + search_filter, + network_tuple=network_tuple, + timeout=timeout, ) key = (limit, since_token, network_tuple) return self.response_cache.wrap( key, self._get_public_room_list, - limit, since_token, - network_tuple=network_tuple, from_federation=from_federation, + limit, + since_token, + network_tuple=network_tuple, + from_federation=from_federation, ) @defer.inlineCallbacks - def _get_public_room_list(self, limit=None, since_token=None, - search_filter=None, - network_tuple=EMPTY_THIRD_PARTY_ID, - from_federation=False, - timeout=None,): + def _get_public_room_list( + self, + limit=None, + since_token=None, + search_filter=None, + network_tuple=EMPTY_THIRD_PARTY_ID, + from_federation=False, + timeout=None, + ): """Generate a public room list. Args: limit (int|None): Maximum amount of rooms to return. @@ -135,15 +149,14 @@ class RoomListHandler(BaseHandler): current_public_id = yield self.store.get_current_public_room_stream_id() public_room_stream_id = since_token.public_room_stream_id newly_visible, newly_unpublished = yield self.store.get_public_room_changes( - public_room_stream_id, current_public_id, - network_tuple=network_tuple, + public_room_stream_id, current_public_id, network_tuple=network_tuple ) else: stream_token = yield self.store.get_room_max_stream_ordering() public_room_stream_id = yield self.store.get_current_public_room_stream_id() room_ids = yield self.store.get_public_room_ids_at_stream_id( - public_room_stream_id, network_tuple=network_tuple, + public_room_stream_id, network_tuple=network_tuple ) # We want to return rooms in a particular order: the number of joined @@ -168,7 +181,7 @@ class RoomListHandler(BaseHandler): return joined_users = yield self.state_handler.get_current_users_in_room( - room_id, latest_event_ids, + room_id, latest_event_ids ) num_joined_users = len(joined_users) @@ -180,8 +193,9 @@ class RoomListHandler(BaseHandler): # We want larger rooms to be first, hence negating num_joined_users rooms_to_order_value[room_id] = (-num_joined_users, room_id) - logger.info("Getting ordering for %i rooms since %s", - len(room_ids), stream_token) + logger.info( + "Getting ordering for %i rooms since %s", len(room_ids), stream_token + ) yield concurrently_execute(get_order_for_room, room_ids, 10) sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1]) @@ -193,7 +207,8 @@ class RoomListHandler(BaseHandler): # Filter out rooms that we don't want to return rooms_to_scan = [ - r for r in sorted_rooms + r + for r in sorted_rooms if r not in newly_unpublished and rooms_to_num_joined[r] > 0 ] @@ -204,13 +219,12 @@ class RoomListHandler(BaseHandler): # `since_token.current_limit` is the index of the last room we # sent down, so we exclude it and everything before/after it. if since_token.direction_is_forward: - rooms_to_scan = rooms_to_scan[since_token.current_limit + 1:] + rooms_to_scan = rooms_to_scan[since_token.current_limit + 1 :] else: - rooms_to_scan = rooms_to_scan[:since_token.current_limit] + rooms_to_scan = rooms_to_scan[: since_token.current_limit] rooms_to_scan.reverse() - logger.info("After sorting and filtering, %i rooms remain", - len(rooms_to_scan)) + logger.info("After sorting and filtering, %i rooms remain", len(rooms_to_scan)) # _append_room_entry_to_chunk will append to chunk but will stop if # len(chunk) > limit @@ -237,15 +251,19 @@ class RoomListHandler(BaseHandler): if timeout and self.clock.time() > timeout: raise Exception("Timed out searching room directory") - batch = rooms_to_scan[i:i + step] + batch = rooms_to_scan[i : i + step] logger.info("Processing %i rooms for result", len(batch)) yield concurrently_execute( lambda r: self._append_room_entry_to_chunk( - r, rooms_to_num_joined[r], - chunk, limit, search_filter, + r, + rooms_to_num_joined[r], + chunk, + limit, + search_filter, from_federation=from_federation, ), - batch, 5, + batch, + 5, ) logger.info("Now %i rooms in result", len(chunk)) if len(chunk) >= limit + 1: @@ -273,10 +291,7 @@ class RoomListHandler(BaseHandler): new_limit = sorted_rooms.index(last_room_id) - results = { - "chunk": chunk, - "total_room_count_estimate": total_room_count, - } + results = {"chunk": chunk, "total_room_count_estimate": total_room_count} if since_token: results["new_rooms"] = bool(newly_visible) @@ -313,8 +328,15 @@ class RoomListHandler(BaseHandler): defer.returnValue(results) @defer.inlineCallbacks - def _append_room_entry_to_chunk(self, room_id, num_joined_users, chunk, limit, - search_filter, from_federation=False): + def _append_room_entry_to_chunk( + self, + room_id, + num_joined_users, + chunk, + limit, + search_filter, + from_federation=False, + ): """Generate the entry for a room in the public room list and append it to the `chunk` if it matches the search filter @@ -345,8 +367,14 @@ class RoomListHandler(BaseHandler): chunk.append(result) @cachedInlineCallbacks(num_args=1, cache_context=True) - def generate_room_entry(self, room_id, num_joined_users, cache_context, - with_alias=True, allow_private=False): + def generate_room_entry( + self, + room_id, + num_joined_users, + cache_context, + with_alias=True, + allow_private=False, + ): """Returns the entry for a room Args: @@ -360,33 +388,31 @@ class RoomListHandler(BaseHandler): Deferred[dict|None]: Returns a room entry as a dictionary, or None if this room was determined not to be shown publicly. """ - result = { - "room_id": room_id, - "num_joined_members": num_joined_users, - } + result = {"room_id": room_id, "num_joined_members": num_joined_users} current_state_ids = yield self.store.get_current_state_ids( - room_id, on_invalidate=cache_context.invalidate, + room_id, on_invalidate=cache_context.invalidate ) - event_map = yield self.store.get_events([ - event_id for key, event_id in iteritems(current_state_ids) - if key[0] in ( - EventTypes.Create, - EventTypes.JoinRules, - EventTypes.Name, - EventTypes.Topic, - EventTypes.CanonicalAlias, - EventTypes.RoomHistoryVisibility, - EventTypes.GuestAccess, - "m.room.avatar", - ) - ]) + event_map = yield self.store.get_events( + [ + event_id + for key, event_id in iteritems(current_state_ids) + if key[0] + in ( + EventTypes.Create, + EventTypes.JoinRules, + EventTypes.Name, + EventTypes.Topic, + EventTypes.CanonicalAlias, + EventTypes.RoomHistoryVisibility, + EventTypes.GuestAccess, + "m.room.avatar", + ) + ] + ) - current_state = { - (ev.type, ev.state_key): ev - for ev in event_map.values() - } + current_state = {(ev.type, ev.state_key): ev for ev in event_map.values()} # Double check that this is actually a public room. @@ -446,14 +472,17 @@ class RoomListHandler(BaseHandler): defer.returnValue(result) @defer.inlineCallbacks - def get_remote_public_room_list(self, server_name, limit=None, since_token=None, - search_filter=None, include_all_networks=False, - third_party_instance_id=None,): + def get_remote_public_room_list( + self, + server_name, + limit=None, + since_token=None, + search_filter=None, + include_all_networks=False, + third_party_instance_id=None, + ): if not self.enable_room_list_search: - defer.returnValue({ - "chunk": [], - "total_room_count_estimate": 0, - }) + defer.returnValue({"chunk": [], "total_room_count_estimate": 0}) if search_filter: # We currently don't support searching across federation, so we have @@ -462,52 +491,75 @@ class RoomListHandler(BaseHandler): since_token = None res = yield self._get_remote_list_cached( - server_name, limit=limit, since_token=since_token, + server_name, + limit=limit, + since_token=since_token, include_all_networks=include_all_networks, third_party_instance_id=third_party_instance_id, ) if search_filter: - res = {"chunk": [ - entry - for entry in list(res.get("chunk", [])) - if _matches_room_entry(entry, search_filter) - ]} + res = { + "chunk": [ + entry + for entry in list(res.get("chunk", [])) + if _matches_room_entry(entry, search_filter) + ] + } defer.returnValue(res) - def _get_remote_list_cached(self, server_name, limit=None, since_token=None, - search_filter=None, include_all_networks=False, - third_party_instance_id=None,): + def _get_remote_list_cached( + self, + server_name, + limit=None, + since_token=None, + search_filter=None, + include_all_networks=False, + third_party_instance_id=None, + ): repl_layer = self.hs.get_federation_client() if search_filter: # We can't cache when asking for search return repl_layer.get_public_rooms( - server_name, limit=limit, since_token=since_token, - search_filter=search_filter, include_all_networks=include_all_networks, + server_name, + limit=limit, + since_token=since_token, + search_filter=search_filter, + include_all_networks=include_all_networks, third_party_instance_id=third_party_instance_id, ) key = ( - server_name, limit, since_token, include_all_networks, + server_name, + limit, + since_token, + include_all_networks, third_party_instance_id, ) return self.remote_response_cache.wrap( key, repl_layer.get_public_rooms, - server_name, limit=limit, since_token=since_token, + server_name, + limit=limit, + since_token=since_token, search_filter=search_filter, include_all_networks=include_all_networks, third_party_instance_id=third_party_instance_id, ) -class RoomListNextBatch(namedtuple("RoomListNextBatch", ( - "stream_ordering", # stream_ordering of the first public room list - "public_room_stream_id", # public room stream id for first public room list - "current_limit", # The number of previous rooms returned - "direction_is_forward", # Bool if this is a next_batch, false if prev_batch -))): +class RoomListNextBatch( + namedtuple( + "RoomListNextBatch", + ( + "stream_ordering", # stream_ordering of the first public room list + "public_room_stream_id", # public room stream id for first public room list + "current_limit", # The number of previous rooms returned + "direction_is_forward", # Bool if this is a next_batch, false if prev_batch + ), + ) +): KEY_DICT = { "stream_ordering": "s", @@ -527,21 +579,19 @@ class RoomListNextBatch(namedtuple("RoomListNextBatch", ( decoded = msgpack.loads(decode_base64(token), raw=False) else: decoded = msgpack.loads(decode_base64(token)) - return RoomListNextBatch(**{ - cls.REVERSE_KEY_DICT[key]: val - for key, val in decoded.items() - }) + return RoomListNextBatch( + **{cls.REVERSE_KEY_DICT[key]: val for key, val in decoded.items()} + ) def to_token(self): - return encode_base64(msgpack.dumps({ - self.KEY_DICT[key]: val - for key, val in self._asdict().items() - })) + return encode_base64( + msgpack.dumps( + {self.KEY_DICT[key]: val for key, val in self._asdict().items()} + ) + ) def copy_and_replace(self, **kwds): - return self._replace( - **kwds - ) + return self._replace(**kwds) def _matches_room_entry(room_entry, search_filter): diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 458902bb7e..4d6e883802 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -166,7 +166,11 @@ class RoomMemberHandler(object): @defer.inlineCallbacks def _local_membership_update( - self, requester, target, room_id, membership, + self, + requester, + target, + room_id, + membership, prev_events_and_hashes, txn_id=None, ratelimit=True, @@ -190,7 +194,6 @@ class RoomMemberHandler(object): "room_id": room_id, "sender": requester.user.to_string(), "state_key": user_id, - # For backwards compatibility: "membership": membership, }, @@ -202,26 +205,19 @@ class RoomMemberHandler(object): # Check if this event matches the previous membership event for the user. duplicate = yield self.event_creation_handler.deduplicate_state_event( - event, context, + event, context ) if duplicate is not None: # Discard the new event since this membership change is a no-op. defer.returnValue(duplicate) yield self.event_creation_handler.handle_new_client_event( - requester, - event, - context, - extra_users=[target], - ratelimit=ratelimit, + requester, event, context, extra_users=[target], ratelimit=ratelimit ) prev_state_ids = yield context.get_prev_state_ids(self.store) - prev_member_event_id = prev_state_ids.get( - (EventTypes.Member, user_id), - None - ) + prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) if event.membership == Membership.JOIN: # Only fire user_joined_room if the user has actually joined the @@ -243,11 +239,11 @@ class RoomMemberHandler(object): if predecessor: # It is an upgraded room. Copy over old tags self.copy_room_tags_and_direct_to_room( - predecessor["room_id"], room_id, user_id, + predecessor["room_id"], room_id, user_id ) # Move over old push rules self.store.move_push_rules_from_room_to_room_for_user( - predecessor["room_id"], room_id, user_id, + predecessor["room_id"], room_id, user_id ) elif event.membership == Membership.LEAVE: if prev_member_event_id: @@ -258,12 +254,7 @@ class RoomMemberHandler(object): defer.returnValue(event) @defer.inlineCallbacks - def copy_room_tags_and_direct_to_room( - self, - old_room_id, - new_room_id, - user_id, - ): + def copy_room_tags_and_direct_to_room(self, old_room_id, new_room_id, user_id): """Copies the tags and direct room state from one room to another. Args: @@ -275,9 +266,7 @@ class RoomMemberHandler(object): Deferred[None] """ # Retrieve user account data for predecessor room - user_account_data, _ = yield self.store.get_account_data_for_user( - user_id, - ) + user_account_data, _ = yield self.store.get_account_data_for_user(user_id) # Copy direct message state if applicable direct_rooms = user_account_data.get("m.direct", {}) @@ -291,34 +280,30 @@ class RoomMemberHandler(object): # Save back to user's m.direct account data yield self.store.add_account_data_for_user( - user_id, "m.direct", direct_rooms, + user_id, "m.direct", direct_rooms ) break # Copy room tags if applicable - room_tags = yield self.store.get_tags_for_room( - user_id, old_room_id, - ) + room_tags = yield self.store.get_tags_for_room(user_id, old_room_id) # Copy each room tag to the new room for tag, tag_content in room_tags.items(): - yield self.store.add_tag_to_room( - user_id, new_room_id, tag, tag_content - ) + yield self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content) @defer.inlineCallbacks def update_membership( - self, - requester, - target, - room_id, - action, - txn_id=None, - remote_room_hosts=None, - third_party_signed=None, - ratelimit=True, - content=None, - require_consent=True, + self, + requester, + target, + room_id, + action, + txn_id=None, + remote_room_hosts=None, + third_party_signed=None, + ratelimit=True, + content=None, + require_consent=True, ): key = (room_id,) @@ -340,17 +325,17 @@ class RoomMemberHandler(object): @defer.inlineCallbacks def _update_membership( - self, - requester, - target, - room_id, - action, - txn_id=None, - remote_room_hosts=None, - third_party_signed=None, - ratelimit=True, - content=None, - require_consent=True, + self, + requester, + target, + room_id, + action, + txn_id=None, + remote_room_hosts=None, + third_party_signed=None, + ratelimit=True, + content=None, + require_consent=True, ): content_specified = bool(content) if content is None: @@ -384,7 +369,7 @@ class RoomMemberHandler(object): if not remote_room_hosts: remote_room_hosts = [] - if effective_membership_state not in ("leave", "ban",): + if effective_membership_state not in ("leave", "ban"): is_blocked = yield self.store.is_room_blocked(room_id) if is_blocked: raise SynapseError(403, "This room has been blocked on this server") @@ -392,22 +377,19 @@ class RoomMemberHandler(object): if effective_membership_state == Membership.INVITE: # block any attempts to invite the server notices mxid if target.to_string() == self._server_notices_mxid: - raise SynapseError( - http_client.FORBIDDEN, - "Cannot invite this user", - ) + raise SynapseError(http_client.FORBIDDEN, "Cannot invite this user") block_invite = False - if (self._server_notices_mxid is not None and - requester.user.to_string() == self._server_notices_mxid): + if ( + self._server_notices_mxid is not None + and requester.user.to_string() == self._server_notices_mxid + ): # allow the server notices mxid to send invites is_requester_admin = True else: - is_requester_admin = yield self.auth.is_server_admin( - requester.user, - ) + is_requester_admin = yield self.auth.is_server_admin(requester.user) if not is_requester_admin: if self.config.block_non_admin_invites: @@ -418,25 +400,19 @@ class RoomMemberHandler(object): block_invite = True if not self.spam_checker.user_may_invite( - requester.user.to_string(), target.to_string(), room_id, + requester.user.to_string(), target.to_string(), room_id ): logger.info("Blocking invite due to spam checker") block_invite = True if block_invite: - raise SynapseError( - 403, "Invites have been disabled on this server", - ) + raise SynapseError(403, "Invites have been disabled on this server") - prev_events_and_hashes = yield self.store.get_prev_events_for_room( - room_id, - ) - latest_event_ids = ( - event_id for (event_id, _, _) in prev_events_and_hashes - ) + prev_events_and_hashes = yield self.store.get_prev_events_for_room(room_id) + latest_event_ids = (event_id for (event_id, _, _) in prev_events_and_hashes) current_state_ids = yield self.state_handler.get_current_state_ids( - room_id, latest_event_ids=latest_event_ids, + room_id, latest_event_ids=latest_event_ids ) # TODO: Refactor into dictionary of explicitly allowed transitions @@ -451,13 +427,13 @@ class RoomMemberHandler(object): 403, "Cannot unban user who was not banned" " (membership=%s)" % old_membership, - errcode=Codes.BAD_STATE + errcode=Codes.BAD_STATE, ) if old_membership == "ban" and action != "unban": raise SynapseError( 403, "Cannot %s user who was banned" % (action,), - errcode=Codes.BAD_STATE + errcode=Codes.BAD_STATE, ) if old_state: @@ -473,8 +449,8 @@ class RoomMemberHandler(object): # we don't allow people to reject invites to the server notice # room, but they can leave it once they are joined. if ( - old_membership == Membership.INVITE and - effective_membership_state == Membership.LEAVE + old_membership == Membership.INVITE + and effective_membership_state == Membership.LEAVE ): is_blocked = yield self._is_server_notice_room(room_id) if is_blocked: @@ -535,7 +511,7 @@ class RoomMemberHandler(object): # send the rejection to the inviter's HS. remote_room_hosts = remote_room_hosts + [inviter.domain] res = yield self._remote_reject_invite( - requester, remote_room_hosts, room_id, target, + requester, remote_room_hosts, room_id, target ) defer.returnValue(res) @@ -554,12 +530,7 @@ class RoomMemberHandler(object): @defer.inlineCallbacks def send_membership_event( - self, - requester, - event, - context, - remote_room_hosts=None, - ratelimit=True, + self, requester, event, context, remote_room_hosts=None, ratelimit=True ): """ Change the membership status of a user in a room. @@ -585,16 +556,15 @@ class RoomMemberHandler(object): if requester is not None: sender = UserID.from_string(event.sender) - assert sender == requester.user, ( - "Sender (%s) must be same as requester (%s)" % - (sender, requester.user) - ) + assert ( + sender == requester.user + ), "Sender (%s) must be same as requester (%s)" % (sender, requester.user) assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) else: requester = synapse.types.create_requester(target_user) prev_event = yield self.event_creation_handler.deduplicate_state_event( - event, context, + event, context ) if prev_event is not None: return @@ -614,16 +584,11 @@ class RoomMemberHandler(object): raise SynapseError(403, "This room has been blocked on this server") yield self.event_creation_handler.handle_new_client_event( - requester, - event, - context, - extra_users=[target_user], - ratelimit=ratelimit, + requester, event, context, extra_users=[target_user], ratelimit=ratelimit ) prev_member_event_id = prev_state_ids.get( - (EventTypes.Member, event.state_key), - None + (EventTypes.Member, event.state_key), None ) if event.membership == Membership.JOIN: @@ -693,31 +658,20 @@ class RoomMemberHandler(object): @defer.inlineCallbacks def _get_inviter(self, user_id, room_id): invite = yield self.store.get_invite_for_user_in_room( - user_id=user_id, - room_id=room_id, + user_id=user_id, room_id=room_id ) if invite: defer.returnValue(UserID.from_string(invite.sender)) @defer.inlineCallbacks def do_3pid_invite( - self, - room_id, - inviter, - medium, - address, - id_server, - requester, - txn_id + self, room_id, inviter, medium, address, id_server, requester, txn_id ): if self.config.block_non_admin_invites: - is_requester_admin = yield self.auth.is_server_admin( - requester.user, - ) + is_requester_admin = yield self.auth.is_server_admin(requester.user) if not is_requester_admin: raise SynapseError( - 403, "Invites have been disabled on this server", - Codes.FORBIDDEN, + 403, "Invites have been disabled on this server", Codes.FORBIDDEN ) # We need to rate limit *before* we send out any 3PID invites, so we @@ -725,35 +679,24 @@ class RoomMemberHandler(object): yield self.base_handler.ratelimit(requester) can_invite = yield self.third_party_event_rules.check_threepid_can_be_invited( - medium, address, room_id, + medium, address, room_id ) if not can_invite: raise SynapseError( - 403, "This third-party identifier can not be invited in this room", + 403, + "This third-party identifier can not be invited in this room", Codes.FORBIDDEN, ) - invitee = yield self._lookup_3pid( - id_server, medium, address - ) + invitee = yield self._lookup_3pid(id_server, medium, address) if invitee: yield self.update_membership( - requester, - UserID.from_string(invitee), - room_id, - "invite", - txn_id=txn_id, + requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id ) else: yield self._make_and_store_3pid_invite( - requester, - id_server, - medium, - address, - room_id, - inviter, - txn_id=txn_id + requester, id_server, medium, address, room_id, inviter, txn_id=txn_id ) @defer.inlineCallbacks @@ -771,15 +714,12 @@ class RoomMemberHandler(object): """ if not self._enable_lookup: raise SynapseError( - 403, "Looking up third-party identifiers is denied from this server", + 403, "Looking up third-party identifiers is denied from this server" ) try: data = yield self.simple_http_client.get_json( - "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,), - { - "medium": medium, - "address": address, - } + "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server), + {"medium": medium, "address": address}, ) if "mxid" in data: @@ -798,29 +738,25 @@ class RoomMemberHandler(object): raise AuthError(401, "No signature from server %s" % (server_hostname,)) for key_name, signature in data["signatures"][server_hostname].items(): key_data = yield self.simple_http_client.get_json( - "%s%s/_matrix/identity/api/v1/pubkey/%s" % - (id_server_scheme, server_hostname, key_name,), + "%s%s/_matrix/identity/api/v1/pubkey/%s" + % (id_server_scheme, server_hostname, key_name) ) if "public_key" not in key_data: - raise AuthError(401, "No public key named %s from %s" % - (key_name, server_hostname,)) + raise AuthError( + 401, "No public key named %s from %s" % (key_name, server_hostname) + ) verify_signed_json( data, server_hostname, - decode_verify_key_bytes(key_name, decode_base64(key_data["public_key"])) + decode_verify_key_bytes( + key_name, decode_base64(key_data["public_key"]) + ), ) return @defer.inlineCallbacks def _make_and_store_3pid_invite( - self, - requester, - id_server, - medium, - address, - room_id, - user, - txn_id + self, requester, id_server, medium, address, room_id, user, txn_id ): room_state = yield self.state_handler.get_current_state(room_id) @@ -868,7 +804,7 @@ class RoomMemberHandler(object): room_join_rules=room_join_rules, room_name=room_name, inviter_display_name=inviter_display_name, - inviter_avatar_url=inviter_avatar_url + inviter_avatar_url=inviter_avatar_url, ) ) @@ -879,7 +815,6 @@ class RoomMemberHandler(object): "content": { "display_name": display_name, "public_keys": public_keys, - # For backwards compatibility: "key_validity_url": fallback_public_key["key_validity_url"], "public_key": fallback_public_key["public_key"], @@ -893,19 +828,19 @@ class RoomMemberHandler(object): @defer.inlineCallbacks def _ask_id_server_for_third_party_invite( - self, - requester, - id_server, - medium, - address, - room_id, - inviter_user_id, - room_alias, - room_avatar_url, - room_join_rules, - room_name, - inviter_display_name, - inviter_avatar_url + self, + requester, + id_server, + medium, + address, + room_id, + inviter_user_id, + room_alias, + room_avatar_url, + room_join_rules, + room_name, + inviter_display_name, + inviter_avatar_url, ): """ Asks an identity server for a third party invite. @@ -937,7 +872,8 @@ class RoomMemberHandler(object): """ is_url = "%s%s/_matrix/identity/api/v1/store-invite" % ( - id_server_scheme, id_server, + id_server_scheme, + id_server, ) invite_config = { @@ -961,14 +897,15 @@ class RoomMemberHandler(object): inviter_user_id=inviter_user_id, ) - invite_config.update({ - "guest_access_token": guest_access_token, - "guest_user_id": guest_user_id, - }) + invite_config.update( + { + "guest_access_token": guest_access_token, + "guest_user_id": guest_user_id, + } + ) data = yield self.simple_http_client.post_urlencoded_get_json( - is_url, - invite_config + is_url, invite_config ) # TODO: Check for success token = data["token"] @@ -976,9 +913,8 @@ class RoomMemberHandler(object): if "public_key" in data: fallback_public_key = { "public_key": data["public_key"], - "key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % ( - id_server_scheme, id_server, - ), + "key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" + % (id_server_scheme, id_server), } else: fallback_public_key = public_keys[0] @@ -1047,10 +983,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): # that we are allowed to join when we decide whether or not we # need to do the invite/join dance. yield self.federation_handler.do_invite_join( - remote_room_hosts, - room_id, - user.to_string(), - content, + remote_room_hosts, room_id, user.to_string(), content ) yield self._user_joined_room(user, room_id) @@ -1061,9 +994,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): fed_handler = self.federation_handler try: ret = yield fed_handler.do_remotely_reject_invite( - remote_room_hosts, - room_id, - target.to_string(), + remote_room_hosts, room_id, target.to_string() ) defer.returnValue(ret) except Exception as e: @@ -1075,9 +1006,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): # logger.warn("Failed to reject invite: %s", e) - yield self.store.locally_reject_invite( - target.to_string(), room_id - ) + yield self.store.locally_reject_invite(target.to_string(), room_id) defer.returnValue({}) def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id): @@ -1101,18 +1030,15 @@ class RoomMemberMasterHandler(RoomMemberHandler): user_id = user.to_string() member = yield self.state_handler.get_current_state( - room_id=room_id, - event_type=EventTypes.Member, - state_key=user_id + room_id=room_id, event_type=EventTypes.Member, state_key=user_id ) membership = member.membership if member else None if membership is not None and membership not in [ - Membership.LEAVE, Membership.BAN + Membership.LEAVE, + Membership.BAN, ]: - raise SynapseError(400, "User %s in room %s" % ( - user_id, room_id - )) + raise SynapseError(400, "User %s in room %s" % (user_id, room_id)) if membership: yield self.store.forget(user_id, room_id) diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index acc6eb8099..da501f38c0 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -71,18 +71,14 @@ class RoomMemberWorkerHandler(RoomMemberHandler): """Implements RoomMemberHandler._user_joined_room """ return self._notify_change_client( - user_id=target.to_string(), - room_id=room_id, - change="joined", + user_id=target.to_string(), room_id=room_id, change="joined" ) def _user_left_room(self, target, room_id): """Implements RoomMemberHandler._user_left_room """ return self._notify_change_client( - user_id=target.to_string(), - room_id=room_id, - change="left", + user_id=target.to_string(), room_id=room_id, change="left" ) def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id): diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 9bba74d6c9..ddc4430d03 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -32,7 +32,6 @@ logger = logging.getLogger(__name__) class SearchHandler(BaseHandler): - def __init__(self, hs): super(SearchHandler, self).__init__(hs) self._event_serializer = hs.get_event_client_serializer() @@ -93,7 +92,7 @@ class SearchHandler(BaseHandler): batch_token = None if batch: try: - b = decode_base64(batch).decode('ascii') + b = decode_base64(batch).decode("ascii") batch_group, batch_group_key, batch_token = b.split("\n") assert batch_group is not None @@ -104,7 +103,9 @@ class SearchHandler(BaseHandler): logger.info( "Search batch properties: %r, %r, %r", - batch_group, batch_group_key, batch_token, + batch_group, + batch_group_key, + batch_token, ) logger.info("Search content: %s", content) @@ -116,9 +117,9 @@ class SearchHandler(BaseHandler): search_term = room_cat["search_term"] # Which "keys" to search over in FTS query - keys = room_cat.get("keys", [ - "content.body", "content.name", "content.topic", - ]) + keys = room_cat.get( + "keys", ["content.body", "content.name", "content.topic"] + ) # Filter to apply to results filter_dict = room_cat.get("filter", {}) @@ -130,9 +131,7 @@ class SearchHandler(BaseHandler): include_state = room_cat.get("include_state", False) # Include context around each event? - event_context = room_cat.get( - "event_context", None - ) + event_context = room_cat.get("event_context", None) # Group results together? May allow clients to paginate within a # group @@ -140,12 +139,8 @@ class SearchHandler(BaseHandler): group_keys = [g["key"] for g in group_by] if event_context is not None: - before_limit = int(event_context.get( - "before_limit", 5 - )) - after_limit = int(event_context.get( - "after_limit", 5 - )) + before_limit = int(event_context.get("before_limit", 5)) + after_limit = int(event_context.get("after_limit", 5)) # Return the historic display name and avatar for the senders # of the events? @@ -159,7 +154,8 @@ class SearchHandler(BaseHandler): if set(group_keys) - {"room_id", "sender"}: raise SynapseError( 400, - "Invalid group by keys: %r" % (set(group_keys) - {"room_id", "sender"},) + "Invalid group by keys: %r" + % (set(group_keys) - {"room_id", "sender"},), ) search_filter = Filter(filter_dict) @@ -190,15 +186,13 @@ class SearchHandler(BaseHandler): room_ids.intersection_update({batch_group_key}) if not room_ids: - defer.returnValue({ - "search_categories": { - "room_events": { - "results": [], - "count": 0, - "highlights": [], + defer.returnValue( + { + "search_categories": { + "room_events": {"results": [], "count": 0, "highlights": []} } } - }) + ) rank_map = {} # event_id -> rank of event allowed_events = [] @@ -213,9 +207,7 @@ class SearchHandler(BaseHandler): count = None if order_by == "rank": - search_result = yield self.store.search_msgs( - room_ids, search_term, keys - ) + search_result = yield self.store.search_msgs(room_ids, search_term, keys) count = search_result["count"] @@ -235,19 +227,17 @@ class SearchHandler(BaseHandler): ) events.sort(key=lambda e: -rank_map[e.event_id]) - allowed_events = events[:search_filter.limit()] + allowed_events = events[: search_filter.limit()] for e in allowed_events: - rm = room_groups.setdefault(e.room_id, { - "results": [], - "order": rank_map[e.event_id], - }) + rm = room_groups.setdefault( + e.room_id, {"results": [], "order": rank_map[e.event_id]} + ) rm["results"].append(e.event_id) - s = sender_group.setdefault(e.sender, { - "results": [], - "order": rank_map[e.event_id], - }) + s = sender_group.setdefault( + e.sender, {"results": [], "order": rank_map[e.event_id]} + ) s["results"].append(e.event_id) elif order_by == "recent": @@ -262,7 +252,10 @@ class SearchHandler(BaseHandler): while len(room_events) < search_filter.limit() and i < 5: i += 1 search_result = yield self.store.search_rooms( - room_ids, search_term, keys, search_filter.limit() * 2, + room_ids, + search_term, + keys, + search_filter.limit() * 2, pagination_token=pagination_token, ) @@ -277,16 +270,14 @@ class SearchHandler(BaseHandler): rank_map.update({r["event"].event_id: r["rank"] for r in results}) - filtered_events = search_filter.filter([ - r["event"] for r in results - ]) + filtered_events = search_filter.filter([r["event"] for r in results]) events = yield filter_events_for_client( self.store, user.to_string(), filtered_events ) room_events.extend(events) - room_events = room_events[:search_filter.limit()] + room_events = room_events[: search_filter.limit()] if len(results) < search_filter.limit() * 2: pagination_token = None @@ -295,9 +286,7 @@ class SearchHandler(BaseHandler): pagination_token = results[-1]["pagination_token"] for event in room_events: - group = room_groups.setdefault(event.room_id, { - "results": [], - }) + group = room_groups.setdefault(event.room_id, {"results": []}) group["results"].append(event.event_id) if room_events and len(room_events) >= search_filter.limit(): @@ -309,18 +298,23 @@ class SearchHandler(BaseHandler): # it returns more from the same group (if applicable) rather # than reverting to searching all results again. if batch_group and batch_group_key: - global_next_batch = encode_base64(("%s\n%s\n%s" % ( - batch_group, batch_group_key, pagination_token - )).encode('ascii')) + global_next_batch = encode_base64( + ( + "%s\n%s\n%s" + % (batch_group, batch_group_key, pagination_token) + ).encode("ascii") + ) else: - global_next_batch = encode_base64(("%s\n%s\n%s" % ( - "all", "", pagination_token - )).encode('ascii')) + global_next_batch = encode_base64( + ("%s\n%s\n%s" % ("all", "", pagination_token)).encode("ascii") + ) for room_id, group in room_groups.items(): - group["next_batch"] = encode_base64(("%s\n%s\n%s" % ( - "room_id", room_id, pagination_token - )).encode('ascii')) + group["next_batch"] = encode_base64( + ("%s\n%s\n%s" % ("room_id", room_id, pagination_token)).encode( + "ascii" + ) + ) allowed_events.extend(room_events) @@ -338,12 +332,13 @@ class SearchHandler(BaseHandler): contexts = {} for event in allowed_events: res = yield self.store.get_events_around( - event.room_id, event.event_id, before_limit, after_limit, + event.room_id, event.event_id, before_limit, after_limit ) logger.info( "Context for search returned %d and %d events", - len(res["events_before"]), len(res["events_after"]), + len(res["events_before"]), + len(res["events_after"]), ) res["events_before"] = yield filter_events_for_client( @@ -403,12 +398,12 @@ class SearchHandler(BaseHandler): for context in contexts.values(): context["events_before"] = ( yield self._event_serializer.serialize_events( - context["events_before"], time_now, + context["events_before"], time_now ) ) context["events_after"] = ( yield self._event_serializer.serialize_events( - context["events_after"], time_now, + context["events_after"], time_now ) ) @@ -426,11 +421,15 @@ class SearchHandler(BaseHandler): results = [] for e in allowed_events: - results.append({ - "rank": rank_map[e.event_id], - "result": (yield self._event_serializer.serialize_event(e, time_now)), - "context": contexts.get(e.event_id, {}), - }) + results.append( + { + "rank": rank_map[e.event_id], + "result": ( + yield self._event_serializer.serialize_event(e, time_now) + ), + "context": contexts.get(e.event_id, {}), + } + ) rooms_cat_res = { "results": results, @@ -442,7 +441,7 @@ class SearchHandler(BaseHandler): s = {} for room_id, state in state_results.items(): s[room_id] = yield self._event_serializer.serialize_events( - state, time_now, + state, time_now ) rooms_cat_res["state"] = s @@ -456,8 +455,4 @@ class SearchHandler(BaseHandler): if global_next_batch: rooms_cat_res["next_batch"] = global_next_batch - defer.returnValue({ - "search_categories": { - "room_events": rooms_cat_res - } - }) + defer.returnValue({"search_categories": {"room_events": rooms_cat_res}}) diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index 7ecdede4dc..5a0995d4fe 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -25,6 +25,7 @@ logger = logging.getLogger(__name__) class SetPasswordHandler(BaseHandler): """Handler which deals with changing user account passwords""" + def __init__(self, hs): super(SetPasswordHandler, self).__init__(hs) self._auth_handler = hs.get_auth_handler() @@ -47,11 +48,11 @@ class SetPasswordHandler(BaseHandler): # we want to log out all of the user's other sessions. First delete # all his other devices. yield self._device_handler.delete_all_devices_for_user( - user_id, except_device_id=except_device_id, + user_id, except_device_id=except_device_id ) # and now delete any access tokens which weren't associated with # devices (or were associated with this device). yield self._auth_handler.delete_access_tokens_for_user( - user_id, except_token_id=except_access_token_id, + user_id, except_token_id=except_access_token_id ) diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py index b268bbcb2c..6b364befd5 100644 --- a/synapse/handlers/state_deltas.py +++ b/synapse/handlers/state_deltas.py @@ -21,7 +21,6 @@ logger = logging.getLogger(__name__) class StateDeltasHandler(object): - def __init__(self, hs): self.store = hs.get_datastore() diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 7ad16c8566..a0ee8db988 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -156,7 +156,7 @@ class StatsHandler(StateDeltasHandler): prev_event_content = {} if prev_event_id is not None: prev_event = yield self.store.get_event( - prev_event_id, allow_none=True, + prev_event_id, allow_none=True ) if prev_event: prev_event_content = prev_event.content diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 62fda0c664..c5188a1f8e 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -64,20 +64,14 @@ LAZY_LOADED_MEMBERS_CACHE_MAX_AGE = 30 * 60 * 1000 LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE = 100 -SyncConfig = collections.namedtuple("SyncConfig", [ - "user", - "filter_collection", - "is_guest", - "request_key", - "device_id", -]) - - -class TimelineBatch(collections.namedtuple("TimelineBatch", [ - "prev_batch", - "events", - "limited", -])): +SyncConfig = collections.namedtuple( + "SyncConfig", ["user", "filter_collection", "is_guest", "request_key", "device_id"] +) + + +class TimelineBatch( + collections.namedtuple("TimelineBatch", ["prev_batch", "events", "limited"]) +): __slots__ = [] def __nonzero__(self): @@ -85,18 +79,24 @@ class TimelineBatch(collections.namedtuple("TimelineBatch", [ to tell if room needs to be part of the sync result. """ return bool(self.events) + __bool__ = __nonzero__ # python3 -class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [ - "room_id", # str - "timeline", # TimelineBatch - "state", # dict[(str, str), FrozenEvent] - "ephemeral", - "account_data", - "unread_notifications", - "summary", -])): +class JoinedSyncResult( + collections.namedtuple( + "JoinedSyncResult", + [ + "room_id", # str + "timeline", # TimelineBatch + "state", # dict[(str, str), FrozenEvent] + "ephemeral", + "account_data", + "unread_notifications", + "summary", + ], + ) +): __slots__ = [] def __nonzero__(self): @@ -111,77 +111,93 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [ # nb the notification count does not, er, count: if there's nothing # else in the result, we don't need to send it. ) + __bool__ = __nonzero__ # python3 -class ArchivedSyncResult(collections.namedtuple("ArchivedSyncResult", [ - "room_id", # str - "timeline", # TimelineBatch - "state", # dict[(str, str), FrozenEvent] - "account_data", -])): +class ArchivedSyncResult( + collections.namedtuple( + "ArchivedSyncResult", + [ + "room_id", # str + "timeline", # TimelineBatch + "state", # dict[(str, str), FrozenEvent] + "account_data", + ], + ) +): __slots__ = [] def __nonzero__(self): """Make the result appear empty if there are no updates. This is used to tell if room needs to be part of the sync result. """ - return bool( - self.timeline - or self.state - or self.account_data - ) + return bool(self.timeline or self.state or self.account_data) + __bool__ = __nonzero__ # python3 -class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [ - "room_id", # str - "invite", # FrozenEvent: the invite event -])): +class InvitedSyncResult( + collections.namedtuple( + "InvitedSyncResult", + ["room_id", "invite"], # str # FrozenEvent: the invite event + ) +): __slots__ = [] def __nonzero__(self): """Invited rooms should always be reported to the client""" return True + __bool__ = __nonzero__ # python3 -class GroupsSyncResult(collections.namedtuple("GroupsSyncResult", [ - "join", - "invite", - "leave", -])): +class GroupsSyncResult( + collections.namedtuple("GroupsSyncResult", ["join", "invite", "leave"]) +): __slots__ = [] def __nonzero__(self): return bool(self.join or self.invite or self.leave) + __bool__ = __nonzero__ # python3 -class DeviceLists(collections.namedtuple("DeviceLists", [ - "changed", # list of user_ids whose devices may have changed - "left", # list of user_ids whose devices we no longer track -])): +class DeviceLists( + collections.namedtuple( + "DeviceLists", + [ + "changed", # list of user_ids whose devices may have changed + "left", # list of user_ids whose devices we no longer track + ], + ) +): __slots__ = [] def __nonzero__(self): return bool(self.changed or self.left) + __bool__ = __nonzero__ # python3 -class SyncResult(collections.namedtuple("SyncResult", [ - "next_batch", # Token for the next sync - "presence", # List of presence events for the user. - "account_data", # List of account_data events for the user. - "joined", # JoinedSyncResult for each joined room. - "invited", # InvitedSyncResult for each invited room. - "archived", # ArchivedSyncResult for each archived room. - "to_device", # List of direct messages for the device. - "device_lists", # List of user_ids whose devices have changed - "device_one_time_keys_count", # Dict of algorithm to count for one time keys - # for this device - "groups", -])): +class SyncResult( + collections.namedtuple( + "SyncResult", + [ + "next_batch", # Token for the next sync + "presence", # List of presence events for the user. + "account_data", # List of account_data events for the user. + "joined", # JoinedSyncResult for each joined room. + "invited", # InvitedSyncResult for each invited room. + "archived", # ArchivedSyncResult for each archived room. + "to_device", # List of direct messages for the device. + "device_lists", # List of user_ids whose devices have changed + "device_one_time_keys_count", # Dict of algorithm to count for one time keys + # for this device + "groups", + ], + ) +): __slots__ = [] def __nonzero__(self): @@ -190,20 +206,20 @@ class SyncResult(collections.namedtuple("SyncResult", [ events. """ return bool( - self.presence or - self.joined or - self.invited or - self.archived or - self.account_data or - self.to_device or - self.device_lists or - self.groups + self.presence + or self.joined + or self.invited + or self.archived + or self.account_data + or self.to_device + or self.device_lists + or self.groups ) + __bool__ = __nonzero__ # python3 class SyncHandler(object): - def __init__(self, hs): self.hs_config = hs.config self.store = hs.get_datastore() @@ -217,13 +233,16 @@ class SyncHandler(object): # ExpiringCache((User, Device)) -> LruCache(state_key => event_id) self.lazy_loaded_members_cache = ExpiringCache( - "lazy_loaded_members_cache", self.clock, - max_len=0, expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, + "lazy_loaded_members_cache", + self.clock, + max_len=0, + expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, ) @defer.inlineCallbacks - def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, - full_state=False): + def wait_for_sync_for_user( + self, sync_config, since_token=None, timeout=0, full_state=False + ): """Get the sync for a client if we have new data for it now. Otherwise wait for new data to arrive on the server. If the timeout expires, then return an empty sync result. @@ -239,13 +258,15 @@ class SyncHandler(object): res = yield self.response_cache.wrap( sync_config.request_key, self._wait_for_sync_for_user, - sync_config, since_token, timeout, full_state, + sync_config, + since_token, + timeout, + full_state, ) defer.returnValue(res) @defer.inlineCallbacks - def _wait_for_sync_for_user(self, sync_config, since_token, timeout, - full_state): + def _wait_for_sync_for_user(self, sync_config, since_token, timeout, full_state): if since_token is None: sync_type = "initial_sync" elif full_state: @@ -261,14 +282,17 @@ class SyncHandler(object): # we are going to return immediately, so don't bother calling # notifier.wait_for_events. result = yield self.current_sync_for_user( - sync_config, since_token, full_state=full_state, + sync_config, since_token, full_state=full_state ) else: + def current_sync_callback(before_token, after_token): return self.current_sync_for_user(sync_config, since_token) result = yield self.notifier.wait_for_events( - sync_config.user.to_string(), timeout, current_sync_callback, + sync_config.user.to_string(), + timeout, + current_sync_callback, from_token=since_token, ) @@ -281,8 +305,7 @@ class SyncHandler(object): defer.returnValue(result) - def current_sync_for_user(self, sync_config, since_token=None, - full_state=False): + def current_sync_for_user(self, sync_config, since_token=None, full_state=False): """Get the sync for client needed to match what the server has now. Returns: A Deferred SyncResult. @@ -334,8 +357,7 @@ class SyncHandler(object): # result returned by the event source is poor form (it might cache # the object) room_id = event["room_id"] - event_copy = {k: v for (k, v) in iteritems(event) - if k != "room_id"} + event_copy = {k: v for (k, v) in iteritems(event) if k != "room_id"} ephemeral_by_room.setdefault(room_id, []).append(event_copy) receipt_key = since_token.receipt_key if since_token else "0" @@ -353,22 +375,30 @@ class SyncHandler(object): for event in receipts: room_id = event["room_id"] # exclude room id, as above - event_copy = {k: v for (k, v) in iteritems(event) - if k != "room_id"} + event_copy = {k: v for (k, v) in iteritems(event) if k != "room_id"} ephemeral_by_room.setdefault(room_id, []).append(event_copy) defer.returnValue((now_token, ephemeral_by_room)) @defer.inlineCallbacks - def _load_filtered_recents(self, room_id, sync_config, now_token, - since_token=None, recents=None, newly_joined_room=False): + def _load_filtered_recents( + self, + room_id, + sync_config, + now_token, + since_token=None, + recents=None, + newly_joined_room=False, + ): """ Returns: a Deferred TimelineBatch """ with Measure(self.clock, "load_filtered_recents"): timeline_limit = sync_config.filter_collection.timeline_limit() - block_all_timeline = sync_config.filter_collection.blocks_all_room_timeline() + block_all_timeline = ( + sync_config.filter_collection.blocks_all_room_timeline() + ) if recents is None or newly_joined_room or timeline_limit < len(recents): limited = True @@ -396,11 +426,9 @@ class SyncHandler(object): recents = [] if not limited or block_all_timeline: - defer.returnValue(TimelineBatch( - events=recents, - prev_batch=now_token, - limited=False - )) + defer.returnValue( + TimelineBatch(events=recents, prev_batch=now_token, limited=False) + ) filtering_factor = 2 load_limit = max(timeline_limit * filtering_factor, 10) @@ -427,9 +455,7 @@ class SyncHandler(object): ) else: events, end_key = yield self.store.get_recent_events_for_room( - room_id, - limit=load_limit + 1, - end_token=end_key, + room_id, limit=load_limit + 1, end_token=end_key ) loaded_recents = sync_config.filter_collection.filter_room_timeline( events @@ -462,15 +488,15 @@ class SyncHandler(object): recents = recents[-timeline_limit:] room_key = recents[0].internal_metadata.before - prev_batch_token = now_token.copy_and_replace( - "room_key", room_key - ) + prev_batch_token = now_token.copy_and_replace("room_key", room_key) - defer.returnValue(TimelineBatch( - events=recents, - prev_batch=prev_batch_token, - limited=limited or newly_joined_room - )) + defer.returnValue( + TimelineBatch( + events=recents, + prev_batch=prev_batch_token, + limited=limited or newly_joined_room, + ) + ) @defer.inlineCallbacks def get_state_after_event(self, event, state_filter=StateFilter.all()): @@ -486,7 +512,7 @@ class SyncHandler(object): A Deferred map from ((type, state_key)->Event) """ state_ids = yield self.store.get_state_ids_for_event( - event.event_id, state_filter=state_filter, + event.event_id, state_filter=state_filter ) if event.is_state(): state_ids = state_ids.copy() @@ -511,13 +537,13 @@ class SyncHandler(object): # does not reliably give you the state at the given stream position. # (https://github.com/matrix-org/synapse/issues/3305) last_events, _ = yield self.store.get_recent_events_for_room( - room_id, end_token=stream_position.room_key, limit=1, + room_id, end_token=stream_position.room_key, limit=1 ) if last_events: last_event = last_events[-1] state = yield self.get_state_after_event( - last_event, state_filter=state_filter, + last_event, state_filter=state_filter ) else: @@ -549,7 +575,7 @@ class SyncHandler(object): # FIXME: this promulgates https://github.com/matrix-org/synapse/issues/3305 last_events, _ = yield self.store.get_recent_event_ids_for_room( - room_id, end_token=now_token.room_key, limit=1, + room_id, end_token=now_token.room_key, limit=1 ) if not last_events: @@ -559,28 +585,25 @@ class SyncHandler(object): last_event = last_events[-1] state_ids = yield self.store.get_state_ids_for_event( last_event.event_id, - state_filter=StateFilter.from_types([ - (EventTypes.Name, ''), - (EventTypes.CanonicalAlias, ''), - ]), + state_filter=StateFilter.from_types( + [(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")] + ), ) # this is heavily cached, thus: fast. details = yield self.store.get_room_summary(room_id) - name_id = state_ids.get((EventTypes.Name, '')) - canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, '')) + name_id = state_ids.get((EventTypes.Name, "")) + canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, "")) summary = {} empty_ms = MemberSummary([], 0) # TODO: only send these when they change. - summary["m.joined_member_count"] = ( - details.get(Membership.JOIN, empty_ms).count - ) - summary["m.invited_member_count"] = ( - details.get(Membership.INVITE, empty_ms).count - ) + summary["m.joined_member_count"] = details.get(Membership.JOIN, empty_ms).count + summary["m.invited_member_count"] = details.get( + Membership.INVITE, empty_ms + ).count # if the room has a name or canonical_alias set, we can skip # calculating heroes. Empty strings are falsey, so we check @@ -592,7 +615,7 @@ class SyncHandler(object): if canonical_alias_id: canonical_alias = yield self.store.get_event( - canonical_alias_id, allow_none=True, + canonical_alias_id, allow_none=True ) if canonical_alias and canonical_alias.content.get("alias"): defer.returnValue(summary) @@ -600,26 +623,14 @@ class SyncHandler(object): me = sync_config.user.to_string() joined_user_ids = [ - r[0] - for r in details.get(Membership.JOIN, empty_ms).members - if r[0] != me + r[0] for r in details.get(Membership.JOIN, empty_ms).members if r[0] != me ] invited_user_ids = [ - r[0] - for r in details.get(Membership.INVITE, empty_ms).members - if r[0] != me + r[0] for r in details.get(Membership.INVITE, empty_ms).members if r[0] != me ] - gone_user_ids = ( - [ - r[0] - for r in details.get(Membership.LEAVE, empty_ms).members - if r[0] != me - ] + [ - r[0] - for r in details.get(Membership.BAN, empty_ms).members - if r[0] != me - ] - ) + gone_user_ids = [ + r[0] for r in details.get(Membership.LEAVE, empty_ms).members if r[0] != me + ] + [r[0] for r in details.get(Membership.BAN, empty_ms).members if r[0] != me] # FIXME: only build up a member_ids list for our heroes member_ids = {} @@ -627,20 +638,18 @@ class SyncHandler(object): Membership.JOIN, Membership.INVITE, Membership.LEAVE, - Membership.BAN + Membership.BAN, ): for user_id, event_id in details.get(membership, empty_ms).members: member_ids[user_id] = event_id # FIXME: order by stream ordering rather than as returned by SQL - if (joined_user_ids or invited_user_ids): - summary['m.heroes'] = sorted( + if joined_user_ids or invited_user_ids: + summary["m.heroes"] = sorted( [user_id for user_id in (joined_user_ids + invited_user_ids)] )[0:5] else: - summary['m.heroes'] = sorted( - [user_id for user_id in gone_user_ids] - )[0:5] + summary["m.heroes"] = sorted([user_id for user_id in gone_user_ids])[0:5] if not sync_config.filter_collection.lazy_load_members(): defer.returnValue(summary) @@ -652,8 +661,7 @@ class SyncHandler(object): # track which members the client should already know about via LL: # Ones which are already in state... existing_members = set( - user_id for (typ, user_id) in state.keys() - if typ == EventTypes.Member + user_id for (typ, user_id) in state.keys() if typ == EventTypes.Member ) # ...or ones which are in the timeline... @@ -664,10 +672,10 @@ class SyncHandler(object): # ...and then ensure any missing ones get included in state. missing_hero_event_ids = [ member_ids[hero_id] - for hero_id in summary['m.heroes'] + for hero_id in summary["m.heroes"] if ( - cache.get(hero_id) != member_ids[hero_id] and - hero_id not in existing_members + cache.get(hero_id) != member_ids[hero_id] + and hero_id not in existing_members ) ] @@ -691,8 +699,9 @@ class SyncHandler(object): return cache @defer.inlineCallbacks - def compute_state_delta(self, room_id, batch, sync_config, since_token, now_token, - full_state): + def compute_state_delta( + self, room_id, batch, sync_config, since_token, now_token, full_state + ): """ Works out the difference in state between the start of the timeline and the previous sync. @@ -745,23 +754,23 @@ class SyncHandler(object): timeline_state = { (event.type, event.state_key): event.event_id - for event in batch.events if event.is_state() + for event in batch.events + if event.is_state() } if full_state: if batch: current_state_ids = yield self.store.get_state_ids_for_event( - batch.events[-1].event_id, state_filter=state_filter, + batch.events[-1].event_id, state_filter=state_filter ) state_ids = yield self.store.get_state_ids_for_event( - batch.events[0].event_id, state_filter=state_filter, + batch.events[0].event_id, state_filter=state_filter ) else: current_state_ids = yield self.get_state_at( - room_id, stream_position=now_token, - state_filter=state_filter, + room_id, stream_position=now_token, state_filter=state_filter ) state_ids = current_state_ids @@ -775,7 +784,7 @@ class SyncHandler(object): ) elif batch.limited: state_at_timeline_start = yield self.store.get_state_ids_for_event( - batch.events[0].event_id, state_filter=state_filter, + batch.events[0].event_id, state_filter=state_filter ) # for now, we disable LL for gappy syncs - see @@ -793,12 +802,11 @@ class SyncHandler(object): state_filter = StateFilter.all() state_at_previous_sync = yield self.get_state_at( - room_id, stream_position=since_token, - state_filter=state_filter, + room_id, stream_position=since_token, state_filter=state_filter ) current_state_ids = yield self.store.get_state_ids_for_event( - batch.events[-1].event_id, state_filter=state_filter, + batch.events[-1].event_id, state_filter=state_filter ) state_ids = _calculate_state( @@ -854,8 +862,7 @@ class SyncHandler(object): # add any member IDs we are about to send into our LruCache for t, event_id in itertools.chain( - state_ids.items(), - timeline_state.items(), + state_ids.items(), timeline_state.items() ): if t[0] == EventTypes.Member: cache.set(t[1], event_id) @@ -864,10 +871,14 @@ class SyncHandler(object): if state_ids: state = yield self.store.get_events(list(state_ids.values())) - defer.returnValue({ - (e.type, e.state_key): e - for e in sync_config.filter_collection.filter_room_state(list(state.values())) - }) + defer.returnValue( + { + (e.type, e.state_key): e + for e in sync_config.filter_collection.filter_room_state( + list(state.values()) + ) + } + ) @defer.inlineCallbacks def unread_notifs_for_room_id(self, room_id, sync_config): @@ -875,7 +886,7 @@ class SyncHandler(object): last_unread_event_id = yield self.store.get_last_receipt_event_id_for_user( user_id=sync_config.user.to_string(), room_id=room_id, - receipt_type="m.read" + receipt_type="m.read", ) notifs = [] @@ -909,7 +920,9 @@ class SyncHandler(object): logger.info( "Calculating sync response for %r between %s and %s", - sync_config.user, since_token, now_token, + sync_config.user, + since_token, + now_token, ) user_id = sync_config.user.to_string() @@ -920,11 +933,12 @@ class SyncHandler(object): raise NotImplementedError() else: joined_room_ids = yield self.get_rooms_for_user_at( - user_id, now_token.room_stream_id, + user_id, now_token.room_stream_id ) sync_result_builder = SyncResultBuilder( - sync_config, full_state, + sync_config, + full_state, since_token=since_token, now_token=now_token, joined_room_ids=joined_room_ids, @@ -941,8 +955,7 @@ class SyncHandler(object): _, _, newly_left_rooms, newly_left_users = res block_all_presence_data = ( - since_token is None and - sync_config.filter_collection.blocks_all_presence() + since_token is None and sync_config.filter_collection.blocks_all_presence() ) if self.hs_config.use_presence and not block_all_presence_data: yield self._generate_sync_entry_for_presence( @@ -973,22 +986,23 @@ class SyncHandler(object): room_id = joined_room.room_id if room_id in newly_joined_rooms: issue4422_logger.debug( - "Sync result for newly joined room %s: %r", - room_id, joined_room, + "Sync result for newly joined room %s: %r", room_id, joined_room ) - defer.returnValue(SyncResult( - presence=sync_result_builder.presence, - account_data=sync_result_builder.account_data, - joined=sync_result_builder.joined, - invited=sync_result_builder.invited, - archived=sync_result_builder.archived, - to_device=sync_result_builder.to_device, - device_lists=device_lists, - groups=sync_result_builder.groups, - device_one_time_keys_count=one_time_key_counts, - next_batch=sync_result_builder.now_token, - )) + defer.returnValue( + SyncResult( + presence=sync_result_builder.presence, + account_data=sync_result_builder.account_data, + joined=sync_result_builder.joined, + invited=sync_result_builder.invited, + archived=sync_result_builder.archived, + to_device=sync_result_builder.to_device, + device_lists=device_lists, + groups=sync_result_builder.groups, + device_one_time_keys_count=one_time_key_counts, + next_batch=sync_result_builder.now_token, + ) + ) @measure_func("_generate_sync_entry_for_groups") @defer.inlineCallbacks @@ -999,11 +1013,11 @@ class SyncHandler(object): if since_token and since_token.groups_key: results = yield self.store.get_groups_changes_for_user( - user_id, since_token.groups_key, now_token.groups_key, + user_id, since_token.groups_key, now_token.groups_key ) else: results = yield self.store.get_all_groups_for_user( - user_id, now_token.groups_key, + user_id, now_token.groups_key ) invited = {} @@ -1031,17 +1045,19 @@ class SyncHandler(object): left[group_id] = content["content"] sync_result_builder.groups = GroupsSyncResult( - join=joined, - invite=invited, - leave=left, + join=joined, invite=invited, leave=left ) @measure_func("_generate_sync_entry_for_device_list") @defer.inlineCallbacks - def _generate_sync_entry_for_device_list(self, sync_result_builder, - newly_joined_rooms, - newly_joined_or_invited_users, - newly_left_rooms, newly_left_users): + def _generate_sync_entry_for_device_list( + self, + sync_result_builder, + newly_joined_rooms, + newly_joined_or_invited_users, + newly_left_rooms, + newly_left_users, + ): user_id = sync_result_builder.sync_config.user.to_string() since_token = sync_result_builder.since_token @@ -1065,24 +1081,20 @@ class SyncHandler(object): changed.update(newly_joined_or_invited_users) if not changed and not newly_left_users: - defer.returnValue(DeviceLists( - changed=[], - left=newly_left_users, - )) + defer.returnValue(DeviceLists(changed=[], left=newly_left_users)) users_who_share_room = yield self.store.get_users_who_share_room_with_user( user_id ) - defer.returnValue(DeviceLists( - changed=users_who_share_room & changed, - left=set(newly_left_users) - users_who_share_room, - )) + defer.returnValue( + DeviceLists( + changed=users_who_share_room & changed, + left=set(newly_left_users) - users_who_share_room, + ) + ) else: - defer.returnValue(DeviceLists( - changed=[], - left=[], - )) + defer.returnValue(DeviceLists(changed=[], left=[])) @defer.inlineCallbacks def _generate_sync_entry_for_to_device(self, sync_result_builder): @@ -1109,8 +1121,9 @@ class SyncHandler(object): deleted = yield self.store.delete_messages_for_device( user_id, device_id, since_stream_id ) - logger.debug("Deleted %d to-device messages up to %d", - deleted, since_stream_id) + logger.debug( + "Deleted %d to-device messages up to %d", deleted, since_stream_id + ) messages, stream_id = yield self.store.get_new_messages_for_device( user_id, device_id, since_stream_id, now_token.to_device_key @@ -1118,7 +1131,10 @@ class SyncHandler(object): logger.debug( "Returning %d to-device messages between %d and %d (current token: %d)", - len(messages), since_stream_id, stream_id, now_token.to_device_key + len(messages), + since_stream_id, + stream_id, + now_token.to_device_key, ) sync_result_builder.now_token = now_token.copy_and_replace( "to_device_key", stream_id @@ -1145,8 +1161,7 @@ class SyncHandler(object): if since_token and not sync_result_builder.full_state: account_data, account_data_by_room = ( yield self.store.get_updated_account_data_for_user( - user_id, - since_token.account_data_key, + user_id, since_token.account_data_key ) ) @@ -1160,27 +1175,28 @@ class SyncHandler(object): ) else: account_data, account_data_by_room = ( - yield self.store.get_account_data_for_user( - sync_config.user.to_string() - ) + yield self.store.get_account_data_for_user(sync_config.user.to_string()) ) - account_data['m.push_rules'] = yield self.push_rules_for_user( + account_data["m.push_rules"] = yield self.push_rules_for_user( sync_config.user ) - account_data_for_user = sync_config.filter_collection.filter_account_data([ - {"type": account_data_type, "content": content} - for account_data_type, content in account_data.items() - ]) + account_data_for_user = sync_config.filter_collection.filter_account_data( + [ + {"type": account_data_type, "content": content} + for account_data_type, content in account_data.items() + ] + ) sync_result_builder.account_data = account_data_for_user defer.returnValue(account_data_by_room) @defer.inlineCallbacks - def _generate_sync_entry_for_presence(self, sync_result_builder, newly_joined_rooms, - newly_joined_or_invited_users): + def _generate_sync_entry_for_presence( + self, sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users + ): """Generates the presence portion of the sync response. Populates the `sync_result_builder` with the result. @@ -1223,17 +1239,13 @@ class SyncHandler(object): extra_users_ids.discard(user.to_string()) if extra_users_ids: - states = yield self.presence_handler.get_states( - extra_users_ids, - ) + states = yield self.presence_handler.get_states(extra_users_ids) presence.extend(states) # Deduplicate the presence entries so that there's at most one per user presence = list({p.user_id: p for p in presence}.values()) - presence = sync_config.filter_collection.filter_presence( - presence - ) + presence = sync_config.filter_collection.filter_presence(presence) sync_result_builder.presence = presence @@ -1253,8 +1265,8 @@ class SyncHandler(object): """ user_id = sync_result_builder.sync_config.user.to_string() block_all_room_ephemeral = ( - sync_result_builder.since_token is None and - sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral() + sync_result_builder.since_token is None + and sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral() ) if block_all_room_ephemeral: @@ -1275,15 +1287,14 @@ class SyncHandler(object): have_changed = yield self._have_rooms_changed(sync_result_builder) if not have_changed: tags_by_room = yield self.store.get_updated_tags( - user_id, - since_token.account_data_key, + user_id, since_token.account_data_key ) if not tags_by_room: logger.debug("no-oping sync") defer.returnValue(([], [], [], [])) ignored_account_data = yield self.store.get_global_account_data_by_type_for_user( - "m.ignored_user_list", user_id=user_id, + "m.ignored_user_list", user_id=user_id ) if ignored_account_data: @@ -1296,7 +1307,7 @@ class SyncHandler(object): room_entries, invited, newly_joined_rooms, newly_left_rooms = res tags_by_room = yield self.store.get_updated_tags( - user_id, since_token.account_data_key, + user_id, since_token.account_data_key ) else: res = yield self._get_all_rooms(sync_result_builder, ignored_users) @@ -1331,8 +1342,8 @@ class SyncHandler(object): for event in it: if event.type == EventTypes.Member: if ( - event.membership == Membership.JOIN or - event.membership == Membership.INVITE + event.membership == Membership.JOIN + or event.membership == Membership.INVITE ): newly_joined_or_invited_users.add(event.state_key) else: @@ -1343,12 +1354,14 @@ class SyncHandler(object): newly_left_users -= newly_joined_or_invited_users - defer.returnValue(( - newly_joined_rooms, - newly_joined_or_invited_users, - newly_left_rooms, - newly_left_users, - )) + defer.returnValue( + ( + newly_joined_rooms, + newly_joined_or_invited_users, + newly_left_rooms, + newly_left_users, + ) + ) @defer.inlineCallbacks def _have_rooms_changed(self, sync_result_builder): @@ -1454,7 +1467,9 @@ class SyncHandler(object): prev_membership = old_mem_ev.membership issue4422_logger.debug( "Previous membership for room %s with join: %s (event %s)", - room_id, prev_membership, old_mem_ev_id, + room_id, + prev_membership, + old_mem_ev_id, ) if not old_mem_ev or old_mem_ev.membership != Membership.JOIN: @@ -1476,8 +1491,7 @@ class SyncHandler(object): if not old_state_ids: old_state_ids = yield self.get_state_at(room_id, since_token) old_mem_ev_id = old_state_ids.get( - (EventTypes.Member, user_id), - None, + (EventTypes.Member, user_id), None ) old_mem_ev = None if old_mem_ev_id: @@ -1498,7 +1512,8 @@ class SyncHandler(object): # Always include leave/ban events. Just take the last one. # TODO: How do we handle ban -> leave in same batch? leave_events = [ - e for e in non_joins + e + for e in non_joins if e.membership in (Membership.LEAVE, Membership.BAN) ] @@ -1526,15 +1541,17 @@ class SyncHandler(object): else: batch_events = None - room_entries.append(RoomSyncResultBuilder( - room_id=room_id, - rtype="archived", - events=batch_events, - newly_joined=room_id in newly_joined_rooms, - full_state=False, - since_token=since_token, - upto_token=leave_token, - )) + room_entries.append( + RoomSyncResultBuilder( + room_id=room_id, + rtype="archived", + events=batch_events, + newly_joined=room_id in newly_joined_rooms, + full_state=False, + since_token=since_token, + upto_token=leave_token, + ) + ) timeline_limit = sync_config.filter_collection.timeline_limit() @@ -1581,7 +1598,8 @@ class SyncHandler(object): # debugging for https://github.com/matrix-org/synapse/issues/4422 issue4422_logger.debug( "RoomSyncResultBuilder events for newly joined room %s: %r", - room_id, entry.events, + room_id, + entry.events, ) room_entries.append(entry) @@ -1606,12 +1624,14 @@ class SyncHandler(object): sync_config = sync_result_builder.sync_config membership_list = ( - Membership.INVITE, Membership.JOIN, Membership.LEAVE, Membership.BAN + Membership.INVITE, + Membership.JOIN, + Membership.LEAVE, + Membership.BAN, ) room_list = yield self.store.get_rooms_for_user_where_membership_is( - user_id=user_id, - membership_list=membership_list + user_id=user_id, membership_list=membership_list ) room_entries = [] @@ -1619,23 +1639,22 @@ class SyncHandler(object): for event in room_list: if event.membership == Membership.JOIN: - room_entries.append(RoomSyncResultBuilder( - room_id=event.room_id, - rtype="joined", - events=None, - newly_joined=False, - full_state=True, - since_token=since_token, - upto_token=now_token, - )) + room_entries.append( + RoomSyncResultBuilder( + room_id=event.room_id, + rtype="joined", + events=None, + newly_joined=False, + full_state=True, + since_token=since_token, + upto_token=now_token, + ) + ) elif event.membership == Membership.INVITE: if event.sender in ignored_users: continue invite = yield self.store.get_event(event.event_id) - invited.append(InvitedSyncResult( - room_id=event.room_id, - invite=invite, - )) + invited.append(InvitedSyncResult(room_id=event.room_id, invite=invite)) elif event.membership in (Membership.LEAVE, Membership.BAN): # Always send down rooms we were banned or kicked from. if not sync_config.filter_collection.include_leave: @@ -1646,22 +1665,31 @@ class SyncHandler(object): leave_token = now_token.copy_and_replace( "room_key", "s%d" % (event.stream_ordering,) ) - room_entries.append(RoomSyncResultBuilder( - room_id=event.room_id, - rtype="archived", - events=None, - newly_joined=False, - full_state=True, - since_token=since_token, - upto_token=leave_token, - )) + room_entries.append( + RoomSyncResultBuilder( + room_id=event.room_id, + rtype="archived", + events=None, + newly_joined=False, + full_state=True, + since_token=since_token, + upto_token=leave_token, + ) + ) defer.returnValue((room_entries, invited, [])) @defer.inlineCallbacks - def _generate_room_entry(self, sync_result_builder, ignored_users, - room_builder, ephemeral, tags, account_data, - always_include=False): + def _generate_room_entry( + self, + sync_result_builder, + ignored_users, + room_builder, + ephemeral, + tags, + account_data, + always_include=False, + ): """Populates the `joined` and `archived` section of `sync_result_builder` based on the `room_builder`. @@ -1678,9 +1706,7 @@ class SyncHandler(object): """ newly_joined = room_builder.newly_joined full_state = ( - room_builder.full_state - or newly_joined - or sync_result_builder.full_state + room_builder.full_state or newly_joined or sync_result_builder.full_state ) events = room_builder.events @@ -1697,7 +1723,8 @@ class SyncHandler(object): upto_token = room_builder.upto_token batch = yield self._load_filtered_recents( - room_id, sync_config, + room_id, + sync_config, now_token=upto_token, since_token=since_token, recents=events, @@ -1708,7 +1735,8 @@ class SyncHandler(object): # debug for https://github.com/matrix-org/synapse/issues/4422 issue4422_logger.debug( "Timeline events after filtering in newly-joined room %s: %r", - room_id, batch, + room_id, + batch, ) # When we join the room (or the client requests full_state), we should @@ -1726,16 +1754,10 @@ class SyncHandler(object): account_data_events = [] if tags is not None: - account_data_events.append({ - "type": "m.tag", - "content": {"tags": tags}, - }) + account_data_events.append({"type": "m.tag", "content": {"tags": tags}}) for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) + account_data_events.append({"type": account_data_type, "content": content}) account_data_events = sync_config.filter_collection.filter_room_account_data( account_data_events @@ -1743,16 +1765,13 @@ class SyncHandler(object): ephemeral = sync_config.filter_collection.filter_room_ephemeral(ephemeral) - if not (always_include - or batch - or account_data_events - or ephemeral - or full_state): + if not ( + always_include or batch or account_data_events or ephemeral or full_state + ): return state = yield self.compute_state_delta( - room_id, batch, sync_config, since_token, now_token, - full_state=full_state + room_id, batch, sync_config, since_token, now_token, full_state=full_state ) summary = {} @@ -1760,22 +1779,19 @@ class SyncHandler(object): # we include a summary in room responses when we're lazy loading # members (as the client otherwise doesn't have enough info to form # the name itself). - if ( - sync_config.filter_collection.lazy_load_members() and - ( - # we recalulate the summary: - # if there are membership changes in the timeline, or - # if membership has changed during a gappy sync, or - # if this is an initial sync. - any(ev.type == EventTypes.Member for ev in batch.events) or - ( - # XXX: this may include false positives in the form of LL - # members which have snuck into state - batch.limited and - any(t == EventTypes.Member for (t, k) in state) - ) or - since_token is None + if sync_config.filter_collection.lazy_load_members() and ( + # we recalulate the summary: + # if there are membership changes in the timeline, or + # if membership has changed during a gappy sync, or + # if this is an initial sync. + any(ev.type == EventTypes.Member for ev in batch.events) + or ( + # XXX: this may include false positives in the form of LL + # members which have snuck into state + batch.limited + and any(t == EventTypes.Member for (t, k) in state) ) + or since_token is None ): summary = yield self.compute_summary( room_id, sync_config, batch, state, now_token @@ -1794,9 +1810,7 @@ class SyncHandler(object): ) if room_sync or always_include: - notifs = yield self.unread_notifs_for_room_id( - room_id, sync_config - ) + notifs = yield self.unread_notifs_for_room_id(room_id, sync_config) if notifs is not None: unread_notifications["notification_count"] = notifs["notify_count"] @@ -1807,11 +1821,8 @@ class SyncHandler(object): if batch.limited and since_token: user_id = sync_result_builder.sync_config.user.to_string() logger.info( - "Incremental gappy sync of %s for user %s with %d state events" % ( - room_id, - user_id, - len(state), - ) + "Incremental gappy sync of %s for user %s with %d state events" + % (room_id, user_id, len(state)) ) elif room_builder.rtype == "archived": room_sync = ArchivedSyncResult( @@ -1841,9 +1852,7 @@ class SyncHandler(object): Deferred[frozenset[str]]: Set of room_ids the user is in at given stream_ordering. """ - joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering( - user_id, - ) + joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering(user_id) joined_room_ids = set() @@ -1862,11 +1871,9 @@ class SyncHandler(object): logger.info("User joined room after current token: %s", room_id) extrems = yield self.store.get_forward_extremeties_for_room( - room_id, stream_ordering, - ) - users_in_room = yield self.state.get_current_users_in_room( - room_id, extrems, + room_id, stream_ordering ) + users_in_room = yield self.state.get_current_users_in_room(room_id, extrems) if user_id in users_in_room: joined_room_ids.add(room_id) @@ -1886,7 +1893,7 @@ def _action_has_highlight(actions): def _calculate_state( - timeline_contains, timeline_start, previous, current, lazy_load_members, + timeline_contains, timeline_start, previous, current, lazy_load_members ): """Works out what state to include in a sync response. @@ -1930,15 +1937,12 @@ def _calculate_state( if lazy_load_members: p_ids.difference_update( - e for t, e in iteritems(timeline_start) - if t[0] == EventTypes.Member + e for t, e in iteritems(timeline_start) if t[0] == EventTypes.Member ) state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids - return { - event_id_to_key[e]: e for e in state_ids - } + return {event_id_to_key[e]: e for e in state_ids} class SyncResultBuilder(object): @@ -1961,8 +1965,10 @@ class SyncResultBuilder(object): groups (GroupsSyncResult|None) to_device (list) """ - def __init__(self, sync_config, full_state, since_token, now_token, - joined_room_ids): + + def __init__( + self, sync_config, full_state, since_token, now_token, joined_room_ids + ): """ Args: sync_config (SyncConfig) @@ -1991,8 +1997,10 @@ class RoomSyncResultBuilder(object): """Stores information needed to create either a `JoinedSyncResult` or `ArchivedSyncResult`. """ - def __init__(self, room_id, rtype, events, newly_joined, full_state, - since_token, upto_token): + + def __init__( + self, room_id, rtype, events, newly_joined, full_state, since_token, upto_token + ): """ Args: room_id(str) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 972662eb48..f8062c8671 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -68,13 +68,10 @@ class TypingHandler(object): # caches which room_ids changed at which serials self._typing_stream_change_cache = StreamChangeCache( - "TypingStreamChangeCache", self._latest_room_serial, + "TypingStreamChangeCache", self._latest_room_serial ) - self.clock.looping_call( - self._handle_timeouts, - 5000, - ) + self.clock.looping_call(self._handle_timeouts, 5000) def _reset(self): """ @@ -108,19 +105,11 @@ class TypingHandler(object): if self.hs.is_mine_id(member.user_id): last_fed_poke = self._member_last_federation_poke.get(member, None) if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now: - run_in_background( - self._push_remote, - member=member, - typing=True - ) + run_in_background(self._push_remote, member=member, typing=True) # Add a paranoia timer to ensure that we always have a timer for # each person typing. - self.wheel_timer.insert( - now=now, - obj=member, - then=now + 60 * 1000, - ) + self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000) def is_typing(self, member): return member.user_id in self._room_typing.get(member.room_id, []) @@ -138,9 +127,7 @@ class TypingHandler(object): yield self.auth.check_joined_room(room_id, target_user_id) - logger.debug( - "%s has started typing in %s", target_user_id, room_id - ) + logger.debug("%s has started typing in %s", target_user_id, room_id) member = RoomMember(room_id=room_id, user_id=target_user_id) @@ -149,20 +136,13 @@ class TypingHandler(object): now = self.clock.time_msec() self._member_typing_until[member] = now + timeout - self.wheel_timer.insert( - now=now, - obj=member, - then=now + timeout, - ) + self.wheel_timer.insert(now=now, obj=member, then=now + timeout) if was_present: # No point sending another notification defer.returnValue(None) - self._push_update( - member=member, - typing=True, - ) + self._push_update(member=member, typing=True) @defer.inlineCallbacks def stopped_typing(self, target_user, auth_user, room_id): @@ -177,9 +157,7 @@ class TypingHandler(object): yield self.auth.check_joined_room(room_id, target_user_id) - logger.debug( - "%s has stopped typing in %s", target_user_id, room_id - ) + logger.debug("%s has stopped typing in %s", target_user_id, room_id) member = RoomMember(room_id=room_id, user_id=target_user_id) @@ -200,20 +178,14 @@ class TypingHandler(object): self._member_typing_until.pop(member, None) self._member_last_federation_poke.pop(member, None) - self._push_update( - member=member, - typing=False, - ) + self._push_update(member=member, typing=False) def _push_update(self, member, typing): if self.hs.is_mine_id(member.user_id): # Only send updates for changes to our own users. run_in_background(self._push_remote, member, typing) - self._push_update_local( - member=member, - typing=typing - ) + self._push_update_local(member=member, typing=typing) @defer.inlineCallbacks def _push_remote(self, member, typing): @@ -223,9 +195,7 @@ class TypingHandler(object): now = self.clock.time_msec() self.wheel_timer.insert( - now=now, - obj=member, - then=now + FEDERATION_PING_INTERVAL, + now=now, obj=member, then=now + FEDERATION_PING_INTERVAL ) for domain in set(get_domain_from_id(u) for u in users): @@ -256,8 +226,7 @@ class TypingHandler(object): if user.domain != origin: logger.info( - "Got typing update from %r with bad 'user_id': %r", - origin, user_id, + "Got typing update from %r with bad 'user_id': %r", origin, user_id ) return @@ -268,15 +237,8 @@ class TypingHandler(object): logger.info("Got typing update from %s: %r", user_id, content) now = self.clock.time_msec() self._member_typing_until[member] = now + FEDERATION_TIMEOUT - self.wheel_timer.insert( - now=now, - obj=member, - then=now + FEDERATION_TIMEOUT, - ) - self._push_update_local( - member=member, - typing=content["typing"] - ) + self.wheel_timer.insert(now=now, obj=member, then=now + FEDERATION_TIMEOUT) + self._push_update_local(member=member, typing=content["typing"]) def _push_update_local(self, member, typing): room_set = self._room_typing.setdefault(member.room_id, set()) @@ -288,7 +250,7 @@ class TypingHandler(object): self._latest_room_serial += 1 self._room_serials[member.room_id] = self._latest_room_serial self._typing_stream_change_cache.entity_has_changed( - member.room_id, self._latest_room_serial, + member.room_id, self._latest_room_serial ) self.notifier.on_new_event( @@ -300,7 +262,7 @@ class TypingHandler(object): return [] changed_rooms = self._typing_stream_change_cache.get_all_entities_changed( - last_id, + last_id ) if changed_rooms is None: @@ -334,9 +296,7 @@ class TypingNotificationEventSource(object): return { "type": "m.typing", "room_id": room_id, - "content": { - "user_ids": list(typing), - }, + "content": {"user_ids": list(typing)}, } def get_new_events(self, from_key, room_ids, **kwargs): diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py index d36bcd6336..3acf772cd1 100644 --- a/synapse/http/__init__.py +++ b/synapse/http/__init__.py @@ -25,6 +25,7 @@ from synapse.api.errors import SynapseError class RequestTimedOutError(SynapseError): """Exception representing timeout of an outbound request""" + def __init__(self): super(RequestTimedOutError, self).__init__(504, "Timed out") @@ -40,15 +41,12 @@ def cancelled_to_request_timed_out_error(value, timeout): return value -ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$') +ACCESS_TOKEN_RE = re.compile(r"(\?.*access(_|%5[Ff])token=)[^&]*(.*)$") def redact_uri(uri): """Strips access tokens from the uri replaces with """ - return ACCESS_TOKEN_RE.sub( - r'\1\3', - uri - ) + return ACCESS_TOKEN_RE.sub(r"\1\3", uri) class QuieterFileBodyProducer(FileBodyProducer): @@ -57,6 +55,7 @@ class QuieterFileBodyProducer(FileBodyProducer): Workaround for https://github.com/matrix-org/synapse/issues/4003 / https://twistedmatrix.com/trac/ticket/6528 """ + def stopProducing(self): try: FileBodyProducer.stopProducing(self) diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py index 0e10e3f8f7..096619a8c2 100644 --- a/synapse/http/additional_resource.py +++ b/synapse/http/additional_resource.py @@ -28,6 +28,7 @@ class AdditionalResource(Resource): This class is also where we wrap the request handler with logging, metrics, and exception handling. """ + def __init__(self, hs, handler): """Initialise AdditionalResource diff --git a/synapse/http/client.py b/synapse/http/client.py index 5c073fff07..9bc7035c8d 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -103,8 +103,8 @@ class IPBlacklistingResolver(object): ip_address, self._ip_whitelist, self._ip_blacklist ): logger.info( - "Dropped %s from DNS resolution to %s due to blacklist" % - (ip_address, hostname) + "Dropped %s from DNS resolution to %s due to blacklist" + % (ip_address, hostname) ) has_bad_ip = True @@ -156,7 +156,7 @@ class BlacklistingAgentWrapper(Agent): self._ip_blacklist = ip_blacklist def request(self, method, uri, headers=None, bodyProducer=None): - h = urllib.parse.urlparse(uri.decode('ascii')) + h = urllib.parse.urlparse(uri.decode("ascii")) try: ip_address = IPAddress(h.hostname) @@ -164,10 +164,7 @@ class BlacklistingAgentWrapper(Agent): if check_against_blacklist( ip_address, self._ip_whitelist, self._ip_blacklist ): - logger.info( - "Blocking access to %s due to blacklist" % - (ip_address,) - ) + logger.info("Blocking access to %s due to blacklist" % (ip_address,)) e = SynapseError(403, "IP address blocked by IP blacklist entry") return defer.fail(Failure(e)) except Exception: @@ -206,7 +203,7 @@ class SimpleHttpClient(object): if hs.config.user_agent_suffix: self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix) - self.user_agent = self.user_agent.encode('ascii') + self.user_agent = self.user_agent.encode("ascii") if self._ip_blacklist: real_reactor = hs.get_reactor() @@ -520,8 +517,8 @@ class SimpleHttpClient(object): resp_headers = dict(response.headers.getAllRawHeaders()) if ( - b'Content-Length' in resp_headers - and int(resp_headers[b'Content-Length'][0]) > max_size + b"Content-Length" in resp_headers + and int(resp_headers[b"Content-Length"][0]) > max_size ): logger.warn("Requested URL is too large > %r bytes" % (self.max_size,)) raise SynapseError( @@ -546,18 +543,13 @@ class SimpleHttpClient(object): # This can happen e.g. because the body is too large. raise except Exception as e: - raise_from( - SynapseError( - 502, ("Failed to download remote body: %s" % e), - ), - e - ) + raise_from(SynapseError(502, ("Failed to download remote body: %s" % e)), e) defer.returnValue( ( length, resp_headers, - response.request.absoluteURI.decode('ascii'), + response.request.absoluteURI.decode("ascii"), response.code, ) ) @@ -647,7 +639,7 @@ def encode_urlencode_args(args): def encode_urlencode_arg(arg): if isinstance(arg, text_type): - return arg.encode('utf-8') + return arg.encode("utf-8") elif isinstance(arg, list): return [encode_urlencode_arg(i) for i in arg] else: diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index cd79ebab62..92a5b606c8 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -31,7 +31,7 @@ def parse_server_name(server_name): ValueError if the server name could not be parsed. """ try: - if server_name[-1] == ']': + if server_name[-1] == "]": # ipv6 literal, hopefully return server_name, None @@ -43,9 +43,7 @@ def parse_server_name(server_name): raise ValueError("Invalid server name '%s'" % server_name) -VALID_HOST_REGEX = re.compile( - "\\A[0-9a-zA-Z.-]+\\Z", -) +VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z") def parse_and_validate_server_name(server_name): @@ -67,17 +65,15 @@ def parse_and_validate_server_name(server_name): # that nobody is sneaking IP literals in that look like hostnames, etc. # look for ipv6 literals - if host[0] == '[': - if host[-1] != ']': - raise ValueError("Mismatched [...] in server name '%s'" % ( - server_name, - )) + if host[0] == "[": + if host[-1] != "]": + raise ValueError("Mismatched [...] in server name '%s'" % (server_name,)) return host, port # otherwise it should only be alphanumerics. if not VALID_HOST_REGEX.match(host): - raise ValueError("Server name '%s' contains invalid characters" % ( - server_name, - )) + raise ValueError( + "Server name '%s' contains invalid characters" % (server_name,) + ) return host, port diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index b4cbe97b41..414cde0777 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -48,7 +48,7 @@ WELL_KNOWN_INVALID_CACHE_PERIOD = 1 * 3600 WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600 logger = logging.getLogger(__name__) -well_known_cache = TTLCache('well-known') +well_known_cache = TTLCache("well-known") @implementer(IAgent) @@ -78,7 +78,9 @@ class MatrixFederationAgent(object): """ def __init__( - self, reactor, tls_client_options_factory, + self, + reactor, + tls_client_options_factory, _well_known_tls_policy=None, _srv_resolver=None, _well_known_cache=well_known_cache, @@ -100,9 +102,9 @@ class MatrixFederationAgent(object): if _well_known_tls_policy is not None: # the param is called 'contextFactory', but actually passing a # contextfactory is deprecated, and it expects an IPolicyForHTTPS. - agent_args['contextFactory'] = _well_known_tls_policy + agent_args["contextFactory"] = _well_known_tls_policy _well_known_agent = RedirectAgent( - Agent(self._reactor, pool=self._pool, **agent_args), + Agent(self._reactor, pool=self._pool, **agent_args) ) self._well_known_agent = _well_known_agent @@ -149,7 +151,7 @@ class MatrixFederationAgent(object): tls_options = None else: tls_options = self._tls_client_options_factory.get_options( - res.tls_server_name.decode("ascii"), + res.tls_server_name.decode("ascii") ) # make sure that the Host header is set correctly @@ -158,14 +160,14 @@ class MatrixFederationAgent(object): else: headers = headers.copy() - if not headers.hasHeader(b'host'): - headers.addRawHeader(b'host', res.host_header) + if not headers.hasHeader(b"host"): + headers.addRawHeader(b"host", res.host_header) class EndpointFactory(object): @staticmethod def endpointForURI(_uri): ep = LoggingHostnameEndpoint( - self._reactor, res.target_host, res.target_port, + self._reactor, res.target_host, res.target_port ) if tls_options is not None: ep = wrapClientTLS(tls_options, ep) @@ -203,21 +205,25 @@ class MatrixFederationAgent(object): port = parsed_uri.port if port == -1: port = 8448 - defer.returnValue(_RoutingResult( - host_header=parsed_uri.netloc, - tls_server_name=parsed_uri.host, - target_host=parsed_uri.host, - target_port=port, - )) + defer.returnValue( + _RoutingResult( + host_header=parsed_uri.netloc, + tls_server_name=parsed_uri.host, + target_host=parsed_uri.host, + target_port=port, + ) + ) if parsed_uri.port != -1: # there is an explicit port - defer.returnValue(_RoutingResult( - host_header=parsed_uri.netloc, - tls_server_name=parsed_uri.host, - target_host=parsed_uri.host, - target_port=parsed_uri.port, - )) + defer.returnValue( + _RoutingResult( + host_header=parsed_uri.netloc, + tls_server_name=parsed_uri.host, + target_host=parsed_uri.host, + target_port=parsed_uri.port, + ) + ) if lookup_well_known: # try a .well-known lookup @@ -229,8 +235,8 @@ class MatrixFederationAgent(object): # parse the server name in the .well-known response into host/port. # (This code is lifted from twisted.web.client.URI.fromBytes). - if b':' in well_known_server: - well_known_host, well_known_port = well_known_server.rsplit(b':', 1) + if b":" in well_known_server: + well_known_host, well_known_port = well_known_server.rsplit(b":", 1) try: well_known_port = int(well_known_port) except ValueError: @@ -264,21 +270,27 @@ class MatrixFederationAgent(object): port = 8448 logger.debug( "No SRV record for %s, using %s:%i", - parsed_uri.host.decode("ascii"), target_host.decode("ascii"), port, + parsed_uri.host.decode("ascii"), + target_host.decode("ascii"), + port, ) else: target_host, port = pick_server_from_list(server_list) logger.debug( "Picked %s:%i from SRV records for %s", - target_host.decode("ascii"), port, parsed_uri.host.decode("ascii"), + target_host.decode("ascii"), + port, + parsed_uri.host.decode("ascii"), ) - defer.returnValue(_RoutingResult( - host_header=parsed_uri.netloc, - tls_server_name=parsed_uri.host, - target_host=target_host, - target_port=port, - )) + defer.returnValue( + _RoutingResult( + host_header=parsed_uri.netloc, + tls_server_name=parsed_uri.host, + target_host=target_host, + target_port=port, + ) + ) @defer.inlineCallbacks def _get_well_known(self, server_name): @@ -318,18 +330,18 @@ class MatrixFederationAgent(object): - None if there was no .well-known file. - INVALID_WELL_KNOWN if the .well-known was invalid """ - uri = b"https://%s/.well-known/matrix/server" % (server_name, ) + uri = b"https://%s/.well-known/matrix/server" % (server_name,) uri_str = uri.decode("ascii") logger.info("Fetching %s", uri_str) try: response = yield make_deferred_yieldable( - self._well_known_agent.request(b"GET", uri), + self._well_known_agent.request(b"GET", uri) ) body = yield make_deferred_yieldable(readBody(response)) if response.code != 200: - raise Exception("Non-200 response %s" % (response.code, )) + raise Exception("Non-200 response %s" % (response.code,)) - parsed_body = json.loads(body.decode('utf-8')) + parsed_body = json.loads(body.decode("utf-8")) logger.info("Response from .well-known: %s", parsed_body) if not isinstance(parsed_body, dict): raise Exception("not a dict") @@ -347,8 +359,7 @@ class MatrixFederationAgent(object): result = parsed_body["m.server"].encode("ascii") cache_period = _cache_period_from_headers( - response.headers, - time_now=self._reactor.seconds, + response.headers, time_now=self._reactor.seconds ) if cache_period is None: cache_period = WELL_KNOWN_DEFAULT_CACHE_PERIOD @@ -364,6 +375,7 @@ class MatrixFederationAgent(object): @implementer(IStreamClientEndpoint) class LoggingHostnameEndpoint(object): """A wrapper for HostnameEndpint which logs when it connects""" + def __init__(self, reactor, host, port, *args, **kwargs): self.host = host self.port = port @@ -377,17 +389,17 @@ class LoggingHostnameEndpoint(object): def _cache_period_from_headers(headers, time_now=time.time): cache_controls = _parse_cache_control(headers) - if b'no-store' in cache_controls: + if b"no-store" in cache_controls: return 0 - if b'max-age' in cache_controls: + if b"max-age" in cache_controls: try: - max_age = int(cache_controls[b'max-age']) + max_age = int(cache_controls[b"max-age"]) return max_age except ValueError: pass - expires = headers.getRawHeaders(b'expires') + expires = headers.getRawHeaders(b"expires") if expires is not None: try: expires_date = stringToDatetime(expires[-1]) @@ -403,9 +415,9 @@ def _cache_period_from_headers(headers, time_now=time.time): def _parse_cache_control(headers): cache_controls = {} - for hdr in headers.getRawHeaders(b'cache-control', []): - for directive in hdr.split(b','): - splits = [x.strip() for x in directive.split(b'=', 1)] + for hdr in headers.getRawHeaders(b"cache-control", []): + for directive in hdr.split(b","): + splits = [x.strip() for x in directive.split(b"=", 1)] k = splits[0].lower() v = splits[1] if len(splits) > 1 else None cache_controls[k] = v diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py index 71830c549d..1f22f78a75 100644 --- a/synapse/http/federation/srv_resolver.py +++ b/synapse/http/federation/srv_resolver.py @@ -45,6 +45,7 @@ class Server(object): expires (int): when the cache should expire this record - in *seconds* since the epoch """ + host = attr.ib() port = attr.ib() priority = attr.ib(default=0) @@ -79,9 +80,7 @@ def pick_server_from_list(server_list): return s.host, s.port # this should be impossible. - raise RuntimeError( - "pick_server_from_list got to end of eligible server list.", - ) + raise RuntimeError("pick_server_from_list got to end of eligible server list.") class SrvResolver(object): @@ -95,6 +94,7 @@ class SrvResolver(object): cache (dict): cache object get_time (callable): clock implementation. Should return seconds since the epoch """ + def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time): self._dns_client = dns_client self._cache = cache @@ -124,7 +124,7 @@ class SrvResolver(object): try: answers, _, _ = yield make_deferred_yieldable( - self._dns_client.lookupService(service_name), + self._dns_client.lookupService(service_name) ) except DNSNameError: # TODO: cache this. We can get the SOA out of the exception, and use @@ -136,17 +136,18 @@ class SrvResolver(object): cache_entry = self._cache.get(service_name, None) if cache_entry: logger.warn( - "Failed to resolve %r, falling back to cache. %r", - service_name, e + "Failed to resolve %r, falling back to cache. %r", service_name, e ) defer.returnValue(list(cache_entry)) else: raise e - if (len(answers) == 1 - and answers[0].type == dns.SRV - and answers[0].payload - and answers[0].payload.target == dns.Name(b'.')): + if ( + len(answers) == 1 + and answers[0].type == dns.SRV + and answers[0].payload + and answers[0].payload.target == dns.Name(b".") + ): raise ConnectError("Service %s unavailable" % service_name) servers = [] @@ -157,13 +158,15 @@ class SrvResolver(object): payload = answer.payload - servers.append(Server( - host=payload.target.name, - port=payload.port, - priority=payload.priority, - weight=payload.weight, - expires=now + answer.ttl, - )) + servers.append( + Server( + host=payload.target.name, + port=payload.port, + priority=payload.priority, + weight=payload.weight, + expires=now + answer.ttl, + ) + ) self._cache[service_name] = list(servers) defer.returnValue(servers) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 663ea72a7a..5ef8bb60a3 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -54,10 +54,12 @@ from synapse.util.metrics import Measure logger = logging.getLogger(__name__) -outgoing_requests_counter = Counter("synapse_http_matrixfederationclient_requests", - "", ["method"]) -incoming_responses_counter = Counter("synapse_http_matrixfederationclient_responses", - "", ["method", "code"]) +outgoing_requests_counter = Counter( + "synapse_http_matrixfederationclient_requests", "", ["method"] +) +incoming_responses_counter = Counter( + "synapse_http_matrixfederationclient_responses", "", ["method", "code"] +) MAX_LONG_RETRIES = 10 @@ -137,11 +139,7 @@ def _handle_json_response(reactor, timeout_sec, request, response): check_content_type_is_json(response.headers) d = treq.json_content(response) - d = timeout_deferred( - d, - timeout=timeout_sec, - reactor=reactor, - ) + d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor) body = yield make_deferred_yieldable(d) except Exception as e: @@ -157,7 +155,7 @@ def _handle_json_response(reactor, timeout_sec, request, response): request.txn_id, request.destination, response.code, - response.phrase.decode('ascii', errors='replace'), + response.phrase.decode("ascii", errors="replace"), ) defer.returnValue(body) @@ -181,7 +179,7 @@ class MatrixFederationHttpClient(object): # We need to use a DNS resolver which filters out blacklisted IP # addresses, to prevent DNS rebinding. nameResolver = IPBlacklistingResolver( - real_reactor, None, hs.config.federation_ip_range_blacklist, + real_reactor, None, hs.config.federation_ip_range_blacklist ) @implementer(IReactorPluggableNameResolver) @@ -194,21 +192,19 @@ class MatrixFederationHttpClient(object): self.reactor = Reactor() - self.agent = MatrixFederationAgent( - self.reactor, - tls_client_options_factory, - ) + self.agent = MatrixFederationAgent(self.reactor, tls_client_options_factory) # Use a BlacklistingAgentWrapper to prevent circumventing the IP # blacklist via IP literals in server names self.agent = BlacklistingAgentWrapper( - self.agent, self.reactor, + self.agent, + self.reactor, ip_blacklist=hs.config.federation_ip_range_blacklist, ) self.clock = hs.get_clock() self._store = hs.get_datastore() - self.version_string_bytes = hs.version_string.encode('ascii') + self.version_string_bytes = hs.version_string.encode("ascii") self.default_timeout = 60 def schedule(x): @@ -218,10 +214,7 @@ class MatrixFederationHttpClient(object): @defer.inlineCallbacks def _send_request_with_optional_trailing_slash( - self, - request, - try_trailing_slash_on_400=False, - **send_request_args + self, request, try_trailing_slash_on_400=False, **send_request_args ): """Wrapper for _send_request which can optionally retry the request upon receiving a combination of a 400 HTTP response code and a @@ -244,9 +237,7 @@ class MatrixFederationHttpClient(object): Deferred[Dict]: Parsed JSON response body. """ try: - response = yield self._send_request( - request, **send_request_args - ) + response = yield self._send_request(request, **send_request_args) except HttpResponseException as e: # Received an HTTP error > 300. Check if it meets the requirements # to retry with a trailing slash @@ -262,9 +253,7 @@ class MatrixFederationHttpClient(object): logger.info("Retrying request with trailing slash") request.path += "/" - response = yield self._send_request( - request, **send_request_args - ) + response = yield self._send_request(request, **send_request_args) defer.returnValue(response) @@ -329,8 +318,8 @@ class MatrixFederationHttpClient(object): _sec_timeout = self.default_timeout if ( - self.hs.config.federation_domain_whitelist is not None and - request.destination not in self.hs.config.federation_domain_whitelist + self.hs.config.federation_domain_whitelist is not None + and request.destination not in self.hs.config.federation_domain_whitelist ): raise FederationDeniedError(request.destination) @@ -350,9 +339,7 @@ class MatrixFederationHttpClient(object): else: query_bytes = b"" - headers_dict = { - b"User-Agent": [self.version_string_bytes], - } + headers_dict = {b"User-Agent": [self.version_string_bytes]} with limiter: # XXX: Would be much nicer to retry only at the transaction-layer @@ -362,16 +349,14 @@ class MatrixFederationHttpClient(object): else: retries_left = MAX_SHORT_RETRIES - url_bytes = urllib.parse.urlunparse(( - b"matrix", destination_bytes, - path_bytes, None, query_bytes, b"", - )) - url_str = url_bytes.decode('ascii') + url_bytes = urllib.parse.urlunparse( + (b"matrix", destination_bytes, path_bytes, None, query_bytes, b"") + ) + url_str = url_bytes.decode("ascii") - url_to_sign_bytes = urllib.parse.urlunparse(( - b"", b"", - path_bytes, None, query_bytes, b"", - )) + url_to_sign_bytes = urllib.parse.urlunparse( + (b"", b"", path_bytes, None, query_bytes, b"") + ) while True: try: @@ -379,26 +364,27 @@ class MatrixFederationHttpClient(object): if json: headers_dict[b"Content-Type"] = [b"application/json"] auth_headers = self.build_auth_headers( - destination_bytes, method_bytes, url_to_sign_bytes, - json, + destination_bytes, method_bytes, url_to_sign_bytes, json ) data = encode_canonical_json(json) producer = QuieterFileBodyProducer( - BytesIO(data), - cooperator=self._cooperator, + BytesIO(data), cooperator=self._cooperator ) else: producer = None auth_headers = self.build_auth_headers( - destination_bytes, method_bytes, url_to_sign_bytes, + destination_bytes, method_bytes, url_to_sign_bytes ) headers_dict[b"Authorization"] = auth_headers logger.info( "{%s} [%s] Sending request: %s %s; timeout %fs", - request.txn_id, request.destination, request.method, - url_str, _sec_timeout, + request.txn_id, + request.destination, + request.method, + url_str, + _sec_timeout, ) try: @@ -430,7 +416,7 @@ class MatrixFederationHttpClient(object): request.txn_id, request.destination, response.code, - response.phrase.decode('ascii', errors='replace'), + response.phrase.decode("ascii", errors="replace"), ) if 200 <= response.code < 300: @@ -440,9 +426,7 @@ class MatrixFederationHttpClient(object): # Update transactions table? d = treq.content(response) d = timeout_deferred( - d, - timeout=_sec_timeout, - reactor=self.reactor, + d, timeout=_sec_timeout, reactor=self.reactor ) try: @@ -460,9 +444,7 @@ class MatrixFederationHttpClient(object): ) body = None - e = HttpResponseException( - response.code, response.phrase, body - ) + e = HttpResponseException(response.code, response.phrase, body) # Retry if the error is a 429 (Too Many Requests), # otherwise just raise a standard HttpResponseException @@ -521,7 +503,7 @@ class MatrixFederationHttpClient(object): defer.returnValue(response) def build_auth_headers( - self, destination, method, url_bytes, content=None, destination_is=None, + self, destination, method, url_bytes, content=None, destination_is=None ): """ Builds the Authorization headers for a federation request @@ -538,11 +520,7 @@ class MatrixFederationHttpClient(object): Returns: list[bytes]: a list of headers to be added as "Authorization:" headers """ - request = { - "method": method, - "uri": url_bytes, - "origin": self.server_name, - } + request = {"method": method, "uri": url_bytes, "origin": self.server_name} if destination is not None: request["destination"] = destination @@ -558,20 +536,28 @@ class MatrixFederationHttpClient(object): auth_headers = [] for key, sig in request["signatures"][self.server_name].items(): - auth_headers.append(( - "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % ( - self.server_name, key, sig, - )).encode('ascii') + auth_headers.append( + ( + 'X-Matrix origin=%s,key="%s",sig="%s"' + % (self.server_name, key, sig) + ).encode("ascii") ) return auth_headers @defer.inlineCallbacks - def put_json(self, destination, path, args={}, data={}, - json_data_callback=None, - long_retries=False, timeout=None, - ignore_backoff=False, - backoff_on_404=False, - try_trailing_slash_on_400=False): + def put_json( + self, + destination, + path, + args={}, + data={}, + json_data_callback=None, + long_retries=False, + timeout=None, + ignore_backoff=False, + backoff_on_404=False, + try_trailing_slash_on_400=False, + ): """ Sends the specifed json data using PUT Args: @@ -635,14 +621,22 @@ class MatrixFederationHttpClient(object): ) body = yield _handle_json_response( - self.reactor, self.default_timeout, request, response, + self.reactor, self.default_timeout, request, response ) defer.returnValue(body) @defer.inlineCallbacks - def post_json(self, destination, path, data={}, long_retries=False, - timeout=None, ignore_backoff=False, args={}): + def post_json( + self, + destination, + path, + data={}, + long_retries=False, + timeout=None, + ignore_backoff=False, + args={}, + ): """ Sends the specifed json data using POST Args: @@ -681,11 +675,7 @@ class MatrixFederationHttpClient(object): """ request = MatrixFederationRequest( - method="POST", - destination=destination, - path=path, - query=args, - json=data, + method="POST", destination=destination, path=path, query=args, json=data ) response = yield self._send_request( @@ -701,14 +691,21 @@ class MatrixFederationHttpClient(object): _sec_timeout = self.default_timeout body = yield _handle_json_response( - self.reactor, _sec_timeout, request, response, + self.reactor, _sec_timeout, request, response ) defer.returnValue(body) @defer.inlineCallbacks - def get_json(self, destination, path, args=None, retry_on_dns_fail=True, - timeout=None, ignore_backoff=False, - try_trailing_slash_on_400=False): + def get_json( + self, + destination, + path, + args=None, + retry_on_dns_fail=True, + timeout=None, + ignore_backoff=False, + try_trailing_slash_on_400=False, + ): """ GETs some json from the given host homeserver and path Args: @@ -745,10 +742,7 @@ class MatrixFederationHttpClient(object): remote, due to e.g. DNS failures, connection timeouts etc. """ request = MatrixFederationRequest( - method="GET", - destination=destination, - path=path, - query=args, + method="GET", destination=destination, path=path, query=args ) response = yield self._send_request_with_optional_trailing_slash( @@ -761,14 +755,21 @@ class MatrixFederationHttpClient(object): ) body = yield _handle_json_response( - self.reactor, self.default_timeout, request, response, + self.reactor, self.default_timeout, request, response ) defer.returnValue(body) @defer.inlineCallbacks - def delete_json(self, destination, path, long_retries=False, - timeout=None, ignore_backoff=False, args={}): + def delete_json( + self, + destination, + path, + long_retries=False, + timeout=None, + ignore_backoff=False, + args={}, + ): """Send a DELETE request to the remote expecting some json response Args: @@ -802,10 +803,7 @@ class MatrixFederationHttpClient(object): remote, due to e.g. DNS failures, connection timeouts etc. """ request = MatrixFederationRequest( - method="DELETE", - destination=destination, - path=path, - query=args, + method="DELETE", destination=destination, path=path, query=args ) response = yield self._send_request( @@ -816,14 +814,21 @@ class MatrixFederationHttpClient(object): ) body = yield _handle_json_response( - self.reactor, self.default_timeout, request, response, + self.reactor, self.default_timeout, request, response ) defer.returnValue(body) @defer.inlineCallbacks - def get_file(self, destination, path, output_stream, args={}, - retry_on_dns_fail=True, max_size=None, - ignore_backoff=False): + def get_file( + self, + destination, + path, + output_stream, + args={}, + retry_on_dns_fail=True, + max_size=None, + ignore_backoff=False, + ): """GETs a file from a given homeserver Args: destination (str): The remote server to send the HTTP request to. @@ -848,16 +853,11 @@ class MatrixFederationHttpClient(object): remote, due to e.g. DNS failures, connection timeouts etc. """ request = MatrixFederationRequest( - method="GET", - destination=destination, - path=path, - query=args, + method="GET", destination=destination, path=path, query=args ) response = yield self._send_request( - request, - retry_on_dns_fail=retry_on_dns_fail, - ignore_backoff=ignore_backoff, + request, retry_on_dns_fail=retry_on_dns_fail, ignore_backoff=ignore_backoff ) headers = dict(response.headers.getAllRawHeaders()) @@ -879,7 +879,7 @@ class MatrixFederationHttpClient(object): request.txn_id, request.destination, response.code, - response.phrase.decode('ascii', errors='replace'), + response.phrase.decode("ascii", errors="replace"), length, ) defer.returnValue((length, headers)) @@ -896,11 +896,13 @@ class _ReadBodyToFileProtocol(protocol.Protocol): self.stream.write(data) self.length += len(data) if self.max_size is not None and self.length >= self.max_size: - self.deferred.errback(SynapseError( - 502, - "Requested file is too large > %r bytes" % (self.max_size,), - Codes.TOO_LARGE, - )) + self.deferred.errback( + SynapseError( + 502, + "Requested file is too large > %r bytes" % (self.max_size,), + Codes.TOO_LARGE, + ) + ) self.deferred = defer.Deferred() self.transport.loseConnection() @@ -920,8 +922,7 @@ def _readBodyToFile(response, stream, max_size): def _flatten_response_never_received(e): if hasattr(e, "reasons"): reasons = ", ".join( - _flatten_response_never_received(f.value) - for f in e.reasons + _flatten_response_never_received(f.value) for f in e.reasons ) return "%s:[%s]" % (type(e).__name__, reasons) @@ -943,16 +944,15 @@ def check_content_type_is_json(headers): """ c_type = headers.getRawHeaders(b"Content-Type") if c_type is None: - raise RequestSendFailed(RuntimeError( - "No Content-Type header" - ), can_retry=False) + raise RequestSendFailed(RuntimeError("No Content-Type header"), can_retry=False) - c_type = c_type[0].decode('ascii') # only the first header + c_type = c_type[0].decode("ascii") # only the first header val, options = cgi.parse_header(c_type) if val != "application/json": - raise RequestSendFailed(RuntimeError( - "Content-Type not application/json: was '%s'" % c_type - ), can_retry=False) + raise RequestSendFailed( + RuntimeError("Content-Type not application/json: was '%s'" % c_type), + can_retry=False, + ) def encode_query_args(args): @@ -967,4 +967,4 @@ def encode_query_args(args): query_bytes = urllib.parse.urlencode(encoded_args, True) - return query_bytes.encode('utf8') + return query_bytes.encode("utf8") diff --git a/synapse/http/server.py b/synapse/http/server.py index 16fb7935da..6fd13e87d1 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -81,9 +81,7 @@ def wrap_json_request_handler(h): yield h(self, request) except SynapseError as e: code = e.code - logger.info( - "%s SynapseError: %s - %s", request, code, e.msg - ) + logger.info("%s SynapseError: %s - %s", request, code, e.msg) # Only respond with an error response if we haven't already started # writing, otherwise lets just kill the connection @@ -96,7 +94,10 @@ def wrap_json_request_handler(h): pass else: respond_with_json( - request, code, e.error_dict(), send_cors=True, + request, + code, + e.error_dict(), + send_cors=True, pretty_print=_request_user_agent_is_curl(request), ) @@ -124,10 +125,7 @@ def wrap_json_request_handler(h): respond_with_json( request, 500, - { - "error": "Internal server error", - "errcode": Codes.UNKNOWN, - }, + {"error": "Internal server error", "errcode": Codes.UNKNOWN}, send_cors=True, pretty_print=_request_user_agent_is_curl(request), ) @@ -143,6 +141,7 @@ def wrap_html_request_handler(h): The handler method must have a signature of "handle_foo(self, request)", where "request" must be a SynapseRequest. """ + def wrapped_request_handler(self, request): d = defer.maybeDeferred(h, self, request) d.addErrback(_return_html_error, request) @@ -164,9 +163,7 @@ def _return_html_error(f, request): msg = cme.msg if isinstance(cme, SynapseError): - logger.info( - "%s SynapseError: %s - %s", request, code, msg - ) + logger.info("%s SynapseError: %s - %s", request, code, msg) else: logger.error( "Failed handle request %r", @@ -183,9 +180,7 @@ def _return_html_error(f, request): exc_info=(f.type, f.value, f.getTracebackObject()), ) - body = HTML_ERROR_TEMPLATE.format( - code=code, msg=cgi.escape(msg), - ).encode("utf-8") + body = HTML_ERROR_TEMPLATE.format(code=code, msg=cgi.escape(msg)).encode("utf-8") request.setResponseCode(code) request.setHeader(b"Content-Type", b"text/html; charset=utf-8") request.setHeader(b"Content-Length", b"%i" % (len(body),)) @@ -205,6 +200,7 @@ def wrap_async_request_handler(h): The handler may return a deferred, in which case the completion of the request isn't logged until the deferred completes. """ + @defer.inlineCallbacks def wrapped_async_request_handler(self, request): with request.processing(): @@ -306,12 +302,14 @@ class JsonResource(HttpServer, resource.Resource): # URL again (as it was decoded by _get_handler_for request), as # ASCII because it's a URL, and then decode it to get the UTF-8 # characters that were quoted. - return urllib.parse.unquote(s.encode('ascii')).decode('utf8') + return urllib.parse.unquote(s.encode("ascii")).decode("utf8") - kwargs = intern_dict({ - name: _unquote(value) if value else value - for name, value in group_dict.items() - }) + kwargs = intern_dict( + { + name: _unquote(value) if value else value + for name, value in group_dict.items() + } + ) callback_return = yield callback(request, **kwargs) if callback_return is not None: @@ -339,7 +337,7 @@ class JsonResource(HttpServer, resource.Resource): # Loop through all the registered callbacks to check if the method # and path regex match for path_entry in self.path_regexs.get(request.method, []): - m = path_entry.pattern.match(request.path.decode('ascii')) + m = path_entry.pattern.match(request.path.decode("ascii")) if m: # We found a match! return path_entry.callback, m.groupdict() @@ -347,11 +345,14 @@ class JsonResource(HttpServer, resource.Resource): # Huh. No one wanted to handle that? Fiiiiiine. Send 400. return _unrecognised_request_handler, {} - def _send_response(self, request, code, response_json_object, - response_code_message=None): + def _send_response( + self, request, code, response_json_object, response_code_message=None + ): # TODO: Only enable CORS for the requests that need it. respond_with_json( - request, code, response_json_object, + request, + code, + response_json_object, send_cors=True, response_code_message=response_code_message, pretty_print=_request_user_agent_is_curl(request), @@ -395,7 +396,7 @@ class RootRedirect(resource.Resource): self.url = path def render_GET(self, request): - return redirectTo(self.url.encode('ascii'), request) + return redirectTo(self.url.encode("ascii"), request) def getChild(self, name, request): if len(name) == 0: @@ -403,16 +404,22 @@ class RootRedirect(resource.Resource): return resource.Resource.getChild(self, name, request) -def respond_with_json(request, code, json_object, send_cors=False, - response_code_message=None, pretty_print=False, - canonical_json=True): +def respond_with_json( + request, + code, + json_object, + send_cors=False, + response_code_message=None, + pretty_print=False, + canonical_json=True, +): # could alternatively use request.notifyFinish() and flip a flag when # the Deferred fires, but since the flag is RIGHT THERE it seems like # a waste. if request._disconnected: logger.warn( - "Not sending response to request %s, already disconnected.", - request) + "Not sending response to request %s, already disconnected.", request + ) return if pretty_print: @@ -425,14 +432,17 @@ def respond_with_json(request, code, json_object, send_cors=False, json_bytes = json.dumps(json_object).encode("utf-8") return respond_with_json_bytes( - request, code, json_bytes, + request, + code, + json_bytes, send_cors=send_cors, response_code_message=response_code_message, ) -def respond_with_json_bytes(request, code, json_bytes, send_cors=False, - response_code_message=None): +def respond_with_json_bytes( + request, code, json_bytes, send_cors=False, response_code_message=None +): """Sends encoded JSON in response to the given request. Args: @@ -474,7 +484,7 @@ def set_cors_headers(request): ) request.setHeader( b"Access-Control-Allow-Headers", - b"Origin, X-Requested-With, Content-Type, Accept, Authorization" + b"Origin, X-Requested-With, Content-Type, Accept, Authorization", ) @@ -498,9 +508,7 @@ def finish_request(request): def _request_user_agent_is_curl(request): - user_agents = request.requestHeaders.getRawHeaders( - b"User-Agent", default=[] - ) + user_agents = request.requestHeaders.getRawHeaders(b"User-Agent", default=[]) for user_agent in user_agents: if b"curl" in user_agent: return True diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 197c652850..cd8415acd5 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -48,7 +48,7 @@ def parse_integer(request, name, default=None, required=False): def parse_integer_from_args(args, name, default=None, required=False): if not isinstance(name, bytes): - name = name.encode('ascii') + name = name.encode("ascii") if name in args: try: @@ -89,18 +89,14 @@ def parse_boolean(request, name, default=None, required=False): def parse_boolean_from_args(args, name, default=None, required=False): if not isinstance(name, bytes): - name = name.encode('ascii') + name = name.encode("ascii") if name in args: try: - return { - b"true": True, - b"false": False, - }[args[name][0]] + return {b"true": True, b"false": False}[args[name][0]] except Exception: message = ( - "Boolean query parameter %r must be one of" - " ['true', 'false']" + "Boolean query parameter %r must be one of" " ['true', 'false']" ) % (name,) raise SynapseError(400, message) else: @@ -111,8 +107,15 @@ def parse_boolean_from_args(args, name, default=None, required=False): return default -def parse_string(request, name, default=None, required=False, - allowed_values=None, param_type="string", encoding='ascii'): +def parse_string( + request, + name, + default=None, + required=False, + allowed_values=None, + param_type="string", + encoding="ascii", +): """ Parse a string parameter from the request query string. @@ -145,11 +148,18 @@ def parse_string(request, name, default=None, required=False, ) -def parse_string_from_args(args, name, default=None, required=False, - allowed_values=None, param_type="string", encoding='ascii'): +def parse_string_from_args( + args, + name, + default=None, + required=False, + allowed_values=None, + param_type="string", + encoding="ascii", +): if not isinstance(name, bytes): - name = name.encode('ascii') + name = name.encode("ascii") if name in args: value = args[name][0] @@ -159,7 +169,8 @@ def parse_string_from_args(args, name, default=None, required=False, if allowed_values is not None and value not in allowed_values: message = "Query parameter %r must be one of [%s]" % ( - name, ", ".join(repr(v) for v in allowed_values) + name, + ", ".join(repr(v) for v in allowed_values), ) raise SynapseError(400, message) else: @@ -201,7 +212,7 @@ def parse_json_value_from_request(request, allow_empty_body=False): # Decode to Unicode so that simplejson will return Unicode strings on # Python 2 try: - content_unicode = content_bytes.decode('utf8') + content_unicode = content_bytes.decode("utf8") except UnicodeDecodeError: logger.warn("Unable to decode UTF-8") raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) @@ -227,9 +238,7 @@ def parse_json_object_from_request(request, allow_empty_body=False): SynapseError if the request body couldn't be decoded as JSON or if it wasn't a JSON object. """ - content = parse_json_value_from_request( - request, allow_empty_body=allow_empty_body, - ) + content = parse_json_value_from_request(request, allow_empty_body=allow_empty_body) if allow_empty_body and content is None: return {} diff --git a/synapse/http/site.py b/synapse/http/site.py index e508c0bd4f..93f679ea48 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -46,10 +46,11 @@ class SynapseRequest(Request): Attributes: logcontext(LoggingContext) : the log context for this request """ + def __init__(self, site, channel, *args, **kw): Request.__init__(self, channel, *args, **kw) self.site = site - self._channel = channel # this is used by the tests + self._channel = channel # this is used by the tests self.authenticated_entity = None self.start_time = 0 @@ -72,12 +73,12 @@ class SynapseRequest(Request): def __repr__(self): # We overwrite this so that we don't log ``access_token`` - return '<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>' % ( + return "<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>" % ( self.__class__.__name__, id(self), self.get_method(), self.get_redacted_uri(), - self.clientproto.decode('ascii', errors='replace'), + self.clientproto.decode("ascii", errors="replace"), self.site.site_tag, ) @@ -87,7 +88,7 @@ class SynapseRequest(Request): def get_redacted_uri(self): uri = self.uri if isinstance(uri, bytes): - uri = self.uri.decode('ascii') + uri = self.uri.decode("ascii") return redact_uri(uri) def get_method(self): @@ -102,7 +103,7 @@ class SynapseRequest(Request): """ method = self.method if isinstance(method, bytes): - method = self.method.decode('ascii') + method = self.method.decode("ascii") return method def get_user_agent(self): @@ -134,8 +135,7 @@ class SynapseRequest(Request): # dispatching to the handler, so that the handler # can update the servlet name in the request # metrics - requests_counter.labels(self.get_method(), - self.request_metrics.name).inc() + requests_counter.labels(self.get_method(), self.request_metrics.name).inc() @contextlib.contextmanager def processing(self): @@ -200,7 +200,7 @@ class SynapseRequest(Request): # the client disconnects. with PreserveLoggingContext(self.logcontext): logger.warn( - "Error processing request %r: %s %s", self, reason.type, reason.value, + "Error processing request %r: %s %s", self, reason.type, reason.value ) if not self._is_processing: @@ -222,7 +222,7 @@ class SynapseRequest(Request): self.start_time = time.time() self.request_metrics = RequestMetrics() self.request_metrics.start( - self.start_time, name=servlet_name, method=self.get_method(), + self.start_time, name=servlet_name, method=self.get_method() ) self.site.access_logger.info( @@ -230,7 +230,7 @@ class SynapseRequest(Request): self.getClientIP(), self.site.site_tag, self.get_method(), - self.get_redacted_uri() + self.get_redacted_uri(), ) def _finished_processing(self): @@ -282,7 +282,7 @@ class SynapseRequest(Request): self.site.access_logger.info( "%s - %s - {%s}" " Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)" - " %sB %s \"%s %s %s\" \"%s\" [%d dbevts]", + ' %sB %s "%s %s %s" "%s" [%d dbevts]', self.getClientIP(), self.site.site_tag, authenticated_entity, @@ -297,7 +297,7 @@ class SynapseRequest(Request): code, self.get_method(), self.get_redacted_uri(), - self.clientproto.decode('ascii', errors='replace'), + self.clientproto.decode("ascii", errors="replace"), user_agent, usage.evt_db_fetch_count, ) @@ -316,14 +316,19 @@ class XForwardedForRequest(SynapseRequest): Add a layer on top of another request that only uses the value of an X-Forwarded-For header as the result of C{getClientIP}. """ + def getClientIP(self): """ @return: The client address (the first address) in the value of the I{X-Forwarded-For header}. If the header is not present, return C{b"-"}. """ - return self.requestHeaders.getRawHeaders( - b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip().decode('ascii') + return ( + self.requestHeaders.getRawHeaders(b"x-forwarded-for", [b"-"])[0] + .split(b",")[0] + .strip() + .decode("ascii") + ) class SynapseRequestFactory(object): @@ -343,8 +348,17 @@ class SynapseSite(Site): Subclass of a twisted http Site that does access logging with python's standard logging """ - def __init__(self, logger_name, site_tag, config, resource, - server_version_string, *args, **kwargs): + + def __init__( + self, + logger_name, + site_tag, + config, + resource, + server_version_string, + *args, + **kwargs + ): Site.__init__(self, resource, *args, **kwargs) self.site_tag = site_tag @@ -352,7 +366,7 @@ class SynapseSite(Site): proxied = config.get("x_forwarded", False) self.requestFactory = SynapseRequestFactory(self, proxied) self.access_logger = logging.getLogger(logger_name) - self.server_version_string = server_version_string.encode('ascii') + self.server_version_string = server_version_string.encode("ascii") def log(self, request): pass diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index 8aee14a8a8..1f30179b51 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -231,10 +231,7 @@ class BucketCollector(object): res.append(["+Inf", sum(data.values())]) metric = HistogramMetricFamily( - self.name, - "", - buckets=res, - sum_value=sum([x * y for x, y in data.items()]), + self.name, "", buckets=res, sum_value=sum([x * y for x, y in data.items()]) ) yield metric @@ -263,7 +260,7 @@ class CPUMetrics(object): ticks_per_sec = 100 try: # Try and get the system config - ticks_per_sec = os.sysconf('SC_CLK_TCK') + ticks_per_sec = os.sysconf("SC_CLK_TCK") except (ValueError, TypeError, AttributeError): pass diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 037f1c490e..167e2c068a 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -60,8 +60,10 @@ _background_process_db_txn_count = Counter( _background_process_db_txn_duration = Counter( "synapse_background_process_db_txn_duration_seconds", - ("Seconds spent by background processes waiting for database " - "transactions, excluding scheduling time"), + ( + "Seconds spent by background processes waiting for database " + "transactions, excluding scheduling time" + ), ["name"], registry=None, ) @@ -94,6 +96,7 @@ class _Collector(object): Ensures that all of the metrics are up-to-date with any in-flight processes before they are returned. """ + def collect(self): background_process_in_flight_count = GaugeMetricFamily( "synapse_background_process_in_flight_count", @@ -105,14 +108,11 @@ class _Collector(object): # We also copy the process lists as that can also change with _bg_metrics_lock: _background_processes_copy = { - k: list(v) - for k, v in six.iteritems(_background_processes) + k: list(v) for k, v in six.iteritems(_background_processes) } for desc, processes in six.iteritems(_background_processes_copy): - background_process_in_flight_count.add_metric( - (desc,), len(processes), - ) + background_process_in_flight_count.add_metric((desc,), len(processes)) for process in processes: process.update_metrics() @@ -121,11 +121,11 @@ class _Collector(object): # now we need to run collect() over each of the static Counters, and # yield each metric they return. for m in ( - _background_process_ru_utime, - _background_process_ru_stime, - _background_process_db_txn_count, - _background_process_db_txn_duration, - _background_process_db_sched_duration, + _background_process_ru_utime, + _background_process_ru_stime, + _background_process_db_txn_count, + _background_process_db_txn_duration, + _background_process_db_sched_duration, ): for r in m.collect(): yield r @@ -151,14 +151,12 @@ class _BackgroundProcess(object): _background_process_ru_utime.labels(self.desc).inc(diff.ru_utime) _background_process_ru_stime.labels(self.desc).inc(diff.ru_stime) - _background_process_db_txn_count.labels(self.desc).inc( - diff.db_txn_count, - ) + _background_process_db_txn_count.labels(self.desc).inc(diff.db_txn_count) _background_process_db_txn_duration.labels(self.desc).inc( - diff.db_txn_duration_sec, + diff.db_txn_duration_sec ) _background_process_db_sched_duration.labels(self.desc).inc( - diff.db_sched_duration_sec, + diff.db_sched_duration_sec ) @@ -182,6 +180,7 @@ def run_as_background_process(desc, func, *args, **kwargs): Returns: Deferred which returns the result of func, but note that it does not follow the synapse logcontext rules. """ + @defer.inlineCallbacks def run(): with _bg_metrics_lock: diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index b3abd1b3c6..bf43ca09be 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -21,6 +21,7 @@ class ModuleApi(object): """A proxy object that gets passed to password auth providers so they can register new users etc if necessary. """ + def __init__(self, hs, auth_handler): self.hs = hs @@ -57,7 +58,7 @@ class ModuleApi(object): Returns: str: qualified @user:id """ - if username.startswith('@'): + if username.startswith("@"): return username return UserID(username, self.hs.hostname).to_string() @@ -89,8 +90,7 @@ class ModuleApi(object): # Register the user reg = self.hs.get_registration_handler() user_id, access_token = yield reg.register( - localpart=localpart, default_display_name=displayname, - bind_emails=emails, + localpart=localpart, default_display_name=displayname, bind_emails=emails ) defer.returnValue((user_id, access_token)) diff --git a/synapse/notifier.py b/synapse/notifier.py index ff589660da..d398078eed 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -37,7 +37,8 @@ logger = logging.getLogger(__name__) notified_events_counter = Counter("synapse_notifier_notified_events", "") users_woken_by_stream_counter = Counter( - "synapse_notifier_users_woken_by_stream", "", ["stream"]) + "synapse_notifier_users_woken_by_stream", "", ["stream"] +) # TODO(paul): Should be shared somewhere @@ -55,6 +56,7 @@ class _NotificationListener(object): The events stream handler will have yielded to the deferred, so to notify the handler it is sufficient to resolve the deferred. """ + __slots__ = ["deferred"] def __init__(self, deferred): @@ -95,9 +97,7 @@ class _NotifierUserStream(object): stream_id(str): The new id for the stream the event came from. time_now_ms(int): The current time in milliseconds. """ - self.current_token = self.current_token.copy_and_advance( - stream_key, stream_id - ) + self.current_token = self.current_token.copy_and_advance(stream_key, stream_id) self.last_notified_token = self.current_token self.last_notified_ms = time_now_ms noify_deferred = self.notify_deferred @@ -141,6 +141,7 @@ class _NotifierUserStream(object): class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))): def __nonzero__(self): return bool(self.events) + __bool__ = __nonzero__ # python3 @@ -190,15 +191,17 @@ class Notifier(object): all_user_streams.add(x) return sum(stream.count_listeners() for stream in all_user_streams) + LaterGauge("synapse_notifier_listeners", "", [], count_listeners) LaterGauge( - "synapse_notifier_rooms", "", [], + "synapse_notifier_rooms", + "", + [], lambda: count(bool, list(self.room_to_user_streams.values())), ) LaterGauge( - "synapse_notifier_users", "", [], - lambda: len(self.user_to_user_stream), + "synapse_notifier_users", "", [], lambda: len(self.user_to_user_stream) ) def add_replication_callback(self, cb): @@ -209,8 +212,9 @@ class Notifier(object): """ self.replication_callbacks.append(cb) - def on_new_room_event(self, event, room_stream_id, max_room_stream_id, - extra_users=[]): + def on_new_room_event( + self, event, room_stream_id, max_room_stream_id, extra_users=[] + ): """ Used by handlers to inform the notifier something has happened in the room, room event wise. @@ -222,9 +226,7 @@ class Notifier(object): until all previous events have been persisted before notifying the client streams. """ - self.pending_new_room_events.append(( - room_stream_id, event, extra_users - )) + self.pending_new_room_events.append((room_stream_id, event, extra_users)) self._notify_pending_new_room_events(max_room_stream_id) self.notify_replication() @@ -240,9 +242,9 @@ class Notifier(object): self.pending_new_room_events = [] for room_stream_id, event, extra_users in pending: if room_stream_id > max_room_stream_id: - self.pending_new_room_events.append(( - room_stream_id, event, extra_users - )) + self.pending_new_room_events.append( + (room_stream_id, event, extra_users) + ) else: self._on_new_room_event(event, room_stream_id, extra_users) @@ -250,8 +252,7 @@ class Notifier(object): """Notify any user streams that are interested in this room event""" # poke any interested application service. run_as_background_process( - "notify_app_services", - self._notify_app_services, room_stream_id, + "notify_app_services", self._notify_app_services, room_stream_id ) if self.federation_sender: @@ -261,9 +262,7 @@ class Notifier(object): self._user_joined_room(event.state_key, event.room_id) self.on_new_event( - "room_key", room_stream_id, - users=extra_users, - rooms=[event.room_id], + "room_key", room_stream_id, users=extra_users, rooms=[event.room_id] ) @defer.inlineCallbacks @@ -305,8 +304,9 @@ class Notifier(object): self.notify_replication() @defer.inlineCallbacks - def wait_for_events(self, user_id, timeout, callback, room_ids=None, - from_token=StreamToken.START): + def wait_for_events( + self, user_id, timeout, callback, room_ids=None, from_token=StreamToken.START + ): """Wait until the callback returns a non empty response or the timeout fires. """ @@ -339,7 +339,7 @@ class Notifier(object): listener = user_stream.new_listener(prev_token) listener.deferred = timeout_deferred( listener.deferred, - (end_time - now) / 1000., + (end_time - now) / 1000.0, self.hs.get_reactor(), ) with PreserveLoggingContext(): @@ -368,9 +368,15 @@ class Notifier(object): defer.returnValue(result) @defer.inlineCallbacks - def get_events_for(self, user, pagination_config, timeout, - only_keys=None, - is_guest=False, explicit_room_id=None): + def get_events_for( + self, + user, + pagination_config, + timeout, + only_keys=None, + is_guest=False, + explicit_room_id=None, + ): """ For the given user and rooms, return any new events for them. If there are no new events wait for up to `timeout` milliseconds for any new events to happen before returning. @@ -419,10 +425,7 @@ class Notifier(object): if name == "room": new_events = yield filter_events_for_client( - self.store, - user.to_string(), - new_events, - is_peeking=is_peeking, + self.store, user.to_string(), new_events, is_peeking=is_peeking ) elif name == "presence": now = self.clock.time_msec() @@ -450,7 +453,8 @@ class Notifier(object): # # I am sorry for what I have done. user_id_for_stream = "_PEEKING_%s_%s" % ( - explicit_room_id, user_id_for_stream + explicit_room_id, + user_id_for_stream, ) result = yield self.wait_for_events( @@ -477,9 +481,7 @@ class Notifier(object): @defer.inlineCallbacks def _is_world_readable(self, room_id): state = yield self.state_handler.get_current_state( - room_id, - EventTypes.RoomHistoryVisibility, - "", + room_id, EventTypes.RoomHistoryVisibility, "" ) if state and "history_visibility" in state.content: defer.returnValue(state.content["history_visibility"] == "world_readable") diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py index a5de75c48a..1ffd5e2df3 100644 --- a/synapse/push/action_generator.py +++ b/synapse/push/action_generator.py @@ -40,6 +40,4 @@ class ActionGenerator(object): @defer.inlineCallbacks def handle_push_actions_for_event(self, event, context): with Measure(self.clock, "action_for_event_by_user"): - yield self.bulk_evaluator.action_for_event_by_user( - event, context - ) + yield self.bulk_evaluator.action_for_event_by_user(event, context) diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 3523a40108..96d087de22 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -31,48 +31,54 @@ def list_with_base_rules(rawrules): # Grab the base rules that the user has modified. # The modified base rules have a priority_class of -1. - modified_base_rules = { - r['rule_id']: r for r in rawrules if r['priority_class'] < 0 - } + modified_base_rules = {r["rule_id"]: r for r in rawrules if r["priority_class"] < 0} # Remove the modified base rules from the list, They'll be added back # in the default postions in the list. - rawrules = [r for r in rawrules if r['priority_class'] >= 0] + rawrules = [r for r in rawrules if r["priority_class"] >= 0] # shove the server default rules for each kind onto the end of each current_prio_class = list(PRIORITY_CLASS_INVERSE_MAP)[-1] - ruleslist.extend(make_base_prepend_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules - )) + ruleslist.extend( + make_base_prepend_rules( + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules + ) + ) for r in rawrules: - if r['priority_class'] < current_prio_class: - while r['priority_class'] < current_prio_class: - ruleslist.extend(make_base_append_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class], - modified_base_rules, - )) - current_prio_class -= 1 - if current_prio_class > 0: - ruleslist.extend(make_base_prepend_rules( + if r["priority_class"] < current_prio_class: + while r["priority_class"] < current_prio_class: + ruleslist.extend( + make_base_append_rules( PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules, - )) + ) + ) + current_prio_class -= 1 + if current_prio_class > 0: + ruleslist.extend( + make_base_prepend_rules( + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], + modified_base_rules, + ) + ) ruleslist.append(r) while current_prio_class > 0: - ruleslist.extend(make_base_append_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class], - modified_base_rules, - )) + ruleslist.extend( + make_base_append_rules( + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules + ) + ) current_prio_class -= 1 if current_prio_class > 0: - ruleslist.extend(make_base_prepend_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class], - modified_base_rules, - )) + ruleslist.extend( + make_base_prepend_rules( + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules + ) + ) return ruleslist @@ -80,20 +86,20 @@ def list_with_base_rules(rawrules): def make_base_append_rules(kind, modified_base_rules): rules = [] - if kind == 'override': + if kind == "override": rules = BASE_APPEND_OVERRIDE_RULES - elif kind == 'underride': + elif kind == "underride": rules = BASE_APPEND_UNDERRIDE_RULES - elif kind == 'content': + elif kind == "content": rules = BASE_APPEND_CONTENT_RULES # Copy the rules before modifying them rules = copy.deepcopy(rules) for r in rules: # Only modify the actions, keep the conditions the same. - modified = modified_base_rules.get(r['rule_id']) + modified = modified_base_rules.get(r["rule_id"]) if modified: - r['actions'] = modified['actions'] + r["actions"] = modified["actions"] return rules @@ -101,103 +107,86 @@ def make_base_append_rules(kind, modified_base_rules): def make_base_prepend_rules(kind, modified_base_rules): rules = [] - if kind == 'override': + if kind == "override": rules = BASE_PREPEND_OVERRIDE_RULES # Copy the rules before modifying them rules = copy.deepcopy(rules) for r in rules: # Only modify the actions, keep the conditions the same. - modified = modified_base_rules.get(r['rule_id']) + modified = modified_base_rules.get(r["rule_id"]) if modified: - r['actions'] = modified['actions'] + r["actions"] = modified["actions"] return rules BASE_APPEND_CONTENT_RULES = [ { - 'rule_id': 'global/content/.m.rule.contains_user_name', - 'conditions': [ + "rule_id": "global/content/.m.rule.contains_user_name", + "conditions": [ { - 'kind': 'event_match', - 'key': 'content.body', - 'pattern_type': 'user_localpart' + "kind": "event_match", + "key": "content.body", + "pattern_type": "user_localpart", } ], - 'actions': [ - 'notify', - { - 'set_tweak': 'sound', - 'value': 'default', - }, { - 'set_tweak': 'highlight' - } - ] - }, + "actions": [ + "notify", + {"set_tweak": "sound", "value": "default"}, + {"set_tweak": "highlight"}, + ], + } ] BASE_PREPEND_OVERRIDE_RULES = [ { - 'rule_id': 'global/override/.m.rule.master', - 'enabled': False, - 'conditions': [], - 'actions': [ - "dont_notify" - ] + "rule_id": "global/override/.m.rule.master", + "enabled": False, + "conditions": [], + "actions": ["dont_notify"], } ] BASE_APPEND_OVERRIDE_RULES = [ { - 'rule_id': 'global/override/.m.rule.suppress_notices', - 'conditions': [ + "rule_id": "global/override/.m.rule.suppress_notices", + "conditions": [ { - 'kind': 'event_match', - 'key': 'content.msgtype', - 'pattern': 'm.notice', - '_id': '_suppress_notices', + "kind": "event_match", + "key": "content.msgtype", + "pattern": "m.notice", + "_id": "_suppress_notices", } ], - 'actions': [ - 'dont_notify', - ] + "actions": ["dont_notify"], }, # NB. .m.rule.invite_for_me must be higher prio than .m.rule.member_event # otherwise invites will be matched by .m.rule.member_event { - 'rule_id': 'global/override/.m.rule.invite_for_me', - 'conditions': [ + "rule_id": "global/override/.m.rule.invite_for_me", + "conditions": [ { - 'kind': 'event_match', - 'key': 'type', - 'pattern': 'm.room.member', - '_id': '_member', + "kind": "event_match", + "key": "type", + "pattern": "m.room.member", + "_id": "_member", }, { - 'kind': 'event_match', - 'key': 'content.membership', - 'pattern': 'invite', - '_id': '_invite_member', - }, - { - 'kind': 'event_match', - 'key': 'state_key', - 'pattern_type': 'user_id' + "kind": "event_match", + "key": "content.membership", + "pattern": "invite", + "_id": "_invite_member", }, + {"kind": "event_match", "key": "state_key", "pattern_type": "user_id"}, + ], + "actions": [ + "notify", + {"set_tweak": "sound", "value": "default"}, + {"set_tweak": "highlight", "value": False}, ], - 'actions': [ - 'notify', - { - 'set_tweak': 'sound', - 'value': 'default' - }, { - 'set_tweak': 'highlight', - 'value': False - } - ] }, # Will we sometimes want to know about people joining and leaving? # Perhaps: if so, this could be expanded upon. Seems the most usual case @@ -206,217 +195,164 @@ BASE_APPEND_OVERRIDE_RULES = [ # join/leave/avatar/displayname events. # See also: https://matrix.org/jira/browse/SYN-607 { - 'rule_id': 'global/override/.m.rule.member_event', - 'conditions': [ + "rule_id": "global/override/.m.rule.member_event", + "conditions": [ { - 'kind': 'event_match', - 'key': 'type', - 'pattern': 'm.room.member', - '_id': '_member', + "kind": "event_match", + "key": "type", + "pattern": "m.room.member", + "_id": "_member", } ], - 'actions': [ - 'dont_notify' - ] + "actions": ["dont_notify"], }, # This was changed from underride to override so it's closer in priority # to the content rules where the user name highlight rule lives. This # way a room rule is lower priority than both but a custom override rule # is higher priority than both. { - 'rule_id': 'global/override/.m.rule.contains_display_name', - 'conditions': [ - { - 'kind': 'contains_display_name' - } + "rule_id": "global/override/.m.rule.contains_display_name", + "conditions": [{"kind": "contains_display_name"}], + "actions": [ + "notify", + {"set_tweak": "sound", "value": "default"}, + {"set_tweak": "highlight"}, ], - 'actions': [ - 'notify', - { - 'set_tweak': 'sound', - 'value': 'default' - }, { - 'set_tweak': 'highlight' - } - ] }, { - 'rule_id': 'global/override/.m.rule.roomnotif', - 'conditions': [ + "rule_id": "global/override/.m.rule.roomnotif", + "conditions": [ { - 'kind': 'event_match', - 'key': 'content.body', - 'pattern': '@room', - '_id': '_roomnotif_content', + "kind": "event_match", + "key": "content.body", + "pattern": "@room", + "_id": "_roomnotif_content", }, { - 'kind': 'sender_notification_permission', - 'key': 'room', - '_id': '_roomnotif_pl', + "kind": "sender_notification_permission", + "key": "room", + "_id": "_roomnotif_pl", }, ], - 'actions': [ - 'notify', { - 'set_tweak': 'highlight', - 'value': True, - } - ] + "actions": ["notify", {"set_tweak": "highlight", "value": True}], }, { - 'rule_id': 'global/override/.m.rule.tombstone', - 'conditions': [ + "rule_id": "global/override/.m.rule.tombstone", + "conditions": [ { - 'kind': 'event_match', - 'key': 'type', - 'pattern': 'm.room.tombstone', - '_id': '_tombstone', + "kind": "event_match", + "key": "type", + "pattern": "m.room.tombstone", + "_id": "_tombstone", } ], - 'actions': [ - 'notify', { - 'set_tweak': 'highlight', - 'value': True, - } - ] - } + "actions": ["notify", {"set_tweak": "highlight", "value": True}], + }, ] BASE_APPEND_UNDERRIDE_RULES = [ { - 'rule_id': 'global/underride/.m.rule.call', - 'conditions': [ + "rule_id": "global/underride/.m.rule.call", + "conditions": [ { - 'kind': 'event_match', - 'key': 'type', - 'pattern': 'm.call.invite', - '_id': '_call', + "kind": "event_match", + "key": "type", + "pattern": "m.call.invite", + "_id": "_call", } ], - 'actions': [ - 'notify', - { - 'set_tweak': 'sound', - 'value': 'ring' - }, { - 'set_tweak': 'highlight', - 'value': False - } - ] + "actions": [ + "notify", + {"set_tweak": "sound", "value": "ring"}, + {"set_tweak": "highlight", "value": False}, + ], }, # XXX: once m.direct is standardised everywhere, we should use it to detect # a DM from the user's perspective rather than this heuristic. { - 'rule_id': 'global/underride/.m.rule.room_one_to_one', - 'conditions': [ + "rule_id": "global/underride/.m.rule.room_one_to_one", + "conditions": [ + {"kind": "room_member_count", "is": "2", "_id": "member_count"}, { - 'kind': 'room_member_count', - 'is': '2', - '_id': 'member_count', + "kind": "event_match", + "key": "type", + "pattern": "m.room.message", + "_id": "_message", }, - { - 'kind': 'event_match', - 'key': 'type', - 'pattern': 'm.room.message', - '_id': '_message', - } ], - 'actions': [ - 'notify', - { - 'set_tweak': 'sound', - 'value': 'default' - }, { - 'set_tweak': 'highlight', - 'value': False - } - ] + "actions": [ + "notify", + {"set_tweak": "sound", "value": "default"}, + {"set_tweak": "highlight", "value": False}, + ], }, # XXX: this is going to fire for events which aren't m.room.messages # but are encrypted (e.g. m.call.*)... { - 'rule_id': 'global/underride/.m.rule.encrypted_room_one_to_one', - 'conditions': [ + "rule_id": "global/underride/.m.rule.encrypted_room_one_to_one", + "conditions": [ + {"kind": "room_member_count", "is": "2", "_id": "member_count"}, { - 'kind': 'room_member_count', - 'is': '2', - '_id': 'member_count', + "kind": "event_match", + "key": "type", + "pattern": "m.room.encrypted", + "_id": "_encrypted", }, - { - 'kind': 'event_match', - 'key': 'type', - 'pattern': 'm.room.encrypted', - '_id': '_encrypted', - } ], - 'actions': [ - 'notify', - { - 'set_tweak': 'sound', - 'value': 'default' - }, { - 'set_tweak': 'highlight', - 'value': False - } - ] + "actions": [ + "notify", + {"set_tweak": "sound", "value": "default"}, + {"set_tweak": "highlight", "value": False}, + ], }, { - 'rule_id': 'global/underride/.m.rule.message', - 'conditions': [ + "rule_id": "global/underride/.m.rule.message", + "conditions": [ { - 'kind': 'event_match', - 'key': 'type', - 'pattern': 'm.room.message', - '_id': '_message', + "kind": "event_match", + "key": "type", + "pattern": "m.room.message", + "_id": "_message", } ], - 'actions': [ - 'notify', { - 'set_tweak': 'highlight', - 'value': False - } - ] + "actions": ["notify", {"set_tweak": "highlight", "value": False}], }, # XXX: this is going to fire for events which aren't m.room.messages # but are encrypted (e.g. m.call.*)... { - 'rule_id': 'global/underride/.m.rule.encrypted', - 'conditions': [ + "rule_id": "global/underride/.m.rule.encrypted", + "conditions": [ { - 'kind': 'event_match', - 'key': 'type', - 'pattern': 'm.room.encrypted', - '_id': '_encrypted', + "kind": "event_match", + "key": "type", + "pattern": "m.room.encrypted", + "_id": "_encrypted", } ], - 'actions': [ - 'notify', { - 'set_tweak': 'highlight', - 'value': False - } - ] - } + "actions": ["notify", {"set_tweak": "highlight", "value": False}], + }, ] BASE_RULE_IDS = set() for r in BASE_APPEND_CONTENT_RULES: - r['priority_class'] = PRIORITY_CLASS_MAP['content'] - r['default'] = True - BASE_RULE_IDS.add(r['rule_id']) + r["priority_class"] = PRIORITY_CLASS_MAP["content"] + r["default"] = True + BASE_RULE_IDS.add(r["rule_id"]) for r in BASE_PREPEND_OVERRIDE_RULES: - r['priority_class'] = PRIORITY_CLASS_MAP['override'] - r['default'] = True - BASE_RULE_IDS.add(r['rule_id']) + r["priority_class"] = PRIORITY_CLASS_MAP["override"] + r["default"] = True + BASE_RULE_IDS.add(r["rule_id"]) for r in BASE_APPEND_OVERRIDE_RULES: - r['priority_class'] = PRIORITY_CLASS_MAP['override'] - r['default'] = True - BASE_RULE_IDS.add(r['rule_id']) + r["priority_class"] = PRIORITY_CLASS_MAP["override"] + r["default"] = True + BASE_RULE_IDS.add(r["rule_id"]) for r in BASE_APPEND_UNDERRIDE_RULES: - r['priority_class'] = PRIORITY_CLASS_MAP['underride'] - r['default'] = True - BASE_RULE_IDS.add(r['rule_id']) + r["priority_class"] = PRIORITY_CLASS_MAP["underride"] + r["default"] = True + BASE_RULE_IDS.add(r["rule_id"]) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 8f9a76147f..c8a5b381da 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -39,9 +39,11 @@ rules_by_room = {} push_rules_invalidation_counter = Counter( - "synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", "") + "synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", "" +) push_rules_state_size_counter = Counter( - "synapse_push_bulk_push_rule_evaluator_push_rules_state_size_counter", "") + "synapse_push_bulk_push_rule_evaluator_push_rules_state_size_counter", "" +) # Measures whether we use the fast path of using state deltas, or if we have to # recalculate from scratch @@ -83,7 +85,7 @@ class BulkPushRuleEvaluator(object): # if this event is an invite event, we may need to run rules for the user # who's been invited, otherwise they won't get told they've been invited - if event.type == 'm.room.member' and event.content['membership'] == 'invite': + if event.type == "m.room.member" and event.content["membership"] == "invite": invited = event.state_key if invited and self.hs.is_mine_id(invited): has_pusher = yield self.store.user_has_pusher(invited) @@ -106,7 +108,9 @@ class BulkPushRuleEvaluator(object): # before any lookup methods get called on it as otherwise there may be # a race if invalidate_all gets called (which assumes its in the cache) return RulesForRoom( - self.hs, room_id, self._get_rules_for_room.cache, + self.hs, + room_id, + self._get_rules_for_room.cache, self.room_push_rule_cache_metrics, ) @@ -121,12 +125,10 @@ class BulkPushRuleEvaluator(object): auth_events = {POWER_KEY: pl_event} else: auth_events_ids = yield self.auth.compute_auth_events( - event, prev_state_ids, for_verification=False, + event, prev_state_ids, for_verification=False ) auth_events = yield self.store.get_events(auth_events_ids) - auth_events = { - (e.type, e.state_key): e for e in itervalues(auth_events) - } + auth_events = {(e.type, e.state_key): e for e in itervalues(auth_events)} sender_level = get_user_power_level(event.sender, auth_events) @@ -145,16 +147,14 @@ class BulkPushRuleEvaluator(object): rules_by_user = yield self._get_rules_for_event(event, context) actions_by_user = {} - room_members = yield self.store.get_joined_users_from_context( - event, context - ) + room_members = yield self.store.get_joined_users_from_context(event, context) (power_levels, sender_power_level) = ( yield self._get_power_levels_and_sender_level(event, context) ) evaluator = PushRuleEvaluatorForEvent( - event, len(room_members), sender_power_level, power_levels, + event, len(room_members), sender_power_level, power_levels ) condition_cache = {} @@ -180,15 +180,15 @@ class BulkPushRuleEvaluator(object): display_name = event.content.get("displayname", None) for rule in rules: - if 'enabled' in rule and not rule['enabled']: + if "enabled" in rule and not rule["enabled"]: continue matches = _condition_checker( - evaluator, rule['conditions'], uid, display_name, condition_cache + evaluator, rule["conditions"], uid, display_name, condition_cache ) if matches: - actions = [x for x in rule['actions'] if x != 'dont_notify'] - if actions and 'notify' in actions: + actions = [x for x in rule["actions"] if x != "dont_notify"] + if actions and "notify" in actions: # Push rules say we should notify the user of this event actions_by_user[uid] = actions break @@ -196,9 +196,7 @@ class BulkPushRuleEvaluator(object): # Mark in the DB staging area the push actions for users who should be # notified for this event. (This will then get handled when we persist # the event) - yield self.store.add_push_actions_to_staging( - event.event_id, actions_by_user, - ) + yield self.store.add_push_actions_to_staging(event.event_id, actions_by_user) def _condition_checker(evaluator, conditions, uid, display_name, cache): @@ -361,19 +359,19 @@ class RulesForRoom(object): self.sequence, members={}, # There were no membership changes rules_by_user=ret_rules_by_user, - state_group=state_group + state_group=state_group, ) if logger.isEnabledFor(logging.DEBUG): logger.debug( - "Returning push rules for %r %r", - self.room_id, ret_rules_by_user.keys(), + "Returning push rules for %r %r", self.room_id, ret_rules_by_user.keys() ) defer.returnValue(ret_rules_by_user) @defer.inlineCallbacks - def _update_rules_with_member_event_ids(self, ret_rules_by_user, member_event_ids, - state_group, event): + def _update_rules_with_member_event_ids( + self, ret_rules_by_user, member_event_ids, state_group, event + ): """Update the partially filled rules_by_user dict by fetching rules for any newly joined users in the `member_event_ids` list. @@ -391,16 +389,13 @@ class RulesForRoom(object): table="room_memberships", column="event_id", iterable=member_event_ids.values(), - retcols=('user_id', 'membership', 'event_id'), + retcols=("user_id", "membership", "event_id"), keyvalues={}, batch_size=500, desc="_get_rules_for_member_event_ids", ) - members = { - row["event_id"]: (row["user_id"], row["membership"]) - for row in rows - } + members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows} # If the event is a join event then it will be in current state evnts # map but not in the DB, so we have to explicitly insert it. @@ -413,15 +408,15 @@ class RulesForRoom(object): logger.debug("Found members %r: %r", self.room_id, members.values()) interested_in_user_ids = set( - user_id for user_id, membership in itervalues(members) + user_id + for user_id, membership in itervalues(members) if membership == Membership.JOIN ) logger.debug("Joined: %r", interested_in_user_ids) if_users_with_pushers = yield self.store.get_if_users_have_pushers( - interested_in_user_ids, - on_invalidate=self.invalidate_all_cb, + interested_in_user_ids, on_invalidate=self.invalidate_all_cb ) user_ids = set( @@ -431,7 +426,7 @@ class RulesForRoom(object): logger.debug("With pushers: %r", user_ids) users_with_receipts = yield self.store.get_users_with_read_receipts_in_room( - self.room_id, on_invalidate=self.invalidate_all_cb, + self.room_id, on_invalidate=self.invalidate_all_cb ) logger.debug("With receipts: %r", users_with_receipts) @@ -442,7 +437,7 @@ class RulesForRoom(object): user_ids.add(uid) rules_by_user = yield self.store.bulk_get_push_rules( - user_ids, on_invalidate=self.invalidate_all_cb, + user_ids, on_invalidate=self.invalidate_all_cb ) ret_rules_by_user.update( diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py index 8bd96b1178..a59b639f15 100644 --- a/synapse/push/clientformat.py +++ b/synapse/push/clientformat.py @@ -25,14 +25,14 @@ def format_push_rules_for_user(user, ruleslist): # We're going to be mutating this a lot, so do a deep copy ruleslist = copy.deepcopy(ruleslist) - rules = {'global': {}, 'device': {}} + rules = {"global": {}, "device": {}} - rules['global'] = _add_empty_priority_class_arrays(rules['global']) + rules["global"] = _add_empty_priority_class_arrays(rules["global"]) for r in ruleslist: rulearray = None - template_name = _priority_class_to_template_name(r['priority_class']) + template_name = _priority_class_to_template_name(r["priority_class"]) # Remove internal stuff. for c in r["conditions"]: @@ -44,14 +44,14 @@ def format_push_rules_for_user(user, ruleslist): elif pattern_type == "user_localpart": c["pattern"] = user.localpart - rulearray = rules['global'][template_name] + rulearray = rules["global"][template_name] template_rule = _rule_to_template(r) if template_rule: - if 'enabled' in r: - template_rule['enabled'] = r['enabled'] + if "enabled" in r: + template_rule["enabled"] = r["enabled"] else: - template_rule['enabled'] = True + template_rule["enabled"] = True rulearray.append(template_rule) return rules @@ -65,33 +65,33 @@ def _add_empty_priority_class_arrays(d): def _rule_to_template(rule): unscoped_rule_id = None - if 'rule_id' in rule: - unscoped_rule_id = _rule_id_from_namespaced(rule['rule_id']) + if "rule_id" in rule: + unscoped_rule_id = _rule_id_from_namespaced(rule["rule_id"]) - template_name = _priority_class_to_template_name(rule['priority_class']) - if template_name in ['override', 'underride']: + template_name = _priority_class_to_template_name(rule["priority_class"]) + if template_name in ["override", "underride"]: templaterule = {k: rule[k] for k in ["conditions", "actions"]} elif template_name in ["sender", "room"]: - templaterule = {'actions': rule['actions']} - unscoped_rule_id = rule['conditions'][0]['pattern'] - elif template_name == 'content': + templaterule = {"actions": rule["actions"]} + unscoped_rule_id = rule["conditions"][0]["pattern"] + elif template_name == "content": if len(rule["conditions"]) != 1: return None thecond = rule["conditions"][0] if "pattern" not in thecond: return None - templaterule = {'actions': rule['actions']} + templaterule = {"actions": rule["actions"]} templaterule["pattern"] = thecond["pattern"] if unscoped_rule_id: - templaterule['rule_id'] = unscoped_rule_id - if 'default' in rule: - templaterule['default'] = rule['default'] + templaterule["rule_id"] = unscoped_rule_id + if "default" in rule: + templaterule["default"] = rule["default"] return templaterule def _rule_id_from_namespaced(in_rule_id): - return in_rule_id.split('/')[-1] + return in_rule_id.split("/")[-1] def _priority_class_to_template_name(pc): diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index c89a8438a9..424ffa8b68 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -32,13 +32,13 @@ DELAY_BEFORE_MAIL_MS = 10 * 60 * 1000 THROTTLE_START_MS = 10 * 60 * 1000 THROTTLE_MAX_MS = 24 * 60 * 60 * 1000 # 24h # THROTTLE_MULTIPLIER = 6 # 10 mins, 1 hour, 6 hours, 24 hours -THROTTLE_MULTIPLIER = 144 # 10 mins, 24 hours - i.e. jump straight to 1 day +THROTTLE_MULTIPLIER = 144 # 10 mins, 24 hours - i.e. jump straight to 1 day # If no event triggers a notification for this long after the previous, # the throttle is released. # 12 hours - a gap of 12 hours in conversation is surely enough to merit a new # notification when things get going again... -THROTTLE_RESET_AFTER_MS = (12 * 60 * 60 * 1000) +THROTTLE_RESET_AFTER_MS = 12 * 60 * 60 * 1000 # does each email include all unread notifs, or just the ones which have happened # since the last mail? @@ -53,17 +53,18 @@ class EmailPusher(object): This shares quite a bit of code with httpusher: it would be good to factor out the common parts """ + def __init__(self, hs, pusherdict, mailer): self.hs = hs self.mailer = mailer self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() - self.pusher_id = pusherdict['id'] - self.user_id = pusherdict['user_name'] - self.app_id = pusherdict['app_id'] - self.email = pusherdict['pushkey'] - self.last_stream_ordering = pusherdict['last_stream_ordering'] + self.pusher_id = pusherdict["id"] + self.user_id = pusherdict["user_name"] + self.app_id = pusherdict["app_id"] + self.email = pusherdict["pushkey"] + self.last_stream_ordering = pusherdict["last_stream_ordering"] self.timed_call = None self.throttle_params = None @@ -93,7 +94,9 @@ class EmailPusher(object): def on_new_notifications(self, min_stream_ordering, max_stream_ordering): if self.max_stream_ordering: - self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering) + self.max_stream_ordering = max( + max_stream_ordering, self.max_stream_ordering + ) else: self.max_stream_ordering = max_stream_ordering self._start_processing() @@ -174,14 +177,12 @@ class EmailPusher(object): return for push_action in unprocessed: - received_at = push_action['received_ts'] + received_at = push_action["received_ts"] if received_at is None: received_at = 0 notif_ready_at = received_at + DELAY_BEFORE_MAIL_MS - room_ready_at = self.room_ready_to_notify_at( - push_action['room_id'] - ) + room_ready_at = self.room_ready_to_notify_at(push_action["room_id"]) should_notify_at = max(notif_ready_at, room_ready_at) @@ -192,25 +193,23 @@ class EmailPusher(object): # to be delivered. reason = { - 'room_id': push_action['room_id'], - 'now': self.clock.time_msec(), - 'received_at': received_at, - 'delay_before_mail_ms': DELAY_BEFORE_MAIL_MS, - 'last_sent_ts': self.get_room_last_sent_ts(push_action['room_id']), - 'throttle_ms': self.get_room_throttle_ms(push_action['room_id']), + "room_id": push_action["room_id"], + "now": self.clock.time_msec(), + "received_at": received_at, + "delay_before_mail_ms": DELAY_BEFORE_MAIL_MS, + "last_sent_ts": self.get_room_last_sent_ts(push_action["room_id"]), + "throttle_ms": self.get_room_throttle_ms(push_action["room_id"]), } yield self.send_notification(unprocessed, reason) - yield self.save_last_stream_ordering_and_success(max([ - ea['stream_ordering'] for ea in unprocessed - ])) + yield self.save_last_stream_ordering_and_success( + max([ea["stream_ordering"] for ea in unprocessed]) + ) # we update the throttle on all the possible unprocessed push actions for ea in unprocessed: - yield self.sent_notif_update_throttle( - ea['room_id'], ea - ) + yield self.sent_notif_update_throttle(ea["room_id"], ea) break else: if soonest_due_at is None or should_notify_at < soonest_due_at: @@ -236,8 +235,11 @@ class EmailPusher(object): self.last_stream_ordering = last_stream_ordering yield self.store.update_pusher_last_stream_ordering_and_success( - self.app_id, self.email, self.user_id, - last_stream_ordering, self.clock.time_msec() + self.app_id, + self.email, + self.user_id, + last_stream_ordering, + self.clock.time_msec(), ) def seconds_until(self, ts_msec): @@ -276,10 +278,10 @@ class EmailPusher(object): # THROTTLE_RESET_AFTER_MS after the previous one that triggered a # notif, we release the throttle. Otherwise, the throttle is increased. time_of_previous_notifs = yield self.store.get_time_of_last_push_action_before( - notified_push_action['stream_ordering'] + notified_push_action["stream_ordering"] ) - time_of_this_notifs = notified_push_action['received_ts'] + time_of_this_notifs = notified_push_action["received_ts"] if time_of_previous_notifs is not None and time_of_this_notifs is not None: gap = time_of_this_notifs - time_of_previous_notifs @@ -298,12 +300,11 @@ class EmailPusher(object): new_throttle_ms = THROTTLE_START_MS else: new_throttle_ms = min( - current_throttle_ms * THROTTLE_MULTIPLIER, - THROTTLE_MAX_MS + current_throttle_ms * THROTTLE_MULTIPLIER, THROTTLE_MAX_MS ) self.throttle_params[room_id] = { "last_sent_ts": self.clock.time_msec(), - "throttle_ms": new_throttle_ms + "throttle_ms": new_throttle_ms, } yield self.store.set_throttle_params( self.pusher_id, room_id, self.throttle_params[room_id] diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index fac05aa44c..4e7b6a5531 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -65,16 +65,16 @@ class HttpPusher(object): self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() self.state_handler = self.hs.get_state_handler() - self.user_id = pusherdict['user_name'] - self.app_id = pusherdict['app_id'] - self.app_display_name = pusherdict['app_display_name'] - self.device_display_name = pusherdict['device_display_name'] - self.pushkey = pusherdict['pushkey'] - self.pushkey_ts = pusherdict['ts'] - self.data = pusherdict['data'] - self.last_stream_ordering = pusherdict['last_stream_ordering'] + self.user_id = pusherdict["user_name"] + self.app_id = pusherdict["app_id"] + self.app_display_name = pusherdict["app_display_name"] + self.device_display_name = pusherdict["device_display_name"] + self.pushkey = pusherdict["pushkey"] + self.pushkey_ts = pusherdict["ts"] + self.data = pusherdict["data"] + self.last_stream_ordering = pusherdict["last_stream_ordering"] self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC - self.failing_since = pusherdict['failing_since'] + self.failing_since = pusherdict["failing_since"] self.timed_call = None self._is_processing = False @@ -85,32 +85,26 @@ class HttpPusher(object): # off as None though as we don't know any better. self.max_stream_ordering = None - if 'data' not in pusherdict: - raise PusherConfigException( - "No 'data' key for HTTP pusher" - ) - self.data = pusherdict['data'] + if "data" not in pusherdict: + raise PusherConfigException("No 'data' key for HTTP pusher") + self.data = pusherdict["data"] self.name = "%s/%s/%s" % ( - pusherdict['user_name'], - pusherdict['app_id'], - pusherdict['pushkey'], + pusherdict["user_name"], + pusherdict["app_id"], + pusherdict["pushkey"], ) if self.data is None: - raise PusherConfigException( - "data can not be null for HTTP pusher" - ) + raise PusherConfigException("data can not be null for HTTP pusher") - if 'url' not in self.data: - raise PusherConfigException( - "'url' required in data for HTTP pusher" - ) - self.url = self.data['url'] + if "url" not in self.data: + raise PusherConfigException("'url' required in data for HTTP pusher") + self.url = self.data["url"] self.http_client = hs.get_simple_http_client() self.data_minus_url = {} self.data_minus_url.update(self.data) - del self.data_minus_url['url'] + del self.data_minus_url["url"] def on_started(self, should_check_for_notifs): """Called when this pusher has been started. @@ -124,7 +118,9 @@ class HttpPusher(object): self._start_processing() def on_new_notifications(self, min_stream_ordering, max_stream_ordering): - self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering or 0) + self.max_stream_ordering = max( + max_stream_ordering, self.max_stream_ordering or 0 + ) self._start_processing() def on_new_receipts(self, min_stream_id, max_stream_id): @@ -192,7 +188,9 @@ class HttpPusher(object): logger.info( "Processing %i unprocessed push actions for %s starting at " "stream_ordering %s", - len(unprocessed), self.name, self.last_stream_ordering, + len(unprocessed), + self.name, + self.last_stream_ordering, ) for push_action in unprocessed: @@ -200,71 +198,72 @@ class HttpPusher(object): if processed: http_push_processed_counter.inc() self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC - self.last_stream_ordering = push_action['stream_ordering'] + self.last_stream_ordering = push_action["stream_ordering"] yield self.store.update_pusher_last_stream_ordering_and_success( - self.app_id, self.pushkey, self.user_id, + self.app_id, + self.pushkey, + self.user_id, self.last_stream_ordering, - self.clock.time_msec() + self.clock.time_msec(), ) if self.failing_since: self.failing_since = None yield self.store.update_pusher_failing_since( - self.app_id, self.pushkey, self.user_id, - self.failing_since + self.app_id, self.pushkey, self.user_id, self.failing_since ) else: http_push_failed_counter.inc() if not self.failing_since: self.failing_since = self.clock.time_msec() yield self.store.update_pusher_failing_since( - self.app_id, self.pushkey, self.user_id, - self.failing_since + self.app_id, self.pushkey, self.user_id, self.failing_since ) if ( - self.failing_since and - self.failing_since < - self.clock.time_msec() - HttpPusher.GIVE_UP_AFTER_MS + self.failing_since + and self.failing_since + < self.clock.time_msec() - HttpPusher.GIVE_UP_AFTER_MS ): # we really only give up so that if the URL gets # fixed, we don't suddenly deliver a load # of old notifications. - logger.warn("Giving up on a notification to user %s, " - "pushkey %s", - self.user_id, self.pushkey) + logger.warn( + "Giving up on a notification to user %s, " "pushkey %s", + self.user_id, + self.pushkey, + ) self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC - self.last_stream_ordering = push_action['stream_ordering'] + self.last_stream_ordering = push_action["stream_ordering"] yield self.store.update_pusher_last_stream_ordering( self.app_id, self.pushkey, self.user_id, - self.last_stream_ordering + self.last_stream_ordering, ) self.failing_since = None yield self.store.update_pusher_failing_since( - self.app_id, - self.pushkey, - self.user_id, - self.failing_since + self.app_id, self.pushkey, self.user_id, self.failing_since ) else: logger.info("Push failed: delaying for %ds", self.backoff_delay) self.timed_call = self.hs.get_reactor().callLater( self.backoff_delay, self.on_timer ) - self.backoff_delay = min(self.backoff_delay * 2, self.MAX_BACKOFF_SEC) + self.backoff_delay = min( + self.backoff_delay * 2, self.MAX_BACKOFF_SEC + ) break @defer.inlineCallbacks def _process_one(self, push_action): - if 'notify' not in push_action['actions']: + if "notify" not in push_action["actions"]: defer.returnValue(True) - tweaks = push_rule_evaluator.tweaks_for_actions(push_action['actions']) + tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"]) badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) - event = yield self.store.get_event(push_action['event_id'], allow_none=True) + event = yield self.store.get_event(push_action["event_id"], allow_none=True) if event is None: defer.returnValue(True) # It's been redacted rejected = yield self.dispatch_push(event, tweaks, badge) @@ -277,37 +276,30 @@ class HttpPusher(object): # for sanity, we only remove the pushkey if it # was the one we actually sent... logger.warn( - ("Ignoring rejected pushkey %s because we" - " didn't send it"), pk + ("Ignoring rejected pushkey %s because we" " didn't send it"), + pk, ) else: - logger.info( - "Pushkey %s was rejected: removing", - pk - ) - yield self.hs.remove_pusher( - self.app_id, pk, self.user_id - ) + logger.info("Pushkey %s was rejected: removing", pk) + yield self.hs.remove_pusher(self.app_id, pk, self.user_id) defer.returnValue(True) @defer.inlineCallbacks def _build_notification_dict(self, event, tweaks, badge): - if self.data.get('format') == 'event_id_only': + if self.data.get("format") == "event_id_only": d = { - 'notification': { - 'event_id': event.event_id, - 'room_id': event.room_id, - 'counts': { - 'unread': badge, - }, - 'devices': [ + "notification": { + "event_id": event.event_id, + "room_id": event.room_id, + "counts": {"unread": badge}, + "devices": [ { - 'app_id': self.app_id, - 'pushkey': self.pushkey, - 'pushkey_ts': long(self.pushkey_ts / 1000), - 'data': self.data_minus_url, + "app_id": self.app_id, + "pushkey": self.pushkey, + "pushkey_ts": long(self.pushkey_ts / 1000), + "data": self.data_minus_url, } - ] + ], } } defer.returnValue(d) @@ -317,41 +309,41 @@ class HttpPusher(object): ) d = { - 'notification': { - 'id': event.event_id, # deprecated: remove soon - 'event_id': event.event_id, - 'room_id': event.room_id, - 'type': event.type, - 'sender': event.user_id, - 'counts': { # -- we don't mark messages as read yet so - # we have no way of knowing + "notification": { + "id": event.event_id, # deprecated: remove soon + "event_id": event.event_id, + "room_id": event.room_id, + "type": event.type, + "sender": event.user_id, + "counts": { # -- we don't mark messages as read yet so + # we have no way of knowing # Just set the badge to 1 until we have read receipts - 'unread': badge, + "unread": badge, # 'missed_calls': 2 }, - 'devices': [ + "devices": [ { - 'app_id': self.app_id, - 'pushkey': self.pushkey, - 'pushkey_ts': long(self.pushkey_ts / 1000), - 'data': self.data_minus_url, - 'tweaks': tweaks + "app_id": self.app_id, + "pushkey": self.pushkey, + "pushkey_ts": long(self.pushkey_ts / 1000), + "data": self.data_minus_url, + "tweaks": tweaks, } - ] + ], } } - if event.type == 'm.room.member' and event.is_state(): - d['notification']['membership'] = event.content['membership'] - d['notification']['user_is_target'] = event.state_key == self.user_id + if event.type == "m.room.member" and event.is_state(): + d["notification"]["membership"] = event.content["membership"] + d["notification"]["user_is_target"] = event.state_key == self.user_id if self.hs.config.push_include_content and event.content: - d['notification']['content'] = event.content + d["notification"]["content"] = event.content # We no longer send aliases separately, instead, we send the human # readable name of the room, which may be an alias. - if 'sender_display_name' in ctx and len(ctx['sender_display_name']) > 0: - d['notification']['sender_display_name'] = ctx['sender_display_name'] - if 'name' in ctx and len(ctx['name']) > 0: - d['notification']['room_name'] = ctx['name'] + if "sender_display_name" in ctx and len(ctx["sender_display_name"]) > 0: + d["notification"]["sender_display_name"] = ctx["sender_display_name"] + if "name" in ctx and len(ctx["name"]) > 0: + d["notification"]["room_name"] = ctx["name"] defer.returnValue(d) @@ -361,16 +353,21 @@ class HttpPusher(object): if not notification_dict: defer.returnValue([]) try: - resp = yield self.http_client.post_json_get_json(self.url, notification_dict) + resp = yield self.http_client.post_json_get_json( + self.url, notification_dict + ) except Exception as e: logger.warning( "Failed to push event %s to %s: %s %s", - event.event_id, self.name, type(e), e, + event.event_id, + self.name, + type(e), + e, ) defer.returnValue(False) rejected = [] - if 'rejected' in resp: - rejected = resp['rejected'] + if "rejected" in resp: + rejected = resp["rejected"] defer.returnValue(rejected) @defer.inlineCallbacks @@ -381,21 +378,19 @@ class HttpPusher(object): """ logger.info("Sending updated badge count %d to %s", badge, self.name) d = { - 'notification': { - 'id': '', - 'type': None, - 'sender': '', - 'counts': { - 'unread': badge - }, - 'devices': [ + "notification": { + "id": "", + "type": None, + "sender": "", + "counts": {"unread": badge}, + "devices": [ { - 'app_id': self.app_id, - 'pushkey': self.pushkey, - 'pushkey_ts': long(self.pushkey_ts / 1000), - 'data': self.data_minus_url, + "app_id": self.app_id, + "pushkey": self.pushkey, + "pushkey_ts": long(self.pushkey_ts / 1000), + "data": self.data_minus_url, } - ] + ], } } try: @@ -403,7 +398,6 @@ class HttpPusher(object): http_badges_processed_counter.inc() except Exception as e: logger.warning( - "Failed to send badge count to %s: %s %s", - self.name, type(e), e, + "Failed to send badge count to %s: %s %s", self.name, type(e), e ) http_badges_failed_counter.inc() diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 099f9545ab..17c7d3195a 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -42,17 +42,21 @@ from synapse.visibility import filter_events_for_client logger = logging.getLogger(__name__) -MESSAGE_FROM_PERSON_IN_ROOM = "You have a message on %(app)s from %(person)s " \ - "in the %(room)s room..." +MESSAGE_FROM_PERSON_IN_ROOM = ( + "You have a message on %(app)s from %(person)s " "in the %(room)s room..." +) MESSAGE_FROM_PERSON = "You have a message on %(app)s from %(person)s..." MESSAGES_FROM_PERSON = "You have messages on %(app)s from %(person)s..." MESSAGES_IN_ROOM = "You have messages on %(app)s in the %(room)s room..." -MESSAGES_IN_ROOM_AND_OTHERS = \ +MESSAGES_IN_ROOM_AND_OTHERS = ( "You have messages on %(app)s in the %(room)s room and others..." -MESSAGES_FROM_PERSON_AND_OTHERS = \ +) +MESSAGES_FROM_PERSON_AND_OTHERS = ( "You have messages on %(app)s from %(person)s and others..." -INVITE_FROM_PERSON_TO_ROOM = "%(person)s has invited you to join the " \ - "%(room)s room on %(app)s..." +) +INVITE_FROM_PERSON_TO_ROOM = ( + "%(person)s has invited you to join the " "%(room)s room on %(app)s..." +) INVITE_FROM_PERSON = "%(person)s has invited you to chat on %(app)s..." CONTEXT_BEFORE = 1 @@ -60,12 +64,38 @@ CONTEXT_AFTER = 1 # From https://github.com/matrix-org/matrix-react-sdk/blob/master/src/HtmlUtils.js ALLOWED_TAGS = [ - 'font', # custom to matrix for IRC-style font coloring - 'del', # for markdown + "font", # custom to matrix for IRC-style font coloring + "del", # for markdown # deliberately no h1/h2 to stop people shouting. - 'h3', 'h4', 'h5', 'h6', 'blockquote', 'p', 'a', 'ul', 'ol', - 'nl', 'li', 'b', 'i', 'u', 'strong', 'em', 'strike', 'code', 'hr', 'br', 'div', - 'table', 'thead', 'caption', 'tbody', 'tr', 'th', 'td', 'pre' + "h3", + "h4", + "h5", + "h6", + "blockquote", + "p", + "a", + "ul", + "ol", + "nl", + "li", + "b", + "i", + "u", + "strong", + "em", + "strike", + "code", + "hr", + "br", + "div", + "table", + "thead", + "caption", + "tbody", + "tr", + "th", + "td", + "pre", ] ALLOWED_ATTRS = { # custom ones first: @@ -94,13 +124,7 @@ class Mailer(object): logger.info("Created Mailer for app_name %s" % app_name) @defer.inlineCallbacks - def send_password_reset_mail( - self, - email_address, - token, - client_secret, - sid, - ): + def send_password_reset_mail(self, email_address, token, client_secret, sid): """Send an email with a password reset link to a user Args: @@ -112,19 +136,16 @@ class Mailer(object): group together multiple email sending attempts sid (str): The generated session ID """ - if email.utils.parseaddr(email_address)[1] == '': + if email.utils.parseaddr(email_address)[1] == "": raise RuntimeError("Invalid 'to' email address") link = ( - self.hs.config.public_baseurl + - "_matrix/client/unstable/password_reset/email/submit_token" - "?token=%s&client_secret=%s&sid=%s" % - (token, client_secret, sid) + self.hs.config.public_baseurl + + "_matrix/client/unstable/password_reset/email/submit_token" + "?token=%s&client_secret=%s&sid=%s" % (token, client_secret, sid) ) - template_vars = { - "link": link, - } + template_vars = {"link": link} yield self.send_email( email_address, @@ -133,15 +154,14 @@ class Mailer(object): ) @defer.inlineCallbacks - def send_notification_mail(self, app_id, user_id, email_address, - push_actions, reason): + def send_notification_mail( + self, app_id, user_id, email_address, push_actions, reason + ): """Send email regarding a user's room notifications""" - rooms_in_order = deduped_ordered_list( - [pa['room_id'] for pa in push_actions] - ) + rooms_in_order = deduped_ordered_list([pa["room_id"] for pa in push_actions]) notif_events = yield self.store.get_events( - [pa['event_id'] for pa in push_actions] + [pa["event_id"] for pa in push_actions] ) notifs_by_room = {} @@ -171,9 +191,7 @@ class Mailer(object): yield concurrently_execute(_fetch_room_state, rooms_in_order, 3) # actually sort our so-called rooms_in_order list, most recent room first - rooms_in_order.sort( - key=lambda r: -(notifs_by_room[r][-1]['received_ts'] or 0) - ) + rooms_in_order.sort(key=lambda r: -(notifs_by_room[r][-1]["received_ts"] or 0)) rooms = [] @@ -183,9 +201,11 @@ class Mailer(object): ) rooms.append(roomvars) - reason['room_name'] = yield calculate_room_name( - self.store, state_by_room[reason['room_id']], user_id, - fallback_to_members=True + reason["room_name"] = yield calculate_room_name( + self.store, + state_by_room[reason["room_id"]], + user_id, + fallback_to_members=True, ) summary_text = yield self.make_summary_text( @@ -204,25 +224,21 @@ class Mailer(object): } yield self.send_email( - email_address, - "[%s] %s" % (self.app_name, summary_text), - template_vars, + email_address, "[%s] %s" % (self.app_name, summary_text), template_vars ) @defer.inlineCallbacks def send_email(self, email_address, subject, template_vars): """Send an email with the given information and template text""" try: - from_string = self.hs.config.email_notif_from % { - "app": self.app_name - } + from_string = self.hs.config.email_notif_from % {"app": self.app_name} except TypeError: from_string = self.hs.config.email_notif_from raw_from = email.utils.parseaddr(from_string)[1] raw_to = email.utils.parseaddr(email_address)[1] - if raw_to == '': + if raw_to == "": raise RuntimeError("Invalid 'to' address") html_text = self.template_html.render(**template_vars) @@ -231,27 +247,31 @@ class Mailer(object): plain_text = self.template_text.render(**template_vars) text_part = MIMEText(plain_text, "plain", "utf8") - multipart_msg = MIMEMultipart('alternative') - multipart_msg['Subject'] = subject - multipart_msg['From'] = from_string - multipart_msg['To'] = email_address - multipart_msg['Date'] = email.utils.formatdate() - multipart_msg['Message-ID'] = email.utils.make_msgid() + multipart_msg = MIMEMultipart("alternative") + multipart_msg["Subject"] = subject + multipart_msg["From"] = from_string + multipart_msg["To"] = email_address + multipart_msg["Date"] = email.utils.formatdate() + multipart_msg["Message-ID"] = email.utils.make_msgid() multipart_msg.attach(text_part) multipart_msg.attach(html_part) logger.info("Sending email push notification to %s" % email_address) - yield make_deferred_yieldable(self.sendmail( - self.hs.config.email_smtp_host, - raw_from, raw_to, multipart_msg.as_string().encode('utf8'), - reactor=self.hs.get_reactor(), - port=self.hs.config.email_smtp_port, - requireAuthentication=self.hs.config.email_smtp_user is not None, - username=self.hs.config.email_smtp_user, - password=self.hs.config.email_smtp_pass, - requireTransportSecurity=self.hs.config.require_transport_security - )) + yield make_deferred_yieldable( + self.sendmail( + self.hs.config.email_smtp_host, + raw_from, + raw_to, + multipart_msg.as_string().encode("utf8"), + reactor=self.hs.get_reactor(), + port=self.hs.config.email_smtp_port, + requireAuthentication=self.hs.config.email_smtp_user is not None, + username=self.hs.config.email_smtp_user, + password=self.hs.config.email_smtp_pass, + requireTransportSecurity=self.hs.config.require_transport_security, + ) + ) @defer.inlineCallbacks def get_room_vars(self, room_id, user_id, notifs, notif_events, room_state_ids): @@ -272,17 +292,18 @@ class Mailer(object): if not is_invite: for n in notifs: notifvars = yield self.get_notif_vars( - n, user_id, notif_events[n['event_id']], room_state_ids + n, user_id, notif_events[n["event_id"]], room_state_ids ) # merge overlapping notifs together. # relies on the notifs being in chronological order. merge = False - if room_vars['notifs'] and 'messages' in room_vars['notifs'][-1]: - prev_messages = room_vars['notifs'][-1]['messages'] - for message in notifvars['messages']: - pm = list(filter(lambda pm: pm['id'] == message['id'], - prev_messages)) + if room_vars["notifs"] and "messages" in room_vars["notifs"][-1]: + prev_messages = room_vars["notifs"][-1]["messages"] + for message in notifvars["messages"]: + pm = list( + filter(lambda pm: pm["id"] == message["id"], prev_messages) + ) if pm: if not message["is_historical"]: pm[0]["is_historical"] = False @@ -293,20 +314,22 @@ class Mailer(object): prev_messages.append(message) if not merge: - room_vars['notifs'].append(notifvars) + room_vars["notifs"].append(notifvars) defer.returnValue(room_vars) @defer.inlineCallbacks def get_notif_vars(self, notif, user_id, notif_event, room_state_ids): results = yield self.store.get_events_around( - notif['room_id'], notif['event_id'], - before_limit=CONTEXT_BEFORE, after_limit=CONTEXT_AFTER + notif["room_id"], + notif["event_id"], + before_limit=CONTEXT_BEFORE, + after_limit=CONTEXT_AFTER, ) ret = { "link": self.make_notif_link(notif), - "ts": notif['received_ts'], + "ts": notif["received_ts"], "messages": [], } @@ -318,7 +341,7 @@ class Mailer(object): for event in the_events: messagevars = yield self.get_message_vars(notif, event, room_state_ids) if messagevars is not None: - ret['messages'].append(messagevars) + ret["messages"].append(messagevars) defer.returnValue(ret) @@ -340,7 +363,7 @@ class Mailer(object): ret = { "msgtype": msgtype, - "is_historical": event.event_id != notif['event_id'], + "is_historical": event.event_id != notif["event_id"], "id": event.event_id, "ts": event.origin_server_ts, "sender_name": sender_name, @@ -379,8 +402,9 @@ class Mailer(object): return messagevars @defer.inlineCallbacks - def make_summary_text(self, notifs_by_room, room_state_ids, - notif_events, user_id, reason): + def make_summary_text( + self, notifs_by_room, room_state_ids, notif_events, user_id, reason + ): if len(notifs_by_room) == 1: # Only one room has new stuff room_id = list(notifs_by_room.keys())[0] @@ -404,16 +428,19 @@ class Mailer(object): inviter_name = name_from_member_event(inviter_member_event) if room_name is None: - defer.returnValue(INVITE_FROM_PERSON % { - "person": inviter_name, - "app": self.app_name - }) + defer.returnValue( + INVITE_FROM_PERSON + % {"person": inviter_name, "app": self.app_name} + ) else: - defer.returnValue(INVITE_FROM_PERSON_TO_ROOM % { - "person": inviter_name, - "room": room_name, - "app": self.app_name, - }) + defer.returnValue( + INVITE_FROM_PERSON_TO_ROOM + % { + "person": inviter_name, + "room": room_name, + "app": self.app_name, + } + ) sender_name = None if len(notifs_by_room[room_id]) == 1: @@ -427,67 +454,86 @@ class Mailer(object): sender_name = name_from_member_event(state_event) if sender_name is not None and room_name is not None: - defer.returnValue(MESSAGE_FROM_PERSON_IN_ROOM % { - "person": sender_name, - "room": room_name, - "app": self.app_name, - }) + defer.returnValue( + MESSAGE_FROM_PERSON_IN_ROOM + % { + "person": sender_name, + "room": room_name, + "app": self.app_name, + } + ) elif sender_name is not None: - defer.returnValue(MESSAGE_FROM_PERSON % { - "person": sender_name, - "app": self.app_name, - }) + defer.returnValue( + MESSAGE_FROM_PERSON + % {"person": sender_name, "app": self.app_name} + ) else: # There's more than one notification for this room, so just # say there are several if room_name is not None: - defer.returnValue(MESSAGES_IN_ROOM % { - "room": room_name, - "app": self.app_name, - }) + defer.returnValue( + MESSAGES_IN_ROOM % {"room": room_name, "app": self.app_name} + ) else: # If the room doesn't have a name, say who the messages # are from explicitly to avoid, "messages in the Bob room" - sender_ids = list(set([ - notif_events[n['event_id']].sender - for n in notifs_by_room[room_id] - ])) - - member_events = yield self.store.get_events([ - room_state_ids[room_id][("m.room.member", s)] - for s in sender_ids - ]) - - defer.returnValue(MESSAGES_FROM_PERSON % { - "person": descriptor_from_member_events(member_events.values()), - "app": self.app_name, - }) + sender_ids = list( + set( + [ + notif_events[n["event_id"]].sender + for n in notifs_by_room[room_id] + ] + ) + ) + + member_events = yield self.store.get_events( + [ + room_state_ids[room_id][("m.room.member", s)] + for s in sender_ids + ] + ) + + defer.returnValue( + MESSAGES_FROM_PERSON + % { + "person": descriptor_from_member_events( + member_events.values() + ), + "app": self.app_name, + } + ) else: # Stuff's happened in multiple different rooms # ...but we still refer to the 'reason' room which triggered the mail - if reason['room_name'] is not None: - defer.returnValue(MESSAGES_IN_ROOM_AND_OTHERS % { - "room": reason['room_name'], - "app": self.app_name, - }) + if reason["room_name"] is not None: + defer.returnValue( + MESSAGES_IN_ROOM_AND_OTHERS + % {"room": reason["room_name"], "app": self.app_name} + ) else: # If the reason room doesn't have a name, say who the messages # are from explicitly to avoid, "messages in the Bob room" - sender_ids = list(set([ - notif_events[n['event_id']].sender - for n in notifs_by_room[reason['room_id']] - ])) + sender_ids = list( + set( + [ + notif_events[n["event_id"]].sender + for n in notifs_by_room[reason["room_id"]] + ] + ) + ) - member_events = yield self.store.get_events([ - room_state_ids[room_id][("m.room.member", s)] - for s in sender_ids - ]) + member_events = yield self.store.get_events( + [room_state_ids[room_id][("m.room.member", s)] for s in sender_ids] + ) - defer.returnValue(MESSAGES_FROM_PERSON_AND_OTHERS % { - "person": descriptor_from_member_events(member_events.values()), - "app": self.app_name, - }) + defer.returnValue( + MESSAGES_FROM_PERSON_AND_OTHERS + % { + "person": descriptor_from_member_events(member_events.values()), + "app": self.app_name, + } + ) def make_room_link(self, room_id): if self.hs.config.email_riot_base_url: @@ -503,17 +549,17 @@ class Mailer(object): if self.hs.config.email_riot_base_url: return "%s/#/room/%s/%s" % ( self.hs.config.email_riot_base_url, - notif['room_id'], notif['event_id'] + notif["room_id"], + notif["event_id"], ) elif self.app_name == "Vector": # need /beta for Universal Links to work on iOS return "https://vector.im/beta/#/room/%s/%s" % ( - notif['room_id'], notif['event_id'] + notif["room_id"], + notif["event_id"], ) else: - return "https://matrix.to/#/%s/%s" % ( - notif['room_id'], notif['event_id'] - ) + return "https://matrix.to/#/%s/%s" % (notif["room_id"], notif["event_id"]) def make_unsubscribe_link(self, user_id, app_id, email_address): params = { @@ -530,12 +576,18 @@ class Mailer(object): def safe_markup(raw_html): - return jinja2.Markup(bleach.linkify(bleach.clean( - raw_html, tags=ALLOWED_TAGS, attributes=ALLOWED_ATTRS, - # bleach master has this, but it isn't released yet - # protocols=ALLOWED_SCHEMES, - strip=True - ))) + return jinja2.Markup( + bleach.linkify( + bleach.clean( + raw_html, + tags=ALLOWED_TAGS, + attributes=ALLOWED_ATTRS, + # bleach master has this, but it isn't released yet + # protocols=ALLOWED_SCHEMES, + strip=True, + ) + ) + ) def safe_text(raw_text): @@ -543,10 +595,9 @@ def safe_text(raw_text): Process text: treat it as HTML but escape any tags (ie. just escape the HTML) then linkify it. """ - return jinja2.Markup(bleach.linkify(bleach.clean( - raw_text, tags=[], attributes={}, - strip=False - ))) + return jinja2.Markup( + bleach.linkify(bleach.clean(raw_text, tags=[], attributes={}, strip=False)) + ) def deduped_ordered_list(l): @@ -595,15 +646,11 @@ def _create_mxc_to_http_filter(config): serverAndMediaId = value[6:] fragment = None - if '#' in serverAndMediaId: - (serverAndMediaId, fragment) = serverAndMediaId.split('#', 1) + if "#" in serverAndMediaId: + (serverAndMediaId, fragment) = serverAndMediaId.split("#", 1) fragment = "#" + fragment - params = { - "width": width, - "height": height, - "method": resize_method, - } + params = {"width": width, "height": height, "method": resize_method} return "%s_matrix/media/v1/thumbnail/%s?%s%s" % ( config.public_baseurl, serverAndMediaId, diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py index 0c66702325..06056fbf4f 100644 --- a/synapse/push/presentable_names.py +++ b/synapse/push/presentable_names.py @@ -28,8 +28,13 @@ ALL_ALONE = "Empty Room" @defer.inlineCallbacks -def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True, - fallback_to_single_member=True): +def calculate_room_name( + store, + room_state_ids, + user_id, + fallback_to_members=True, + fallback_to_single_member=True, +): """ Works out a user-facing name for the given room as per Matrix spec recommendations. @@ -58,8 +63,10 @@ def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True room_state_ids[("m.room.canonical_alias", "")], allow_none=True ) if ( - canon_alias and canon_alias.content and canon_alias.content["alias"] and - _looks_like_an_alias(canon_alias.content["alias"]) + canon_alias + and canon_alias.content + and canon_alias.content["alias"] + and _looks_like_an_alias(canon_alias.content["alias"]) ): defer.returnValue(canon_alias.content["alias"]) @@ -71,9 +78,7 @@ def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True if "m.room.aliases" in room_state_bytype_ids: m_room_aliases = room_state_bytype_ids["m.room.aliases"] for alias_id in m_room_aliases.values(): - alias_event = yield store.get_event( - alias_id, allow_none=True - ) + alias_event = yield store.get_event(alias_id, allow_none=True) if alias_event and alias_event.content.get("aliases"): the_aliases = alias_event.content["aliases"] if len(the_aliases) > 0 and _looks_like_an_alias(the_aliases[0]): @@ -89,8 +94,8 @@ def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True ) if ( - my_member_event is not None and - my_member_event.content['membership'] == "invite" + my_member_event is not None + and my_member_event.content["membership"] == "invite" ): if ("m.room.member", my_member_event.sender) in room_state_ids: inviter_member_event = yield store.get_event( @@ -100,9 +105,8 @@ def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True if inviter_member_event: if fallback_to_single_member: defer.returnValue( - "Invite from %s" % ( - name_from_member_event(inviter_member_event), - ) + "Invite from %s" + % (name_from_member_event(inviter_member_event),) ) else: return @@ -116,8 +120,10 @@ def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True list(room_state_bytype_ids["m.room.member"].values()) ) all_members = [ - ev for ev in member_events.values() - if ev.content['membership'] == "join" or ev.content['membership'] == "invite" + ev + for ev in member_events.values() + if ev.content["membership"] == "join" + or ev.content["membership"] == "invite" ] # Sort the member events oldest-first so the we name people in the # order the joined (it should at least be deterministic rather than @@ -134,9 +140,9 @@ def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True # or inbound invite, or outbound 3PID invite. if all_members[0].sender == user_id: if "m.room.third_party_invite" in room_state_bytype_ids: - third_party_invites = ( - room_state_bytype_ids["m.room.third_party_invite"].values() - ) + third_party_invites = room_state_bytype_ids[ + "m.room.third_party_invite" + ].values() if len(third_party_invites) > 0: # technically third party invite events are not member @@ -191,8 +197,9 @@ def descriptor_from_member_events(member_events): def name_from_member_event(member_event): if ( - member_event.content and "displayname" in member_event.content and - member_event.content["displayname"] + member_event.content + and "displayname" in member_event.content + and member_event.content["displayname"] ): return member_event.content["displayname"] return member_event.state_key diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index cf6c8b875e..5ed9147de4 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -26,8 +26,8 @@ from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) -GLOB_REGEX = re.compile(r'\\\[(\\\!|)(.*)\\\]') -IS_GLOB = re.compile(r'[\?\*\[\]]') +GLOB_REGEX = re.compile(r"\\\[(\\\!|)(.*)\\\]") +IS_GLOB = re.compile(r"[\?\*\[\]]") INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$") @@ -36,20 +36,20 @@ def _room_member_count(ev, condition, room_member_count): def _sender_notification_permission(ev, condition, sender_power_level, power_levels): - notif_level_key = condition.get('key') + notif_level_key = condition.get("key") if notif_level_key is None: return False - notif_levels = power_levels.get('notifications', {}) + notif_levels = power_levels.get("notifications", {}) room_notif_level = notif_levels.get(notif_level_key, 50) return sender_power_level >= room_notif_level def _test_ineq_condition(condition, number): - if 'is' not in condition: + if "is" not in condition: return False - m = INEQUALITY_EXPR.match(condition['is']) + m = INEQUALITY_EXPR.match(condition["is"]) if not m: return False ineq = m.group(1) @@ -58,15 +58,15 @@ def _test_ineq_condition(condition, number): return False rhs = int(rhs) - if ineq == '' or ineq == '==': + if ineq == "" or ineq == "==": return number == rhs - elif ineq == '<': + elif ineq == "<": return number < rhs - elif ineq == '>': + elif ineq == ">": return number > rhs - elif ineq == '>=': + elif ineq == ">=": return number >= rhs - elif ineq == '<=': + elif ineq == "<=": return number <= rhs else: return False @@ -77,8 +77,8 @@ def tweaks_for_actions(actions): for a in actions: if not isinstance(a, dict): continue - if 'set_tweak' in a and 'value' in a: - tweaks[a['set_tweak']] = a['value'] + if "set_tweak" in a and "value" in a: + tweaks[a["set_tweak"]] = a["value"] return tweaks @@ -93,26 +93,24 @@ class PushRuleEvaluatorForEvent(object): self._value_cache = _flatten_dict(event) def matches(self, condition, user_id, display_name): - if condition['kind'] == 'event_match': + if condition["kind"] == "event_match": return self._event_match(condition, user_id) - elif condition['kind'] == 'contains_display_name': + elif condition["kind"] == "contains_display_name": return self._contains_display_name(display_name) - elif condition['kind'] == 'room_member_count': - return _room_member_count( - self._event, condition, self._room_member_count - ) - elif condition['kind'] == 'sender_notification_permission': + elif condition["kind"] == "room_member_count": + return _room_member_count(self._event, condition, self._room_member_count) + elif condition["kind"] == "sender_notification_permission": return _sender_notification_permission( - self._event, condition, self._sender_power_level, self._power_levels, + self._event, condition, self._sender_power_level, self._power_levels ) else: return True def _event_match(self, condition, user_id): - pattern = condition.get('pattern', None) + pattern = condition.get("pattern", None) if not pattern: - pattern_type = condition.get('pattern_type', None) + pattern_type = condition.get("pattern_type", None) if pattern_type == "user_id": pattern = user_id elif pattern_type == "user_localpart": @@ -123,14 +121,14 @@ class PushRuleEvaluatorForEvent(object): return False # XXX: optimisation: cache our pattern regexps - if condition['key'] == 'content.body': + if condition["key"] == "content.body": body = self._event.content.get("body", None) if not body: return False return _glob_matches(pattern, body, word_boundary=True) else: - haystack = self._get_value(condition['key']) + haystack = self._get_value(condition["key"]) if haystack is None: return False @@ -193,16 +191,13 @@ def _glob_to_re(glob, word_boundary): if IS_GLOB.search(glob): r = re.escape(glob) - r = r.replace(r'\*', '.*?') - r = r.replace(r'\?', '.') + r = r.replace(r"\*", ".*?") + r = r.replace(r"\?", ".") # handle [abc], [a-z] and [!a-z] style ranges. r = GLOB_REGEX.sub( lambda x: ( - '[%s%s]' % ( - x.group(1) and '^' or '', - x.group(2).replace(r'\\\-', '-') - ) + "[%s%s]" % (x.group(1) and "^" or "", x.group(2).replace(r"\\\-", "-")) ), r, ) diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index 8049c298c2..e37269cdb9 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -23,9 +23,7 @@ def get_badge_count(store, user_id): invites = yield store.get_invited_rooms_for_user(user_id) joins = yield store.get_rooms_for_user(user_id) - my_receipts_by_room = yield store.get_receipts_for_user( - user_id, "m.read", - ) + my_receipts_by_room = yield store.get_receipts_for_user(user_id, "m.read") badge = len(invites) @@ -57,10 +55,10 @@ def get_context_for_event(store, state_handler, ev, user_id): store, room_state_ids, user_id, fallback_to_single_member=False ) if name: - ctx['name'] = name + ctx["name"] = name sender_state_event_id = room_state_ids[("m.room.member", ev.sender)] sender_state_event = yield store.get_event(sender_state_event_id) - ctx['sender_display_name'] = name_from_member_event(sender_state_event) + ctx["sender_display_name"] = name_from_member_event(sender_state_event) defer.returnValue(ctx) diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py index aff85daeb5..a9c64a9c54 100644 --- a/synapse/push/pusher.py +++ b/synapse/push/pusher.py @@ -36,9 +36,7 @@ class PusherFactory(object): def __init__(self, hs): self.hs = hs - self.pusher_types = { - "http": HttpPusher, - } + self.pusher_types = {"http": HttpPusher} logger.info("email enable notifs: %r", hs.config.email_enable_notifs) if hs.config.email_enable_notifs: @@ -56,7 +54,7 @@ class PusherFactory(object): logger.info("defined email pusher type") def create_pusher(self, pusherdict): - kind = pusherdict['kind'] + kind = pusherdict["kind"] f = self.pusher_types.get(kind, None) if not f: return None @@ -77,8 +75,8 @@ class PusherFactory(object): return EmailPusher(self.hs, pusherdict, mailer) def _app_name_from_pusherdict(self, pusherdict): - if 'data' in pusherdict and 'brand' in pusherdict['data']: - app_name = pusherdict['data']['brand'] + if "data" in pusherdict and "brand" in pusherdict["data"]: + app_name = pusherdict["data"]["brand"] else: app_name = self.hs.config.email_app_name diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 63c583565f..df6f670740 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -40,6 +40,7 @@ class PusherPool: notifications are sent; accordingly Pusher.on_started, Pusher.on_new_notifications and Pusher.on_new_receipts are not expected to return deferreds. """ + def __init__(self, _hs): self.hs = _hs self.pusher_factory = PusherFactory(_hs) @@ -57,9 +58,19 @@ class PusherPool: run_as_background_process("start_pushers", self._start_pushers) @defer.inlineCallbacks - def add_pusher(self, user_id, access_token, kind, app_id, - app_display_name, device_display_name, pushkey, lang, data, - profile_tag=""): + def add_pusher( + self, + user_id, + access_token, + kind, + app_id, + app_display_name, + device_display_name, + pushkey, + lang, + data, + profile_tag="", + ): """Creates a new pusher and adds it to the pool Returns: @@ -71,21 +82,23 @@ class PusherPool: # will then get pulled out of the database, # recreated, added and started: this means we have only one # code path adding pushers. - self.pusher_factory.create_pusher({ - "id": None, - "user_name": user_id, - "kind": kind, - "app_id": app_id, - "app_display_name": app_display_name, - "device_display_name": device_display_name, - "pushkey": pushkey, - "ts": time_now_msec, - "lang": lang, - "data": data, - "last_stream_ordering": None, - "last_success": None, - "failing_since": None - }) + self.pusher_factory.create_pusher( + { + "id": None, + "user_name": user_id, + "kind": kind, + "app_id": app_id, + "app_display_name": app_display_name, + "device_display_name": device_display_name, + "pushkey": pushkey, + "ts": time_now_msec, + "lang": lang, + "data": data, + "last_stream_ordering": None, + "last_success": None, + "failing_since": None, + } + ) # create the pusher setting last_stream_ordering to the current maximum # stream ordering in event_push_actions, so it will process @@ -113,18 +126,19 @@ class PusherPool: defer.returnValue(pusher) @defer.inlineCallbacks - def remove_pushers_by_app_id_and_pushkey_not_user(self, app_id, pushkey, - not_user_id): - to_remove = yield self.store.get_pushers_by_app_id_and_pushkey( - app_id, pushkey - ) + def remove_pushers_by_app_id_and_pushkey_not_user( + self, app_id, pushkey, not_user_id + ): + to_remove = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) for p in to_remove: - if p['user_name'] != not_user_id: + if p["user_name"] != not_user_id: logger.info( "Removing pusher for app id %s, pushkey %s, user %s", - app_id, pushkey, p['user_name'] + app_id, + pushkey, + p["user_name"], ) - yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) + yield self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) @defer.inlineCallbacks def remove_pushers_by_access_token(self, user_id, access_tokens): @@ -138,14 +152,14 @@ class PusherPool: """ tokens = set(access_tokens) for p in (yield self.store.get_pushers_by_user_id(user_id)): - if p['access_token'] in tokens: + if p["access_token"] in tokens: logger.info( "Removing pusher for app id %s, pushkey %s, user %s", - p['app_id'], p['pushkey'], p['user_name'] - ) - yield self.remove_pusher( - p['app_id'], p['pushkey'], p['user_name'], + p["app_id"], + p["pushkey"], + p["user_name"], ) + yield self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) @defer.inlineCallbacks def on_new_notifications(self, min_stream_id, max_stream_id): @@ -199,13 +213,11 @@ class PusherPool: if not self._should_start_pushers: return - resultlist = yield self.store.get_pushers_by_app_id_and_pushkey( - app_id, pushkey - ) + resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) pusher_dict = None for r in resultlist: - if r['user_name'] == user_id: + if r["user_name"] == user_id: pusher_dict = r pusher = None @@ -245,9 +257,9 @@ class PusherPool: except PusherConfigException as e: logger.warning( "Pusher incorrectly configured user=%s, appid=%s, pushkey=%s: %s", - pusherdict.get('user_name'), - pusherdict.get('app_id'), - pusherdict.get('pushkey'), + pusherdict.get("user_name"), + pusherdict.get("app_id"), + pusherdict.get("pushkey"), e, ) return @@ -258,11 +270,8 @@ class PusherPool: if not p: return - appid_pushkey = "%s:%s" % ( - pusherdict['app_id'], - pusherdict['pushkey'], - ) - byuser = self.pushers.setdefault(pusherdict['user_name'], {}) + appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"]) + byuser = self.pushers.setdefault(pusherdict["user_name"], {}) if appid_pushkey in byuser: byuser[appid_pushkey].on_stop() @@ -275,7 +284,7 @@ class PusherPool: last_stream_ordering = pusherdict["last_stream_ordering"] if last_stream_ordering: have_notifs = yield self.store.get_if_maybe_push_in_range_for_user( - user_id, last_stream_ordering, + user_id, last_stream_ordering ) else: # We always want to default to starting up the pusher rather than diff --git a/synapse/push/rulekinds.py b/synapse/push/rulekinds.py index 4cae48ac07..ce7cc1b4ee 100644 --- a/synapse/push/rulekinds.py +++ b/synapse/push/rulekinds.py @@ -13,10 +13,10 @@ # limitations under the License. PRIORITY_CLASS_MAP = { - 'underride': 1, - 'sender': 2, - 'room': 3, - 'content': 4, - 'override': 5, + "underride": 1, + "sender": 2, + "room": 3, + "content": 4, + "override": 5, } PRIORITY_CLASS_INVERSE_MAP = {v: k for k, v in PRIORITY_CLASS_MAP.items()} diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 11ace2bfb1..77e2d1a0e1 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -45,14 +45,11 @@ REQUIREMENTS = [ "signedjson>=1.0.0", "pynacl>=1.2.1", "idna>=2.5", - # validating SSL certs for IP addresses requires service_identity 18.1. "service_identity>=18.1.0", - # our logcontext handling relies on the ability to cancel inlineCallbacks # (https://twistedmatrix.com/trac/ticket/4632) which landed in Twisted 18.7. "Twisted>=18.7.0", - "treq>=15.1", # Twisted has required pyopenssl 16.0 since about Twisted 16.6. "pyopenssl>=16.0.0", @@ -71,34 +68,27 @@ REQUIREMENTS = [ # prometheus_client 0.4.0 changed the format of counter metrics # (cf https://github.com/matrix-org/synapse/issues/4001) "prometheus_client>=0.0.18,<0.4.0", - # we use attr.s(slots), which arrived in 16.0.0 # Twisted 18.7.0 requires attrs>=17.4.0 "attrs>=17.4.0", - "netaddr>=0.7.18", ] CONDITIONAL_REQUIREMENTS = { "email": ["Jinja2>=2.9", "bleach>=1.4.3"], "matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"], - # we use execute_batch, which arrived in psycopg 2.7. "postgres": ["psycopg2>=2.7"], - # ConsentResource uses select_autoescape, which arrived in jinja 2.9 "resources.consent": ["Jinja2>=2.9"], - # ACME support is required to provision TLS certificates from authorities # that use the protocol, such as Let's Encrypt. "acme": [ "txacme>=0.9.2", - # txacme depends on eliot. Eliot 1.8.0 is incompatible with # python 3.5.2, as per https://github.com/itamarst/eliot/issues/418 'eliot<1.8.0;python_version<"3.5.3"', ], - "saml2": ["pysaml2>=4.5.0"], "systemd": ["systemd-python>=231"], "url_preview": ["lxml>=3.5.0"], @@ -121,12 +111,14 @@ def list_requirements(): class DependencyException(Exception): @property def message(self): - return "\n".join([ - "Missing Requirements: %s" % (", ".join(self.dependencies),), - "To install run:", - " pip install --upgrade --force %s" % (" ".join(self.dependencies),), - "", - ]) + return "\n".join( + [ + "Missing Requirements: %s" % (", ".join(self.dependencies),), + "To install run:", + " pip install --upgrade --force %s" % (" ".join(self.dependencies),), + "", + ] + ) @property def dependencies(self): diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 0a432a16fa..fe482e279f 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -83,8 +83,7 @@ class ReplicationEndpoint(object): def __init__(self, hs): if self.CACHE: self.response_cache = ResponseCache( - hs, "repl." + self.NAME, - timeout_ms=30 * 60 * 1000, + hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000 ) assert self.METHOD in ("PUT", "POST", "GET") @@ -134,8 +133,7 @@ class ReplicationEndpoint(object): data = yield cls._serialize_payload(**kwargs) url_args = [ - urllib.parse.quote(kwargs[name], safe='') - for name in cls.PATH_ARGS + urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS ] if cls.CACHE: @@ -156,7 +154,10 @@ class ReplicationEndpoint(object): ) uri = "http://%s:%s/_synapse/replication/%s/%s" % ( - host, port, cls.NAME, "/".join(url_args) + host, + port, + cls.NAME, + "/".join(url_args), ) try: @@ -202,10 +203,7 @@ class ReplicationEndpoint(object): url_args.append("txn_id") args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args) - pattern = re.compile("^/_synapse/replication/%s/%s$" % ( - self.NAME, - args - )) + pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args)) http_server.register_paths(method, [pattern], handler) @@ -219,8 +217,4 @@ class ReplicationEndpoint(object): assert self.CACHE - return self.response_cache.wrap( - txn_id, - self._handle_request, - request, **kwargs - ) + return self.response_cache.wrap(txn_id, self._handle_request, request, **kwargs) diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 0f0a07c422..61eafbe708 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -68,18 +68,17 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): for event, context in event_and_contexts: serialized_context = yield context.serialize(event, store) - event_payloads.append({ - "event": event.get_pdu_json(), - "event_format_version": event.format_version, - "internal_metadata": event.internal_metadata.get_dict(), - "rejected_reason": event.rejected_reason, - "context": serialized_context, - }) - - payload = { - "events": event_payloads, - "backfilled": backfilled, - } + event_payloads.append( + { + "event": event.get_pdu_json(), + "event_format_version": event.format_version, + "internal_metadata": event.internal_metadata.get_dict(), + "rejected_reason": event.rejected_reason, + "context": serialized_context, + } + ) + + payload = {"events": event_payloads, "backfilled": backfilled} defer.returnValue(payload) @@ -103,18 +102,15 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): event = EventType(event_dict, internal_metadata, rejected_reason) context = yield EventContext.deserialize( - self.store, event_payload["context"], + self.store, event_payload["context"] ) event_and_contexts.append((event, context)) - logger.info( - "Got %d events from federation", - len(event_and_contexts), - ) + logger.info("Got %d events from federation", len(event_and_contexts)) yield self.federation_handler.persist_events_and_notify( - event_and_contexts, backfilled, + event_and_contexts, backfilled ) defer.returnValue((200, {})) @@ -146,10 +142,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): @staticmethod def _serialize_payload(edu_type, origin, content): - return { - "origin": origin, - "content": content, - } + return {"origin": origin, "content": content} @defer.inlineCallbacks def _handle_request(self, request, edu_type): @@ -159,10 +152,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): origin = content["origin"] edu_content = content["content"] - logger.info( - "Got %r edu from %s", - edu_type, origin, - ) + logger.info("Got %r edu from %s", edu_type, origin) result = yield self.registry.on_edu(edu_type, origin, edu_content) @@ -201,9 +191,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): query_type (str) args (dict): The arguments received for the given query type """ - return { - "args": args, - } + return {"args": args} @defer.inlineCallbacks def _handle_request(self, request, query_type): @@ -212,10 +200,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): args = content["args"] - logger.info( - "Got %r query", - query_type, - ) + logger.info("Got %r query", query_type) result = yield self.registry.on_query(query_type, args) diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index 63bc0405ea..7c1197e5dd 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -61,13 +61,10 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): is_guest = content["is_guest"] device_id, access_token = yield self.registration_handler.register_device( - user_id, device_id, initial_display_name, is_guest, + user_id, device_id, initial_display_name, is_guest ) - defer.returnValue((200, { - "device_id": device_id, - "access_token": access_token, - })) + defer.returnValue((200, {"device_id": device_id, "access_token": access_token})) def register_servlets(hs, http_server): diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 81a2b204c7..0a76a3762f 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -40,7 +40,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): """ NAME = "remote_join" - PATH_ARGS = ("room_id", "user_id",) + PATH_ARGS = ("room_id", "user_id") def __init__(self, hs): super(ReplicationRemoteJoinRestServlet, self).__init__(hs) @@ -50,8 +50,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): self.clock = hs.get_clock() @staticmethod - def _serialize_payload(requester, room_id, user_id, remote_room_hosts, - content): + def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content): """ Args: requester(Requester) @@ -78,16 +77,10 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): if requester.user: request.authenticated_entity = requester.user.to_string() - logger.info( - "remote_join: %s into room: %s", - user_id, room_id, - ) + logger.info("remote_join: %s into room: %s", user_id, room_id) yield self.federation_handler.do_invite_join( - remote_room_hosts, - room_id, - user_id, - event_content, + remote_room_hosts, room_id, user_id, event_content ) defer.returnValue((200, {})) @@ -107,7 +100,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): """ NAME = "remote_reject_invite" - PATH_ARGS = ("room_id", "user_id",) + PATH_ARGS = ("room_id", "user_id") def __init__(self, hs): super(ReplicationRemoteRejectInviteRestServlet, self).__init__(hs) @@ -141,16 +134,11 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): if requester.user: request.authenticated_entity = requester.user.to_string() - logger.info( - "remote_reject_invite: %s out of room: %s", - user_id, room_id, - ) + logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id) try: event = yield self.federation_handler.do_remotely_reject_invite( - remote_room_hosts, - room_id, - user_id, + remote_room_hosts, room_id, user_id ) ret = event.get_pdu_json() except Exception as e: @@ -162,9 +150,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): # logger.warn("Failed to reject invite: %s", e) - yield self.store.locally_reject_invite( - user_id, room_id - ) + yield self.store.locally_reject_invite(user_id, room_id) ret = {} defer.returnValue((200, ret)) @@ -228,7 +214,7 @@ class ReplicationRegister3PIDGuestRestServlet(ReplicationEndpoint): logger.info("get_or_register_3pid_guest: %r", content) ret = yield self.registeration_handler.get_or_register_3pid_guest( - medium, address, inviter_user_id, + medium, address, inviter_user_id ) defer.returnValue((200, ret)) @@ -264,7 +250,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): user_id (str) change (str): Either "joined" or "left" """ - assert change in ("joined", "left",) + assert change in ("joined", "left") return {} diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index 912a5ac341..f81a0f1b8f 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -37,8 +37,16 @@ class ReplicationRegisterServlet(ReplicationEndpoint): @staticmethod def _serialize_payload( - user_id, token, password_hash, was_guest, make_guest, appservice_id, - create_profile_with_displayname, admin, user_type, address, + user_id, + token, + password_hash, + was_guest, + make_guest, + appservice_id, + create_profile_with_displayname, + admin, + user_type, + address, ): """ Args: @@ -85,7 +93,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): create_profile_with_displayname=content["create_profile_with_displayname"], admin=content["admin"], user_type=content["user_type"], - address=content["address"] + address=content["address"], ) defer.returnValue((200, {})) @@ -104,8 +112,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): self.registration_handler = hs.get_registration_handler() @staticmethod - def _serialize_payload(user_id, auth_result, access_token, bind_email, - bind_msisdn): + def _serialize_payload(user_id, auth_result, access_token, bind_email, bind_msisdn): """ Args: user_id (str): The user ID that consented diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index 3635015eda..034763fe99 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -45,6 +45,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): "extra_users": [], } """ + NAME = "send_event" PATH_ARGS = ("event_id",) @@ -57,8 +58,9 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): @staticmethod @defer.inlineCallbacks - def _serialize_payload(event_id, store, event, context, requester, - ratelimit, extra_users): + def _serialize_payload( + event_id, store, event, context, requester, ratelimit, extra_users + ): """ Args: event_id (str) @@ -108,14 +110,11 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): request.authenticated_entity = requester.user.to_string() logger.info( - "Got event to send with ID: %s into room: %s", - event.event_id, event.room_id, + "Got event to send with ID: %s into room: %s", event.event_id, event.room_id ) yield self.event_creation_handler.persist_and_notify_client_event( - requester, event, context, - ratelimit=ratelimit, - extra_users=extra_users, + requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) defer.returnValue((200, {})) diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 817d1f67f9..182cb2a1d8 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -37,7 +37,7 @@ class BaseSlavedStore(SQLBaseStore): super(BaseSlavedStore, self).__init__(db_conn, hs) if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen = SlavedIdTracker( - db_conn, "cache_invalidation_stream", "stream_id", + db_conn, "cache_invalidation_stream", "stream_id" ) else: self._cache_id_gen = None diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index d9ba6d69b1..3c44d1d48d 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -21,10 +21,9 @@ from synapse.storage.tags import TagsWorkerStore class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): self._account_data_id_gen = SlavedIdTracker( - db_conn, "account_data_max_stream_id", "stream_id", + db_conn, "account_data_max_stream_id", "stream_id" ) super(SlavedAccountDataStore, self).__init__(db_conn, hs) @@ -45,24 +44,20 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved self._account_data_id_gen.advance(token) for row in rows: self.get_tags_for_user.invalidate((row.user_id,)) - self._account_data_stream_cache.entity_has_changed( - row.user_id, token - ) + self._account_data_stream_cache.entity_has_changed(row.user_id, token) elif stream_name == "account_data": self._account_data_id_gen.advance(token) for row in rows: if not row.room_id: self.get_global_account_data_by_type_for_user.invalidate( - (row.data_type, row.user_id,) + (row.data_type, row.user_id) ) self.get_account_data_for_user.invalidate((row.user_id,)) - self.get_account_data_for_room.invalidate((row.user_id, row.room_id,)) + self.get_account_data_for_room.invalidate((row.user_id, row.room_id)) self.get_account_data_for_room_and_type.invalidate( - (row.user_id, row.room_id, row.data_type,), - ) - self._account_data_stream_cache.entity_has_changed( - row.user_id, token + (row.user_id, row.room_id, row.data_type) ) + self._account_data_stream_cache.entity_has_changed(row.user_id, token) return super(SlavedAccountDataStore, self).process_replication_rows( stream_name, token, rows ) diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py index b53a4c6bd1..cda12ea70d 100644 --- a/synapse/replication/slave/storage/appservice.py +++ b/synapse/replication/slave/storage/appservice.py @@ -20,6 +20,7 @@ from synapse.storage.appservice import ( ) -class SlavedApplicationServiceStore(ApplicationServiceTransactionWorkerStore, - ApplicationServiceWorkerStore): +class SlavedApplicationServiceStore( + ApplicationServiceTransactionWorkerStore, ApplicationServiceWorkerStore +): pass diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 5b8521c770..14ced32333 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -25,9 +25,7 @@ class SlavedClientIpStore(BaseSlavedStore): super(SlavedClientIpStore, self).__init__(db_conn, hs) self.client_ip_last_seen = Cache( - name="client_ip_last_seen", - keylen=4, - max_entries=50000 * CACHE_SIZE_FACTOR, + name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR ) def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index 4d59778863..284fd30d89 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -24,15 +24,15 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): super(SlavedDeviceInboxStore, self).__init__(db_conn, hs) self._device_inbox_id_gen = SlavedIdTracker( - db_conn, "device_max_stream_id", "stream_id", + db_conn, "device_max_stream_id", "stream_id" ) self._device_inbox_stream_cache = StreamChangeCache( "DeviceInboxStreamChangeCache", - self._device_inbox_id_gen.get_current_token() + self._device_inbox_id_gen.get_current_token(), ) self._device_federation_outbox_stream_cache = StreamChangeCache( "DeviceFederationOutboxStreamChangeCache", - self._device_inbox_id_gen.get_current_token() + self._device_inbox_id_gen.get_current_token(), ) self._last_device_delete_cache = ExpiringCache( diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 16c9a162c5..d9300fce33 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -27,14 +27,14 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto self.hs = hs self._device_list_id_gen = SlavedIdTracker( - db_conn, "device_lists_stream", "stream_id", + db_conn, "device_lists_stream", "stream_id" ) device_list_max = self._device_list_id_gen.get_current_token() self._device_list_stream_cache = StreamChangeCache( - "DeviceListStreamChangeCache", device_list_max, + "DeviceListStreamChangeCache", device_list_max ) self._device_list_federation_stream_cache = StreamChangeCache( - "DeviceListFederationStreamChangeCache", device_list_max, + "DeviceListFederationStreamChangeCache", device_list_max ) def stream_positions(self): @@ -46,17 +46,13 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto if stream_name == "device_lists": self._device_list_id_gen.advance(token) for row in rows: - self._invalidate_caches_for_devices( - token, row.user_id, row.destination, - ) + self._invalidate_caches_for_devices(token, row.user_id, row.destination) return super(SlavedDeviceStore, self).process_replication_rows( stream_name, token, rows ) def _invalidate_caches_for_devices(self, token, user_id, destination): - self._device_list_stream_cache.entity_has_changed( - user_id, token - ) + self._device_list_stream_cache.entity_has_changed(user_id, token) if destination: self._device_list_federation_stream_cache.entity_has_changed( diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index a3952506c1..ab5937e638 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -45,21 +45,20 @@ logger = logging.getLogger(__name__) # the method descriptor on the DataStore and chuck them into our class. -class SlavedEventStore(EventFederationWorkerStore, - RoomMemberWorkerStore, - EventPushActionsWorkerStore, - StreamWorkerStore, - StateGroupWorkerStore, - EventsWorkerStore, - SignatureWorkerStore, - UserErasureWorkerStore, - RelationsWorkerStore, - BaseSlavedStore): - +class SlavedEventStore( + EventFederationWorkerStore, + RoomMemberWorkerStore, + EventPushActionsWorkerStore, + StreamWorkerStore, + StateGroupWorkerStore, + EventsWorkerStore, + SignatureWorkerStore, + UserErasureWorkerStore, + RelationsWorkerStore, + BaseSlavedStore, +): def __init__(self, db_conn, hs): - self._stream_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering", - ) + self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering") self._backfill_id_gen = SlavedIdTracker( db_conn, "events", "stream_ordering", step=-1 ) @@ -90,8 +89,13 @@ class SlavedEventStore(EventFederationWorkerStore, self._backfill_id_gen.advance(-token) for row in rows: self.invalidate_caches_for_event( - -token, row.event_id, row.room_id, row.type, row.state_key, - row.redacts, row.relates_to, + -token, + row.event_id, + row.room_id, + row.type, + row.state_key, + row.redacts, + row.relates_to, backfilled=True, ) return super(SlavedEventStore, self).process_replication_rows( @@ -103,41 +107,48 @@ class SlavedEventStore(EventFederationWorkerStore, if row.type == EventsStreamEventRow.TypeId: self.invalidate_caches_for_event( - token, data.event_id, data.room_id, data.type, data.state_key, - data.redacts, data.relates_to, + token, + data.event_id, + data.room_id, + data.type, + data.state_key, + data.redacts, + data.relates_to, backfilled=False, ) elif row.type == EventsStreamCurrentStateRow.TypeId: if data.type == EventTypes.Member: self.get_rooms_for_user_with_stream_ordering.invalidate( - (data.state_key, ), + (data.state_key,) ) else: - raise Exception("Unknown events stream row type %s" % (row.type, )) - - def invalidate_caches_for_event(self, stream_ordering, event_id, room_id, - etype, state_key, redacts, relates_to, - backfilled): + raise Exception("Unknown events stream row type %s" % (row.type,)) + + def invalidate_caches_for_event( + self, + stream_ordering, + event_id, + room_id, + etype, + state_key, + redacts, + relates_to, + backfilled, + ): self._invalidate_get_event_cache(event_id) self.get_latest_event_ids_in_room.invalidate((room_id,)) - self.get_unread_event_push_actions_by_room_for_user.invalidate_many( - (room_id,) - ) + self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,)) if not backfilled: - self._events_stream_cache.entity_has_changed( - room_id, stream_ordering - ) + self._events_stream_cache.entity_has_changed(room_id, stream_ordering) if redacts: self._invalidate_get_event_cache(redacts) if etype == EventTypes.Member: - self._membership_stream_cache.entity_has_changed( - state_key, stream_ordering - ) + self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) self.get_invited_rooms_for_user.invalidate((state_key,)) if relates_to: diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index e933b170bb..28a46edd28 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -27,10 +27,11 @@ class SlavedGroupServerStore(BaseSlavedStore): self.hs = hs self._group_updates_id_gen = SlavedIdTracker( - db_conn, "local_group_updates", "stream_id", + db_conn, "local_group_updates", "stream_id" ) self._group_updates_stream_cache = StreamChangeCache( - "_group_updates_stream_cache", self._group_updates_id_gen.get_current_token(), + "_group_updates_stream_cache", + self._group_updates_id_gen.get_current_token(), ) get_groups_changes_for_user = __func__(DataStore.get_groups_changes_for_user) @@ -46,9 +47,7 @@ class SlavedGroupServerStore(BaseSlavedStore): if stream_name == "groups": self._group_updates_id_gen.advance(token) for row in rows: - self._group_updates_stream_cache.entity_has_changed( - row.user_id, token - ) + self._group_updates_stream_cache.entity_has_changed(row.user_id, token) return super(SlavedGroupServerStore, self).process_replication_rows( stream_name, token, rows diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index 0ec1db25ce..82d808af4c 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -24,9 +24,7 @@ from ._slaved_id_tracker import SlavedIdTracker class SlavedPresenceStore(BaseSlavedStore): def __init__(self, db_conn, hs): super(SlavedPresenceStore, self).__init__(db_conn, hs) - self._presence_id_gen = SlavedIdTracker( - db_conn, "presence_stream", "stream_id", - ) + self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id") self._presence_on_startup = self._get_active_presence(db_conn) @@ -55,9 +53,7 @@ class SlavedPresenceStore(BaseSlavedStore): if stream_name == "presence": self._presence_id_gen.advance(token) for row in rows: - self.presence_stream_cache.entity_has_changed( - row.user_id, token - ) + self.presence_stream_cache.entity_has_changed(row.user_id, token) self._get_presence_for_user.invalidate((row.user_id,)) return super(SlavedPresenceStore, self).process_replication_rows( stream_name, token, rows diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 45fc913c52..af7012702e 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -23,7 +23,7 @@ from .events import SlavedEventStore class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): def __init__(self, db_conn, hs): self._push_rules_stream_id_gen = SlavedIdTracker( - db_conn, "push_rules_stream", "stream_id", + db_conn, "push_rules_stream", "stream_id" ) super(SlavedPushRuleStore, self).__init__(db_conn, hs) @@ -47,9 +47,7 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): for row in rows: self.get_push_rules_for_user.invalidate((row.user_id,)) self.get_push_rules_enabled_for_user.invalidate((row.user_id,)) - self.push_rules_stream_cache.entity_has_changed( - row.user_id, token - ) + self.push_rules_stream_cache.entity_has_changed(row.user_id, token) return super(SlavedPushRuleStore, self).process_replication_rows( stream_name, token, rows ) diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index 3b2213c0d4..8eeb267d61 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -21,12 +21,10 @@ from ._slaved_id_tracker import SlavedIdTracker class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): super(SlavedPusherStore, self).__init__(db_conn, hs) self._pushers_id_gen = SlavedIdTracker( - db_conn, "pushers", "id", - extra_tables=[("deleted_pushers", "stream_id")], + db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] ) def stream_positions(self): diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index ed12342f40..91afa5a72b 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -29,7 +29,6 @@ from ._slaved_id_tracker import SlavedIdTracker class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): # We instantiate this first as the ReceiptsWorkerStore constructor # needs to be able to call get_max_receipt_stream_id diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index 0cb474928c..f68b3378e3 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -38,6 +38,4 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore): if stream_name == "public_rooms": self._public_room_id_gen.advance(token) - return super(RoomStore, self).process_replication_rows( - stream_name, token, rows - ) + return super(RoomStore, self).process_replication_rows(stream_name, token, rows) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 206dc3b397..a44ceb00e7 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -39,6 +39,7 @@ class ReplicationClientFactory(ReconnectingClientFactory): Accepts a handler that will be called when new data is available or data is required. """ + maxDelay = 30 # Try at least once every N seconds def __init__(self, hs, client_name, handler): @@ -64,9 +65,7 @@ class ReplicationClientFactory(ReconnectingClientFactory): def clientConnectionFailed(self, connector, reason): logger.error("Failed to connect to replication: %r", reason) - ReconnectingClientFactory.clientConnectionFailed( - self, connector, reason - ) + ReconnectingClientFactory.clientConnectionFailed(self, connector, reason) class ReplicationClientHandler(object): @@ -74,6 +73,7 @@ class ReplicationClientHandler(object): By default proxies incoming replication data to the SlaveStore. """ + def __init__(self, store): self.store = store diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 2098c32a77..0ff2a7199f 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -23,9 +23,11 @@ import platform if platform.python_implementation() == "PyPy": import json + _json_encoder = json.JSONEncoder() else: import simplejson as json + _json_encoder = json.JSONEncoder(namedtuple_as_object=False) logger = logging.getLogger(__name__) @@ -41,6 +43,7 @@ class Command(object): The default implementation creates a command of form ` ` """ + NAME = None def __init__(self, data): @@ -73,6 +76,7 @@ class ServerCommand(Command): SERVER """ + NAME = "SERVER" @@ -99,6 +103,7 @@ class RdataCommand(Command): RDATA presence batch ["@bar:example.com", "online", ...] RDATA presence 59 ["@baz:example.com", "online", ...] """ + NAME = "RDATA" def __init__(self, stream_name, token, row): @@ -110,17 +115,17 @@ class RdataCommand(Command): def from_line(cls, line): stream_name, token, row_json = line.split(" ", 2) return cls( - stream_name, - None if token == "batch" else int(token), - json.loads(row_json) + stream_name, None if token == "batch" else int(token), json.loads(row_json) ) def to_line(self): - return " ".join(( - self.stream_name, - str(self.token) if self.token is not None else "batch", - _json_encoder.encode(self.row), - )) + return " ".join( + ( + self.stream_name, + str(self.token) if self.token is not None else "batch", + _json_encoder.encode(self.row), + ) + ) def get_logcontext_id(self): return "RDATA-" + self.stream_name @@ -133,6 +138,7 @@ class PositionCommand(Command): Sent to the client after all missing updates for a stream have been sent to the client and they're now up to date. """ + NAME = "POSITION" def __init__(self, stream_name, token): @@ -145,19 +151,21 @@ class PositionCommand(Command): return cls(stream_name, int(token)) def to_line(self): - return " ".join((self.stream_name, str(self.token),)) + return " ".join((self.stream_name, str(self.token))) class ErrorCommand(Command): """Sent by either side if there was an ERROR. The data is a string describing the error. """ + NAME = "ERROR" class PingCommand(Command): """Sent by either side as a keep alive. The data is arbitary (often timestamp) """ + NAME = "PING" @@ -165,6 +173,7 @@ class NameCommand(Command): """Sent by client to inform the server of the client's identity. The data is the name """ + NAME = "NAME" @@ -184,6 +193,7 @@ class ReplicateCommand(Command): REPLICATE ALL NOW """ + NAME = "REPLICATE" def __init__(self, stream_name, token): @@ -200,7 +210,7 @@ class ReplicateCommand(Command): return cls(stream_name, token) def to_line(self): - return " ".join((self.stream_name, str(self.token),)) + return " ".join((self.stream_name, str(self.token))) def get_logcontext_id(self): return "REPLICATE-" + self.stream_name @@ -218,6 +228,7 @@ class UserSyncCommand(Command): Where is either "start" or "stop" """ + NAME = "USER_SYNC" def __init__(self, user_id, is_syncing, last_sync_ms): @@ -235,9 +246,13 @@ class UserSyncCommand(Command): return cls(user_id, state == "start", int(last_sync_ms)) def to_line(self): - return " ".join(( - self.user_id, "start" if self.is_syncing else "end", str(self.last_sync_ms), - )) + return " ".join( + ( + self.user_id, + "start" if self.is_syncing else "end", + str(self.last_sync_ms), + ) + ) class FederationAckCommand(Command): @@ -251,6 +266,7 @@ class FederationAckCommand(Command): FEDERATION_ACK """ + NAME = "FEDERATION_ACK" def __init__(self, token): @@ -268,6 +284,7 @@ class SyncCommand(Command): """Used for testing. The client protocol implementation allows waiting on a SYNC command with a specified data. """ + NAME = "SYNC" @@ -278,6 +295,7 @@ class RemovePusherCommand(Command): REMOVE_PUSHER """ + NAME = "REMOVE_PUSHER" def __init__(self, app_id, push_key, user_id): @@ -309,6 +327,7 @@ class InvalidateCacheCommand(Command): Where is a json list. """ + NAME = "INVALIDATE_CACHE" def __init__(self, cache_func, keys): @@ -322,9 +341,7 @@ class InvalidateCacheCommand(Command): return cls(cache_func, json.loads(keys_json)) def to_line(self): - return " ".join(( - self.cache_func, _json_encoder.encode(self.keys), - )) + return " ".join((self.cache_func, _json_encoder.encode(self.keys))) class UserIpCommand(Command): @@ -334,6 +351,7 @@ class UserIpCommand(Command): USER_IP , , , , , """ + NAME = "USER_IP" def __init__(self, user_id, access_token, ip, user_agent, device_id, last_seen): @@ -350,15 +368,22 @@ class UserIpCommand(Command): access_token, ip, user_agent, device_id, last_seen = json.loads(jsn) - return cls( - user_id, access_token, ip, user_agent, device_id, last_seen - ) + return cls(user_id, access_token, ip, user_agent, device_id, last_seen) def to_line(self): - return self.user_id + " " + _json_encoder.encode(( - self.access_token, self.ip, self.user_agent, self.device_id, - self.last_seen, - )) + return ( + self.user_id + + " " + + _json_encoder.encode( + ( + self.access_token, + self.ip, + self.user_agent, + self.device_id, + self.last_seen, + ) + ) + ) # Map of command name to command type. diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index b51590cf8f..97efb835ad 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -84,7 +84,8 @@ from .commands import ( from .streams import STREAMS_MAP connection_close_counter = Counter( - "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]) + "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"] +) # A list of all connected protocols. This allows us to send metrics about the # connections. @@ -119,7 +120,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): It also sends `PING` periodically, and correctly times out remote connections (if they send a `PING` command) """ - delimiter = b'\n' + + delimiter = b"\n" VALID_INBOUND_COMMANDS = [] # Valid commands we expect to receive VALID_OUTBOUND_COMMANDS = [] # Valid commans we can send @@ -183,10 +185,14 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): if now - self.last_sent_command >= PING_TIME: self.send_command(PingCommand(now)) - if self.received_ping and now - self.last_received_command > PING_TIMEOUT_MS: + if ( + self.received_ping + and now - self.last_received_command > PING_TIMEOUT_MS + ): logger.info( "[%s] Connection hasn't received command in %r ms. Closing.", - self.id(), now - self.last_received_command + self.id(), + now - self.last_received_command, ) self.send_error("ping timeout") @@ -208,7 +214,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.last_received_command = self.clock.time_msec() self.inbound_commands_counter[cmd_name] = ( - self.inbound_commands_counter[cmd_name] + 1) + self.inbound_commands_counter[cmd_name] + 1 + ) cmd_cls = COMMAND_MAP[cmd_name] try: @@ -224,9 +231,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # Now lets try and call on_ function run_as_background_process( - "replication-" + cmd.get_logcontext_id(), - self.handle_command, - cmd, + "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd ) def handle_command(self, cmd): @@ -274,8 +279,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): return self.outbound_commands_counter[cmd.NAME] = ( - self.outbound_commands_counter[cmd.NAME] + 1) - string = "%s %s" % (cmd.NAME, cmd.to_line(),) + self.outbound_commands_counter[cmd.NAME] + 1 + ) + string = "%s %s" % (cmd.NAME, cmd.to_line()) if "\n" in string: raise Exception("Unexpected newline in command: %r", string) @@ -283,10 +289,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): if len(encoded_string) > self.MAX_LENGTH: raise Exception( - "Failed to send command %s as too long (%d > %d)" % ( - cmd.NAME, - len(encoded_string), self.MAX_LENGTH, - ) + "Failed to send command %s as too long (%d > %d)" + % (cmd.NAME, len(encoded_string), self.MAX_LENGTH) ) self.sendLine(encoded_string) @@ -379,7 +383,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): if self.transport: addr = str(self.transport.getPeer()) return "ReplicationConnection" % ( - self.name, self.conn_id, addr, + self.name, + self.conn_id, + addr, ) def id(self): @@ -422,7 +428,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): def on_USER_SYNC(self, cmd): return self.streamer.on_user_sync( - self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms, + self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms ) def on_REPLICATE(self, cmd): @@ -432,10 +438,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): if stream_name == "ALL": # Subscribe to all streams we're publishing to. deferreds = [ - run_in_background( - self.subscribe_to_stream, - stream, token, - ) + run_in_background(self.subscribe_to_stream, stream, token) for stream in iterkeys(self.streamer.streams_by_name) ] @@ -449,16 +452,18 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): return self.streamer.federation_ack(cmd.token) def on_REMOVE_PUSHER(self, cmd): - return self.streamer.on_remove_pusher( - cmd.app_id, cmd.push_key, cmd.user_id, - ) + return self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id) def on_INVALIDATE_CACHE(self, cmd): return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) def on_USER_IP(self, cmd): return self.streamer.on_user_ip( - cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id, + cmd.user_id, + cmd.access_token, + cmd.ip, + cmd.user_agent, + cmd.device_id, cmd.last_seen, ) @@ -476,7 +481,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): try: # Get missing updates updates, current_token = yield self.streamer.get_stream_updates( - stream_name, token, + stream_name, token ) # Send all the missing updates @@ -608,8 +613,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): row = STREAMS_MAP[stream_name].parse_row(cmd.row) except Exception: logger.exception( - "[%s] Failed to parse RDATA: %r %r", - self.id(), stream_name, cmd.row + "[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row ) raise @@ -643,7 +647,9 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): logger.info( "[%s] Subscribing to replication stream: %r from %r", - self.id(), stream_name, token + self.id(), + stream_name, + token, ) self.streams_connecting.add(stream_name) @@ -661,9 +667,7 @@ pending_commands = LaterGauge( "synapse_replication_tcp_protocol_pending_commands", "", ["name"], - lambda: { - (p.name,): len(p.pending_commands) for p in connected_connections - }, + lambda: {(p.name,): len(p.pending_commands) for p in connected_connections}, ) @@ -678,9 +682,7 @@ transport_send_buffer = LaterGauge( "synapse_replication_tcp_protocol_transport_send_buffer", "", ["name"], - lambda: { - (p.name,): transport_buffer_size(p) for p in connected_connections - }, + lambda: {(p.name,): transport_buffer_size(p) for p in connected_connections}, ) @@ -694,7 +696,7 @@ def transport_kernel_read_buffer_size(protocol, read=True): op = SIOCINQ else: op = SIOCOUTQ - size = struct.unpack("I", fcntl.ioctl(fileno, op, '\0\0\0\0'))[0] + size = struct.unpack("I", fcntl.ioctl(fileno, op, "\0\0\0\0"))[0] return size return 0 @@ -726,7 +728,7 @@ tcp_inbound_commands = LaterGauge( "", ["command", "name"], lambda: { - (k, p.name,): count + (k, p.name): count for p in connected_connections for k, count in iteritems(p.inbound_commands_counter) }, @@ -737,7 +739,7 @@ tcp_outbound_commands = LaterGauge( "", ["command", "name"], lambda: { - (k, p.name,): count + (k, p.name): count for p in connected_connections for k, count in iteritems(p.outbound_commands_counter) }, diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index f6a38f5140..d1e98428bc 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -33,13 +33,15 @@ from .protocol import ServerReplicationStreamProtocol from .streams import STREAMS_MAP from .streams.federation import FederationStream -stream_updates_counter = Counter("synapse_replication_tcp_resource_stream_updates", - "", ["stream_name"]) +stream_updates_counter = Counter( + "synapse_replication_tcp_resource_stream_updates", "", ["stream_name"] +) user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "") federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "") remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "") -invalidate_cache_counter = Counter("synapse_replication_tcp_resource_invalidate_cache", - "") +invalidate_cache_counter = Counter( + "synapse_replication_tcp_resource_invalidate_cache", "" +) user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "") logger = logging.getLogger(__name__) @@ -48,6 +50,7 @@ logger = logging.getLogger(__name__) class ReplicationStreamProtocolFactory(Factory): """Factory for new replication connections. """ + def __init__(self, hs): self.streamer = ReplicationStreamer(hs) self.clock = hs.get_clock() @@ -55,9 +58,7 @@ class ReplicationStreamProtocolFactory(Factory): def buildProtocol(self, addr): return ServerReplicationStreamProtocol( - self.server_name, - self.clock, - self.streamer, + self.server_name, self.clock, self.streamer ) @@ -80,29 +81,39 @@ class ReplicationStreamer(object): # Current connections. self.connections = [] - LaterGauge("synapse_replication_tcp_resource_total_connections", "", [], - lambda: len(self.connections)) + LaterGauge( + "synapse_replication_tcp_resource_total_connections", + "", + [], + lambda: len(self.connections), + ) # List of streams that clients can subscribe to. # We only support federation stream if federation sending hase been # disabled on the master. self.streams = [ - stream(hs) for stream in itervalues(STREAMS_MAP) + stream(hs) + for stream in itervalues(STREAMS_MAP) if stream != FederationStream or not hs.config.send_federation ] self.streams_by_name = {stream.NAME: stream for stream in self.streams} LaterGauge( - "synapse_replication_tcp_resource_connections_per_stream", "", + "synapse_replication_tcp_resource_connections_per_stream", + "", ["stream_name"], lambda: { - (stream_name,): len([ - conn for conn in self.connections - if stream_name in conn.replication_streams - ]) + (stream_name,): len( + [ + conn + for conn in self.connections + if stream_name in conn.replication_streams + ] + ) for stream_name in self.streams_by_name - }) + }, + ) self.federation_sender = None if not hs.config.send_federation: @@ -179,7 +190,9 @@ class ReplicationStreamer(object): logger.debug( "Getting stream: %s: %s -> %s", - stream.NAME, stream.last_token, stream.upto_token + stream.NAME, + stream.last_token, + stream.upto_token, ) try: updates, current_token = yield stream.get_updates() @@ -189,7 +202,8 @@ class ReplicationStreamer(object): logger.debug( "Sending %d updates to %d connections", - len(updates), len(self.connections), + len(updates), + len(self.connections), ) if updates: @@ -243,7 +257,7 @@ class ReplicationStreamer(object): """ user_sync_counter.inc() yield self.presence_handler.update_external_syncs_row( - conn_id, user_id, is_syncing, last_sync_ms, + conn_id, user_id, is_syncing, last_sync_ms ) @measure_func("repl.on_remove_pusher") @@ -272,7 +286,7 @@ class ReplicationStreamer(object): """ user_ip_cache_counter.inc() yield self.store.insert_client_ip( - user_id, access_token, ip, user_agent, device_id, last_seen, + user_id, access_token, ip, user_agent, device_id, last_seen ) yield self._server_notices_sender.on_user_ip(user_id) diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index b6ce7a7bee..7ef67a5a73 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -26,78 +26,75 @@ logger = logging.getLogger(__name__) MAX_EVENTS_BEHIND = 10000 -BackfillStreamRow = namedtuple("BackfillStreamRow", ( - "event_id", # str - "room_id", # str - "type", # str - "state_key", # str, optional - "redacts", # str, optional - "relates_to", # str, optional -)) -PresenceStreamRow = namedtuple("PresenceStreamRow", ( - "user_id", # str - "state", # str - "last_active_ts", # int - "last_federation_update_ts", # int - "last_user_sync_ts", # int - "status_msg", # str - "currently_active", # bool -)) -TypingStreamRow = namedtuple("TypingStreamRow", ( - "room_id", # str - "user_ids", # list(str) -)) -ReceiptsStreamRow = namedtuple("ReceiptsStreamRow", ( - "room_id", # str - "receipt_type", # str - "user_id", # str - "event_id", # str - "data", # dict -)) -PushRulesStreamRow = namedtuple("PushRulesStreamRow", ( - "user_id", # str -)) -PushersStreamRow = namedtuple("PushersStreamRow", ( - "user_id", # str - "app_id", # str - "pushkey", # str - "deleted", # bool -)) -CachesStreamRow = namedtuple("CachesStreamRow", ( - "cache_func", # str - "keys", # list(str) - "invalidation_ts", # int -)) -PublicRoomsStreamRow = namedtuple("PublicRoomsStreamRow", ( - "room_id", # str - "visibility", # str - "appservice_id", # str, optional - "network_id", # str, optional -)) -DeviceListsStreamRow = namedtuple("DeviceListsStreamRow", ( - "user_id", # str - "destination", # str -)) -ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ( - "entity", # str -)) -TagAccountDataStreamRow = namedtuple("TagAccountDataStreamRow", ( - "user_id", # str - "room_id", # str - "data", # dict -)) -AccountDataStreamRow = namedtuple("AccountDataStream", ( - "user_id", # str - "room_id", # str - "data_type", # str - "data", # dict -)) -GroupsStreamRow = namedtuple("GroupsStreamRow", ( - "group_id", # str - "user_id", # str - "type", # str - "content", # dict -)) +BackfillStreamRow = namedtuple( + "BackfillStreamRow", + ( + "event_id", # str + "room_id", # str + "type", # str + "state_key", # str, optional + "redacts", # str, optional + "relates_to", # str, optional + ), +) +PresenceStreamRow = namedtuple( + "PresenceStreamRow", + ( + "user_id", # str + "state", # str + "last_active_ts", # int + "last_federation_update_ts", # int + "last_user_sync_ts", # int + "status_msg", # str + "currently_active", # bool + ), +) +TypingStreamRow = namedtuple( + "TypingStreamRow", ("room_id", "user_ids") # str # list(str) +) +ReceiptsStreamRow = namedtuple( + "ReceiptsStreamRow", + ( + "room_id", # str + "receipt_type", # str + "user_id", # str + "event_id", # str + "data", # dict + ), +) +PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str +PushersStreamRow = namedtuple( + "PushersStreamRow", + ("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool +) +CachesStreamRow = namedtuple( + "CachesStreamRow", + ("cache_func", "keys", "invalidation_ts"), # str # list(str) # int +) +PublicRoomsStreamRow = namedtuple( + "PublicRoomsStreamRow", + ( + "room_id", # str + "visibility", # str + "appservice_id", # str, optional + "network_id", # str, optional + ), +) +DeviceListsStreamRow = namedtuple( + "DeviceListsStreamRow", ("user_id", "destination") # str # str +) +ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str +TagAccountDataStreamRow = namedtuple( + "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict +) +AccountDataStreamRow = namedtuple( + "AccountDataStream", + ("user_id", "room_id", "data_type", "data"), # str # str # str # dict +) +GroupsStreamRow = namedtuple( + "GroupsStreamRow", + ("group_id", "user_id", "type", "content"), # str # str # str # dict +) class Stream(object): @@ -106,6 +103,7 @@ class Stream(object): Provides a `get_updates()` function that returns new updates since the last time it was called up until the point `advance_current_token` was called. """ + NAME = None # The name of the stream ROW_TYPE = None # The type of the row. Used by the default impl of parse_row. _LIMITED = True # Whether the update function takes a limit @@ -185,16 +183,13 @@ class Stream(object): if self._LIMITED: rows = yield self.update_function( - from_token, current_token, - limit=MAX_EVENTS_BEHIND + 1, + from_token, current_token, limit=MAX_EVENTS_BEHIND + 1 ) # never turn more than MAX_EVENTS_BEHIND + 1 into updates. rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1) else: - rows = yield self.update_function( - from_token, current_token, - ) + rows = yield self.update_function(from_token, current_token) updates = [(row[0], row[1:]) for row in rows] @@ -230,6 +225,7 @@ class BackfillStream(Stream): """We fetched some old events and either we had never seen that event before or it went from being an outlier to not. """ + NAME = "backfill" ROW_TYPE = BackfillStreamRow @@ -286,6 +282,7 @@ class ReceiptsStream(Stream): class PushRulesStream(Stream): """A user has changed their push rules """ + NAME = "push_rules" ROW_TYPE = PushRulesStreamRow @@ -306,6 +303,7 @@ class PushRulesStream(Stream): class PushersStream(Stream): """A user has added/changed/removed a pusher """ + NAME = "pushers" ROW_TYPE = PushersStreamRow @@ -322,6 +320,7 @@ class CachesStream(Stream): """A cache was invalidated on the master and no other stream would invalidate the cache on the workers """ + NAME = "caches" ROW_TYPE = CachesStreamRow @@ -337,6 +336,7 @@ class CachesStream(Stream): class PublicRoomsStream(Stream): """The public rooms list changed """ + NAME = "public_rooms" ROW_TYPE = PublicRoomsStreamRow @@ -352,6 +352,7 @@ class PublicRoomsStream(Stream): class DeviceListsStream(Stream): """Someone added/changed/removed a device """ + NAME = "device_lists" _LIMITED = False ROW_TYPE = DeviceListsStreamRow @@ -368,6 +369,7 @@ class DeviceListsStream(Stream): class ToDeviceStream(Stream): """New to_device messages for a client """ + NAME = "to_device" ROW_TYPE = ToDeviceStreamRow @@ -383,6 +385,7 @@ class ToDeviceStream(Stream): class TagAccountDataStream(Stream): """Someone added/removed a tag for a room """ + NAME = "tag_account_data" ROW_TYPE = TagAccountDataStreamRow @@ -398,6 +401,7 @@ class TagAccountDataStream(Stream): class AccountDataStream(Stream): """Global or per room account data was changed """ + NAME = "account_data" ROW_TYPE = AccountDataStreamRow @@ -416,7 +420,7 @@ class AccountDataStream(Stream): results = list(room_results) results.extend( - (stream_id, user_id, None, account_data_type, content,) + (stream_id, user_id, None, account_data_type, content) for stream_id, user_id, account_data_type, content in global_results ) diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index f1290d022a..3d0694bb11 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -52,6 +52,7 @@ data part are: @attr.s(slots=True, frozen=True) class EventsStreamRow(object): """A parsed row from the events replication stream""" + type = attr.ib() # str: the TypeId of one of the *EventsStreamRows data = attr.ib() # BaseEventsStreamRow @@ -80,11 +81,11 @@ class BaseEventsStreamRow(object): class EventsStreamEventRow(BaseEventsStreamRow): TypeId = "ev" - event_id = attr.ib() # str - room_id = attr.ib() # str - type = attr.ib() # str - state_key = attr.ib() # str, optional - redacts = attr.ib() # str, optional + event_id = attr.ib() # str + room_id = attr.ib() # str + type = attr.ib() # str + state_key = attr.ib() # str, optional + redacts = attr.ib() # str, optional relates_to = attr.ib() # str, optional @@ -92,24 +93,21 @@ class EventsStreamEventRow(BaseEventsStreamRow): class EventsStreamCurrentStateRow(BaseEventsStreamRow): TypeId = "state" - room_id = attr.ib() # str - type = attr.ib() # str + room_id = attr.ib() # str + type = attr.ib() # str state_key = attr.ib() # str - event_id = attr.ib() # str, optional + event_id = attr.ib() # str, optional TypeToRow = { - Row.TypeId: Row - for Row in ( - EventsStreamEventRow, - EventsStreamCurrentStateRow, - ) + Row.TypeId: Row for Row in (EventsStreamEventRow, EventsStreamCurrentStateRow) } class EventsStream(Stream): """We received a new event, or an event went from being an outlier to not """ + NAME = "events" def __init__(self, hs): @@ -121,19 +119,17 @@ class EventsStream(Stream): @defer.inlineCallbacks def update_function(self, from_token, current_token, limit=None): event_rows = yield self._store.get_all_new_forward_event_rows( - from_token, current_token, limit, + from_token, current_token, limit ) event_updates = ( - (row[0], EventsStreamEventRow.TypeId, row[1:]) - for row in event_rows + (row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows ) state_rows = yield self._store.get_all_updated_current_state_deltas( from_token, current_token, limit ) state_updates = ( - (row[0], EventsStreamCurrentStateRow.TypeId, row[1:]) - for row in state_rows + (row[0], EventsStreamCurrentStateRow.TypeId, row[1:]) for row in state_rows ) all_updates = heapq.merge(event_updates, state_updates) diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index 9aa43aa8d2..dc2484109d 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -17,16 +17,20 @@ from collections import namedtuple from ._base import Stream -FederationStreamRow = namedtuple("FederationStreamRow", ( - "type", # str, the type of data as defined in the BaseFederationRows - "data", # dict, serialization of a federation.send_queue.BaseFederationRow -)) +FederationStreamRow = namedtuple( + "FederationStreamRow", + ( + "type", # str, the type of data as defined in the BaseFederationRows + "data", # dict, serialization of a federation.send_queue.BaseFederationRow + ), +) class FederationStream(Stream): """Data to be sent over federation. Only available when master has federation sending disabled. """ + NAME = "federation" ROW_TYPE = FederationStreamRow diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index e6110ad9b1..1d20b96d03 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -66,6 +66,7 @@ class ClientRestResource(JsonResource): * /_matrix/client/unstable * etc """ + def __init__(self, hs): JsonResource.__init__(self, hs, canonical_json=False) self.register_servlets(self, hs) diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index d6c4dcdb18..9843a902c6 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -61,7 +61,7 @@ def historical_admin_path_patterns(path_regex): "^/_synapse/admin/v1", "^/_matrix/client/api/v1/admin", "^/_matrix/client/unstable/admin", - "^/_matrix/client/r0/admin" + "^/_matrix/client/r0/admin", ) ) @@ -88,12 +88,12 @@ class UsersRestServlet(RestServlet): class VersionServlet(RestServlet): - PATTERNS = (re.compile("^/_synapse/admin/v1/server_version$"), ) + PATTERNS = (re.compile("^/_synapse/admin/v1/server_version$"),) def __init__(self, hs): self.res = { - 'server_version': get_version_string(synapse), - 'python_version': platform.python_version(), + "server_version": get_version_string(synapse), + "python_version": platform.python_version(), } def on_GET(self, request): @@ -107,6 +107,7 @@ class UserRegisterServlet(RestServlet): nonces (dict[str, int]): The nonces that we will accept. A dict of nonce to the time it was generated, in int seconds. """ + PATTERNS = historical_admin_path_patterns("/register") NONCE_TIMEOUT = 60 @@ -146,28 +147,24 @@ class UserRegisterServlet(RestServlet): body = parse_json_object_from_request(request) if "nonce" not in body: - raise SynapseError( - 400, "nonce must be specified", errcode=Codes.BAD_JSON, - ) + raise SynapseError(400, "nonce must be specified", errcode=Codes.BAD_JSON) nonce = body["nonce"] if nonce not in self.nonces: - raise SynapseError( - 400, "unrecognised nonce", - ) + raise SynapseError(400, "unrecognised nonce") # Delete the nonce, so it can't be reused, even if it's invalid del self.nonces[nonce] if "username" not in body: raise SynapseError( - 400, "username must be specified", errcode=Codes.BAD_JSON, + 400, "username must be specified", errcode=Codes.BAD_JSON ) else: if ( - not isinstance(body['username'], text_type) - or len(body['username']) > 512 + not isinstance(body["username"], text_type) + or len(body["username"]) > 512 ): raise SynapseError(400, "Invalid username") @@ -177,12 +174,12 @@ class UserRegisterServlet(RestServlet): if "password" not in body: raise SynapseError( - 400, "password must be specified", errcode=Codes.BAD_JSON, + 400, "password must be specified", errcode=Codes.BAD_JSON ) else: if ( - not isinstance(body['password'], text_type) - or len(body['password']) > 512 + not isinstance(body["password"], text_type) + or len(body["password"]) > 512 ): raise SynapseError(400, "Invalid password") @@ -202,7 +199,7 @@ class UserRegisterServlet(RestServlet): key=self.hs.config.registration_shared_secret.encode(), digestmod=hashlib.sha1, ) - want_mac.update(nonce.encode('utf8')) + want_mac.update(nonce.encode("utf8")) want_mac.update(b"\x00") want_mac.update(username) want_mac.update(b"\x00") @@ -211,13 +208,10 @@ class UserRegisterServlet(RestServlet): want_mac.update(b"admin" if admin else b"notadmin") if user_type: want_mac.update(b"\x00") - want_mac.update(user_type.encode('utf8')) + want_mac.update(user_type.encode("utf8")) want_mac = want_mac.hexdigest() - if not hmac.compare_digest( - want_mac.encode('ascii'), - got_mac.encode('ascii') - ): + if not hmac.compare_digest(want_mac.encode("ascii"), got_mac.encode("ascii")): raise SynapseError(403, "HMAC incorrect") # Reuse the parts of RegisterRestServlet to reduce code duplication @@ -226,7 +220,7 @@ class UserRegisterServlet(RestServlet): register = RegisterRestServlet(self.hs) (user_id, _) = yield register.registration_handler.register( - localpart=body['username'].lower(), + localpart=body["username"].lower(), password=body["password"], admin=bool(admin), generate_token=False, @@ -308,7 +302,7 @@ class PurgeHistoryRestServlet(RestServlet): # user can provide an event_id in the URL or the request body, or can # provide a timestamp in the request body. if event_id is None: - event_id = body.get('purge_up_to_event_id') + event_id = body.get("purge_up_to_event_id") if event_id is not None: event = yield self.store.get_event(event_id) @@ -318,44 +312,39 @@ class PurgeHistoryRestServlet(RestServlet): token = yield self.store.get_topological_token_for_event(event_id) - logger.info( - "[purge] purging up to token %s (event_id %s)", - token, event_id, - ) - elif 'purge_up_to_ts' in body: - ts = body['purge_up_to_ts'] + logger.info("[purge] purging up to token %s (event_id %s)", token, event_id) + elif "purge_up_to_ts" in body: + ts = body["purge_up_to_ts"] if not isinstance(ts, int): raise SynapseError( - 400, "purge_up_to_ts must be an int", - errcode=Codes.BAD_JSON, + 400, "purge_up_to_ts must be an int", errcode=Codes.BAD_JSON ) - stream_ordering = ( - yield self.store.find_first_stream_ordering_after_ts(ts) - ) + stream_ordering = (yield self.store.find_first_stream_ordering_after_ts(ts)) r = ( yield self.store.get_room_event_after_stream_ordering( - room_id, stream_ordering, + room_id, stream_ordering ) ) if not r: logger.warn( "[purge] purging events not possible: No event found " "(received_ts %i => stream_ordering %i)", - ts, stream_ordering, + ts, + stream_ordering, ) raise SynapseError( - 404, - "there is no event to be purged", - errcode=Codes.NOT_FOUND, + 404, "there is no event to be purged", errcode=Codes.NOT_FOUND ) (stream, topo, _event_id) = r token = "t%d-%d" % (topo, stream) logger.info( "[purge] purging up to token %s (received_ts %i => " "stream_ordering %i)", - token, ts, stream_ordering, + token, + ts, + stream_ordering, ) else: raise SynapseError( @@ -365,13 +354,10 @@ class PurgeHistoryRestServlet(RestServlet): ) purge_id = yield self.pagination_handler.start_purge_history( - room_id, token, - delete_local_events=delete_local_events, + room_id, token, delete_local_events=delete_local_events ) - defer.returnValue((200, { - "purge_id": purge_id, - })) + defer.returnValue((200, {"purge_id": purge_id})) class PurgeHistoryStatusRestServlet(RestServlet): @@ -421,16 +407,14 @@ class DeactivateAccountRestServlet(RestServlet): UserID.from_string(target_user_id) result = yield self._deactivate_account_handler.deactivate_account( - target_user_id, erase, + target_user_id, erase ) if result: id_server_unbind_result = "success" else: id_server_unbind_result = "no-support" - defer.returnValue((200, { - "id_server_unbind_result": id_server_unbind_result, - })) + defer.returnValue((200, {"id_server_unbind_result": id_server_unbind_result})) class ShutdownRoomRestServlet(RestServlet): @@ -439,6 +423,7 @@ class ShutdownRoomRestServlet(RestServlet): to a new room created by `new_room_user_id` and kicked users will be auto joined to the new room. """ + PATTERNS = historical_admin_path_patterns("/shutdown_room/(?P[^/]+)") DEFAULT_MESSAGE = ( @@ -474,9 +459,7 @@ class ShutdownRoomRestServlet(RestServlet): config={ "preset": "public_chat", "name": room_name, - "power_level_content_override": { - "users_default": -10, - }, + "power_level_content_override": {"users_default": -10}, }, ratelimit=False, ) @@ -485,8 +468,7 @@ class ShutdownRoomRestServlet(RestServlet): requester_user_id = requester.user.to_string() logger.info( - "Shutting down room %r, joining to new room: %r", - room_id, new_room_id, + "Shutting down room %r, joining to new room: %r", room_id, new_room_id ) # This will work even if the room is already blocked, but that is @@ -529,7 +511,7 @@ class ShutdownRoomRestServlet(RestServlet): kicked_users.append(user_id) except Exception: logger.exception( - "Failed to leave old room and join new room for %r", user_id, + "Failed to leave old room and join new room for %r", user_id ) failed_to_kick_users.append(user_id) @@ -550,18 +532,24 @@ class ShutdownRoomRestServlet(RestServlet): room_id, new_room_id, requester_user_id ) - defer.returnValue((200, { - "kicked_users": kicked_users, - "failed_to_kick_users": failed_to_kick_users, - "local_aliases": aliases_for_room, - "new_room_id": new_room_id, - })) + defer.returnValue( + ( + 200, + { + "kicked_users": kicked_users, + "failed_to_kick_users": failed_to_kick_users, + "local_aliases": aliases_for_room, + "new_room_id": new_room_id, + }, + ) + ) class QuarantineMediaInRoom(RestServlet): """Quarantines all media in a room so that no one can download it via this server. """ + PATTERNS = historical_admin_path_patterns("/quarantine_media/(?P[^/]+)") def __init__(self, hs): @@ -574,7 +562,7 @@ class QuarantineMediaInRoom(RestServlet): yield assert_user_is_admin(self.auth, requester.user) num_quarantined = yield self.store.quarantine_media_ids_in_room( - room_id, requester.user.to_string(), + room_id, requester.user.to_string() ) defer.returnValue((200, {"num_quarantined": num_quarantined})) @@ -583,6 +571,7 @@ class QuarantineMediaInRoom(RestServlet): class ListMediaInRoom(RestServlet): """Lists all of the media in a given room. """ + PATTERNS = historical_admin_path_patterns("/room/(?P[^/]+)/media") def __init__(self, hs): @@ -613,7 +602,10 @@ class ResetPasswordRestServlet(RestServlet): Returns: 200 OK with empty object if success otherwise an error. """ - PATTERNS = historical_admin_path_patterns("/reset_password/(?P[^/]*)") + + PATTERNS = historical_admin_path_patterns( + "/reset_password/(?P[^/]*)" + ) def __init__(self, hs): self.store = hs.get_datastore() @@ -633,7 +625,7 @@ class ResetPasswordRestServlet(RestServlet): params = parse_json_object_from_request(request) assert_params_in_dict(params, ["new_password"]) - new_password = params['new_password'] + new_password = params["new_password"] yield self._set_password_handler.set_password( target_user_id, new_password, requester @@ -650,7 +642,10 @@ class GetUsersPaginatedRestServlet(RestServlet): Returns: 200 OK with json object {list[dict[str, Any]], count} or empty object. """ - PATTERNS = historical_admin_path_patterns("/users_paginate/(?P[^/]*)") + + PATTERNS = historical_admin_path_patterns( + "/users_paginate/(?P[^/]*)" + ) def __init__(self, hs): self.store = hs.get_datastore() @@ -676,9 +671,7 @@ class GetUsersPaginatedRestServlet(RestServlet): logger.info("limit: %s, start: %s", limit, start) - ret = yield self.handlers.admin_handler.get_users_paginate( - order, start, limit - ) + ret = yield self.handlers.admin_handler.get_users_paginate(order, start, limit) defer.returnValue((200, ret)) @defer.inlineCallbacks @@ -702,13 +695,11 @@ class GetUsersPaginatedRestServlet(RestServlet): order = "name" # order by name in user table params = parse_json_object_from_request(request) assert_params_in_dict(params, ["limit", "start"]) - limit = params['limit'] - start = params['start'] + limit = params["limit"] + start = params["start"] logger.info("limit: %s, start: %s", limit, start) - ret = yield self.handlers.admin_handler.get_users_paginate( - order, start, limit - ) + ret = yield self.handlers.admin_handler.get_users_paginate(order, start, limit) defer.returnValue((200, ret)) @@ -722,6 +713,7 @@ class SearchUsersRestServlet(RestServlet): Returns: 200 OK with json object {list[dict[str, Any]], count} or empty object. """ + PATTERNS = historical_admin_path_patterns("/search_users/(?P[^/]*)") def __init__(self, hs): @@ -750,15 +742,14 @@ class SearchUsersRestServlet(RestServlet): term = parse_string(request, "term", required=True) logger.info("term: %s ", term) - ret = yield self.handlers.admin_handler.search_users( - term - ) + ret = yield self.handlers.admin_handler.search_users(term) defer.returnValue((200, ret)) class DeleteGroupAdminRestServlet(RestServlet): """Allows deleting of local groups """ + PATTERNS = historical_admin_path_patterns("/delete_group/(?P[^/]*)") def __init__(self, hs): @@ -800,15 +791,15 @@ class AccountValidityRenewServlet(RestServlet): raise SynapseError(400, "Missing property 'user_id' in the request body") expiration_ts = yield self.account_activity_handler.renew_account_for_user( - body["user_id"], body.get("expiration_ts"), + body["user_id"], + body.get("expiration_ts"), not body.get("enable_renewal_emails", True), ) - res = { - "expiration_ts": expiration_ts, - } + res = {"expiration_ts": expiration_ts} defer.returnValue((200, res)) + ######################################################################################## # # please don't add more servlets here: this file is already long and unwieldy. Put diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py index ae5aca9dac..ee66838a0d 100644 --- a/synapse/rest/admin/server_notice_servlet.py +++ b/synapse/rest/admin/server_notice_servlet.py @@ -46,6 +46,7 @@ class SendServerNoticeServlet(RestServlet): "event_id": "$1895723857jgskldgujpious" } """ + def __init__(self, hs): """ Args: @@ -58,15 +59,9 @@ class SendServerNoticeServlet(RestServlet): def register(self, json_resource): PATTERN = "^/_synapse/admin/v1/send_server_notice" + json_resource.register_paths("POST", (re.compile(PATTERN + "$"),), self.on_POST) json_resource.register_paths( - "POST", - (re.compile(PATTERN + "$"), ), - self.on_POST, - ) - json_resource.register_paths( - "PUT", - (re.compile(PATTERN + "/(?P[^/]*)$",), ), - self.on_PUT, + "PUT", (re.compile(PATTERN + "/(?P[^/]*)$"),), self.on_PUT ) @defer.inlineCallbacks @@ -96,5 +91,5 @@ class SendServerNoticeServlet(RestServlet): def on_PUT(self, request, txn_id): return self.txns.fetch_or_execute_request( - request, self.on_POST, request, txn_id, + request, self.on_POST, request, txn_id ) diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py index 48c17f1b6d..36404b797d 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py @@ -26,7 +26,6 @@ CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins class HttpTransactionCache(object): - def __init__(self, hs): self.hs = hs self.auth = self.hs.get_auth() @@ -53,7 +52,7 @@ class HttpTransactionCache(object): str: A transaction key """ token = self.auth.get_access_token_from_request(request) - return request.path.decode('utf8') + "/" + token + return request.path.decode("utf8") + "/" + token def fetch_or_execute_request(self, request, fn, *args, **kwargs): """A helper function for fetch_or_execute which extracts diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 0035182bb9..dd0d38ea5c 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -56,8 +56,9 @@ class ClientDirectoryServer(RestServlet): content = parse_json_object_from_request(request) if "room_id" not in content: - raise SynapseError(400, 'Missing params: ["room_id"]', - errcode=Codes.BAD_JSON) + raise SynapseError( + 400, 'Missing params: ["room_id"]', errcode=Codes.BAD_JSON + ) logger.debug("Got content: %s", content) logger.debug("Got room name: %s", room_alias.to_string()) @@ -89,13 +90,11 @@ class ClientDirectoryServer(RestServlet): try: service = yield self.auth.get_appservice_by_req(request) room_alias = RoomAlias.from_string(room_alias) - yield dir_handler.delete_appservice_association( - service, room_alias - ) + yield dir_handler.delete_appservice_association(service, room_alias) logger.info( "Application service at %s deleted alias %s", service.url, - room_alias.to_string() + room_alias.to_string(), ) defer.returnValue((200, {})) except AuthError: @@ -107,14 +106,10 @@ class ClientDirectoryServer(RestServlet): room_alias = RoomAlias.from_string(room_alias) - yield dir_handler.delete_association( - requester, room_alias - ) + yield dir_handler.delete_association(requester, room_alias) logger.info( - "User %s deleted alias %s", - user.to_string(), - room_alias.to_string() + "User %s deleted alias %s", user.to_string(), room_alias.to_string() ) defer.returnValue((200, {})) @@ -135,9 +130,9 @@ class ClientDirectoryListServer(RestServlet): if room is None: raise NotFoundError("Unknown room") - defer.returnValue((200, { - "visibility": "public" if room["is_public"] else "private" - })) + defer.returnValue( + (200, {"visibility": "public" if room["is_public"] else "private"}) + ) @defer.inlineCallbacks def on_PUT(self, request, room_id): @@ -147,7 +142,7 @@ class ClientDirectoryListServer(RestServlet): visibility = content.get("visibility", "public") yield self.handlers.directory_handler.edit_published_room_list( - requester, room_id, visibility, + requester, room_id, visibility ) defer.returnValue((200, {})) @@ -157,7 +152,7 @@ class ClientDirectoryListServer(RestServlet): requester = yield self.auth.get_user_by_req(request) yield self.handlers.directory_handler.edit_published_room_list( - requester, room_id, "private", + requester, room_id, "private" ) defer.returnValue((200, {})) @@ -191,7 +186,7 @@ class ClientAppserviceDirectoryListServer(RestServlet): ) yield self.handlers.directory_handler.edit_published_appservice_room_list( - requester.app_service.id, network_id, room_id, visibility, + requester.app_service.id, network_id, room_id, visibility ) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 84ca36270b..d6de2b7360 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -38,17 +38,14 @@ class EventStreamRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request): - requester = yield self.auth.get_user_by_req( - request, - allow_guest=True, - ) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) is_guest = requester.is_guest room_id = None if is_guest: if b"room_id" not in request.args: raise SynapseError(400, "Guest users must specify room_id param") if b"room_id" in request.args: - room_id = request.args[b"room_id"][0].decode('ascii') + room_id = request.args[b"room_id"][0].decode("ascii") pagin_config = PaginationConfig.from_request(request) timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 3b60728628..4efb679a04 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -44,10 +44,7 @@ def login_submission_legacy_convert(submission): to a typed object. """ if "user" in submission: - submission["identifier"] = { - "type": "m.id.user", - "user": submission["user"], - } + submission["identifier"] = {"type": "m.id.user", "user": submission["user"]} del submission["user"] if "medium" in submission and "address" in submission: @@ -73,11 +70,7 @@ def login_id_thirdparty_from_phone(identifier): msisdn = phone_number_to_msisdn(identifier["country"], identifier["number"]) - return { - "type": "m.id.thirdparty", - "medium": "msisdn", - "address": msisdn, - } + return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn} class LoginRestServlet(RestServlet): @@ -120,9 +113,9 @@ class LoginRestServlet(RestServlet): # login flow types returned. flows.append({"type": LoginRestServlet.TOKEN_TYPE}) - flows.extend(( - {"type": t} for t in self.auth_handler.get_supported_login_types() - )) + flows.extend( + ({"type": t} for t in self.auth_handler.get_supported_login_types()) + ) return (200, {"flows": flows}) @@ -132,7 +125,8 @@ class LoginRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request): self._address_ratelimiter.ratelimit( - request.getClientIP(), time_now_s=self.hs.clock.time(), + request.getClientIP(), + time_now_s=self.hs.clock.time(), rate_hz=self.hs.config.rc_login_address.per_second, burst_count=self.hs.config.rc_login_address.burst_count, update=True, @@ -140,8 +134,9 @@ class LoginRestServlet(RestServlet): login_submission = parse_json_object_from_request(request) try: - if self.jwt_enabled and (login_submission["type"] == - LoginRestServlet.JWT_TYPE): + if self.jwt_enabled and ( + login_submission["type"] == LoginRestServlet.JWT_TYPE + ): result = yield self.do_jwt_login(login_submission) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: result = yield self.do_token_login(login_submission) @@ -170,10 +165,10 @@ class LoginRestServlet(RestServlet): # field) logger.info( "Got login request with identifier: %r, medium: %r, address: %r, user: %r", - login_submission.get('identifier'), - login_submission.get('medium'), - login_submission.get('address'), - login_submission.get('user'), + login_submission.get("identifier"), + login_submission.get("medium"), + login_submission.get("address"), + login_submission.get("user"), ) login_submission_legacy_convert(login_submission) @@ -190,13 +185,13 @@ class LoginRestServlet(RestServlet): # convert threepid identifiers to user IDs if identifier["type"] == "m.id.thirdparty": - address = identifier.get('address') - medium = identifier.get('medium') + address = identifier.get("address") + medium = identifier.get("medium") if medium is None or address is None: raise SynapseError(400, "Invalid thirdparty identifier") - if medium == 'email': + if medium == "email": # For emails, transform the address to lowercase. # We store all email addreses as lowercase in the DB. # (See add_threepid in synapse/handlers/auth.py) @@ -205,34 +200,28 @@ class LoginRestServlet(RestServlet): # Check for login providers that support 3pid login types canonical_user_id, callback_3pid = ( yield self.auth_handler.check_password_provider_3pid( - medium, - address, - login_submission["password"], + medium, address, login_submission["password"] ) ) if canonical_user_id: # Authentication through password provider and 3pid succeeded result = yield self._register_device_with_callback( - canonical_user_id, login_submission, callback_3pid, + canonical_user_id, login_submission, callback_3pid ) defer.returnValue(result) # No password providers were able to handle this 3pid # Check local store user_id = yield self.hs.get_datastore().get_user_id_by_threepid( - medium, address, + medium, address ) if not user_id: logger.warn( - "unknown 3pid identifier medium %s, address %r", - medium, address, + "unknown 3pid identifier medium %s, address %r", medium, address ) raise LoginError(403, "", errcode=Codes.FORBIDDEN) - identifier = { - "type": "m.id.user", - "user": user_id, - } + identifier = {"type": "m.id.user", "user": user_id} # by this point, the identifier should be an m.id.user: if it's anything # else, we haven't understood it. @@ -242,22 +231,16 @@ class LoginRestServlet(RestServlet): raise SynapseError(400, "User identifier is missing 'user' key") canonical_user_id, callback = yield self.auth_handler.validate_login( - identifier["user"], - login_submission, + identifier["user"], login_submission ) result = yield self._register_device_with_callback( - canonical_user_id, login_submission, callback, + canonical_user_id, login_submission, callback ) defer.returnValue(result) @defer.inlineCallbacks - def _register_device_with_callback( - self, - user_id, - login_submission, - callback=None, - ): + def _register_device_with_callback(self, user_id, login_submission, callback=None): """ Registers a device with a given user_id. Optionally run a callback function after registration has completed. @@ -273,7 +256,7 @@ class LoginRestServlet(RestServlet): device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( - user_id, device_id, initial_display_name, + user_id, device_id, initial_display_name ) result = { @@ -290,7 +273,7 @@ class LoginRestServlet(RestServlet): @defer.inlineCallbacks def do_token_login(self, login_submission): - token = login_submission['token'] + token = login_submission["token"] auth_handler = self.auth_handler user_id = ( yield auth_handler.validate_short_term_login_token_and_get_user_id(token) @@ -299,7 +282,7 @@ class LoginRestServlet(RestServlet): device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( - user_id, device_id, initial_display_name, + user_id, device_id, initial_display_name ) result = { @@ -316,15 +299,16 @@ class LoginRestServlet(RestServlet): token = login_submission.get("token", None) if token is None: raise LoginError( - 401, "Token field for JWT is missing", - errcode=Codes.UNAUTHORIZED + 401, "Token field for JWT is missing", errcode=Codes.UNAUTHORIZED ) import jwt from jwt.exceptions import InvalidTokenError try: - payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm]) + payload = jwt.decode( + token, self.jwt_secret, algorithms=[self.jwt_algorithm] + ) except jwt.ExpiredSignatureError: raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED) except InvalidTokenError: @@ -342,7 +326,7 @@ class LoginRestServlet(RestServlet): device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( - registered_user_id, device_id, initial_display_name, + registered_user_id, device_id, initial_display_name ) result = { @@ -358,7 +342,7 @@ class LoginRestServlet(RestServlet): device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( - registered_user_id, device_id, initial_display_name, + registered_user_id, device_id, initial_display_name ) result = { @@ -375,21 +359,20 @@ class CasRedirectServlet(RestServlet): def __init__(self, hs): super(CasRedirectServlet, self).__init__() - self.cas_server_url = hs.config.cas_server_url.encode('ascii') - self.cas_service_url = hs.config.cas_service_url.encode('ascii') + self.cas_server_url = hs.config.cas_server_url.encode("ascii") + self.cas_service_url = hs.config.cas_service_url.encode("ascii") def on_GET(self, request): args = request.args if b"redirectUrl" not in args: return (400, "Redirect URL not specified for CAS auth") - client_redirect_url_param = urllib.parse.urlencode({ - b"redirectUrl": args[b"redirectUrl"][0] - }).encode('ascii') - hs_redirect_url = (self.cas_service_url + - b"/_matrix/client/r0/login/cas/ticket") - service_param = urllib.parse.urlencode({ - b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param) - }).encode('ascii') + client_redirect_url_param = urllib.parse.urlencode( + {b"redirectUrl": args[b"redirectUrl"][0]} + ).encode("ascii") + hs_redirect_url = self.cas_service_url + b"/_matrix/client/r0/login/cas/ticket" + service_param = urllib.parse.urlencode( + {b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)} + ).encode("ascii") request.redirect(b"%s/login?%s" % (self.cas_server_url, service_param)) finish_request(request) @@ -411,7 +394,7 @@ class CasTicketServlet(RestServlet): uri = self.cas_server_url + "/proxyValidate" args = { "ticket": parse_string(request, "ticket", required=True), - "service": self.cas_service_url + "service": self.cas_service_url, } try: body = yield self._http_client.get_raw(uri, args) @@ -438,7 +421,7 @@ class CasTicketServlet(RestServlet): raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) return self._sso_auth_handler.on_successful_auth( - user, request, client_redirect_url, + user, request, client_redirect_url ) def parse_cas_response(self, cas_response_body): @@ -448,7 +431,7 @@ class CasTicketServlet(RestServlet): root = ET.fromstring(cas_response_body) if not root.tag.endswith("serviceResponse"): raise Exception("root of CAS response is not serviceResponse") - success = (root[0].tag.endswith("authenticationSuccess")) + success = root[0].tag.endswith("authenticationSuccess") for child in root[0]: if child.tag.endswith("user"): user = child.text @@ -466,11 +449,11 @@ class CasTicketServlet(RestServlet): raise Exception("CAS response does not contain user") except Exception: logger.error("Error parsing CAS response", exc_info=1) - raise LoginError(401, "Invalid CAS response", - errcode=Codes.UNAUTHORIZED) + raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) if not success: - raise LoginError(401, "Unsuccessful CAS response", - errcode=Codes.UNAUTHORIZED) + raise LoginError( + 401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED + ) return user, attributes @@ -482,6 +465,7 @@ class SSOAuthHandler(object): Args: hs (synapse.server.HomeServer) """ + def __init__(self, hs): self._hostname = hs.hostname self._auth_handler = hs.get_auth_handler() @@ -490,8 +474,7 @@ class SSOAuthHandler(object): @defer.inlineCallbacks def on_successful_auth( - self, username, request, client_redirect_url, - user_display_name=None, + self, username, request, client_redirect_url, user_display_name=None ): """Called once the user has successfully authenticated with the SSO. diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py index b8064f261e..cd711be519 100644 --- a/synapse/rest/client/v1/logout.py +++ b/synapse/rest/client/v1/logout.py @@ -46,7 +46,8 @@ class LogoutRestServlet(RestServlet): yield self._auth_handler.delete_access_token(access_token) else: yield self._device_handler.delete_device( - requester.user.to_string(), requester.device_id) + requester.user.to_string(), requester.device_id + ) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index e263da3cb7..3e87f0fdb3 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -47,7 +47,7 @@ class PresenceStatusRestServlet(RestServlet): if requester.user != user: allowed = yield self.presence_handler.is_visible( - observed_user=user, observer_user=requester.user, + observed_user=user, observer_user=requester.user ) if not allowed: diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index e15d9d82a6..4d8ab1f47e 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -63,8 +63,7 @@ class ProfileDisplaynameRestServlet(RestServlet): except Exception: defer.returnValue((400, "Unable to parse name")) - yield self.profile_handler.set_displayname( - user, requester, new_name, is_admin) + yield self.profile_handler.set_displayname(user, requester, new_name, is_admin) defer.returnValue((200, {})) @@ -113,8 +112,7 @@ class ProfileAvatarURLRestServlet(RestServlet): except Exception: defer.returnValue((400, "Unable to parse name")) - yield self.profile_handler.set_avatar_url( - user, requester, new_name, is_admin) + yield self.profile_handler.set_avatar_url(user, requester, new_name, is_admin) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 3d6326fe2f..e635efb420 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -21,7 +21,11 @@ from synapse.api.errors import ( SynapseError, UnrecognizedRequestError, ) -from synapse.http.servlet import RestServlet, parse_json_value_from_request, parse_string +from synapse.http.servlet import ( + RestServlet, + parse_json_value_from_request, + parse_string, +) from synapse.push.baserules import BASE_RULE_IDS from synapse.push.clientformat import format_push_rules_for_user from synapse.push.rulekinds import PRIORITY_CLASS_MAP @@ -32,7 +36,8 @@ from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundExc class PushRuleRestServlet(RestServlet): PATTERNS = client_patterns("/(?Ppushrules/.*)$", v1=True) SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( - "Unrecognised request: You probably wanted a trailing slash") + "Unrecognised request: You probably wanted a trailing slash" + ) def __init__(self, hs): super(PushRuleRestServlet, self).__init__() @@ -54,27 +59,25 @@ class PushRuleRestServlet(RestServlet): requester = yield self.auth.get_user_by_req(request) - if '/' in spec['rule_id'] or '\\' in spec['rule_id']: + if "/" in spec["rule_id"] or "\\" in spec["rule_id"]: raise SynapseError(400, "rule_id may not contain slashes") content = parse_json_value_from_request(request) user_id = requester.user.to_string() - if 'attr' in spec: + if "attr" in spec: yield self.set_rule_attr(user_id, spec, content) self.notify_user(user_id) defer.returnValue((200, {})) - if spec['rule_id'].startswith('.'): + if spec["rule_id"].startswith("."): # Rule ids starting with '.' are reserved for server default rules. raise SynapseError(400, "cannot add new rule_ids that start with '.'") try: (conditions, actions) = _rule_tuple_from_request_object( - spec['template'], - spec['rule_id'], - content, + spec["template"], spec["rule_id"], content ) except InvalidRuleException as e: raise SynapseError(400, str(e)) @@ -95,7 +98,7 @@ class PushRuleRestServlet(RestServlet): conditions=conditions, actions=actions, before=before, - after=after + after=after, ) self.notify_user(user_id) except InconsistentRuleException as e: @@ -118,9 +121,7 @@ class PushRuleRestServlet(RestServlet): namespaced_rule_id = _namespaced_rule_id_from_spec(spec) try: - yield self.store.delete_push_rule( - user_id, namespaced_rule_id - ) + yield self.store.delete_push_rule(user_id, namespaced_rule_id) self.notify_user(user_id) defer.returnValue((200, {})) except StoreError as e: @@ -149,10 +150,10 @@ class PushRuleRestServlet(RestServlet): PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR ) - if path[0] == '': + if path[0] == "": defer.returnValue((200, rules)) - elif path[0] == 'global': - result = _filter_ruleset_with_path(rules['global'], path[1:]) + elif path[0] == "global": + result = _filter_ruleset_with_path(rules["global"], path[1:]) defer.returnValue((200, result)) else: raise UnrecognizedRequestError() @@ -162,12 +163,10 @@ class PushRuleRestServlet(RestServlet): def notify_user(self, user_id): stream_id, _ = self.store.get_push_rules_stream_token() - self.notifier.on_new_event( - "push_rules_key", stream_id, users=[user_id] - ) + self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id]) def set_rule_attr(self, user_id, spec, val): - if spec['attr'] == 'enabled': + if spec["attr"] == "enabled": if isinstance(val, dict) and "enabled" in val: val = val["enabled"] if not isinstance(val, bool): @@ -176,14 +175,12 @@ class PushRuleRestServlet(RestServlet): # bools directly, so let's not break them. raise SynapseError(400, "Value for 'enabled' must be boolean") namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - return self.store.set_push_rule_enabled( - user_id, namespaced_rule_id, val - ) - elif spec['attr'] == 'actions': - actions = val.get('actions') + return self.store.set_push_rule_enabled(user_id, namespaced_rule_id, val) + elif spec["attr"] == "actions": + actions = val.get("actions") _check_actions(actions) namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - rule_id = spec['rule_id'] + rule_id = spec["rule_id"] is_default_rule = rule_id.startswith(".") if is_default_rule: if namespaced_rule_id not in BASE_RULE_IDS: @@ -210,12 +207,12 @@ def _rule_spec_from_path(path): """ if len(path) < 2: raise UnrecognizedRequestError() - if path[0] != 'pushrules': + if path[0] != "pushrules": raise UnrecognizedRequestError() scope = path[1] path = path[2:] - if scope != 'global': + if scope != "global": raise UnrecognizedRequestError() if len(path) == 0: @@ -229,56 +226,40 @@ def _rule_spec_from_path(path): rule_id = path[0] - spec = { - 'scope': scope, - 'template': template, - 'rule_id': rule_id - } + spec = {"scope": scope, "template": template, "rule_id": rule_id} path = path[1:] if len(path) > 0 and len(path[0]) > 0: - spec['attr'] = path[0] + spec["attr"] = path[0] return spec def _rule_tuple_from_request_object(rule_template, rule_id, req_obj): - if rule_template in ['override', 'underride']: - if 'conditions' not in req_obj: + if rule_template in ["override", "underride"]: + if "conditions" not in req_obj: raise InvalidRuleException("Missing 'conditions'") - conditions = req_obj['conditions'] + conditions = req_obj["conditions"] for c in conditions: - if 'kind' not in c: + if "kind" not in c: raise InvalidRuleException("Condition without 'kind'") - elif rule_template == 'room': - conditions = [{ - 'kind': 'event_match', - 'key': 'room_id', - 'pattern': rule_id - }] - elif rule_template == 'sender': - conditions = [{ - 'kind': 'event_match', - 'key': 'user_id', - 'pattern': rule_id - }] - elif rule_template == 'content': - if 'pattern' not in req_obj: + elif rule_template == "room": + conditions = [{"kind": "event_match", "key": "room_id", "pattern": rule_id}] + elif rule_template == "sender": + conditions = [{"kind": "event_match", "key": "user_id", "pattern": rule_id}] + elif rule_template == "content": + if "pattern" not in req_obj: raise InvalidRuleException("Content rule missing 'pattern'") - pat = req_obj['pattern'] + pat = req_obj["pattern"] - conditions = [{ - 'kind': 'event_match', - 'key': 'content.body', - 'pattern': pat - }] + conditions = [{"kind": "event_match", "key": "content.body", "pattern": pat}] else: raise InvalidRuleException("Unknown rule template: %s" % (rule_template,)) - if 'actions' not in req_obj: + if "actions" not in req_obj: raise InvalidRuleException("No actions found") - actions = req_obj['actions'] + actions = req_obj["actions"] _check_actions(actions) @@ -290,9 +271,9 @@ def _check_actions(actions): raise InvalidRuleException("No actions found") for a in actions: - if a in ['notify', 'dont_notify', 'coalesce']: + if a in ["notify", "dont_notify", "coalesce"]: pass - elif isinstance(a, dict) and 'set_tweak' in a: + elif isinstance(a, dict) and "set_tweak" in a: pass else: raise InvalidRuleException("Unrecognised action") @@ -304,7 +285,7 @@ def _filter_ruleset_with_path(ruleset, path): PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR ) - if path[0] == '': + if path[0] == "": return ruleset template_kind = path[0] if template_kind not in ruleset: @@ -314,13 +295,13 @@ def _filter_ruleset_with_path(ruleset, path): raise UnrecognizedRequestError( PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR ) - if path[0] == '': + if path[0] == "": return ruleset[template_kind] rule_id = path[0] the_rule = None for r in ruleset[template_kind]: - if r['rule_id'] == rule_id: + if r["rule_id"] == rule_id: the_rule = r if the_rule is None: raise NotFoundError @@ -339,19 +320,19 @@ def _filter_ruleset_with_path(ruleset, path): def _priority_class_from_spec(spec): - if spec['template'] not in PRIORITY_CLASS_MAP.keys(): - raise InvalidRuleException("Unknown template: %s" % (spec['template'])) - pc = PRIORITY_CLASS_MAP[spec['template']] + if spec["template"] not in PRIORITY_CLASS_MAP.keys(): + raise InvalidRuleException("Unknown template: %s" % (spec["template"])) + pc = PRIORITY_CLASS_MAP[spec["template"]] return pc def _namespaced_rule_id_from_spec(spec): - return _namespaced_rule_id(spec, spec['rule_id']) + return _namespaced_rule_id(spec, spec["rule_id"]) def _namespaced_rule_id(spec, rule_id): - return "global/%s/%s" % (spec['template'], rule_id) + return "global/%s/%s" % (spec["template"], rule_id) class InvalidRuleException(Exception): diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 15d860db37..e9246018df 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -44,9 +44,7 @@ class PushersRestServlet(RestServlet): requester = yield self.auth.get_user_by_req(request) user = requester.user - pushers = yield self.hs.get_datastore().get_pushers_by_user_id( - user.to_string() - ) + pushers = yield self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) allowed_keys = [ "app_display_name", @@ -87,50 +85,61 @@ class PushersSetRestServlet(RestServlet): content = parse_json_object_from_request(request) - if ('pushkey' in content and 'app_id' in content - and 'kind' in content and - content['kind'] is None): + if ( + "pushkey" in content + and "app_id" in content + and "kind" in content + and content["kind"] is None + ): yield self.pusher_pool.remove_pusher( - content['app_id'], content['pushkey'], user_id=user.to_string() + content["app_id"], content["pushkey"], user_id=user.to_string() ) defer.returnValue((200, {})) assert_params_in_dict( content, - ['kind', 'app_id', 'app_display_name', - 'device_display_name', 'pushkey', 'lang', 'data'] + [ + "kind", + "app_id", + "app_display_name", + "device_display_name", + "pushkey", + "lang", + "data", + ], ) - logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind']) + logger.debug("set pushkey %s to kind %s", content["pushkey"], content["kind"]) logger.debug("Got pushers request with body: %r", content) append = False - if 'append' in content: - append = content['append'] + if "append" in content: + append = content["append"] if not append: yield self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user( - app_id=content['app_id'], - pushkey=content['pushkey'], - not_user_id=user.to_string() + app_id=content["app_id"], + pushkey=content["pushkey"], + not_user_id=user.to_string(), ) try: yield self.pusher_pool.add_pusher( user_id=user.to_string(), access_token=requester.access_token_id, - kind=content['kind'], - app_id=content['app_id'], - app_display_name=content['app_display_name'], - device_display_name=content['device_display_name'], - pushkey=content['pushkey'], - lang=content['lang'], - data=content['data'], - profile_tag=content.get('profile_tag', ""), + kind=content["kind"], + app_id=content["app_id"], + app_display_name=content["app_display_name"], + device_display_name=content["device_display_name"], + pushkey=content["pushkey"], + lang=content["lang"], + data=content["data"], + profile_tag=content.get("profile_tag", ""), ) except PusherConfigException as pce: - raise SynapseError(400, "Config Error: " + str(pce), - errcode=Codes.MISSING_PARAM) + raise SynapseError( + 400, "Config Error: " + str(pce), errcode=Codes.MISSING_PARAM + ) self.notifier.on_new_replication_data() @@ -144,6 +153,7 @@ class PushersRemoveRestServlet(RestServlet): """ To allow pusher to be delete by clicking a link (ie. GET request) """ + PATTERNS = client_patterns("/pushers/remove$", v1=True) SUCCESS_HTML = b"You have been unsubscribed" @@ -164,9 +174,7 @@ class PushersRemoveRestServlet(RestServlet): try: yield self.pusher_pool.remove_pusher( - app_id=app_id, - pushkey=pushkey, - user_id=user.to_string(), + app_id=app_id, pushkey=pushkey, user_id=user.to_string() ) except StoreError as se: if se.code != 404: @@ -177,9 +185,9 @@ class PushersRemoveRestServlet(RestServlet): request.setResponseCode(200) request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % ( - len(PushersRemoveRestServlet.SUCCESS_HTML), - )) + request.setHeader( + b"Content-Length", b"%d" % (len(PushersRemoveRestServlet.SUCCESS_HTML),) + ) request.write(PushersRemoveRestServlet.SUCCESS_HTML) finish_request(request) defer.returnValue(None) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index e8f672c4ba..a028337125 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -61,18 +61,16 @@ class RoomCreateRestServlet(TransactionRestServlet): PATTERNS = "/createRoom" register_txn_path(self, PATTERNS, http_server) # define CORS for all of /rooms in RoomCreateRestServlet for simplicity - http_server.register_paths("OPTIONS", - client_patterns("/rooms(?:/.*)?$", v1=True), - self.on_OPTIONS) + http_server.register_paths( + "OPTIONS", client_patterns("/rooms(?:/.*)?$", v1=True), self.on_OPTIONS + ) # define CORS for /createRoom[/txnid] - http_server.register_paths("OPTIONS", - client_patterns("/createRoom(?:/.*)?$", v1=True), - self.on_OPTIONS) + http_server.register_paths( + "OPTIONS", client_patterns("/createRoom(?:/.*)?$", v1=True), self.on_OPTIONS + ) def on_PUT(self, request, txn_id): - return self.txns.fetch_or_execute_request( - request, self.on_POST, request - ) + return self.txns.fetch_or_execute_request(request, self.on_POST, request) @defer.inlineCallbacks def on_POST(self, request): @@ -107,21 +105,23 @@ class RoomStateEventRestServlet(TransactionRestServlet): no_state_key = "/rooms/(?P[^/]*)/state/(?P[^/]*)$" # /room/$roomid/state/$eventtype/$statekey - state_key = ("/rooms/(?P[^/]*)/state/" - "(?P[^/]*)/(?P[^/]*)$") - - http_server.register_paths("GET", - client_patterns(state_key, v1=True), - self.on_GET) - http_server.register_paths("PUT", - client_patterns(state_key, v1=True), - self.on_PUT) - http_server.register_paths("GET", - client_patterns(no_state_key, v1=True), - self.on_GET_no_state_key) - http_server.register_paths("PUT", - client_patterns(no_state_key, v1=True), - self.on_PUT_no_state_key) + state_key = ( + "/rooms/(?P[^/]*)/state/" + "(?P[^/]*)/(?P[^/]*)$" + ) + + http_server.register_paths( + "GET", client_patterns(state_key, v1=True), self.on_GET + ) + http_server.register_paths( + "PUT", client_patterns(state_key, v1=True), self.on_PUT + ) + http_server.register_paths( + "GET", client_patterns(no_state_key, v1=True), self.on_GET_no_state_key + ) + http_server.register_paths( + "PUT", client_patterns(no_state_key, v1=True), self.on_PUT_no_state_key + ) def on_GET_no_state_key(self, request, room_id, event_type): return self.on_GET(request, room_id, event_type, "") @@ -132,8 +132,9 @@ class RoomStateEventRestServlet(TransactionRestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id, event_type, state_key): requester = yield self.auth.get_user_by_req(request, allow_guest=True) - format = parse_string(request, "format", default="content", - allowed_values=["content", "event"]) + format = parse_string( + request, "format", default="content", allowed_values=["content", "event"] + ) msg_handler = self.message_handler data = yield msg_handler.get_room_data( @@ -145,9 +146,7 @@ class RoomStateEventRestServlet(TransactionRestServlet): ) if not data: - raise SynapseError( - 404, "Event not found.", errcode=Codes.NOT_FOUND - ) + raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) if format == "event": event = format_event_for_client_v2(data.get_dict()) @@ -182,9 +181,7 @@ class RoomStateEventRestServlet(TransactionRestServlet): ) else: event = yield self.event_creation_handler.create_and_send_nonmember_event( - requester, - event_dict, - txn_id=txn_id, + requester, event_dict, txn_id=txn_id ) ret = {} @@ -195,7 +192,6 @@ class RoomStateEventRestServlet(TransactionRestServlet): # TODO: Needs unit testing for generic events + feedback class RoomSendEventRestServlet(TransactionRestServlet): - def __init__(self, hs): super(RoomSendEventRestServlet, self).__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() @@ -203,7 +199,7 @@ class RoomSendEventRestServlet(TransactionRestServlet): def register(self, http_server): # /rooms/$roomid/send/$event_type[/$txn_id] - PATTERNS = ("/rooms/(?P[^/]*)/send/(?P[^/]*)") + PATTERNS = "/rooms/(?P[^/]*)/send/(?P[^/]*)" register_txn_path(self, PATTERNS, http_server, with_get=True) @defer.inlineCallbacks @@ -218,13 +214,11 @@ class RoomSendEventRestServlet(TransactionRestServlet): "sender": requester.user.to_string(), } - if b'ts' in request.args and requester.app_service: - event_dict['origin_server_ts'] = parse_integer(request, "ts", 0) + if b"ts" in request.args and requester.app_service: + event_dict["origin_server_ts"] = parse_integer(request, "ts", 0) event = yield self.event_creation_handler.create_and_send_nonmember_event( - requester, - event_dict, - txn_id=txn_id, + requester, event_dict, txn_id=txn_id ) defer.returnValue((200, {"event_id": event.event_id})) @@ -247,15 +241,12 @@ class JoinRoomAliasServlet(TransactionRestServlet): def register(self, http_server): # /join/$room_identifier[/$txn_id] - PATTERNS = ("/join/(?P[^/]*)") + PATTERNS = "/join/(?P[^/]*)" register_txn_path(self, PATTERNS, http_server) @defer.inlineCallbacks def on_POST(self, request, room_identifier, txn_id=None): - requester = yield self.auth.get_user_by_req( - request, - allow_guest=True, - ) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) try: content = parse_json_object_from_request(request) @@ -268,7 +259,7 @@ class JoinRoomAliasServlet(TransactionRestServlet): room_id = room_identifier try: remote_room_hosts = [ - x.decode('ascii') for x in request.args[b"server_name"] + x.decode("ascii") for x in request.args[b"server_name"] ] except Exception: remote_room_hosts = None @@ -278,9 +269,9 @@ class JoinRoomAliasServlet(TransactionRestServlet): room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias) room_id = room_id.to_string() else: - raise SynapseError(400, "%s was not legal room ID or room alias" % ( - room_identifier, - )) + raise SynapseError( + 400, "%s was not legal room ID or room alias" % (room_identifier,) + ) yield self.room_member_handler.update_membership( requester=requester, @@ -339,14 +330,11 @@ class PublicRoomListRestServlet(TransactionRestServlet): handler = self.hs.get_room_list_handler() if server: data = yield handler.get_remote_public_room_list( - server, - limit=limit, - since_token=since_token, + server, limit=limit, since_token=since_token ) else: data = yield handler.get_local_public_room_list( - limit=limit, - since_token=since_token, + limit=limit, since_token=since_token ) defer.returnValue((200, data)) @@ -439,16 +427,13 @@ class RoomMemberListRestServlet(RestServlet): chunk = [] for event in events: - if ( - (membership and event['content'].get("membership") != membership) or - (not_membership and event['content'].get("membership") == not_membership) + if (membership and event["content"].get("membership") != membership) or ( + not_membership and event["content"].get("membership") == not_membership ): continue chunk.append(event) - defer.returnValue((200, { - "chunk": chunk - })) + defer.returnValue((200, {"chunk": chunk})) # deprecated in favour of /members?membership=join? @@ -466,12 +451,10 @@ class JoinedRoomMemberListRestServlet(RestServlet): requester = yield self.auth.get_user_by_req(request) users_with_profile = yield self.message_handler.get_joined_members( - requester, room_id, + requester, room_id ) - defer.returnValue((200, { - "joined": users_with_profile, - })) + defer.returnValue((200, {"joined": users_with_profile})) # TODO: Needs better unit testing @@ -486,9 +469,7 @@ class RoomMessageListRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) - pagination_config = PaginationConfig.from_request( - request, default_limit=10, - ) + pagination_config = PaginationConfig.from_request(request, default_limit=10) as_client_event = b"raw" not in request.args filter_bytes = parse_string(request, b"filter", encoding=None) if filter_bytes: @@ -544,9 +525,7 @@ class RoomInitialSyncRestServlet(RestServlet): requester = yield self.auth.get_user_by_req(request, allow_guest=True) pagination_config = PaginationConfig.from_request(request) content = yield self.initial_sync_handler.room_initial_sync( - room_id=room_id, - requester=requester, - pagin_config=pagination_config, + room_id=room_id, requester=requester, pagin_config=pagination_config ) defer.returnValue((200, content)) @@ -603,30 +582,24 @@ class RoomEventContextServlet(RestServlet): event_filter = None results = yield self.room_context_handler.get_event_context( - requester.user, - room_id, - event_id, - limit, - event_filter, + requester.user, room_id, event_id, limit, event_filter ) if not results: - raise SynapseError( - 404, "Event not found.", errcode=Codes.NOT_FOUND - ) + raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) time_now = self.clock.time_msec() results["events_before"] = yield self._event_serializer.serialize_events( - results["events_before"], time_now, + results["events_before"], time_now ) results["event"] = yield self._event_serializer.serialize_event( - results["event"], time_now, + results["event"], time_now ) results["events_after"] = yield self._event_serializer.serialize_events( - results["events_after"], time_now, + results["events_after"], time_now ) results["state"] = yield self._event_serializer.serialize_events( - results["state"], time_now, + results["state"], time_now ) defer.returnValue((200, results)) @@ -639,20 +612,14 @@ class RoomForgetRestServlet(TransactionRestServlet): self.auth = hs.get_auth() def register(self, http_server): - PATTERNS = ("/rooms/(?P[^/]*)/forget") + PATTERNS = "/rooms/(?P[^/]*)/forget" register_txn_path(self, PATTERNS, http_server) @defer.inlineCallbacks def on_POST(self, request, room_id, txn_id=None): - requester = yield self.auth.get_user_by_req( - request, - allow_guest=False, - ) + requester = yield self.auth.get_user_by_req(request, allow_guest=False) - yield self.room_member_handler.forget( - user=requester.user, - room_id=room_id, - ) + yield self.room_member_handler.forget(user=requester.user, room_id=room_id) defer.returnValue((200, {})) @@ -664,7 +631,6 @@ class RoomForgetRestServlet(TransactionRestServlet): # TODO: Needs unit testing class RoomMembershipRestServlet(TransactionRestServlet): - def __init__(self, hs): super(RoomMembershipRestServlet, self).__init__(hs) self.room_member_handler = hs.get_room_member_handler() @@ -672,20 +638,19 @@ class RoomMembershipRestServlet(TransactionRestServlet): def register(self, http_server): # /rooms/$roomid/[invite|join|leave] - PATTERNS = ("/rooms/(?P[^/]*)/" - "(?Pjoin|invite|leave|ban|unban|kick)") + PATTERNS = ( + "/rooms/(?P[^/]*)/" + "(?Pjoin|invite|leave|ban|unban|kick)" + ) register_txn_path(self, PATTERNS, http_server) @defer.inlineCallbacks def on_POST(self, request, room_id, membership_action, txn_id=None): - requester = yield self.auth.get_user_by_req( - request, - allow_guest=True, - ) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) if requester.is_guest and membership_action not in { Membership.JOIN, - Membership.LEAVE + Membership.LEAVE, }: raise AuthError(403, "Guest access not allowed") @@ -704,7 +669,7 @@ class RoomMembershipRestServlet(TransactionRestServlet): content["address"], content["id_server"], requester, - txn_id + txn_id, ) defer.returnValue((200, {})) return @@ -715,8 +680,8 @@ class RoomMembershipRestServlet(TransactionRestServlet): target = UserID.from_string(content["user_id"]) event_content = None - if 'reason' in content and membership_action in ['kick', 'ban']: - event_content = {'reason': content['reason']} + if "reason" in content and membership_action in ["kick", "ban"]: + event_content = {"reason": content["reason"]} yield self.room_member_handler.update_membership( requester=requester, @@ -755,7 +720,7 @@ class RoomRedactEventRestServlet(TransactionRestServlet): self.auth = hs.get_auth() def register(self, http_server): - PATTERNS = ("/rooms/(?P[^/]*)/redact/(?P[^/]*)") + PATTERNS = "/rooms/(?P[^/]*)/redact/(?P[^/]*)" register_txn_path(self, PATTERNS, http_server) @defer.inlineCallbacks @@ -817,9 +782,7 @@ class RoomTypingRestServlet(RestServlet): ) else: yield self.typing_handler.stopped_typing( - target_user=target_user, - auth_user=requester.user, - room_id=room_id, + target_user=target_user, auth_user=requester.user, room_id=room_id ) defer.returnValue((200, {})) @@ -841,9 +804,7 @@ class SearchRestServlet(RestServlet): batch = parse_string(request, "next_batch") results = yield self.handlers.search_handler.search( - requester.user, - content, - batch, + requester.user, content, batch ) defer.returnValue((200, results)) @@ -879,20 +840,18 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False): with_get: True to also register respective GET paths for the PUTs. """ http_server.register_paths( - "POST", - client_patterns(regex_string + "$", v1=True), - servlet.on_POST + "POST", client_patterns(regex_string + "$", v1=True), servlet.on_POST ) http_server.register_paths( "PUT", client_patterns(regex_string + "/(?P[^/]*)$", v1=True), - servlet.on_PUT + servlet.on_PUT, ) if with_get: http_server.register_paths( "GET", client_patterns(regex_string + "/(?P[^/]*)$", v1=True), - servlet.on_GET + servlet.on_GET, ) diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index 6381049210..41b3171ac8 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -34,8 +34,7 @@ class VoipRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request): requester = yield self.auth.get_user_by_req( - request, - self.hs.config.turn_allow_guests + request, self.hs.config.turn_allow_guests ) turnUris = self.hs.config.turn_uris @@ -49,9 +48,7 @@ class VoipRestServlet(RestServlet): username = "%d:%s" % (expiry, requester.user.to_string()) mac = hmac.new( - turnSecret.encode(), - msg=username.encode(), - digestmod=hashlib.sha1 + turnSecret.encode(), msg=username.encode(), digestmod=hashlib.sha1 ) # We need to use standard padded base64 encoding here # encode_base64 because we need to add the standard padding to get the @@ -65,12 +62,17 @@ class VoipRestServlet(RestServlet): else: defer.returnValue((200, {})) - defer.returnValue((200, { - 'username': username, - 'password': password, - 'ttl': userLifetime / 1000, - 'uris': turnUris, - })) + defer.returnValue( + ( + 200, + { + "username": username, + "password": password, + "ttl": userLifetime / 1000, + "uris": turnUris, + }, + ) + ) def on_OPTIONS(self, request): return (200, {}) diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index 5236d5d566..e3d59ac3ac 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -52,11 +52,11 @@ def client_patterns(path_regex, releases=(0,), unstable=True, v1=False): def set_timeline_upper_limit(filter_json, filter_timeline_limit): if filter_timeline_limit < 0: return # no upper limits - timeline = filter_json.get('room', {}).get('timeline', {}) - if 'limit' in timeline: - filter_json['room']['timeline']["limit"] = min( - filter_json['room']['timeline']['limit'], - filter_timeline_limit) + timeline = filter_json.get("room", {}).get("timeline", {}) + if "limit" in timeline: + filter_json["room"]["timeline"]["limit"] = min( + filter_json["room"]["timeline"]["limit"], filter_timeline_limit + ) def interactive_auth_handler(orig): @@ -74,10 +74,12 @@ def interactive_auth_handler(orig): # ... yield self.auth_handler.check_auth """ + def wrapped(*args, **kwargs): res = defer.maybeDeferred(orig, *args, **kwargs) res.addErrback(_catch_incomplete_interactive_auth) return res + return wrapped diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index ab75f6c2b2..f143d8b85c 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -52,6 +52,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): if self.config.email_password_reset_behaviour == "local": from synapse.push.mailer import Mailer, load_jinja2_templates + templates = load_jinja2_templates( config=hs.config, template_html_name=hs.config.email_password_reset_template_html, @@ -72,14 +73,12 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): "User password resets have been disabled due to lack of email config" ) raise SynapseError( - 400, "Email-based password resets have been disabled on this server", + 400, "Email-based password resets have been disabled on this server" ) body = parse_json_object_from_request(request) - assert_params_in_dict(body, [ - 'client_secret', 'email', 'send_attempt' - ]) + assert_params_in_dict(body, ["client_secret", "email", "send_attempt"]) # Extract params from body client_secret = body["client_secret"] @@ -95,24 +94,24 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): ) existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( - 'email', email, + "email", email ) if existingUid is None: raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) if self.config.email_password_reset_behaviour == "remote": - if 'id_server' not in body: + if "id_server" not in body: raise SynapseError(400, "Missing 'id_server' param in body") # Have the identity server handle the password reset flow ret = yield self.identity_handler.requestEmailToken( - body["id_server"], email, client_secret, send_attempt, next_link, + body["id_server"], email, client_secret, send_attempt, next_link ) else: # Send password reset emails from Synapse sid = yield self.send_password_reset( - email, client_secret, send_attempt, next_link, + email, client_secret, send_attempt, next_link ) # Wrap the session id in a JSON object @@ -121,13 +120,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): defer.returnValue((200, ret)) @defer.inlineCallbacks - def send_password_reset( - self, - email, - client_secret, - send_attempt, - next_link=None, - ): + def send_password_reset(self, email, client_secret, send_attempt, next_link=None): """Send a password reset email Args: @@ -144,14 +137,14 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # Check that this email/client_secret/send_attempt combo is new or # greater than what we've seen previously session = yield self.datastore.get_threepid_validation_session( - "email", client_secret, address=email, validated=False, + "email", client_secret, address=email, validated=False ) # Check to see if a session already exists and that it is not yet # marked as validated if session and session.get("validated_at") is None: - session_id = session['session_id'] - last_send_attempt = session['last_send_attempt'] + session_id = session["session_id"] + last_send_attempt = session["last_send_attempt"] # Check that the send_attempt is higher than previous attempts if send_attempt <= last_send_attempt: @@ -169,22 +162,27 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # and session_id try: yield self.mailer.send_password_reset_mail( - email, token, client_secret, session_id, + email, token, client_secret, session_id ) except Exception: - logger.exception( - "Error sending a password reset email to %s", email, - ) + logger.exception("Error sending a password reset email to %s", email) raise SynapseError( 500, "An error was encountered when sending the password reset email" ) - token_expires = (self.hs.clock.time_msec() + - self.config.email_validation_token_lifetime) + token_expires = ( + self.hs.clock.time_msec() + self.config.email_validation_token_lifetime + ) yield self.datastore.start_or_continue_validation_session( - "email", email, session_id, client_secret, send_attempt, - next_link, token, token_expires, + "email", + email, + session_id, + client_secret, + send_attempt, + next_link, + token, + token_expires, ) defer.returnValue(session_id) @@ -203,12 +201,12 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet): def on_POST(self, request): body = parse_json_object_from_request(request) - assert_params_in_dict(body, [ - 'id_server', 'client_secret', - 'country', 'phone_number', 'send_attempt', - ]) + assert_params_in_dict( + body, + ["id_server", "client_secret", "country", "phone_number", "send_attempt"], + ) - msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) + msisdn = phone_number_to_msisdn(body["country"], body["phone_number"]) if not check_3pid_allowed(self.hs, "msisdn", msisdn): raise SynapseError( @@ -217,9 +215,7 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - existingUid = yield self.datastore.get_user_id_by_threepid( - 'msisdn', msisdn - ) + existingUid = yield self.datastore.get_user_id_by_threepid("msisdn", msisdn) if existingUid is None: raise SynapseError(400, "MSISDN not found", Codes.THREEPID_NOT_FOUND) @@ -230,10 +226,9 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet): class PasswordResetSubmitTokenServlet(RestServlet): """Handles 3PID validation token submission""" + PATTERNS = client_patterns( - "/password_reset/(?P[^/]*)/submit_token/*$", - releases=(), - unstable=True, + "/password_reset/(?P[^/]*)/submit_token/*$", releases=(), unstable=True ) def __init__(self, hs): @@ -252,8 +247,7 @@ class PasswordResetSubmitTokenServlet(RestServlet): def on_GET(self, request, medium): if medium != "email": raise SynapseError( - 400, - "This medium is currently not supported for password resets", + 400, "This medium is currently not supported for password resets" ) if self.config.email_password_reset_behaviour == "off": if self.config.password_resets_were_disabled_due_to_email_config: @@ -261,7 +255,7 @@ class PasswordResetSubmitTokenServlet(RestServlet): "User password resets have been disabled due to lack of email config" ) raise SynapseError( - 400, "Email-based password resets have been disabled on this server", + 400, "Email-based password resets have been disabled on this server" ) sid = parse_string(request, "sid") @@ -272,10 +266,7 @@ class PasswordResetSubmitTokenServlet(RestServlet): try: # Mark the session as valid next_link = yield self.datastore.validate_threepid_session( - sid, - client_secret, - token, - self.clock.time_msec(), + sid, client_secret, token, self.clock.time_msec() ) # Perform a 302 redirect if next_link is set @@ -298,13 +289,11 @@ class PasswordResetSubmitTokenServlet(RestServlet): html = self.load_jinja2_template( self.config.email_template_dir, self.config.email_password_reset_failure_template, - template_vars={ - "failure_reason": e.msg, - } + template_vars={"failure_reason": e.msg}, ) request.setResponseCode(e.code) - request.write(html.encode('utf-8')) + request.write(html.encode("utf-8")) finish_request(request) defer.returnValue(None) @@ -330,20 +319,14 @@ class PasswordResetSubmitTokenServlet(RestServlet): def on_POST(self, request, medium): if medium != "email": raise SynapseError( - 400, - "This medium is currently not supported for password resets", + 400, "This medium is currently not supported for password resets" ) body = parse_json_object_from_request(request) - assert_params_in_dict(body, [ - 'sid', 'client_secret', 'token', - ]) + assert_params_in_dict(body, ["sid", "client_secret", "token"]) valid, _ = yield self.datastore.validate_threepid_validation_token( - body['sid'], - body['client_secret'], - body['token'], - self.clock.time_msec(), + body["sid"], body["client_secret"], body["token"], self.clock.time_msec() ) response_code = 200 if valid else 400 @@ -379,29 +362,30 @@ class PasswordRestServlet(RestServlet): if self.auth.has_access_token(request): requester = yield self.auth.get_user_by_req(request) params = yield self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request), + requester, body, self.hs.get_ip_from_request(request) ) user_id = requester.user.to_string() else: requester = None result, params, _ = yield self.auth_handler.check_auth( [[LoginType.EMAIL_IDENTITY], [LoginType.MSISDN]], - body, self.hs.get_ip_from_request(request), + body, + self.hs.get_ip_from_request(request), password_servlet=True, ) if LoginType.EMAIL_IDENTITY in result: threepid = result[LoginType.EMAIL_IDENTITY] - if 'medium' not in threepid or 'address' not in threepid: + if "medium" not in threepid or "address" not in threepid: raise SynapseError(500, "Malformed threepid") - if threepid['medium'] == 'email': + if threepid["medium"] == "email": # For emails, transform the address to lowercase. # We store all email addreses as lowercase in the DB. # (See add_threepid in synapse/handlers/auth.py) - threepid['address'] = threepid['address'].lower() + threepid["address"] = threepid["address"].lower() # if using email, we must know about the email they're authing with! threepid_user_id = yield self.datastore.get_user_id_by_threepid( - threepid['medium'], threepid['address'] + threepid["medium"], threepid["address"] ) if not threepid_user_id: raise SynapseError(404, "Email address not found", Codes.NOT_FOUND) @@ -411,11 +395,9 @@ class PasswordRestServlet(RestServlet): raise SynapseError(500, "", Codes.UNKNOWN) assert_params_in_dict(params, ["new_password"]) - new_password = params['new_password'] + new_password = params["new_password"] - yield self._set_password_handler.set_password( - user_id, new_password, requester - ) + yield self._set_password_handler.set_password(user_id, new_password, requester) defer.returnValue((200, {})) @@ -450,25 +432,22 @@ class DeactivateAccountRestServlet(RestServlet): # allow ASes to dectivate their own users if requester.app_service: yield self._deactivate_account_handler.deactivate_account( - requester.user.to_string(), erase, + requester.user.to_string(), erase ) defer.returnValue((200, {})) yield self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request), + requester, body, self.hs.get_ip_from_request(request) ) result = yield self._deactivate_account_handler.deactivate_account( - requester.user.to_string(), erase, - id_server=body.get("id_server"), + requester.user.to_string(), erase, id_server=body.get("id_server") ) if result: id_server_unbind_result = "success" else: id_server_unbind_result = "no-support" - defer.returnValue((200, { - "id_server_unbind_result": id_server_unbind_result, - })) + defer.returnValue((200, {"id_server_unbind_result": id_server_unbind_result})) class EmailThreepidRequestTokenRestServlet(RestServlet): @@ -484,11 +463,10 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): def on_POST(self, request): body = parse_json_object_from_request(request) assert_params_in_dict( - body, - ['id_server', 'client_secret', 'email', 'send_attempt'], + body, ["id_server", "client_secret", "email", "send_attempt"] ) - if not check_3pid_allowed(self.hs, "email", body['email']): + if not check_3pid_allowed(self.hs, "email", body["email"]): raise SynapseError( 403, "Your email domain is not authorized on this server", @@ -496,7 +474,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): ) existingUid = yield self.datastore.get_user_id_by_threepid( - 'email', body['email'] + "email", body["email"] ) if existingUid is not None: @@ -518,12 +496,12 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request): body = parse_json_object_from_request(request) - assert_params_in_dict(body, [ - 'id_server', 'client_secret', - 'country', 'phone_number', 'send_attempt', - ]) + assert_params_in_dict( + body, + ["id_server", "client_secret", "country", "phone_number", "send_attempt"], + ) - msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) + msisdn = phone_number_to_msisdn(body["country"], body["phone_number"]) if not check_3pid_allowed(self.hs, "msisdn", msisdn): raise SynapseError( @@ -532,9 +510,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - existingUid = yield self.datastore.get_user_id_by_threepid( - 'msisdn', msisdn - ) + existingUid = yield self.datastore.get_user_id_by_threepid("msisdn", msisdn) if existingUid is not None: raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) @@ -558,18 +534,16 @@ class ThreepidRestServlet(RestServlet): def on_GET(self, request): requester = yield self.auth.get_user_by_req(request) - threepids = yield self.datastore.user_get_threepids( - requester.user.to_string() - ) + threepids = yield self.datastore.user_get_threepids(requester.user.to_string()) - defer.returnValue((200, {'threepids': threepids})) + defer.returnValue((200, {"threepids": threepids})) @defer.inlineCallbacks def on_POST(self, request): body = parse_json_object_from_request(request) - threePidCreds = body.get('threePidCreds') - threePidCreds = body.get('three_pid_creds', threePidCreds) + threePidCreds = body.get("threePidCreds") + threePidCreds = body.get("three_pid_creds", threePidCreds) if threePidCreds is None: raise SynapseError(400, "Missing param", Codes.MISSING_PARAM) @@ -579,30 +553,20 @@ class ThreepidRestServlet(RestServlet): threepid = yield self.identity_handler.threepid_from_creds(threePidCreds) if not threepid: - raise SynapseError( - 400, "Failed to auth 3pid", Codes.THREEPID_AUTH_FAILED - ) + raise SynapseError(400, "Failed to auth 3pid", Codes.THREEPID_AUTH_FAILED) - for reqd in ['medium', 'address', 'validated_at']: + for reqd in ["medium", "address", "validated_at"]: if reqd not in threepid: logger.warn("Couldn't add 3pid: invalid response from ID server") raise SynapseError(500, "Invalid response from ID Server") yield self.auth_handler.add_threepid( - user_id, - threepid['medium'], - threepid['address'], - threepid['validated_at'], + user_id, threepid["medium"], threepid["address"], threepid["validated_at"] ) - if 'bind' in body and body['bind']: - logger.debug( - "Binding threepid %s to %s", - threepid, user_id - ) - yield self.identity_handler.bind_threepid( - threePidCreds, user_id - ) + if "bind" in body and body["bind"]: + logger.debug("Binding threepid %s to %s", threepid, user_id) + yield self.identity_handler.bind_threepid(threePidCreds, user_id) defer.returnValue((200, {})) @@ -618,14 +582,14 @@ class ThreepidDeleteRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request): body = parse_json_object_from_request(request) - assert_params_in_dict(body, ['medium', 'address']) + assert_params_in_dict(body, ["medium", "address"]) requester = yield self.auth.get_user_by_req(request) user_id = requester.user.to_string() try: ret = yield self.auth_handler.delete_threepid( - user_id, body['medium'], body['address'], body.get("id_server"), + user_id, body["medium"], body["address"], body.get("id_server") ) except Exception: # NB. This endpoint should succeed if there is nothing to @@ -639,9 +603,7 @@ class ThreepidDeleteRestServlet(RestServlet): else: id_server_unbind_result = "no-support" - defer.returnValue((200, { - "id_server_unbind_result": id_server_unbind_result, - })) + defer.returnValue((200, {"id_server_unbind_result": id_server_unbind_result})) class WhoamiRestServlet(RestServlet): @@ -655,7 +617,7 @@ class WhoamiRestServlet(RestServlet): def on_GET(self, request): requester = yield self.auth.get_user_by_req(request) - defer.returnValue((200, {'user_id': requester.user.to_string()})) + defer.returnValue((200, {"user_id": requester.user.to_string()})) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py index 574a6298ce..f155c26259 100644 --- a/synapse/rest/client/v2_alpha/account_data.py +++ b/synapse/rest/client/v2_alpha/account_data.py @@ -30,6 +30,7 @@ class AccountDataServlet(RestServlet): PUT /user/{user_id}/account_data/{account_dataType} HTTP/1.1 GET /user/{user_id}/account_data/{account_dataType} HTTP/1.1 """ + PATTERNS = client_patterns( "/user/(?P[^/]*)/account_data/(?P[^/]*)" ) @@ -52,9 +53,7 @@ class AccountDataServlet(RestServlet): user_id, account_data_type, body ) - self.notifier.on_new_event( - "account_data_key", max_id, users=[user_id] - ) + self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) defer.returnValue((200, {})) @@ -65,7 +64,7 @@ class AccountDataServlet(RestServlet): raise AuthError(403, "Cannot get account data for other users.") event = yield self.store.get_global_account_data_by_type_for_user( - account_data_type, user_id, + account_data_type, user_id ) if event is None: @@ -79,6 +78,7 @@ class RoomAccountDataServlet(RestServlet): PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 GET /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 """ + PATTERNS = client_patterns( "/user/(?P[^/]*)" "/rooms/(?P[^/]*)" @@ -103,16 +103,14 @@ class RoomAccountDataServlet(RestServlet): raise SynapseError( 405, "Cannot set m.fully_read through this API." - " Use /rooms/!roomId:server.name/read_markers" + " Use /rooms/!roomId:server.name/read_markers", ) max_id = yield self.store.add_account_data_to_room( user_id, room_id, account_data_type, body ) - self.notifier.on_new_event( - "account_data_key", max_id, users=[user_id] - ) + self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) defer.returnValue((200, {})) @@ -123,7 +121,7 @@ class RoomAccountDataServlet(RestServlet): raise AuthError(403, "Cannot get account data for other users.") event = yield self.store.get_account_data_for_room_and_type( - user_id, room_id, account_data_type, + user_id, room_id, account_data_type ) if event is None: diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py index 63bdc33564..d29c10b83d 100644 --- a/synapse/rest/client/v2_alpha/account_validity.py +++ b/synapse/rest/client/v2_alpha/account_validity.py @@ -28,7 +28,9 @@ logger = logging.getLogger(__name__) class AccountValidityRenewServlet(RestServlet): PATTERNS = client_patterns("/account_validity/renew$") - SUCCESS_HTML = b"Your account has been successfully renewed." + SUCCESS_HTML = ( + b"Your account has been successfully renewed." + ) def __init__(self, hs): """ @@ -47,13 +49,13 @@ class AccountValidityRenewServlet(RestServlet): raise SynapseError(400, "Missing renewal token") renewal_token = request.args[b"token"][0] - yield self.account_activity_handler.renew_account(renewal_token.decode('utf8')) + yield self.account_activity_handler.renew_account(renewal_token.decode("utf8")) request.setResponseCode(200) request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % ( - len(AccountValidityRenewServlet.SUCCESS_HTML), - )) + request.setHeader( + b"Content-Length", b"%d" % (len(AccountValidityRenewServlet.SUCCESS_HTML),) + ) request.write(AccountValidityRenewServlet.SUCCESS_HTML) finish_request(request) defer.returnValue(None) @@ -77,7 +79,9 @@ class AccountValiditySendMailServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request): if not self.account_validity.renew_by_email_enabled: - raise AuthError(403, "Account renewal via email is disabled on this server.") + raise AuthError( + 403, "Account renewal via email is disabled on this server." + ) requester = yield self.auth.get_user_by_req(request, allow_expired=True) user_id = requester.user.to_string() diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 8dfe5cba02..bebc2951e7 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -122,6 +122,7 @@ class AuthRestServlet(RestServlet): cannot be handled in the normal flow (with requests to the same endpoint). Current use is for web fallback auth. """ + PATTERNS = client_patterns(r"/auth/(?P[\w\.]*)/fallback/web") def __init__(self, hs): @@ -138,11 +139,10 @@ class AuthRestServlet(RestServlet): if stagetype == LoginType.RECAPTCHA: html = RECAPTCHA_TEMPLATE % { - 'session': session, - 'myurl': "%s/r0/auth/%s/fallback/web" % ( - CLIENT_API_PREFIX, LoginType.RECAPTCHA - ), - 'sitekey': self.hs.config.recaptcha_public_key, + "session": session, + "myurl": "%s/r0/auth/%s/fallback/web" + % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), + "sitekey": self.hs.config.recaptcha_public_key, } html_bytes = html.encode("utf8") request.setResponseCode(200) @@ -154,14 +154,11 @@ class AuthRestServlet(RestServlet): return None elif stagetype == LoginType.TERMS: html = TERMS_TEMPLATE % { - 'session': session, - 'terms_url': "%s_matrix/consent?v=%s" % ( - self.hs.config.public_baseurl, - self.hs.config.user_consent_version, - ), - 'myurl': "%s/r0/auth/%s/fallback/web" % ( - CLIENT_API_PREFIX, LoginType.TERMS - ), + "session": session, + "terms_url": "%s_matrix/consent?v=%s" + % (self.hs.config.public_baseurl, self.hs.config.user_consent_version), + "myurl": "%s/r0/auth/%s/fallback/web" + % (CLIENT_API_PREFIX, LoginType.TERMS), } html_bytes = html.encode("utf8") request.setResponseCode(200) @@ -187,26 +184,20 @@ class AuthRestServlet(RestServlet): if not response: raise SynapseError(400, "No captcha response supplied") - authdict = { - 'response': response, - 'session': session, - } + authdict = {"response": response, "session": session} success = yield self.auth_handler.add_oob_auth( - LoginType.RECAPTCHA, - authdict, - self.hs.get_ip_from_request(request) + LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request) ) if success: html = SUCCESS_TEMPLATE else: html = RECAPTCHA_TEMPLATE % { - 'session': session, - 'myurl': "%s/r0/auth/%s/fallback/web" % ( - CLIENT_API_PREFIX, LoginType.RECAPTCHA - ), - 'sitekey': self.hs.config.recaptcha_public_key, + "session": session, + "myurl": "%s/r0/auth/%s/fallback/web" + % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), + "sitekey": self.hs.config.recaptcha_public_key, } html_bytes = html.encode("utf8") request.setResponseCode(200) @@ -218,31 +209,28 @@ class AuthRestServlet(RestServlet): defer.returnValue(None) elif stagetype == LoginType.TERMS: - if ('session' not in request.args or - len(request.args['session'])) == 0: + if ("session" not in request.args or len(request.args["session"])) == 0: raise SynapseError(400, "No session supplied") - session = request.args['session'][0] - authdict = {'session': session} + session = request.args["session"][0] + authdict = {"session": session} success = yield self.auth_handler.add_oob_auth( - LoginType.TERMS, - authdict, - self.hs.get_ip_from_request(request) + LoginType.TERMS, authdict, self.hs.get_ip_from_request(request) ) if success: html = SUCCESS_TEMPLATE else: html = TERMS_TEMPLATE % { - 'session': session, - 'terms_url': "%s_matrix/consent?v=%s" % ( + "session": session, + "terms_url": "%s_matrix/consent?v=%s" + % ( self.hs.config.public_baseurl, self.hs.config.user_consent_version, ), - 'myurl': "%s/r0/auth/%s/fallback/web" % ( - CLIENT_API_PREFIX, LoginType.TERMS - ), + "myurl": "%s/r0/auth/%s/fallback/web" + % (CLIENT_API_PREFIX, LoginType.TERMS), } html_bytes = html.encode("utf8") request.setResponseCode(200) diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 78665304a5..d279229d74 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -56,6 +56,7 @@ class DeleteDevicesRestServlet(RestServlet): API for bulk deletion of devices. Accepts a JSON object with a devices key which lists the device_ids to delete. Requires user interactive auth. """ + PATTERNS = client_patterns("/delete_devices") def __init__(self, hs): @@ -84,12 +85,11 @@ class DeleteDevicesRestServlet(RestServlet): assert_params_in_dict(body, ["devices"]) yield self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request), + requester, body, self.hs.get_ip_from_request(request) ) yield self.device_handler.delete_devices( - requester.user.to_string(), - body['devices'], + requester.user.to_string(), body["devices"] ) defer.returnValue((200, {})) @@ -112,8 +112,7 @@ class DeviceRestServlet(RestServlet): def on_GET(self, request, device_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) device = yield self.device_handler.get_device( - requester.user.to_string(), - device_id, + requester.user.to_string(), device_id ) defer.returnValue((200, device)) @@ -134,12 +133,10 @@ class DeviceRestServlet(RestServlet): raise yield self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request), + requester, body, self.hs.get_ip_from_request(request) ) - yield self.device_handler.delete_device( - requester.user.to_string(), device_id, - ) + yield self.device_handler.delete_device(requester.user.to_string(), device_id) defer.returnValue((200, {})) @defer.inlineCallbacks @@ -148,9 +145,7 @@ class DeviceRestServlet(RestServlet): body = parse_json_object_from_request(request) yield self.device_handler.update_device( - requester.user.to_string(), - device_id, - body + requester.user.to_string(), device_id, body ) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index 65db48c3cc..3f0adf4a21 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -53,8 +53,7 @@ class GetFilterRestServlet(RestServlet): try: filter = yield self.filtering.get_user_filter( - user_localpart=target_user.localpart, - filter_id=filter_id, + user_localpart=target_user.localpart, filter_id=filter_id ) defer.returnValue((200, filter.get_filter_json())) @@ -84,14 +83,10 @@ class CreateFilterRestServlet(RestServlet): raise AuthError(403, "Can only create filters for local users") content = parse_json_object_from_request(request) - set_timeline_upper_limit( - content, - self.hs.config.filter_timeline_limit - ) + set_timeline_upper_limit(content, self.hs.config.filter_timeline_limit) filter_id = yield self.filtering.add_user_filter( - user_localpart=target_user.localpart, - user_filter=content, + user_localpart=target_user.localpart, user_filter=content ) defer.returnValue((200, {"filter_id": str(filter_id)})) diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py index d082385ec7..a312dd2593 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py @@ -29,6 +29,7 @@ logger = logging.getLogger(__name__) class GroupServlet(RestServlet): """Get the group profile """ + PATTERNS = client_patterns("/groups/(?P[^/]*)/profile$") def __init__(self, hs): @@ -43,8 +44,7 @@ class GroupServlet(RestServlet): requester_user_id = requester.user.to_string() group_description = yield self.groups_handler.get_group_profile( - group_id, - requester_user_id, + group_id, requester_user_id ) defer.returnValue((200, group_description)) @@ -56,7 +56,7 @@ class GroupServlet(RestServlet): content = parse_json_object_from_request(request) yield self.groups_handler.update_group_profile( - group_id, requester_user_id, content, + group_id, requester_user_id, content ) defer.returnValue((200, {})) @@ -65,6 +65,7 @@ class GroupServlet(RestServlet): class GroupSummaryServlet(RestServlet): """Get the full group summary """ + PATTERNS = client_patterns("/groups/(?P[^/]*)/summary$") def __init__(self, hs): @@ -79,8 +80,7 @@ class GroupSummaryServlet(RestServlet): requester_user_id = requester.user.to_string() get_group_summary = yield self.groups_handler.get_group_summary( - group_id, - requester_user_id, + group_id, requester_user_id ) defer.returnValue((200, get_group_summary)) @@ -93,6 +93,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): - /groups/:group/summary/rooms/:room_id - /groups/:group/summary/categories/:category/rooms/:room_id """ + PATTERNS = client_patterns( "/groups/(?P[^/]*)/summary" "(/categories/(?P[^/]+))?" @@ -112,7 +113,8 @@ class GroupSummaryRoomsCatServlet(RestServlet): content = parse_json_object_from_request(request) resp = yield self.groups_handler.update_group_summary_room( - group_id, requester_user_id, + group_id, + requester_user_id, room_id=room_id, category_id=category_id, content=content, @@ -126,9 +128,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): requester_user_id = requester.user.to_string() resp = yield self.groups_handler.delete_group_summary_room( - group_id, requester_user_id, - room_id=room_id, - category_id=category_id, + group_id, requester_user_id, room_id=room_id, category_id=category_id ) defer.returnValue((200, resp)) @@ -137,6 +137,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): class GroupCategoryServlet(RestServlet): """Get/add/update/delete a group category """ + PATTERNS = client_patterns( "/groups/(?P[^/]*)/categories/(?P[^/]+)$" ) @@ -153,8 +154,7 @@ class GroupCategoryServlet(RestServlet): requester_user_id = requester.user.to_string() category = yield self.groups_handler.get_group_category( - group_id, requester_user_id, - category_id=category_id, + group_id, requester_user_id, category_id=category_id ) defer.returnValue((200, category)) @@ -166,9 +166,7 @@ class GroupCategoryServlet(RestServlet): content = parse_json_object_from_request(request) resp = yield self.groups_handler.update_group_category( - group_id, requester_user_id, - category_id=category_id, - content=content, + group_id, requester_user_id, category_id=category_id, content=content ) defer.returnValue((200, resp)) @@ -179,8 +177,7 @@ class GroupCategoryServlet(RestServlet): requester_user_id = requester.user.to_string() resp = yield self.groups_handler.delete_group_category( - group_id, requester_user_id, - category_id=category_id, + group_id, requester_user_id, category_id=category_id ) defer.returnValue((200, resp)) @@ -189,9 +186,8 @@ class GroupCategoryServlet(RestServlet): class GroupCategoriesServlet(RestServlet): """Get all group categories """ - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/categories/$" - ) + + PATTERNS = client_patterns("/groups/(?P[^/]*)/categories/$") def __init__(self, hs): super(GroupCategoriesServlet, self).__init__() @@ -205,7 +201,7 @@ class GroupCategoriesServlet(RestServlet): requester_user_id = requester.user.to_string() category = yield self.groups_handler.get_group_categories( - group_id, requester_user_id, + group_id, requester_user_id ) defer.returnValue((200, category)) @@ -214,9 +210,8 @@ class GroupCategoriesServlet(RestServlet): class GroupRoleServlet(RestServlet): """Get/add/update/delete a group role """ - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/roles/(?P[^/]+)$" - ) + + PATTERNS = client_patterns("/groups/(?P[^/]*)/roles/(?P[^/]+)$") def __init__(self, hs): super(GroupRoleServlet, self).__init__() @@ -230,8 +225,7 @@ class GroupRoleServlet(RestServlet): requester_user_id = requester.user.to_string() category = yield self.groups_handler.get_group_role( - group_id, requester_user_id, - role_id=role_id, + group_id, requester_user_id, role_id=role_id ) defer.returnValue((200, category)) @@ -243,9 +237,7 @@ class GroupRoleServlet(RestServlet): content = parse_json_object_from_request(request) resp = yield self.groups_handler.update_group_role( - group_id, requester_user_id, - role_id=role_id, - content=content, + group_id, requester_user_id, role_id=role_id, content=content ) defer.returnValue((200, resp)) @@ -256,8 +248,7 @@ class GroupRoleServlet(RestServlet): requester_user_id = requester.user.to_string() resp = yield self.groups_handler.delete_group_role( - group_id, requester_user_id, - role_id=role_id, + group_id, requester_user_id, role_id=role_id ) defer.returnValue((200, resp)) @@ -266,9 +257,8 @@ class GroupRoleServlet(RestServlet): class GroupRolesServlet(RestServlet): """Get all group roles """ - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/roles/$" - ) + + PATTERNS = client_patterns("/groups/(?P[^/]*)/roles/$") def __init__(self, hs): super(GroupRolesServlet, self).__init__() @@ -282,7 +272,7 @@ class GroupRolesServlet(RestServlet): requester_user_id = requester.user.to_string() category = yield self.groups_handler.get_group_roles( - group_id, requester_user_id, + group_id, requester_user_id ) defer.returnValue((200, category)) @@ -295,6 +285,7 @@ class GroupSummaryUsersRoleServlet(RestServlet): - /groups/:group/summary/users/:room_id - /groups/:group/summary/roles/:role/users/:user_id """ + PATTERNS = client_patterns( "/groups/(?P[^/]*)/summary" "(/roles/(?P[^/]+))?" @@ -314,7 +305,8 @@ class GroupSummaryUsersRoleServlet(RestServlet): content = parse_json_object_from_request(request) resp = yield self.groups_handler.update_group_summary_user( - group_id, requester_user_id, + group_id, + requester_user_id, user_id=user_id, role_id=role_id, content=content, @@ -328,9 +320,7 @@ class GroupSummaryUsersRoleServlet(RestServlet): requester_user_id = requester.user.to_string() resp = yield self.groups_handler.delete_group_summary_user( - group_id, requester_user_id, - user_id=user_id, - role_id=role_id, + group_id, requester_user_id, user_id=user_id, role_id=role_id ) defer.returnValue((200, resp)) @@ -339,6 +329,7 @@ class GroupSummaryUsersRoleServlet(RestServlet): class GroupRoomServlet(RestServlet): """Get all rooms in a group """ + PATTERNS = client_patterns("/groups/(?P[^/]*)/rooms$") def __init__(self, hs): @@ -352,7 +343,9 @@ class GroupRoomServlet(RestServlet): requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_rooms_in_group(group_id, requester_user_id) + result = yield self.groups_handler.get_rooms_in_group( + group_id, requester_user_id + ) defer.returnValue((200, result)) @@ -360,6 +353,7 @@ class GroupRoomServlet(RestServlet): class GroupUsersServlet(RestServlet): """Get all users in a group """ + PATTERNS = client_patterns("/groups/(?P[^/]*)/users$") def __init__(self, hs): @@ -373,7 +367,9 @@ class GroupUsersServlet(RestServlet): requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_users_in_group(group_id, requester_user_id) + result = yield self.groups_handler.get_users_in_group( + group_id, requester_user_id + ) defer.returnValue((200, result)) @@ -381,6 +377,7 @@ class GroupUsersServlet(RestServlet): class GroupInvitedUsersServlet(RestServlet): """Get users invited to a group """ + PATTERNS = client_patterns("/groups/(?P[^/]*)/invited_users$") def __init__(self, hs): @@ -395,8 +392,7 @@ class GroupInvitedUsersServlet(RestServlet): requester_user_id = requester.user.to_string() result = yield self.groups_handler.get_invited_users_in_group( - group_id, - requester_user_id, + group_id, requester_user_id ) defer.returnValue((200, result)) @@ -405,6 +401,7 @@ class GroupInvitedUsersServlet(RestServlet): class GroupSettingJoinPolicyServlet(RestServlet): """Set group join policy """ + PATTERNS = client_patterns("/groups/(?P[^/]*)/settings/m.join_policy$") def __init__(self, hs): @@ -420,9 +417,7 @@ class GroupSettingJoinPolicyServlet(RestServlet): content = parse_json_object_from_request(request) result = yield self.groups_handler.set_group_join_policy( - group_id, - requester_user_id, - content, + group_id, requester_user_id, content ) defer.returnValue((200, result)) @@ -431,6 +426,7 @@ class GroupSettingJoinPolicyServlet(RestServlet): class GroupCreateServlet(RestServlet): """Create a group """ + PATTERNS = client_patterns("/create_group$") def __init__(self, hs): @@ -451,9 +447,7 @@ class GroupCreateServlet(RestServlet): group_id = GroupID(localpart, self.server_name).to_string() result = yield self.groups_handler.create_group( - group_id, - requester_user_id, - content, + group_id, requester_user_id, content ) defer.returnValue((200, result)) @@ -462,6 +456,7 @@ class GroupCreateServlet(RestServlet): class GroupAdminRoomsServlet(RestServlet): """Add a room to the group """ + PATTERNS = client_patterns( "/groups/(?P[^/]*)/admin/rooms/(?P[^/]*)$" ) @@ -479,7 +474,7 @@ class GroupAdminRoomsServlet(RestServlet): content = parse_json_object_from_request(request) result = yield self.groups_handler.add_room_to_group( - group_id, requester_user_id, room_id, content, + group_id, requester_user_id, room_id, content ) defer.returnValue((200, result)) @@ -490,7 +485,7 @@ class GroupAdminRoomsServlet(RestServlet): requester_user_id = requester.user.to_string() result = yield self.groups_handler.remove_room_from_group( - group_id, requester_user_id, room_id, + group_id, requester_user_id, room_id ) defer.returnValue((200, result)) @@ -499,6 +494,7 @@ class GroupAdminRoomsServlet(RestServlet): class GroupAdminRoomsConfigServlet(RestServlet): """Update the config of a room in a group """ + PATTERNS = client_patterns( "/groups/(?P[^/]*)/admin/rooms/(?P[^/]*)" "/config/(?P[^/]*)$" @@ -517,7 +513,7 @@ class GroupAdminRoomsConfigServlet(RestServlet): content = parse_json_object_from_request(request) result = yield self.groups_handler.update_room_in_group( - group_id, requester_user_id, room_id, config_key, content, + group_id, requester_user_id, room_id, config_key, content ) defer.returnValue((200, result)) @@ -526,6 +522,7 @@ class GroupAdminRoomsConfigServlet(RestServlet): class GroupAdminUsersInviteServlet(RestServlet): """Invite a user to the group """ + PATTERNS = client_patterns( "/groups/(?P[^/]*)/admin/users/invite/(?P[^/]*)$" ) @@ -546,7 +543,7 @@ class GroupAdminUsersInviteServlet(RestServlet): content = parse_json_object_from_request(request) config = content.get("config", {}) result = yield self.groups_handler.invite( - group_id, user_id, requester_user_id, config, + group_id, user_id, requester_user_id, config ) defer.returnValue((200, result)) @@ -555,6 +552,7 @@ class GroupAdminUsersInviteServlet(RestServlet): class GroupAdminUsersKickServlet(RestServlet): """Kick a user from the group """ + PATTERNS = client_patterns( "/groups/(?P[^/]*)/admin/users/remove/(?P[^/]*)$" ) @@ -572,7 +570,7 @@ class GroupAdminUsersKickServlet(RestServlet): content = parse_json_object_from_request(request) result = yield self.groups_handler.remove_user_from_group( - group_id, user_id, requester_user_id, content, + group_id, user_id, requester_user_id, content ) defer.returnValue((200, result)) @@ -581,9 +579,8 @@ class GroupAdminUsersKickServlet(RestServlet): class GroupSelfLeaveServlet(RestServlet): """Leave a joined group """ - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/self/leave$" - ) + + PATTERNS = client_patterns("/groups/(?P[^/]*)/self/leave$") def __init__(self, hs): super(GroupSelfLeaveServlet, self).__init__() @@ -598,7 +595,7 @@ class GroupSelfLeaveServlet(RestServlet): content = parse_json_object_from_request(request) result = yield self.groups_handler.remove_user_from_group( - group_id, requester_user_id, requester_user_id, content, + group_id, requester_user_id, requester_user_id, content ) defer.returnValue((200, result)) @@ -607,9 +604,8 @@ class GroupSelfLeaveServlet(RestServlet): class GroupSelfJoinServlet(RestServlet): """Attempt to join a group, or knock """ - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/self/join$" - ) + + PATTERNS = client_patterns("/groups/(?P[^/]*)/self/join$") def __init__(self, hs): super(GroupSelfJoinServlet, self).__init__() @@ -624,7 +620,7 @@ class GroupSelfJoinServlet(RestServlet): content = parse_json_object_from_request(request) result = yield self.groups_handler.join_group( - group_id, requester_user_id, content, + group_id, requester_user_id, content ) defer.returnValue((200, result)) @@ -633,9 +629,8 @@ class GroupSelfJoinServlet(RestServlet): class GroupSelfAcceptInviteServlet(RestServlet): """Accept a group invite """ - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/self/accept_invite$" - ) + + PATTERNS = client_patterns("/groups/(?P[^/]*)/self/accept_invite$") def __init__(self, hs): super(GroupSelfAcceptInviteServlet, self).__init__() @@ -650,7 +645,7 @@ class GroupSelfAcceptInviteServlet(RestServlet): content = parse_json_object_from_request(request) result = yield self.groups_handler.accept_invite( - group_id, requester_user_id, content, + group_id, requester_user_id, content ) defer.returnValue((200, result)) @@ -659,9 +654,8 @@ class GroupSelfAcceptInviteServlet(RestServlet): class GroupSelfUpdatePublicityServlet(RestServlet): """Update whether we publicise a users membership of a group """ - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/self/update_publicity$" - ) + + PATTERNS = client_patterns("/groups/(?P[^/]*)/self/update_publicity$") def __init__(self, hs): super(GroupSelfUpdatePublicityServlet, self).__init__() @@ -676,9 +670,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet): content = parse_json_object_from_request(request) publicise = content["publicise"] - yield self.store.update_group_publicity( - group_id, requester_user_id, publicise, - ) + yield self.store.update_group_publicity(group_id, requester_user_id, publicise) defer.returnValue((200, {})) @@ -686,9 +678,8 @@ class GroupSelfUpdatePublicityServlet(RestServlet): class PublicisedGroupsForUserServlet(RestServlet): """Get the list of groups a user is advertising """ - PATTERNS = client_patterns( - "/publicised_groups/(?P[^/]*)$" - ) + + PATTERNS = client_patterns("/publicised_groups/(?P[^/]*)$") def __init__(self, hs): super(PublicisedGroupsForUserServlet, self).__init__() @@ -701,9 +692,7 @@ class PublicisedGroupsForUserServlet(RestServlet): def on_GET(self, request, user_id): yield self.auth.get_user_by_req(request, allow_guest=True) - result = yield self.groups_handler.get_publicised_groups_for_user( - user_id - ) + result = yield self.groups_handler.get_publicised_groups_for_user(user_id) defer.returnValue((200, result)) @@ -711,9 +700,8 @@ class PublicisedGroupsForUserServlet(RestServlet): class PublicisedGroupsForUsersServlet(RestServlet): """Get the list of groups a user is advertising """ - PATTERNS = client_patterns( - "/publicised_groups$" - ) + + PATTERNS = client_patterns("/publicised_groups$") def __init__(self, hs): super(PublicisedGroupsForUsersServlet, self).__init__() @@ -729,9 +717,7 @@ class PublicisedGroupsForUsersServlet(RestServlet): content = parse_json_object_from_request(request) user_ids = content["user_ids"] - result = yield self.groups_handler.bulk_get_publicised_groups( - user_ids - ) + result = yield self.groups_handler.bulk_get_publicised_groups(user_ids) defer.returnValue((200, result)) @@ -739,9 +725,8 @@ class PublicisedGroupsForUsersServlet(RestServlet): class GroupsForUserServlet(RestServlet): """Get all groups the logged in user is joined to """ - PATTERNS = client_patterns( - "/joined_groups$" - ) + + PATTERNS = client_patterns("/joined_groups$") def __init__(self, hs): super(GroupsForUserServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 4cbfbf5631..45c9928b65 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -56,6 +56,7 @@ class KeyUploadServlet(RestServlet): }, } """ + PATTERNS = client_patterns("/keys/upload(/(?P[^/]+))?$") def __init__(self, hs): @@ -76,18 +77,19 @@ class KeyUploadServlet(RestServlet): if device_id is not None: # passing the device_id here is deprecated; however, we allow it # for now for compatibility with older clients. - if (requester.device_id is not None and - device_id != requester.device_id): - logger.warning("Client uploading keys for a different device " - "(logged in as %s, uploading for %s)", - requester.device_id, device_id) + if requester.device_id is not None and device_id != requester.device_id: + logger.warning( + "Client uploading keys for a different device " + "(logged in as %s, uploading for %s)", + requester.device_id, + device_id, + ) else: device_id = requester.device_id if device_id is None: raise SynapseError( - 400, - "To upload keys, you must pass device_id when authenticating" + 400, "To upload keys, you must pass device_id when authenticating" ) result = yield self.e2e_keys_handler.upload_keys_for_user( @@ -159,6 +161,7 @@ class KeyChangesServlet(RestServlet): 200 OK { "changed": ["@foo:example.com"] } """ + PATTERNS = client_patterns("/keys/changes$") def __init__(self, hs): @@ -184,9 +187,7 @@ class KeyChangesServlet(RestServlet): user_id = requester.user.to_string() - results = yield self.device_handler.get_user_ids_changed( - user_id, from_token, - ) + results = yield self.device_handler.get_user_ids_changed(user_id, from_token) defer.returnValue((200, results)) @@ -209,6 +210,7 @@ class OneTimeKeyServlet(RestServlet): } } } } """ + PATTERNS = client_patterns("/keys/claim$") def __init__(self, hs): @@ -221,10 +223,7 @@ class OneTimeKeyServlet(RestServlet): yield self.auth.get_user_by_req(request, allow_guest=True) timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) - result = yield self.e2e_keys_handler.claim_one_time_keys( - body, - timeout, - ) + result = yield self.e2e_keys_handler.claim_one_time_keys(body, timeout) defer.returnValue((200, result)) diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py index 53e666989b..728a52328f 100644 --- a/synapse/rest/client/v2_alpha/notifications.py +++ b/synapse/rest/client/v2_alpha/notifications.py @@ -51,7 +51,7 @@ class NotificationsServlet(RestServlet): ) receipts_by_room = yield self.store.get_receipts_for_user_with_orderings( - user_id, 'm.read' + user_id, "m.read" ) notif_event_ids = [pa["event_id"] for pa in push_actions] @@ -67,11 +67,13 @@ class NotificationsServlet(RestServlet): "profile_tag": pa["profile_tag"], "actions": pa["actions"], "ts": pa["received_ts"], - "event": (yield self._event_serializer.serialize_event( - notif_events[pa["event_id"]], - self.clock.time_msec(), - event_format=format_event_for_client_v2_without_room_id, - )), + "event": ( + yield self._event_serializer.serialize_event( + notif_events[pa["event_id"]], + self.clock.time_msec(), + event_format=format_event_for_client_v2_without_room_id, + ) + ), } if pa["room_id"] not in receipts_by_room: @@ -80,17 +82,15 @@ class NotificationsServlet(RestServlet): receipt = receipts_by_room[pa["room_id"]] returned_pa["read"] = ( - receipt["topological_ordering"], receipt["stream_ordering"] - ) >= ( - pa["topological_ordering"], pa["stream_ordering"] - ) + receipt["topological_ordering"], + receipt["stream_ordering"], + ) >= (pa["topological_ordering"], pa["stream_ordering"]) returned_push_actions.append(returned_pa) next_token = str(pa["stream_ordering"]) - defer.returnValue((200, { - "notifications": returned_push_actions, - "next_token": next_token, - })) + defer.returnValue( + (200, {"notifications": returned_push_actions, "next_token": next_token}) + ) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py index bb927d9f9d..b1b5385b09 100644 --- a/synapse/rest/client/v2_alpha/openid.py +++ b/synapse/rest/client/v2_alpha/openid.py @@ -56,9 +56,8 @@ class IdTokenServlet(RestServlet): "expires_in": 3600, } """ - PATTERNS = client_patterns( - "/user/(?P[^/]*)/openid/request_token" - ) + + PATTERNS = client_patterns("/user/(?P[^/]*)/openid/request_token") EXPIRES_MS = 3600 * 1000 @@ -84,12 +83,17 @@ class IdTokenServlet(RestServlet): yield self.store.insert_open_id_token(token, ts_valid_until_ms, user_id) - defer.returnValue((200, { - "access_token": token, - "token_type": "Bearer", - "matrix_server_name": self.server_name, - "expires_in": self.EXPIRES_MS / 1000, - })) + defer.returnValue( + ( + 200, + { + "access_token": token, + "token_type": "Bearer", + "matrix_server_name": self.server_name, + "expires_in": self.EXPIRES_MS / 1000, + }, + ) + ) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py index f4bd0d077f..e75664279b 100644 --- a/synapse/rest/client/v2_alpha/read_marker.py +++ b/synapse/rest/client/v2_alpha/read_marker.py @@ -48,7 +48,7 @@ class ReadMarkerRestServlet(RestServlet): room_id, "m.read", user_id=requester.user.to_string(), - event_id=read_event_id + event_id=read_event_id, ) read_marker_event_id = body.get("m.fully_read", None) @@ -56,7 +56,7 @@ class ReadMarkerRestServlet(RestServlet): yield self.read_marker_handler.received_client_read_marker( room_id, user_id=requester.user.to_string(), - event_id=read_marker_event_id + event_id=read_marker_event_id, ) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index fa12ac3e4d..488905626a 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -49,10 +49,7 @@ class ReceiptRestServlet(RestServlet): yield self.presence_handler.bump_presence_active_time(requester.user) yield self.receipts_handler.received_client_receipt( - room_id, - receipt_type, - user_id=requester.user.to_string(), - event_id=event_id + room_id, receipt_type, user_id=requester.user.to_string(), event_id=event_id ) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 79c085408b..5c120e4dd5 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -52,6 +52,7 @@ from ._base import client_patterns, interactive_auth_handler if hasattr(hmac, "compare_digest"): compare_digest = hmac.compare_digest else: + def compare_digest(a, b): return a == b @@ -75,11 +76,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): def on_POST(self, request): body = parse_json_object_from_request(request) - assert_params_in_dict(body, [ - 'id_server', 'client_secret', 'email', 'send_attempt' - ]) + assert_params_in_dict( + body, ["id_server", "client_secret", "email", "send_attempt"] + ) - if not check_3pid_allowed(self.hs, "email", body['email']): + if not check_3pid_allowed(self.hs, "email", body["email"]): raise SynapseError( 403, "Your email domain is not authorized to register on this server", @@ -87,7 +88,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): ) existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( - 'email', body['email'] + "email", body["email"] ) if existingUid is not None: @@ -113,13 +114,12 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): def on_POST(self, request): body = parse_json_object_from_request(request) - assert_params_in_dict(body, [ - 'id_server', 'client_secret', - 'country', 'phone_number', - 'send_attempt', - ]) + assert_params_in_dict( + body, + ["id_server", "client_secret", "country", "phone_number", "send_attempt"], + ) - msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) + msisdn = phone_number_to_msisdn(body["country"], body["phone_number"]) if not check_3pid_allowed(self.hs, "msisdn", msisdn): raise SynapseError( @@ -129,7 +129,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): ) existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( - 'msisdn', msisdn + "msisdn", msisdn ) if existingUid is not None: @@ -165,7 +165,7 @@ class UsernameAvailabilityRestServlet(RestServlet): reject_limit=1, # Allow 1 request at a time concurrent_requests=1, - ) + ), ) @defer.inlineCallbacks @@ -212,7 +212,8 @@ class RegisterRestServlet(RestServlet): time_now = self.clock.time() allowed, time_allowed = self.ratelimiter.can_do_action( - client_addr, time_now_s=time_now, + client_addr, + time_now_s=time_now, rate_hz=self.hs.config.rc_registration.per_second, burst_count=self.hs.config.rc_registration.burst_count, update=False, @@ -220,7 +221,7 @@ class RegisterRestServlet(RestServlet): if not allowed: raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now)), + retry_after_ms=int(1000 * (time_allowed - time_now)) ) kind = b"user" @@ -239,18 +240,22 @@ class RegisterRestServlet(RestServlet): # we do basic sanity checks here because the auth layer will store these # in sessions. Pull out the username/password provided to us. desired_password = None - if 'password' in body: - if (not isinstance(body['password'], string_types) or - len(body['password']) > 512): + if "password" in body: + if ( + not isinstance(body["password"], string_types) + or len(body["password"]) > 512 + ): raise SynapseError(400, "Invalid password") desired_password = body["password"] desired_username = None - if 'username' in body: - if (not isinstance(body['username'], string_types) or - len(body['username']) > 512): + if "username" in body: + if ( + not isinstance(body["username"], string_types) + or len(body["username"]) > 512 + ): raise SynapseError(400, "Invalid username") - desired_username = body['username'] + desired_username = body["username"] appservice = None if self.auth.has_access_token(request): @@ -290,7 +295,7 @@ class RegisterRestServlet(RestServlet): desired_username = desired_username.lower() # == Shared Secret Registration == (e.g. create new user scripts) - if 'mac' in body: + if "mac" in body: # FIXME: Should we really be determining if this is shared secret # auth based purely on the 'mac' key? result = yield self._do_shared_secret_registration( @@ -305,16 +310,13 @@ class RegisterRestServlet(RestServlet): guest_access_token = body.get("guest_access_token", None) - if ( - 'initial_device_display_name' in body and - 'password' not in body - ): + if "initial_device_display_name" in body and "password" not in body: # ignore 'initial_device_display_name' if sent without # a password to work around a client bug where it sent # the 'initial_device_display_name' param alone, wiping out # the original registration params logger.warn("Ignoring initial_device_display_name without password") - del body['initial_device_display_name'] + del body["initial_device_display_name"] session_id = self.auth_handler.get_session_id(body) registered_user_id = None @@ -336,8 +338,8 @@ class RegisterRestServlet(RestServlet): # FIXME: need a better error than "no auth flow found" for scenarios # where we required 3PID for registration but the user didn't give one - require_email = 'email' in self.hs.config.registrations_require_3pid - require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid + require_email = "email" in self.hs.config.registrations_require_3pid + require_msisdn = "msisdn" in self.hs.config.registrations_require_3pid show_msisdn = True if self.hs.config.disable_msisdn_registration: @@ -362,9 +364,9 @@ class RegisterRestServlet(RestServlet): if not require_email: flows.extend([[LoginType.RECAPTCHA, LoginType.MSISDN]]) # always let users provide both MSISDN & email - flows.extend([ - [LoginType.RECAPTCHA, LoginType.MSISDN, LoginType.EMAIL_IDENTITY], - ]) + flows.extend( + [[LoginType.RECAPTCHA, LoginType.MSISDN, LoginType.EMAIL_IDENTITY]] + ) else: # only support 3PIDless registration if no 3PIDs are required if not require_email and not require_msisdn: @@ -378,9 +380,7 @@ class RegisterRestServlet(RestServlet): if not require_email or require_msisdn: flows.extend([[LoginType.MSISDN]]) # always let users provide both MSISDN & email - flows.extend([ - [LoginType.MSISDN, LoginType.EMAIL_IDENTITY] - ]) + flows.extend([[LoginType.MSISDN, LoginType.EMAIL_IDENTITY]]) # Append m.login.terms to all flows if we're requiring consent if self.hs.config.user_consent_at_registration: @@ -410,21 +410,20 @@ class RegisterRestServlet(RestServlet): if auth_result: for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]: if login_type in auth_result: - medium = auth_result[login_type]['medium'] - address = auth_result[login_type]['address'] + medium = auth_result[login_type]["medium"] + address = auth_result[login_type]["address"] if not check_3pid_allowed(self.hs, medium, address): raise SynapseError( 403, - "Third party identifiers (email/phone numbers)" + - " are not authorized on this server", + "Third party identifiers (email/phone numbers)" + + " are not authorized on this server", Codes.THREEPID_DENIED, ) if registered_user_id is not None: logger.info( - "Already registered user ID %r for this session", - registered_user_id + "Already registered user ID %r for this session", registered_user_id ) # don't re-register the threepids registered = False @@ -451,11 +450,11 @@ class RegisterRestServlet(RestServlet): # the two activation emails, they would register the same 3pid twice. for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]: if login_type in auth_result: - medium = auth_result[login_type]['medium'] - address = auth_result[login_type]['address'] + medium = auth_result[login_type]["medium"] + address = auth_result[login_type]["address"] existingUid = yield self.store.get_user_id_by_threepid( - medium, address, + medium, address ) if existingUid is not None: @@ -520,7 +519,7 @@ class RegisterRestServlet(RestServlet): raise SynapseError(400, "Shared secret registration is not enabled") if not username: raise SynapseError( - 400, "username must be specified", errcode=Codes.BAD_JSON, + 400, "username must be specified", errcode=Codes.BAD_JSON ) # use the username from the original request rather than the @@ -541,12 +540,10 @@ class RegisterRestServlet(RestServlet): ).hexdigest() if not compare_digest(want_mac, got_mac): - raise SynapseError( - 403, "HMAC incorrect", - ) + raise SynapseError(403, "HMAC incorrect") (user_id, _) = yield self.registration_handler.register( - localpart=username, password=password, generate_token=False, + localpart=username, password=password, generate_token=False ) result = yield self._create_registration_details(user_id, body) @@ -565,21 +562,15 @@ class RegisterRestServlet(RestServlet): Returns: defer.Deferred: (object) dictionary for response from /register """ - result = { - "user_id": user_id, - "home_server": self.hs.hostname, - } + result = {"user_id": user_id, "home_server": self.hs.hostname} if not params.get("inhibit_login", False): device_id = params.get("device_id") initial_display_name = params.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( - user_id, device_id, initial_display_name, is_guest=False, + user_id, device_id, initial_display_name, is_guest=False ) - result.update({ - "access_token": access_token, - "device_id": device_id, - }) + result.update({"access_token": access_token, "device_id": device_id}) defer.returnValue(result) @defer.inlineCallbacks @@ -587,9 +578,7 @@ class RegisterRestServlet(RestServlet): if not self.hs.config.allow_guest_access: raise SynapseError(403, "Guest access is disabled") user_id, _ = yield self.registration_handler.register( - generate_token=False, - make_guest=True, - address=address, + generate_token=False, make_guest=True, address=address ) # we don't allow guests to specify their own device_id, because @@ -597,15 +586,20 @@ class RegisterRestServlet(RestServlet): device_id = synapse.api.auth.GUEST_DEVICE_ID initial_display_name = params.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( - user_id, device_id, initial_display_name, is_guest=True, + user_id, device_id, initial_display_name, is_guest=True ) - defer.returnValue((200, { - "user_id": user_id, - "device_id": device_id, - "access_token": access_token, - "home_server": self.hs.hostname, - })) + defer.returnValue( + ( + 200, + { + "user_id": user_id, + "device_id": device_id, + "access_token": access_token, + "home_server": self.hs.hostname, + }, + ) + ) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index f8f8742bdc..8e362782cc 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -32,7 +32,10 @@ from synapse.http.servlet import ( parse_string, ) from synapse.rest.client.transactions import HttpTransactionCache -from synapse.storage.relations import AggregationPaginationToken, RelationPaginationToken +from synapse.storage.relations import ( + AggregationPaginationToken, + RelationPaginationToken, +) from ._base import client_patterns diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py index 10198662a9..e7578af804 100644 --- a/synapse/rest/client/v2_alpha/report_event.py +++ b/synapse/rest/client/v2_alpha/report_event.py @@ -33,9 +33,7 @@ logger = logging.getLogger(__name__) class ReportEventRestServlet(RestServlet): - PATTERNS = client_patterns( - "/rooms/(?P[^/]*)/report/(?P[^/]*)$" - ) + PATTERNS = client_patterns("/rooms/(?P[^/]*)/report/(?P[^/]*)$") def __init__(self, hs): super(ReportEventRestServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py index 87779645f9..8d1b810565 100644 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -129,22 +129,12 @@ class RoomKeysServlet(RestServlet): version = parse_string(request, "version") if session_id: - body = { - "sessions": { - session_id: body - } - } + body = {"sessions": {session_id: body}} if room_id: - body = { - "rooms": { - room_id: body - } - } + body = {"rooms": {room_id: body}} - yield self.e2e_room_keys_handler.upload_room_keys( - user_id, version, body - ) + yield self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) defer.returnValue((200, {})) @defer.inlineCallbacks @@ -212,10 +202,10 @@ class RoomKeysServlet(RestServlet): if session_id: # If the client requests a specific session, but that session was # not backed up, then return an M_NOT_FOUND. - if room_keys['rooms'] == {}: + if room_keys["rooms"] == {}: raise NotFoundError("No room_keys found") else: - room_keys = room_keys['rooms'][room_id]['sessions'][session_id] + room_keys = room_keys["rooms"][room_id]["sessions"][session_id] elif room_id: # If the client requests all sessions from a room, but no sessions # are found, then return an empty result rather than an error, so @@ -223,10 +213,10 @@ class RoomKeysServlet(RestServlet): # empty result is valid. (Similarly if the client requests all # sessions from the backup, but in that case, room_keys is already # in the right format, so we don't need to do anything about it.) - if room_keys['rooms'] == {}: - room_keys = {'sessions': {}} + if room_keys["rooms"] == {}: + room_keys = {"sessions": {}} else: - room_keys = room_keys['rooms'][room_id] + room_keys = room_keys["rooms"][room_id] defer.returnValue((200, room_keys)) @@ -256,9 +246,7 @@ class RoomKeysServlet(RestServlet): class RoomKeysNewVersionServlet(RestServlet): - PATTERNS = client_patterns( - "/room_keys/version$" - ) + PATTERNS = client_patterns("/room_keys/version$") def __init__(self, hs): """ @@ -304,9 +292,7 @@ class RoomKeysNewVersionServlet(RestServlet): user_id = requester.user.to_string() info = parse_json_object_from_request(request) - new_version = yield self.e2e_room_keys_handler.create_version( - user_id, info - ) + new_version = yield self.e2e_room_keys_handler.create_version(user_id, info) defer.returnValue((200, {"version": new_version})) # we deliberately don't have a PUT /version, as these things really should @@ -314,9 +300,7 @@ class RoomKeysNewVersionServlet(RestServlet): class RoomKeysVersionServlet(RestServlet): - PATTERNS = client_patterns( - "/room_keys/version(/(?P[^/]+))?$" - ) + PATTERNS = client_patterns("/room_keys/version(/(?P[^/]+))?$") def __init__(self, hs): """ @@ -350,9 +334,7 @@ class RoomKeysVersionServlet(RestServlet): user_id = requester.user.to_string() try: - info = yield self.e2e_room_keys_handler.get_version_info( - user_id, version - ) + info = yield self.e2e_room_keys_handler.get_version_info(user_id, version) except SynapseError as e: if e.code == 404: raise SynapseError(404, "No backup found", Codes.NOT_FOUND) @@ -375,9 +357,7 @@ class RoomKeysVersionServlet(RestServlet): requester = yield self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() - yield self.e2e_room_keys_handler.delete_version( - user_id, version - ) + yield self.e2e_room_keys_handler.delete_version(user_id, version) defer.returnValue((200, {})) @defer.inlineCallbacks @@ -407,11 +387,11 @@ class RoomKeysVersionServlet(RestServlet): info = parse_json_object_from_request(request) if version is None: - raise SynapseError(400, "No version specified to update", Codes.MISSING_PARAM) + raise SynapseError( + 400, "No version specified to update", Codes.MISSING_PARAM + ) - yield self.e2e_room_keys_handler.update_version( - user_id, version, info - ) + yield self.e2e_room_keys_handler.update_version(user_id, version, info) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py index c621a90fba..d7f7faa029 100644 --- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py +++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py @@ -47,9 +47,10 @@ class RoomUpgradeRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): """ + PATTERNS = client_patterns( # /rooms/$roomid/upgrade - "/rooms/(?P[^/]*)/upgrade$", + "/rooms/(?P[^/]*)/upgrade$" ) def __init__(self, hs): @@ -63,7 +64,7 @@ class RoomUpgradeRestServlet(RestServlet): requester = yield self._auth.get_user_by_req(request) content = parse_json_object_from_request(request) - assert_params_in_dict(content, ("new_version", )) + assert_params_in_dict(content, ("new_version",)) new_version = content["new_version"] if new_version not in KNOWN_ROOM_VERSIONS: @@ -77,9 +78,7 @@ class RoomUpgradeRestServlet(RestServlet): requester, room_id, new_version ) - ret = { - "replacement_room": new_room_id, - } + ret = {"replacement_room": new_room_id} defer.returnValue((200, ret)) diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py index 120a713361..78075b8fc0 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) class SendToDeviceRestServlet(servlet.RestServlet): PATTERNS = client_patterns( - "/sendToDevice/(?P[^/]*)/(?P[^/]*)$", + "/sendToDevice/(?P[^/]*)/(?P[^/]*)$" ) def __init__(self, hs): diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 148fc6c985..02d56dee6c 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -96,44 +96,42 @@ class SyncRestServlet(RestServlet): 400, "'from' is not a valid query parameter. Did you mean 'since'?" ) - requester = yield self.auth.get_user_by_req( - request, allow_guest=True - ) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) user = requester.user device_id = requester.device_id timeout = parse_integer(request, "timeout", default=0) since = parse_string(request, "since") set_presence = parse_string( - request, "set_presence", default="online", - allowed_values=self.ALLOWED_PRESENCE + request, + "set_presence", + default="online", + allowed_values=self.ALLOWED_PRESENCE, ) filter_id = parse_string(request, "filter", default=None) full_state = parse_boolean(request, "full_state", default=False) logger.debug( "/sync: user=%r, timeout=%r, since=%r," - " set_presence=%r, filter_id=%r, device_id=%r" % ( - user, timeout, since, set_presence, filter_id, device_id - ) + " set_presence=%r, filter_id=%r, device_id=%r" + % (user, timeout, since, set_presence, filter_id, device_id) ) request_key = (user, timeout, since, filter_id, full_state, device_id) if filter_id: - if filter_id.startswith('{'): + if filter_id.startswith("{"): try: filter_object = json.loads(filter_id) - set_timeline_upper_limit(filter_object, - self.hs.config.filter_timeline_limit) + set_timeline_upper_limit( + filter_object, self.hs.config.filter_timeline_limit + ) except Exception: raise SynapseError(400, "Invalid filter JSON") self.filtering.check_valid_filter(filter_object) filter = FilterCollection(filter_object) else: - filter = yield self.filtering.get_user_filter( - user.localpart, filter_id - ) + filter = yield self.filtering.get_user_filter(user.localpart, filter_id) else: filter = DEFAULT_FILTER_COLLECTION @@ -156,15 +154,19 @@ class SyncRestServlet(RestServlet): affect_presence = set_presence != PresenceState.OFFLINE if affect_presence: - yield self.presence_handler.set_state(user, {"presence": set_presence}, True) + yield self.presence_handler.set_state( + user, {"presence": set_presence}, True + ) context = yield self.presence_handler.user_syncing( - user.to_string(), affect_presence=affect_presence, + user.to_string(), affect_presence=affect_presence ) with context: sync_result = yield self.sync_handler.wait_for_sync_for_user( - sync_config, since_token=since_token, timeout=timeout, - full_state=full_state + sync_config, + since_token=since_token, + timeout=timeout, + full_state=full_state, ) time_now = self.clock.time_msec() @@ -176,53 +178,54 @@ class SyncRestServlet(RestServlet): @defer.inlineCallbacks def encode_response(self, time_now, sync_result, access_token_id, filter): - if filter.event_format == 'client': + if filter.event_format == "client": event_formatter = format_event_for_client_v2_without_room_id - elif filter.event_format == 'federation': + elif filter.event_format == "federation": event_formatter = format_event_raw else: - raise Exception("Unknown event format %s" % (filter.event_format, )) + raise Exception("Unknown event format %s" % (filter.event_format,)) joined = yield self.encode_joined( - sync_result.joined, time_now, access_token_id, + sync_result.joined, + time_now, + access_token_id, filter.event_fields, event_formatter, ) invited = yield self.encode_invited( - sync_result.invited, time_now, access_token_id, - event_formatter, + sync_result.invited, time_now, access_token_id, event_formatter ) archived = yield self.encode_archived( - sync_result.archived, time_now, access_token_id, + sync_result.archived, + time_now, + access_token_id, filter.event_fields, event_formatter, ) - defer.returnValue({ - "account_data": {"events": sync_result.account_data}, - "to_device": {"events": sync_result.to_device}, - "device_lists": { - "changed": list(sync_result.device_lists.changed), - "left": list(sync_result.device_lists.left), - }, - "presence": SyncRestServlet.encode_presence( - sync_result.presence, time_now - ), - "rooms": { - "join": joined, - "invite": invited, - "leave": archived, - }, - "groups": { - "join": sync_result.groups.join, - "invite": sync_result.groups.invite, - "leave": sync_result.groups.leave, - }, - "device_one_time_keys_count": sync_result.device_one_time_keys_count, - "next_batch": sync_result.next_batch.to_string(), - }) + defer.returnValue( + { + "account_data": {"events": sync_result.account_data}, + "to_device": {"events": sync_result.to_device}, + "device_lists": { + "changed": list(sync_result.device_lists.changed), + "left": list(sync_result.device_lists.left), + }, + "presence": SyncRestServlet.encode_presence( + sync_result.presence, time_now + ), + "rooms": {"join": joined, "invite": invited, "leave": archived}, + "groups": { + "join": sync_result.groups.join, + "invite": sync_result.groups.invite, + "leave": sync_result.groups.leave, + }, + "device_one_time_keys_count": sync_result.device_one_time_keys_count, + "next_batch": sync_result.next_batch.to_string(), + } + ) @staticmethod def encode_presence(events, time_now): @@ -262,7 +265,11 @@ class SyncRestServlet(RestServlet): joined = {} for room in rooms: joined[room.room_id] = yield self.encode_room( - room, time_now, token_id, joined=True, only_fields=event_fields, + room, + time_now, + token_id, + joined=True, + only_fields=event_fields, event_formatter=event_formatter, ) @@ -290,7 +297,9 @@ class SyncRestServlet(RestServlet): invited = {} for room in rooms: invite = yield self._event_serializer.serialize_event( - room.invite, time_now, token_id=token_id, + room.invite, + time_now, + token_id=token_id, event_format=event_formatter, is_invite=True, ) @@ -298,9 +307,7 @@ class SyncRestServlet(RestServlet): invite["unsigned"] = unsigned invited_state = list(unsigned.pop("invite_room_state", [])) invited_state.append(invite) - invited[room.room_id] = { - "invite_state": {"events": invited_state} - } + invited[room.room_id] = {"invite_state": {"events": invited_state}} defer.returnValue(invited) @@ -327,7 +334,10 @@ class SyncRestServlet(RestServlet): joined = {} for room in rooms: joined[room.room_id] = yield self.encode_room( - room, time_now, token_id, joined=False, + room, + time_now, + token_id, + joined=False, only_fields=event_fields, event_formatter=event_formatter, ) @@ -336,8 +346,7 @@ class SyncRestServlet(RestServlet): @defer.inlineCallbacks def encode_room( - self, room, time_now, token_id, joined, - only_fields, event_formatter, + self, room, time_now, token_id, joined, only_fields, event_formatter ): """ Args: @@ -355,9 +364,11 @@ class SyncRestServlet(RestServlet): Returns: dict[str, object]: the room, encoded in our response format """ + def serialize(events): return self._event_serializer.serialize_events( - events, time_now=time_now, + events, + time_now=time_now, # We don't bundle "live" events, as otherwise clients # will end up double counting annotations. bundle_aggregations=False, @@ -377,7 +388,9 @@ class SyncRestServlet(RestServlet): if event.room_id != room.room_id: logger.warn( "Event %r is under room %r instead of %r", - event.event_id, room.room_id, event.room_id, + event.event_id, + room.room_id, + event.room_id, ) serialized_state = yield serialize(state_events) diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index ebff7cff45..07b6ede603 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -29,9 +29,8 @@ class TagListServlet(RestServlet): """ GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1 """ - PATTERNS = client_patterns( - "/user/(?P[^/]*)/rooms/(?P[^/]*)/tags" - ) + + PATTERNS = client_patterns("/user/(?P[^/]*)/rooms/(?P[^/]*)/tags") def __init__(self, hs): super(TagListServlet, self).__init__() @@ -54,6 +53,7 @@ class TagServlet(RestServlet): PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 """ + PATTERNS = client_patterns( "/user/(?P[^/]*)/rooms/(?P[^/]*)/tags/(?P[^/]*)" ) @@ -74,9 +74,7 @@ class TagServlet(RestServlet): max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body) - self.notifier.on_new_event( - "account_data_key", max_id, users=[user_id] - ) + self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) defer.returnValue((200, {})) @@ -88,9 +86,7 @@ class TagServlet(RestServlet): max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag) - self.notifier.on_new_event( - "account_data_key", max_id, users=[user_id] - ) + self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index e7a987466a..1e66662a05 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -57,7 +57,7 @@ class ThirdPartyProtocolServlet(RestServlet): yield self.auth.get_user_by_req(request, allow_guest=True) protocols = yield self.appservice_handler.get_3pe_protocols( - only_protocol=protocol, + only_protocol=protocol ) if protocol in protocols: defer.returnValue((200, protocols[protocol])) diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py index 6c366142e1..2da0f55811 100644 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -26,6 +26,7 @@ class TokenRefreshRestServlet(RestServlet): Exchanges refresh tokens for a pair of an access token and a new refresh token. """ + PATTERNS = client_patterns("/tokenrefresh") def __init__(self, hs): diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py index 69e4efc47a..e19fb6d583 100644 --- a/synapse/rest/client/v2_alpha/user_directory.py +++ b/synapse/rest/client/v2_alpha/user_directory.py @@ -60,10 +60,7 @@ class UserDirectorySearchRestServlet(RestServlet): user_id = requester.user.to_string() if not self.hs.config.user_directory_search_enabled: - defer.returnValue((200, { - "limited": False, - "results": [], - })) + defer.returnValue((200, {"limited": False, "results": []})) body = parse_json_object_from_request(request) @@ -76,7 +73,7 @@ class UserDirectorySearchRestServlet(RestServlet): raise SynapseError(400, "`search_term` is required field") results = yield self.user_directory_handler.search_users( - user_id, search_term, limit, + user_id, search_term, limit ) defer.returnValue((200, results)) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index babbf6a23c..0e09191632 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -25,27 +25,28 @@ class VersionsRestServlet(RestServlet): PATTERNS = [re.compile("^/_matrix/client/versions$")] def on_GET(self, request): - return (200, { - "versions": [ - # XXX: at some point we need to decide whether we need to include - # the previous version numbers, given we've defined r0.3.0 to be - # backwards compatible with r0.2.0. But need to check how - # conscientious we've been in compatibility, and decide whether the - # middle number is the major revision when at 0.X.Y (as opposed to - # X.Y.Z). And we need to decide whether it's fair to make clients - # parse the version string to figure out what's going on. - "r0.0.1", - "r0.1.0", - "r0.2.0", - "r0.3.0", - "r0.4.0", - "r0.5.0", - ], - # as per MSC1497: - "unstable_features": { - "m.lazy_load_members": True, - } - }) + return ( + 200, + { + "versions": [ + # XXX: at some point we need to decide whether we need to include + # the previous version numbers, given we've defined r0.3.0 to be + # backwards compatible with r0.2.0. But need to check how + # conscientious we've been in compatibility, and decide whether the + # middle number is the major revision when at 0.X.Y (as opposed to + # X.Y.Z). And we need to decide whether it's fair to make clients + # parse the version string to figure out what's going on. + "r0.0.1", + "r0.1.0", + "r0.2.0", + "r0.3.0", + "r0.4.0", + "r0.5.0", + ], + # as per MSC1497: + "unstable_features": {"m.lazy_load_members": True}, + }, + ) def register_servlets(http_server): diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index 6b371bfa2f..9a32892d8b 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -42,6 +42,7 @@ logger = logging.getLogger(__name__) if hasattr(hmac, "compare_digest"): compare_digest = hmac.compare_digest else: + def compare_digest(a, b): return a == b @@ -80,6 +81,7 @@ class ConsentResource(Resource): For POST: required; gives the value to be recorded in the database against the user. """ + def __init__(self, hs): """ Args: @@ -98,21 +100,20 @@ class ConsentResource(Resource): if self._default_consent_version is None: raise ConfigError( "Consent resource is enabled but user_consent section is " - "missing in config file.", + "missing in config file." ) consent_template_directory = hs.config.user_consent_template_dir loader = jinja2.FileSystemLoader(consent_template_directory) self._jinja_env = jinja2.Environment( - loader=loader, - autoescape=jinja2.select_autoescape(['html', 'htm', 'xml']), + loader=loader, autoescape=jinja2.select_autoescape(["html", "htm", "xml"]) ) if hs.config.form_secret is None: raise ConfigError( "Consent resource is enabled but form_secret is not set in " - "config file. It should be set to an arbitrary secret string.", + "config file. It should be set to an arbitrary secret string." ) self._hmac_secret = hs.config.form_secret.encode("utf-8") @@ -139,7 +140,7 @@ class ConsentResource(Resource): self._check_hash(username, userhmac_bytes) - if username.startswith('@'): + if username.startswith("@"): qualified_user_id = username else: qualified_user_id = UserID(username, self.hs.hostname).to_string() @@ -153,7 +154,8 @@ class ConsentResource(Resource): try: self._render_template( - request, "%s.html" % (version,), + request, + "%s.html" % (version,), user=username, userhmac=userhmac, version=version, @@ -180,7 +182,7 @@ class ConsentResource(Resource): self._check_hash(username, userhmac) - if username.startswith('@'): + if username.startswith("@"): qualified_user_id = username else: qualified_user_id = UserID(username, self.hs.hostname).to_string() @@ -221,11 +223,13 @@ class ConsentResource(Resource): SynapseError if the hash doesn't match """ - want_mac = hmac.new( - key=self._hmac_secret, - msg=userid.encode('utf-8'), - digestmod=sha256, - ).hexdigest().encode('ascii') + want_mac = ( + hmac.new( + key=self._hmac_secret, msg=userid.encode("utf-8"), digestmod=sha256 + ) + .hexdigest() + .encode("ascii") + ) if not compare_digest(want_mac, userhmac): raise SynapseError(http_client.FORBIDDEN, "HMAC incorrect") diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py index ec0ec7b431..c16280f668 100644 --- a/synapse/rest/key/v2/local_key_resource.py +++ b/synapse/rest/key/v2/local_key_resource.py @@ -80,33 +80,27 @@ class LocalKey(Resource): for key in self.config.signing_key: verify_key_bytes = key.verify_key.encode() key_id = "%s:%s" % (key.alg, key.version) - verify_keys[key_id] = { - u"key": encode_base64(verify_key_bytes) - } + verify_keys[key_id] = {"key": encode_base64(verify_key_bytes)} old_verify_keys = {} for key_id, key in self.config.old_signing_keys.items(): verify_key_bytes = key.encode() old_verify_keys[key_id] = { - u"key": encode_base64(verify_key_bytes), - u"expired_ts": key.expired_ts, + "key": encode_base64(verify_key_bytes), + "expired_ts": key.expired_ts, } tls_fingerprints = self.config.tls_fingerprints json_object = { - u"valid_until_ts": self.valid_until_ts, - u"server_name": self.config.server_name, - u"verify_keys": verify_keys, - u"old_verify_keys": old_verify_keys, - u"tls_fingerprints": tls_fingerprints, + "valid_until_ts": self.valid_until_ts, + "server_name": self.config.server_name, + "verify_keys": verify_keys, + "old_verify_keys": old_verify_keys, + "tls_fingerprints": tls_fingerprints, } for key in self.config.signing_key: - json_object = sign_json( - json_object, - self.config.server_name, - key, - ) + json_object = sign_json(json_object, self.config.server_name, key) return json_object def render_GET(self, request): @@ -114,6 +108,4 @@ class LocalKey(Resource): # Update the expiry time if less than half the interval remains. if time_now + self.config.key_refresh_interval / 2 > self.valid_until_ts: self.update_response_body(time_now) - return respond_with_json_bytes( - request, 200, self.response_body, - ) + return respond_with_json_bytes(request, 200, self.response_body) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 8a730bbc35..ec8b9d7269 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -103,20 +103,16 @@ class RemoteKey(Resource): def async_render_GET(self, request): if len(request.postpath) == 1: server, = request.postpath - query = {server.decode('ascii'): {}} + query = {server.decode("ascii"): {}} elif len(request.postpath) == 2: server, key_id = request.postpath - minimum_valid_until_ts = parse_integer( - request, "minimum_valid_until_ts" - ) + minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts") arguments = {} if minimum_valid_until_ts is not None: arguments["minimum_valid_until_ts"] = minimum_valid_until_ts - query = {server.decode('ascii'): {key_id.decode('ascii'): arguments}} + query = {server.decode("ascii"): {key_id.decode("ascii"): arguments}} else: - raise SynapseError( - 404, "Not found %r" % request.postpath, Codes.NOT_FOUND - ) + raise SynapseError(404, "Not found %r" % request.postpath, Codes.NOT_FOUND) yield self.query_keys(request, query, query_remote_on_cache_miss=True) @@ -140,8 +136,8 @@ class RemoteKey(Resource): store_queries = [] for server_name, key_ids in query.items(): if ( - self.federation_domain_whitelist is not None and - server_name not in self.federation_domain_whitelist + self.federation_domain_whitelist is not None + and server_name not in self.federation_domain_whitelist ): logger.debug("Federation denied with %s", server_name) continue @@ -159,9 +155,7 @@ class RemoteKey(Resource): cache_misses = dict() for (server_name, key_id, from_server), results in cached.items(): - results = [ - (result["ts_added_ms"], result) for result in results - ] + results = [(result["ts_added_ms"], result) for result in results] if not results and key_id is not None: cache_misses.setdefault(server_name, set()).add(key_id) @@ -178,23 +172,30 @@ class RemoteKey(Resource): logger.debug( "Cached response for %r/%r is older than requested" ": valid_until (%r) < minimum_valid_until (%r)", - server_name, key_id, - ts_valid_until_ms, req_valid_until + server_name, + key_id, + ts_valid_until_ms, + req_valid_until, ) miss = True else: logger.debug( "Cached response for %r/%r is newer than requested" ": valid_until (%r) >= minimum_valid_until (%r)", - server_name, key_id, - ts_valid_until_ms, req_valid_until + server_name, + key_id, + ts_valid_until_ms, + req_valid_until, ) elif (ts_added_ms + ts_valid_until_ms) / 2 < time_now_ms: logger.debug( "Cached response for %r/%r is too old" ": (added (%r) + valid_until (%r)) / 2 < now (%r)", - server_name, key_id, - ts_added_ms, ts_valid_until_ms, time_now_ms + server_name, + key_id, + ts_added_ms, + ts_valid_until_ms, + time_now_ms, ) # We more than half way through the lifetime of the # response. We should fetch a fresh copy. @@ -203,8 +204,11 @@ class RemoteKey(Resource): logger.debug( "Cached response for %r/%r is still valid" ": (added (%r) + valid_until (%r)) / 2 < now (%r)", - server_name, key_id, - ts_added_ms, ts_valid_until_ms, time_now_ms + server_name, + key_id, + ts_added_ms, + ts_valid_until_ms, + time_now_ms, ) if miss: @@ -216,12 +220,10 @@ class RemoteKey(Resource): if cache_misses and query_remote_on_cache_miss: yield self.fetcher.get_keys(cache_misses) - yield self.query_keys( - request, query, query_remote_on_cache_miss=False - ) + yield self.query_keys(request, query, query_remote_on_cache_miss=False) else: result_io = BytesIO() - result_io.write(b"{\"server_keys\":") + result_io.write(b'{"server_keys":') sep = b"[" for json_bytes in json_results: result_io.write(sep) @@ -231,6 +233,4 @@ class RemoteKey(Resource): result_io.write(sep) result_io.write(b"]}") - respond_with_json_bytes( - request, 200, result_io.getvalue(), - ) + respond_with_json_bytes(request, 200, result_io.getvalue()) diff --git a/synapse/rest/media/v0/content_repository.py b/synapse/rest/media/v0/content_repository.py index 5a426ff2f6..86884c0ef4 100644 --- a/synapse/rest/media/v0/content_repository.py +++ b/synapse/rest/media/v0/content_repository.py @@ -44,6 +44,7 @@ class ContentRepoResource(resource.Resource): - Content type base64d (so we can return it when clients GET it) """ + isLeaf = True def __init__(self, hs, directory): @@ -56,7 +57,7 @@ class ContentRepoResource(resource.Resource): # servers. # TODO: A little crude here, we could do this better. - filename = request.path.decode('ascii').split('/')[-1] + filename = request.path.decode("ascii").split("/")[-1] # be paranoid filename = re.sub("[^0-9A-z.-_]", "", filename) @@ -69,17 +70,15 @@ class ContentRepoResource(resource.Resource): base64_contentype = filename.split(".")[1] content_type = base64.urlsafe_b64decode(base64_contentype) logger.info("Sending file %s", file_path) - f = open(file_path, 'rb') - request.setHeader('Content-Type', content_type) + f = open(file_path, "rb") + request.setHeader("Content-Type", content_type) # cache for at least a day. # XXX: we might want to turn this off for data we don't want to # recommend caching as it's sensitive or private - or at least # select private. don't bother setting Expires as all our matrix # clients are smart enough to be happy with Cache-Control (right?) - request.setHeader( - b"Cache-Control", b"public,max-age=86400,s-maxage=86400" - ) + request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400") d = FileSender().beginFileTransfer(f, request) @@ -87,13 +86,15 @@ class ContentRepoResource(resource.Resource): def cbFinished(ignored): f.close() finish_request(request) + d.addCallback(cbFinished) else: respond_with_json_bytes( request, 404, json.dumps(cs_error("Not found", code=Codes.NOT_FOUND)), - send_cors=True) + send_cors=True, + ) return server.NOT_DONE_YET diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 2dcc8f74d6..3318638d3e 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -38,8 +38,8 @@ def parse_media_id(request): server_name, media_id = request.postpath[:2] if isinstance(server_name, bytes): - server_name = server_name.decode('utf-8') - media_id = media_id.decode('utf8') + server_name = server_name.decode("utf-8") + media_id = media_id.decode("utf8") file_name = None if len(request.postpath) > 2: @@ -120,11 +120,11 @@ def add_file_headers(request, media_type, file_size, upload_name): # correctly interpret those as of 0.99.2 and (b) they are a bit of a pain and we # may as well just do the filename* version. if _can_encode_filename_as_token(upload_name): - disposition = 'inline; filename=%s' % (upload_name, ) + disposition = "inline; filename=%s" % (upload_name,) else: - disposition = "inline; filename*=utf-8''%s" % (_quote(upload_name), ) + disposition = "inline; filename*=utf-8''%s" % (_quote(upload_name),) - request.setHeader(b"Content-Disposition", disposition.encode('ascii')) + request.setHeader(b"Content-Disposition", disposition.encode("ascii")) # cache for at least a day. # XXX: we might want to turn this off for data we don't want to @@ -137,10 +137,27 @@ def add_file_headers(request, media_type, file_size, upload_name): # separators as defined in RFC2616. SP and HT are handled separately. # see _can_encode_filename_as_token. -_FILENAME_SEPARATOR_CHARS = set(( - "(", ")", "<", ">", "@", ",", ";", ":", "\\", '"', - "/", "[", "]", "?", "=", "{", "}", -)) +_FILENAME_SEPARATOR_CHARS = set( + ( + "(", + ")", + "<", + ">", + "@", + ",", + ";", + ":", + "\\", + '"', + "/", + "[", + "]", + "?", + "=", + "{", + "}", + ) +) def _can_encode_filename_as_token(x): @@ -271,7 +288,7 @@ def get_filename_from_headers(headers): Returns: A Unicode string of the filename, or None. """ - content_disposition = headers.get(b"Content-Disposition", [b'']) + content_disposition = headers.get(b"Content-Disposition", [b""]) # No header, bail out. if not content_disposition[0]: @@ -293,7 +310,7 @@ def get_filename_from_headers(headers): # Once it is decoded, we can then unquote the %-encoded # parts strictly into a unicode string. upload_name = urllib.parse.unquote( - upload_name_utf8.decode('ascii'), errors="strict" + upload_name_utf8.decode("ascii"), errors="strict" ) except UnicodeDecodeError: # Incorrect UTF-8. @@ -302,7 +319,7 @@ def get_filename_from_headers(headers): # On Python 2, we first unquote the %-encoded parts and then # decode it strictly using UTF-8. try: - upload_name = urllib.parse.unquote(upload_name_utf8).decode('utf8') + upload_name = urllib.parse.unquote(upload_name_utf8).decode("utf8") except UnicodeDecodeError: pass @@ -310,7 +327,7 @@ def get_filename_from_headers(headers): if not upload_name: upload_name_ascii = params.get(b"filename", None) if upload_name_ascii and is_ascii(upload_name_ascii): - upload_name = upload_name_ascii.decode('ascii') + upload_name = upload_name_ascii.decode("ascii") # This may be None here, indicating we did not find a matching name. return upload_name @@ -328,19 +345,19 @@ def _parse_header(line): Tuple[bytes, dict[bytes, bytes]]: the main content-type, followed by the parameter dictionary """ - parts = _parseparam(b';' + line) + parts = _parseparam(b";" + line) key = next(parts) pdict = {} for p in parts: - i = p.find(b'=') + i = p.find(b"=") if i >= 0: name = p[:i].strip().lower() - value = p[i + 1:].strip() + value = p[i + 1 :].strip() # strip double-quotes if len(value) >= 2 and value[0:1] == value[-1:] == b'"': value = value[1:-1] - value = value.replace(b'\\\\', b'\\').replace(b'\\"', b'"') + value = value.replace(b"\\\\", b"\\").replace(b'\\"', b'"') pdict[name] = value return key, pdict @@ -357,16 +374,16 @@ def _parseparam(s): Returns: Iterable[bytes]: the split input """ - while s[:1] == b';': + while s[:1] == b";": s = s[1:] # look for the next ; - end = s.find(b';') + end = s.find(b";") # if there is an odd number of " marks between here and the next ;, skip to the # next ; instead while end > 0 and (s.count(b'"', 0, end) - s.count(b'\\"', 0, end)) % 2: - end = s.find(b';', end + 1) + end = s.find(b";", end + 1) if end < 0: end = len(s) diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py index 77316033f7..fa3d6680fc 100644 --- a/synapse/rest/media/v1/config_resource.py +++ b/synapse/rest/media/v1/config_resource.py @@ -29,9 +29,7 @@ class MediaConfigResource(Resource): config = hs.get_config() self.clock = hs.get_clock() self.auth = hs.get_auth() - self.limits_dict = { - "m.upload.size": config.max_upload_size, - } + self.limits_dict = {"m.upload.size": config.max_upload_size} def render_GET(self, request): self._async_render_GET(request) diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py index bdc5daecc1..a21a35f843 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py @@ -54,18 +54,20 @@ class DownloadResource(Resource): b" plugin-types application/pdf;" b" style-src 'unsafe-inline';" b" media-src 'self';" - b" object-src 'self';" + b" object-src 'self';", ) server_name, media_id, name = parse_media_id(request) if server_name == self.server_name: yield self.media_repo.get_local_media(request, media_id, name) else: allow_remote = synapse.http.servlet.parse_boolean( - request, "allow_remote", default=True) + request, "allow_remote", default=True + ) if not allow_remote: logger.info( "Rejecting request for remote media %s/%s due to allow_remote", - server_name, media_id, + server_name, + media_id, ) respond_404(request) return diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index c8586fa280..e25c382c9c 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py @@ -24,6 +24,7 @@ def _wrap_in_base_path(func): """Takes a function that returns a relative path and turns it into an absolute path based on the location of the primary media store """ + @functools.wraps(func) def _wrapped(self, *args, **kwargs): path = func(self, *args, **kwargs) @@ -43,125 +44,102 @@ class MediaFilePaths(object): def __init__(self, primary_base_path): self.base_path = primary_base_path - def default_thumbnail_rel(self, default_top_level, default_sub_type, width, - height, content_type, method): + def default_thumbnail_rel( + self, default_top_level, default_sub_type, width, height, content_type, method + ): top_level_type, sub_type = content_type.split("/") - file_name = "%i-%i-%s-%s-%s" % ( - width, height, top_level_type, sub_type, method - ) + file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) return os.path.join( - "default_thumbnails", default_top_level, - default_sub_type, file_name + "default_thumbnails", default_top_level, default_sub_type, file_name ) default_thumbnail = _wrap_in_base_path(default_thumbnail_rel) def local_media_filepath_rel(self, media_id): - return os.path.join( - "local_content", - media_id[0:2], media_id[2:4], media_id[4:] - ) + return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:]) local_media_filepath = _wrap_in_base_path(local_media_filepath_rel) - def local_media_thumbnail_rel(self, media_id, width, height, content_type, - method): + def local_media_thumbnail_rel(self, media_id, width, height, content_type, method): top_level_type, sub_type = content_type.split("/") - file_name = "%i-%i-%s-%s-%s" % ( - width, height, top_level_type, sub_type, method - ) + file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) return os.path.join( - "local_thumbnails", - media_id[0:2], media_id[2:4], media_id[4:], - file_name + "local_thumbnails", media_id[0:2], media_id[2:4], media_id[4:], file_name ) local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel) def remote_media_filepath_rel(self, server_name, file_id): return os.path.join( - "remote_content", server_name, - file_id[0:2], file_id[2:4], file_id[4:] + "remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:] ) remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel) - def remote_media_thumbnail_rel(self, server_name, file_id, width, height, - content_type, method): + def remote_media_thumbnail_rel( + self, server_name, file_id, width, height, content_type, method + ): top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type) return os.path.join( - "remote_thumbnail", server_name, - file_id[0:2], file_id[2:4], file_id[4:], - file_name + "remote_thumbnail", + server_name, + file_id[0:2], + file_id[2:4], + file_id[4:], + file_name, ) remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel) def remote_media_thumbnail_dir(self, server_name, file_id): return os.path.join( - self.base_path, "remote_thumbnail", server_name, - file_id[0:2], file_id[2:4], file_id[4:], + self.base_path, + "remote_thumbnail", + server_name, + file_id[0:2], + file_id[2:4], + file_id[4:], ) def url_cache_filepath_rel(self, media_id): if NEW_FORMAT_ID_RE.match(media_id): # Media id is of the form # E.g.: 2017-09-28-fsdRDt24DS234dsf - return os.path.join( - "url_cache", - media_id[:10], media_id[11:] - ) + return os.path.join("url_cache", media_id[:10], media_id[11:]) else: - return os.path.join( - "url_cache", - media_id[0:2], media_id[2:4], media_id[4:], - ) + return os.path.join("url_cache", media_id[0:2], media_id[2:4], media_id[4:]) url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel) def url_cache_filepath_dirs_to_delete(self, media_id): "The dirs to try and remove if we delete the media_id file" if NEW_FORMAT_ID_RE.match(media_id): - return [ - os.path.join( - self.base_path, "url_cache", - media_id[:10], - ), - ] + return [os.path.join(self.base_path, "url_cache", media_id[:10])] else: return [ - os.path.join( - self.base_path, "url_cache", - media_id[0:2], media_id[2:4], - ), - os.path.join( - self.base_path, "url_cache", - media_id[0:2], - ), + os.path.join(self.base_path, "url_cache", media_id[0:2], media_id[2:4]), + os.path.join(self.base_path, "url_cache", media_id[0:2]), ] - def url_cache_thumbnail_rel(self, media_id, width, height, content_type, - method): + def url_cache_thumbnail_rel(self, media_id, width, height, content_type, method): # Media id is of the form # E.g.: 2017-09-28-fsdRDt24DS234dsf top_level_type, sub_type = content_type.split("/") - file_name = "%i-%i-%s-%s-%s" % ( - width, height, top_level_type, sub_type, method - ) + file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) if NEW_FORMAT_ID_RE.match(media_id): return os.path.join( - "url_cache_thumbnails", - media_id[:10], media_id[11:], - file_name + "url_cache_thumbnails", media_id[:10], media_id[11:], file_name ) else: return os.path.join( "url_cache_thumbnails", - media_id[0:2], media_id[2:4], media_id[4:], - file_name + media_id[0:2], + media_id[2:4], + media_id[4:], + file_name, ) url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel) @@ -172,13 +150,15 @@ class MediaFilePaths(object): if NEW_FORMAT_ID_RE.match(media_id): return os.path.join( - self.base_path, "url_cache_thumbnails", - media_id[:10], media_id[11:], + self.base_path, "url_cache_thumbnails", media_id[:10], media_id[11:] ) else: return os.path.join( - self.base_path, "url_cache_thumbnails", - media_id[0:2], media_id[2:4], media_id[4:], + self.base_path, + "url_cache_thumbnails", + media_id[0:2], + media_id[2:4], + media_id[4:], ) def url_cache_thumbnail_dirs_to_delete(self, media_id): @@ -188,26 +168,21 @@ class MediaFilePaths(object): if NEW_FORMAT_ID_RE.match(media_id): return [ os.path.join( - self.base_path, "url_cache_thumbnails", - media_id[:10], media_id[11:], - ), - os.path.join( - self.base_path, "url_cache_thumbnails", - media_id[:10], + self.base_path, "url_cache_thumbnails", media_id[:10], media_id[11:] ), + os.path.join(self.base_path, "url_cache_thumbnails", media_id[:10]), ] else: return [ os.path.join( - self.base_path, "url_cache_thumbnails", - media_id[0:2], media_id[2:4], media_id[4:], - ), - os.path.join( - self.base_path, "url_cache_thumbnails", - media_id[0:2], media_id[2:4], + self.base_path, + "url_cache_thumbnails", + media_id[0:2], + media_id[2:4], + media_id[4:], ), os.path.join( - self.base_path, "url_cache_thumbnails", - media_id[0:2], + self.base_path, "url_cache_thumbnails", media_id[0:2], media_id[2:4] ), + os.path.join(self.base_path, "url_cache_thumbnails", media_id[0:2]), ] diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index a4929dd5db..df3d985a38 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -100,17 +100,16 @@ class MediaRepository(object): storage_providers.append(provider) self.media_storage = MediaStorage( - self.hs, self.primary_base_path, self.filepaths, storage_providers, + self.hs, self.primary_base_path, self.filepaths, storage_providers ) self.clock.looping_call( - self._start_update_recently_accessed, - UPDATE_RECENTLY_ACCESSED_TS, + self._start_update_recently_accessed, UPDATE_RECENTLY_ACCESSED_TS ) def _start_update_recently_accessed(self): return run_as_background_process( - "update_recently_accessed_media", self._update_recently_accessed, + "update_recently_accessed_media", self._update_recently_accessed ) @defer.inlineCallbacks @@ -138,8 +137,9 @@ class MediaRepository(object): self.recently_accessed_locals.add(media_id) @defer.inlineCallbacks - def create_content(self, media_type, upload_name, content, content_length, - auth_user): + def create_content( + self, media_type, upload_name, content, content_length, auth_user + ): """Store uploaded content for a local user and return the mxc URL Args: @@ -154,10 +154,7 @@ class MediaRepository(object): """ media_id = random_string(24) - file_info = FileInfo( - server_name=None, - file_id=media_id, - ) + file_info = FileInfo(server_name=None, file_id=media_id) fname = yield self.media_storage.store_file(content, file_info) @@ -172,9 +169,7 @@ class MediaRepository(object): user_id=auth_user, ) - yield self._generate_thumbnails( - None, media_id, media_id, media_type, - ) + yield self._generate_thumbnails(None, media_id, media_id, media_type) defer.returnValue("mxc://%s/%s" % (self.server_name, media_id)) @@ -205,14 +200,11 @@ class MediaRepository(object): upload_name = name if name else media_info["upload_name"] url_cache = media_info["url_cache"] - file_info = FileInfo( - None, media_id, - url_cache=url_cache, - ) + file_info = FileInfo(None, media_id, url_cache=url_cache) responder = yield self.media_storage.fetch_media(file_info) yield respond_with_responder( - request, responder, media_type, media_length, upload_name, + request, responder, media_type, media_length, upload_name ) @defer.inlineCallbacks @@ -232,8 +224,8 @@ class MediaRepository(object): to request """ if ( - self.federation_domain_whitelist is not None and - server_name not in self.federation_domain_whitelist + self.federation_domain_whitelist is not None + and server_name not in self.federation_domain_whitelist ): raise FederationDeniedError(server_name) @@ -244,7 +236,7 @@ class MediaRepository(object): key = (server_name, media_id) with (yield self.remote_media_linearizer.queue(key)): responder, media_info = yield self._get_remote_media_impl( - server_name, media_id, + server_name, media_id ) # We deliberately stream the file outside the lock @@ -253,7 +245,7 @@ class MediaRepository(object): media_length = media_info["media_length"] upload_name = name if name else media_info["upload_name"] yield respond_with_responder( - request, responder, media_type, media_length, upload_name, + request, responder, media_type, media_length, upload_name ) else: respond_404(request) @@ -272,8 +264,8 @@ class MediaRepository(object): Deferred[dict]: The media_info of the file """ if ( - self.federation_domain_whitelist is not None and - server_name not in self.federation_domain_whitelist + self.federation_domain_whitelist is not None + and server_name not in self.federation_domain_whitelist ): raise FederationDeniedError(server_name) @@ -282,7 +274,7 @@ class MediaRepository(object): key = (server_name, media_id) with (yield self.remote_media_linearizer.queue(key)): responder, media_info = yield self._get_remote_media_impl( - server_name, media_id, + server_name, media_id ) # Ensure we actually use the responder so that it releases resources @@ -305,9 +297,7 @@ class MediaRepository(object): Returns: Deferred[(Responder, media_info)] """ - media_info = yield self.store.get_cached_remote_media( - server_name, media_id - ) + media_info = yield self.store.get_cached_remote_media(server_name, media_id) # file_id is the ID we use to track the file locally. If we've already # seen the file then reuse the existing ID, otherwise genereate a new @@ -331,9 +321,7 @@ class MediaRepository(object): # Failed to find the file anywhere, lets download it. - media_info = yield self._download_remote_file( - server_name, media_id, file_id - ) + media_info = yield self._download_remote_file(server_name, media_id, file_id) responder = yield self.media_storage.fetch_media(file_info) defer.returnValue((responder, media_info)) @@ -354,54 +342,60 @@ class MediaRepository(object): Deferred[MediaInfo] """ - file_info = FileInfo( - server_name=server_name, - file_id=file_id, - ) + file_info = FileInfo(server_name=server_name, file_id=file_id) with self.media_storage.store_into_file(file_info) as (f, fname, finish): - request_path = "/".join(( - "/_matrix/media/v1/download", server_name, media_id, - )) + request_path = "/".join( + ("/_matrix/media/v1/download", server_name, media_id) + ) try: length, headers = yield self.client.get_file( - server_name, request_path, output_stream=f, - max_size=self.max_upload_size, args={ + server_name, + request_path, + output_stream=f, + max_size=self.max_upload_size, + args={ # tell the remote server to 404 if it doesn't # recognise the server_name, to make sure we don't # end up with a routing loop. - "allow_remote": "false", - } + "allow_remote": "false" + }, ) except RequestSendFailed as e: - logger.warn("Request failed fetching remote media %s/%s: %r", - server_name, media_id, e) + logger.warn( + "Request failed fetching remote media %s/%s: %r", + server_name, + media_id, + e, + ) raise SynapseError(502, "Failed to fetch remote media") except HttpResponseException as e: - logger.warn("HTTP error fetching remote media %s/%s: %s", - server_name, media_id, e.response) + logger.warn( + "HTTP error fetching remote media %s/%s: %s", + server_name, + media_id, + e.response, + ) if e.code == twisted.web.http.NOT_FOUND: raise e.to_synapse_error() raise SynapseError(502, "Failed to fetch remote media") except SynapseError: - logger.warn( - "Failed to fetch remote media %s/%s", - server_name, media_id, - ) + logger.warn("Failed to fetch remote media %s/%s", server_name, media_id) raise except NotRetryingDestination: logger.warn("Not retrying destination %r", server_name) raise SynapseError(502, "Failed to fetch remote media") except Exception: - logger.exception("Failed to fetch remote media %s/%s", - server_name, media_id) + logger.exception( + "Failed to fetch remote media %s/%s", server_name, media_id + ) raise SynapseError(502, "Failed to fetch remote media") yield finish() - media_type = headers[b"Content-Type"][0].decode('ascii') + media_type = headers[b"Content-Type"][0].decode("ascii") upload_name = get_filename_from_headers(headers) time_now_ms = self.clock.time_msec() @@ -425,24 +419,23 @@ class MediaRepository(object): "filesystem_id": file_id, } - yield self._generate_thumbnails( - server_name, media_id, file_id, media_type, - ) + yield self._generate_thumbnails(server_name, media_id, file_id, media_type) defer.returnValue(media_info) def _get_thumbnail_requirements(self, media_type): return self.thumbnail_requirements.get(media_type, ()) - def _generate_thumbnail(self, thumbnailer, t_width, t_height, - t_method, t_type): + def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type): m_width = thumbnailer.width m_height = thumbnailer.height if m_width * m_height >= self.max_image_pixels: logger.info( "Image too large to thumbnail %r x %r > %r", - m_width, m_height, self.max_image_pixels + m_width, + m_height, + self.max_image_pixels, ) return @@ -462,17 +455,22 @@ class MediaRepository(object): return t_byte_source @defer.inlineCallbacks - def generate_local_exact_thumbnail(self, media_id, t_width, t_height, - t_method, t_type, url_cache): - input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo( - None, media_id, url_cache=url_cache, - )) + def generate_local_exact_thumbnail( + self, media_id, t_width, t_height, t_method, t_type, url_cache + ): + input_path = yield self.media_storage.ensure_media_is_in_local_cache( + FileInfo(None, media_id, url_cache=url_cache) + ) thumbnailer = Thumbnailer(input_path) t_byte_source = yield logcontext.defer_to_thread( self.hs.get_reactor(), self._generate_thumbnail, - thumbnailer, t_width, t_height, t_method, t_type + thumbnailer, + t_width, + t_height, + t_method, + t_type, ) if t_byte_source: @@ -489,7 +487,7 @@ class MediaRepository(object): ) output_path = yield self.media_storage.store_file( - t_byte_source, file_info, + t_byte_source, file_info ) finally: t_byte_source.close() @@ -505,17 +503,22 @@ class MediaRepository(object): defer.returnValue(output_path) @defer.inlineCallbacks - def generate_remote_exact_thumbnail(self, server_name, file_id, media_id, - t_width, t_height, t_method, t_type): - input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo( - server_name, file_id, url_cache=False, - )) + def generate_remote_exact_thumbnail( + self, server_name, file_id, media_id, t_width, t_height, t_method, t_type + ): + input_path = yield self.media_storage.ensure_media_is_in_local_cache( + FileInfo(server_name, file_id, url_cache=False) + ) thumbnailer = Thumbnailer(input_path) t_byte_source = yield logcontext.defer_to_thread( self.hs.get_reactor(), self._generate_thumbnail, - thumbnailer, t_width, t_height, t_method, t_type + thumbnailer, + t_width, + t_height, + t_method, + t_type, ) if t_byte_source: @@ -531,7 +534,7 @@ class MediaRepository(object): ) output_path = yield self.media_storage.store_file( - t_byte_source, file_info, + t_byte_source, file_info ) finally: t_byte_source.close() @@ -541,15 +544,22 @@ class MediaRepository(object): t_len = os.path.getsize(output_path) yield self.store.store_remote_media_thumbnail( - server_name, media_id, file_id, - t_width, t_height, t_type, t_method, t_len + server_name, + media_id, + file_id, + t_width, + t_height, + t_type, + t_method, + t_len, ) defer.returnValue(output_path) @defer.inlineCallbacks - def _generate_thumbnails(self, server_name, media_id, file_id, media_type, - url_cache=False): + def _generate_thumbnails( + self, server_name, media_id, file_id, media_type, url_cache=False + ): """Generate and store thumbnails for an image. Args: @@ -568,9 +578,9 @@ class MediaRepository(object): if not requirements: return - input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo( - server_name, file_id, url_cache=url_cache, - )) + input_path = yield self.media_storage.ensure_media_is_in_local_cache( + FileInfo(server_name, file_id, url_cache=url_cache) + ) thumbnailer = Thumbnailer(input_path) m_width = thumbnailer.width @@ -579,14 +589,15 @@ class MediaRepository(object): if m_width * m_height >= self.max_image_pixels: logger.info( "Image too large to thumbnail %r x %r > %r", - m_width, m_height, self.max_image_pixels + m_width, + m_height, + self.max_image_pixels, ) return if thumbnailer.transpose_method is not None: m_width, m_height = yield logcontext.defer_to_thread( - self.hs.get_reactor(), - thumbnailer.transpose + self.hs.get_reactor(), thumbnailer.transpose ) # We deduplicate the thumbnail sizes by ignoring the cropped versions if @@ -606,15 +617,11 @@ class MediaRepository(object): # Generate the thumbnail if t_method == "crop": t_byte_source = yield logcontext.defer_to_thread( - self.hs.get_reactor(), - thumbnailer.crop, - t_width, t_height, t_type, + self.hs.get_reactor(), thumbnailer.crop, t_width, t_height, t_type ) elif t_method == "scale": t_byte_source = yield logcontext.defer_to_thread( - self.hs.get_reactor(), - thumbnailer.scale, - t_width, t_height, t_type, + self.hs.get_reactor(), thumbnailer.scale, t_width, t_height, t_type ) else: logger.error("Unrecognized method: %r", t_method) @@ -636,7 +643,7 @@ class MediaRepository(object): ) output_path = yield self.media_storage.store_file( - t_byte_source, file_info, + t_byte_source, file_info ) finally: t_byte_source.close() @@ -646,18 +653,21 @@ class MediaRepository(object): # Write to database if server_name: yield self.store.store_remote_media_thumbnail( - server_name, media_id, file_id, - t_width, t_height, t_type, t_method, t_len + server_name, + media_id, + file_id, + t_width, + t_height, + t_type, + t_method, + t_len, ) else: yield self.store.store_local_thumbnail( media_id, t_width, t_height, t_type, t_method, t_len ) - defer.returnValue({ - "width": m_width, - "height": m_height, - }) + defer.returnValue({"width": m_width, "height": m_height}) @defer.inlineCallbacks def delete_old_remote_media(self, before_ts): @@ -749,11 +759,12 @@ class MediaRepositoryResource(Resource): self.putChild(b"upload", UploadResource(hs, media_repo)) self.putChild(b"download", DownloadResource(hs, media_repo)) - self.putChild(b"thumbnail", ThumbnailResource( - hs, media_repo, media_repo.media_storage, - )) + self.putChild( + b"thumbnail", ThumbnailResource(hs, media_repo, media_repo.media_storage) + ) if hs.config.url_preview_enabled: - self.putChild(b"preview_url", PreviewUrlResource( - hs, media_repo, media_repo.media_storage, - )) + self.putChild( + b"preview_url", + PreviewUrlResource(hs, media_repo, media_repo.media_storage), + ) self.putChild(b"config", MediaConfigResource(hs)) diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index 896078fe76..eff86836fb 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -66,8 +66,7 @@ class MediaStorage(object): with self.store_into_file(file_info) as (f, fname, finish_cb): # Write to the main repository yield logcontext.defer_to_thread( - self.hs.get_reactor(), - _write_file_synchronously, source, f, + self.hs.get_reactor(), _write_file_synchronously, source, f ) yield finish_cb() @@ -179,7 +178,8 @@ class MediaStorage(object): if res: with res: consumer = BackgroundFileConsumer( - open(local_path, "wb"), self.hs.get_reactor()) + open(local_path, "wb"), self.hs.get_reactor() + ) yield res.write_to_consumer(consumer) yield consumer.wait() defer.returnValue(local_path) @@ -217,10 +217,10 @@ class MediaStorage(object): width=file_info.thumbnail_width, height=file_info.thumbnail_height, content_type=file_info.thumbnail_type, - method=file_info.thumbnail_method + method=file_info.thumbnail_method, ) return self.filepaths.remote_media_filepath_rel( - file_info.server_name, file_info.file_id, + file_info.server_name, file_info.file_id ) if file_info.thumbnail: @@ -229,11 +229,9 @@ class MediaStorage(object): width=file_info.thumbnail_width, height=file_info.thumbnail_height, content_type=file_info.thumbnail_type, - method=file_info.thumbnail_method + method=file_info.thumbnail_method, ) - return self.filepaths.local_media_filepath_rel( - file_info.file_id, - ) + return self.filepaths.local_media_filepath_rel(file_info.file_id) def _write_file_synchronously(source, dest): @@ -255,6 +253,7 @@ class FileResponder(Responder): open_file (file): A file like object to be streamed ot the client, is closed when finished streaming. """ + def __init__(self, open_file): self.open_file = open_file diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index acf87709f2..de6f292ffb 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -92,7 +92,7 @@ class PreviewUrlResource(Resource): ) self._cleaner_loop = self.clock.looping_call( - self._start_expire_url_cache_data, 10 * 1000, + self._start_expire_url_cache_data, 10 * 1000 ) def render_OPTIONS(self, request): @@ -121,16 +121,16 @@ class PreviewUrlResource(Resource): for attrib in entry: pattern = entry[attrib] value = getattr(url_tuple, attrib) - logger.debug(( - "Matching attrib '%s' with value '%s' against" - " pattern '%s'" - ) % (attrib, value, pattern)) + logger.debug( + ("Matching attrib '%s' with value '%s' against" " pattern '%s'") + % (attrib, value, pattern) + ) if value is None: match = False continue - if pattern.startswith('^'): + if pattern.startswith("^"): if not re.match(pattern, getattr(url_tuple, attrib)): match = False continue @@ -139,12 +139,9 @@ class PreviewUrlResource(Resource): match = False continue if match: - logger.warn( - "URL %s blocked by url_blacklist entry %s", url, entry - ) + logger.warn("URL %s blocked by url_blacklist entry %s", url, entry) raise SynapseError( - 403, "URL blocked by url pattern blacklist entry", - Codes.UNKNOWN + 403, "URL blocked by url pattern blacklist entry", Codes.UNKNOWN ) # the in-memory cache: @@ -156,14 +153,8 @@ class PreviewUrlResource(Resource): observable = self._cache.get(url) if not observable: - download = run_in_background( - self._do_preview, - url, requester.user, ts, - ) - observable = ObservableDeferred( - download, - consumeErrors=True - ) + download = run_in_background(self._do_preview, url, requester.user, ts) + observable = ObservableDeferred(download, consumeErrors=True) self._cache[url] = observable else: logger.info("Returning cached response") @@ -187,15 +178,15 @@ class PreviewUrlResource(Resource): # historical previews, if we have any) cache_result = yield self.store.get_url_cache(url, ts) if ( - cache_result and - cache_result["expires_ts"] > ts and - cache_result["response_code"] / 100 == 2 + cache_result + and cache_result["expires_ts"] > ts + and cache_result["response_code"] / 100 == 2 ): # It may be stored as text in the database, not as bytes (such as # PostgreSQL). If so, encode it back before handing it on. og = cache_result["og"] if isinstance(og, six.text_type): - og = og.encode('utf8') + og = og.encode("utf8") defer.returnValue(og) return @@ -203,33 +194,31 @@ class PreviewUrlResource(Resource): logger.debug("got media_info of '%s'" % media_info) - if _is_media(media_info['media_type']): - file_id = media_info['filesystem_id'] + if _is_media(media_info["media_type"]): + file_id = media_info["filesystem_id"] dims = yield self.media_repo._generate_thumbnails( - None, file_id, file_id, media_info["media_type"], - url_cache=True, + None, file_id, file_id, media_info["media_type"], url_cache=True ) og = { - "og:description": media_info['download_name'], - "og:image": "mxc://%s/%s" % ( - self.server_name, media_info['filesystem_id'] - ), - "og:image:type": media_info['media_type'], - "matrix:image:size": media_info['media_length'], + "og:description": media_info["download_name"], + "og:image": "mxc://%s/%s" + % (self.server_name, media_info["filesystem_id"]), + "og:image:type": media_info["media_type"], + "matrix:image:size": media_info["media_length"], } if dims: - og["og:image:width"] = dims['width'] - og["og:image:height"] = dims['height'] + og["og:image:width"] = dims["width"] + og["og:image:height"] = dims["height"] else: logger.warn("Couldn't get dims for %s" % url) # define our OG response for this media - elif _is_html(media_info['media_type']): + elif _is_html(media_info["media_type"]): # TODO: somehow stop a big HTML tree from exploding synapse's RAM - with open(media_info['filename'], 'rb') as file: + with open(media_info["filename"], "rb") as file: body = file.read() encoding = None @@ -242,45 +231,43 @@ class PreviewUrlResource(Resource): # If we find a match, it should take precedence over the # Content-Type header, so set it here. if match: - encoding = match.group(1).decode('ascii') + encoding = match.group(1).decode("ascii") # If we don't find a match, we'll look at the HTTP Content-Type, and # if that doesn't exist, we'll fall back to UTF-8. if not encoding: - match = _content_type_match.match( - media_info['media_type'] - ) + match = _content_type_match.match(media_info["media_type"]) encoding = match.group(1) if match else "utf-8" - og = decode_and_calc_og(body, media_info['uri'], encoding) + og = decode_and_calc_og(body, media_info["uri"], encoding) # pre-cache the image for posterity # FIXME: it might be cleaner to use the same flow as the main /preview_url # request itself and benefit from the same caching etc. But for now we # just rely on the caching on the master request to speed things up. - if 'og:image' in og and og['og:image']: + if "og:image" in og and og["og:image"]: image_info = yield self._download_url( - _rebase_url(og['og:image'], media_info['uri']), user + _rebase_url(og["og:image"], media_info["uri"]), user ) - if _is_media(image_info['media_type']): + if _is_media(image_info["media_type"]): # TODO: make sure we don't choke on white-on-transparent images - file_id = image_info['filesystem_id'] + file_id = image_info["filesystem_id"] dims = yield self.media_repo._generate_thumbnails( - None, file_id, file_id, image_info["media_type"], - url_cache=True, + None, file_id, file_id, image_info["media_type"], url_cache=True ) if dims: - og["og:image:width"] = dims['width'] - og["og:image:height"] = dims['height'] + og["og:image:width"] = dims["width"] + og["og:image:height"] = dims["height"] else: logger.warn("Couldn't get dims for %s" % og["og:image"]) og["og:image"] = "mxc://%s/%s" % ( - self.server_name, image_info['filesystem_id'] + self.server_name, + image_info["filesystem_id"], ) - og["og:image:type"] = image_info['media_type'] - og["matrix:image:size"] = image_info['media_length'] + og["og:image:type"] = image_info["media_type"] + og["matrix:image:size"] = image_info["media_length"] else: del og["og:image"] else: @@ -289,7 +276,7 @@ class PreviewUrlResource(Resource): logger.debug("Calculated OG for %s as %s" % (url, og)) - jsonog = json.dumps(og).encode('utf8') + jsonog = json.dumps(og).encode("utf8") # store OG in history-aware DB cache yield self.store.store_url_cache( @@ -310,19 +297,15 @@ class PreviewUrlResource(Resource): # we're most likely being explicitly triggered by a human rather than a # bot, so are we really a robot? - file_id = datetime.date.today().isoformat() + '_' + random_string(16) + file_id = datetime.date.today().isoformat() + "_" + random_string(16) - file_info = FileInfo( - server_name=None, - file_id=file_id, - url_cache=True, - ) + file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True) with self.media_storage.store_into_file(file_info) as (f, fname, finish): try: logger.debug("Trying to get url '%s'" % url) length, headers, uri, code = yield self.client.get_file( - url, output_stream=f, max_size=self.max_spider_size, + url, output_stream=f, max_size=self.max_spider_size ) except SynapseError: # Pass SynapseErrors through directly, so that the servlet @@ -334,24 +317,25 @@ class PreviewUrlResource(Resource): # Note: This will also be the case if one of the resolved IP # addresses is blacklisted raise SynapseError( - 502, "DNS resolution failure during URL preview generation", - Codes.UNKNOWN + 502, + "DNS resolution failure during URL preview generation", + Codes.UNKNOWN, ) except Exception as e: # FIXME: pass through 404s and other error messages nicely logger.warn("Error downloading %s: %r", url, e) raise SynapseError( - 500, "Failed to download content: %s" % ( - traceback.format_exception_only(sys.exc_info()[0], e), - ), + 500, + "Failed to download content: %s" + % (traceback.format_exception_only(sys.exc_info()[0], e),), Codes.UNKNOWN, ) yield finish() try: if b"Content-Type" in headers: - media_type = headers[b"Content-Type"][0].decode('ascii') + media_type = headers[b"Content-Type"][0].decode("ascii") else: media_type = "application/octet-stream" time_now_ms = self.clock.time_msec() @@ -375,24 +359,26 @@ class PreviewUrlResource(Resource): # therefore not expire it. raise - defer.returnValue({ - "media_type": media_type, - "media_length": length, - "download_name": download_name, - "created_ts": time_now_ms, - "filesystem_id": file_id, - "filename": fname, - "uri": uri, - "response_code": code, - # FIXME: we should calculate a proper expiration based on the - # Cache-Control and Expire headers. But for now, assume 1 hour. - "expires": 60 * 60 * 1000, - "etag": headers["ETag"][0] if "ETag" in headers else None, - }) + defer.returnValue( + { + "media_type": media_type, + "media_length": length, + "download_name": download_name, + "created_ts": time_now_ms, + "filesystem_id": file_id, + "filename": fname, + "uri": uri, + "response_code": code, + # FIXME: we should calculate a proper expiration based on the + # Cache-Control and Expire headers. But for now, assume 1 hour. + "expires": 60 * 60 * 1000, + "etag": headers["ETag"][0] if "ETag" in headers else None, + } + ) def _start_expire_url_cache_data(self): return run_as_background_process( - "expire_url_cache_data", self._expire_url_cache_data, + "expire_url_cache_data", self._expire_url_cache_data ) @defer.inlineCallbacks @@ -496,7 +482,7 @@ def decode_and_calc_og(body, media_uri, request_encoding=None): # blindly try decoding the body as utf-8, which seems to fix # the charset mismatches on https://google.com parser = etree.HTMLParser(recover=True, encoding=request_encoding) - tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser) + tree = etree.fromstring(body.decode("utf-8", "ignore"), parser) og = _calc_og(tree, media_uri) return og @@ -523,8 +509,8 @@ def _calc_og(tree, media_uri): og = {} for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"): - if 'content' in tag.attrib: - og[tag.attrib['property']] = tag.attrib['content'] + if "content" in tag.attrib: + og[tag.attrib["property"]] = tag.attrib["content"] # TODO: grab article: meta tags too, e.g.: @@ -535,39 +521,43 @@ def _calc_og(tree, media_uri): # "article:published_time" content="2016-03-31T19:58:24+00:00" /> # "article:modified_time" content="2016-04-01T18:31:53+00:00" /> - if 'og:title' not in og: + if "og:title" not in og: # do some basic spidering of the HTML title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]") if title and title[0].text is not None: - og['og:title'] = title[0].text.strip() + og["og:title"] = title[0].text.strip() else: - og['og:title'] = None + og["og:title"] = None - if 'og:image' not in og: + if "og:image" not in og: # TODO: extract a favicon failing all else meta_image = tree.xpath( "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content" ) if meta_image: - og['og:image'] = _rebase_url(meta_image[0], media_uri) + og["og:image"] = _rebase_url(meta_image[0], media_uri) else: # TODO: consider inlined CSS styles as well as width & height attribs images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]") - images = sorted(images, key=lambda i: ( - -1 * float(i.attrib['width']) * float(i.attrib['height']) - )) + images = sorted( + images, + key=lambda i: ( + -1 * float(i.attrib["width"]) * float(i.attrib["height"]) + ), + ) if not images: images = tree.xpath("//img[@src]") if images: - og['og:image'] = images[0].attrib['src'] + og["og:image"] = images[0].attrib["src"] - if 'og:description' not in og: + if "og:description" not in og: meta_description = tree.xpath( "//*/meta" "[translate(@name, 'DESCRIPTION', 'description')='description']" - "/@content") + "/@content" + ) if meta_description: - og['og:description'] = meta_description[0] + og["og:description"] = meta_description[0] else: # grab any text nodes which are inside the tag... # unless they are within an HTML5 semantic markup tag... @@ -588,18 +578,18 @@ def _calc_og(tree, media_uri): "script", "noscript", "style", - etree.Comment + etree.Comment, ) # Split all the text nodes into paragraphs (by splitting on new # lines) text_nodes = ( - re.sub(r'\s+', '\n', el).strip() + re.sub(r"\s+", "\n", el).strip() for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE) ) - og['og:description'] = summarize_paragraphs(text_nodes) + og["og:description"] = summarize_paragraphs(text_nodes) else: - og['og:description'] = summarize_paragraphs([og['og:description']]) + og["og:description"] = summarize_paragraphs([og["og:description"]]) # TODO: delete the url downloads to stop diskfilling, # as we only ever cared about its OG @@ -636,7 +626,7 @@ def _iterate_over_text(tree, *tags_to_ignore): [child, child.tail] if child.tail else [child] for child in el.iterchildren() ), - elements + elements, ) @@ -647,8 +637,8 @@ def _rebase_url(url, base): url[0] = base[0] or "http" if not url[1]: # fix up hostname url[1] = base[1] - if not url[2].startswith('/'): - url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2] + if not url[2].startswith("/"): + url[2] = re.sub(r"/[^/]+$", "/", base[2]) + url[2] return urlparse.urlunparse(url) @@ -659,9 +649,8 @@ def _is_media(content_type): def _is_html(content_type): content_type = content_type.lower() - if ( - content_type.startswith("text/html") or - content_type.startswith("application/xhtml") + if content_type.startswith("text/html") or content_type.startswith( + "application/xhtml" ): return True @@ -671,19 +660,19 @@ def summarize_paragraphs(text_nodes, min_size=200, max_size=500): # first paragraph and then word boundaries. # TODO: Respect sentences? - description = '' + description = "" # Keep adding paragraphs until we get to the MIN_SIZE. for text_node in text_nodes: if len(description) < min_size: - text_node = re.sub(r'[\t \r\n]+', ' ', text_node) - description += text_node + '\n\n' + text_node = re.sub(r"[\t \r\n]+", " ", text_node) + description += text_node + "\n\n" else: break description = description.strip() - description = re.sub(r'[\t ]+', ' ', description) - description = re.sub(r'[\t \r\n]*[\r\n]+', '\n\n', description) + description = re.sub(r"[\t ]+", " ", description) + description = re.sub(r"[\t \r\n]*[\r\n]+", "\n\n", description) # If the concatenation of paragraphs to get above MIN_SIZE # took us over MAX_SIZE, then we need to truncate mid paragraph @@ -715,5 +704,5 @@ def summarize_paragraphs(text_nodes, min_size=200, max_size=500): # We always add an ellipsis because at the very least # we chopped mid paragraph. - description = new_desc.strip() + u"…" + description = new_desc.strip() + "…" return description if description else None diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index d90cbfb56a..359b45ebfc 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -32,6 +32,7 @@ class StorageProvider(object): """A storage provider is a service that can store uploaded media and retrieve them. """ + def store_file(self, path, file_info): """Store the file described by file_info. The actual contents can be retrieved by reading the file in file_info.upload_path. @@ -70,6 +71,7 @@ class StorageProviderWrapper(StorageProvider): uploaded, or todo the upload in the backgroud. store_remote (bool): Whether remote media should be uploaded """ + def __init__(self, backend, store_local, store_synchronous, store_remote): self.backend = backend self.store_local = store_local @@ -92,6 +94,7 @@ class StorageProviderWrapper(StorageProvider): return self.backend.store_file(path, file_info) except Exception: logger.exception("Error storing file") + run_in_background(store) return defer.succeed(None) @@ -123,8 +126,7 @@ class FileStorageProviderBackend(StorageProvider): os.makedirs(dirname) return logcontext.defer_to_thread( - self.hs.get_reactor(), - shutil.copyfile, primary_fname, backup_fname, + self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname ) def fetch(self, path, file_info): diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index 35a750923b..ca84c9f139 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -74,19 +74,18 @@ class ThumbnailResource(Resource): else: if self.dynamic_thumbnails: yield self._select_or_generate_remote_thumbnail( - request, server_name, media_id, - width, height, method, m_type + request, server_name, media_id, width, height, method, m_type ) else: yield self._respond_remote_thumbnail( - request, server_name, media_id, - width, height, method, m_type + request, server_name, media_id, width, height, method, m_type ) self.media_repo.mark_recently_accessed(server_name, media_id) @defer.inlineCallbacks - def _respond_local_thumbnail(self, request, media_id, width, height, - method, m_type): + def _respond_local_thumbnail( + self, request, media_id, width, height, method, m_type + ): media_info = yield self.store.get_local_media(media_id) if not media_info: @@ -105,7 +104,8 @@ class ThumbnailResource(Resource): ) file_info = FileInfo( - server_name=None, file_id=media_id, + server_name=None, + file_id=media_id, url_cache=media_info["url_cache"], thumbnail=True, thumbnail_width=thumbnail_info["thumbnail_width"], @@ -124,9 +124,15 @@ class ThumbnailResource(Resource): respond_404(request) @defer.inlineCallbacks - def _select_or_generate_local_thumbnail(self, request, media_id, desired_width, - desired_height, desired_method, - desired_type): + def _select_or_generate_local_thumbnail( + self, + request, + media_id, + desired_width, + desired_height, + desired_method, + desired_type, + ): media_info = yield self.store.get_local_media(media_id) if not media_info: @@ -146,7 +152,8 @@ class ThumbnailResource(Resource): if t_w and t_h and t_method and t_type: file_info = FileInfo( - server_name=None, file_id=media_id, + server_name=None, + file_id=media_id, url_cache=media_info["url_cache"], thumbnail=True, thumbnail_width=info["thumbnail_width"], @@ -167,7 +174,11 @@ class ThumbnailResource(Resource): # Okay, so we generate one. file_path = yield self.media_repo.generate_local_exact_thumbnail( - media_id, desired_width, desired_height, desired_method, desired_type, + media_id, + desired_width, + desired_height, + desired_method, + desired_type, url_cache=media_info["url_cache"], ) @@ -178,13 +189,20 @@ class ThumbnailResource(Resource): respond_404(request) @defer.inlineCallbacks - def _select_or_generate_remote_thumbnail(self, request, server_name, media_id, - desired_width, desired_height, - desired_method, desired_type): + def _select_or_generate_remote_thumbnail( + self, + request, + server_name, + media_id, + desired_width, + desired_height, + desired_method, + desired_type, + ): media_info = yield self.media_repo.get_remote_media_info(server_name, media_id) thumbnail_infos = yield self.store.get_remote_media_thumbnails( - server_name, media_id, + server_name, media_id ) file_id = media_info["filesystem_id"] @@ -197,7 +215,8 @@ class ThumbnailResource(Resource): if t_w and t_h and t_method and t_type: file_info = FileInfo( - server_name=server_name, file_id=media_info["filesystem_id"], + server_name=server_name, + file_id=media_info["filesystem_id"], thumbnail=True, thumbnail_width=info["thumbnail_width"], thumbnail_height=info["thumbnail_height"], @@ -217,8 +236,13 @@ class ThumbnailResource(Resource): # Okay, so we generate one. file_path = yield self.media_repo.generate_remote_exact_thumbnail( - server_name, file_id, media_id, desired_width, - desired_height, desired_method, desired_type + server_name, + file_id, + media_id, + desired_width, + desired_height, + desired_method, + desired_type, ) if file_path: @@ -228,15 +252,16 @@ class ThumbnailResource(Resource): respond_404(request) @defer.inlineCallbacks - def _respond_remote_thumbnail(self, request, server_name, media_id, width, - height, method, m_type): + def _respond_remote_thumbnail( + self, request, server_name, media_id, width, height, method, m_type + ): # TODO: Don't download the whole remote file # We should proxy the thumbnail from the remote server instead of # downloading the remote file and generating our own thumbnails. media_info = yield self.media_repo.get_remote_media_info(server_name, media_id) thumbnail_infos = yield self.store.get_remote_media_thumbnails( - server_name, media_id, + server_name, media_id ) if thumbnail_infos: @@ -244,7 +269,8 @@ class ThumbnailResource(Resource): width, height, method, m_type, thumbnail_infos ) file_info = FileInfo( - server_name=server_name, file_id=media_info["filesystem_id"], + server_name=server_name, + file_id=media_info["filesystem_id"], thumbnail=True, thumbnail_width=thumbnail_info["thumbnail_width"], thumbnail_height=thumbnail_info["thumbnail_height"], @@ -261,8 +287,14 @@ class ThumbnailResource(Resource): logger.info("Failed to find any generated thumbnails") respond_404(request) - def _select_thumbnail(self, desired_width, desired_height, desired_method, - desired_type, thumbnail_infos): + def _select_thumbnail( + self, + desired_width, + desired_height, + desired_method, + desired_type, + thumbnail_infos, + ): d_w = desired_width d_h = desired_height @@ -280,15 +312,27 @@ class ThumbnailResource(Resource): type_quality = desired_type != info["thumbnail_type"] length_quality = info["thumbnail_length"] if t_w >= d_w or t_h >= d_h: - info_list.append(( - aspect_quality, min_quality, size_quality, type_quality, - length_quality, info - )) + info_list.append( + ( + aspect_quality, + min_quality, + size_quality, + type_quality, + length_quality, + info, + ) + ) else: - info_list2.append(( - aspect_quality, min_quality, size_quality, type_quality, - length_quality, info - )) + info_list2.append( + ( + aspect_quality, + min_quality, + size_quality, + type_quality, + length_quality, + info, + ) + ) if info_list: return min(info_list)[-1] else: @@ -304,13 +348,11 @@ class ThumbnailResource(Resource): type_quality = desired_type != info["thumbnail_type"] length_quality = info["thumbnail_length"] if t_method == "scale" and (t_w >= d_w or t_h >= d_h): - info_list.append(( - size_quality, type_quality, length_quality, info - )) + info_list.append((size_quality, type_quality, length_quality, info)) elif t_method == "scale": - info_list2.append(( - size_quality, type_quality, length_quality, info - )) + info_list2.append( + (size_quality, type_quality, length_quality, info) + ) if info_list: return min(info_list)[-1] else: diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 3efd0d80fc..90d8e6bffe 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -28,16 +28,13 @@ EXIF_TRANSPOSE_MAPPINGS = { 5: Image.TRANSPOSE, 6: Image.ROTATE_270, 7: Image.TRANSVERSE, - 8: Image.ROTATE_90 + 8: Image.ROTATE_90, } class Thumbnailer(object): - FORMATS = { - "image/jpeg": "JPEG", - "image/png": "PNG", - } + FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"} def __init__(self, input_path): self.image = Image.open(input_path) @@ -110,17 +107,13 @@ class Thumbnailer(object): """ if width * self.height > height * self.width: scaled_height = (width * self.height) // self.width - scaled_image = self.image.resize( - (width, scaled_height), Image.ANTIALIAS - ) + scaled_image = self.image.resize((width, scaled_height), Image.ANTIALIAS) crop_top = (scaled_height - height) // 2 crop_bottom = height + crop_top cropped = scaled_image.crop((0, crop_top, width, crop_bottom)) else: scaled_width = (height * self.width) // self.height - scaled_image = self.image.resize( - (scaled_width, height), Image.ANTIALIAS - ) + scaled_image = self.image.resize((scaled_width, height), Image.ANTIALIAS) crop_left = (scaled_width - width) // 2 crop_right = width + crop_left cropped = scaled_image.crop((crop_left, 0, crop_right, height)) diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index c1240e1963..d1d7e959f0 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -55,48 +55,36 @@ class UploadResource(Resource): requester = yield self.auth.get_user_by_req(request) # TODO: The checks here are a bit late. The content will have # already been uploaded to a tmp file at this point - content_length = request.getHeader(b"Content-Length").decode('ascii') + content_length = request.getHeader(b"Content-Length").decode("ascii") if content_length is None: - raise SynapseError( - msg="Request must specify a Content-Length", code=400 - ) + raise SynapseError(msg="Request must specify a Content-Length", code=400) if int(content_length) > self.max_upload_size: - raise SynapseError( - msg="Upload request body is too large", - code=413, - ) + raise SynapseError(msg="Upload request body is too large", code=413) upload_name = parse_string(request, b"filename", encoding=None) if upload_name: try: - upload_name = upload_name.decode('utf8') + upload_name = upload_name.decode("utf8") except UnicodeDecodeError: raise SynapseError( - msg="Invalid UTF-8 filename parameter: %r" % (upload_name), - code=400, + msg="Invalid UTF-8 filename parameter: %r" % (upload_name), code=400 ) headers = request.requestHeaders if headers.hasHeader(b"Content-Type"): - media_type = headers.getRawHeaders(b"Content-Type")[0].decode('ascii') + media_type = headers.getRawHeaders(b"Content-Type")[0].decode("ascii") else: - raise SynapseError( - msg="Upload request missing 'Content-Type'", - code=400, - ) + raise SynapseError(msg="Upload request missing 'Content-Type'", code=400) # if headers.hasHeader(b"Content-Disposition"): # disposition = headers.getRawHeaders(b"Content-Disposition")[0] # TODO(markjh): parse content-dispostion content_uri = yield self.media_repo.create_content( - media_type, upload_name, request.content, - content_length, requester.user + media_type, upload_name, request.content, content_length, requester.user ) logger.info("Uploaded content with URI %r", content_uri) - respond_with_json( - request, 200, {"content_uri": content_uri}, send_cors=True - ) + respond_with_json(request, 200, {"content_uri": content_uri}, send_cors=True) diff --git a/synapse/rest/saml2/metadata_resource.py b/synapse/rest/saml2/metadata_resource.py index e8c680aeb4..1e8526e22e 100644 --- a/synapse/rest/saml2/metadata_resource.py +++ b/synapse/rest/saml2/metadata_resource.py @@ -30,7 +30,7 @@ class SAML2MetadataResource(Resource): def render_GET(self, request): metadata_xml = saml2.metadata.create_metadata_string( - configfile=None, config=self.sp_config, + configfile=None, config=self.sp_config ) request.setHeader(b"Content-Type", b"text/xml; charset=utf-8") return metadata_xml diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/saml2/response_resource.py index 69fb77b322..ab14b70675 100644 --- a/synapse/rest/saml2/response_resource.py +++ b/synapse/rest/saml2/response_resource.py @@ -46,18 +46,16 @@ class SAML2ResponseResource(Resource): @wrap_html_request_handler def _async_render_POST(self, request): - resp_bytes = parse_string(request, 'SAMLResponse', required=True) - relay_state = parse_string(request, 'RelayState', required=True) + resp_bytes = parse_string(request, "SAMLResponse", required=True) + relay_state = parse_string(request, "RelayState", required=True) try: saml2_auth = self._saml_client.parse_authn_request_response( - resp_bytes, saml2.BINDING_HTTP_POST, + resp_bytes, saml2.BINDING_HTTP_POST ) except Exception as e: logger.warning("Exception parsing SAML2 response", exc_info=1) - raise CodeMessageException( - 400, "Unable to parse SAML2 response: %s" % (e,), - ) + raise CodeMessageException(400, "Unable to parse SAML2 response: %s" % (e,)) if saml2_auth.not_signed: raise CodeMessageException(400, "SAML2 response was not signed") @@ -69,6 +67,5 @@ class SAML2ResponseResource(Resource): displayName = saml2_auth.ava.get("displayName", [None])[0] return self._sso_auth_handler.on_successful_auth( - username, request, relay_state, - user_display_name=displayName, + username, request, relay_state, user_display_name=displayName ) diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index a7fa4f39af..5e8fda4b65 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -29,6 +29,7 @@ class WellKnownBuilder(object): Args: hs (synapse.server.HomeServer): """ + def __init__(self, hs): self._config = hs.config @@ -37,15 +38,11 @@ class WellKnownBuilder(object): if self._config.public_baseurl is None: return None - result = { - "m.homeserver": { - "base_url": self._config.public_baseurl, - }, - } + result = {"m.homeserver": {"base_url": self._config.public_baseurl}} if self._config.default_identity_server: result["m.identity_server"] = { - "base_url": self._config.default_identity_server, + "base_url": self._config.default_identity_server } return result @@ -66,7 +63,7 @@ class WellKnownResource(Resource): if not r: request.setResponseCode(404) request.setHeader(b"Content-Type", b"text/plain") - return b'.well-known not available' + return b".well-known not available" logger.debug("returning: %s", r) request.setHeader(b"Content-Type", b"application/json") diff --git a/synapse/secrets.py b/synapse/secrets.py index f6280f951c..0b327a0f82 100644 --- a/synapse/secrets.py +++ b/synapse/secrets.py @@ -29,6 +29,7 @@ if sys.version_info[0:2] >= (3, 6): def Secrets(): return secrets + else: import os import binascii @@ -38,4 +39,4 @@ else: return os.urandom(nbytes) def token_hex(self, nbytes=32): - return binascii.hexlify(self.token_bytes(nbytes)).decode('ascii') + return binascii.hexlify(self.token_bytes(nbytes)).decode("ascii") diff --git a/synapse/server.py b/synapse/server.py index a54e023cc9..a9592c396c 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -91,7 +91,9 @@ from synapse.rest.media.v1.media_repository import ( from synapse.secrets import Secrets from synapse.server_notices.server_notices_manager import ServerNoticesManager from synapse.server_notices.server_notices_sender import ServerNoticesSender -from synapse.server_notices.worker_server_notices_sender import WorkerServerNoticesSender +from synapse.server_notices.worker_server_notices_sender import ( + WorkerServerNoticesSender, +) from synapse.state import StateHandler, StateResolutionHandler from synapse.streams.events import EventSources from synapse.util import Clock @@ -127,79 +129,76 @@ class HomeServer(object): __metaclass__ = abc.ABCMeta DEPENDENCIES = [ - 'http_client', - 'db_pool', - 'federation_client', - 'federation_server', - 'handlers', - 'auth', - 'room_creation_handler', - 'state_handler', - 'state_resolution_handler', - 'presence_handler', - 'sync_handler', - 'typing_handler', - 'room_list_handler', - 'acme_handler', - 'auth_handler', - 'device_handler', - 'stats_handler', - 'e2e_keys_handler', - 'e2e_room_keys_handler', - 'event_handler', - 'event_stream_handler', - 'initial_sync_handler', - 'application_service_api', - 'application_service_scheduler', - 'application_service_handler', - 'device_message_handler', - 'profile_handler', - 'event_creation_handler', - 'deactivate_account_handler', - 'set_password_handler', - 'notifier', - 'event_sources', - 'keyring', - 'pusherpool', - 'event_builder_factory', - 'filtering', - 'http_client_context_factory', - 'simple_http_client', - 'media_repository', - 'media_repository_resource', - 'federation_transport_client', - 'federation_sender', - 'receipts_handler', - 'macaroon_generator', - 'tcp_replication', - 'read_marker_handler', - 'action_generator', - 'user_directory_handler', - 'groups_local_handler', - 'groups_server_handler', - 'groups_attestation_signing', - 'groups_attestation_renewer', - 'secrets', - 'spam_checker', - 'third_party_event_rules', - 'room_member_handler', - 'federation_registry', - 'server_notices_manager', - 'server_notices_sender', - 'message_handler', - 'pagination_handler', - 'room_context_handler', - 'sendmail', - 'registration_handler', - 'account_validity_handler', - 'event_client_serializer', - ] - - REQUIRED_ON_MASTER_STARTUP = [ + "http_client", + "db_pool", + "federation_client", + "federation_server", + "handlers", + "auth", + "room_creation_handler", + "state_handler", + "state_resolution_handler", + "presence_handler", + "sync_handler", + "typing_handler", + "room_list_handler", + "acme_handler", + "auth_handler", + "device_handler", + "stats_handler", + "e2e_keys_handler", + "e2e_room_keys_handler", + "event_handler", + "event_stream_handler", + "initial_sync_handler", + "application_service_api", + "application_service_scheduler", + "application_service_handler", + "device_message_handler", + "profile_handler", + "event_creation_handler", + "deactivate_account_handler", + "set_password_handler", + "notifier", + "event_sources", + "keyring", + "pusherpool", + "event_builder_factory", + "filtering", + "http_client_context_factory", + "simple_http_client", + "media_repository", + "media_repository_resource", + "federation_transport_client", + "federation_sender", + "receipts_handler", + "macaroon_generator", + "tcp_replication", + "read_marker_handler", + "action_generator", "user_directory_handler", - "stats_handler" + "groups_local_handler", + "groups_server_handler", + "groups_attestation_signing", + "groups_attestation_renewer", + "secrets", + "spam_checker", + "third_party_event_rules", + "room_member_handler", + "federation_registry", + "server_notices_manager", + "server_notices_sender", + "message_handler", + "pagination_handler", + "room_context_handler", + "sendmail", + "registration_handler", + "account_validity_handler", + "event_client_serializer", ] + REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"] + # This is overridden in derived application classes # (such as synapse.app.homeserver.SynapseHomeServer) and gives the class to be # instantiated during setup() for future return by get_datastore() @@ -410,9 +409,7 @@ class HomeServer(object): name = self.db_config["name"] return adbapi.ConnectionPool( - name, - cp_reactor=self.get_reactor(), - **self.db_config.get("args", {}) + name, cp_reactor=self.get_reactor(), **self.db_config.get("args", {}) ) def get_db_conn(self, run_new_connection=True): @@ -424,7 +421,8 @@ class HomeServer(object): # Any param beginning with cp_ is a parameter for adbapi, and should # not be passed to the database engine. db_params = { - k: v for k, v in self.db_config.get("args", {}).items() + k: v + for k, v in self.db_config.get("args", {}).items() if not k.startswith("cp_") } db_conn = self.database_engine.module.connect(**db_params) @@ -555,9 +553,7 @@ def _make_dependency_method(depname): if builder: # Prevent cyclic dependencies from deadlocking if depname in hs._building: - raise ValueError("Cyclic dependency while building %s" % ( - depname, - )) + raise ValueError("Cyclic dependency while building %s" % (depname,)) hs._building[depname] = 1 dep = builder() @@ -568,9 +564,7 @@ def _make_dependency_method(depname): return dep raise NotImplementedError( - "%s has no %s nor a builder for it" % ( - type(hs).__name__, depname, - ) + "%s has no %s nor a builder for it" % (type(hs).__name__, depname) ) setattr(HomeServer, "get_%s" % (depname), _get) diff --git a/synapse/server.pyi b/synapse/server.pyi index 9583e82d52..16f8f6b573 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -22,60 +22,57 @@ class HomeServer(object): @property def config(self) -> synapse.config.homeserver.HomeServerConfig: pass - def get_auth(self) -> synapse.api.auth.Auth: pass - def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler: pass - def get_datastore(self) -> synapse.storage.DataStore: pass - def get_device_handler(self) -> synapse.handlers.device.DeviceHandler: pass - def get_e2e_keys_handler(self) -> synapse.handlers.e2e_keys.E2eKeysHandler: pass - def get_handlers(self) -> synapse.handlers.Handlers: pass - def get_state_handler(self) -> synapse.state.StateHandler: pass - def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler: pass - - def get_deactivate_account_handler(self) -> synapse.handlers.deactivate_account.DeactivateAccountHandler: + def get_deactivate_account_handler( + self + ) -> synapse.handlers.deactivate_account.DeactivateAccountHandler: pass - def get_room_creation_handler(self) -> synapse.handlers.room.RoomCreationHandler: pass - def get_room_member_handler(self) -> synapse.handlers.room_member.RoomMemberHandler: pass - - def get_event_creation_handler(self) -> synapse.handlers.message.EventCreationHandler: + def get_event_creation_handler( + self + ) -> synapse.handlers.message.EventCreationHandler: pass - - def get_set_password_handler(self) -> synapse.handlers.set_password.SetPasswordHandler: + def get_set_password_handler( + self + ) -> synapse.handlers.set_password.SetPasswordHandler: pass - def get_federation_sender(self) -> synapse.federation.sender.FederationSender: pass - - def get_federation_transport_client(self) -> synapse.federation.transport.client.TransportLayerClient: + def get_federation_transport_client( + self + ) -> synapse.federation.transport.client.TransportLayerClient: pass - - def get_media_repository_resource(self) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource: + def get_media_repository_resource( + self + ) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource: pass - - def get_media_repository(self) -> synapse.rest.media.v1.media_repository.MediaRepository: + def get_media_repository( + self + ) -> synapse.rest.media.v1.media_repository.MediaRepository: pass - - def get_server_notices_manager(self) -> synapse.server_notices.server_notices_manager.ServerNoticesManager: + def get_server_notices_manager( + self + ) -> synapse.server_notices.server_notices_manager.ServerNoticesManager: pass - - def get_server_notices_sender(self) -> synapse.server_notices.server_notices_sender.ServerNoticesSender: + def get_server_notices_sender( + self + ) -> synapse.server_notices.server_notices_sender.ServerNoticesSender: pass diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py index 5e3044d164..415e9c17d8 100644 --- a/synapse/server_notices/consent_server_notices.py +++ b/synapse/server_notices/consent_server_notices.py @@ -30,6 +30,7 @@ class ConsentServerNotices(object): """Keeps track of whether we need to send users server_notices about privacy policy consent, and sends one if we do. """ + def __init__(self, hs): """ @@ -49,12 +50,11 @@ class ConsentServerNotices(object): if not self._server_notices_manager.is_enabled(): raise ConfigError( "user_consent configuration requires server notices, but " - "server notices are not enabled.", + "server notices are not enabled." ) - if 'body' not in self._server_notice_content: + if "body" not in self._server_notice_content: raise ConfigError( - "user_consent server_notice_consent must contain a 'body' " - "key.", + "user_consent server_notice_consent must contain a 'body' " "key." ) self._consent_uri_builder = ConsentURIBuilder(hs.config) @@ -95,18 +95,14 @@ class ConsentServerNotices(object): # need to send a message. try: consent_uri = self._consent_uri_builder.build_user_consent_uri( - get_localpart_from_id(user_id), + get_localpart_from_id(user_id) ) content = copy_with_str_subst( - self._server_notice_content, { - 'consent_uri': consent_uri, - }, - ) - yield self._server_notices_manager.send_notice( - user_id, content, + self._server_notice_content, {"consent_uri": consent_uri} ) + yield self._server_notices_manager.send_notice(user_id, content) yield self._store.user_set_consent_server_notice_sent( - user_id, self._current_consent_version, + user_id, self._current_consent_version ) except SynapseError as e: logger.error("Error sending server notice about user consent: %s", e) @@ -128,9 +124,7 @@ def copy_with_str_subst(x, substitutions): if isinstance(x, string_types): return x % substitutions if isinstance(x, dict): - return { - k: copy_with_str_subst(v, substitutions) for (k, v) in iteritems(x) - } + return {k: copy_with_str_subst(v, substitutions) for (k, v) in iteritems(x)} if isinstance(x, (list, tuple)): return [copy_with_str_subst(y) for y in x] diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index af15cba0ee..f183743f31 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -33,6 +33,7 @@ class ResourceLimitsServerNotices(object): """ Keeps track of whether the server has reached it's resource limit and ensures that the client is kept up to date. """ + def __init__(self, hs): """ Args: @@ -104,34 +105,28 @@ class ResourceLimitsServerNotices(object): if currently_blocked and not is_auth_blocking: # Room is notifying of a block, when it ought not to be. # Remove block notification - content = { - "pinned": ref_events - } + content = {"pinned": ref_events} yield self._server_notices_manager.send_notice( - user_id, content, EventTypes.Pinned, '', + user_id, content, EventTypes.Pinned, "" ) elif not currently_blocked and is_auth_blocking: # Room is not notifying of a block, when it ought to be. # Add block notification content = { - 'body': event_content, - 'msgtype': ServerNoticeMsgType, - 'server_notice_type': ServerNoticeLimitReached, - 'admin_contact': self._config.admin_contact, - 'limit_type': event_limit_type + "body": event_content, + "msgtype": ServerNoticeMsgType, + "server_notice_type": ServerNoticeLimitReached, + "admin_contact": self._config.admin_contact, + "limit_type": event_limit_type, } event = yield self._server_notices_manager.send_notice( - user_id, content, EventTypes.Message, + user_id, content, EventTypes.Message ) - content = { - "pinned": [ - event.event_id, - ] - } + content = {"pinned": [event.event_id]} yield self._server_notices_manager.send_notice( - user_id, content, EventTypes.Pinned, '', + user_id, content, EventTypes.Pinned, "" ) except SynapseError as e: @@ -156,9 +151,7 @@ class ResourceLimitsServerNotices(object): max_id = yield self._store.add_tag_to_room( user_id, room_id, SERVER_NOTICE_ROOM_TAG, {} ) - self._notifier.on_new_event( - "account_data_key", max_id, users=[user_id] - ) + self._notifier.on_new_event("account_data_key", max_id, users=[user_id]) @defer.inlineCallbacks def _is_room_currently_blocked(self, room_id): @@ -188,7 +181,7 @@ class ResourceLimitsServerNotices(object): referenced_events = [] if pinned_state_event is not None: - referenced_events = list(pinned_state_event.content.get('pinned', [])) + referenced_events = list(pinned_state_event.content.get("pinned", [])) events = yield self._store.get_events(referenced_events) for event_id, event in iteritems(events): diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index c5cc6d728e..71e7e75320 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -51,8 +51,7 @@ class ServerNoticesManager(object): @defer.inlineCallbacks def send_notice( - self, user_id, event_content, - type=EventTypes.Message, state_key=None + self, user_id, event_content, type=EventTypes.Message, state_key=None ): """Send a notice to the given user @@ -82,10 +81,10 @@ class ServerNoticesManager(object): } if state_key is not None: - event_dict['state_key'] = state_key + event_dict["state_key"] = state_key res = yield self._event_creation_handler.create_and_send_nonmember_event( - requester, event_dict, ratelimit=False, + requester, event_dict, ratelimit=False ) defer.returnValue(res) @@ -104,11 +103,10 @@ class ServerNoticesManager(object): if not self.is_enabled(): raise Exception("Server notices not enabled") - assert self._is_mine_id(user_id), \ - "Cannot send server notices to remote users" + assert self._is_mine_id(user_id), "Cannot send server notices to remote users" rooms = yield self._store.get_rooms_for_user_where_membership_is( - user_id, [Membership.INVITE, Membership.JOIN], + user_id, [Membership.INVITE, Membership.JOIN] ) system_mxid = self._config.server_notices_mxid for room in rooms: @@ -132,8 +130,8 @@ class ServerNoticesManager(object): # avatar, we have to use both. join_profile = None if ( - self._config.server_notices_mxid_display_name is not None or - self._config.server_notices_mxid_avatar_url is not None + self._config.server_notices_mxid_display_name is not None + or self._config.server_notices_mxid_avatar_url is not None ): join_profile = { "displayname": self._config.server_notices_mxid_display_name, @@ -146,22 +144,18 @@ class ServerNoticesManager(object): config={ "preset": RoomCreationPreset.PRIVATE_CHAT, "name": self._config.server_notices_room_name, - "power_level_content_override": { - "users_default": -10, - }, - "invite": (user_id,) + "power_level_content_override": {"users_default": -10}, + "invite": (user_id,), }, ratelimit=False, creator_join_profile=join_profile, ) - room_id = info['room_id'] + room_id = info["room_id"] max_id = yield self._store.add_tag_to_room( - user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}, - ) - self._notifier.on_new_event( - "account_data_key", max_id, users=[user_id] + user_id, room_id, SERVER_NOTICE_ROOM_TAG, {} ) + self._notifier.on_new_event("account_data_key", max_id, users=[user_id]) logger.info("Created server notices room %s for %s", room_id, user_id) defer.returnValue(room_id) diff --git a/synapse/server_notices/server_notices_sender.py b/synapse/server_notices/server_notices_sender.py index 6121b2f267..652bab58e3 100644 --- a/synapse/server_notices/server_notices_sender.py +++ b/synapse/server_notices/server_notices_sender.py @@ -24,6 +24,7 @@ class ServerNoticesSender(object): """A centralised place which sends server notices automatically when Certain Events take place """ + def __init__(self, hs): """ @@ -32,7 +33,7 @@ class ServerNoticesSender(object): """ self._server_notices = ( ConsentServerNotices(hs), - ResourceLimitsServerNotices(hs) + ResourceLimitsServerNotices(hs), ) @defer.inlineCallbacks @@ -43,9 +44,7 @@ class ServerNoticesSender(object): user_id (str): mxid of user who synced """ for sn in self._server_notices: - yield sn.maybe_send_server_notice_to_user( - user_id, - ) + yield sn.maybe_send_server_notice_to_user(user_id) @defer.inlineCallbacks def on_user_ip(self, user_id): @@ -58,6 +57,4 @@ class ServerNoticesSender(object): # we check for notices to send to the user in on_user_ip as well as # in on_user_syncing for sn in self._server_notices: - yield sn.maybe_send_server_notice_to_user( - user_id, - ) + yield sn.maybe_send_server_notice_to_user(user_id) diff --git a/synapse/server_notices/worker_server_notices_sender.py b/synapse/server_notices/worker_server_notices_sender.py index 4a133026c3..245ec7c64f 100644 --- a/synapse/server_notices/worker_server_notices_sender.py +++ b/synapse/server_notices/worker_server_notices_sender.py @@ -17,6 +17,7 @@ from twisted.internet import defer class WorkerServerNoticesSender(object): """Stub impl of ServerNoticesSender which does nothing""" + def __init__(self, hs): """ Args: diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 36684ef9f6..fc20d1eaee 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -98,8 +98,9 @@ class StateHandler(object): self._state_resolution_handler = hs.get_state_resolution_handler() @defer.inlineCallbacks - def get_current_state(self, room_id, event_type=None, state_key="", - latest_event_ids=None): + def get_current_state( + self, room_id, event_type=None, state_key="", latest_event_ids=None + ): """ Retrieves the current state for the room. This is done by calling `get_latest_events_in_room` to get the leading edges of the event graph and then resolving any of the state conflicts. @@ -128,8 +129,9 @@ class StateHandler(object): defer.returnValue(event) return - state_map = yield self.store.get_events(list(state.values()), - get_prev_content=False) + state_map = yield self.store.get_events( + list(state.values()), get_prev_content=False + ) state = { key: state_map[e_id] for key, e_id in iteritems(state) if e_id in state_map } @@ -211,9 +213,7 @@ class StateHandler(object): # state. Certainly store.get_current_state won't return any, and # persisting the event won't store the state group. if old_state: - prev_state_ids = { - (s.type, s.state_key): s.event_id for s in old_state - } + prev_state_ids = {(s.type, s.state_key): s.event_id for s in old_state} if event.is_state(): current_state_ids = dict(prev_state_ids) key = (event.type, event.state_key) @@ -239,9 +239,7 @@ class StateHandler(object): # Let's just correctly fill out the context and create a # new state group for it. - prev_state_ids = { - (s.type, s.state_key): s.event_id for s in old_state - } + prev_state_ids = {(s.type, s.state_key): s.event_id for s in old_state} if event.is_state(): key = (event.type, event.state_key) @@ -273,7 +271,7 @@ class StateHandler(object): logger.debug("calling resolve_state_groups from compute_event_context") entry = yield self.resolve_state_groups_for_events( - event.room_id, event.prev_event_ids(), + event.room_id, event.prev_event_ids() ) prev_state_ids = entry.state @@ -296,9 +294,7 @@ class StateHandler(object): # If the state at the event has a state group assigned then # we can use that as the prev group prev_group = entry.state_group - delta_ids = { - key: event.event_id - } + delta_ids = {key: event.event_id} elif entry.prev_group: # If the state at the event only has a prev group, then we can # use that as a prev group too. @@ -360,31 +356,31 @@ class StateHandler(object): # map from state group id to the state in that state group (where # 'state' is a map from state key to event id) # dict[int, dict[(str, str), str]] - state_groups_ids = yield self.store.get_state_groups_ids( - room_id, event_ids - ) + state_groups_ids = yield self.store.get_state_groups_ids(room_id, event_ids) if len(state_groups_ids) == 0: - defer.returnValue(_StateCacheEntry( - state={}, - state_group=None, - )) + defer.returnValue(_StateCacheEntry(state={}, state_group=None)) elif len(state_groups_ids) == 1: name, state_list = list(state_groups_ids.items()).pop() prev_group, delta_ids = yield self.store.get_state_group_delta(name) - defer.returnValue(_StateCacheEntry( - state=state_list, - state_group=name, - prev_group=prev_group, - delta_ids=delta_ids, - )) + defer.returnValue( + _StateCacheEntry( + state=state_list, + state_group=name, + prev_group=prev_group, + delta_ids=delta_ids, + ) + ) room_version = yield self.store.get_room_version(room_id) result = yield self._state_resolution_handler.resolve_state_groups( - room_id, room_version, state_groups_ids, None, + room_id, + room_version, + state_groups_ids, + None, state_res_store=StateResolutionStore(self.store), ) defer.returnValue(result) @@ -394,27 +390,21 @@ class StateHandler(object): logger.info( "Resolving state for %s with %d groups", event.room_id, len(state_sets) ) - state_set_ids = [{ - (ev.type, ev.state_key): ev.event_id - for ev in st - } for st in state_sets] - - state_map = { - ev.event_id: ev - for st in state_sets - for ev in st - } + state_set_ids = [ + {(ev.type, ev.state_key): ev.event_id for ev in st} for st in state_sets + ] + + state_map = {ev.event_id: ev for st in state_sets for ev in st} with Measure(self.clock, "state._resolve_events"): new_state = yield resolve_events_with_store( - room_version, state_set_ids, + room_version, + state_set_ids, event_map=state_map, state_res_store=StateResolutionStore(self.store), ) - new_state = { - key: state_map[ev_id] for key, ev_id in iteritems(new_state) - } + new_state = {key: state_map[ev_id] for key, ev_id in iteritems(new_state)} defer.returnValue(new_state) @@ -425,6 +415,7 @@ class StateResolutionHandler(object): Note that the storage layer depends on this handler, so all functions must be storage-independent. """ + def __init__(self, hs): self.clock = hs.get_clock() @@ -444,7 +435,7 @@ class StateResolutionHandler(object): @defer.inlineCallbacks @log_function def resolve_state_groups( - self, room_id, room_version, state_groups_ids, event_map, state_res_store, + self, room_id, room_version, state_groups_ids, event_map, state_res_store ): """Resolves conflicts between a set of state groups @@ -471,10 +462,7 @@ class StateResolutionHandler(object): Returns: Deferred[_StateCacheEntry]: resolved state """ - logger.debug( - "resolve_state_groups state_groups %s", - state_groups_ids.keys() - ) + logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys()) group_names = frozenset(state_groups_ids.keys()) @@ -529,10 +517,7 @@ class StateResolutionHandler(object): defer.returnValue(cache) -def _make_state_cache_entry( - new_state, - state_groups_ids, -): +def _make_state_cache_entry(new_state, state_groups_ids): """Given a resolved state, and a set of input state groups, pick one to base a new state group on (if any), and return an appropriately-constructed _StateCacheEntry. @@ -562,10 +547,7 @@ def _make_state_cache_entry( old_state_event_ids = set(itervalues(state)) if new_state_event_ids == old_state_event_ids: # got an exact match. - return _StateCacheEntry( - state=new_state, - state_group=sg, - ) + return _StateCacheEntry(state=new_state, state_group=sg) # TODO: We want to create a state group for this set of events, to # increase cache hits, but we need to make sure that it doesn't @@ -576,20 +558,13 @@ def _make_state_cache_entry( delta_ids = None for old_group, old_state in iteritems(state_groups_ids): - n_delta_ids = { - k: v - for k, v in iteritems(new_state) - if old_state.get(k) != v - } + n_delta_ids = {k: v for k, v in iteritems(new_state) if old_state.get(k) != v} if not delta_ids or len(n_delta_ids) < len(delta_ids): prev_group = old_group delta_ids = n_delta_ids return _StateCacheEntry( - state=new_state, - state_group=None, - prev_group=prev_group, - delta_ids=delta_ids, + state=new_state, state_group=None, prev_group=prev_group, delta_ids=delta_ids ) @@ -618,11 +593,11 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto v = KNOWN_ROOM_VERSIONS[room_version] if v.state_res == StateResolutionVersions.V1: return v1.resolve_events_with_store( - state_sets, event_map, state_res_store.get_events, + state_sets, event_map, state_res_store.get_events ) else: return v2.resolve_events_with_store( - room_version, state_sets, event_map, state_res_store, + room_version, state_sets, event_map, state_res_store ) diff --git a/synapse/state/v1.py b/synapse/state/v1.py index 29b4e86cfd..88acd4817e 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -57,23 +57,17 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory): if len(state_sets) == 1: defer.returnValue(state_sets[0]) - unconflicted_state, conflicted_state = _seperate( - state_sets, - ) + unconflicted_state, conflicted_state = _seperate(state_sets) needed_events = set( - event_id - for event_ids in itervalues(conflicted_state) - for event_id in event_ids + event_id for event_ids in itervalues(conflicted_state) for event_id in event_ids ) needed_event_count = len(needed_events) if event_map is not None: needed_events -= set(iterkeys(event_map)) logger.info( - "Asking for %d/%d conflicted events", - len(needed_events), - needed_event_count, + "Asking for %d/%d conflicted events", len(needed_events), needed_event_count ) # dict[str, FrozenEvent]: a map from state event id to event. Only includes @@ -97,17 +91,17 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory): new_needed_events -= set(iterkeys(event_map)) logger.info( - "Asking for %d/%d auth events", - len(new_needed_events), - new_needed_event_count, + "Asking for %d/%d auth events", len(new_needed_events), new_needed_event_count ) state_map_new = yield state_map_factory(new_needed_events) state_map.update(state_map_new) - defer.returnValue(_resolve_with_state( - unconflicted_state, conflicted_state, auth_events, state_map - )) + defer.returnValue( + _resolve_with_state( + unconflicted_state, conflicted_state, auth_events, state_map + ) + ) def _seperate(state_sets): @@ -173,8 +167,9 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma return auth_events -def _resolve_with_state(unconflicted_state_ids, conflicted_state_ids, auth_event_ids, - state_map): +def _resolve_with_state( + unconflicted_state_ids, conflicted_state_ids, auth_event_ids, state_map +): conflicted_state = {} for key, event_ids in iteritems(conflicted_state_ids): events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map] @@ -190,9 +185,7 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ids, auth_event } try: - resolved_state = _resolve_state_events( - conflicted_state, auth_events - ) + resolved_state = _resolve_state_events(conflicted_state, auth_events) except Exception: logger.exception("Failed to resolve state") raise @@ -218,37 +211,28 @@ def _resolve_state_events(conflicted_state, auth_events): if POWER_KEY in conflicted_state: events = conflicted_state[POWER_KEY] logger.debug("Resolving conflicted power levels %r", events) - resolved_state[POWER_KEY] = _resolve_auth_events( - events, auth_events) + resolved_state[POWER_KEY] = _resolve_auth_events(events, auth_events) auth_events.update(resolved_state) for key, events in iteritems(conflicted_state): if key[0] == EventTypes.JoinRules: logger.debug("Resolving conflicted join rules %r", events) - resolved_state[key] = _resolve_auth_events( - events, - auth_events - ) + resolved_state[key] = _resolve_auth_events(events, auth_events) auth_events.update(resolved_state) for key, events in iteritems(conflicted_state): if key[0] == EventTypes.Member: logger.debug("Resolving conflicted member lists %r", events) - resolved_state[key] = _resolve_auth_events( - events, - auth_events - ) + resolved_state[key] = _resolve_auth_events(events, auth_events) auth_events.update(resolved_state) for key, events in iteritems(conflicted_state): if key not in resolved_state: logger.debug("Resolving conflicted state %r:%r", key, events) - resolved_state[key] = _resolve_normal_events( - events, auth_events - ) + resolved_state[key] = _resolve_normal_events(events, auth_events) return resolved_state @@ -257,9 +241,7 @@ def _resolve_auth_events(events, auth_events): reverse = [i for i in reversed(_ordered_events(events))] auth_keys = set( - key - for event in events - for key in event_auth.auth_types_for_event(event) + key for event in events for key in event_auth.auth_types_for_event(event) ) new_auth_events = {} @@ -313,6 +295,6 @@ def _ordered_events(events): def key_func(e): # we have to use utf-8 rather than ascii here because it turns out we allow # people to send us events with non-ascii event IDs :/ - return -int(e.depth), hashlib.sha1(e.event_id.encode('utf-8')).hexdigest() + return -int(e.depth), hashlib.sha1(e.event_id.encode("utf-8")).hexdigest() return sorted(events, key=key_func) diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 650995c92c..db969e8997 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -70,19 +70,18 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto # Also fetch all auth events that appear in only some of the state sets' # auth chains. - auth_diff = yield _get_auth_chain_difference( - state_sets, event_map, state_res_store, - ) + auth_diff = yield _get_auth_chain_difference(state_sets, event_map, state_res_store) - full_conflicted_set = set(itertools.chain( - itertools.chain.from_iterable(itervalues(conflicted_state)), - auth_diff, - )) + full_conflicted_set = set( + itertools.chain( + itertools.chain.from_iterable(itervalues(conflicted_state)), auth_diff + ) + ) - events = yield state_res_store.get_events([ - eid for eid in full_conflicted_set - if eid not in event_map - ], allow_rejected=True) + events = yield state_res_store.get_events( + [eid for eid in full_conflicted_set if eid not in event_map], + allow_rejected=True, + ) event_map.update(events) full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map) @@ -91,22 +90,21 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto # Get and sort all the power events (kicks/bans/etc) power_events = ( - eid for eid in full_conflicted_set - if _is_power_event(event_map[eid]) + eid for eid in full_conflicted_set if _is_power_event(event_map[eid]) ) sorted_power_events = yield _reverse_topological_power_sort( - power_events, - event_map, - state_res_store, - full_conflicted_set, + power_events, event_map, state_res_store, full_conflicted_set ) logger.debug("sorted %d power events", len(sorted_power_events)) # Now sequentially auth each one resolved_state = yield _iterative_auth_checks( - room_version, sorted_power_events, unconflicted_state, event_map, + room_version, + sorted_power_events, + unconflicted_state, + event_map, state_res_store, ) @@ -116,23 +114,20 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto # events using the mainline of the resolved power level. leftover_events = [ - ev_id - for ev_id in full_conflicted_set - if ev_id not in sorted_power_events + ev_id for ev_id in full_conflicted_set if ev_id not in sorted_power_events ] logger.debug("sorting %d remaining events", len(leftover_events)) pl = resolved_state.get((EventTypes.PowerLevels, ""), None) leftover_events = yield _mainline_sort( - leftover_events, pl, event_map, state_res_store, + leftover_events, pl, event_map, state_res_store ) logger.debug("resolving remaining events") resolved_state = yield _iterative_auth_checks( - room_version, leftover_events, resolved_state, event_map, - state_res_store, + room_version, leftover_events, resolved_state, event_map, state_res_store ) logger.debug("resolved") @@ -209,14 +204,16 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store): auth_ids = set( eid for key, eid in iteritems(state_set) - if (key[0] in ( - EventTypes.Member, - EventTypes.ThirdPartyInvite, - ) or key in ( - (EventTypes.PowerLevels, ''), - (EventTypes.Create, ''), - (EventTypes.JoinRules, ''), - )) and eid not in common + if ( + key[0] in (EventTypes.Member, EventTypes.ThirdPartyInvite) + or key + in ( + (EventTypes.PowerLevels, ""), + (EventTypes.Create, ""), + (EventTypes.JoinRules, ""), + ) + ) + and eid not in common ) auth_chain = yield state_res_store.get_auth_chain(auth_ids) @@ -274,15 +271,16 @@ def _is_power_event(event): return True if event.type == EventTypes.Member: - if event.membership in ('leave', 'ban'): + if event.membership in ("leave", "ban"): return event.sender != event.state_key return False @defer.inlineCallbacks -def _add_event_and_auth_chain_to_graph(graph, event_id, event_map, - state_res_store, auth_diff): +def _add_event_and_auth_chain_to_graph( + graph, event_id, event_map, state_res_store, auth_diff +): """Helper function for _reverse_topological_power_sort that add the event and its auth chain (that is in the auth diff) to the graph @@ -327,7 +325,7 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_ graph = {} for event_id in event_ids: yield _add_event_and_auth_chain_to_graph( - graph, event_id, event_map, state_res_store, auth_diff, + graph, event_id, event_map, state_res_store, auth_diff ) event_to_pl = {} @@ -342,18 +340,16 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_ return -pl, ev.origin_server_ts, event_id # Note: graph is modified during the sort - it = lexicographical_topological_sort( - graph, - key=_get_power_order, - ) + it = lexicographical_topological_sort(graph, key=_get_power_order) sorted_events = list(it) defer.returnValue(sorted_events) @defer.inlineCallbacks -def _iterative_auth_checks(room_version, event_ids, base_state, event_map, - state_res_store): +def _iterative_auth_checks( + room_version, event_ids, base_state, event_map, state_res_store +): """Sequentially apply auth checks to each event in given list, updating the state as it goes along. @@ -389,9 +385,11 @@ def _iterative_auth_checks(room_version, event_ids, base_state, event_map, try: event_auth.check( - room_version, event, auth_events, + room_version, + event, + auth_events, do_sig_check=False, - do_size_check=False + do_size_check=False, ) resolved_state[(event.type, event.state_key)] = event_id @@ -402,8 +400,7 @@ def _iterative_auth_checks(room_version, event_ids, base_state, event_map, @defer.inlineCallbacks -def _mainline_sort(event_ids, resolved_power_event_id, event_map, - state_res_store): +def _mainline_sort(event_ids, resolved_power_event_id, event_map, state_res_store): """Returns a sorted list of event_ids sorted by mainline ordering based on the given event resolved_power_event_id @@ -436,8 +433,7 @@ def _mainline_sort(event_ids, resolved_power_event_id, event_map, order_map = {} for ev_id in event_ids: depth = yield _get_mainline_depth_for_event( - event_map[ev_id], mainline_map, - event_map, state_res_store, + event_map[ev_id], mainline_map, event_map, state_res_store ) order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 0ca6f6121f..6b0ca80087 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -280,7 +280,7 @@ class DataStore( Counts the number of users who used this homeserver in the last 24 hours. """ yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) - return self.runInteraction("count_daily_users", self._count_users, yesterday,) + return self.runInteraction("count_daily_users", self._count_users, yesterday) def count_monthly_users(self): """ @@ -291,9 +291,7 @@ class DataStore( """ thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) return self.runInteraction( - "count_monthly_users", - self._count_users, - thirty_days_ago, + "count_monthly_users", self._count_users, thirty_days_ago ) def _count_users(self, txn, time_from): @@ -361,7 +359,7 @@ class DataStore( txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) for row in txn: - if row[0] == 'unknown': + if row[0] == "unknown": pass results[row[0]] = row[1] @@ -388,7 +386,7 @@ class DataStore( txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) count, = txn.fetchone() - results['all'] = count + results["all"] = count return results diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 941c07fce5..c74bcc8f0b 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -312,9 +312,7 @@ class SQLBaseStore(object): if res: for user in res: self.set_expiration_date_for_user_txn( - txn, - user["name"], - use_delta=True, + txn, user["name"], use_delta=True ) yield self.runInteraction( @@ -1667,7 +1665,7 @@ def db_to_json(db_content): # Decode it to a Unicode string before feeding it to json.loads, so we # consistenty get a Unicode-containing object out. if isinstance(db_content, (bytes, bytearray)): - db_content = db_content.decode('utf8') + db_content = db_content.decode("utf8") try: return json.loads(db_content) diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index b8b8273f73..50f913a414 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -169,7 +169,7 @@ class BackgroundUpdateStore(SQLBaseStore): in_flight = set(update["update_name"] for update in updates) for update in updates: if update["depends_on"] not in in_flight: - self._background_update_queue.append(update['update_name']) + self._background_update_queue.append(update["update_name"]) if not self._background_update_queue: # no work left to do diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index d102e07372..3413a46675 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -149,9 +149,7 @@ class DeviceWorkerStore(SQLBaseStore): defer.returnValue((stream_id_cutoff, [])) results = yield self._get_device_update_edus_by_remote( - destination, - from_stream_id, - query_map, + destination, from_stream_id, query_map ) defer.returnValue((now_stream_id, results)) @@ -182,9 +180,7 @@ class DeviceWorkerStore(SQLBaseStore): return list(txn) @defer.inlineCallbacks - def _get_device_update_edus_by_remote( - self, destination, from_stream_id, query_map, - ): + def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_map): """Returns a list of device update EDUs as well as E2EE keys Args: @@ -210,7 +206,7 @@ class DeviceWorkerStore(SQLBaseStore): # The prev_id for the first row is always the last row before # `from_stream_id` prev_id = yield self._get_last_device_update_for_remote_user( - destination, user_id, from_stream_id, + destination, user_id, from_stream_id ) for device_id, device in iteritems(user_devices): stream_id = query_map[(user_id, device_id)] @@ -238,7 +234,7 @@ class DeviceWorkerStore(SQLBaseStore): defer.returnValue(results) def _get_last_device_update_for_remote_user( - self, destination, user_id, from_stream_id, + self, destination, user_id, from_stream_id ): def f(txn): prev_sent_id_sql = """ diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py index 521936e3b0..f40ef2ab64 100644 --- a/synapse/storage/e2e_room_keys.py +++ b/synapse/storage/e2e_room_keys.py @@ -87,10 +87,10 @@ class EndToEndRoomKeyStore(SQLBaseStore): }, values={ "version": version, - "first_message_index": room_key['first_message_index'], - "forwarded_count": room_key['forwarded_count'], - "is_verified": room_key['is_verified'], - "session_data": json.dumps(room_key['session_data']), + "first_message_index": room_key["first_message_index"], + "forwarded_count": room_key["forwarded_count"], + "is_verified": room_key["is_verified"], + "session_data": json.dumps(room_key["session_data"]), }, lock=False, ) @@ -118,13 +118,13 @@ class EndToEndRoomKeyStore(SQLBaseStore): try: version = int(version) except ValueError: - defer.returnValue({'rooms': {}}) + defer.returnValue({"rooms": {}}) keyvalues = {"user_id": user_id, "version": version} if room_id: - keyvalues['room_id'] = room_id + keyvalues["room_id"] = room_id if session_id: - keyvalues['session_id'] = session_id + keyvalues["session_id"] = session_id rows = yield self._simple_select_list( table="e2e_room_keys", @@ -141,10 +141,10 @@ class EndToEndRoomKeyStore(SQLBaseStore): desc="get_e2e_room_keys", ) - sessions = {'rooms': {}} + sessions = {"rooms": {}} for row in rows: - room_entry = sessions['rooms'].setdefault(row['room_id'], {"sessions": {}}) - room_entry['sessions'][row['session_id']] = { + room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}}) + room_entry["sessions"][row["session_id"]] = { "first_message_index": row["first_message_index"], "forwarded_count": row["forwarded_count"], "is_verified": row["is_verified"], @@ -174,9 +174,9 @@ class EndToEndRoomKeyStore(SQLBaseStore): keyvalues = {"user_id": user_id, "version": int(version)} if room_id: - keyvalues['room_id'] = room_id + keyvalues["room_id"] = room_id if session_id: - keyvalues['session_id'] = session_id + keyvalues["session_id"] = session_id yield self._simple_delete( table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys" @@ -191,7 +191,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): ) row = txn.fetchone() if not row: - raise StoreError(404, 'No current backup version') + raise StoreError(404, "No current backup version") return row[0] def get_e2e_room_keys_version_info(self, user_id, version=None): @@ -255,7 +255,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): ) current_version = txn.fetchone()[0] if current_version is None: - current_version = '0' + current_version = "0" new_version = str(int(current_version) + 1) diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index 933bcf42c2..e9b9caa49a 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -85,7 +85,7 @@ class Sqlite3Engine(object): def _parse_match_info(buf): bufsize = len(buf) - return [struct.unpack('@I', buf[i : i + 4])[0] for i in range(0, bufsize, 4)] + return [struct.unpack("@I", buf[i : i + 4])[0] for i in range(0, bufsize, 4)] def _rank(raw_match_info): diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index e8d16edbc8..cb4478342f 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -215,8 +215,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return [room_id for room_id, in txn] return self.runInteraction( - "get_rooms_with_many_extremities", - _get_rooms_with_many_extremities_txn, + "get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn ) @cached(max_entries=5000, iterable=True) diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index a729f3e067..eca77069fd 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -277,7 +277,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): # contain results from the first query, correctly ordered, followed # by results from the second query, but we want them all ordered # by stream_ordering, oldest first. - notifs.sort(key=lambda r: r['stream_ordering']) + notifs.sort(key=lambda r: r["stream_ordering"]) # Take only up to the limit. We have to stop at the limit because # one of the subqueries may have hit the limit. @@ -379,7 +379,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): # contain results from the first query, correctly ordered, followed # by results from the second query, but we want them all ordered # by received_ts (most recent first) - notifs.sort(key=lambda r: -(r['received_ts'] or 0)) + notifs.sort(key=lambda r: -(r["received_ts"] or 0)) # Now return the first `limit` defer.returnValue(notifs[:limit]) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index f631fb1733..de5965569c 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -234,7 +234,7 @@ class EventsStore( BucketCollector( "synapse_forward_extremities", lambda: self._current_forward_extremities_amount, - buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"] + buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"], ) # Read the extrems every 60 minutes diff --git a/synapse/storage/events_bg_updates.py b/synapse/storage/events_bg_updates.py index 75c1935bf3..1ce21d190c 100644 --- a/synapse/storage/events_bg_updates.py +++ b/synapse/storage/events_bg_updates.py @@ -64,8 +64,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): ) self.register_background_update_handler( - self.DELETE_SOFT_FAILED_EXTREMITIES, - self._cleanup_extremities_bg_update, + self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update ) @defer.inlineCallbacks @@ -269,7 +268,8 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): LEFT JOIN events USING (event_id) LEFT JOIN event_json USING (event_id) LEFT JOIN rejections USING (event_id) - """, (batch_size,) + """, + (batch_size,), ) for prev_event_id, event_id, metadata, rejected, outlier in txn: @@ -364,13 +364,12 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): column="event_id", iterable=to_delete, keyvalues={}, - retcols=("room_id",) + retcols=("room_id",), ) room_ids = set(row["room_id"] for row in rows) for room_id in room_ids: txn.call_after( - self.get_latest_event_ids_in_room.invalidate, - (room_id,) + self.get_latest_event_ids_in_room.invalidate, (room_id,) ) self._simple_delete_many_txn( @@ -384,7 +383,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): return len(original_set) num_handled = yield self.runInteraction( - "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn, + "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn ) if not num_handled: @@ -394,8 +393,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): txn.execute("DROP TABLE _extremities_to_check") yield self.runInteraction( - "_cleanup_extremities_bg_update_drop_table", - _drop_table_txn, + "_cleanup_extremities_bg_update_drop_table", _drop_table_txn ) defer.returnValue(num_handled) diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index cc7df5cf14..6d680d405a 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -27,7 +27,6 @@ from synapse.api.constants import EventTypes from synapse.api.errors import NotFoundError from synapse.api.room_versions import EventFormatVersions from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401 -# these are only included to make the type annotations work from synapse.events.snapshot import EventContext # noqa: F401 from synapse.events.utils import prune_event from synapse.metrics.background_process_metrics import run_as_background_process @@ -111,8 +110,7 @@ class EventsWorkerStore(SQLBaseStore): return ts return self.runInteraction( - "get_approximate_received_ts", - _get_approximate_received_ts_txn, + "get_approximate_received_ts", _get_approximate_received_ts_txn ) @defer.inlineCallbacks @@ -677,7 +675,8 @@ class EventsWorkerStore(SQLBaseStore): """ return self.runInteraction( "get_total_state_event_counts", - self._get_total_state_event_counts_txn, room_id + self._get_total_state_event_counts_txn, + room_id, ) def _get_current_state_event_counts_txn(self, txn, room_id): @@ -701,7 +700,8 @@ class EventsWorkerStore(SQLBaseStore): """ return self.runInteraction( "get_current_state_event_counts", - self._get_current_state_event_counts_txn, room_id + self._get_current_state_event_counts_txn, + room_id, ) @defer.inlineCallbacks diff --git a/synapse/storage/group_server.py b/synapse/storage/group_server.py index dce6a43ac1..73e6fc6de2 100644 --- a/synapse/storage/group_server.py +++ b/synapse/storage/group_server.py @@ -1179,11 +1179,7 @@ class GroupServerStore(SQLBaseStore): for table in tables: self._simple_delete_txn( - txn, - table=table, - keyvalues={"group_id": group_id}, + txn, table=table, keyvalues={"group_id": group_id} ) - return self.runInteraction( - "delete_group", _delete_group_txn - ) + return self.runInteraction("delete_group", _delete_group_txn) diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index e3655ad8d7..e72f89e446 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -131,7 +131,7 @@ class KeyStore(SQLBaseStore): def _invalidate(res): f = self._get_server_verify_key.invalidate for i in invalidations: - f((i, )) + f((i,)) return res return self.runInteraction( diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py index 3ecf47e7a7..6b1238ce4a 100644 --- a/synapse/storage/media_repository.py +++ b/synapse/storage/media_repository.py @@ -22,11 +22,11 @@ class MediaRepositoryStore(BackgroundUpdateStore): super(MediaRepositoryStore, self).__init__(db_conn, hs) self.register_background_index_update( - update_name='local_media_repository_url_idx', - index_name='local_media_repository_url_idx', - table='local_media_repository', - columns=['created_ts'], - where_clause='url_cache IS NOT NULL', + update_name="local_media_repository_url_idx", + index_name="local_media_repository_url_idx", + table="local_media_repository", + columns=["created_ts"], + where_clause="url_cache IS NOT NULL", ) def get_local_media(self, media_id): @@ -108,12 +108,12 @@ class MediaRepositoryStore(BackgroundUpdateStore): return dict( zip( ( - 'response_code', - 'etag', - 'expires_ts', - 'og', - 'media_id', - 'download_ts', + "response_code", + "etag", + "expires_ts", + "og", + "media_id", + "download_ts", ), row, ) diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 8aa8abc470..081564360f 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -86,11 +86,11 @@ class MonthlyActiveUsersStore(SQLBaseStore): if len(self.reserved_users) > 0: # questionmarks is a hack to overcome sqlite not supporting # tuples in 'WHERE IN %s' - questionmarks = '?' * len(self.reserved_users) + questionmarks = "?" * len(self.reserved_users) query_args.extend(self.reserved_users) sql = base_sql + """ AND user_id NOT IN ({})""".format( - ','.join(questionmarks) + ",".join(questionmarks) ) else: sql = base_sql @@ -124,7 +124,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): if len(self.reserved_users) > 0: query_args.extend(self.reserved_users) sql = base_sql + """ AND user_id NOT IN ({})""".format( - ','.join(questionmarks) + ",".join(questionmarks) ) else: sql = base_sql diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index f2c1bed487..fc10b9534e 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -146,9 +146,10 @@ def _setup_new_database(cur, database_engine): directory_entries = os.listdir(sql_dir) - for filename in sorted(fnmatch.filter(directory_entries, "*.sql") + fnmatch.filter( - directory_entries, "*.sql." + specific - )): + for filename in sorted( + fnmatch.filter(directory_entries, "*.sql") + + fnmatch.filter(directory_entries, "*.sql." + specific) + ): sql_loc = os.path.join(sql_dir, filename) logger.debug("Applying schema %s", sql_loc) executescript(cur, sql_loc) @@ -313,7 +314,7 @@ def _apply_module_schemas(txn, database_engine, config): application config """ for (mod, _config) in config.password_providers: - if not hasattr(mod, 'get_db_schema_files'): + if not hasattr(mod, "get_db_schema_files"): continue modname = ".".join((mod.__module__, mod.__name__)) _apply_module_schema_files( @@ -343,7 +344,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams) continue root_name, ext = os.path.splitext(name) - if ext != '.sql': + if ext != ".sql": raise PrepareDatabaseException( "only .sql files are currently supported for module schemas" ) @@ -407,7 +408,7 @@ def get_statements(f): def executescript(txn, schema_path): - with open(schema_path, 'r') as f: + with open(schema_path, "r") as f: for statement in get_statements(f): txn.execute(statement) diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py index aeec2f57c4..0ff392bdb4 100644 --- a/synapse/storage/profile.py +++ b/synapse/storage/profile.py @@ -41,7 +41,7 @@ class ProfileWorkerStore(SQLBaseStore): defer.returnValue( ProfileInfo( - avatar_url=profile['avatar_url'], display_name=profile['displayname'] + avatar_url=profile["avatar_url"], display_name=profile["displayname"] ) ) diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 9e406baafa..98cec8c82b 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -46,12 +46,12 @@ def _load_rules(rawrules, enabled_map): rules = list(list_with_base_rules(ruleslist)) for i, rule in enumerate(rules): - rule_id = rule['rule_id'] + rule_id = rule["rule_id"] if rule_id in enabled_map: - if rule.get('enabled', True) != bool(enabled_map[rule_id]): + if rule.get("enabled", True) != bool(enabled_map[rule_id]): # Rules are cached across users. rule = dict(rule) - rule['enabled'] = bool(enabled_map[rule_id]) + rule["enabled"] = bool(enabled_map[rule_id]) rules[i] = rule return rules @@ -126,12 +126,12 @@ class PushRulesWorkerStore( def get_push_rules_enabled_for_user(self, user_id): results = yield self._simple_select_list( table="push_rules_enable", - keyvalues={'user_name': user_id}, + keyvalues={"user_name": user_id}, retcols=("user_name", "rule_id", "enabled"), desc="get_push_rules_enabled_for_user", ) defer.returnValue( - {r['rule_id']: False if r['enabled'] == 0 else True for r in results} + {r["rule_id"]: False if r["enabled"] == 0 else True for r in results} ) def have_push_rules_changed_for_user(self, user_id, last_id): @@ -175,7 +175,7 @@ class PushRulesWorkerStore( rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) for row in rows: - results.setdefault(row['user_name'], []).append(row) + results.setdefault(row["user_name"], []).append(row) enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids) @@ -194,7 +194,7 @@ class PushRulesWorkerStore( rule (Dict): A push rule. """ # Create new rule id - rule_id_scope = '/'.join(rule["rule_id"].split('/')[:-1]) + rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1]) new_rule_id = rule_id_scope + "/" + new_room_id # Change room id in each condition @@ -334,8 +334,8 @@ class PushRulesWorkerStore( desc="bulk_get_push_rules_enabled", ) for row in rows: - enabled = bool(row['enabled']) - results.setdefault(row['user_name'], {})[row['rule_id']] = enabled + enabled = bool(row["enabled"]) + results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled defer.returnValue(results) @@ -568,7 +568,7 @@ class PushRuleStore(PushRulesWorkerStore): def delete_push_rule_txn(txn, stream_id, event_stream_ordering): self._simple_delete_one_txn( - txn, "push_rules", {'user_name': user_id, 'rule_id': rule_id} + txn, "push_rules", {"user_name": user_id, "rule_id": rule_id} ) self._insert_push_rules_update_txn( @@ -605,9 +605,9 @@ class PushRuleStore(PushRulesWorkerStore): self._simple_upsert_txn( txn, "push_rules_enable", - {'user_name': user_id, 'rule_id': rule_id}, - {'enabled': 1 if enabled else 0}, - {'id': new_id}, + {"user_name": user_id, "rule_id": rule_id}, + {"enabled": 1 if enabled else 0}, + {"id": new_id}, ) self._insert_push_rules_update_txn( @@ -645,8 +645,8 @@ class PushRuleStore(PushRulesWorkerStore): self._simple_update_one_txn( txn, "push_rules", - {'user_name': user_id, 'rule_id': rule_id}, - {'actions': actions_json}, + {"user_name": user_id, "rule_id": rule_id}, + {"actions": actions_json}, ) self._insert_push_rules_update_txn( diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 1567e1df48..cfe0a94330 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -37,24 +37,24 @@ else: class PusherWorkerStore(SQLBaseStore): def _decode_pushers_rows(self, rows): for r in rows: - dataJson = r['data'] - r['data'] = None + dataJson = r["data"] + r["data"] = None try: if isinstance(dataJson, db_binary_type): dataJson = str(dataJson).decode("UTF8") - r['data'] = json.loads(dataJson) + r["data"] = json.loads(dataJson) except Exception as e: logger.warn( "Invalid JSON in data for pusher %d: %s, %s", - r['id'], + r["id"], dataJson, e.args[0], ) pass - if isinstance(r['pushkey'], db_binary_type): - r['pushkey'] = str(r['pushkey']).decode("UTF8") + if isinstance(r["pushkey"], db_binary_type): + r["pushkey"] = str(r["pushkey"]).decode("UTF8") return rows @@ -195,15 +195,15 @@ class PusherWorkerStore(SQLBaseStore): ) def get_if_users_have_pushers(self, user_ids): rows = yield self._simple_select_many_batch( - table='pushers', - column='user_name', + table="pushers", + column="user_name", iterable=user_ids, - retcols=['user_name'], - desc='get_if_users_have_pushers', + retcols=["user_name"], + desc="get_if_users_have_pushers", ) result = {user_id: False for user_id in user_ids} - result.update({r['user_name']: True for r in rows}) + result.update({r["user_name"]: True for r in rows}) defer.returnValue(result) @@ -299,8 +299,8 @@ class PusherStore(PusherWorkerStore): ): yield self._simple_update_one( "pushers", - {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id}, - {'last_stream_ordering': last_stream_ordering}, + {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, + {"last_stream_ordering": last_stream_ordering}, desc="update_pusher_last_stream_ordering", ) @@ -310,10 +310,10 @@ class PusherStore(PusherWorkerStore): ): yield self._simple_update_one( "pushers", - {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id}, + {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, { - 'last_stream_ordering': last_stream_ordering, - 'last_success': last_success, + "last_stream_ordering": last_stream_ordering, + "last_success": last_success, }, desc="update_pusher_last_stream_ordering_and_success", ) @@ -322,8 +322,8 @@ class PusherStore(PusherWorkerStore): def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since): yield self._simple_update_one( "pushers", - {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id}, - {'failing_since': failing_since}, + {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, + {"failing_since": failing_since}, desc="update_pusher_failing_since", ) diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index a1647e50a1..b477da12b1 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -58,7 +58,7 @@ class ReceiptsWorkerStore(SQLBaseStore): @cachedInlineCallbacks() def get_users_with_read_receipts_in_room(self, room_id): receipts = yield self.get_receipts_for_room(room_id, "m.read") - defer.returnValue(set(r['user_id'] for r in receipts)) + defer.returnValue(set(r["user_id"] for r in receipts)) @cached(num_args=2) def get_receipts_for_room(self, room_id, receipt_type): diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 0b3c656e90..983ce13291 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -116,8 +116,9 @@ class RegistrationWorkerStore(SQLBaseStore): defer.returnValue(res) @defer.inlineCallbacks - def set_account_validity_for_user(self, user_id, expiration_ts, email_sent, - renewal_token=None): + def set_account_validity_for_user( + self, user_id, expiration_ts, email_sent, renewal_token=None + ): """Updates the account validity properties of the given account, with the given values. @@ -131,6 +132,7 @@ class RegistrationWorkerStore(SQLBaseStore): renewal_token (str): Renewal token the user can use to extend the validity of their account. Defaults to no token. """ + def set_account_validity_for_user_txn(txn): self._simple_update_txn( txn=txn, @@ -143,12 +145,11 @@ class RegistrationWorkerStore(SQLBaseStore): }, ) self._invalidate_cache_and_stream( - txn, self.get_expiration_ts_for_user, (user_id,), + txn, self.get_expiration_ts_for_user, (user_id,) ) yield self.runInteraction( - "set_account_validity_for_user", - set_account_validity_for_user_txn, + "set_account_validity_for_user", set_account_validity_for_user_txn ) @defer.inlineCallbacks @@ -217,6 +218,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]] """ + def select_users_txn(txn, now_ms, renew_at): sql = ( "SELECT user_id, expiration_ts_ms FROM account_validity" @@ -229,7 +231,8 @@ class RegistrationWorkerStore(SQLBaseStore): res = yield self.runInteraction( "get_users_expiring_soon", select_users_txn, - self.clock.time_msec(), self.config.account_validity.renew_at, + self.clock.time_msec(), + self.config.account_validity.renew_at, ) defer.returnValue(res) @@ -369,7 +372,7 @@ class RegistrationWorkerStore(SQLBaseStore): WHERE creation_ts > ? ) AS t GROUP BY user_type """ - results = {'native': 0, 'guest': 0, 'bridged': 0} + results = {"native": 0, "guest": 0, "bridged": 0} txn.execute(sql, (yesterday,)) for row in txn: results[row[0]] = row[1] @@ -435,7 +438,7 @@ class RegistrationWorkerStore(SQLBaseStore): {"medium": medium, "address": address}, ["guest_access_token"], True, - 'get_3pid_guest_access_token', + "get_3pid_guest_access_token", ) if ret: defer.returnValue(ret["guest_access_token"]) @@ -472,11 +475,11 @@ class RegistrationWorkerStore(SQLBaseStore): txn, "user_threepids", {"medium": medium, "address": address}, - ['user_id'], + ["user_id"], True, ) if ret: - return ret['user_id'] + return ret["user_id"] return None @defer.inlineCallbacks @@ -492,8 +495,8 @@ class RegistrationWorkerStore(SQLBaseStore): ret = yield self._simple_select_list( "user_threepids", {"user_id": user_id}, - ['medium', 'address', 'validated_at', 'added_at'], - 'user_get_threepids', + ["medium", "address", "validated_at", "added_at"], + "user_get_threepids", ) defer.returnValue(ret) @@ -572,11 +575,7 @@ class RegistrationWorkerStore(SQLBaseStore): """ return self._simple_select_onecol( table="user_threepid_id_server", - keyvalues={ - "user_id": user_id, - "medium": medium, - "address": address, - }, + keyvalues={"user_id": user_id, "medium": medium, "address": address}, retcol="id_server", desc="get_id_servers_user_bound", ) @@ -612,16 +611,16 @@ class RegistrationStore( self.register_noop_background_update("refresh_tokens_device_index") self.register_background_update_handler( - "user_threepids_grandfather", self._bg_user_threepids_grandfather, + "user_threepids_grandfather", self._bg_user_threepids_grandfather ) self.register_background_update_handler( - "users_set_deactivated_flag", self._backgroud_update_set_deactivated_flag, + "users_set_deactivated_flag", self._backgroud_update_set_deactivated_flag ) # Create a background job for culling expired 3PID validity tokens hs.get_clock().looping_call( - self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS, + self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS ) @defer.inlineCallbacks @@ -677,8 +676,7 @@ class RegistrationStore( return False end = yield self.runInteraction( - "users_set_deactivated_flag", - _backgroud_update_set_deactivated_flag_txn, + "users_set_deactivated_flag", _backgroud_update_set_deactivated_flag_txn ) if end: @@ -851,7 +849,7 @@ class RegistrationStore( def user_set_password_hash_txn(txn): self._simple_update_one_txn( - txn, 'users', {'name': user_id}, {'password_hash': password_hash} + txn, "users", {"name": user_id}, {"password_hash": password_hash} ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) @@ -872,9 +870,9 @@ class RegistrationStore( def f(txn): self._simple_update_one_txn( txn, - table='users', - keyvalues={'name': user_id}, - updatevalues={'consent_version': consent_version}, + table="users", + keyvalues={"name": user_id}, + updatevalues={"consent_version": consent_version}, ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) @@ -896,9 +894,9 @@ class RegistrationStore( def f(txn): self._simple_update_one_txn( txn, - table='users', - keyvalues={'name': user_id}, - updatevalues={'consent_server_notice_sent': consent_version}, + table="users", + keyvalues={"name": user_id}, + updatevalues={"consent_server_notice_sent": consent_version}, ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) @@ -1068,7 +1066,7 @@ class RegistrationStore( if id_servers: yield self.runInteraction( - "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn, + "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn ) yield self._end_background_update("user_threepids_grandfather") @@ -1076,12 +1074,7 @@ class RegistrationStore( defer.returnValue(1) def get_threepid_validation_session( - self, - medium, - client_secret, - address=None, - sid=None, - validated=True, + self, medium, client_secret, address=None, sid=None, validated=True ): """Gets a session_id and last_send_attempt (if available) for a client_secret/medium/(address|session_id) combo @@ -1101,23 +1094,22 @@ class RegistrationStore( latest session_id and send_attempt count for this 3PID. Otherwise None if there hasn't been a previous attempt """ - keyvalues = { - "medium": medium, - "client_secret": client_secret, - } + keyvalues = {"medium": medium, "client_secret": client_secret} if address: keyvalues["address"] = address if sid: keyvalues["session_id"] = sid - assert(address or sid) + assert address or sid def get_threepid_validation_session_txn(txn): sql = """ SELECT address, session_id, medium, client_secret, last_send_attempt, validated_at FROM threepid_validation_session WHERE %s - """ % (" AND ".join("%s = ?" % k for k in iterkeys(keyvalues)),) + """ % ( + " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)), + ) if validated is not None: sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL") @@ -1132,17 +1124,10 @@ class RegistrationStore( return rows[0] return self.runInteraction( - "get_threepid_validation_session", - get_threepid_validation_session_txn, + "get_threepid_validation_session", get_threepid_validation_session_txn ) - def validate_threepid_session( - self, - session_id, - client_secret, - token, - current_ts, - ): + def validate_threepid_session(self, session_id, client_secret, token, current_ts): """Attempt to validate a threepid session using a token Args: @@ -1174,7 +1159,7 @@ class RegistrationStore( if retrieved_client_secret != client_secret: raise ThreepidValidationError( - 400, "This client_secret does not match the provided session_id", + 400, "This client_secret does not match the provided session_id" ) row = self._simple_select_one_txn( @@ -1187,7 +1172,7 @@ class RegistrationStore( if not row: raise ThreepidValidationError( - 400, "Validation token not found or has expired", + 400, "Validation token not found or has expired" ) expires = row["expires"] next_link = row["next_link"] @@ -1198,7 +1183,7 @@ class RegistrationStore( if expires <= current_ts: raise ThreepidValidationError( - 400, "This token has expired. Please request a new one", + 400, "This token has expired. Please request a new one" ) # Looks good. Validate the session @@ -1213,8 +1198,7 @@ class RegistrationStore( # Return next_link if it exists return self.runInteraction( - "validate_threepid_session_txn", - validate_threepid_session_txn, + "validate_threepid_session_txn", validate_threepid_session_txn ) def upsert_threepid_validation_session( @@ -1281,6 +1265,7 @@ class RegistrationStore( token_expires (int): The timestamp for which after the token will no longer be valid """ + def start_or_continue_validation_session_txn(txn): # Create or update a validation session self._simple_upsert_txn( @@ -1314,6 +1299,7 @@ class RegistrationStore( def cull_expired_threepid_validation_tokens(self): """Remove threepid validation tokens with expiry dates that have passed""" + def cull_expired_threepid_validation_tokens_txn(txn, ts): sql = """ DELETE FROM threepid_validation_token WHERE @@ -1335,6 +1321,7 @@ class RegistrationStore( Args: session_id (str): The ID of the session to delete """ + def delete_threepid_session_txn(txn): self._simple_delete_txn( txn, @@ -1348,8 +1335,7 @@ class RegistrationStore( ) return self.runInteraction( - "delete_threepid_session", - delete_threepid_session_txn, + "delete_threepid_session", delete_threepid_session_txn ) def set_user_deactivated_status_txn(self, txn, user_id, deactivated): @@ -1360,7 +1346,7 @@ class RegistrationStore( updatevalues={"deactivated": 1 if deactivated else 0}, ) self._invalidate_cache_and_stream( - txn, self.get_user_deactivated_status, (user_id,), + txn, self.get_user_deactivated_status, (user_id,) ) @defer.inlineCallbacks @@ -1375,7 +1361,8 @@ class RegistrationStore( yield self.runInteraction( "set_user_deactivated_status", self.set_user_deactivated_status_txn, - user_id, deactivated, + user_id, + deactivated, ) @cachedInlineCallbacks() diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index 4c83800cca..1b01934c19 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -468,9 +468,5 @@ class RelationsStore(RelationsWorkerStore): """ self._simple_delete_txn( - txn, - table="event_relations", - keyvalues={ - "event_id": redacted_event_id, - } + txn, table="event_relations", keyvalues={"event_id": redacted_event_id} ) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 7617913326..8004aeb909 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -420,7 +420,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): table="room_memberships", column="event_id", iterable=missing_member_event_ids, - retcols=('user_id', 'display_name', 'avatar_url'), + retcols=("user_id", "display_name", "avatar_url"), keyvalues={"membership": Membership.JOIN}, batch_size=500, desc="_get_joined_users_from_context", @@ -448,7 +448,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): @cachedInlineCallbacks(max_entries=10000) def is_host_joined(self, room_id, host): - if '%' in host or '_' in host: + if "%" in host or "_" in host: raise Exception("Invalid host name") sql = """ @@ -490,7 +490,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): Deferred: Resolves to True if the host is/was in the room, otherwise False. """ - if '%' in host or '_' in host: + if "%" in host or "_" in host: raise Exception("Invalid host name") sql = """ @@ -723,7 +723,7 @@ class RoomMemberStore(RoomMemberWorkerStore): room_id = row["room_id"] try: event_json = json.loads(row["json"]) - content = event_json['content'] + content = event_json["content"] except Exception: continue diff --git a/synapse/storage/schema/delta/20/pushers.py b/synapse/storage/schema/delta/20/pushers.py index 147496a38b..3edfcfd783 100644 --- a/synapse/storage/schema/delta/20/pushers.py +++ b/synapse/storage/schema/delta/20/pushers.py @@ -29,7 +29,8 @@ logger = logging.getLogger(__name__) def run_create(cur, database_engine, *args, **kwargs): logger.info("Porting pushers table...") - cur.execute(""" + cur.execute( + """ CREATE TABLE IF NOT EXISTS pushers2 ( id BIGINT PRIMARY KEY, user_name TEXT NOT NULL, @@ -48,27 +49,34 @@ def run_create(cur, database_engine, *args, **kwargs): failing_since BIGINT, UNIQUE (app_id, pushkey, user_name) ) - """) - cur.execute("""SELECT + """ + ) + cur.execute( + """SELECT id, user_name, access_token, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_token, last_success, failing_since FROM pushers - """) + """ + ) count = 0 for row in cur.fetchall(): row = list(row) row[8] = bytes(row[8]).decode("utf-8") row[11] = bytes(row[11]).decode("utf-8") - cur.execute(database_engine.convert_param_style(""" + cur.execute( + database_engine.convert_param_style( + """ INSERT into pushers2 ( id, user_name, access_token, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_token, last_success, failing_since - ) values (%s)""" % (','.join(['?' for _ in range(len(row))]))), - row + ) values (%s)""" + % (",".join(["?" for _ in range(len(row))])) + ), + row, ) count += 1 cur.execute("DROP TABLE pushers") diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/schema/delta/30/as_users.py index ef7ec34346..9b95411fb6 100644 --- a/synapse/storage/schema/delta/30/as_users.py +++ b/synapse/storage/schema/delta/30/as_users.py @@ -40,9 +40,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): logger.warning("Could not get app_service_config_files from config") pass - appservices = load_appservices( - config.server_name, config_files - ) + appservices = load_appservices(config.server_name, config_files) owned = {} @@ -53,20 +51,19 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): if user_id in owned.keys(): logger.error( "user_id %s was owned by more than one application" - " service (IDs %s and %s); assigning arbitrarily to %s" % - (user_id, owned[user_id], appservice.id, owned[user_id]) + " service (IDs %s and %s); assigning arbitrarily to %s" + % (user_id, owned[user_id], appservice.id, owned[user_id]) ) owned.setdefault(appservice.id, []).append(user_id) for as_id, user_ids in owned.items(): n = 100 - user_chunks = (user_ids[i:i + 100] for i in range(0, len(user_ids), n)) + user_chunks = (user_ids[i : i + 100] for i in range(0, len(user_ids), n)) for chunk in user_chunks: cur.execute( database_engine.convert_param_style( - "UPDATE users SET appservice_id = ? WHERE name IN (%s)" % ( - ",".join("?" for _ in chunk), - ) + "UPDATE users SET appservice_id = ? WHERE name IN (%s)" + % (",".join("?" for _ in chunk),) ), - [as_id] + chunk + [as_id] + chunk, ) diff --git a/synapse/storage/schema/delta/31/pushers.py b/synapse/storage/schema/delta/31/pushers.py index 93367fa09e..9bb504aad5 100644 --- a/synapse/storage/schema/delta/31/pushers.py +++ b/synapse/storage/schema/delta/31/pushers.py @@ -24,12 +24,13 @@ logger = logging.getLogger(__name__) def token_to_stream_ordering(token): - return int(token[1:].split('_')[0]) + return int(token[1:].split("_")[0]) def run_create(cur, database_engine, *args, **kwargs): logger.info("Porting pushers table, delta 31...") - cur.execute(""" + cur.execute( + """ CREATE TABLE IF NOT EXISTS pushers2 ( id BIGINT PRIMARY KEY, user_name TEXT NOT NULL, @@ -48,26 +49,33 @@ def run_create(cur, database_engine, *args, **kwargs): failing_since BIGINT, UNIQUE (app_id, pushkey, user_name) ) - """) - cur.execute("""SELECT + """ + ) + cur.execute( + """SELECT id, user_name, access_token, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_token, last_success, failing_since FROM pushers - """) + """ + ) count = 0 for row in cur.fetchall(): row = list(row) row[12] = token_to_stream_ordering(row[12]) - cur.execute(database_engine.convert_param_style(""" + cur.execute( + database_engine.convert_param_style( + """ INSERT into pushers2 ( id, user_name, access_token, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_stream_ordering, last_success, failing_since - ) values (%s)""" % (','.join(['?' for _ in range(len(row))]))), - row + ) values (%s)""" + % (",".join(["?" for _ in range(len(row))])) + ), + row, ) count += 1 cur.execute("DROP TABLE pushers") diff --git a/synapse/storage/schema/delta/33/remote_media_ts.py b/synapse/storage/schema/delta/33/remote_media_ts.py index 9754d3ccfb..a26057dfb6 100644 --- a/synapse/storage/schema/delta/33/remote_media_ts.py +++ b/synapse/storage/schema/delta/33/remote_media_ts.py @@ -26,5 +26,5 @@ def run_upgrade(cur, database_engine, *args, **kwargs): database_engine.convert_param_style( "UPDATE remote_media_cache SET last_access_ts = ?" ), - (int(time.time() * 1000),) + (int(time.time() * 1000),), ) diff --git a/synapse/storage/schema/delta/47/state_group_seq.py b/synapse/storage/schema/delta/47/state_group_seq.py index f6766501d2..9fd1ccf6f7 100644 --- a/synapse/storage/schema/delta/47/state_group_seq.py +++ b/synapse/storage/schema/delta/47/state_group_seq.py @@ -27,10 +27,7 @@ def run_create(cur, database_engine, *args, **kwargs): else: start_val = row[0] + 1 - cur.execute( - "CREATE SEQUENCE state_group_id_seq START WITH %s", - (start_val, ), - ) + cur.execute("CREATE SEQUENCE state_group_id_seq START WITH %s", (start_val,)) def run_upgrade(*args, **kwargs): diff --git a/synapse/storage/schema/delta/48/group_unique_indexes.py b/synapse/storage/schema/delta/48/group_unique_indexes.py index 2233af87d7..49f5f2c003 100644 --- a/synapse/storage/schema/delta/48/group_unique_indexes.py +++ b/synapse/storage/schema/delta/48/group_unique_indexes.py @@ -38,16 +38,22 @@ def run_create(cur, database_engine, *args, **kwargs): rowid = "ctid" if isinstance(database_engine, PostgresEngine) else "rowid" # remove duplicates from group_users & group_invites tables - cur.execute(""" + cur.execute( + """ DELETE FROM group_users WHERE %s NOT IN ( SELECT min(%s) FROM group_users GROUP BY group_id, user_id ); - """ % (rowid, rowid)) - cur.execute(""" + """ + % (rowid, rowid) + ) + cur.execute( + """ DELETE FROM group_invites WHERE %s NOT IN ( SELECT min(%s) FROM group_invites GROUP BY group_id, user_id ); - """ % (rowid, rowid)) + """ + % (rowid, rowid) + ) for statement in get_statements(FIX_INDEXES.splitlines()): cur.execute(statement) diff --git a/synapse/storage/schema/delta/50/make_event_content_nullable.py b/synapse/storage/schema/delta/50/make_event_content_nullable.py index 6dd467b6c5..b1684a8441 100644 --- a/synapse/storage/schema/delta/50/make_event_content_nullable.py +++ b/synapse/storage/schema/delta/50/make_event_content_nullable.py @@ -65,14 +65,18 @@ def run_create(cur, database_engine, *args, **kwargs): def run_upgrade(cur, database_engine, *args, **kwargs): if isinstance(database_engine, PostgresEngine): - cur.execute(""" + cur.execute( + """ ALTER TABLE events ALTER COLUMN content DROP NOT NULL; - """) + """ + ) return # sqlite is an arse about this. ref: https://www.sqlite.org/lang_altertable.html - cur.execute("SELECT sql FROM sqlite_master WHERE tbl_name='events' AND type='table'") + cur.execute( + "SELECT sql FROM sqlite_master WHERE tbl_name='events' AND type='table'" + ) (oldsql,) = cur.fetchone() sql = oldsql.replace("content TEXT NOT NULL", "content TEXT") @@ -86,7 +90,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs): cur.execute("PRAGMA writable_schema=ON") cur.execute( "UPDATE sqlite_master SET sql=? WHERE tbl_name='events' AND type='table'", - (sql, ), + (sql,), ) cur.execute("PRAGMA schema_version=%i" % (oldver + 1,)) cur.execute("PRAGMA writable_schema=OFF") diff --git a/synapse/storage/search.py b/synapse/storage/search.py index 10a27c207a..f3b1cec933 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -31,8 +31,8 @@ from .background_updates import BackgroundUpdateStore logger = logging.getLogger(__name__) SearchEntry = namedtuple( - 'SearchEntry', - ['key', 'value', 'event_id', 'room_id', 'stream_ordering', 'origin_server_ts'], + "SearchEntry", + ["key", "value", "event_id", "room_id", "stream_ordering", "origin_server_ts"], ) @@ -216,7 +216,7 @@ class SearchStore(BackgroundUpdateStore): target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] rows_inserted = progress.get("rows_inserted", 0) - have_added_index = progress['have_added_indexes'] + have_added_index = progress["have_added_indexes"] if not have_added_index: diff --git a/synapse/storage/stats.py b/synapse/storage/stats.py index ff266b09b0..1cec84ee2e 100644 --- a/synapse/storage/stats.py +++ b/synapse/storage/stats.py @@ -71,7 +71,8 @@ class StatsStore(StateDeltasStore): # Get all the rooms that we want to process. def _make_staging_area(txn): # Create the temporary tables - stmts = get_statements(""" + stmts = get_statements( + """ -- We just recreate the table, we'll be reinserting the -- correct entries again later anyway. DROP TABLE IF EXISTS {temp}_rooms; @@ -85,7 +86,10 @@ class StatsStore(StateDeltasStore): ON {temp}_rooms(events); CREATE INDEX {temp}_rooms_id ON {temp}_rooms(room_id); - """.format(temp=TEMP_TABLE).splitlines()) + """.format( + temp=TEMP_TABLE + ).splitlines() + ) for statement in stmts: txn.execute(statement) @@ -105,7 +109,9 @@ class StatsStore(StateDeltasStore): LEFT JOIN room_stats_earliest_token AS t USING (room_id) WHERE t.room_id IS NULL GROUP BY c.room_id - """ % (TEMP_TABLE,) + """ % ( + TEMP_TABLE, + ) txn.execute(sql) new_pos = yield self.get_max_stream_id_in_current_state_deltas() @@ -184,7 +190,8 @@ class StatsStore(StateDeltasStore): logger.info( "Processing the next %d rooms of %d remaining", - len(rooms_to_work_on), progress["remaining"], + len(rooms_to_work_on), + progress["remaining"], ) # Number of state events we've processed by going through each room @@ -204,10 +211,17 @@ class StatsStore(StateDeltasStore): avatar_id = current_state_ids.get((EventTypes.RoomAvatar, "")) canonical_alias_id = current_state_ids.get((EventTypes.CanonicalAlias, "")) - state_events = yield self.get_events([ - join_rules_id, history_visibility_id, encryption_id, name_id, - topic_id, avatar_id, canonical_alias_id, - ]) + state_events = yield self.get_events( + [ + join_rules_id, + history_visibility_id, + encryption_id, + name_id, + topic_id, + avatar_id, + canonical_alias_id, + ] + ) def _get_or_none(event_id, arg): event = state_events.get(event_id) @@ -271,7 +285,7 @@ class StatsStore(StateDeltasStore): # We've finished a room. Delete it from the table. self._simple_delete_one_txn( - txn, TEMP_TABLE + "_rooms", {"room_id": room_id}, + txn, TEMP_TABLE + "_rooms", {"room_id": room_id} ) yield self.runInteraction("update_room_stats", _fetch_data) @@ -338,7 +352,7 @@ class StatsStore(StateDeltasStore): "name", "topic", "avatar", - "canonical_alias" + "canonical_alias", ): field = fields.get(col) if field and "\0" in field: diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 6f7f65d96b..d9482a3843 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -65,7 +65,7 @@ _EventDictReturn = namedtuple( def generate_pagination_where_clause( - direction, column_names, from_token, to_token, engine, + direction, column_names, from_token, to_token, engine ): """Creates an SQL expression to bound the columns by the pagination tokens. @@ -153,7 +153,7 @@ def _make_generic_sql_bound(bound, column_names, values, engine): str """ - assert(bound in (">", "<", ">=", "<=")) + assert bound in (">", "<", ">=", "<=") name1, name2 = column_names val1, val2 = values @@ -169,11 +169,7 @@ def _make_generic_sql_bound(bound, column_names, values, engine): # Postgres doesn't optimise ``(x < a) OR (x=a AND y" % ( - id(self), self._result, self._deferred, + id(self), + self._result, + self._deferred, ) @@ -150,10 +153,12 @@ def concurrently_execute(func, args, limit): except StopIteration: pass - return logcontext.make_deferred_yieldable(defer.gatherResults([ - run_in_background(_concurrently_execute_inner) - for _ in range(limit) - ], consumeErrors=True)).addErrback(unwrapFirstError) + return logcontext.make_deferred_yieldable( + defer.gatherResults( + [run_in_background(_concurrently_execute_inner) for _ in range(limit)], + consumeErrors=True, + ) + ).addErrback(unwrapFirstError) def yieldable_gather_results(func, iter, *args, **kwargs): @@ -169,10 +174,12 @@ def yieldable_gather_results(func, iter, *args, **kwargs): Deferred[list]: Resolved when all functions have been invoked, or errors if one of the function calls fails. """ - return logcontext.make_deferred_yieldable(defer.gatherResults([ - run_in_background(func, item, *args, **kwargs) - for item in iter - ], consumeErrors=True)).addErrback(unwrapFirstError) + return logcontext.make_deferred_yieldable( + defer.gatherResults( + [run_in_background(func, item, *args, **kwargs) for item in iter], + consumeErrors=True, + ) + ).addErrback(unwrapFirstError) class Linearizer(object): @@ -185,6 +192,7 @@ class Linearizer(object): # do some work. """ + def __init__(self, name=None, max_count=1, clock=None): """ Args: @@ -197,6 +205,7 @@ class Linearizer(object): if not clock: from twisted.internet import reactor + clock = Clock(reactor) self._clock = clock self.max_count = max_count @@ -221,7 +230,7 @@ class Linearizer(object): res = self._await_lock(key) else: logger.debug( - "Acquired uncontended linearizer lock %r for key %r", self.name, key, + "Acquired uncontended linearizer lock %r for key %r", self.name, key ) entry[0] += 1 res = defer.succeed(None) @@ -266,9 +275,7 @@ class Linearizer(object): """ entry = self.key_to_defer[key] - logger.debug( - "Waiting to acquire linearizer lock %r for key %r", self.name, key, - ) + logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key) new_defer = make_deferred_yieldable(defer.Deferred()) entry[1][new_defer] = 1 @@ -293,14 +300,14 @@ class Linearizer(object): logger.info("defer %r got err %r", new_defer, e) if isinstance(e, CancelledError): logger.debug( - "Cancelling wait for linearizer lock %r for key %r", - self.name, key, + "Cancelling wait for linearizer lock %r for key %r", self.name, key ) else: logger.warn( "Unexpected exception waiting for linearizer lock %r for key %r", - self.name, key, + self.name, + key, ) # we just have to take ourselves back out of the queue. @@ -438,7 +445,7 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None): try: deferred.cancel() - except: # noqa: E722, if we throw any exception it'll break time outs + except: # noqa: E722, if we throw any exception it'll break time outs logger.exception("Canceller failed during timeout") if not new_d.called: diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index f37d5bec08..8271229015 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -104,8 +104,8 @@ def register_cache(cache_type, cache_name, cache): KNOWN_KEYS = { - key: key for key in - ( + key: key + for key in ( "auth_events", "content", "depth", @@ -150,7 +150,7 @@ def intern_dict(dictionary): def _intern_known_values(key, value): - intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key",) + intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key") if key in intern_keys: return intern_string(value) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 187510576a..d2f25063aa 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -40,9 +40,7 @@ _CacheSentinel = object() class CacheEntry(object): - __slots__ = [ - "deferred", "callbacks", "invalidated" - ] + __slots__ = ["deferred", "callbacks", "invalidated"] def __init__(self, deferred, callbacks): self.deferred = deferred @@ -73,7 +71,9 @@ class Cache(object): self._pending_deferred_cache = cache_type() self.cache = LruCache( - max_size=max_entries, keylen=keylen, cache_type=cache_type, + max_size=max_entries, + keylen=keylen, + cache_type=cache_type, size_callback=(lambda d: len(d)) if iterable else None, evicted_callback=self._on_evicted, ) @@ -133,10 +133,7 @@ class Cache(object): def set(self, key, value, callback=None): callbacks = [callback] if callback else [] self.check_thread() - entry = CacheEntry( - deferred=value, - callbacks=callbacks, - ) + entry = CacheEntry(deferred=value, callbacks=callbacks) existing_entry = self._pending_deferred_cache.pop(key, None) if existing_entry: @@ -191,9 +188,7 @@ class Cache(object): def invalidate_many(self, key): self.check_thread() if not isinstance(key, tuple): - raise TypeError( - "The cache key must be a tuple not %r" % (type(key),) - ) + raise TypeError("The cache key must be a tuple not %r" % (type(key),)) self.cache.del_multi(key) # if we have a pending lookup for this key, remove it from the @@ -244,29 +239,25 @@ class _CacheDescriptorBase(object): raise Exception( "Not enough explicit positional arguments to key off for %r: " "got %i args, but wanted %i. (@cached cannot key off *args or " - "**kwargs)" - % (orig.__name__, len(all_args), num_args) + "**kwargs)" % (orig.__name__, len(all_args), num_args) ) self.num_args = num_args # list of the names of the args used as the cache key - self.arg_names = all_args[1:num_args + 1] + self.arg_names = all_args[1 : num_args + 1] # self.arg_defaults is a map of arg name to its default value for each # argument that has a default value if arg_spec.defaults: - self.arg_defaults = dict(zip( - all_args[-len(arg_spec.defaults):], - arg_spec.defaults - )) + self.arg_defaults = dict( + zip(all_args[-len(arg_spec.defaults) :], arg_spec.defaults) + ) else: self.arg_defaults = {} if "cache_context" in self.arg_names: - raise Exception( - "cache_context arg cannot be included among the cache keys" - ) + raise Exception("cache_context arg cannot be included among the cache keys") self.add_cache_context = cache_context @@ -304,12 +295,24 @@ class CacheDescriptor(_CacheDescriptorBase): ``cache_context``) to use as cache keys. Defaults to all named args of the function. """ - def __init__(self, orig, max_entries=1000, num_args=None, tree=False, - inlineCallbacks=False, cache_context=False, iterable=False): + + def __init__( + self, + orig, + max_entries=1000, + num_args=None, + tree=False, + inlineCallbacks=False, + cache_context=False, + iterable=False, + ): super(CacheDescriptor, self).__init__( - orig, num_args=num_args, inlineCallbacks=inlineCallbacks, - cache_context=cache_context) + orig, + num_args=num_args, + inlineCallbacks=inlineCallbacks, + cache_context=cache_context, + ) max_entries = int(max_entries * get_cache_factor_for(orig.__name__)) @@ -356,7 +359,9 @@ class CacheDescriptor(_CacheDescriptorBase): return args[0] else: return self.arg_defaults[nm] + else: + def get_cache_key(args, kwargs): return tuple(get_cache_key_gen(args, kwargs)) @@ -383,8 +388,7 @@ class CacheDescriptor(_CacheDescriptorBase): except KeyError: ret = defer.maybeDeferred( - logcontext.preserve_fn(self.function_to_call), - obj, *args, **kwargs + logcontext.preserve_fn(self.function_to_call), obj, *args, **kwargs ) def onErr(f): @@ -437,8 +441,9 @@ class CacheListDescriptor(_CacheDescriptorBase): results. """ - def __init__(self, orig, cached_method_name, list_name, num_args=None, - inlineCallbacks=False): + def __init__( + self, orig, cached_method_name, list_name, num_args=None, inlineCallbacks=False + ): """ Args: orig (function) @@ -451,7 +456,8 @@ class CacheListDescriptor(_CacheDescriptorBase): be wrapped by defer.inlineCallbacks """ super(CacheListDescriptor, self).__init__( - orig, num_args=num_args, inlineCallbacks=inlineCallbacks) + orig, num_args=num_args, inlineCallbacks=inlineCallbacks + ) self.list_name = list_name @@ -463,7 +469,7 @@ class CacheListDescriptor(_CacheDescriptorBase): if self.list_name not in self.arg_names: raise Exception( "Couldn't see arguments %r for %r." - % (self.list_name, cached_method_name,) + % (self.list_name, cached_method_name) ) def __get__(self, obj, objtype=None): @@ -494,8 +500,10 @@ class CacheListDescriptor(_CacheDescriptorBase): # If the cache takes a single arg then that is used as the key, # otherwise a tuple is used. if num_args == 1: + def arg_to_cache_key(arg): return arg + else: keylist = list(keyargs) @@ -505,8 +513,7 @@ class CacheListDescriptor(_CacheDescriptorBase): for arg in list_args: try: - res = cache.get(arg_to_cache_key(arg), - callback=invalidate_callback) + res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback) if not isinstance(res, ObservableDeferred): results[arg] = res elif not res.has_succeeded(): @@ -554,18 +561,15 @@ class CacheListDescriptor(_CacheDescriptorBase): args_to_call = dict(arg_dict) args_to_call[self.list_name] = list(missing) - cached_defers.append(defer.maybeDeferred( - logcontext.preserve_fn(self.function_to_call), - **args_to_call - ).addCallbacks(complete_all, errback)) + cached_defers.append( + defer.maybeDeferred( + logcontext.preserve_fn(self.function_to_call), **args_to_call + ).addCallbacks(complete_all, errback) + ) if cached_defers: - d = defer.gatherResults( - cached_defers, - consumeErrors=True, - ).addCallbacks( - lambda _: results, - unwrapFirstError + d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks( + lambda _: results, unwrapFirstError ) return logcontext.make_deferred_yieldable(d) else: @@ -586,8 +590,9 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))): self.cache.invalidate(self.key) -def cached(max_entries=1000, num_args=None, tree=False, cache_context=False, - iterable=False): +def cached( + max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False +): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, @@ -598,8 +603,9 @@ def cached(max_entries=1000, num_args=None, tree=False, cache_context=False, ) -def cachedInlineCallbacks(max_entries=1000, num_args=None, tree=False, - cache_context=False, iterable=False): +def cachedInlineCallbacks( + max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False +): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index 6c0b5a4094..6834e6f3ae 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -35,6 +35,7 @@ class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "va there. value (dict): The full or partial dict value """ + def __len__(self): return len(self.value) @@ -84,13 +85,15 @@ class DictionaryCache(object): self.metrics.inc_hits() if dict_keys is None: - return DictionaryEntry(entry.full, entry.known_absent, dict(entry.value)) + return DictionaryEntry( + entry.full, entry.known_absent, dict(entry.value) + ) else: - return DictionaryEntry(entry.full, entry.known_absent, { - k: entry.value[k] - for k in dict_keys - if k in entry.value - }) + return DictionaryEntry( + entry.full, + entry.known_absent, + {k: entry.value[k] for k in dict_keys if k in entry.value}, + ) self.metrics.inc_misses() return DictionaryEntry(False, set(), {}) diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index f369780277..cddf1ed515 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -28,8 +28,15 @@ SENTINEL = object() class ExpiringCache(object): - def __init__(self, cache_name, clock, max_len=0, expiry_ms=0, - reset_expiry_on_get=False, iterable=False): + def __init__( + self, + cache_name, + clock, + max_len=0, + expiry_ms=0, + reset_expiry_on_get=False, + iterable=False, + ): """ Args: cache_name (str): Name of this cache, used for logging. @@ -67,8 +74,7 @@ class ExpiringCache(object): def f(): return run_as_background_process( - "prune_cache_%s" % self._cache_name, - self._prune_cache, + "prune_cache_%s" % self._cache_name, self._prune_cache ) self._clock.looping_call(f, self._expiry_ms / 2) @@ -153,7 +159,9 @@ class ExpiringCache(object): logger.debug( "[%s] _prune_cache before: %d, after len: %d", - self._cache_name, begin_length, len(self) + self._cache_name, + begin_length, + len(self), ) def __len__(self): diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index b684f24e7b..1536cb64f3 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -49,8 +49,15 @@ class LruCache(object): Can also set callbacks on objects when getting/setting which are fired when that key gets invalidated/evicted. """ - def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None, - evicted_callback=None): + + def __init__( + self, + max_size, + keylen=1, + cache_type=dict, + size_callback=None, + evicted_callback=None, + ): """ Args: max_size (int): @@ -93,9 +100,12 @@ class LruCache(object): cached_cache_len = [0] if size_callback is not None: + def cache_len(): return cached_cache_len[0] + else: + def cache_len(): return len(cache) diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index afb03b2e1b..b1da81633c 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -35,12 +35,10 @@ class ResponseCache(object): self.pending_result_cache = {} # Requests that haven't finished yet. self.clock = hs.get_clock() - self.timeout_sec = timeout_ms / 1000. + self.timeout_sec = timeout_ms / 1000.0 self._name = name - self._metrics = register_cache( - "response_cache", name, self - ) + self._metrics = register_cache("response_cache", name, self) def size(self): return len(self.pending_result_cache) @@ -100,8 +98,7 @@ class ResponseCache(object): def remove(r): if self.timeout_sec: self.clock.call_later( - self.timeout_sec, - self.pending_result_cache.pop, key, None, + self.timeout_sec, self.pending_result_cache.pop, key, None ) else: self.pending_result_cache.pop(key, None) @@ -147,14 +144,15 @@ class ResponseCache(object): """ result = self.get(key) if not result: - logger.info("[%s]: no cached result for [%s], calculating new one", - self._name, key) + logger.info( + "[%s]: no cached result for [%s], calculating new one", self._name, key + ) d = run_in_background(callback, *args, **kwargs) result = self.set(key, d) elif not isinstance(result, defer.Deferred) or result.called: - logger.info("[%s]: using completed cached result for [%s]", - self._name, key) + logger.info("[%s]: using completed cached result for [%s]", self._name, key) else: - logger.info("[%s]: using incomplete cached result for [%s]", - self._name, key) + logger.info( + "[%s]: using incomplete cached result for [%s]", self._name, key + ) return make_deferred_yieldable(result) diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 625aedc940..235f64049c 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -77,9 +77,8 @@ class StreamChangeCache(object): if stream_pos >= self._earliest_known_stream_pos: changed_entities = { - self._cache[k] for k in self._cache.islice( - start=self._cache.bisect_right(stream_pos), - ) + self._cache[k] + for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)) } result = changed_entities.intersection(entities) @@ -114,8 +113,10 @@ class StreamChangeCache(object): assert type(stream_pos) is int if stream_pos >= self._earliest_known_stream_pos: - return [self._cache[k] for k in self._cache.islice( - start=self._cache.bisect_right(stream_pos))] + return [ + self._cache[k] + for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)) + ] else: return None @@ -136,7 +137,7 @@ class StreamChangeCache(object): while len(self._cache) > self._max_size: k, r = self._cache.popitem(0) self._earliest_known_stream_pos = max( - k, self._earliest_known_stream_pos, + k, self._earliest_known_stream_pos ) self._entity_to_key.pop(r, None) diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index dd4c9e6067..9a72218d85 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -9,6 +9,7 @@ class TreeCache(object): efficiently. Keys must be tuples. """ + def __init__(self): self.size = 0 self.root = {} diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py index 5ba1862506..2af8ca43b1 100644 --- a/synapse/util/caches/ttlcache.py +++ b/synapse/util/caches/ttlcache.py @@ -155,6 +155,7 @@ class TTLCache(object): @attr.s(frozen=True, slots=True) class _CacheEntry(object): """TTLCache entry""" + # expiry_time is the first attribute, so that entries are sorted by expiry. expiry_time = attr.ib() key = attr.ib() diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index e14c8bdfda..5a79db821c 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -51,9 +51,7 @@ class Distributor(object): if name in self.signals: raise KeyError("%r already has a signal named %s" % (self, name)) - self.signals[name] = Signal( - name, - ) + self.signals[name] = Signal(name) if name in self.pre_registration: signal = self.signals[name] @@ -78,11 +76,7 @@ class Distributor(object): if name not in self.signals: raise KeyError("%r does not have a signal named %s" % (self, name)) - run_as_background_process( - name, - self.signals[name].fire, - *args, **kwargs - ) + run_as_background_process(name, self.signals[name].fire, *args, **kwargs) class Signal(object): @@ -118,22 +112,23 @@ class Signal(object): def eb(failure): logger.warning( "%s signal observer %s failed: %r", - self.name, observer, failure, + self.name, + observer, + failure, exc_info=( failure.type, failure.value, - failure.getTracebackObject())) + failure.getTracebackObject(), + ), + ) return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb) - deferreds = [ - run_in_background(do, o) - for o in self.observers - ] + deferreds = [run_in_background(do, o) for o in self.observers] - return make_deferred_yieldable(defer.gatherResults( - deferreds, consumeErrors=True, - )) + return make_deferred_yieldable( + defer.gatherResults(deferreds, consumeErrors=True) + ) def __repr__(self): return "" % (self.name,) diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py index 014edea971..635b897d6c 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py @@ -60,11 +60,10 @@ def _handle_frozendict(obj): # fishing the protected dict out of the object is a bit nasty, # but we don't really want the overhead of copying the dict. return obj._dict - raise TypeError('Object of type %s is not JSON serializable' % - obj.__class__.__name__) + raise TypeError( + "Object of type %s is not JSON serializable" % obj.__class__.__name__ + ) # A JSONEncoder which is capable of encoding frozendics without barfing -frozendict_json_encoder = json.JSONEncoder( - default=_handle_frozendict, -) +frozendict_json_encoder = json.JSONEncoder(default=_handle_frozendict) diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py index 2d7ddc1cbe..1a20c596bf 100644 --- a/synapse/util/httpresourcetree.py +++ b/synapse/util/httpresourcetree.py @@ -45,7 +45,7 @@ def create_resource_tree(desired_tree, root_resource): logger.info("Attaching %s to path %s", res, full_path) last_resource = root_resource - for path_seg in full_path.split(b'/')[1:-1]: + for path_seg in full_path.split(b"/")[1:-1]: if path_seg not in last_resource.listNames(): # resource doesn't exist, so make a "dummy resource" child_resource = NoResource() @@ -60,7 +60,7 @@ def create_resource_tree(desired_tree, root_resource): # =========================== # now attach the actual desired resource - last_path_seg = full_path.split(b'/')[-1] + last_path_seg = full_path.split(b"/")[-1] # if there is already a resource here, thieve its children and # replace it @@ -70,9 +70,7 @@ def create_resource_tree(desired_tree, root_resource): # to be replaced with the desired resource. existing_dummy_resource = resource_mappings[res_id] for child_name in existing_dummy_resource.listNames(): - child_res_id = _resource_id( - existing_dummy_resource, child_name - ) + child_res_id = _resource_id(existing_dummy_resource, child_name) child_resource = resource_mappings[child_res_id] # steal the children res.putChild(child_name, child_resource) diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py index d668e5a6b8..6dce03dd3a 100644 --- a/synapse/util/jsonobject.py +++ b/synapse/util/jsonobject.py @@ -70,7 +70,8 @@ class JsonEncodedObject(object): dict """ d = { - k: _encode(v) for (k, v) in self.__dict__.items() + k: _encode(v) + for (k, v) in self.__dict__.items() if k in self.valid_keys and k not in self.internal_keys } d.update(self.unrecognized_keys) @@ -78,7 +79,8 @@ class JsonEncodedObject(object): def get_internal_dict(self): d = { - k: _encode(v, internal=True) for (k, v) in self.__dict__.items() + k: _encode(v, internal=True) + for (k, v) in self.__dict__.items() if k in self.valid_keys } d.update(self.unrecognized_keys) diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index fe412355d8..a9885cb507 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -42,6 +42,8 @@ try: def get_thread_resource_usage(): return resource.getrusage(RUSAGE_THREAD) + + except Exception: # If the system doesn't support resource.getrusage(RUSAGE_THREAD) then we # won't track resource usage by returning None. @@ -64,8 +66,11 @@ class ContextResourceUsage(object): """ __slots__ = [ - "ru_stime", "ru_utime", - "db_txn_count", "db_txn_duration_sec", "db_sched_duration_sec", + "ru_stime", + "ru_utime", + "db_txn_count", + "db_txn_duration_sec", + "db_sched_duration_sec", "evt_db_fetch_count", ] @@ -91,8 +96,8 @@ class ContextResourceUsage(object): return ContextResourceUsage(copy_from=self) def reset(self): - self.ru_stime = 0. - self.ru_utime = 0. + self.ru_stime = 0.0 + self.ru_utime = 0.0 self.db_txn_count = 0 self.db_txn_duration_sec = 0 @@ -100,15 +105,18 @@ class ContextResourceUsage(object): self.evt_db_fetch_count = 0 def __repr__(self): - return ("") % ( - self.ru_stime, - self.ru_utime, - self.db_txn_count, - self.db_txn_duration_sec, - self.db_sched_duration_sec, - self.evt_db_fetch_count,) + return ( + "" + ) % ( + self.ru_stime, + self.ru_utime, + self.db_txn_count, + self.db_txn_duration_sec, + self.db_sched_duration_sec, + self.evt_db_fetch_count, + ) def __iadd__(self, other): """Add another ContextResourceUsage's stats to this one's. @@ -159,11 +167,15 @@ class LoggingContext(object): """ __slots__ = [ - "previous_context", "name", "parent_context", + "previous_context", + "name", + "parent_context", "_resource_usage", "usage_start", - "main_thread", "alive", - "request", "tag", + "main_thread", + "alive", + "request", + "tag", ] thread_local = threading.local() @@ -196,6 +208,7 @@ class LoggingContext(object): def __nonzero__(self): return False + __bool__ = __nonzero__ # python3 sentinel = Sentinel() @@ -261,7 +274,8 @@ class LoggingContext(object): if self.previous_context != old_context: logger.warn( "Expected previous context %r, found %r", - self.previous_context, old_context + self.previous_context, + old_context, ) self.alive = True @@ -285,9 +299,8 @@ class LoggingContext(object): self.alive = False # if we have a parent, pass our CPU usage stats on - if ( - self.parent_context is not None - and hasattr(self.parent_context, '_resource_usage') + if self.parent_context is not None and hasattr( + self.parent_context, "_resource_usage" ): self.parent_context._resource_usage += self._resource_usage @@ -320,9 +333,7 @@ class LoggingContext(object): # When we stop, let's record the cpu used since we started if not self.usage_start: - logger.warning( - "Called stop on logcontext %s without calling start", self, - ) + logger.warning("Called stop on logcontext %s without calling start", self) return usage_end = get_thread_resource_usage() @@ -381,6 +392,7 @@ class LoggingContextFilter(logging.Filter): **defaults: Default values to avoid formatters complaining about missing fields """ + def __init__(self, **defaults): self.defaults = defaults @@ -416,17 +428,12 @@ class PreserveLoggingContext(object): def __enter__(self): """Captures the current logging context""" - self.current_context = LoggingContext.set_current_context( - self.new_context - ) + self.current_context = LoggingContext.set_current_context(self.new_context) if self.current_context: self.has_parent = self.current_context.previous_context is not None if not self.current_context.alive: - logger.debug( - "Entering dead context: %s", - self.current_context, - ) + logger.debug("Entering dead context: %s", self.current_context) def __exit__(self, type, value, traceback): """Restores the current logging context""" @@ -444,10 +451,7 @@ class PreserveLoggingContext(object): if self.current_context is not LoggingContext.sentinel: if not self.current_context.alive: - logger.debug( - "Restoring dead context: %s", - self.current_context, - ) + logger.debug("Restoring dead context: %s", self.current_context) def nested_logging_context(suffix, parent_context=None): @@ -474,15 +478,16 @@ def nested_logging_context(suffix, parent_context=None): if parent_context is None: parent_context = LoggingContext.current_context() return LoggingContext( - parent_context=parent_context, - request=parent_context.request + "-" + suffix, + parent_context=parent_context, request=parent_context.request + "-" + suffix ) def preserve_fn(f): """Function decorator which wraps the function with run_in_background""" + def g(*args, **kwargs): return run_in_background(f, *args, **kwargs) + return g @@ -502,7 +507,7 @@ def run_in_background(f, *args, **kwargs): current = LoggingContext.current_context() try: res = f(*args, **kwargs) - except: # noqa: E722 + except: # noqa: E722 # the assumption here is that the caller doesn't want to be disturbed # by synchronous exceptions, so let's turn them into Failures. return defer.fail() @@ -639,6 +644,4 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs): with LoggingContext(parent_context=logcontext): return f(*args, **kwargs) - return make_deferred_yieldable( - threads.deferToThreadPool(reactor, threadpool, g) - ) + return make_deferred_yieldable(threads.deferToThreadPool(reactor, threadpool, g)) diff --git a/synapse/util/logformatter.py b/synapse/util/logformatter.py index a46bc47ce3..fbf570c756 100644 --- a/synapse/util/logformatter.py +++ b/synapse/util/logformatter.py @@ -29,6 +29,7 @@ class LogFormatter(logging.Formatter): (Normally only stack frames between the point the exception was raised and where it was caught are logged). """ + def __init__(self, *args, **kwargs): super(LogFormatter, self).__init__(*args, **kwargs) @@ -40,7 +41,7 @@ class LogFormatter(logging.Formatter): # check that we actually have an f_back attribute to work around # https://twistedmatrix.com/trac/ticket/9305 - if tb and hasattr(tb.tb_frame, 'f_back'): + if tb and hasattr(tb.tb_frame, "f_back"): sio.write("Capture point (most recent call last):\n") traceback.print_stack(tb.tb_frame.f_back, None, sio) diff --git a/synapse/util/logutils.py b/synapse/util/logutils.py index ef31458226..7df0fa6087 100644 --- a/synapse/util/logutils.py +++ b/synapse/util/logutils.py @@ -44,7 +44,7 @@ def _log_debug_as_f(f, msg, msg_args): lineno=lineno, msg=msg, args=msg_args, - exc_info=None + exc_info=None, ) logger.handle(record) @@ -70,20 +70,11 @@ def log_function(f): r = r[:50] + "..." return r - func_args = [ - "%s=%s" % (k, format(v)) for k, v in bound_args.items() - ] + func_args = ["%s=%s" % (k, format(v)) for k, v in bound_args.items()] - msg_args = { - "func_name": func_name, - "args": ", ".join(func_args) - } + msg_args = {"func_name": func_name, "args": ", ".join(func_args)} - _log_debug_as_f( - f, - "Invoked '%(func_name)s' with args: %(args)s", - msg_args - ) + _log_debug_as_f(f, "Invoked '%(func_name)s' with args: %(args)s", msg_args) return f(*args, **kwargs) @@ -103,19 +94,13 @@ def time_function(f): start = time.clock() try: - _log_debug_as_f( - f, - "[FUNC START] {%s-%d}", - (func_name, id), - ) + _log_debug_as_f(f, "[FUNC START] {%s-%d}", (func_name, id)) r = f(*args, **kwargs) finally: end = time.clock() _log_debug_as_f( - f, - "[FUNC END] {%s-%d} %.3f sec", - (func_name, id, end - start,), + f, "[FUNC END] {%s-%d} %.3f sec", (func_name, id, end - start) ) return r @@ -137,9 +122,8 @@ def trace_function(f): s = inspect.currentframe().f_back to_print = [ - "\t%s:%s %s. Args: args=%s, kwargs=%s" % ( - pathname, linenum, func_name, args, kwargs - ) + "\t%s:%s %s. Args: args=%s, kwargs=%s" + % (pathname, linenum, func_name, args, kwargs) ] while s: if True or s.f_globals["__name__"].startswith("synapse"): @@ -147,9 +131,7 @@ def trace_function(f): args_string = inspect.formatargvalues(*inspect.getargvalues(s)) to_print.append( - "\t%s:%d %s. Args: %s" % ( - filename, lineno, function, args_string - ) + "\t%s:%d %s. Args: %s" % (filename, lineno, function, args_string) ) s = s.f_back @@ -163,7 +145,7 @@ def trace_function(f): lineno=lineno, msg=msg, args=None, - exc_info=None + exc_info=None, ) logger.handle(record) @@ -182,13 +164,13 @@ def get_previous_frames(): filename, lineno, function, _, _ = inspect.getframeinfo(s) args_string = inspect.formatargvalues(*inspect.getargvalues(s)) - to_return.append("{{ %s:%d %s - Args: %s }}" % ( - filename, lineno, function, args_string - )) + to_return.append( + "{{ %s:%d %s - Args: %s }}" % (filename, lineno, function, args_string) + ) s = s.f_back - return ", ". join(to_return) + return ", ".join(to_return) def get_previous_frame(ignore=[]): @@ -201,7 +183,10 @@ def get_previous_frame(ignore=[]): args_string = inspect.formatargvalues(*inspect.getargvalues(s)) return "{{ %s:%d %s - Args: %s }}" % ( - filename, lineno, function, args_string + filename, + lineno, + function, + args_string, ) s = s.f_back diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py index 628a2962d9..631654f297 100644 --- a/synapse/util/manhole.py +++ b/synapse/util/manhole.py @@ -74,27 +74,25 @@ def manhole(username, password, globals): twisted.internet.protocol.Factory: A factory to pass to ``listenTCP`` """ if not isinstance(password, bytes): - password = password.encode('ascii') + password = password.encode("ascii") - checker = checkers.InMemoryUsernamePasswordDatabaseDontUse( - **{username: password} - ) + checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(**{username: password}) rlm = manhole_ssh.TerminalRealm() rlm.chainedProtocolFactory = lambda: insults.ServerProtocol( - SynapseManhole, - dict(globals, __name__="__console__") + SynapseManhole, dict(globals, __name__="__console__") ) factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker])) - factory.publicKeys[b'ssh-rsa'] = Key.fromString(PUBLIC_KEY) - factory.privateKeys[b'ssh-rsa'] = Key.fromString(PRIVATE_KEY) + factory.publicKeys[b"ssh-rsa"] = Key.fromString(PUBLIC_KEY) + factory.privateKeys[b"ssh-rsa"] = Key.fromString(PRIVATE_KEY) return factory class SynapseManhole(ColoredManhole): """Overrides connectionMade to create our own ManholeInterpreter""" + def connectionMade(self): super(SynapseManhole, self).connectionMade() @@ -127,7 +125,7 @@ class SynapseManholeInterpreter(ManholeInterpreter): value = SyntaxError(msg, (filename, lineno, offset, line)) sys.last_value = value lines = traceback.format_exception_only(type, value) - self.write(''.join(lines)) + self.write("".join(lines)) def showtraceback(self): """Display the exception that just occurred. @@ -140,6 +138,6 @@ class SynapseManholeInterpreter(ManholeInterpreter): try: # We remove the first stack item because it is our own code. lines = traceback.format_exception(ei[0], ei[1], last_tb.tb_next) - self.write(''.join(lines)) + self.write("".join(lines)) finally: last_tb = ei = None diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 4b4ac5f6c7..01284d3cf8 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -30,25 +30,31 @@ block_counter = Counter("synapse_util_metrics_block_count", "", ["block_name"]) block_timer = Counter("synapse_util_metrics_block_time_seconds", "", ["block_name"]) block_ru_utime = Counter( - "synapse_util_metrics_block_ru_utime_seconds", "", ["block_name"]) + "synapse_util_metrics_block_ru_utime_seconds", "", ["block_name"] +) block_ru_stime = Counter( - "synapse_util_metrics_block_ru_stime_seconds", "", ["block_name"]) + "synapse_util_metrics_block_ru_stime_seconds", "", ["block_name"] +) block_db_txn_count = Counter( - "synapse_util_metrics_block_db_txn_count", "", ["block_name"]) + "synapse_util_metrics_block_db_txn_count", "", ["block_name"] +) # seconds spent waiting for db txns, excluding scheduling time, in this block block_db_txn_duration = Counter( - "synapse_util_metrics_block_db_txn_duration_seconds", "", ["block_name"]) + "synapse_util_metrics_block_db_txn_duration_seconds", "", ["block_name"] +) # seconds spent waiting for a db connection, in this block block_db_sched_duration = Counter( - "synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"]) + "synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"] +) # Tracks the number of blocks currently active in_flight = InFlightGauge( - "synapse_util_metrics_block_in_flight", "", + "synapse_util_metrics_block_in_flight", + "", labels=["block_name"], sub_metrics=["real_time_max", "real_time_sum"], ) @@ -62,13 +68,18 @@ def measure_func(name): with Measure(self.clock, name): r = yield func(self, *args, **kwargs) defer.returnValue(r) + return measured_func + return wrapper class Measure(object): __slots__ = [ - "clock", "name", "start_context", "start", + "clock", + "name", + "start_context", + "start", "created_context", "start_usage", ] @@ -108,7 +119,9 @@ class Measure(object): if context != self.start_context: logger.warn( "Context has unexpectedly changed from '%s' to '%s'. (%r)", - self.start_context, context, self.name + self.start_context, + context, + self.name, ) return @@ -126,8 +139,7 @@ class Measure(object): block_db_sched_duration.labels(self.name).inc(usage.db_sched_duration_sec) except ValueError: logger.warn( - "Failed to save metrics! OLD: %r, NEW: %r", - self.start_usage, current + "Failed to save metrics! OLD: %r, NEW: %r", self.start_usage, current ) if self.created_context: diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py index 4288312b8a..522acd5aa8 100644 --- a/synapse/util/module_loader.py +++ b/synapse/util/module_loader.py @@ -28,15 +28,13 @@ def load_module(provider): """ # We need to import the module, and then pick the class out of # that, so we split based on the last dot. - module, clz = provider['module'].rsplit(".", 1) + module, clz = provider["module"].rsplit(".", 1) module = importlib.import_module(module) provider_class = getattr(module, clz) try: provider_config = provider_class.parse_config(provider["config"]) except Exception as e: - raise ConfigError( - "Failed to parse config for %r: %r" % (provider['module'], e) - ) + raise ConfigError("Failed to parse config for %r: %r" % (provider["module"], e)) return provider_class, provider_config diff --git a/synapse/util/msisdn.py b/synapse/util/msisdn.py index a6c30e5265..c8bcbe297a 100644 --- a/synapse/util/msisdn.py +++ b/synapse/util/msisdn.py @@ -36,6 +36,6 @@ def phone_number_to_msisdn(country, number): phoneNumber = phonenumbers.parse(number, country) except phonenumbers.NumberParseException: raise SynapseError(400, "Unable to parse phone number") - return phonenumbers.format_number( - phoneNumber, phonenumbers.PhoneNumberFormat.E164 - )[1:] + return phonenumbers.format_number(phoneNumber, phonenumbers.PhoneNumberFormat.E164)[ + 1: + ] diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index b146d137f4..06defa8199 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -56,11 +56,7 @@ class FederationRateLimiter(object): _PerHostRatelimiter """ return self.ratelimiters.setdefault( - host, - _PerHostRatelimiter( - clock=self.clock, - config=self._config, - ) + host, _PerHostRatelimiter(clock=self.clock, config=self._config) ).ratelimit() @@ -112,8 +108,7 @@ class _PerHostRatelimiter(object): # remove any entries from request_times which aren't within the window self.request_times[:] = [ - r for r in self.request_times - if time_now - r < self.window_size + r for r in self.request_times if time_now - r < self.window_size ] # reject the request if we already have too many queued up (either @@ -121,9 +116,7 @@ class _PerHostRatelimiter(object): queue_size = len(self.ready_request_queue) + len(self.sleeping_requests) if queue_size > self.reject_limit: raise LimitExceededError( - retry_after_ms=int( - self.window_size / self.sleep_limit - ), + retry_after_ms=int(self.window_size / self.sleep_limit) ) self.request_times.append(time_now) @@ -143,22 +136,18 @@ class _PerHostRatelimiter(object): logger.debug( "Ratelimit [%s]: len(self.request_times)=%d", - id(request_id), len(self.request_times), + id(request_id), + len(self.request_times), ) if len(self.request_times) > self.sleep_limit: - logger.debug( - "Ratelimiter: sleeping request for %f sec", self.sleep_sec, - ) + logger.debug("Ratelimiter: sleeping request for %f sec", self.sleep_sec) ret_defer = run_in_background(self.clock.sleep, self.sleep_sec) self.sleeping_requests.add(request_id) def on_wait_finished(_): - logger.debug( - "Ratelimit [%s]: Finished sleeping", - id(request_id), - ) + logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id)) self.sleeping_requests.discard(request_id) queue_defer = queue_request() return queue_defer @@ -168,10 +157,7 @@ class _PerHostRatelimiter(object): ret_defer = queue_request() def on_start(r): - logger.debug( - "Ratelimit [%s]: Processing req", - id(request_id), - ) + logger.debug("Ratelimit [%s]: Processing req", id(request_id)) self.current_processing.add(request_id) return r @@ -193,10 +179,7 @@ class _PerHostRatelimiter(object): return make_deferred_yieldable(ret_defer) def _on_exit(self, request_id): - logger.debug( - "Ratelimit [%s]: Processed req", - id(request_id), - ) + logger.debug("Ratelimit [%s]: Processed req", id(request_id)) self.current_processing.discard(request_id) try: # start processing the next item on the queue. diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 69dffd8244..982c6d81ca 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -20,9 +20,7 @@ import six from six import PY2, PY3 from six.moves import range -_string_with_symbols = ( - string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" -) +_string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" # random_string and random_string_with_symbols are used for a range of things, # some cryptographically important, some less so. We use SystemRandom to make sure @@ -31,13 +29,11 @@ rand = random.SystemRandom() def random_string(length): - return ''.join(rand.choice(string.ascii_letters) for _ in range(length)) + return "".join(rand.choice(string.ascii_letters) for _ in range(length)) def random_string_with_symbols(length): - return ''.join( - rand.choice(_string_with_symbols) for _ in range(length) - ) + return "".join(rand.choice(_string_with_symbols) for _ in range(length)) def is_ascii(s): @@ -45,7 +41,7 @@ def is_ascii(s): if PY3: if isinstance(s, bytes): try: - s.decode('ascii').encode('ascii') + s.decode("ascii").encode("ascii") except UnicodeDecodeError: return False except UnicodeEncodeError: @@ -104,12 +100,12 @@ def exception_to_unicode(e): # and instead look at what is in the args member. if len(e.args) == 0: - return u"" + return "" elif len(e.args) > 1: return six.text_type(repr(e.args)) msg = e.args[0] if isinstance(msg, bytes): - return msg.decode('utf-8', errors='replace') + return msg.decode("utf-8", errors="replace") else: return msg diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py index 75efa0117b..3ec1dfb0c2 100644 --- a/synapse/util/threepids.py +++ b/synapse/util/threepids.py @@ -35,11 +35,13 @@ def check_3pid_allowed(hs, medium, address): for constraint in hs.config.allowed_local_3pids: logger.debug( "Checking 3PID %s (%s) against %s (%s)", - address, medium, constraint['pattern'], constraint['medium'], + address, + medium, + constraint["pattern"], + constraint["medium"], ) - if ( - medium == constraint['medium'] and - re.match(constraint['pattern'], address) + if medium == constraint["medium"] and re.match( + constraint["pattern"], address ): return True else: diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py index 3baba3225a..a4d9a462f7 100644 --- a/synapse/util/versionstring.py +++ b/synapse/util/versionstring.py @@ -23,44 +23,53 @@ logger = logging.getLogger(__name__) def get_version_string(module): try: - null = open(os.devnull, 'w') + null = open(os.devnull, "w") cwd = os.path.dirname(os.path.abspath(module.__file__)) try: - git_branch = subprocess.check_output( - ['git', 'rev-parse', '--abbrev-ref', 'HEAD'], - stderr=null, - cwd=cwd, - ).strip().decode('ascii') + git_branch = ( + subprocess.check_output( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], stderr=null, cwd=cwd + ) + .strip() + .decode("ascii") + ) git_branch = "b=" + git_branch except subprocess.CalledProcessError: git_branch = "" try: - git_tag = subprocess.check_output( - ['git', 'describe', '--exact-match'], - stderr=null, - cwd=cwd, - ).strip().decode('ascii') + git_tag = ( + subprocess.check_output( + ["git", "describe", "--exact-match"], stderr=null, cwd=cwd + ) + .strip() + .decode("ascii") + ) git_tag = "t=" + git_tag except subprocess.CalledProcessError: git_tag = "" try: - git_commit = subprocess.check_output( - ['git', 'rev-parse', '--short', 'HEAD'], - stderr=null, - cwd=cwd, - ).strip().decode('ascii') + git_commit = ( + subprocess.check_output( + ["git", "rev-parse", "--short", "HEAD"], stderr=null, cwd=cwd + ) + .strip() + .decode("ascii") + ) except subprocess.CalledProcessError: git_commit = "" try: dirty_string = "-this_is_a_dirty_checkout" - is_dirty = subprocess.check_output( - ['git', 'describe', '--dirty=' + dirty_string], - stderr=null, - cwd=cwd, - ).strip().decode('ascii').endswith(dirty_string) + is_dirty = ( + subprocess.check_output( + ["git", "describe", "--dirty=" + dirty_string], stderr=null, cwd=cwd + ) + .strip() + .decode("ascii") + .endswith(dirty_string) + ) git_dirty = "dirty" if is_dirty else "" except subprocess.CalledProcessError: @@ -68,16 +77,10 @@ def get_version_string(module): if git_branch or git_tag or git_commit or git_dirty: git_version = ",".join( - s for s in - (git_branch, git_tag, git_commit, git_dirty,) - if s + s for s in (git_branch, git_tag, git_commit, git_dirty) if s ) - return ( - "%s (%s)" % ( - module.__version__, git_version, - ) - ) + return "%s (%s)" % (module.__version__, git_version) except Exception as e: logger.info("Failed to check for git repository: %s", e) diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py index 7a9e45aca9..9bf6a44f75 100644 --- a/synapse/util/wheel_timer.py +++ b/synapse/util/wheel_timer.py @@ -69,9 +69,7 @@ class WheelTimer(object): # Add empty entries between the end of the current list and when we want # to insert. This ensures there are no gaps. - self.entries.extend( - _Entry(key) for key in range(last_key, then_key + 1) - ) + self.entries.extend(_Entry(key) for key in range(last_key, then_key + 1)) self.entries[-1].queue.append(obj) diff --git a/synapse/visibility.py b/synapse/visibility.py index 16c40cd74c..2a11c83596 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -29,12 +29,7 @@ from synapse.types import get_domain_from_id logger = logging.getLogger(__name__) -VISIBILITY_PRIORITY = ( - "world_readable", - "shared", - "invited", - "joined", -) +VISIBILITY_PRIORITY = ("world_readable", "shared", "invited", "joined") MEMBERSHIP_PRIORITY = ( @@ -47,8 +42,9 @@ MEMBERSHIP_PRIORITY = ( @defer.inlineCallbacks -def filter_events_for_client(store, user_id, events, is_peeking=False, - always_include_ids=frozenset()): +def filter_events_for_client( + store, user_id, events, is_peeking=False, always_include_ids=frozenset() +): """ Check which events a user is allowed to see @@ -71,23 +67,21 @@ def filter_events_for_client(store, user_id, events, is_peeking=False, # to clients. events = list(e for e in events if not e.internal_metadata.is_soft_failed()) - types = ( - (EventTypes.RoomHistoryVisibility, ""), - (EventTypes.Member, user_id), - ) + types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id)) event_id_to_state = yield store.get_state_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types(types), ) ignore_dict_content = yield store.get_global_account_data_by_type_for_user( - "m.ignored_user_list", user_id, + "m.ignored_user_list", user_id ) # FIXME: This will explode if people upload something incorrect. ignore_list = frozenset( ignore_dict_content.get("ignored_users", {}).keys() - if ignore_dict_content else [] + if ignore_dict_content + else [] ) erased_senders = yield store.are_users_erased((e.sender for e in events)) @@ -185,9 +179,7 @@ def filter_events_for_client(store, user_id, events, is_peeking=False, elif visibility == "invited": # user can also see the event if they were *invited* at the time # of the event. - return ( - event if membership == Membership.INVITE else None - ) + return event if membership == Membership.INVITE else None elif visibility == "shared" and is_peeking: # if the visibility is shared, users cannot see the event unless @@ -220,8 +212,9 @@ def filter_events_for_client(store, user_id, events, is_peeking=False, @defer.inlineCallbacks -def filter_events_for_server(store, server_name, events, redact=True, - check_history_visibility_only=False): +def filter_events_for_server( + store, server_name, events, redact=True, check_history_visibility_only=False +): """Filter a list of events based on whether given server is allowed to see them. @@ -242,15 +235,12 @@ def filter_events_for_server(store, server_name, events, redact=True, def is_sender_erased(event, erased_senders): if erased_senders and erased_senders[event.sender]: - logger.info( - "Sender of %s has been erased, redacting", - event.event_id, - ) + logger.info("Sender of %s has been erased, redacting", event.event_id) return True return False def check_event_is_visible(event, state): - history = state.get((EventTypes.RoomHistoryVisibility, ''), None) + history = state.get((EventTypes.RoomHistoryVisibility, ""), None) if history: visibility = history.content.get("history_visibility", "shared") if visibility in ["invited", "joined"]: @@ -287,8 +277,8 @@ def filter_events_for_server(store, server_name, events, redact=True, event_to_state_ids = yield store.get_state_ids_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types( - types=((EventTypes.RoomHistoryVisibility, ""),), - ) + types=((EventTypes.RoomHistoryVisibility, ""),) + ), ) visibility_ids = set() @@ -309,9 +299,7 @@ def filter_events_for_server(store, server_name, events, redact=True, ) if not check_history_visibility_only: - erased_senders = yield store.are_users_erased( - (e.sender for e in events), - ) + erased_senders = yield store.are_users_erased((e.sender for e in events)) else: # We don't want to check whether users are erased, which is equivalent # to no users having been erased. @@ -343,11 +331,8 @@ def filter_events_for_server(store, server_name, events, redact=True, event_to_state_ids = yield store.get_state_ids_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types( - types=( - (EventTypes.RoomHistoryVisibility, ""), - (EventTypes.Member, None), - ), - ) + types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None)) + ), ) # We only want to pull out member events that correspond to the @@ -371,13 +356,15 @@ def filter_events_for_server(store, server_name, events, redact=True, idx = state_key.find(":") if idx == -1: return False - return state_key[idx + 1:] == server_name - - event_map = yield store.get_events([ - e_id - for e_id, key in iteritems(event_id_to_state_key) - if include(key[0], key[1]) - ]) + return state_key[idx + 1 :] == server_name + + event_map = yield store.get_events( + [ + e_id + for e_id, key in iteritems(event_id_to_state_key) + if include(key[0], key[1]) + ] + ) event_to_state = { e_id: { diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index d0d36f96fa..d4e75b5b2e 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -172,7 +172,7 @@ class AuthTestCase(unittest.TestCase): request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = yield self.auth.get_user_by_req(request) self.assertEquals( - requester.user.to_string(), masquerading_user_id.decode('utf8') + requester.user.to_string(), masquerading_user_id.decode("utf8") ) def test_get_user_by_req_appservice_valid_token_bad_user_id(self): @@ -264,7 +264,7 @@ class AuthTestCase(unittest.TestCase): # check the token works request = Mock(args={}) - request.args[b"access_token"] = [token.encode('ascii')] + request.args[b"access_token"] = [token.encode("ascii")] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = yield self.auth.get_user_by_req(request, allow_guest=True) self.assertEqual(UserID.from_string(USER_ID), requester.user) @@ -277,7 +277,7 @@ class AuthTestCase(unittest.TestCase): # the token should *not* work now request = Mock(args={}) - request.args[b"access_token"] = [guest_tok.encode('ascii')] + request.args[b"access_token"] = [guest_tok.encode("ascii")] request.requestHeaders.getRawHeaders = mock_getRawHeaders() with self.assertRaises(AuthError) as cm: @@ -321,11 +321,11 @@ class AuthTestCase(unittest.TestCase): self.hs.config.limit_usage_by_mau = True self.hs.config.max_mau_value = 1 self.store.get_monthly_active_count = lambda: defer.succeed(2) - threepid = {'medium': 'email', 'address': 'reserved@server.com'} - unknown_threepid = {'medium': 'email', 'address': 'unreserved@server.com'} + threepid = {"medium": "email", "address": "reserved@server.com"} + unknown_threepid = {"medium": "email", "address": "unreserved@server.com"} self.hs.config.mau_limits_reserved_threepids = [threepid] - yield self.store.register(user_id='user1', token="123", password_hash=None) + yield self.store.register(user_id="user1", token="123", password_hash=None) with self.assertRaises(ResourceLimitError): yield self.auth.check_auth_blocking() diff --git a/tests/config/test_server.py b/tests/config/test_server.py index de64965a60..1ca5ea54ca 100644 --- a/tests/config/test_server.py +++ b/tests/config/test_server.py @@ -20,10 +20,10 @@ from tests import unittest class ServerConfigTestCase(unittest.TestCase): def test_is_threepid_reserved(self): - user1 = {'medium': 'email', 'address': 'user1@example.com'} - user2 = {'medium': 'email', 'address': 'user2@example.com'} - user3 = {'medium': 'email', 'address': 'user3@example.com'} - user1_msisdn = {'medium': 'msisdn', 'address': '447700000000'} + user1 = {"medium": "email", "address": "user1@example.com"} + user2 = {"medium": "email", "address": "user2@example.com"} + user3 = {"medium": "email", "address": "user3@example.com"} + user1_msisdn = {"medium": "msisdn", "address": "447700000000"} config = [user1, user2] self.assertTrue(is_threepid_reserved(config, user1)) diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py index 40ca428778..0cbbf4e885 100644 --- a/tests/config/test_tls.py +++ b/tests/config/test_tls.py @@ -32,7 +32,7 @@ class TLSConfigTests(TestCase): """ config_dir = self.mktemp() os.mkdir(config_dir) - with open(os.path.join(config_dir, "cert.pem"), 'w') as f: + with open(os.path.join(config_dir, "cert.pem"), "w") as f: f.write( """-----BEGIN CERTIFICATE----- MIID6DCCAtACAws9CjANBgkqhkiG9w0BAQUFADCBtzELMAkGA1UEBhMCVFIxDzAN diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py index 71aa731439..126e176004 100644 --- a/tests/crypto/test_event_signing.py +++ b/tests/crypto/test_event_signing.py @@ -41,25 +41,25 @@ class EventSigningTestCase(unittest.TestCase): def test_sign_minimal(self): event_dict = { - 'event_id': "$0:domain", - 'origin': "domain", - 'origin_server_ts': 1000000, - 'signatures': {}, - 'type': "X", - 'unsigned': {'age_ts': 1000000}, + "event_id": "$0:domain", + "origin": "domain", + "origin_server_ts": 1000000, + "signatures": {}, + "type": "X", + "unsigned": {"age_ts": 1000000}, } add_hashes_and_signatures(event_dict, HOSTNAME, self.signing_key) event = FrozenEvent(event_dict) - self.assertTrue(hasattr(event, 'hashes')) - self.assertIn('sha256', event.hashes) + self.assertTrue(hasattr(event, "hashes")) + self.assertIn("sha256", event.hashes) self.assertEquals( - event.hashes['sha256'], "6tJjLpXtggfke8UxFhAKg82QVkJzvKOVOOSjUDK4ZSI" + event.hashes["sha256"], "6tJjLpXtggfke8UxFhAKg82QVkJzvKOVOOSjUDK4ZSI" ) - self.assertTrue(hasattr(event, 'signatures')) + self.assertTrue(hasattr(event, "signatures")) self.assertIn(HOSTNAME, event.signatures) self.assertIn(KEY_NAME, event.signatures["domain"]) self.assertEquals( @@ -70,28 +70,28 @@ class EventSigningTestCase(unittest.TestCase): def test_sign_message(self): event_dict = { - 'content': {'body': "Here is the message content"}, - 'event_id': "$0:domain", - 'origin': "domain", - 'origin_server_ts': 1000000, - 'type': "m.room.message", - 'room_id': "!r:domain", - 'sender': "@u:domain", - 'signatures': {}, - 'unsigned': {'age_ts': 1000000}, + "content": {"body": "Here is the message content"}, + "event_id": "$0:domain", + "origin": "domain", + "origin_server_ts": 1000000, + "type": "m.room.message", + "room_id": "!r:domain", + "sender": "@u:domain", + "signatures": {}, + "unsigned": {"age_ts": 1000000}, } add_hashes_and_signatures(event_dict, HOSTNAME, self.signing_key) event = FrozenEvent(event_dict) - self.assertTrue(hasattr(event, 'hashes')) - self.assertIn('sha256', event.hashes) + self.assertTrue(hasattr(event, "hashes")) + self.assertIn("sha256", event.hashes) self.assertEquals( - event.hashes['sha256'], "onLKD1bGljeBWQhWZ1kaP9SorVmRQNdN5aM2JYU2n/g" + event.hashes["sha256"], "onLKD1bGljeBWQhWZ1kaP9SorVmRQNdN5aM2JYU2n/g" ) - self.assertTrue(hasattr(event, 'signatures')) + self.assertTrue(hasattr(event, "signatures")) self.assertIn(HOSTNAME, event.signatures) self.assertIn(KEY_NAME, event.signatures["domain"]) self.assertEquals( diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index d0cc492deb..9e3d4d0f47 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -37,88 +37,88 @@ class PruneEventTestCase(unittest.TestCase): def test_minimal(self): self.run_test( - {'type': 'A', 'event_id': '$test:domain'}, + {"type": "A", "event_id": "$test:domain"}, { - 'type': 'A', - 'event_id': '$test:domain', - 'content': {}, - 'signatures': {}, - 'unsigned': {}, + "type": "A", + "event_id": "$test:domain", + "content": {}, + "signatures": {}, + "unsigned": {}, }, ) def test_basic_keys(self): self.run_test( { - 'type': 'A', - 'room_id': '!1:domain', - 'sender': '@2:domain', - 'event_id': '$3:domain', - 'origin': 'domain', + "type": "A", + "room_id": "!1:domain", + "sender": "@2:domain", + "event_id": "$3:domain", + "origin": "domain", }, { - 'type': 'A', - 'room_id': '!1:domain', - 'sender': '@2:domain', - 'event_id': '$3:domain', - 'origin': 'domain', - 'content': {}, - 'signatures': {}, - 'unsigned': {}, + "type": "A", + "room_id": "!1:domain", + "sender": "@2:domain", + "event_id": "$3:domain", + "origin": "domain", + "content": {}, + "signatures": {}, + "unsigned": {}, }, ) def test_unsigned_age_ts(self): self.run_test( - {'type': 'B', 'event_id': '$test:domain', 'unsigned': {'age_ts': 20}}, + {"type": "B", "event_id": "$test:domain", "unsigned": {"age_ts": 20}}, { - 'type': 'B', - 'event_id': '$test:domain', - 'content': {}, - 'signatures': {}, - 'unsigned': {'age_ts': 20}, + "type": "B", + "event_id": "$test:domain", + "content": {}, + "signatures": {}, + "unsigned": {"age_ts": 20}, }, ) self.run_test( { - 'type': 'B', - 'event_id': '$test:domain', - 'unsigned': {'other_key': 'here'}, + "type": "B", + "event_id": "$test:domain", + "unsigned": {"other_key": "here"}, }, { - 'type': 'B', - 'event_id': '$test:domain', - 'content': {}, - 'signatures': {}, - 'unsigned': {}, + "type": "B", + "event_id": "$test:domain", + "content": {}, + "signatures": {}, + "unsigned": {}, }, ) def test_content(self): self.run_test( - {'type': 'C', 'event_id': '$test:domain', 'content': {'things': 'here'}}, + {"type": "C", "event_id": "$test:domain", "content": {"things": "here"}}, { - 'type': 'C', - 'event_id': '$test:domain', - 'content': {}, - 'signatures': {}, - 'unsigned': {}, + "type": "C", + "event_id": "$test:domain", + "content": {}, + "signatures": {}, + "unsigned": {}, }, ) self.run_test( { - 'type': 'm.room.create', - 'event_id': '$test:domain', - 'content': {'creator': '@2:domain', 'other_field': 'here'}, + "type": "m.room.create", + "event_id": "$test:domain", + "content": {"creator": "@2:domain", "other_field": "here"}, }, { - 'type': 'm.room.create', - 'event_id': '$test:domain', - 'content': {'creator': '@2:domain'}, - 'signatures': {}, - 'unsigned': {}, + "type": "m.room.create", + "event_id": "$test:domain", + "content": {"creator": "@2:domain"}, + "signatures": {}, + "unsigned": {}, }, ) diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 1e3e5aec66..a5b03005d7 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -32,7 +32,7 @@ class RoomComplexityTests(unittest.HomeserverTestCase): login.register_servlets, ] - def default_config(self, name='test'): + def default_config(self, name="test"): config = super(RoomComplexityTests, self).default_config(name=name) config["limit_large_remote_room_joins"] = True config["limit_large_remote_room_complexity"] = 0.05 diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 7bb106b5f7..cce8d8c6de 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -51,16 +51,16 @@ class FederationSenderTestCases(HomeserverTestCase): json_cb = mock_send_transaction.call_args[0][1] data = json_cb() self.assertEqual( - data['edus'], + data["edus"], [ { - 'edu_type': 'm.receipt', - 'content': { - 'room_id': { - 'm.read': { - 'user_id': { - 'event_ids': ['event_id'], - 'data': {'ts': 1234}, + "edu_type": "m.receipt", + "content": { + "room_id": { + "m.read": { + "user_id": { + "event_ids": ["event_id"], + "data": {"ts": 1234}, } } } @@ -93,16 +93,16 @@ class FederationSenderTestCases(HomeserverTestCase): json_cb = mock_send_transaction.call_args[0][1] data = json_cb() self.assertEqual( - data['edus'], + data["edus"], [ { - 'edu_type': 'm.receipt', - 'content': { - 'room_id': { - 'm.read': { - 'user_id': { - 'event_ids': ['event_id'], - 'data': {'ts': 1234}, + "edu_type": "m.receipt", + "content": { + "room_id": { + "m.read": { + "user_id": { + "event_ids": ["event_id"], + "data": {"ts": 1234}, } } } @@ -128,16 +128,16 @@ class FederationSenderTestCases(HomeserverTestCase): json_cb = mock_send_transaction.call_args[0][1] data = json_cb() self.assertEqual( - data['edus'], + data["edus"], [ { - 'edu_type': 'm.receipt', - 'content': { - 'room_id': { - 'm.read': { - 'user_id': { - 'event_ids': ['other_id'], - 'data': {'ts': 1234}, + "edu_type": "m.receipt", + "content": { + "room_id": { + "m.read": { + "user_id": { + "event_ids": ["other_id"], + "data": {"ts": 1234}, } } } diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 1e39fe0ec2..b204a0700d 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -117,7 +117,7 @@ class AuthTestCase(unittest.TestCase): def test_mau_limits_disabled(self): self.hs.config.limit_usage_by_mau = False # Ensure does not throw exception - yield self.auth_handler.get_access_token_for_user_id('user_a') + yield self.auth_handler.get_access_token_for_user_id("user_a") yield self.auth_handler.validate_short_term_login_token_and_get_user_id( self._get_macaroon().serialize() @@ -131,7 +131,7 @@ class AuthTestCase(unittest.TestCase): ) with self.assertRaises(ResourceLimitError): - yield self.auth_handler.get_access_token_for_user_id('user_a') + yield self.auth_handler.get_access_token_for_user_id("user_a") self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.large_number_of_users) @@ -150,7 +150,7 @@ class AuthTestCase(unittest.TestCase): return_value=defer.succeed(self.hs.config.max_mau_value) ) with self.assertRaises(ResourceLimitError): - yield self.auth_handler.get_access_token_for_user_id('user_a') + yield self.auth_handler.get_access_token_for_user_id("user_a") self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) @@ -166,7 +166,7 @@ class AuthTestCase(unittest.TestCase): self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) - yield self.auth_handler.get_access_token_for_user_id('user_a') + yield self.auth_handler.get_access_token_for_user_id("user_a") self.hs.get_datastore().user_last_seen_monthly_active = Mock( return_value=defer.succeed(self.hs.get_clock().time_msec()) ) @@ -185,7 +185,7 @@ class AuthTestCase(unittest.TestCase): return_value=defer.succeed(self.small_number_of_users) ) # Ensure does not raise exception - yield self.auth_handler.get_access_token_for_user_id('user_a') + yield self.auth_handler.get_access_token_for_user_id("user_a") self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.small_number_of_users) diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 917548bb31..91c7a17070 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -132,7 +132,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): request, channel = self.make_request( "PUT", b"directory/room/%23test%3Atest", - ('{"room_id":"%s"}' % (room_id,)).encode('ascii'), + ('{"room_id":"%s"}' % (room_id,)).encode("ascii"), ) self.render(request) self.assertEquals(403, channel.code, channel.result) @@ -143,7 +143,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): request, channel = self.make_request( "PUT", b"directory/room/%23unofficial_test%3Atest", - ('{"room_id":"%s"}' % (room_id,)).encode('ascii'), + ('{"room_id":"%s"}' % (room_id,)).encode("ascii"), ) self.render(request) self.assertEquals(200, channel.code, channel.result) @@ -158,7 +158,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): room_id = self.helper.create_room_as(self.user_id) request, channel = self.make_request( - "PUT", b"directory/list/room/%s" % (room_id.encode('ascii'),), b'{}' + "PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}" ) self.render(request) self.assertEquals(200, channel.code, channel.result) @@ -190,7 +190,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): # Room list disabled so we shouldn't be allowed to publish rooms room_id = self.helper.create_room_as(self.user_id) request, channel = self.make_request( - "PUT", b"directory/list/room/%s" % (room_id.encode('ascii'),), b'{}' + "PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}" ) self.render(request) self.assertEquals(403, channel.code, channel.result) diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index 2e72a1dd23..c4503c1611 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -395,37 +395,37 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): yield self.handler.upload_room_keys(self.local_user, version, room_keys) new_room_keys = copy.deepcopy(room_keys) - new_room_key = new_room_keys['rooms']['!abc:matrix.org']['sessions']['c0ff33'] + new_room_key = new_room_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"] # test that increasing the message_index doesn't replace the existing session - new_room_key['first_message_index'] = 2 - new_room_key['session_data'] = 'new' + new_room_key["first_message_index"] = 2 + new_room_key["session_data"] = "new" yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) res = yield self.handler.get_room_keys(self.local_user, version) self.assertEqual( - res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], + res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "SSBBTSBBIEZJU0gK", ) # test that marking the session as verified however /does/ replace it - new_room_key['is_verified'] = True + new_room_key["is_verified"] = True yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) res = yield self.handler.get_room_keys(self.local_user, version) self.assertEqual( - res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], "new" + res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) # test that a session with a higher forwarded_count doesn't replace one # with a lower forwarding count - new_room_key['forwarded_count'] = 2 - new_room_key['session_data'] = 'other' + new_room_key["forwarded_count"] = 2 + new_room_key["session_data"] = "other" yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) res = yield self.handler.get_room_keys(self.local_user, version) self.assertEqual( - res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], "new" + res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) # TODO: check edge cases as well as the common variations here diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 5ffba2ca7a..4edce7af43 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -52,7 +52,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.mock_distributor.declare("registered_user") self.mock_captcha_client = Mock() self.macaroon_generator = Mock( - generate_access_token=Mock(return_value='secret') + generate_access_token=Mock(return_value="secret") ) self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator) self.handler = self.hs.get_registration_handler() @@ -71,7 +71,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ) self.assertEquals(result_user_id, user_id) self.assertTrue(result_token is not None) - self.assertEquals(result_token, 'secret') + self.assertEquals(result_token, "secret") def test_if_user_exists(self): store = self.hs.get_datastore() @@ -96,7 +96,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.hs.config.limit_usage_by_mau = False # Ensure does not throw exception self.get_success( - self.handler.get_or_create_user(self.requester, 'a', "display_name") + self.handler.get_or_create_user(self.requester, "a", "display_name") ) def test_get_or_create_user_mau_not_blocked(self): @@ -105,7 +105,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): return_value=defer.succeed(self.hs.config.max_mau_value - 1) ) # Ensure does not throw exception - self.get_success(self.handler.get_or_create_user(self.requester, 'c', "User")) + self.get_success(self.handler.get_or_create_user(self.requester, "c", "User")) def test_get_or_create_user_mau_blocked(self): self.hs.config.limit_usage_by_mau = True @@ -113,7 +113,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): return_value=defer.succeed(self.lots_of_users) ) self.get_failure( - self.handler.get_or_create_user(self.requester, 'b', "display_name"), + self.handler.get_or_create_user(self.requester, "b", "display_name"), ResourceLimitError, ) @@ -121,7 +121,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): return_value=defer.succeed(self.hs.config.max_mau_value) ) self.get_failure( - self.handler.get_or_create_user(self.requester, 'b', "display_name"), + self.handler.get_or_create_user(self.requester, "b", "display_name"), ResourceLimitError, ) @@ -144,13 +144,13 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_auto_create_auto_join_rooms(self): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - res = self.get_success(self.handler.register(localpart='jeff')) + res = self.get_success(self.handler.register(localpart="jeff")) rooms = self.get_success(self.store.get_rooms_for_user(res[0])) directory_handler = self.hs.get_handlers().directory_handler room_alias = RoomAlias.from_string(room_alias_str) room_id = self.get_success(directory_handler.get_association(room_alias)) - self.assertTrue(room_id['room_id'] in rooms) + self.assertTrue(room_id["room_id"] in rooms) self.assertEqual(len(rooms), 1) def test_auto_create_auto_join_rooms_with_no_rooms(self): @@ -173,7 +173,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.hs.config.autocreate_auto_join_rooms = False room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - res = self.get_success(self.handler.register(localpart='jeff')) + res = self.get_success(self.handler.register(localpart="jeff")) rooms = self.get_success(self.store.get_rooms_for_user(res[0])) self.assertEqual(len(rooms), 0) @@ -182,7 +182,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.hs.config.auto_join_rooms = [room_alias_str] self.store.is_support_user = Mock(return_value=True) - res = self.get_success(self.handler.register(localpart='support')) + res = self.get_success(self.handler.register(localpart="support")) rooms = self.get_success(self.store.get_rooms_for_user(res[0])) self.assertEqual(len(rooms), 0) directory_handler = self.hs.get_handlers().directory_handler @@ -211,7 +211,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): # When:- # * the user is registered and post consent actions are called - res = self.get_success(self.handler.register(localpart='jeff')) + res = self.get_success(self.handler.register(localpart="jeff")) self.get_success(self.handler.post_consent_actions(res[0])) # Then:- @@ -221,17 +221,14 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_register_support_user(self): res = self.get_success( - self.handler.register(localpart='user', user_type=UserTypes.SUPPORT) + self.handler.register(localpart="user", user_type=UserTypes.SUPPORT) ) self.assertTrue(self.store.is_support_user(res[0])) def test_register_not_support_user(self): - res = self.get_success(self.handler.register(localpart='user')) + res = self.get_success(self.handler.register(localpart="user")) self.assertFalse(self.store.is_support_user(res[0])) def test_invalid_user_id_length(self): invalid_user_id = "x" * 256 - self.get_failure( - self.handler.register(localpart=invalid_user_id), - SynapseError - ) + self.get_failure(self.handler.register(localpart=invalid_user_id), SynapseError) diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 2710c991cf..a8b858eb4f 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -265,10 +265,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): while not self.get_success(self.store.has_completed_background_updates()): self.get_success(self.store.do_next_background_update(100), by=0.1) - events = { - "a1": None, - "a2": {"membership": Membership.JOIN}, - } + events = {"a1": None, "a2": {"membership": Membership.JOIN}} def get_event(event_id, allow_none=True): if events.get(event_id): diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index cb8b4d2913..5d5e324df2 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -47,7 +47,7 @@ def _expect_edu_transaction(edu_type, content, origin="test"): def _make_edu_transaction_json(edu_type, content): - return json.dumps(_expect_edu_transaction(edu_type, content)).encode('utf8') + return json.dumps(_expect_edu_transaction(edu_type, content)).encode("utf8") class TypingNotificationsTestCase(unittest.HomeserverTestCase): @@ -151,7 +151,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) ) - self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])]) + self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) self.assertEquals(self.event_source.get_current_key(), 1) events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) @@ -209,12 +209,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): "typing": True, }, ), - federation_auth_origin=b'farm', + federation_auth_origin=b"farm", ) self.render(request) self.assertEqual(channel.code, 200) - self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])]) + self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) self.assertEquals(self.event_source.get_current_key(), 1) events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) @@ -247,7 +247,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) ) - self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])]) + self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) put_json = self.hs.get_http_client().put_json put_json.assert_called_once_with( @@ -285,7 +285,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) ) - self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])]) + self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 1) @@ -303,7 +303,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.reactor.pump([16]) - self.on_new_event.assert_has_calls([call('typing_key', 2, rooms=[ROOM_ID])]) + self.on_new_event.assert_has_calls([call("typing_key", 2, rooms=[ROOM_ID])]) self.assertEquals(self.event_source.get_current_key(), 2) events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1) @@ -320,7 +320,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) ) - self.on_new_event.assert_has_calls([call('typing_key', 3, rooms=[ROOM_ID])]) + self.on_new_event.assert_has_calls([call("typing_key", 3, rooms=[ROOM_ID])]) self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 3) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 9021e647fe..b135486c48 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -60,15 +60,15 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) profile = self.get_success(self.store.get_user_in_directory(support_user_id)) self.assertTrue(profile is None) - display_name = 'display_name' + display_name = "display_name" - profile_info = ProfileInfo(avatar_url='avatar_url', display_name=display_name) - regular_user_id = '@regular:test' + profile_info = ProfileInfo(avatar_url="avatar_url", display_name=display_name) + regular_user_id = "@regular:test" self.get_success( self.handler.handle_local_profile_change(regular_user_id, profile_info) ) profile = self.get_success(self.store.get_user_in_directory(regular_user_id)) - self.assertTrue(profile['display_name'] == display_name) + self.assertTrue(profile["display_name"] == display_name) def test_handle_user_deactivated_support_user(self): s_user_id = "@support:test" diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index ecce473b01..b1094c1448 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -53,13 +53,15 @@ def get_connection_factory(): # this needs to happen once, but not until we are ready to run the first test global test_server_connection_factory if test_server_connection_factory is None: - test_server_connection_factory = TestServerTLSConnectionFactory(sanlist=[ - b'DNS:testserv', - b'DNS:target-server', - b'DNS:xn--bcher-kva.com', - b'IP:1.2.3.4', - b'IP:::1', - ]) + test_server_connection_factory = TestServerTLSConnectionFactory( + sanlist=[ + b"DNS:testserv", + b"DNS:target-server", + b"DNS:xn--bcher-kva.com", + b"IP:1.2.3.4", + b"IP:::1", + ] + ) return test_server_connection_factory @@ -133,7 +135,7 @@ class MatrixFederationAgentTests(TestCase): Sends a simple GET request via the agent, and checks its logcontext management """ with LoggingContext("one") as context: - fetch_d = self.agent.request(b'GET', uri) + fetch_d = self.agent.request(b"GET", uri) # Nothing happened yet self.assertNoResult(fetch_d) @@ -177,9 +179,9 @@ class MatrixFederationAgentTests(TestCase): """Check that an incoming request looks like a valid .well-known request, and send back the response. """ - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/.well-known/matrix/server') - self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/.well-known/matrix/server") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"]) # send back a response for k, v in headers.items(): request.setHeader(k, v) @@ -202,7 +204,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8448) # make a test server, and wire up the client @@ -210,20 +212,20 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), [b'testserv:8448'] + request.requestHeaders.getRawHeaders(b"host"), [b"testserv:8448"] ) content = request.content.read() - self.assertEqual(content, b'') + self.assertEqual(content, b"") # Deferred is still without a result self.assertNoResult(test_d) # send the headers - request.responseHeaders.setRawHeaders(b'Content-Type', [b'application/json']) - request.write('') + request.responseHeaders.setRawHeaders(b"Content-Type", [b"application/json"]) + request.write("") self.reactor.pump((0.1,)) @@ -233,7 +235,7 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(response.code, 200) # Send the body - request.write('{ "a": 1 }'.encode('ascii')) + request.write('{ "a": 1 }'.encode("ascii")) request.finish() self.reactor.pump((0.1,)) @@ -258,7 +260,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8448) # make a test server, and wire up the client @@ -266,9 +268,9 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') - self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'1.2.3.4']) + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"1.2.3.4"]) # finish the request request.finish() @@ -293,7 +295,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '::1') + self.assertEqual(host, "::1") self.assertEqual(port, 8448) # make a test server, and wire up the client @@ -301,9 +303,9 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') - self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'[::1]']) + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"[::1]"]) # finish the request request.finish() @@ -328,7 +330,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '::1') + self.assertEqual(host, "::1") self.assertEqual(port, 80) # make a test server, and wire up the client @@ -336,9 +338,9 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') - self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'[::1]:80']) + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"[::1]:80"]) # finish the request request.finish() @@ -364,7 +366,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) # fonx the connection @@ -382,11 +384,11 @@ class MatrixFederationAgentTests(TestCase): # we should fall back to a direct connection self.assertEqual(len(clients), 2) (host, port, client_factory, _timeout, _bindAddress) = clients[1] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8448) # make a test server, and wire up the client - http_server = self._make_connection(client_factory, expected_sni=b'testserv1') + http_server = self._make_connection(client_factory, expected_sni=b"testserv1") # there should be no requests self.assertEqual(len(http_server.requests), 0) @@ -413,7 +415,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.5') + self.assertEqual(host, "1.2.3.5") self.assertEqual(port, 8448) # make a test server, and wire up the client @@ -447,7 +449,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) # fonx the connection @@ -465,17 +467,17 @@ class MatrixFederationAgentTests(TestCase): # we should fall back to a direct connection self.assertEqual(len(clients), 2) (host, port, client_factory, _timeout, _bindAddress) = clients[1] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8448) # make a test server, and wire up the client - http_server = self._make_connection(client_factory, expected_sni=b'testserv') + http_server = self._make_connection(client_factory, expected_sni=b"testserv") self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') - self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"]) # finish the request request.finish() @@ -499,7 +501,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) self._handle_well_known_connection( @@ -516,20 +518,20 @@ class MatrixFederationAgentTests(TestCase): # now we should get a connection to the target server self.assertEqual(len(clients), 2) (host, port, client_factory, _timeout, _bindAddress) = clients[1] - self.assertEqual(host, '1::f') + self.assertEqual(host, "1::f") self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection( - client_factory, expected_sni=b'target-server' + client_factory, expected_sni=b"target-server" ) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), [b'target-server'] + request.requestHeaders.getRawHeaders(b"host"), [b"target-server"] ) # finish the request @@ -561,7 +563,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop() - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) redirect_server = self._make_connection( @@ -571,7 +573,7 @@ class MatrixFederationAgentTests(TestCase): # send a 302 redirect self.assertEqual(len(redirect_server.requests), 1) request = redirect_server.requests[0] - request.redirect(b'https://testserv/even_better_known') + request.redirect(b"https://testserv/even_better_known") request.finish() self.reactor.pump((0.1,)) @@ -580,7 +582,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop() - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) well_known_server = self._make_connection( @@ -589,8 +591,8 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(len(well_known_server.requests), 1, "No request after 302") request = well_known_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/even_better_known') + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/even_better_known") request.write(b'{ "m.server": "target-server" }') request.finish() @@ -604,20 +606,20 @@ class MatrixFederationAgentTests(TestCase): # now we should get a connection to the target server self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1::f') + self.assertEqual(host, "1::f") self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection( - client_factory, expected_sni=b'target-server' + client_factory, expected_sni=b"target-server" ) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), [b'target-server'] + request.requestHeaders.getRawHeaders(b"host"), [b"target-server"] ) # finish the request @@ -652,11 +654,11 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop() - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) self._handle_well_known_connection( - client_factory, expected_sni=b"testserv", content=b'NOT JSON' + client_factory, expected_sni=b"testserv", content=b"NOT JSON" ) # now there should be a SRV lookup @@ -667,17 +669,17 @@ class MatrixFederationAgentTests(TestCase): # we should fall back to a direct connection self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop() - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8448) # make a test server, and wire up the client - http_server = self._make_connection(client_factory, expected_sni=b'testserv') + http_server = self._make_connection(client_factory, expected_sni=b"testserv") self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') - self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"]) # finish the request request.finish() @@ -712,12 +714,10 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) - http_proto = self._make_connection( - client_factory, expected_sni=b"testserv", - ) + http_proto = self._make_connection(client_factory, expected_sni=b"testserv") # there should be no requests self.assertEqual(len(http_proto.requests), 0) @@ -750,17 +750,17 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8443) # make a test server, and wire up the client - http_server = self._make_connection(client_factory, expected_sni=b'testserv') + http_server = self._make_connection(client_factory, expected_sni=b"testserv") self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') - self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"]) # finish the request request.finish() @@ -783,7 +783,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) self.mock_resolver.resolve_service.side_effect = lambda _: [ @@ -804,20 +804,20 @@ class MatrixFederationAgentTests(TestCase): # now we should get a connection to the target of the SRV record self.assertEqual(len(clients), 2) (host, port, client_factory, _timeout, _bindAddress) = clients[1] - self.assertEqual(host, '5.6.7.8') + self.assertEqual(host, "5.6.7.8") self.assertEqual(port, 8443) # make a test server, and wire up the client http_server = self._make_connection( - client_factory, expected_sni=b'target-server' + client_factory, expected_sni=b"target-server" ) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), [b'target-server'] + request.requestHeaders.getRawHeaders(b"host"), [b"target-server"] ) # finish the request @@ -846,7 +846,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) # fonx the connection @@ -865,20 +865,20 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 2) (host, port, client_factory, _timeout, _bindAddress) = clients[1] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection( - client_factory, expected_sni=b'xn--bcher-kva.com' + client_factory, expected_sni=b"xn--bcher-kva.com" ) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), [b'xn--bcher-kva.com'] + request.requestHeaders.getRawHeaders(b"host"), [b"xn--bcher-kva.com"] ) # finish the request @@ -907,20 +907,20 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8443) # make a test server, and wire up the client http_server = self._make_connection( - client_factory, expected_sni=b'xn--bcher-kva.com' + client_factory, expected_sni=b"xn--bcher-kva.com" ) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), [b'xn--bcher-kva.com'] + request.requestHeaders.getRawHeaders(b"host"), [b"xn--bcher-kva.com"] ) # finish the request @@ -941,42 +941,42 @@ class MatrixFederationAgentTests(TestCase): def test_well_known_cache(self): self.reactor.lookups["testserv"] = "1.2.3.4" - fetch_d = self.do_get_well_known(b'testserv') + fetch_d = self.do_get_well_known(b"testserv") # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) well_known_server = self._handle_well_known_connection( client_factory, expected_sni=b"testserv", - response_headers={b'Cache-Control': b'max-age=10'}, + response_headers={b"Cache-Control": b"max-age=10"}, content=b'{ "m.server": "target-server" }', ) r = self.successResultOf(fetch_d) - self.assertEqual(r, b'target-server') + self.assertEqual(r, b"target-server") # close the tcp connection well_known_server.loseConnection() # repeat the request: it should hit the cache - fetch_d = self.do_get_well_known(b'testserv') + fetch_d = self.do_get_well_known(b"testserv") r = self.successResultOf(fetch_d) - self.assertEqual(r, b'target-server') + self.assertEqual(r, b"target-server") # expire the cache self.reactor.pump((10.0,)) # now it should connect again - fetch_d = self.do_get_well_known(b'testserv') + fetch_d = self.do_get_well_known(b"testserv") self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) self._handle_well_known_connection( @@ -986,7 +986,7 @@ class MatrixFederationAgentTests(TestCase): ) r = self.successResultOf(fetch_d) - self.assertEqual(r, b'other-server') + self.assertEqual(r, b"other-server") class TestCachePeriodFromHeaders(TestCase): @@ -994,27 +994,27 @@ class TestCachePeriodFromHeaders(TestCase): # uppercase self.assertEqual( _cache_period_from_headers( - Headers({b'Cache-Control': [b'foo, Max-Age = 100, bar']}) + Headers({b"Cache-Control": [b"foo, Max-Age = 100, bar"]}) ), 100, ) # missing value self.assertIsNone( - _cache_period_from_headers(Headers({b'Cache-Control': [b'max-age=, bar']})) + _cache_period_from_headers(Headers({b"Cache-Control": [b"max-age=, bar"]})) ) # hackernews: bogus due to semicolon self.assertIsNone( _cache_period_from_headers( - Headers({b'Cache-Control': [b'private; max-age=0']}) + Headers({b"Cache-Control": [b"private; max-age=0"]}) ) ) # github self.assertEqual( _cache_period_from_headers( - Headers({b'Cache-Control': [b'max-age=0, private, must-revalidate']}) + Headers({b"Cache-Control": [b"max-age=0, private, must-revalidate"]}) ), 0, ) @@ -1022,7 +1022,7 @@ class TestCachePeriodFromHeaders(TestCase): # google self.assertEqual( _cache_period_from_headers( - Headers({b'cache-control': [b'private, max-age=0']}) + Headers({b"cache-control": [b"private, max-age=0"]}) ), 0, ) @@ -1030,7 +1030,7 @@ class TestCachePeriodFromHeaders(TestCase): def test_expires(self): self.assertEqual( _cache_period_from_headers( - Headers({b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT']}), + Headers({b"Expires": [b"Wed, 30 Jan 2019 07:35:33 GMT"]}), time_now=lambda: 1548833700, ), 33, @@ -1041,8 +1041,8 @@ class TestCachePeriodFromHeaders(TestCase): _cache_period_from_headers( Headers( { - b'cache-control': [b'max-age=10'], - b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT'], + b"cache-control": [b"max-age=10"], + b"Expires": [b"Wed, 30 Jan 2019 07:35:33 GMT"], } ), time_now=lambda: 1548833700, @@ -1051,7 +1051,7 @@ class TestCachePeriodFromHeaders(TestCase): ) # invalid expires means immediate expiry - self.assertEqual(_cache_period_from_headers(Headers({b'Expires': [b'0']})), 0) + self.assertEqual(_cache_period_from_headers(Headers({b"Expires": [b"0"]})), 0) def _check_logcontext(context): diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index 034c0db8d2..cf6c6e95b5 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -100,7 +100,7 @@ class SrvResolverTestCase(unittest.TestCase): def test_from_cache(self): clock = MockClock() - dns_client_mock = Mock(spec_set=['lookupService']) + dns_client_mock = Mock(spec_set=["lookupService"]) dns_client_mock.lookupService = Mock(spec_set=[]) service_name = b"test_service.example.com" diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py index 3b0155ed03..b2e9533b07 100644 --- a/tests/http/test_endpoint.py +++ b/tests/http/test_endpoint.py @@ -20,12 +20,12 @@ from tests import unittest class ServerNameTestCase(unittest.TestCase): def test_parse_server_name(self): test_data = { - 'localhost': ('localhost', None), - 'my-example.com:1234': ('my-example.com', 1234), - '1.2.3.4': ('1.2.3.4', None), - '[0abc:1def::1234]': ('[0abc:1def::1234]', None), - '1.2.3.4:1': ('1.2.3.4', 1), - '[0abc:1def::1234]:8080': ('[0abc:1def::1234]', 8080), + "localhost": ("localhost", None), + "my-example.com:1234": ("my-example.com", 1234), + "1.2.3.4": ("1.2.3.4", None), + "[0abc:1def::1234]": ("[0abc:1def::1234]", None), + "1.2.3.4:1": ("1.2.3.4", 1), + "[0abc:1def::1234]:8080": ("[0abc:1def::1234]", 8080), } for i, o in test_data.items(): diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index ee767f3a5a..c4c0d9b968 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -83,7 +83,7 @@ class FederationClientTests(HomeserverTestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8008) # complete the connection and wire it up to a fake transport @@ -99,7 +99,7 @@ class FederationClientTests(HomeserverTestCase): self.assertNoResult(test_d) # Send it the HTTP response - res_json = '{ "a": 1 }'.encode('ascii') + res_json = '{ "a": 1 }'.encode("ascii") protocol.dataReceived( b"HTTP/1.1 200 OK\r\n" b"Server: Fake\r\n" @@ -138,7 +138,7 @@ class FederationClientTests(HomeserverTestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8008) e = Exception("go away") factory.clientConnectionFailed(None, e) @@ -164,7 +164,7 @@ class FederationClientTests(HomeserverTestCase): # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) - self.assertEqual(clients[0][0], '1.2.3.4') + self.assertEqual(clients[0][0], "1.2.3.4") self.assertEqual(clients[0][1], 8008) # Deferred is still without a result @@ -194,7 +194,7 @@ class FederationClientTests(HomeserverTestCase): # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) - self.assertEqual(clients[0][0], '1.2.3.4') + self.assertEqual(clients[0][0], "1.2.3.4") self.assertEqual(clients[0][1], 8008) conn = Mock() @@ -215,10 +215,9 @@ class FederationClientTests(HomeserverTestCase): """Ensure that Synapse does not try to connect to blacklisted IPs""" # Set up the ip_range blacklist - self.hs.config.federation_ip_range_blacklist = IPSet([ - "127.0.0.0/8", - "fe80::/64", - ]) + self.hs.config.federation_ip_range_blacklist = IPSet( + ["127.0.0.0/8", "fe80::/64"] + ) self.reactor.lookups["internal"] = "127.0.0.1" self.reactor.lookups["internalv6"] = "fe80:0:0:0:0:8a2e:370:7337" self.reactor.lookups["fine"] = "10.20.30.40" @@ -382,7 +381,7 @@ class FederationClientTests(HomeserverTestCase): b"Content-Type: application/json\r\n" b"Content-Length: 2\r\n" b"\r\n" - b'{}' + b"{}" ) # We should get a successful response diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 72760a0733..358b593cd4 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -57,7 +57,7 @@ class EmailPusherTests(HomeserverTestCase): config["email"] = { "enable_notifs": True, "template_dir": os.path.abspath( - pkg_resources.resource_filename('synapse', 'res/templates') + pkg_resources.resource_filename("synapse", "res/templates") ), "expiry_template_html": "notice_expiry.html", "expiry_template_text": "notice_expiry.txt", @@ -120,7 +120,7 @@ class EmailPusherTests(HomeserverTestCase): # Create a simple room with two users room = self.helper.create_room_as(self.user_id, tok=self.access_token) self.helper.invite( - room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id, + room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id ) self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token) @@ -141,7 +141,7 @@ class EmailPusherTests(HomeserverTestCase): for other in self.others: self.helper.invite( - room=room, src=self.user_id, tok=self.access_token, targ=other.id, + room=room, src=self.user_id, tok=self.access_token, targ=other.id ) self.helper.join(room=room, user=other.id, tok=other.token) diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index e5fc2fcd15..5877bb2133 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -30,7 +30,7 @@ from tests import unittest class VersionTestCase(unittest.HomeserverTestCase): - url = '/_synapse/admin/v1/server_version' + url = "/_synapse/admin/v1/server_version" def create_test_json_resource(self): resource = JsonResource(self.hs) @@ -43,7 +43,7 @@ class VersionTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( - {'server_version', 'python_version'}, set(channel.json_body.keys()) + {"server_version", "python_version"}, set(channel.json_body.keys()) ) @@ -68,7 +68,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): self.hs = self.setup_test_homeserver() - self.hs.config.registration_shared_secret = u"shared" + self.hs.config.registration_shared_secret = "shared" self.hs.get_media_repository = Mock() self.hs.get_deactivate_account_handler = Mock() @@ -82,12 +82,12 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): """ self.hs.config.registration_shared_secret = None - request, channel = self.make_request("POST", self.url, b'{}') + request, channel = self.make_request("POST", self.url, b"{}") self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( - 'Shared secret registration is not enabled', channel.json_body["error"] + "Shared secret registration is not enabled", channel.json_body["error"] ) def test_get_nonce(self): @@ -118,20 +118,20 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): self.reactor.advance(59) body = json.dumps({"nonce": nonce}) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('username must be specified', channel.json_body["error"]) + self.assertEqual("username must be specified", channel.json_body["error"]) # 61 seconds self.reactor.advance(2) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('unrecognised nonce', channel.json_body["error"]) + self.assertEqual("unrecognised nonce", channel.json_body["error"]) def test_register_incorrect_nonce(self): """ @@ -154,7 +154,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "mac": want_mac, } ) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) @@ -171,7 +171,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update( - nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin\x00support" + nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin\x00support" ) want_mac = want_mac.hexdigest() @@ -185,7 +185,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "mac": want_mac, } ) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -200,7 +200,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) - want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin") + want_mac.update(nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin") want_mac = want_mac.hexdigest() body = json.dumps( @@ -212,18 +212,18 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "mac": want_mac, } ) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["user_id"]) # Now, try and reuse it - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('unrecognised nonce', channel.json_body["error"]) + self.assertEqual("unrecognised nonce", channel.json_body["error"]) def test_missing_parts(self): """ @@ -243,11 +243,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Must be present body = json.dumps({}) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('nonce must be specified', channel.json_body["error"]) + self.assertEqual("nonce must be specified", channel.json_body["error"]) # # Username checks @@ -255,35 +255,35 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Must be present body = json.dumps({"nonce": nonce()}) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('username must be specified', channel.json_body["error"]) + self.assertEqual("username must be specified", channel.json_body["error"]) # Must be a string body = json.dumps({"nonce": nonce(), "username": 1234}) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('Invalid username', channel.json_body["error"]) + self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes - body = json.dumps({"nonce": nonce(), "username": u"abcd\u0000"}) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + body = json.dumps({"nonce": nonce(), "username": "abcd\u0000"}) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('Invalid username', channel.json_body["error"]) + self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes body = json.dumps({"nonce": nonce(), "username": "a" * 1000}) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('Invalid username', channel.json_body["error"]) + self.assertEqual("Invalid username", channel.json_body["error"]) # # Password checks @@ -291,37 +291,35 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Must be present body = json.dumps({"nonce": nonce(), "username": "a"}) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('password must be specified', channel.json_body["error"]) + self.assertEqual("password must be specified", channel.json_body["error"]) # Must be a string body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234}) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('Invalid password', channel.json_body["error"]) + self.assertEqual("Invalid password", channel.json_body["error"]) # Must not have null bytes - body = json.dumps( - {"nonce": nonce(), "username": "a", "password": u"abcd\u0000"} - ) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + body = json.dumps({"nonce": nonce(), "username": "a", "password": "abcd\u0000"}) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('Invalid password', channel.json_body["error"]) + self.assertEqual("Invalid password", channel.json_body["error"]) # Super long body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000}) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('Invalid password', channel.json_body["error"]) + self.assertEqual("Invalid password", channel.json_body["error"]) # # user_type check @@ -336,11 +334,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "user_type": "invalid", } ) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('Invalid user type', channel.json_body["error"]) + self.assertEqual("Invalid user type", channel.json_body["error"]) class ShutdownRoomTestCase(unittest.HomeserverTestCase): @@ -396,7 +394,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): url = "admin/shutdown_room/" + room_id request, channel = self.make_request( "POST", - url.encode('ascii'), + url.encode("ascii"), json.dumps({"new_room_user_id": self.admin_user}), access_token=self.admin_user_tok, ) @@ -421,7 +419,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): url = "rooms/%s/state/m.room.history_visibility" % (room_id,) request, channel = self.make_request( "PUT", - url.encode('ascii'), + url.encode("ascii"), json.dumps({"history_visibility": "world_readable"}), access_token=self.other_user_token, ) @@ -432,7 +430,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): url = "admin/shutdown_room/" + room_id request, channel = self.make_request( "POST", - url.encode('ascii'), + url.encode("ascii"), json.dumps({"new_room_user_id": self.admin_user}), access_token=self.admin_user_tok, ) @@ -449,7 +447,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): url = "rooms/%s/initialSync" % (room_id,) request, channel = self.make_request( - "GET", url.encode('ascii'), access_token=self.admin_user_tok + "GET", url.encode("ascii"), access_token=self.admin_user_tok ) self.render(request) self.assertEqual( @@ -458,7 +456,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): url = "events?timeout=0&room_id=" + room_id request, channel = self.make_request( - "GET", url.encode('ascii'), access_token=self.admin_user_tok + "GET", url.encode("ascii"), access_token=self.admin_user_tok ) self.render(request) self.assertEqual( @@ -486,7 +484,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): # Create a new group request, channel = self.make_request( "POST", - "/create_group".encode('ascii'), + "/create_group".encode("ascii"), access_token=self.admin_user_tok, content={"localpart": "test"}, ) @@ -502,14 +500,14 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): url = "/groups/%s/admin/users/invite/%s" % (group_id, self.other_user) request, channel = self.make_request( - "PUT", url.encode('ascii'), access_token=self.admin_user_tok, content={} + "PUT", url.encode("ascii"), access_token=self.admin_user_tok, content={} ) self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) url = "/groups/%s/self/accept_invite" % (group_id,) request, channel = self.make_request( - "PUT", url.encode('ascii'), access_token=self.other_user_token, content={} + "PUT", url.encode("ascii"), access_token=self.other_user_token, content={} ) self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -522,7 +520,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): url = "/admin/delete_group/" + group_id request, channel = self.make_request( "POST", - url.encode('ascii'), + url.encode("ascii"), access_token=self.admin_user_tok, content={"localpart": "test"}, ) @@ -544,7 +542,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): url = "/groups/%s/profile" % (group_id,) request, channel = self.make_request( - "GET", url.encode('ascii'), access_token=self.admin_user_tok + "GET", url.encode("ascii"), access_token=self.admin_user_tok ) self.render(request) @@ -556,7 +554,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): """Returns the list of groups the user is in (given their access token) """ request, channel = self.make_request( - "GET", "/joined_groups".encode('ascii'), access_token=access_token + "GET", "/joined_groups".encode("ascii"), access_token=access_token ) self.render(request) diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py index efc5a99db3..6803b372ac 100644 --- a/tests/rest/client/test_consent.py +++ b/tests/rest/client/test_consent.py @@ -42,17 +42,17 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): # Make some temporary templates... temp_consent_path = self.mktemp() os.mkdir(temp_consent_path) - os.mkdir(os.path.join(temp_consent_path, 'en')) + os.mkdir(os.path.join(temp_consent_path, "en")) config["user_consent"] = { "version": "1", "template_dir": os.path.abspath(temp_consent_path), } - with open(os.path.join(temp_consent_path, "en/1.html"), 'w') as f: + with open(os.path.join(temp_consent_path, "en/1.html"), "w") as f: f.write("{{version}},{{has_consented}}") - with open(os.path.join(temp_consent_path, "en/success.html"), 'w') as f: + with open(os.path.join(temp_consent_path, "en/success.html"), "w") as f: f.write("yay!") hs = self.setup_test_homeserver(config=config) @@ -88,7 +88,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) # Get the version from the body, and whether we've consented - version, consented = channel.result["body"].decode('ascii').split(",") + version, consented = channel.result["body"].decode("ascii").split(",") self.assertEqual(consented, "False") # POST to the consent page, saying we've agreed @@ -111,6 +111,6 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): # Get the version from the body, and check that it's the version we # agreed to, and that we've consented to it. - version, consented = channel.result["body"].decode('ascii').split(",") + version, consented = channel.result["body"].decode("ascii").split(",") self.assertEqual(consented, "True") self.assertEqual(version, "1") diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py index 68949307d9..c973521907 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py @@ -56,7 +56,7 @@ class IdentityTestCase(unittest.HomeserverTestCase): "address": "test@example.com", } request_data = json.dumps(params) - request_url = ("/rooms/%s/invite" % (room_id)).encode('ascii') + request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii") request, channel = self.make_request( b"POST", request_url, request_data, access_token=tok ) diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 72c7ed93cb..dff9b2f10c 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -183,7 +183,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): def test_set_displayname(self): request, channel = self.make_request( "PUT", - "/profile/%s/displayname" % (self.owner, ), + "/profile/%s/displayname" % (self.owner,), content=json.dumps({"displayname": "test"}), access_token=self.owner_tok, ) @@ -197,7 +197,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): """Attempts to set a stupid displayname should get a 400""" request, channel = self.make_request( "PUT", - "/profile/%s/displayname" % (self.owner, ), + "/profile/%s/displayname" % (self.owner,), content=json.dumps({"displayname": "test" * 100}), access_token=self.owner_tok, ) @@ -209,8 +209,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): def get_displayname(self): request, channel = self.make_request( - "GET", - "/profile/%s/displayname" % (self.owner, ), + "GET", "/profile/%s/displayname" % (self.owner,) ) self.render(request) self.assertEqual(channel.code, 200, channel.result) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 5f75ad7579..2e3a765bf3 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -79,7 +79,7 @@ class RoomPermissionsTestCase(RoomBase): # send a message in one of the rooms self.created_rmid_msg_path = ( "rooms/%s/send/m.room.message/a1" % (self.created_rmid) - ).encode('ascii') + ).encode("ascii") request, channel = self.make_request( "PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}' ) @@ -89,7 +89,7 @@ class RoomPermissionsTestCase(RoomBase): # set topic for public room request, channel = self.make_request( "PUT", - ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode('ascii'), + ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"), b'{"topic":"Public Room Topic"}', ) self.render(request) @@ -193,7 +193,7 @@ class RoomPermissionsTestCase(RoomBase): request, channel = self.make_request("GET", topic_path) self.render(request) self.assertEquals(200, channel.code, msg=channel.result["body"]) - self.assert_dict(json.loads(topic_content.decode('utf8')), channel.json_body) + self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body) # set/get topic in created PRIVATE room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) @@ -497,7 +497,7 @@ class RoomTopicTestCase(RoomBase): def test_invalid_puts(self): # missing keys or invalid json - request, channel = self.make_request("PUT", self.path, '{}') + request, channel = self.make_request("PUT", self.path, "{}") self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) @@ -515,11 +515,11 @@ class RoomTopicTestCase(RoomBase): self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", self.path, 'text only') + request, channel = self.make_request("PUT", self.path, "text only") self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", self.path, '') + request, channel = self.make_request("PUT", self.path, "") self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) @@ -572,7 +572,7 @@ class RoomMemberStateTestCase(RoomBase): def test_invalid_puts(self): path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) # missing keys or invalid json - request, channel = self.make_request("PUT", path, '{}') + request, channel = self.make_request("PUT", path, "{}") self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) @@ -590,11 +590,11 @@ class RoomMemberStateTestCase(RoomBase): self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, 'text only') + request, channel = self.make_request("PUT", path, "text only") self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, '') + request, channel = self.make_request("PUT", path, "") self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) @@ -604,7 +604,7 @@ class RoomMemberStateTestCase(RoomBase): Membership.JOIN, Membership.LEAVE, ) - request, channel = self.make_request("PUT", path, content.encode('ascii')) + request, channel = self.make_request("PUT", path, content.encode("ascii")) self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) @@ -616,7 +616,7 @@ class RoomMemberStateTestCase(RoomBase): # valid join message (NOOP since we made the room) content = '{"membership":"%s"}' % Membership.JOIN - request, channel = self.make_request("PUT", path, content.encode('ascii')) + request, channel = self.make_request("PUT", path, content.encode("ascii")) self.render(request) self.assertEquals(200, channel.code, msg=channel.result["body"]) @@ -678,7 +678,7 @@ class RoomMessagesTestCase(RoomBase): def test_invalid_puts(self): path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) # missing keys or invalid json - request, channel = self.make_request("PUT", path, b'{}') + request, channel = self.make_request("PUT", path, b"{}") self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) @@ -696,11 +696,11 @@ class RoomMessagesTestCase(RoomBase): self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, b'text only') + request, channel = self.make_request("PUT", path, b"text only") self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, b'') + request, channel = self.make_request("PUT", path, b"") self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) @@ -786,7 +786,7 @@ class RoomMessageListTestCase(RoomBase): self.render(request) self.assertEquals(200, channel.code) self.assertTrue("start" in channel.json_body) - self.assertEquals(token, channel.json_body['start']) + self.assertEquals(token, channel.json_body["start"]) self.assertTrue("chunk" in channel.json_body) self.assertTrue("end" in channel.json_body) @@ -798,7 +798,7 @@ class RoomMessageListTestCase(RoomBase): self.render(request) self.assertEquals(200, channel.code) self.assertTrue("start" in channel.json_body) - self.assertEquals(token, channel.json_body['start']) + self.assertEquals(token, channel.json_body["start"]) self.assertTrue("chunk" in channel.json_body) self.assertTrue("end" in channel.json_body) @@ -961,9 +961,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): # Set a profile for the test user self.displayname = "test user" - data = { - "displayname": self.displayname, - } + data = {"displayname": self.displayname} request_data = json.dumps(data) request, channel = self.make_request( "PUT", @@ -977,16 +975,12 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) def test_per_room_profile_forbidden(self): - data = { - "membership": "join", - "displayname": "other test user" - } + data = {"membership": "join", "displayname": "other test user"} request_data = json.dumps(data) request, channel = self.make_request( "PUT", - "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % ( - self.room_id, self.user_id, - ), + "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" + % (self.room_id, self.user_id), request_data, access_token=self.tok, ) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index f7133fc12e..9915367144 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -44,7 +44,7 @@ class RestHelper(object): path = path + "?access_token=%s" % tok request, channel = make_request( - self.hs.get_reactor(), "POST", path, json.dumps(content).encode('utf8') + self.hs.get_reactor(), "POST", path, json.dumps(content).encode("utf8") ) render(request, self.resource, self.hs.get_reactor()) @@ -93,7 +93,7 @@ class RestHelper(object): data = {"membership": membership} request, channel = make_request( - self.hs.get_reactor(), "PUT", path, json.dumps(data).encode('utf8') + self.hs.get_reactor(), "PUT", path, json.dumps(data).encode("utf8") ) render(request, self.resource, self.hs.get_reactor()) @@ -117,7 +117,7 @@ class RestHelper(object): path = path + "?access_token=%s" % tok request, channel = make_request( - self.hs.get_reactor(), "PUT", path, json.dumps(content).encode('utf8') + self.hs.get_reactor(), "PUT", path, json.dumps(content).encode("utf8") ) render(request, self.resource, self.hs.get_reactor()) @@ -134,7 +134,7 @@ class RestHelper(object): path = path + "?access_token=%s" % tok request, channel = make_request( - self.hs.get_reactor(), "PUT", path, json.dumps(body).encode('utf8') + self.hs.get_reactor(), "PUT", path, json.dumps(body).encode("utf8") ) render(request, self.resource, self.hs.get_reactor()) diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index a60a4a3b87..920de41de4 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -135,9 +135,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): self.assertEquals(len(self.email_attempts), 1) # Attempt to reset password without clicking the link - self._reset_password( - new_password, session_id, client_secret, expected_code=401, - ) + self._reset_password(new_password, session_id, client_secret, expected_code=401) # Assert we can log in with the old password self.login("kermit", old_password) @@ -172,9 +170,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): session_id = "weasle" # Attempt to reset password without even requesting an email - self._reset_password( - new_password, session_id, client_secret, expected_code=401, - ) + self._reset_password(new_password, session_id, client_secret, expected_code=401) # Assert we can log in with the old password self.login("kermit", old_password) @@ -258,19 +254,18 @@ class DeactivateTestCase(unittest.HomeserverTestCase): user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test") - request_data = json.dumps({ - "auth": { - "type": "m.login.password", - "user": user_id, - "password": "test", - }, - "erase": False, - }) + request_data = json.dumps( + { + "auth": { + "type": "m.login.password", + "user": user_id, + "password": "test", + }, + "erase": False, + } + ) request, channel = self.make_request( - "POST", - "account/deactivate", - request_data, - access_token=tok, + "POST", "account/deactivate", request_data, access_token=tok ) self.render(request) self.assertEqual(request.code, 200) diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py index bce5b0cf4c..b9e01c9418 100644 --- a/tests/rest/client/v2_alpha/test_capabilities.py +++ b/tests/rest/client/v2_alpha/test_capabilities.py @@ -47,15 +47,15 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): request, channel = self.make_request("GET", self.url, access_token=access_token) self.render(request) - capabilities = channel.json_body['capabilities'] + capabilities = channel.json_body["capabilities"] self.assertEqual(channel.code, 200) - for room_version in capabilities['m.room_versions']['available'].keys(): + for room_version in capabilities["m.room_versions"]["available"].keys(): self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, "" + room_version) self.assertEqual( self.config.default_room_version.identifier, - capabilities['m.room_versions']['default'], + capabilities["m.room_versions"]["default"], ) def test_get_change_password_capabilities(self): @@ -66,16 +66,16 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): request, channel = self.make_request("GET", self.url, access_token=access_token) self.render(request) - capabilities = channel.json_body['capabilities'] + capabilities = channel.json_body["capabilities"] self.assertEqual(channel.code, 200) # Test case where password is handled outside of Synapse - self.assertTrue(capabilities['m.change_password']['enabled']) + self.assertTrue(capabilities["m.change_password"]["enabled"]) self.get_success(self.store.user_set_password_hash(user, None)) request, channel = self.make_request("GET", self.url, access_token=access_token) self.render(request) - capabilities = channel.json_body['capabilities'] + capabilities = channel.json_body["capabilities"] self.assertEqual(channel.code, 200) - self.assertFalse(capabilities['m.change_password']['enabled']) + self.assertFalse(capabilities["m.change_password"]["enabled"]) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index b35b215446..89a3f95c0a 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -335,7 +335,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): config["email"] = { "enable_notifs": True, "template_dir": os.path.abspath( - pkg_resources.resource_filename('synapse', 'res/templates') + pkg_resources.resource_filename("synapse", "res/templates") ), "expiry_template_html": "notice_expiry.html", "expiry_template_text": "notice_expiry.txt", @@ -400,19 +400,18 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): (user_id, tok) = self.create_user() - request_data = json.dumps({ - "auth": { - "type": "m.login.password", - "user": user_id, - "password": "monkey", - }, - "erase": False, - }) + request_data = json.dumps( + { + "auth": { + "type": "m.login.password", + "user": user_id, + "password": "monkey", + }, + "erase": False, + } + ) request, channel = self.make_request( - "POST", - "account/deactivate", - request_data, - access_token=tok, + "POST", "account/deactivate", request_data, access_token=tok ) self.render(request) self.assertEqual(request.code, 200) @@ -476,20 +475,16 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - ] + servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource] def make_homeserver(self, reactor, clock): self.validity_period = 10 - self.max_delta = self.validity_period * 10. / 100. + self.max_delta = self.validity_period * 10.0 / 100.0 config = self.default_config() config["enable_registration"] = True - config["account_validity"] = { - "enabled": False, - } + config["account_validity"] = {"enabled": False} self.hs = self.setup_test_homeserver(config=config) self.hs.config.account_validity.period = self.validity_period diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index 43b3049daa..3deeed3a70 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -56,7 +56,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): creates the right shape of event. """ - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key=u"👍") + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") self.assertEquals(200, channel.code, channel.json_body) event_id = channel.json_body["event_id"] @@ -76,7 +76,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "content": { "m.relates_to": { "event_id": self.parent_id, - "key": u"👍", + "key": "👍", "rel_type": RelationTypes.ANNOTATION, } }, @@ -187,7 +187,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): access_tokens.append(token) idx = 0 - sent_groups = {u"👍": 10, u"a": 7, u"b": 5, u"c": 3, u"d": 2, u"e": 1} + sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1} for key in itertools.chain.from_iterable( itertools.repeat(key, num) for key, num in sent_groups.items() ): @@ -259,7 +259,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", - key=u"👍", + key="👍", access_token=access_tokens[idx], ) self.assertEquals(200, channel.code, channel.json_body) @@ -273,7 +273,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): prev_token = None found_event_ids = [] - encoded_key = six.moves.urllib.parse.quote_plus(u"👍".encode("utf-8")) + encoded_key = six.moves.urllib.parse.quote_plus("👍".encode("utf-8")) for _ in range(20): from_token = "" if prev_token: diff --git a/tests/rest/media/v1/test_base.py b/tests/rest/media/v1/test_base.py index 00688a7325..ebd7869208 100644 --- a/tests/rest/media/v1/test_base.py +++ b/tests/rest/media/v1/test_base.py @@ -21,17 +21,17 @@ from tests import unittest class GetFileNameFromHeadersTests(unittest.TestCase): # input -> expected result TEST_CASES = { - b"inline; filename=abc.txt": u"abc.txt", - b'inline; filename="azerty"': u"azerty", - b'inline; filename="aze%20rty"': u"aze%20rty", - b'inline; filename="aze\"rty"': u'aze"rty', - b'inline; filename="azer;ty"': u"azer;ty", - b"inline; filename*=utf-8''foo%C2%A3bar": u"foo£bar", + b"inline; filename=abc.txt": "abc.txt", + b'inline; filename="azerty"': "azerty", + b'inline; filename="aze%20rty"': "aze%20rty", + b'inline; filename="aze"rty"': 'aze"rty', + b'inline; filename="azer;ty"': "azer;ty", + b"inline; filename*=utf-8''foo%C2%A3bar": "foo£bar", } def tests(self): for hdr, expected in self.TEST_CASES.items(): - res = get_filename_from_headers({b'Content-Disposition': [hdr]}) + res = get_filename_from_headers({b"Content-Disposition": [hdr]}) self.assertEqual( res, expected, diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 1069a44145..e2d418b1df 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -143,7 +143,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.media_repo = hs.get_media_repository_resource() - self.download_resource = self.media_repo.children[b'download'] + self.download_resource = self.media_repo.children[b"download"] # smol png self.end_content = unhexlify( @@ -171,7 +171,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): headers = { b"Content-Length": [b"%d" % (len(self.end_content))], - b"Content-Type": [b'image/png'], + b"Content-Type": [b"image/png"], } if content_disposition: headers[b"Content-Disposition"] = [content_disposition] @@ -204,7 +204,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): correctly decode it as the UTF-8 string, and use filename* in the response. """ - filename = parse.quote(u"\u2603".encode('utf8')).encode('ascii') + filename = parse.quote("\u2603".encode("utf8")).encode("ascii") channel = self._req(b"inline; filename*=utf-8''" + filename + b".png") headers = channel.headers diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 1ab0f7293a..8fe5961866 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -55,10 +55,10 @@ class URLPreviewTests(unittest.HomeserverTestCase): hijack_auth = True user_id = "@test:user" end_content = ( - b'' + b"" b'' b'' - b'' + b"" ) def make_homeserver(self, reactor, clock): @@ -98,7 +98,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.media_repo = hs.get_media_repository_resource() - self.preview_url = self.media_repo.children[b'preview_url'] + self.preview_url = self.media_repo.children[b"preview_url"] self.lookups = {} @@ -109,7 +109,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): hostName, portNumber=0, addressTypes=None, - transportSemantics='TCP', + transportSemantics="TCP", ): resolution = HostResolution(hostName) @@ -118,7 +118,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): raise DNSLookupError("OH NO") for i in self.lookups[hostName]: - resolutionReceiver.addressResolved(i[0]('TCP', i[1], portNumber)) + resolutionReceiver.addressResolved(i[0]("TCP", i[1], portNumber)) resolutionReceiver.resolutionComplete() return resolutionReceiver @@ -184,11 +184,11 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] end_content = ( - b'' + b"" b'' b'' b'' - b'' + b"" ) request, channel = self.make_request( @@ -204,7 +204,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): client.dataReceived( ( b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" - b"Content-Type: text/html; charset=\"utf8\"\r\n\r\n" + b'Content-Type: text/html; charset="utf8"\r\n\r\n' ) % (len(end_content),) + end_content @@ -212,16 +212,16 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.pump() self.assertEqual(channel.code, 200) - self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430") + self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430") def test_non_ascii_preview_content_type(self): self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] end_content = ( - b'' + b"" b'' b'' - b'' + b"" ) request, channel = self.make_request( @@ -237,7 +237,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): client.dataReceived( ( b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" - b"Content-Type: text/html; charset=\"windows-1251\"\r\n\r\n" + b'Content-Type: text/html; charset="windows-1251"\r\n\r\n' ) % (len(end_content),) + end_content @@ -245,7 +245,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.pump() self.assertEqual(channel.code, 200) - self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430") + self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430") def test_ipaddr(self): """ @@ -293,8 +293,8 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual( channel.json_body, { - 'errcode': 'M_UNKNOWN', - 'error': 'DNS resolution failure during URL preview generation', + "errcode": "M_UNKNOWN", + "error": "DNS resolution failure during URL preview generation", }, ) @@ -314,8 +314,8 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual( channel.json_body, { - 'errcode': 'M_UNKNOWN', - 'error': 'DNS resolution failure during URL preview generation', + "errcode": "M_UNKNOWN", + "error": "DNS resolution failure during URL preview generation", }, ) @@ -334,8 +334,8 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual( channel.json_body, { - 'errcode': 'M_UNKNOWN', - 'error': 'IP address blocked by IP blacklist entry', + "errcode": "M_UNKNOWN", + "error": "IP address blocked by IP blacklist entry", }, ) self.assertEqual(channel.code, 403) @@ -354,8 +354,8 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual( channel.json_body, { - 'errcode': 'M_UNKNOWN', - 'error': 'IP address blocked by IP blacklist entry', + "errcode": "M_UNKNOWN", + "error": "IP address blocked by IP blacklist entry", }, ) @@ -396,7 +396,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): non-blacklisted one, it will be rejected. """ # Hardcode the URL resolving to the IP we want. - self.lookups[u"example.com"] = [ + self.lookups["example.com"] = [ (IPv4Address, "1.1.1.2"), (IPv4Address, "8.8.8.8"), ] @@ -410,8 +410,8 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual( channel.json_body, { - 'errcode': 'M_UNKNOWN', - 'error': 'DNS resolution failure during URL preview generation', + "errcode": "M_UNKNOWN", + "error": "DNS resolution failure during URL preview generation", }, ) @@ -435,8 +435,8 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual( channel.json_body, { - 'errcode': 'M_UNKNOWN', - 'error': 'DNS resolution failure during URL preview generation', + "errcode": "M_UNKNOWN", + "error": "DNS resolution failure during URL preview generation", }, ) @@ -456,7 +456,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual( channel.json_body, { - 'errcode': 'M_UNKNOWN', - 'error': 'DNS resolution failure during URL preview generation', + "errcode": "M_UNKNOWN", + "error": "DNS resolution failure during URL preview generation", }, ) diff --git a/tests/server.py b/tests/server.py index c15a47f2a4..e573c4e4c5 100644 --- a/tests/server.py +++ b/tests/server.py @@ -46,7 +46,7 @@ class FakeChannel(object): def json_body(self): if not self.result: raise Exception("No result yet.") - return json.loads(self.result["body"].decode('utf8')) + return json.loads(self.result["body"].decode("utf8")) @property def code(self): @@ -151,10 +151,10 @@ def make_request( Tuple[synapse.http.site.SynapseRequest, channel] """ if not isinstance(method, bytes): - method = method.encode('ascii') + method = method.encode("ascii") if not isinstance(path, bytes): - path = path.encode('ascii') + path = path.encode("ascii") # Decorate it to be the full path, if we're using shorthand if shorthand and not path.startswith(b"/_matrix"): @@ -165,7 +165,7 @@ def make_request( path = b"/" + path if isinstance(content, text_type): - content = content.encode('utf8') + content = content.encode("utf8") site = FakeSite() channel = FakeChannel(reactor) @@ -173,11 +173,11 @@ def make_request( req = request(site, channel) req.process = lambda: b"" req.content = BytesIO(content) - req.postpath = list(map(unquote, path[1:].split(b'/'))) + req.postpath = list(map(unquote, path[1:].split(b"/"))) if access_token: req.requestHeaders.addRawHeader( - b"Authorization", b"Bearer " + access_token.encode('ascii') + b"Authorization", b"Bearer " + access_token.encode("ascii") ) if federation_auth_origin is not None: @@ -242,7 +242,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): self.nameResolver = SimpleResolverComplexifier(FakeResolver()) super(ThreadedMemoryReactorClock, self).__init__() - def listenUDP(self, port, protocol, interface='', maxPacketSize=8196): + def listenUDP(self, port, protocol, interface="", maxPacketSize=8196): p = udp.Port(port, protocol, interface, maxPacketSize, self) p.startListening() self._udp.append(p) @@ -371,7 +371,7 @@ class FakeTransport(object): disconnecting = False disconnected = False - buffer = attr.ib(default=b'') + buffer = attr.ib(default=b"") producer = attr.ib(default=None) autoflush = attr.ib(default=True) diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 739ee59ce4..984feb623f 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -109,7 +109,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): Test when user has blocked notice, but notice ought to be there (NOOP) """ self._rlsn._auth.check_auth_blocking = Mock( - side_effect=ResourceLimitError(403, 'foo') + side_effect=ResourceLimitError(403, "foo") ) mock_event = Mock( @@ -128,7 +128,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): """ self._rlsn._auth.check_auth_blocking = Mock( - side_effect=ResourceLimitError(403, 'foo') + side_effect=ResourceLimitError(403, "foo") ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index 9c5311d916..8d3845c870 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -181,7 +181,7 @@ class StateTestCase(unittest.TestCase): id="PB", sender=BOB, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 50}}, ), ] @@ -229,14 +229,14 @@ class StateTestCase(unittest.TestCase): id="PB", sender=BOB, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 50, CHARLIE: 50}}, ), FakeEvent( id="PC", sender=CHARLIE, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 50, CHARLIE: 0}}, ), ] @@ -256,7 +256,7 @@ class StateTestCase(unittest.TestCase): id="PA1", sender=ALICE, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( @@ -266,14 +266,14 @@ class StateTestCase(unittest.TestCase): id="PA2", sender=ALICE, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 0}}, ), FakeEvent( id="PB", sender=BOB, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( @@ -296,7 +296,7 @@ class StateTestCase(unittest.TestCase): id="PA", sender=ALICE, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( @@ -326,7 +326,7 @@ class StateTestCase(unittest.TestCase): id="PA1", sender=ALICE, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( @@ -336,14 +336,14 @@ class StateTestCase(unittest.TestCase): id="PA2", sender=ALICE, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 0}}, ), FakeEvent( id="PB", sender=BOB, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 25a6c89ef5..622b16a071 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -74,7 +74,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): namespaces={}, ) # use the token as the filename - with open(as_token, 'w') as outfile: + with open(as_token, "w") as outfile: outfile.write(yaml.dump(as_yaml)) self.as_yaml_files.append(as_token) @@ -135,7 +135,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): namespaces={}, ) # use the token as the filename - with open(as_token, 'w') as outfile: + with open(as_token, "w") as outfile: outfile.write(yaml.dump(as_yaml)) self.as_yaml_files.append(as_token) diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index b62eae7abc..59c6f8c227 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -94,11 +94,11 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): result, [ { - 'access_token': 'access_token', - 'ip': 'ip', - 'user_agent': 'user_agent', - 'device_id': None, - 'last_seen': 12345678000, + "access_token": "access_token", + "ip": "ip", + "user_agent": "user_agent", + "device_id": None, + "last_seen": 12345678000, } ], ) @@ -125,11 +125,11 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): result, [ { - 'access_token': 'access_token', - 'ip': 'ip', - 'user_agent': 'user_agent', - 'device_id': None, - 'last_seen': 12345878000, + "access_token": "access_token", + "ip": "ip", + "user_agent": "user_agent", + "device_id": None, + "last_seen": 12345878000, } ], ) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 6396ccddb5..3cc18f9f1c 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -77,12 +77,12 @@ class DeviceStoreTestCase(tests.unittest.TestCase): # Add two device updates with a single stream_id yield self.store.add_device_change_to_streams( - "user_id", device_ids, ["somehost"], + "user_id", device_ids, ["somehost"] ) # Get all device updates ever meant for this remote now_stream_id, device_updates = yield self.store.get_devices_by_remote( - "somehost", -1, limit=100, + "somehost", -1, limit=100 ) # Check original device_ids are contained within these updates @@ -95,19 +95,19 @@ class DeviceStoreTestCase(tests.unittest.TestCase): # first add one device device_ids1 = ["device_id0"] yield self.store.add_device_change_to_streams( - "user_id", device_ids1, ["someotherhost"], + "user_id", device_ids1, ["someotherhost"] ) # then add 101 device_ids2 = ["device_id" + str(i + 1) for i in range(101)] yield self.store.add_device_change_to_streams( - "user_id", device_ids2, ["someotherhost"], + "user_id", device_ids2, ["someotherhost"] ) # then one more device_ids3 = ["newdevice"] yield self.store.add_device_change_to_streams( - "user_id", device_ids3, ["someotherhost"], + "user_id", device_ids3, ["someotherhost"] ) # @@ -116,20 +116,20 @@ class DeviceStoreTestCase(tests.unittest.TestCase): # first we should get a single update now_stream_id, device_updates = yield self.store.get_devices_by_remote( - "someotherhost", -1, limit=100, + "someotherhost", -1, limit=100 ) self._check_devices_in_updates(device_ids1, device_updates) # Then we should get an empty list back as the 101 devices broke the limit now_stream_id, device_updates = yield self.store.get_devices_by_remote( - "someotherhost", now_stream_id, limit=100, + "someotherhost", now_stream_id, limit=100 ) self.assertEqual(len(device_updates), 0) # The 101 devices should've been cleared, so we should now just get one device # update now_stream_id, device_updates = yield self.store.get_devices_by_remote( - "someotherhost", now_stream_id, limit=100, + "someotherhost", now_stream_id, limit=100 ) self._check_devices_in_updates(device_ids3, device_updates) diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index cd2bcd4ca3..c8ece15284 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -80,10 +80,10 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): yield self.store.store_device("user2", "device1", None) yield self.store.store_device("user2", "device2", None) - yield self.store.set_e2e_device_keys("user1", "device1", now, 'json11') - yield self.store.set_e2e_device_keys("user1", "device2", now, 'json12') - yield self.store.set_e2e_device_keys("user2", "device1", now, 'json21') - yield self.store.set_e2e_device_keys("user2", "device2", now, 'json22') + yield self.store.set_e2e_device_keys("user1", "device1", now, "json11") + yield self.store.set_e2e_device_keys("user1", "device2", now, "json12") + yield self.store.set_e2e_device_keys("user2", "device1", now, "json21") + yield self.store.set_e2e_device_keys("user2", "device2", now, "json22") res = yield self.store.get_e2e_device_keys( (("user1", "device1"), ("user2", "device2")) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 0d4e74d637..86c7ac350d 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -27,11 +27,11 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_get_prev_events_for_room(self): - room_id = '@ROOM:local' + room_id = "@ROOM:local" # add a bunch of events and hashes to act as forward extremities def insert_event(txn, i): - event_id = '$event_%i:local' % i + event_id = "$event_%i:local" % i txn.execute( ( @@ -45,19 +45,19 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): txn.execute( ( - 'INSERT INTO event_forward_extremities (room_id, event_id) ' - 'VALUES (?, ?)' + "INSERT INTO event_forward_extremities (room_id, event_id) " + "VALUES (?, ?)" ), (room_id, event_id), ) txn.execute( ( - 'INSERT INTO event_reference_hashes ' - '(event_id, algorithm, hash) ' + "INSERT INTO event_reference_hashes " + "(event_id, algorithm, hash) " "VALUES (?, 'sha256', ?)" ), - (event_id, b'ffff'), + (event_id, b"ffff"), ) for i in range(0, 11): diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py index 19f9ccf5e0..d44359ff93 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py @@ -61,22 +61,24 @@ class ExtremStatisticsTestCase(HomeserverTestCase): ) ) - expected = set([ - b'synapse_forward_extremities_bucket{le="1.0"} 0.0', - b'synapse_forward_extremities_bucket{le="2.0"} 2.0', - b'synapse_forward_extremities_bucket{le="3.0"} 2.0', - b'synapse_forward_extremities_bucket{le="5.0"} 2.0', - b'synapse_forward_extremities_bucket{le="7.0"} 3.0', - b'synapse_forward_extremities_bucket{le="10.0"} 3.0', - b'synapse_forward_extremities_bucket{le="15.0"} 3.0', - b'synapse_forward_extremities_bucket{le="20.0"} 3.0', - b'synapse_forward_extremities_bucket{le="50.0"} 3.0', - b'synapse_forward_extremities_bucket{le="100.0"} 3.0', - b'synapse_forward_extremities_bucket{le="200.0"} 3.0', - b'synapse_forward_extremities_bucket{le="500.0"} 3.0', - b'synapse_forward_extremities_bucket{le="+Inf"} 3.0', - b'synapse_forward_extremities_count 3.0', - b'synapse_forward_extremities_sum 10.0', - ]) + expected = set( + [ + b'synapse_forward_extremities_bucket{le="1.0"} 0.0', + b'synapse_forward_extremities_bucket{le="2.0"} 2.0', + b'synapse_forward_extremities_bucket{le="3.0"} 2.0', + b'synapse_forward_extremities_bucket{le="5.0"} 2.0', + b'synapse_forward_extremities_bucket{le="7.0"} 3.0', + b'synapse_forward_extremities_bucket{le="10.0"} 3.0', + b'synapse_forward_extremities_bucket{le="15.0"} 3.0', + b'synapse_forward_extremities_bucket{le="20.0"} 3.0', + b'synapse_forward_extremities_bucket{le="50.0"} 3.0', + b'synapse_forward_extremities_bucket{le="100.0"} 3.0', + b'synapse_forward_extremities_bucket{le="200.0"} 3.0', + b'synapse_forward_extremities_bucket{le="500.0"} 3.0', + b'synapse_forward_extremities_bucket{le="+Inf"} 3.0', + b"synapse_forward_extremities_count 3.0", + b"synapse_forward_extremities_sum 10.0", + ] + ) self.assertEqual(items, expected) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index f458c03054..0ce0b991f9 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -46,9 +46,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): user3_email = "user3@matrix.org" threepids = [ - {'medium': 'email', 'address': user1_email}, - {'medium': 'email', 'address': user2_email}, - {'medium': 'email', 'address': user3_email}, + {"medium": "email", "address": user1_email}, + {"medium": "email", "address": user2_email}, + {"medium": "email", "address": user3_email}, ] # -1 because user3 is a support user and does not count user_num = len(threepids) - 1 @@ -177,7 +177,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(None) ) - self.store.populate_monthly_active_users('user_id') + self.store.populate_monthly_active_users("user_id") self.pump() self.store.upsert_monthly_active_user.assert_called_once() @@ -188,7 +188,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(self.hs.get_clock().time_msec()) ) - self.store.populate_monthly_active_users('user_id') + self.store.populate_monthly_active_users("user_id") self.pump() self.store.upsert_monthly_active_user.assert_not_called() @@ -198,13 +198,13 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.assertEquals(self.get_success(count), 0) # Test reserved users but no registered users - user1 = '@user1:example.com' - user2 = '@user2:example.com' - user1_email = 'user1@example.com' - user2_email = 'user2@example.com' + user1 = "@user1:example.com" + user2 = "@user2:example.com" + user1_email = "user1@example.com" + user2_email = "user2@example.com" threepids = [ - {'medium': 'email', 'address': user1_email}, - {'medium': 'email', 'address': user2_email}, + {"medium": "email", "address": user1_email}, + {"medium": "email", "address": user2_email}, ] self.hs.config.mau_limits_reserved_threepids = threepids self.store.runInteraction( diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 4823d44dec..732a778fab 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -82,7 +82,7 @@ class RedactionTestCase(unittest.TestCase): "sender": user.to_string(), "state_key": user.to_string(), "room_id": room.to_string(), - "content": {"body": body, "msgtype": u"message"}, + "content": {"body": body, "msgtype": "message"}, }, ) @@ -118,7 +118,7 @@ class RedactionTestCase(unittest.TestCase): def test_redact(self): yield self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) - msg_event = yield self.inject_message(self.room1, self.u_alice, u"t") + msg_event = yield self.inject_message(self.room1, self.u_alice, "t") # Check event has not been redacted: event = yield self.store.get_event(msg_event.event_id) diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index c0e0155bb4..625b651e91 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -128,4 +128,4 @@ class TokenGenerator: def generate(self, user_id): self._last_issued_token += 1 - return u"%s-%d" % (user_id, self._last_issued_token) + return "%s-%d" % (user_id, self._last_issued_token) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index a1ea23b068..1bee45706f 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -78,7 +78,7 @@ class RoomEventsStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def STALE_test_room_name(self): - name = u"A-Room-Name" + name = "A-Room-Name" yield self.inject_room_event( etype=EventTypes.Name, name=name, content={"name": name}, depth=1 @@ -94,7 +94,7 @@ class RoomEventsStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def STALE_test_room_topic(self): - topic = u"A place for things" + topic = "A place for things" yield self.inject_room_event( etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1 diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index b6169436de..212a7ae765 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -76,10 +76,10 @@ class StateStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_get_state_groups_ids(self): e1 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Create, '', {} + self.room, self.u_alice, EventTypes.Create, "", {} ) e2 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"} + self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) state_group_map = yield self.store.get_state_groups_ids( @@ -89,16 +89,16 @@ class StateStoreTestCase(tests.unittest.TestCase): state_map = list(state_group_map.values())[0] self.assertDictEqual( state_map, - {(EventTypes.Create, ''): e1.event_id, (EventTypes.Name, ''): e2.event_id}, + {(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id}, ) @defer.inlineCallbacks def test_get_state_groups(self): e1 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Create, '', {} + self.room, self.u_alice, EventTypes.Create, "", {} ) e2 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"} + self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) state_group_map = yield self.store.get_state_groups(self.room, [e2.event_id]) @@ -113,10 +113,10 @@ class StateStoreTestCase(tests.unittest.TestCase): # this defaults to a linear DAG as each new injection defaults to whatever # forward extremities are currently in the DB for this room. e1 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Create, '', {} + self.room, self.u_alice, EventTypes.Create, "", {} ) e2 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"} + self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) e3 = yield self.inject_state_event( self.room, @@ -158,7 +158,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # check we can filter to the m.room.name event (with a '' state key) state = yield self.store.get_state_for_event( - e5.event_id, StateFilter.from_types([(EventTypes.Name, '')]) + e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) diff --git a/tests/test_preview.py b/tests/test_preview.py index 84ef5e5ba4..7f67ee9e1f 100644 --- a/tests/test_preview.py +++ b/tests/test_preview.py @@ -24,14 +24,14 @@ from . import unittest class PreviewTestCase(unittest.TestCase): def test_long_summarize(self): example_paras = [ - u"""Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami: + """Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami: Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in Troms county, Norway. The administrative centre of the municipality is the city of Tromsø. Outside of Norway, Tromso and Tromsö are alternative spellings of the city.Tromsø is considered the northernmost city in the world with a population above 50,000. The most populous town north of it is Alta, Norway, with a population of 14,272 (2013).""", - u"""Tromsø lies in Northern Norway. The municipality has a population of + """Tromsø lies in Northern Norway. The municipality has a population of (2015) 72,066, but with an annual influx of students it has over 75,000 most of the year. It is the largest urban area in Northern Norway and the third largest north of the Arctic Circle (following Murmansk and Norilsk). @@ -44,7 +44,7 @@ class PreviewTestCase(unittest.TestCase): Sandnessund Bridge. Tromsø Airport connects the city to many destinations in Europe. The city is warmer than most other places located on the same latitude, due to the warming effect of the Gulf Stream.""", - u"""The city centre of Tromsø contains the highest number of old wooden + """The city centre of Tromsø contains the highest number of old wooden houses in Northern Norway, the oldest house dating from 1789. The Arctic Cathedral, a modern church from 1965, is probably the most famous landmark in Tromsø. The city is a cultural centre for its region, with several @@ -58,87 +58,87 @@ class PreviewTestCase(unittest.TestCase): self.assertEquals( desc, - u"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" - u" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" - u" Troms county, Norway. The administrative centre of the municipality is" - u" the city of Tromsø. Outside of Norway, Tromso and Tromsö are" - u" alternative spellings of the city.Tromsø is considered the northernmost" - u" city in the world with a population above 50,000. The most populous town" - u" north of it is Alta, Norway, with a population of 14,272 (2013).", + "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" + " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" + " Troms county, Norway. The administrative centre of the municipality is" + " the city of Tromsø. Outside of Norway, Tromso and Tromsö are" + " alternative spellings of the city.Tromsø is considered the northernmost" + " city in the world with a population above 50,000. The most populous town" + " north of it is Alta, Norway, with a population of 14,272 (2013).", ) desc = summarize_paragraphs(example_paras[1:], min_size=200, max_size=500) self.assertEquals( desc, - u"Tromsø lies in Northern Norway. The municipality has a population of" - u" (2015) 72,066, but with an annual influx of students it has over 75,000" - u" most of the year. It is the largest urban area in Northern Norway and the" - u" third largest north of the Arctic Circle (following Murmansk and Norilsk)." - u" Most of Tromsø, including the city centre, is located on the island of" - u" Tromsøya, 350 kilometres (217 mi) north of the Arctic Circle. In 2012," - u" Tromsøya had a population of 36,088. Substantial parts of the urban…", + "Tromsø lies in Northern Norway. The municipality has a population of" + " (2015) 72,066, but with an annual influx of students it has over 75,000" + " most of the year. It is the largest urban area in Northern Norway and the" + " third largest north of the Arctic Circle (following Murmansk and Norilsk)." + " Most of Tromsø, including the city centre, is located on the island of" + " Tromsøya, 350 kilometres (217 mi) north of the Arctic Circle. In 2012," + " Tromsøya had a population of 36,088. Substantial parts of the urban…", ) def test_short_summarize(self): example_paras = [ - u"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" - u" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" - u" Troms county, Norway.", - u"Tromsø lies in Northern Norway. The municipality has a population of" - u" (2015) 72,066, but with an annual influx of students it has over 75,000" - u" most of the year.", - u"The city centre of Tromsø contains the highest number of old wooden" - u" houses in Northern Norway, the oldest house dating from 1789. The Arctic" - u" Cathedral, a modern church from 1965, is probably the most famous landmark" - u" in Tromsø.", + "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" + " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" + " Troms county, Norway.", + "Tromsø lies in Northern Norway. The municipality has a population of" + " (2015) 72,066, but with an annual influx of students it has over 75,000" + " most of the year.", + "The city centre of Tromsø contains the highest number of old wooden" + " houses in Northern Norway, the oldest house dating from 1789. The Arctic" + " Cathedral, a modern church from 1965, is probably the most famous landmark" + " in Tromsø.", ] desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) self.assertEquals( desc, - u"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" - u" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" - u" Troms county, Norway.\n" - u"\n" - u"Tromsø lies in Northern Norway. The municipality has a population of" - u" (2015) 72,066, but with an annual influx of students it has over 75,000" - u" most of the year.", + "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" + " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" + " Troms county, Norway.\n" + "\n" + "Tromsø lies in Northern Norway. The municipality has a population of" + " (2015) 72,066, but with an annual influx of students it has over 75,000" + " most of the year.", ) def test_small_then_large_summarize(self): example_paras = [ - u"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" - u" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" - u" Troms county, Norway.", - u"Tromsø lies in Northern Norway. The municipality has a population of" - u" (2015) 72,066, but with an annual influx of students it has over 75,000" - u" most of the year." - u" The city centre of Tromsø contains the highest number of old wooden" - u" houses in Northern Norway, the oldest house dating from 1789. The Arctic" - u" Cathedral, a modern church from 1965, is probably the most famous landmark" - u" in Tromsø.", + "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" + " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" + " Troms county, Norway.", + "Tromsø lies in Northern Norway. The municipality has a population of" + " (2015) 72,066, but with an annual influx of students it has over 75,000" + " most of the year." + " The city centre of Tromsø contains the highest number of old wooden" + " houses in Northern Norway, the oldest house dating from 1789. The Arctic" + " Cathedral, a modern church from 1965, is probably the most famous landmark" + " in Tromsø.", ] desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) self.assertEquals( desc, - u"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" - u" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" - u" Troms county, Norway.\n" - u"\n" - u"Tromsø lies in Northern Norway. The municipality has a population of" - u" (2015) 72,066, but with an annual influx of students it has over 75,000" - u" most of the year. The city centre of Tromsø contains the highest number" - u" of old wooden houses in Northern Norway, the oldest house dating from" - u" 1789. The Arctic Cathedral, a modern church from…", + "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" + " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" + " Troms county, Norway.\n" + "\n" + "Tromsø lies in Northern Norway. The municipality has a population of" + " (2015) 72,066, but with an annual influx of students it has over 75,000" + " most of the year. The city centre of Tromsø contains the highest number" + " of old wooden houses in Northern Norway, the oldest house dating from" + " 1789. The Arctic Cathedral, a modern church from…", ) class PreviewUrlTestCase(unittest.TestCase): def test_simple(self): - html = u""" + html = """ Foo @@ -149,10 +149,10 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, {u"og:title": u"Foo", u"og:description": u"Some text."}) + self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."}) def test_comment(self): - html = u""" + html = """ Foo @@ -164,10 +164,10 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, {u"og:title": u"Foo", u"og:description": u"Some text."}) + self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."}) def test_comment2(self): - html = u""" + html = """ Foo @@ -185,13 +185,13 @@ class PreviewUrlTestCase(unittest.TestCase): self.assertEquals( og, { - u"og:title": u"Foo", - u"og:description": u"Some text.\n\nSome more text.\n\nText\n\nMore text", + "og:title": "Foo", + "og:description": "Some text.\n\nSome more text.\n\nText\n\nMore text", }, ) def test_script(self): - html = u""" + html = """ Foo @@ -203,10 +203,10 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, {u"og:title": u"Foo", u"og:description": u"Some text."}) + self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."}) def test_missing_title(self): - html = u""" + html = """ Some text. @@ -216,10 +216,10 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, {u"og:title": None, u"og:description": u"Some text."}) + self.assertEquals(og, {"og:title": None, "og:description": "Some text."}) def test_h1_as_title(self): - html = u""" + html = """ @@ -230,10 +230,10 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, {u"og:title": u"Title", u"og:description": u"Some text."}) + self.assertEquals(og, {"og:title": "Title", "og:description": "Some text."}) def test_missing_title_and_broken_h1(self): - html = u""" + html = """

@@ -244,4 +244,4 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, {u"og:title": None, u"og:description": u"Some text."}) + self.assertEquals(og, {"og:title": None, "og:description": "Some text."}) diff --git a/tests/test_server.py b/tests/test_server.py index 08fb3fe02f..da29ae92ce 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -69,8 +69,8 @@ class JsonResourceTests(unittest.TestCase): ) render(request, res, self.reactor) - self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]}) - self.assertEqual(got_kwargs, {u"room_id": u"\N{SNOWMAN}"}) + self.assertEqual(request.args, {b"a": ["\N{SNOWMAN}".encode("utf8")]}) + self.assertEqual(got_kwargs, {"room_id": "\N{SNOWMAN}"}) def test_callback_direct_exception(self): """ @@ -87,7 +87,7 @@ class JsonResourceTests(unittest.TestCase): request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") render(request, res, self.reactor) - self.assertEqual(channel.result["code"], b'500') + self.assertEqual(channel.result["code"], b"500") def test_callback_indirect_exception(self): """ @@ -110,7 +110,7 @@ class JsonResourceTests(unittest.TestCase): request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") render(request, res, self.reactor) - self.assertEqual(channel.result["code"], b'500') + self.assertEqual(channel.result["code"], b"500") def test_callback_synapseerror(self): """ @@ -127,7 +127,7 @@ class JsonResourceTests(unittest.TestCase): request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") render(request, res, self.reactor) - self.assertEqual(channel.result["code"], b'403') + self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.json_body["error"], "Forbidden!!one!") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -148,7 +148,7 @@ class JsonResourceTests(unittest.TestCase): request, channel = make_request(self.reactor, b"GET", b"/_matrix/foobar") render(request, res, self.reactor) - self.assertEqual(channel.result["code"], b'400') + self.assertEqual(channel.result["code"], b"400") self.assertEqual(channel.json_body["error"], "Unrecognized request") self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") @@ -180,7 +180,7 @@ class SiteTestCase(unittest.HomeserverTestCase): # Make a resource and a Site, the resource will hang and allow us to # time out the request while it's 'processing' base_resource = Resource() - base_resource.putChild(b'', HangingResource()) + base_resource.putChild(b"", HangingResource()) site = SynapseSite("test", "site_tag", {}, base_resource, "1.0") server = site.buildProtocol(None) diff --git a/tests/test_state.py b/tests/test_state.py index 6491a7105a..6d33566f47 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -366,11 +366,11 @@ class StateTestCase(unittest.TestCase): def _add_depths(self, nodes, edges): def _get_depth(ev): node = nodes[ev] - if 'depth' not in node: + if "depth" not in node: prevs = edges[ev] depth = max(_get_depth(prev) for prev in prevs) + 1 - node['depth'] = depth - return node['depth'] + node["depth"] = depth + return node["depth"] for n in nodes: _get_depth(n) diff --git a/tests/test_types.py b/tests/test_types.py index d83c36559f..9ab5f829b0 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -102,7 +102,7 @@ class MapUsernameTestCase(unittest.TestCase): def testNonAscii(self): # this should work with either a unicode or a bytes - self.assertEqual(map_username_to_mxid_localpart(u'têst'), "t=c3=aast") + self.assertEqual(map_username_to_mxid_localpart("têst"), "t=c3=aast") self.assertEqual( - map_username_to_mxid_localpart(u'têst'.encode('utf-8')), "t=c3=aast" + map_username_to_mxid_localpart("têst".encode("utf-8")), "t=c3=aast" ) diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py index fde0baee8e..813f984199 100644 --- a/tests/test_utils/logging_setup.py +++ b/tests/test_utils/logging_setup.py @@ -27,7 +27,7 @@ class ToTwistedHandler(logging.Handler): def emit(self, record): log_entry = self.format(record) - log_level = record.levelname.lower().replace('warning', 'warn') + log_level = record.levelname.lower().replace("warning", "warn") self.tx_log.emit( twisted.logger.LogLevel.levelWithName(log_level), log_entry.replace("{", r"(").replace("}", r")"), diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 6a180ddc32..118c3bd238 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -265,7 +265,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): pr.disable() with open("filter_events_for_server.profile", "w+") as f: - ps = pstats.Stats(pr, stream=f).sort_stats('cumulative') + ps = pstats.Stats(pr, stream=f).sort_stats("cumulative") ps.print_stats() # the result should be 5 redacted events, and 5 unredacted events. diff --git a/tests/unittest.py b/tests/unittest.py index b6dc7932ce..d64702b0c2 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -298,7 +298,7 @@ class HomeserverTestCase(TestCase): Tuple[synapse.http.site.SynapseRequest, channel] """ if isinstance(content, dict): - content = json.dumps(content).encode('utf8') + content = json.dumps(content).encode("utf8") return make_request( self.reactor, @@ -389,7 +389,7 @@ class HomeserverTestCase(TestCase): Returns: The MXID of the new user (unicode). """ - self.hs.config.registration_shared_secret = u"shared" + self.hs.config.registration_shared_secret = "shared" # Create the user request, channel = self.make_request("GET", "/_matrix/client/r0/admin/register") @@ -397,13 +397,13 @@ class HomeserverTestCase(TestCase): nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) - nonce_str = b"\x00".join([username.encode('utf8'), password.encode('utf8')]) + nonce_str = b"\x00".join([username.encode("utf8"), password.encode("utf8")]) if admin: nonce_str += b"\x00admin" else: nonce_str += b"\x00notadmin" - want_mac.update(nonce.encode('ascii') + b"\x00" + nonce_str) + want_mac.update(nonce.encode("ascii") + b"\x00" + nonce_str) want_mac = want_mac.hexdigest() body = json.dumps( @@ -416,7 +416,7 @@ class HomeserverTestCase(TestCase): } ) request, channel = self.make_request( - "POST", "/_matrix/client/r0/admin/register", body.encode('utf8') + "POST", "/_matrix/client/r0/admin/register", body.encode("utf8") ) self.render(request) self.assertEqual(channel.code, 200) @@ -435,7 +435,7 @@ class HomeserverTestCase(TestCase): body["device_id"] = device_id request, channel = self.make_request( - "POST", "/_matrix/client/r0/login", json.dumps(body).encode('utf8') + "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8") ) self.render(request) self.assertEqual(channel.code, 200, channel.result) @@ -481,9 +481,7 @@ class HomeserverTestCase(TestCase): if soft_failed: event.internal_metadata.soft_failed = True - self.get_success( - event_creator.send_nonmember_event(requester, event, context) - ) + self.get_success(event_creator.send_nonmember_event(requester, event, context)) return event.event_id @@ -508,7 +506,7 @@ class HomeserverTestCase(TestCase): body = {"type": "m.login.password", "user": username, "password": password} request, channel = self.make_request( - "POST", "/_matrix/client/r0/login", json.dumps(body).encode('utf8') + "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8") ) self.render(request) self.assertEqual(channel.code, 403, channel.result) diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 463a737efa..6f8f52537c 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -88,24 +88,24 @@ class DescriptorTestCase(unittest.TestCase): obj = Cls() - obj.mock.return_value = 'fish' + obj.mock.return_value = "fish" r = yield obj.fn(1, 2) - self.assertEqual(r, 'fish') + self.assertEqual(r, "fish") obj.mock.assert_called_once_with(1, 2) obj.mock.reset_mock() # a call with different params should call the mock again - obj.mock.return_value = 'chips' + obj.mock.return_value = "chips" r = yield obj.fn(1, 3) - self.assertEqual(r, 'chips') + self.assertEqual(r, "chips") obj.mock.assert_called_once_with(1, 3) obj.mock.reset_mock() # the two values should now be cached r = yield obj.fn(1, 2) - self.assertEqual(r, 'fish') + self.assertEqual(r, "fish") r = yield obj.fn(1, 3) - self.assertEqual(r, 'chips') + self.assertEqual(r, "chips") obj.mock.assert_not_called() @defer.inlineCallbacks @@ -121,25 +121,25 @@ class DescriptorTestCase(unittest.TestCase): return self.mock(arg1, arg2) obj = Cls() - obj.mock.return_value = 'fish' + obj.mock.return_value = "fish" r = yield obj.fn(1, 2) - self.assertEqual(r, 'fish') + self.assertEqual(r, "fish") obj.mock.assert_called_once_with(1, 2) obj.mock.reset_mock() # a call with different params should call the mock again - obj.mock.return_value = 'chips' + obj.mock.return_value = "chips" r = yield obj.fn(2, 3) - self.assertEqual(r, 'chips') + self.assertEqual(r, "chips") obj.mock.assert_called_once_with(2, 3) obj.mock.reset_mock() # the two values should now be cached; we should be able to vary # the second argument and still get the cached result. r = yield obj.fn(1, 4) - self.assertEqual(r, 'fish') + self.assertEqual(r, "fish") r = yield obj.fn(2, 5) - self.assertEqual(r, 'chips') + self.assertEqual(r, "chips") obj.mock.assert_not_called() def test_cache_logcontexts(self): @@ -248,30 +248,30 @@ class DescriptorTestCase(unittest.TestCase): obj = Cls() - obj.mock.return_value = 'fish' + obj.mock.return_value = "fish" r = yield obj.fn(1, 2, 3) - self.assertEqual(r, 'fish') + self.assertEqual(r, "fish") obj.mock.assert_called_once_with(1, 2, 3) obj.mock.reset_mock() # a call with same params shouldn't call the mock again r = yield obj.fn(1, 2) - self.assertEqual(r, 'fish') + self.assertEqual(r, "fish") obj.mock.assert_not_called() obj.mock.reset_mock() # a call with different params should call the mock again - obj.mock.return_value = 'chips' + obj.mock.return_value = "chips" r = yield obj.fn(2, 3) - self.assertEqual(r, 'chips') + self.assertEqual(r, "chips") obj.mock.assert_called_once_with(2, 3, 3) obj.mock.reset_mock() # the two values should now be cached r = yield obj.fn(1, 2) - self.assertEqual(r, 'fish') + self.assertEqual(r, "fish") r = yield obj.fn(2, 3) - self.assertEqual(r, 'chips') + self.assertEqual(r, "chips") obj.mock.assert_not_called() @@ -297,7 +297,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): with logcontext.LoggingContext() as c1: c1.request = "c1" obj = Cls() - obj.mock.return_value = {10: 'fish', 20: 'chips'} + obj.mock.return_value = {10: "fish", 20: "chips"} d1 = obj.list_fn([10, 20], 2) self.assertEqual( logcontext.LoggingContext.current_context(), @@ -306,26 +306,26 @@ class CachedListDescriptorTestCase(unittest.TestCase): r = yield d1 self.assertEqual(logcontext.LoggingContext.current_context(), c1) obj.mock.assert_called_once_with([10, 20], 2) - self.assertEqual(r, {10: 'fish', 20: 'chips'}) + self.assertEqual(r, {10: "fish", 20: "chips"}) obj.mock.reset_mock() # a call with different params should call the mock again - obj.mock.return_value = {30: 'peas'} + obj.mock.return_value = {30: "peas"} r = yield obj.list_fn([20, 30], 2) obj.mock.assert_called_once_with([30], 2) - self.assertEqual(r, {20: 'chips', 30: 'peas'}) + self.assertEqual(r, {20: "chips", 30: "peas"}) obj.mock.reset_mock() # all the values should now be cached r = yield obj.fn(10, 2) - self.assertEqual(r, 'fish') + self.assertEqual(r, "fish") r = yield obj.fn(20, 2) - self.assertEqual(r, 'chips') + self.assertEqual(r, "chips") r = yield obj.fn(30, 2) - self.assertEqual(r, 'peas') + self.assertEqual(r, "peas") r = yield obj.list_fn([10, 20, 30], 2) obj.mock.assert_not_called() - self.assertEqual(r, {10: 'fish', 20: 'chips', 30: 'peas'}) + self.assertEqual(r, {10: "fish", 20: "chips", 30: "peas"}) @defer.inlineCallbacks def test_invalidate(self): @@ -350,16 +350,16 @@ class CachedListDescriptorTestCase(unittest.TestCase): invalidate1 = mock.Mock() # cache miss - obj.mock.return_value = {10: 'fish', 20: 'chips'} + obj.mock.return_value = {10: "fish", 20: "chips"} r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0) obj.mock.assert_called_once_with([10, 20], 2) - self.assertEqual(r1, {10: 'fish', 20: 'chips'}) + self.assertEqual(r1, {10: "fish", 20: "chips"}) obj.mock.reset_mock() # cache hit r2 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate1) obj.mock.assert_not_called() - self.assertEqual(r2, {10: 'fish', 20: 'chips'}) + self.assertEqual(r2, {10: "fish", 20: "chips"}) invalidate0.assert_not_called() invalidate1.assert_not_called() diff --git a/tests/util/caches/test_ttlcache.py b/tests/util/caches/test_ttlcache.py index 03b3c15db6..c94cbb662b 100644 --- a/tests/util/caches/test_ttlcache.py +++ b/tests/util/caches/test_ttlcache.py @@ -27,57 +27,57 @@ class CacheTestCase(unittest.TestCase): def test_get(self): """simple set/get tests""" - self.cache.set('one', '1', 10) - self.cache.set('two', '2', 20) - self.cache.set('three', '3', 30) + self.cache.set("one", "1", 10) + self.cache.set("two", "2", 20) + self.cache.set("three", "3", 30) self.assertEqual(len(self.cache), 3) - self.assertTrue('one' in self.cache) - self.assertEqual(self.cache.get('one'), '1') - self.assertEqual(self.cache['one'], '1') - self.assertEqual(self.cache.get_with_expiry('one'), ('1', 110)) + self.assertTrue("one" in self.cache) + self.assertEqual(self.cache.get("one"), "1") + self.assertEqual(self.cache["one"], "1") + self.assertEqual(self.cache.get_with_expiry("one"), ("1", 110)) self.assertEqual(self.cache._metrics.hits, 3) self.assertEqual(self.cache._metrics.misses, 0) - self.cache.set('two', '2.5', 20) - self.assertEqual(self.cache['two'], '2.5') + self.cache.set("two", "2.5", 20) + self.assertEqual(self.cache["two"], "2.5") self.assertEqual(self.cache._metrics.hits, 4) # non-existent-item tests - self.assertEqual(self.cache.get('four', '4'), '4') - self.assertIs(self.cache.get('four', None), None) + self.assertEqual(self.cache.get("four", "4"), "4") + self.assertIs(self.cache.get("four", None), None) with self.assertRaises(KeyError): - self.cache['four'] + self.cache["four"] with self.assertRaises(KeyError): - self.cache.get('four') + self.cache.get("four") with self.assertRaises(KeyError): - self.cache.get_with_expiry('four') + self.cache.get_with_expiry("four") self.assertEqual(self.cache._metrics.hits, 4) self.assertEqual(self.cache._metrics.misses, 5) def test_expiry(self): - self.cache.set('one', '1', 10) - self.cache.set('two', '2', 20) - self.cache.set('three', '3', 30) + self.cache.set("one", "1", 10) + self.cache.set("two", "2", 20) + self.cache.set("three", "3", 30) self.assertEqual(len(self.cache), 3) - self.assertEqual(self.cache['one'], '1') - self.assertEqual(self.cache['two'], '2') + self.assertEqual(self.cache["one"], "1") + self.assertEqual(self.cache["two"], "2") # enough for the first entry to expire, but not the rest self.mock_timer.side_effect = lambda: 110.0 self.assertEqual(len(self.cache), 2) - self.assertFalse('one' in self.cache) - self.assertEqual(self.cache['two'], '2') - self.assertEqual(self.cache['three'], '3') + self.assertFalse("one" in self.cache) + self.assertEqual(self.cache["two"], "2") + self.assertEqual(self.cache["three"], "3") - self.assertEqual(self.cache.get_with_expiry('two'), ('2', 120)) + self.assertEqual(self.cache.get_with_expiry("two"), ("2", 120)) self.assertEqual(self.cache._metrics.hits, 5) self.assertEqual(self.cache._metrics.misses, 0) diff --git a/tests/utils.py b/tests/utils.py index f8c7ad2604..bd2c7c954c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -358,9 +358,9 @@ def setup_test_homeserver( # Need to let the HS build an auth handler and then mess with it # because AuthHandler's constructor requires the HS, so we can't make one # beforehand and pass it in to the HS's constructor (chicken / egg) - hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode('utf8')).hexdigest() + hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode("utf8")).hexdigest() hs.get_auth_handler().validate_hash = ( - lambda p, h: hashlib.md5(p.encode('utf8')).hexdigest() == h + lambda p, h: hashlib.md5(p.encode("utf8")).hexdigest() == h ) fed = kargs.get("resource_for_federation", None) @@ -407,7 +407,7 @@ class MockHttpResource(HttpServer): def trigger_get(self, path): return self.trigger(b"GET", path, None) - @patch('twisted.web.http.Request') + @patch("twisted.web.http.Request") @defer.inlineCallbacks def trigger( self, http_method, path, content, mock_request, federation_auth_origin=None @@ -431,12 +431,12 @@ class MockHttpResource(HttpServer): # annoyingly we return a twisted http request which has chained calls # to get at the http content, hence mock it here. mock_content = Mock() - config = {'read.return_value': content} + config = {"read.return_value": content} mock_content.configure_mock(**config) mock_request.content = mock_content - mock_request.method = http_method.encode('ascii') - mock_request.uri = path.encode('ascii') + mock_request.method = http_method.encode("ascii") + mock_request.uri = path.encode("ascii") mock_request.getClientIP.return_value = "-" @@ -452,14 +452,14 @@ class MockHttpResource(HttpServer): # add in query params to the right place try: - mock_request.args = urlparse.parse_qs(path.split('?')[1]) - mock_request.path = path.split('?')[0] + mock_request.args = urlparse.parse_qs(path.split("?")[1]) + mock_request.path = path.split("?")[0] path = mock_request.path except Exception: pass if isinstance(path, bytes): - path = path.decode('utf8') + path = path.decode("utf8") for (method, pattern, func) in self.callbacks: if http_method != method: diff --git a/tox.ini b/tox.ini index 0c4d562766..09b4b8fc3c 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = packaging, py35, py36, py37, pep8, check_isort +envlist = packaging, py35, py36, py37, check_codestyle, check_isort [base] deps = @@ -112,12 +112,15 @@ deps = commands = check-manifest -[testenv:pep8] +[testenv:check_codestyle] skip_install = True basepython = python3.6 deps = flake8 -commands = /bin/sh -c "flake8 synapse tests scripts scripts-dev scripts/hash_password scripts/register_new_matrix_user scripts/synapse_port_db synctl {env:PEP8SUFFIX:}" + black +commands = + python -m black --check --diff . + /bin/sh -c "flake8 synapse tests scripts scripts-dev scripts/hash_password scripts/register_new_matrix_user scripts/synapse_port_db synctl {env:PEP8SUFFIX:}" [testenv:check_isort] skip_install = True -- cgit 1.5.1 From c3c6b00d956937ad50673cec75eab7989938a39b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 24 Jun 2019 11:34:45 +0100 Subject: Pass config_dir_path and data_dir_path into Config.read_config. (#5522) * Pull config_dir_path and data_dir_path calculation out of read_config_files * Pass config_dir_path and data_dir_path into read_config --- changelog.d/5522.misc | 1 + synapse/config/_base.py | 104 ++++++++++++++------- synapse/config/api.py | 2 +- synapse/config/appservice.py | 2 +- synapse/config/captcha.py | 2 +- synapse/config/cas.py | 2 +- synapse/config/consent_config.py | 2 +- synapse/config/database.py | 2 +- synapse/config/emailconfig.py | 2 +- synapse/config/groups.py | 2 +- synapse/config/jwt_config.py | 2 +- synapse/config/key.py | 2 +- synapse/config/logger.py | 2 +- synapse/config/metrics.py | 2 +- synapse/config/password.py | 2 +- synapse/config/password_auth_providers.py | 2 +- synapse/config/push.py | 2 +- synapse/config/ratelimiting.py | 2 +- synapse/config/registration.py | 2 +- synapse/config/repository.py | 2 +- synapse/config/room_directory.py | 2 +- synapse/config/saml2_config.py | 2 +- synapse/config/server.py | 2 +- synapse/config/server_notices_config.py | 2 +- synapse/config/spam_checker.py | 2 +- synapse/config/stats.py | 2 +- synapse/config/third_party_event_rules.py | 2 +- synapse/config/tls.py | 2 +- synapse/config/user_directory.py | 2 +- synapse/config/voip.py | 2 +- synapse/config/workers.py | 2 +- tests/config/test_tls.py | 2 +- .../federation/test_matrix_federation_agent.py | 2 +- tests/unittest.py | 2 +- tests/utils.py | 2 +- 35 files changed, 104 insertions(+), 67 deletions(-) create mode 100644 changelog.d/5522.misc (limited to 'tests') diff --git a/changelog.d/5522.misc b/changelog.d/5522.misc new file mode 100644 index 0000000000..17a7be5c99 --- /dev/null +++ b/changelog.d/5522.misc @@ -0,0 +1 @@ +Pass config_dir_path and data_dir_path into Config.read_config. diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 36e9c04cee..6baa315874 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -216,7 +218,7 @@ class Config(object): "--keys-directory", metavar="DIRECTORY", help="Where files such as certs and signing keys are stored when" - " their location is given explicitly in the config." + " their location is not given explicitly in the config." " Defaults to the directory containing the last config file", ) @@ -228,10 +230,22 @@ class Config(object): config_files = find_config_files(search_paths=config_args.config_path) + if not config_files: + config_parser.error("Must supply a config file.") + + if config_args.keys_directory: + config_dir_path = config_args.keys_directory + else: + config_dir_path = os.path.dirname(config_files[-1]) + config_dir_path = os.path.abspath(config_dir_path) + data_dir_path = os.getcwd() + config_dict = obj.read_config_files( - config_files, keys_directory=config_args.keys_directory + config_files, config_dir_path=config_dir_path, data_dir_path=data_dir_path + ) + obj.parse_config_dict( + config_dict, config_dir_path=config_dir_path, data_dir_path=data_dir_path ) - obj.parse_config_dict(config_dict) obj.invoke_all("read_arguments", config_args) @@ -282,7 +296,7 @@ class Config(object): metavar="DIRECTORY", help=( "Specify where additional config files such as signing keys and log" - " config should be stored. Defaults to the same directory as the main" + " config should be stored. Defaults to the same directory as the last" " config file." ), ) @@ -290,6 +304,20 @@ class Config(object): config_files = find_config_files(search_paths=config_args.config_path) + if not config_files: + config_parser.error( + "Must supply a config file.\nA config file can be automatically" + ' generated using "--generate-config -H SERVER_NAME' + ' -c CONFIG-FILE"' + ) + + if config_args.config_directory: + config_dir_path = config_args.config_directory + else: + config_dir_path = os.path.dirname(config_files[-1]) + config_dir_path = os.path.abspath(config_dir_path) + data_dir_path = os.getcwd() + generate_missing_configs = config_args.generate_missing_configs obj = cls() @@ -300,20 +328,10 @@ class Config(object): "Please specify either --report-stats=yes or --report-stats=no\n\n" + MISSING_REPORT_STATS_SPIEL ) - if not config_files: - config_parser.error( - "Must supply a config file.\nA config file can be automatically" - ' generated using "--generate-config -H SERVER_NAME' - ' -c CONFIG-FILE"' - ) + (config_path,) = config_files if not cls.path_exists(config_path): print("Generating config file %s" % (config_path,)) - if config_args.config_directory: - config_dir_path = config_args.config_directory - else: - config_dir_path = os.path.dirname(config_path) - config_dir_path = os.path.abspath(config_dir_path) server_name = config_args.server_name if not server_name: @@ -324,7 +342,7 @@ class Config(object): config_str = obj.generate_config( config_dir_path=config_dir_path, - data_dir_path=os.getcwd(), + data_dir_path=data_dir_path, server_name=server_name, report_stats=(config_args.report_stats == "yes"), generate_secrets=True, @@ -367,35 +385,37 @@ class Config(object): obj.invoke_all("add_arguments", parser) args = parser.parse_args(remaining_args) - if not config_files: - config_parser.error( - "Must supply a config file.\nA config file can be automatically" - ' generated using "--generate-config -H SERVER_NAME' - ' -c CONFIG-FILE"' - ) - config_dict = obj.read_config_files( - config_files, keys_directory=config_args.config_directory + config_files, config_dir_path=config_dir_path, data_dir_path=data_dir_path ) if generate_missing_configs: obj.generate_missing_files(config_dict) return None - obj.parse_config_dict(config_dict) + obj.parse_config_dict( + config_dict, config_dir_path=config_dir_path, data_dir_path=data_dir_path + ) obj.invoke_all("read_arguments", args) return obj - def read_config_files(self, config_files, keys_directory=None): + def read_config_files(self, config_files, config_dir_path, data_dir_path): """Read the config files into a dict + Args: + config_files (iterable[str]): A list of the config files to read + + config_dir_path (str): The path where the config files are kept. Used to + create filenames for things like the log config and the signing key. + + data_dir_path (str): The path where the data files are kept. Used to create + filenames for things like the database and media store. + Returns: dict """ - if not keys_directory: - keys_directory = os.path.dirname(config_files[-1]) - - self.config_dir_path = os.path.abspath(keys_directory) + # FIXME: get rid of this + self.config_dir_path = config_dir_path # first we read the config files into a dict specified_config = {} @@ -409,8 +429,8 @@ class Config(object): raise ConfigError(MISSING_SERVER_NAME) server_name = specified_config["server_name"] config_string = self.generate_config( - config_dir_path=self.config_dir_path, - data_dir_path=os.getcwd(), + config_dir_path=config_dir_path, + data_dir_path=data_dir_path, server_name=server_name, generate_secrets=False, ) @@ -430,8 +450,24 @@ class Config(object): ) return config - def parse_config_dict(self, config_dict): - self.invoke_all("read_config", config_dict) + def parse_config_dict(self, config_dict, config_dir_path, data_dir_path): + """Read the information from the config dict into this Config object. + + Args: + config_dict (dict): Configuration data, as read from the yaml + + config_dir_path (str): The path where the config files are kept. Used to + create filenames for things like the log config and the signing key. + + data_dir_path (str): The path where the data files are kept. Used to create + filenames for things like the database and media store. + """ + self.invoke_all( + "read_config", + config_dict, + config_dir_path=config_dir_path, + data_dir_path=data_dir_path, + ) def generate_missing_files(self, config_dict): self.invoke_all("generate_files", config_dict) diff --git a/synapse/config/api.py b/synapse/config/api.py index 23b0ea6962..d9eff9ae1f 100644 --- a/synapse/config/api.py +++ b/synapse/config/api.py @@ -18,7 +18,7 @@ from ._base import Config class ApiConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): self.room_invite_state_types = config.get( "room_invite_state_types", [ diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 679ee62480..b74cebfca9 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) class AppServiceConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): self.app_service_config_files = config.get("app_service_config_files", []) self.notify_appservices = config.get("notify_appservices", True) self.track_appservice_user_ips = config.get("track_appservice_user_ips", False) diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py index e2eb473a92..a08b08570b 100644 --- a/synapse/config/captcha.py +++ b/synapse/config/captcha.py @@ -16,7 +16,7 @@ from ._base import Config class CaptchaConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): self.recaptcha_private_key = config.get("recaptcha_private_key") self.recaptcha_public_key = config.get("recaptcha_public_key") self.enable_registration_captcha = config.get( diff --git a/synapse/config/cas.py b/synapse/config/cas.py index 609c0815c8..a5f0449955 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -22,7 +22,7 @@ class CasConfig(Config): cas_server_url: URL of CAS server """ - def read_config(self, config): + def read_config(self, config, **kwargs): cas_config = config.get("cas_config", None) if cas_config: self.cas_enabled = cas_config.get("enabled", True) diff --git a/synapse/config/consent_config.py b/synapse/config/consent_config.py index 5b0bf919c7..6fd4931681 100644 --- a/synapse/config/consent_config.py +++ b/synapse/config/consent_config.py @@ -84,7 +84,7 @@ class ConsentConfig(Config): self.user_consent_at_registration = False self.user_consent_policy_name = "Privacy Policy" - def read_config(self, config): + def read_config(self, config, **kwargs): consent_config = config.get("user_consent") if consent_config is None: return diff --git a/synapse/config/database.py b/synapse/config/database.py index adc0a47ddf..c8963e276a 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -18,7 +18,7 @@ from ._base import Config class DatabaseConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K")) self.database_config = config.get("database") diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 3a6cb07206..07df7b7173 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -27,7 +27,7 @@ from ._base import Config, ConfigError class EmailConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): # TODO: We should separate better the email configuration from the notification # and account validity config. diff --git a/synapse/config/groups.py b/synapse/config/groups.py index e4be172a79..d11f4d3b96 100644 --- a/synapse/config/groups.py +++ b/synapse/config/groups.py @@ -17,7 +17,7 @@ from ._base import Config class GroupsConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): self.enable_group_creation = config.get("enable_group_creation", False) self.group_creation_prefix = config.get("group_creation_prefix", "") diff --git a/synapse/config/jwt_config.py b/synapse/config/jwt_config.py index b190dcbe38..a2c97dea95 100644 --- a/synapse/config/jwt_config.py +++ b/synapse/config/jwt_config.py @@ -23,7 +23,7 @@ MISSING_JWT = """Missing jwt library. This is required for jwt login. class JWTConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): jwt_config = config.get("jwt_config", None) if jwt_config: self.jwt_enabled = jwt_config.get("enabled", False) diff --git a/synapse/config/key.py b/synapse/config/key.py index 21c4f5c51c..e58638f708 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -65,7 +65,7 @@ class TrustedKeyServer(object): class KeyConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): # the signing key can be specified inline or in a separate file if "signing_key" in config: self.signing_key = read_signing_keys([config["signing_key"]]) diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 9db2e087e4..153a137517 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -74,7 +74,7 @@ root: class LoggingConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): self.verbosity = config.get("verbose", 0) self.no_redirect_stdio = config.get("no_redirect_stdio", False) self.log_config = self.abspath(config.get("log_config")) diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index c85e234d22..6af82e1329 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -21,7 +21,7 @@ MISSING_SENTRY = """Missing sentry-sdk library. This is required to enable sentr class MetricsConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): self.enable_metrics = config.get("enable_metrics", False) self.report_stats = config.get("report_stats", None) self.metrics_port = config.get("metrics_port") diff --git a/synapse/config/password.py b/synapse/config/password.py index eea59e772b..300b67f236 100644 --- a/synapse/config/password.py +++ b/synapse/config/password.py @@ -20,7 +20,7 @@ class PasswordConfig(Config): """Password login configuration """ - def read_config(self, config): + def read_config(self, config, **kwargs): password_config = config.get("password_config", {}) if password_config is None: password_config = {} diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py index fcf279e8e1..8ffefd2639 100644 --- a/synapse/config/password_auth_providers.py +++ b/synapse/config/password_auth_providers.py @@ -21,7 +21,7 @@ LDAP_PROVIDER = "ldap_auth_provider.LdapAuthProvider" class PasswordAuthProviderConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): self.password_providers = [] providers = [] diff --git a/synapse/config/push.py b/synapse/config/push.py index 62c0060c9c..99d15e4461 100644 --- a/synapse/config/push.py +++ b/synapse/config/push.py @@ -18,7 +18,7 @@ from ._base import Config class PushConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): push_config = config.get("push", {}) self.push_include_content = push_config.get("include_content", True) diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 5a9adac480..b03047f2b5 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -36,7 +36,7 @@ class FederationRateLimitConfig(object): class RatelimitConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): # Load the new-style messages config if it exists. Otherwise fall back # to the old method. diff --git a/synapse/config/registration.py b/synapse/config/registration.py index a1e27ba66c..6d8a2df29b 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -46,7 +46,7 @@ class AccountValidityConfig(Config): class RegistrationConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): self.enable_registration = bool( strtobool(str(config.get("enable_registration", False))) ) diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 9f9669ebb1..15a19e0911 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -86,7 +86,7 @@ def parse_thumbnail_requirements(thumbnail_sizes): class ContentRepositoryConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): self.max_upload_size = self.parse_size(config.get("max_upload_size", "10M")) self.max_image_pixels = self.parse_size(config.get("max_image_pixels", "32M")) self.max_spider_size = self.parse_size(config.get("max_spider_size", "10M")) diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py index c1da0e20e0..24223db7a1 100644 --- a/synapse/config/room_directory.py +++ b/synapse/config/room_directory.py @@ -19,7 +19,7 @@ from ._base import Config, ConfigError class RoomDirectoryConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): self.enable_room_list_search = config.get("enable_room_list_search", True) alias_creation_rules = config.get("alias_creation_rules") diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index 2ec38e48e9..d86cf0e6ee 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -17,7 +17,7 @@ from ._base import Config, ConfigError class SAML2Config(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): self.saml2_enabled = False saml2_config = config.get("saml2_config") diff --git a/synapse/config/server.py b/synapse/config/server.py index 9ceca0a606..1e58b2e91b 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -40,7 +40,7 @@ DEFAULT_ROOM_VERSION = "4" class ServerConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): self.server_name = config["server_name"] self.server_context = config.get("server_context", None) diff --git a/synapse/config/server_notices_config.py b/synapse/config/server_notices_config.py index d930eb33b5..05110c17a6 100644 --- a/synapse/config/server_notices_config.py +++ b/synapse/config/server_notices_config.py @@ -66,7 +66,7 @@ class ServerNoticesConfig(Config): self.server_notices_mxid_avatar_url = None self.server_notices_room_name = None - def read_config(self, config): + def read_config(self, config, **kwargs): c = config.get("server_notices") if c is None: return diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py index 1502e9faba..1968003cb3 100644 --- a/synapse/config/spam_checker.py +++ b/synapse/config/spam_checker.py @@ -19,7 +19,7 @@ from ._base import Config class SpamCheckerConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): self.spam_checker = None provider = config.get("spam_checker", None) diff --git a/synapse/config/stats.py b/synapse/config/stats.py index 80fc1b9dd0..73a87c73f2 100644 --- a/synapse/config/stats.py +++ b/synapse/config/stats.py @@ -25,7 +25,7 @@ class StatsConfig(Config): Configuration for the behaviour of synapse's stats engine """ - def read_config(self, config): + def read_config(self, config, **kwargs): self.stats_enabled = True self.stats_bucket_size = 86400 self.stats_retention = sys.maxsize diff --git a/synapse/config/third_party_event_rules.py b/synapse/config/third_party_event_rules.py index a89dd5f98a..1bedd607b6 100644 --- a/synapse/config/third_party_event_rules.py +++ b/synapse/config/third_party_event_rules.py @@ -19,7 +19,7 @@ from ._base import Config class ThirdPartyRulesConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): self.third_party_event_rules = None provider = config.get("third_party_event_rules", None) diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 7951bf21fa..28be4366d6 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -33,7 +33,7 @@ logger = logging.getLogger(__name__) class TlsConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): acme_config = config.get("acme", None) if acme_config is None: diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py index e031b11599..0665dc3fcf 100644 --- a/synapse/config/user_directory.py +++ b/synapse/config/user_directory.py @@ -21,7 +21,7 @@ class UserDirectoryConfig(Config): Configuration for the behaviour of the /user_directory API """ - def read_config(self, config): + def read_config(self, config, **kwargs): self.user_directory_search_enabled = True self.user_directory_search_all_users = False user_directory_config = config.get("user_directory", None) diff --git a/synapse/config/voip.py b/synapse/config/voip.py index 82cf8c53a8..01e0cb2e28 100644 --- a/synapse/config/voip.py +++ b/synapse/config/voip.py @@ -16,7 +16,7 @@ from ._base import Config class VoipConfig(Config): - def read_config(self, config): + def read_config(self, config, **kwargs): self.turn_uris = config.get("turn_uris", []) self.turn_shared_secret = config.get("turn_shared_secret") self.turn_username = config.get("turn_username") diff --git a/synapse/config/workers.py b/synapse/config/workers.py index 4f283a0c2f..3b75471d85 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -21,7 +21,7 @@ class WorkerConfig(Config): They have their own pid_file and listener configuration. They use the replication_url to talk to the main synapse process.""" - def read_config(self, config): + def read_config(self, config, **kwargs): self.worker_app = config.get("worker_app") # Canonicalise worker_app so that master always has None diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py index 0cbbf4e885..a5d88d644a 100644 --- a/tests/config/test_tls.py +++ b/tests/config/test_tls.py @@ -65,7 +65,7 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg= } t = TestConfig() - t.read_config(config) + t.read_config(config, config_dir_path="", data_dir_path="") t.read_certificate_from_disk(require_cert_and_key=False) warnings = self.flushWarnings() diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index b1094c1448..417fda3ab2 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -78,7 +78,7 @@ class MatrixFederationAgentTests(TestCase): # config_dict["trusted_key_servers"] = [] self._config = config = HomeServerConfig() - config.parse_config_dict(config_dict) + config.parse_config_dict(config_dict, "", "") self.agent = MatrixFederationAgent( reactor=self.reactor, diff --git a/tests/unittest.py b/tests/unittest.py index d64702b0c2..36df43c137 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -342,7 +342,7 @@ class HomeserverTestCase(TestCase): # Parse the config from a config dict into a HomeServerConfig config_obj = HomeServerConfig() - config_obj.parse_config_dict(config) + config_obj.parse_config_dict(config, "", "") kwargs["config"] = config_obj hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) diff --git a/tests/utils.py b/tests/utils.py index bd2c7c954c..da43166f3a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -182,7 +182,7 @@ def default_config(name, parse=False): if parse: config = HomeServerConfig() - config.parse_config_dict(config_dict) + config.parse_config_dict(config_dict, "", "") return config return config_dict -- cgit 1.5.1 From bfe84e051e8472a1c3534e1287f1fb727df63259 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 24 Jun 2019 11:45:11 +0100 Subject: Split public rooms directory auth config in two --- changelog.d/5534.feature | 1 + docs/sample_config.yaml | 12 ++++++---- synapse/config/server.py | 44 ++++++++++++++++++++++++++-------- synapse/federation/transport/server.py | 8 +++---- synapse/rest/client/v1/room.py | 2 +- tests/rest/client/v1/test_rooms.py | 2 +- 6 files changed, 49 insertions(+), 20 deletions(-) create mode 100644 changelog.d/5534.feature (limited to 'tests') diff --git a/changelog.d/5534.feature b/changelog.d/5534.feature new file mode 100644 index 0000000000..2e279c9b77 --- /dev/null +++ b/changelog.d/5534.feature @@ -0,0 +1 @@ +Split public rooms directory auth config in two settings, in order to manage client auth independently from the federation part of it. Obsoletes the "restrict_public_rooms_to_local_users" configuration setting. If "restrict_public_rooms_to_local_users" is set in the config, Synapse will act as if both new options are enabled, i.e. require authentication through the client API and deny federation requests. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index d5cc3e7abc..f4fd113211 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -54,11 +54,15 @@ pid_file: DATADIR/homeserver.pid # #require_auth_for_profile_requests: true -# If set to 'true', requires authentication to access the server's -# public rooms directory through the client API, and forbids any other -# homeserver to fetch it via federation. Defaults to 'false'. +# If set to 'false', requires authentication to access the server's public rooms +# directory through the client API. Defaults to 'true'. # -#restrict_public_rooms_to_local_users: true +#allow_public_rooms_without_auth: false + +# If set to 'false', forbids any other homeserver to fetch the server's public +# rooms directory via federation. Defaults to 'true'. +# +#allow_public_rooms_over_federation: false # The default room version for newly created rooms. # diff --git a/synapse/config/server.py b/synapse/config/server.py index 1e58b2e91b..7cbb699a66 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -82,12 +82,32 @@ class ServerConfig(Config): "require_auth_for_profile_requests", False ) - # If set to 'True', requires authentication to access the server's - # public rooms directory through the client API, and forbids any other - # homeserver to fetch it via federation. - self.restrict_public_rooms_to_local_users = config.get( - "restrict_public_rooms_to_local_users", False - ) + if "restrict_public_rooms_to_local_users" in config and ( + "allow_public_rooms_without_auth" in config + or "allow_public_rooms_over_federation" in config + ): + raise ConfigError( + "Can't use 'restrict_public_rooms_to_local_users' if" + " 'allow_public_rooms_without_auth' and/or" + " 'allow_public_rooms_over_federation' is set." + ) + + # Check if the legacy "restrict_public_rooms_to_local_users" flag is set. This + # flag is now obsolete but we need to check it for backward-compatibility. + if config.get("restrict_public_rooms_to_local_users", False): + self.allow_public_rooms_without_auth = False + self.allow_public_rooms_over_federation = False + else: + # If set to 'False', requires authentication to access the server's public + # rooms directory through the client API. Defaults to 'True'. + self.allow_public_rooms_without_auth = config.get( + "allow_public_rooms_without_auth", True + ) + # If set to 'False', forbids any other homeserver to fetch the server's public + # rooms directory via federation. Defaults to 'True'. + self.allow_public_rooms_over_federation = config.get( + "allow_public_rooms_over_federation", True + ) default_room_version = config.get("default_room_version", DEFAULT_ROOM_VERSION) @@ -366,11 +386,15 @@ class ServerConfig(Config): # #require_auth_for_profile_requests: true - # If set to 'true', requires authentication to access the server's - # public rooms directory through the client API, and forbids any other - # homeserver to fetch it via federation. Defaults to 'false'. + # If set to 'false', requires authentication to access the server's public rooms + # directory through the client API. Defaults to 'true'. + # + #allow_public_rooms_without_auth: false + + # If set to 'false', forbids any other homeserver to fetch the server's public + # rooms directory via federation. Defaults to 'true'. # - #restrict_public_rooms_to_local_users: true + #allow_public_rooms_over_federation: false # The default room version for newly created rooms. # diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index b4854e82f6..955f0f4308 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -721,15 +721,15 @@ class PublicRoomList(BaseFederationServlet): PATH = "/publicRooms" - def __init__(self, handler, authenticator, ratelimiter, server_name, deny_access): + def __init__(self, handler, authenticator, ratelimiter, server_name, allow_access): super(PublicRoomList, self).__init__( handler, authenticator, ratelimiter, server_name ) - self.deny_access = deny_access + self.allow_access = allow_access @defer.inlineCallbacks def on_GET(self, origin, content, query): - if self.deny_access: + if not self.allow_access: raise FederationDeniedError(origin) limit = parse_integer_from_args(query, "limit", 0) @@ -1436,7 +1436,7 @@ def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=N authenticator=authenticator, ratelimiter=ratelimiter, server_name=hs.hostname, - deny_access=hs.config.restrict_public_rooms_to_local_users, + allow_access=hs.config.allow_public_rooms_over_federation, ).register(resource) if "group_server" in servlet_groups: diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index a028337125..cca7e45ddb 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -311,7 +311,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): # Option to allow servers to require auth when accessing # /publicRooms via CS API. This is especially helpful in private # federations. - if self.hs.config.restrict_public_rooms_to_local_users: + if not self.hs.config.allow_public_rooms_without_auth: raise # We allow people to not be authed if they're just looking at our diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 2e3a765bf3..fe741637f5 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -920,7 +920,7 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): self.url = b"/_matrix/client/r0/publicRooms" config = self.default_config() - config["restrict_public_rooms_to_local_users"] = True + config["allow_public_rooms_without_auth"] = False self.hs = self.setup_test_homeserver(config=config) return self.hs -- cgit 1.5.1 From be3b901ccdf28d0f81d312d7cd8b7bedb22b4049 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Fri, 28 Jun 2019 18:19:09 +1000 Subject: Update the TLS cipher string and provide configurability for TLS on outgoing federation (#5550) --- changelog.d/5550.feature | 1 + changelog.d/5550.misc | 1 + docs/sample_config.yaml | 9 +++ scripts/generate_config | 2 +- synapse/config/tls.py | 32 ++++++++++- synapse/crypto/context_factory.py | 39 +++++++++++-- tests/config/test_tls.py | 115 +++++++++++++++++++++++++++++++++++++- 7 files changed, 190 insertions(+), 9 deletions(-) create mode 100644 changelog.d/5550.feature create mode 100644 changelog.d/5550.misc (limited to 'tests') diff --git a/changelog.d/5550.feature b/changelog.d/5550.feature new file mode 100644 index 0000000000..79ecedf3b8 --- /dev/null +++ b/changelog.d/5550.feature @@ -0,0 +1 @@ +The minimum TLS version used for outgoing federation requests can now be set with `federation_client_minimum_tls_version`. diff --git a/changelog.d/5550.misc b/changelog.d/5550.misc new file mode 100644 index 0000000000..ad5693338e --- /dev/null +++ b/changelog.d/5550.misc @@ -0,0 +1 @@ +Synapse will now only allow TLS v1.2 connections when serving federation, if it terminates TLS. As Synapse's allowed ciphers were only able to be used in TLSv1.2 before, this does not change behaviour. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index a01e1152f7..bf9cd88b15 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -317,6 +317,15 @@ listeners: # #federation_verify_certificates: false +# The minimum TLS version that will be used for outbound federation requests. +# +# Defaults to `1`. Configurable to `1`, `1.1`, `1.2`, or `1.3`. Note +# that setting this value higher than `1.2` will prevent federation to most +# of the public Matrix network: only configure it to `1.3` if you have an +# entirely private federation setup and you can ensure TLS 1.3 support. +# +#federation_client_minimum_tls_version: 1.2 + # Skip federation certificate verification on the following whitelist # of domains. # diff --git a/scripts/generate_config b/scripts/generate_config index 93b6406992..771cbf8d95 100755 --- a/scripts/generate_config +++ b/scripts/generate_config @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 import argparse import shutil diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 8fcf801418..ca508a224f 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -23,7 +23,7 @@ import six from unpaddedbase64 import encode_base64 -from OpenSSL import crypto +from OpenSSL import SSL, crypto from twisted.internet._sslverify import Certificate, trustRootFromCertificates from synapse.config._base import Config, ConfigError @@ -81,6 +81,27 @@ class TlsConfig(Config): "federation_verify_certificates", True ) + # Minimum TLS version to use for outbound federation traffic + self.federation_client_minimum_tls_version = str( + config.get("federation_client_minimum_tls_version", 1) + ) + + if self.federation_client_minimum_tls_version not in ["1", "1.1", "1.2", "1.3"]: + raise ConfigError( + "federation_client_minimum_tls_version must be one of: 1, 1.1, 1.2, 1.3" + ) + + # Prevent people shooting themselves in the foot here by setting it to + # the biggest number blindly + if self.federation_client_minimum_tls_version == "1.3": + if getattr(SSL, "OP_NO_TLSv1_3", None) is None: + raise ConfigError( + ( + "federation_client_minimum_tls_version cannot be 1.3, " + "your OpenSSL does not support it" + ) + ) + # Whitelist of domains to not verify certificates for fed_whitelist_entries = config.get( "federation_certificate_verification_whitelist", [] @@ -261,6 +282,15 @@ class TlsConfig(Config): # #federation_verify_certificates: false + # The minimum TLS version that will be used for outbound federation requests. + # + # Defaults to `1`. Configurable to `1`, `1.1`, `1.2`, or `1.3`. Note + # that setting this value higher than `1.2` will prevent federation to most + # of the public Matrix network: only configure it to `1.3` if you have an + # entirely private federation setup and you can ensure TLS 1.3 support. + # + #federation_client_minimum_tls_version: 1.2 + # Skip federation certificate verification on the following whitelist # of domains. # diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index 2bc5cc3807..4f48e8e88d 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -24,12 +24,25 @@ from OpenSSL import SSL, crypto from twisted.internet._sslverify import _defaultCurveName from twisted.internet.abstract import isIPAddress, isIPv6Address from twisted.internet.interfaces import IOpenSSLClientConnectionCreator -from twisted.internet.ssl import CertificateOptions, ContextFactory, platformTrust +from twisted.internet.ssl import ( + CertificateOptions, + ContextFactory, + TLSVersion, + platformTrust, +) from twisted.python.failure import Failure logger = logging.getLogger(__name__) +_TLS_VERSION_MAP = { + "1": TLSVersion.TLSv1_0, + "1.1": TLSVersion.TLSv1_1, + "1.2": TLSVersion.TLSv1_2, + "1.3": TLSVersion.TLSv1_3, +} + + class ServerContextFactory(ContextFactory): """Factory for PyOpenSSL SSL contexts that are used to handle incoming connections.""" @@ -43,16 +56,18 @@ class ServerContextFactory(ContextFactory): try: _ecCurve = crypto.get_elliptic_curve(_defaultCurveName) context.set_tmp_ecdh(_ecCurve) - except Exception: logger.exception("Failed to enable elliptic curve for TLS") - context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3) + + context.set_options( + SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | SSL.OP_NO_TLSv1 | SSL.OP_NO_TLSv1_1 + ) context.use_certificate_chain_file(config.tls_certificate_file) context.use_privatekey(config.tls_private_key) # https://hynek.me/articles/hardening-your-web-servers-ssl-ciphers/ context.set_cipher_list( - "ECDH+AESGCM:ECDH+CHACHA20:ECDH+AES256:ECDH+AES128:!aNULL:!SHA1" + "ECDH+AESGCM:ECDH+CHACHA20:ECDH+AES256:ECDH+AES128:!aNULL:!SHA1:!AESCCM" ) def getContext(self): @@ -79,10 +94,22 @@ class ClientTLSOptionsFactory(object): # Use CA root certs provided by OpenSSL trust_root = platformTrust() - self._verify_ssl_context = CertificateOptions(trustRoot=trust_root).getContext() + # "insecurelyLowerMinimumTo" is the argument that will go lower than + # Twisted's default, which is why it is marked as "insecure" (since + # Twisted's defaults are reasonably secure). But, since Twisted is + # moving to TLS 1.2 by default, we want to respect the config option if + # it is set to 1.0 (which the alternate option, raiseMinimumTo, will not + # let us do). + minTLS = _TLS_VERSION_MAP[config.federation_client_minimum_tls_version] + + self._verify_ssl = CertificateOptions( + trustRoot=trust_root, insecurelyLowerMinimumTo=minTLS + ) + self._verify_ssl_context = self._verify_ssl.getContext() self._verify_ssl_context.set_info_callback(self._context_info_cb) - self._no_verify_ssl_context = CertificateOptions().getContext() + self._no_verify_ssl = CertificateOptions(insecurelyLowerMinimumTo=minTLS) + self._no_verify_ssl_context = self._no_verify_ssl.getContext() self._no_verify_ssl_context.set_info_callback(self._context_info_cb) def get_options(self, host): diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py index a5d88d644a..4f8a87a3df 100644 --- a/tests/config/test_tls.py +++ b/tests/config/test_tls.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2019 New Vector Ltd +# Copyright 2019 Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,7 +16,10 @@ import os -from synapse.config.tls import TlsConfig +from OpenSSL import SSL + +from synapse.config.tls import ConfigError, TlsConfig +from synapse.crypto.context_factory import ClientTLSOptionsFactory from tests.unittest import TestCase @@ -78,3 +82,112 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg= "or use Synapse's ACME support to provision one." ), ) + + def test_tls_client_minimum_default(self): + """ + The default client TLS version is 1.0. + """ + config = {} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + + self.assertEqual(t.federation_client_minimum_tls_version, "1") + + def test_tls_client_minimum_set(self): + """ + The default client TLS version can be set to 1.0, 1.1, and 1.2. + """ + config = {"federation_client_minimum_tls_version": 1} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.federation_client_minimum_tls_version, "1") + + config = {"federation_client_minimum_tls_version": 1.1} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.federation_client_minimum_tls_version, "1.1") + + config = {"federation_client_minimum_tls_version": 1.2} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.federation_client_minimum_tls_version, "1.2") + + # Also test a string version + config = {"federation_client_minimum_tls_version": "1"} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.federation_client_minimum_tls_version, "1") + + config = {"federation_client_minimum_tls_version": "1.2"} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.federation_client_minimum_tls_version, "1.2") + + def test_tls_client_minimum_1_point_3_missing(self): + """ + If TLS 1.3 support is missing and it's configured, it will raise a + ConfigError. + """ + # thanks i hate it + if hasattr(SSL, "OP_NO_TLSv1_3"): + OP_NO_TLSv1_3 = SSL.OP_NO_TLSv1_3 + delattr(SSL, "OP_NO_TLSv1_3") + self.addCleanup(setattr, SSL, "SSL.OP_NO_TLSv1_3", OP_NO_TLSv1_3) + assert not hasattr(SSL, "OP_NO_TLSv1_3") + + config = {"federation_client_minimum_tls_version": 1.3} + t = TestConfig() + with self.assertRaises(ConfigError) as e: + t.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual( + e.exception.args[0], + ( + "federation_client_minimum_tls_version cannot be 1.3, " + "your OpenSSL does not support it" + ), + ) + + def test_tls_client_minimum_1_point_3_exists(self): + """ + If TLS 1.3 support exists and it's configured, it will be settable. + """ + # thanks i hate it, still + if not hasattr(SSL, "OP_NO_TLSv1_3"): + SSL.OP_NO_TLSv1_3 = 0x00 + self.addCleanup(lambda: delattr(SSL, "OP_NO_TLSv1_3")) + assert hasattr(SSL, "OP_NO_TLSv1_3") + + config = {"federation_client_minimum_tls_version": 1.3} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.federation_client_minimum_tls_version, "1.3") + + def test_tls_client_minimum_set_passed_through_1_2(self): + """ + The configured TLS version is correctly configured by the ContextFactory. + """ + config = {"federation_client_minimum_tls_version": 1.2} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + + cf = ClientTLSOptionsFactory(t) + + # The context has had NO_TLSv1_1 and NO_TLSv1_0 set, but not NO_TLSv1_2 + self.assertNotEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1, 0) + self.assertNotEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_1, 0) + self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_2, 0) + + def test_tls_client_minimum_set_passed_through_1_0(self): + """ + The configured TLS version is correctly configured by the ContextFactory. + """ + config = {"federation_client_minimum_tls_version": 1} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + + cf = ClientTLSOptionsFactory(t) + + # The context has not had any of the NO_TLS set. + self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1, 0) + self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_1, 0) + self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_2, 0) -- cgit 1.5.1 From f40a7dc41fdd738a74546ff22b110a4c8ab850fe Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Sat, 29 Jun 2019 17:06:55 +1000 Subject: Make the http server handle coroutine-making REST servlets (#5475) --- changelog.d/5475.misc | 1 + synapse/http/server.py | 77 ++++++++++++++------------- synapse/rest/consent/consent_resource.py | 35 +++++------- synapse/rest/key/v2/remote_key_resource.py | 28 ++++------ synapse/rest/media/v1/config_resource.py | 21 ++++---- synapse/rest/media/v1/download_resource.py | 26 ++++----- synapse/rest/media/v1/preview_url_resource.py | 18 +++---- synapse/rest/media/v1/thumbnail_resource.py | 27 +++++----- synapse/rest/media/v1/upload_resource.py | 23 ++++---- synapse/rest/saml2/response_resource.py | 15 ++---- tests/rest/media/v1/test_media_storage.py | 25 +++++---- tests/unittest.py | 40 +++++++++++--- 12 files changed, 162 insertions(+), 174 deletions(-) create mode 100644 changelog.d/5475.misc (limited to 'tests') diff --git a/changelog.d/5475.misc b/changelog.d/5475.misc new file mode 100644 index 0000000000..6be06d4d0b --- /dev/null +++ b/changelog.d/5475.misc @@ -0,0 +1 @@ +Synapse can now handle RestServlets that return coroutines. diff --git a/synapse/http/server.py b/synapse/http/server.py index 6fd13e87d1..f067c163c1 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -16,10 +16,11 @@ import cgi import collections +import http.client import logging - -from six import PY3 -from six.moves import http_client, urllib +import types +import urllib +from io import BytesIO from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json @@ -41,11 +42,6 @@ from synapse.api.errors import ( from synapse.util.caches import intern_dict from synapse.util.logcontext import preserve_fn -if PY3: - from io import BytesIO -else: - from cStringIO import StringIO as BytesIO - logger = logging.getLogger(__name__) HTML_ERROR_TEMPLATE = """ @@ -75,10 +71,9 @@ def wrap_json_request_handler(h): deferred fails with any other type of error we send a 500 reponse. """ - @defer.inlineCallbacks - def wrapped_request_handler(self, request): + async def wrapped_request_handler(self, request): try: - yield h(self, request) + await h(self, request) except SynapseError as e: code = e.code logger.info("%s SynapseError: %s - %s", request, code, e.msg) @@ -142,10 +137,12 @@ def wrap_html_request_handler(h): where "request" must be a SynapseRequest. """ - def wrapped_request_handler(self, request): - d = defer.maybeDeferred(h, self, request) - d.addErrback(_return_html_error, request) - return d + async def wrapped_request_handler(self, request): + try: + return await h(self, request) + except Exception: + f = failure.Failure() + return _return_html_error(f, request) return wrap_async_request_handler(wrapped_request_handler) @@ -171,7 +168,7 @@ def _return_html_error(f, request): exc_info=(f.type, f.value, f.getTracebackObject()), ) else: - code = http_client.INTERNAL_SERVER_ERROR + code = http.client.INTERNAL_SERVER_ERROR msg = "Internal server error" logger.error( @@ -201,10 +198,9 @@ def wrap_async_request_handler(h): logged until the deferred completes. """ - @defer.inlineCallbacks - def wrapped_async_request_handler(self, request): + async def wrapped_async_request_handler(self, request): with request.processing(): - yield h(self, request) + await h(self, request) # we need to preserve_fn here, because the synchronous render method won't yield for # us (obviously) @@ -270,12 +266,11 @@ class JsonResource(HttpServer, resource.Resource): def render(self, request): """ This gets called by twisted every time someone sends us a request. """ - self._async_render(request) + defer.ensureDeferred(self._async_render(request)) return NOT_DONE_YET @wrap_json_request_handler - @defer.inlineCallbacks - def _async_render(self, request): + async def _async_render(self, request): """ This gets called from render() every time someone sends us a request. This checks if anyone has registered a callback for that method and path. @@ -292,26 +287,19 @@ class JsonResource(HttpServer, resource.Resource): # Now trigger the callback. If it returns a response, we send it # here. If it throws an exception, that is handled by the wrapper # installed by @request_handler. - - def _unquote(s): - if PY3: - # On Python 3, unquote is unicode -> unicode - return urllib.parse.unquote(s) - else: - # On Python 2, unquote is bytes -> bytes We need to encode the - # URL again (as it was decoded by _get_handler_for request), as - # ASCII because it's a URL, and then decode it to get the UTF-8 - # characters that were quoted. - return urllib.parse.unquote(s.encode("ascii")).decode("utf8") - kwargs = intern_dict( { - name: _unquote(value) if value else value + name: urllib.parse.unquote(value) if value else value for name, value in group_dict.items() } ) - callback_return = yield callback(request, **kwargs) + callback_return = callback(request, **kwargs) + + # Is it synchronous? We'll allow this for now. + if isinstance(callback_return, (defer.Deferred, types.CoroutineType)): + callback_return = await callback_return + if callback_return is not None: code, response = callback_return self._send_response(request, code, response) @@ -360,6 +348,23 @@ class JsonResource(HttpServer, resource.Resource): ) +class DirectServeResource(resource.Resource): + def render(self, request): + """ + Render the request, using an asynchronous render handler if it exists. + """ + render_callback_name = "_async_render_" + request.method.decode("ascii") + + if hasattr(self, render_callback_name): + # Call the handler + callback = getattr(self, render_callback_name) + defer.ensureDeferred(callback(request)) + + return NOT_DONE_YET + else: + super().render(request) + + def _options_handler(request): """Request handler for OPTIONS requests diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index 9a32892d8b..624c42441e 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -23,13 +23,13 @@ from six.moves import http_client import jinja2 from jinja2 import TemplateNotFound -from twisted.internet import defer -from twisted.web.resource import Resource -from twisted.web.server import NOT_DONE_YET - from synapse.api.errors import NotFoundError, StoreError, SynapseError from synapse.config import ConfigError -from synapse.http.server import finish_request, wrap_html_request_handler +from synapse.http.server import ( + DirectServeResource, + finish_request, + wrap_html_request_handler, +) from synapse.http.servlet import parse_string from synapse.types import UserID @@ -47,7 +47,7 @@ else: return a == b -class ConsentResource(Resource): +class ConsentResource(DirectServeResource): """A twisted Resource to display a privacy policy and gather consent to it When accessed via GET, returns the privacy policy via a template. @@ -87,7 +87,7 @@ class ConsentResource(Resource): Args: hs (synapse.server.HomeServer): homeserver """ - Resource.__init__(self) + super().__init__() self.hs = hs self.store = hs.get_datastore() @@ -118,18 +118,12 @@ class ConsentResource(Resource): self._hmac_secret = hs.config.form_secret.encode("utf-8") - def render_GET(self, request): - self._async_render_GET(request) - return NOT_DONE_YET - @wrap_html_request_handler - @defer.inlineCallbacks - def _async_render_GET(self, request): + async def _async_render_GET(self, request): """ Args: request (twisted.web.http.Request): """ - version = parse_string(request, "v", default=self._default_consent_version) username = parse_string(request, "u", required=False, default="") userhmac = None @@ -145,7 +139,7 @@ class ConsentResource(Resource): else: qualified_user_id = UserID(username, self.hs.hostname).to_string() - u = yield self.store.get_user_by_id(qualified_user_id) + u = await self.store.get_user_by_id(qualified_user_id) if u is None: raise NotFoundError("Unknown user") @@ -165,13 +159,8 @@ class ConsentResource(Resource): except TemplateNotFound: raise NotFoundError("Unknown policy version") - def render_POST(self, request): - self._async_render_POST(request) - return NOT_DONE_YET - @wrap_html_request_handler - @defer.inlineCallbacks - def _async_render_POST(self, request): + async def _async_render_POST(self, request): """ Args: request (twisted.web.http.Request): @@ -188,12 +177,12 @@ class ConsentResource(Resource): qualified_user_id = UserID(username, self.hs.hostname).to_string() try: - yield self.store.user_set_consent_version(qualified_user_id, version) + await self.store.user_set_consent_version(qualified_user_id, version) except StoreError as e: if e.code != 404: raise raise NotFoundError("Unknown user") - yield self.registration_handler.post_consent_actions(qualified_user_id) + await self.registration_handler.post_consent_actions(qualified_user_id) try: self._render_template(request, "success.html") diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index ec8b9d7269..031a316693 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -16,18 +16,20 @@ import logging from io import BytesIO from twisted.internet import defer -from twisted.web.resource import Resource -from twisted.web.server import NOT_DONE_YET from synapse.api.errors import Codes, SynapseError from synapse.crypto.keyring import ServerKeyFetcher -from synapse.http.server import respond_with_json_bytes, wrap_json_request_handler +from synapse.http.server import ( + DirectServeResource, + respond_with_json_bytes, + wrap_json_request_handler, +) from synapse.http.servlet import parse_integer, parse_json_object_from_request logger = logging.getLogger(__name__) -class RemoteKey(Resource): +class RemoteKey(DirectServeResource): """HTTP resource for retreiving the TLS certificate and NACL signature verification keys for a collection of servers. Checks that the reported X.509 TLS certificate matches the one used in the HTTPS connection. Checks @@ -94,13 +96,8 @@ class RemoteKey(Resource): self.clock = hs.get_clock() self.federation_domain_whitelist = hs.config.federation_domain_whitelist - def render_GET(self, request): - self.async_render_GET(request) - return NOT_DONE_YET - @wrap_json_request_handler - @defer.inlineCallbacks - def async_render_GET(self, request): + async def _async_render_GET(self, request): if len(request.postpath) == 1: server, = request.postpath query = {server.decode("ascii"): {}} @@ -114,20 +111,15 @@ class RemoteKey(Resource): else: raise SynapseError(404, "Not found %r" % request.postpath, Codes.NOT_FOUND) - yield self.query_keys(request, query, query_remote_on_cache_miss=True) - - def render_POST(self, request): - self.async_render_POST(request) - return NOT_DONE_YET + await self.query_keys(request, query, query_remote_on_cache_miss=True) @wrap_json_request_handler - @defer.inlineCallbacks - def async_render_POST(self, request): + async def _async_render_POST(self, request): content = parse_json_object_from_request(request) query = content["server_keys"] - yield self.query_keys(request, query, query_remote_on_cache_miss=True) + await self.query_keys(request, query, query_remote_on_cache_miss=True) @defer.inlineCallbacks def query_keys(self, request, query, query_remote_on_cache_miss=False): diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py index fa3d6680fc..9f747de263 100644 --- a/synapse/rest/media/v1/config_resource.py +++ b/synapse/rest/media/v1/config_resource.py @@ -14,31 +14,28 @@ # limitations under the License. # -from twisted.internet import defer -from twisted.web.resource import Resource from twisted.web.server import NOT_DONE_YET -from synapse.http.server import respond_with_json, wrap_json_request_handler +from synapse.http.server import ( + DirectServeResource, + respond_with_json, + wrap_json_request_handler, +) -class MediaConfigResource(Resource): +class MediaConfigResource(DirectServeResource): isLeaf = True def __init__(self, hs): - Resource.__init__(self) + super().__init__() config = hs.get_config() self.clock = hs.get_clock() self.auth = hs.get_auth() self.limits_dict = {"m.upload.size": config.max_upload_size} - def render_GET(self, request): - self._async_render_GET(request) - return NOT_DONE_YET - @wrap_json_request_handler - @defer.inlineCallbacks - def _async_render_GET(self, request): - yield self.auth.get_user_by_req(request) + async def _async_render_GET(self, request): + await self.auth.get_user_by_req(request) respond_with_json(request, 200, self.limits_dict, send_cors=True) def render_OPTIONS(self, request): diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py index a21a35f843..66a01559e1 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py @@ -14,37 +14,31 @@ # limitations under the License. import logging -from twisted.internet import defer -from twisted.web.resource import Resource -from twisted.web.server import NOT_DONE_YET - import synapse.http.servlet -from synapse.http.server import set_cors_headers, wrap_json_request_handler +from synapse.http.server import ( + DirectServeResource, + set_cors_headers, + wrap_json_request_handler, +) from ._base import parse_media_id, respond_404 logger = logging.getLogger(__name__) -class DownloadResource(Resource): +class DownloadResource(DirectServeResource): isLeaf = True def __init__(self, hs, media_repo): - Resource.__init__(self) - + super().__init__() self.media_repo = media_repo self.server_name = hs.hostname # this is expected by @wrap_json_request_handler self.clock = hs.get_clock() - def render_GET(self, request): - self._async_render_GET(request) - return NOT_DONE_YET - @wrap_json_request_handler - @defer.inlineCallbacks - def _async_render_GET(self, request): + async def _async_render_GET(self, request): set_cors_headers(request) request.setHeader( b"Content-Security-Policy", @@ -58,7 +52,7 @@ class DownloadResource(Resource): ) server_name, media_id, name = parse_media_id(request) if server_name == self.server_name: - yield self.media_repo.get_local_media(request, media_id, name) + await self.media_repo.get_local_media(request, media_id, name) else: allow_remote = synapse.http.servlet.parse_boolean( request, "allow_remote", default=True @@ -72,4 +66,4 @@ class DownloadResource(Resource): respond_404(request) return - yield self.media_repo.get_remote_media(request, server_name, media_id, name) + await self.media_repo.get_remote_media(request, server_name, media_id, name) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index de6f292ffb..0337b64dc2 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -32,12 +32,11 @@ from canonicaljson import json from twisted.internet import defer from twisted.internet.error import DNSLookupError -from twisted.web.resource import Resource -from twisted.web.server import NOT_DONE_YET from synapse.api.errors import Codes, SynapseError from synapse.http.client import SimpleHttpClient from synapse.http.server import ( + DirectServeResource, respond_with_json, respond_with_json_bytes, wrap_json_request_handler, @@ -58,11 +57,11 @@ _charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=r _content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I) -class PreviewUrlResource(Resource): +class PreviewUrlResource(DirectServeResource): isLeaf = True def __init__(self, hs, media_repo, media_storage): - Resource.__init__(self) + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() @@ -98,16 +97,11 @@ class PreviewUrlResource(Resource): def render_OPTIONS(self, request): return respond_with_json(request, 200, {}, send_cors=True) - def render_GET(self, request): - self._async_render_GET(request) - return NOT_DONE_YET - @wrap_json_request_handler - @defer.inlineCallbacks - def _async_render_GET(self, request): + async def _async_render_GET(self, request): # XXX: if get_user_by_req fails, what should we do in an async render? - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) url = parse_string(request, "url") if b"ts" in request.args: ts = parse_integer(request, "ts") @@ -159,7 +153,7 @@ class PreviewUrlResource(Resource): else: logger.info("Returning cached response") - og = yield make_deferred_yieldable(observable.observe()) + og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe)) respond_with_json_bytes(request, 200, og, send_cors=True) @defer.inlineCallbacks diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index ca84c9f139..08329884ac 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -17,10 +17,12 @@ import logging from twisted.internet import defer -from twisted.web.resource import Resource -from twisted.web.server import NOT_DONE_YET -from synapse.http.server import set_cors_headers, wrap_json_request_handler +from synapse.http.server import ( + DirectServeResource, + set_cors_headers, + wrap_json_request_handler, +) from synapse.http.servlet import parse_integer, parse_string from ._base import ( @@ -34,11 +36,11 @@ from ._base import ( logger = logging.getLogger(__name__) -class ThumbnailResource(Resource): +class ThumbnailResource(DirectServeResource): isLeaf = True def __init__(self, hs, media_repo, media_storage): - Resource.__init__(self) + super().__init__() self.store = hs.get_datastore() self.media_repo = media_repo @@ -47,13 +49,8 @@ class ThumbnailResource(Resource): self.server_name = hs.hostname self.clock = hs.get_clock() - def render_GET(self, request): - self._async_render_GET(request) - return NOT_DONE_YET - @wrap_json_request_handler - @defer.inlineCallbacks - def _async_render_GET(self, request): + async def _async_render_GET(self, request): set_cors_headers(request) server_name, media_id, _ = parse_media_id(request) width = parse_integer(request, "width", required=True) @@ -63,21 +60,21 @@ class ThumbnailResource(Resource): if server_name == self.server_name: if self.dynamic_thumbnails: - yield self._select_or_generate_local_thumbnail( + await self._select_or_generate_local_thumbnail( request, media_id, width, height, method, m_type ) else: - yield self._respond_local_thumbnail( + await self._respond_local_thumbnail( request, media_id, width, height, method, m_type ) self.media_repo.mark_recently_accessed(None, media_id) else: if self.dynamic_thumbnails: - yield self._select_or_generate_remote_thumbnail( + await self._select_or_generate_remote_thumbnail( request, server_name, media_id, width, height, method, m_type ) else: - yield self._respond_remote_thumbnail( + await self._respond_remote_thumbnail( request, server_name, media_id, width, height, method, m_type ) self.media_repo.mark_recently_accessed(server_name, media_id) diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index d1d7e959f0..5d76bbdf68 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -15,22 +15,24 @@ import logging -from twisted.internet import defer -from twisted.web.resource import Resource from twisted.web.server import NOT_DONE_YET from synapse.api.errors import SynapseError -from synapse.http.server import respond_with_json, wrap_json_request_handler +from synapse.http.server import ( + DirectServeResource, + respond_with_json, + wrap_json_request_handler, +) from synapse.http.servlet import parse_string logger = logging.getLogger(__name__) -class UploadResource(Resource): +class UploadResource(DirectServeResource): isLeaf = True def __init__(self, hs, media_repo): - Resource.__init__(self) + super().__init__() self.media_repo = media_repo self.filepaths = media_repo.filepaths @@ -41,18 +43,13 @@ class UploadResource(Resource): self.max_upload_size = hs.config.max_upload_size self.clock = hs.get_clock() - def render_POST(self, request): - self._async_render_POST(request) - return NOT_DONE_YET - def render_OPTIONS(self, request): respond_with_json(request, 200, {}, send_cors=True) return NOT_DONE_YET @wrap_json_request_handler - @defer.inlineCallbacks - def _async_render_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def _async_render_POST(self, request): + requester = await self.auth.get_user_by_req(request) # TODO: The checks here are a bit late. The content will have # already been uploaded to a tmp file at this point content_length = request.getHeader(b"Content-Length").decode("ascii") @@ -81,7 +78,7 @@ class UploadResource(Resource): # disposition = headers.getRawHeaders(b"Content-Disposition")[0] # TODO(markjh): parse content-dispostion - content_uri = yield self.media_repo.create_content( + content_uri = await self.media_repo.create_content( media_type, upload_name, request.content, content_length, requester.user ) diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/saml2/response_resource.py index ab14b70675..939c87306c 100644 --- a/synapse/rest/saml2/response_resource.py +++ b/synapse/rest/saml2/response_resource.py @@ -18,34 +18,27 @@ import logging import saml2 from saml2.client import Saml2Client -from twisted.web.resource import Resource -from twisted.web.server import NOT_DONE_YET - from synapse.api.errors import CodeMessageException -from synapse.http.server import wrap_html_request_handler +from synapse.http.server import DirectServeResource, wrap_html_request_handler from synapse.http.servlet import parse_string from synapse.rest.client.v1.login import SSOAuthHandler logger = logging.getLogger(__name__) -class SAML2ResponseResource(Resource): +class SAML2ResponseResource(DirectServeResource): """A Twisted web resource which handles the SAML response""" isLeaf = 1 def __init__(self, hs): - Resource.__init__(self) + super().__init__() self._saml_client = Saml2Client(hs.config.saml2_sp_config) self._sso_auth_handler = SSOAuthHandler(hs) - def render_POST(self, request): - self._async_render_POST(request) - return NOT_DONE_YET - @wrap_html_request_handler - def _async_render_POST(self, request): + async def _async_render_POST(self, request): resp_bytes = parse_string(request, "SAMLResponse", required=True) relay_state = parse_string(request, "RelayState", required=True) diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index e2d418b1df..39c9342423 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -22,7 +22,6 @@ from binascii import unhexlify from mock import Mock from six.moves.urllib import parse -from twisted.internet import defer, reactor from twisted.internet.defer import Deferred from synapse.rest.media.v1._base import FileInfo @@ -34,15 +33,17 @@ from synapse.util.logcontext import make_deferred_yieldable from tests import unittest -class MediaStorageTests(unittest.TestCase): - def setUp(self): +class MediaStorageTests(unittest.HomeserverTestCase): + + needs_threadpool = True + + def prepare(self, reactor, clock, hs): self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-") + self.addCleanup(shutil.rmtree, self.test_dir) self.primary_base_path = os.path.join(self.test_dir, "primary") self.secondary_base_path = os.path.join(self.test_dir, "secondary") - hs = Mock() - hs.get_reactor = Mock(return_value=reactor) hs.config.media_store_path = self.primary_base_path storage_providers = [FileStorageProviderBackend(hs, self.secondary_base_path)] @@ -52,10 +53,6 @@ class MediaStorageTests(unittest.TestCase): hs, self.primary_base_path, self.filepaths, storage_providers ) - def tearDown(self): - shutil.rmtree(self.test_dir) - - @defer.inlineCallbacks def test_ensure_media_is_in_local_cache(self): media_id = "some_media_id" test_body = "Test\n" @@ -73,7 +70,15 @@ class MediaStorageTests(unittest.TestCase): # Now we run ensure_media_is_in_local_cache, which should copy the file # to the local cache. file_info = FileInfo(None, media_id) - local_path = yield self.media_storage.ensure_media_is_in_local_cache(file_info) + + # This uses a real blocking threadpool so we have to wait for it to be + # actually done :/ + x = self.media_storage.ensure_media_is_in_local_cache(file_info) + + # Hotloop until the threadpool does its job... + self.wait_on_thread(x) + + local_path = self.get_success(x) self.assertTrue(os.path.exists(local_path)) diff --git a/tests/unittest.py b/tests/unittest.py index 36df43c137..d26804b5b5 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -17,6 +17,7 @@ import gc import hashlib import hmac import logging +import time from mock import Mock @@ -24,7 +25,8 @@ from canonicaljson import json import twisted import twisted.logger -from twisted.internet.defer import Deferred +from twisted.internet.defer import Deferred, succeed +from twisted.python.threadpool import ThreadPool from twisted.trial import unittest from synapse.api.constants import EventTypes @@ -164,6 +166,7 @@ class HomeserverTestCase(TestCase): servlets = [] hijack_auth = True + needs_threadpool = False def setUp(self): """ @@ -192,15 +195,19 @@ class HomeserverTestCase(TestCase): if self.hijack_auth: def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.helper.auth_user_id), - "token_id": 1, - "is_guest": False, - } + return succeed( + { + "user": UserID.from_string(self.helper.auth_user_id), + "token_id": 1, + "is_guest": False, + } + ) def get_user_by_req(request, allow_guest=False, rights="access"): - return create_requester( - UserID.from_string(self.helper.auth_user_id), 1, False, None + return succeed( + create_requester( + UserID.from_string(self.helper.auth_user_id), 1, False, None + ) ) self.hs.get_auth().get_user_by_req = get_user_by_req @@ -209,9 +216,26 @@ class HomeserverTestCase(TestCase): return_value="1234" ) + if self.needs_threadpool: + self.reactor.threadpool = ThreadPool() + self.addCleanup(self.reactor.threadpool.stop) + self.reactor.threadpool.start() + if hasattr(self, "prepare"): self.prepare(self.reactor, self.clock, self.hs) + def wait_on_thread(self, deferred, timeout=10): + """ + Wait until a Deferred is done, where it's waiting on a real thread. + """ + start_time = time.time() + + while not deferred.called: + if start_time + timeout < time.time(): + raise ValueError("Timed out waiting for threadpool") + self.reactor.advance(0.01) + time.sleep(0.01) + def make_homeserver(self, reactor, clock): """ Make and return a homeserver. -- cgit 1.5.1 From 8ee69f299cb3360de5c88f0c6b07525d35247fbd Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 1 Jul 2019 17:55:11 +0100 Subject: Add basic function to get all data for a user out of synapse --- synapse/handlers/admin.py | 247 ++++++++++++++++++++++++++++++++++++++++++ synapse/storage/roommember.py | 20 ++++ tests/handlers/test_admin.py | 210 +++++++++++++++++++++++++++++++++++ tests/unittest.py | 2 +- 4 files changed, 478 insertions(+), 1 deletion(-) create mode 100644 tests/handlers/test_admin.py (limited to 'tests') diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 941ebfa107..e424fc46bd 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -14,9 +14,17 @@ # limitations under the License. import logging +import os +import tempfile + +from canonicaljson import json from twisted.internet import defer +from synapse.api.constants import Membership +from synapse.types import RoomStreamToken +from synapse.visibility import filter_events_for_client + from ._base import BaseHandler logger = logging.getLogger(__name__) @@ -89,3 +97,242 @@ class AdminHandler(BaseHandler): ret = yield self.store.search_users(term) defer.returnValue(ret) + + @defer.inlineCallbacks + def exfiltrate_user_data(self, user_id, writer): + """Write all data we have of the user to the specified directory. + + Args: + user_id (str) + writer (ExfiltrationWriter) + + Returns: + defer.Deferred + """ + # Get all rooms the user is in or has been in + rooms = yield self.store.get_rooms_for_user_where_membership_is( + user_id, + membership_list=( + Membership.JOIN, + Membership.LEAVE, + Membership.BAN, + Membership.INVITE, + ), + ) + + # We only try and fetch events for rooms the user has been in. If + # they've been e.g. invited to a room without joining then we handle + # those seperately. + rooms_user_has_been_in = yield self.store.get_rooms_user_has_been_in(user_id) + + for index, room in enumerate(rooms): + room_id = room.room_id + + logger.info( + "[%s] Handling room %s, %d/%d", user_id, room_id, index + 1, len(rooms) + ) + + forgotten = yield self.store.did_forget(user_id, room_id) + if forgotten: + logger.info("[%s] User forgot room %d, ignoring", room_id) + continue + + if room_id not in rooms_user_has_been_in: + # If we haven't been in the rooms then the filtering code below + # won't return anything, so we need to handle these cases + # explicitly. + + if room.membership == Membership.INVITE: + event_id = room.event_id + invite = yield self.store.get_event(event_id, allow_none=True) + if invite: + invited_state = invite.unsigned["invite_room_state"] + writer.write_invite(room_id, invite, invited_state) + + continue + + # We only want to bother fetching events up to the last time they + # were joined. We estimate that point by looking at the + # stream_ordering of the last membership if it wasn't a join. + if room.membership == Membership.JOIN: + stream_ordering = yield self.store.get_room_max_stream_ordering() + else: + stream_ordering = room.stream_ordering + + from_key = str(RoomStreamToken(0, 0)) + to_key = str(RoomStreamToken(None, stream_ordering)) + + written_events = set() # Events that we've processed in this room + + # We need to track gaps in the events stream so that we can then + # write out the state at those events. We do this by keeping track + # of events whose prev events we haven't seen. + + # Map from event ID to prev events that haven't been processed, + # dict[str, set[str]]. + event_to_unseen_prevs = {} + + # The reverse mapping to above, i.e. map from unseen event to parent + # events. dict[str, set[str]] + unseen_event_to_parents = {} + + # We fetch events in the room the user could see by fetching *all* + # events that we have and then filtering, this isn't the most + # efficient method perhaps but it does guarentee we get everything. + while True: + events, _ = yield self.store.paginate_room_events( + room_id, from_key, to_key, limit=100, direction="f" + ) + if not events: + break + + from_key = events[-1].internal_metadata.after + + events = yield filter_events_for_client(self.store, user_id, events) + + writer.write_events(room_id, events) + + # Update the extremity tracking dicts + for event in events: + # Check if we have any prev events that haven't been + # processed yet, and add those to the appropriate dicts. + unseen_events = set(event.prev_event_ids()) - written_events + if unseen_events: + event_to_unseen_prevs[event.event_id] = unseen_events + for unseen in unseen_events: + unseen_event_to_parents.setdefault(unseen, set()).add( + event.event_id + ) + + # Now check if this event is an unseen prev event, if so + # then we remove this event from the appropriate dicts. + for event_id in unseen_event_to_parents.pop(event.event_id, []): + event_to_unseen_prevs.get(event_id, set()).discard( + event.event_id + ) + + written_events.add(event.event_id) + + logger.info( + "Written %d events in room %s", len(written_events), room_id + ) + + # Extremities are the events who have at least one unseen prev event. + extremities = ( + event_id + for event_id, unseen_prevs in event_to_unseen_prevs.items() + if unseen_prevs + ) + for event_id in extremities: + if not event_to_unseen_prevs[event_id]: + continue + state = yield self.store.get_state_for_event(event_id) + writer.write_state(room_id, event_id, state) + + defer.returnValue(writer.finished()) + + +class ExfiltrationWriter(object): + """Interfaced used to specify how to write exfilrated data. + """ + + def write_events(self, room_id, events): + """Write a batch of events for a room. + + Args: + room_id (str) + events (list[FrozenEvent]) + """ + pass + + def write_state(self, room_id, event_id, state): + """Write the state at the given event in the room. + + This only gets called for backward extremities rather than for each + event. + + Args: + room_id (str) + event_id (str) + state (list[FrozenEvent]) + """ + pass + + def write_invite(self, room_id, event, state): + """Write an invite for the room, with associated invite state. + + Args: + room_id (str) + invite (FrozenEvent) + state (list[dict]): A subset of the state at the invite, with a + subset of the event keys (type, state_key, content and sender) + """ + + def finished(self): + """Called when exfiltration is complete, and the return valus is passed + to the requester. + """ + pass + + +class FileExfiltrationWriter(ExfiltrationWriter): + """An ExfiltrationWriter that writes the users data to a directory. + + Returns the directory location on completion. + + Args: + user_id (str): The user whose data is being exfiltrated. + directory (str|None): The directory to write the data to, if None then + will write to a temporary directory. + """ + + def __init__(self, user_id, directory=None): + self.user_id = user_id + + if directory: + self.base_directory = directory + else: + self.base_directory = tempfile.mkdtemp( + prefix="synapse-exfiltrate__%s__" % (user_id,) + ) + + os.makedirs(self.base_directory, exist_ok=True) + if list(os.listdir(self.base_directory)): + raise Exception("Directory must be empty") + + def write_events(self, room_id, events): + room_directory = os.path.join(self.base_directory, "rooms", room_id) + os.makedirs(room_directory, exist_ok=True) + events_file = os.path.join(room_directory, "events") + + with open(events_file, "a") as f: + for event in events: + print(json.dumps(event.get_pdu_json()), file=f) + + def write_state(self, room_id, event_id, state): + room_directory = os.path.join(self.base_directory, "rooms", room_id) + state_directory = os.path.join(room_directory, "state") + os.makedirs(state_directory, exist_ok=True) + + event_file = os.path.join(state_directory, event_id) + + with open(event_file, "a") as f: + for event in state.values(): + print(json.dumps(event.get_pdu_json()), file=f) + + def write_invite(self, room_id, event, state): + self.write_events(room_id, [event]) + + # We write the invite state somewhere else as they aren't full events + # and are only a subset of the state at the event. + room_directory = os.path.join(self.base_directory, "rooms", room_id) + os.makedirs(room_directory, exist_ok=True) + + invite_state = os.path.join(room_directory, "invite_state") + + with open(invite_state, "a") as f: + for event in state.values(): + print(json.dumps(event), file=f) + + def finished(self): + return self.base_directory diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 8004aeb909..32cfd010a5 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -575,6 +575,26 @@ class RoomMemberWorkerStore(EventsWorkerStore): count = yield self.runInteraction("did_forget_membership", f) defer.returnValue(count == 0) + @defer.inlineCallbacks + def get_rooms_user_has_been_in(self, user_id): + """Get all rooms that the user has ever been in. + + Args: + user_id (str) + + Returns: + Deferred[set[str]]: Set of room IDs. + """ + + room_ids = yield self._simple_select_onecol( + table="room_memberships", + keyvalues={"membership": Membership.JOIN, "user_id": user_id}, + retcol="room_id", + desc="get_rooms_user_has_been_in", + ) + + return set(room_ids) + class RoomMemberStore(RoomMemberWorkerStore): def __init__(self, db_conn, hs): diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py new file mode 100644 index 0000000000..5e7d2d3361 --- /dev/null +++ b/tests/handlers/test_admin.py @@ -0,0 +1,210 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import Counter + +from mock import Mock + +import synapse.api.errors +import synapse.handlers.admin +import synapse.rest.admin +import synapse.storage +from synapse.api.constants import EventTypes +from synapse.rest.client.v1 import login, room + +from tests import unittest + + +class ExfiltrateData(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.admin_handler = hs.get_handlers().admin_handler + + self.user1 = self.register_user("user1", "password") + self.token1 = self.login("user1", "password") + + self.user2 = self.register_user("user2", "password") + self.token2 = self.login("user2", "password") + + def test_single_public_joined_room(self): + """Test that we write *all* events for a public room + """ + room_id = self.helper.create_room_as( + self.user1, tok=self.token1, is_public=True + ) + self.helper.send(room_id, body="Hello!", tok=self.token1) + self.helper.join(room_id, self.user2, tok=self.token2) + self.helper.send(room_id, body="Hello again!", tok=self.token1) + + writer = Mock() + + self.get_success(self.admin_handler.exfiltrate_user_data(self.user2, writer)) + + writer.write_events.assert_called() + + # Since we can see all events there shouldn't be any extremities, so no + # state should be written + writer.write_state.assert_not_called() + + # Collect all events that were written + written_events = [] + for (called_room_id, events), _ in writer.write_events.call_args_list: + self.assertEqual(called_room_id, room_id) + written_events.extend(events) + + # Check that the right number of events were written + counter = Counter( + (event.type, getattr(event, "state_key", None)) for event in written_events + ) + self.assertEqual(counter[(EventTypes.Message, None)], 2) + self.assertEqual(counter[(EventTypes.Member, self.user1)], 1) + self.assertEqual(counter[(EventTypes.Member, self.user2)], 1) + + def test_single_private_joined_room(self): + """Tests that we correctly write state when we can't see all events in + a room. + """ + room_id = self.helper.create_room_as(self.user1, tok=self.token1) + self.helper.send_state( + room_id, + EventTypes.RoomHistoryVisibility, + body={"history_visibility": "joined"}, + tok=self.token1, + ) + self.helper.send(room_id, body="Hello!", tok=self.token1) + self.helper.join(room_id, self.user2, tok=self.token2) + self.helper.send(room_id, body="Hello again!", tok=self.token1) + + writer = Mock() + + self.get_success(self.admin_handler.exfiltrate_user_data(self.user2, writer)) + + writer.write_events.assert_called() + + # Since we can't see all events there should be one extremity. + writer.write_state.assert_called_once() + + # Collect all events that were written + written_events = [] + for (called_room_id, events), _ in writer.write_events.call_args_list: + self.assertEqual(called_room_id, room_id) + written_events.extend(events) + + # Check that the right number of events were written + counter = Counter( + (event.type, getattr(event, "state_key", None)) for event in written_events + ) + self.assertEqual(counter[(EventTypes.Message, None)], 1) + self.assertEqual(counter[(EventTypes.Member, self.user1)], 1) + self.assertEqual(counter[(EventTypes.Member, self.user2)], 1) + + def test_single_left_room(self): + """Tests that we don't see events in the room after we leave. + """ + room_id = self.helper.create_room_as(self.user1, tok=self.token1) + self.helper.send(room_id, body="Hello!", tok=self.token1) + self.helper.join(room_id, self.user2, tok=self.token2) + self.helper.send(room_id, body="Hello again!", tok=self.token1) + self.helper.leave(room_id, self.user2, tok=self.token2) + self.helper.send(room_id, body="Helloooooo!", tok=self.token1) + + writer = Mock() + + self.get_success(self.admin_handler.exfiltrate_user_data(self.user2, writer)) + + writer.write_events.assert_called() + + # Since we can see all events there shouldn't be any extremities, so no + # state should be written + writer.write_state.assert_not_called() + + written_events = [] + for (called_room_id, events), _ in writer.write_events.call_args_list: + self.assertEqual(called_room_id, room_id) + written_events.extend(events) + + # Check that the right number of events were written + counter = Counter( + (event.type, getattr(event, "state_key", None)) for event in written_events + ) + self.assertEqual(counter[(EventTypes.Message, None)], 2) + self.assertEqual(counter[(EventTypes.Member, self.user1)], 1) + self.assertEqual(counter[(EventTypes.Member, self.user2)], 2) + + def test_single_left_rejoined_private_room(self): + """Tests that see the correct events in private rooms when we + repeatedly join and leave. + """ + room_id = self.helper.create_room_as(self.user1, tok=self.token1) + self.helper.send_state( + room_id, + EventTypes.RoomHistoryVisibility, + body={"history_visibility": "joined"}, + tok=self.token1, + ) + self.helper.send(room_id, body="Hello!", tok=self.token1) + self.helper.join(room_id, self.user2, tok=self.token2) + self.helper.send(room_id, body="Hello again!", tok=self.token1) + self.helper.leave(room_id, self.user2, tok=self.token2) + self.helper.send(room_id, body="Helloooooo!", tok=self.token1) + self.helper.join(room_id, self.user2, tok=self.token2) + self.helper.send(room_id, body="Helloooooo!!", tok=self.token1) + + writer = Mock() + + self.get_success(self.admin_handler.exfiltrate_user_data(self.user2, writer)) + + writer.write_events.assert_called_once() + + # Since we joined/left/joined again we expect there to be two gaps. + self.assertEqual(writer.write_state.call_count, 2) + + written_events = [] + for (called_room_id, events), _ in writer.write_events.call_args_list: + self.assertEqual(called_room_id, room_id) + written_events.extend(events) + + # Check that the right number of events were written + counter = Counter( + (event.type, getattr(event, "state_key", None)) for event in written_events + ) + self.assertEqual(counter[(EventTypes.Message, None)], 2) + self.assertEqual(counter[(EventTypes.Member, self.user1)], 1) + self.assertEqual(counter[(EventTypes.Member, self.user2)], 3) + + def test_invite(self): + """Tests that pending invites get handled correctly. + """ + room_id = self.helper.create_room_as(self.user1, tok=self.token1) + self.helper.send(room_id, body="Hello!", tok=self.token1) + self.helper.invite(room_id, self.user1, self.user2, tok=self.token1) + + writer = Mock() + + self.get_success(self.admin_handler.exfiltrate_user_data(self.user2, writer)) + + writer.write_events.assert_not_called() + writer.write_state.assert_not_called() + writer.write_invite.assert_called_once() + + args = writer.write_invite.call_args[0] + self.assertEqual(args[0], room_id) + self.assertEqual(args[1].content["membership"], "invite") + self.assertTrue(args[2]) # Assert there is at least one bit of state diff --git a/tests/unittest.py b/tests/unittest.py index d26804b5b5..684d5cb1cf 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -443,7 +443,7 @@ class HomeserverTestCase(TestCase): "POST", "/_matrix/client/r0/admin/register", body.encode("utf8") ) self.render(request) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, 200, channel.json_body) user_id = channel.json_body["user_id"] return user_id -- cgit 1.5.1 From 0ee9076ffe40140db5e86763af077efbdb5d86b0 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Wed, 3 Jul 2019 04:01:28 +1000 Subject: Fix media repo breaking (#5593) --- changelog.d/5593.bugfix | 1 + synapse/http/server.py | 26 +++++++++++++-------- synapse/rest/media/v1/preview_url_resource.py | 1 + synapse/util/logcontext.py | 9 ++++++-- tests/rest/media/v1/test_url_preview.py | 12 ++++++++++ tests/util/test_logcontext.py | 33 +++++++++++++++++---------- 6 files changed, 58 insertions(+), 24 deletions(-) create mode 100644 changelog.d/5593.bugfix (limited to 'tests') diff --git a/changelog.d/5593.bugfix b/changelog.d/5593.bugfix new file mode 100644 index 0000000000..e981589ac3 --- /dev/null +++ b/changelog.d/5593.bugfix @@ -0,0 +1 @@ +Fix regression in 1.1rc1 where OPTIONS requests to the media repo would fail. diff --git a/synapse/http/server.py b/synapse/http/server.py index f067c163c1..d993161a3e 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -65,8 +65,8 @@ def wrap_json_request_handler(h): The handler method must have a signature of "handle_foo(self, request)", where "request" must be a SynapseRequest. - The handler must return a deferred. If the deferred succeeds we assume that - a response has been sent. If the deferred fails with a SynapseError we use + The handler must return a deferred or a coroutine. If the deferred succeeds + we assume that a response has been sent. If the deferred fails with a SynapseError we use it to send a JSON response with the appropriate HTTP reponse code. If the deferred fails with any other type of error we send a 500 reponse. """ @@ -353,16 +353,22 @@ class DirectServeResource(resource.Resource): """ Render the request, using an asynchronous render handler if it exists. """ - render_callback_name = "_async_render_" + request.method.decode("ascii") + async_render_callback_name = "_async_render_" + request.method.decode("ascii") - if hasattr(self, render_callback_name): - # Call the handler - callback = getattr(self, render_callback_name) - defer.ensureDeferred(callback(request)) + # Try and get the async renderer + callback = getattr(self, async_render_callback_name, None) - return NOT_DONE_YET - else: - super().render(request) + # No async renderer for this request method. + if not callback: + return super().render(request) + + resp = callback(request) + + # If it's a coroutine, turn it into a Deferred + if isinstance(resp, types.CoroutineType): + defer.ensureDeferred(resp) + + return NOT_DONE_YET def _options_handler(request): diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 0337b64dc2..053346fb86 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -95,6 +95,7 @@ class PreviewUrlResource(DirectServeResource): ) def render_OPTIONS(self, request): + request.setHeader(b"Allow", b"OPTIONS, GET") return respond_with_json(request, 200, {}, send_cors=True) @wrap_json_request_handler diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index 6b0d2deea0..9e1b537804 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -24,6 +24,7 @@ See doc/log_contexts.rst for details on how this works. import logging import threading +import types from twisted.internet import defer, threads @@ -528,8 +529,9 @@ def run_in_background(f, *args, **kwargs): return from the function, and that the sentinel context is set once the deferred returned by the function completes. - Useful for wrapping functions that return a deferred which you don't yield - on (for instance because you want to pass it to deferred.gatherResults()). + Useful for wrapping functions that return a deferred or coroutine, which you don't + yield or await on (for instance because you want to pass it to + deferred.gatherResults()). Note that if you completely discard the result, you should make sure that `f` doesn't raise any deferred exceptions, otherwise a scary-looking @@ -544,6 +546,9 @@ def run_in_background(f, *args, **kwargs): # by synchronous exceptions, so let's turn them into Failures. return defer.fail() + if isinstance(res, types.CoroutineType): + res = defer.ensureDeferred(res) + if not isinstance(res, defer.Deferred): return res diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 8fe5961866..976652aee8 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -460,3 +460,15 @@ class URLPreviewTests(unittest.HomeserverTestCase): "error": "DNS resolution failure during URL preview generation", }, ) + + def test_OPTIONS(self): + """ + OPTIONS returns the OPTIONS. + """ + request, channel = self.make_request( + "OPTIONS", "url_preview?url=http://example.com", shorthand=False + ) + request.render(self.preview_url) + self.pump() + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body, {}) diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index 8adaee3c8d..8d69fbf111 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -39,24 +39,17 @@ class LoggingContextTestCase(unittest.TestCase): callback_completed = [False] - def test(): + with LoggingContext() as context_one: context_one.request = "one" - d = function() + + # fire off function, but don't wait on it. + d2 = logcontext.run_in_background(function) def cb(res): - self._check_test_key("one") callback_completed[0] = True return res - d.addCallback(cb) - - return d - - with LoggingContext() as context_one: - context_one.request = "one" - - # fire off function, but don't wait on it. - logcontext.run_in_background(test) + d2.addCallback(cb) self._check_test_key("one") @@ -105,6 +98,22 @@ class LoggingContextTestCase(unittest.TestCase): return self._test_run_in_background(testfunc) + def test_run_in_background_with_coroutine(self): + async def testfunc(): + self._check_test_key("one") + d = Clock(reactor).sleep(0) + self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + await d + self._check_test_key("one") + + return self._test_run_in_background(testfunc) + + def test_run_in_background_with_nonblocking_coroutine(self): + async def testfunc(): + self._check_test_key("one") + + return self._test_run_in_background(testfunc) + @defer.inlineCallbacks def test_make_deferred_yieldable(self): # a function which retuns an incomplete deferred, but doesn't follow -- cgit 1.5.1 From 463b072b1290a8cb75bc1d7b6688fa76bbfdb14f Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Thu, 4 Jul 2019 00:07:04 +1000 Subject: Move logging utilities out of the side drawer of util/ and into logging/ (#5606) --- changelog.d/5606.misc | 1 + contrib/example_log_config.yaml | 8 +- contrib/experiments/test_messaging.py | 2 +- contrib/systemd/log_config.yaml | 2 +- debian/changelog | 3 + debian/log.yaml | 2 +- docker/conf/log.config | 2 +- docs/log_contexts.rst | 38 +- synapse/app/_base.py | 2 +- synapse/app/appservice.py | 2 +- synapse/app/client_reader.py | 2 +- synapse/app/event_creator.py | 2 +- synapse/app/federation_reader.py | 2 +- synapse/app/federation_sender.py | 2 +- synapse/app/frontend_proxy.py | 2 +- synapse/app/homeserver.py | 2 +- synapse/app/media_repository.py | 2 +- synapse/app/pusher.py | 2 +- synapse/app/synchrotron.py | 2 +- synapse/app/user_dir.py | 2 +- synapse/appservice/scheduler.py | 2 +- synapse/config/logger.py | 4 +- synapse/crypto/keyring.py | 15 +- synapse/events/snapshot.py | 2 +- synapse/federation/federation_base.py | 24 +- synapse/federation/federation_client.py | 8 +- synapse/federation/federation_server.py | 4 +- synapse/federation/persistence.py | 2 +- synapse/federation/sender/__init__.py | 12 +- synapse/federation/transport/client.py | 2 +- synapse/federation/transport/server.py | 2 +- synapse/groups/attestations.py | 2 +- synapse/handlers/account_validity.py | 2 +- synapse/handlers/appservice.py | 2 +- synapse/handlers/auth.py | 6 +- synapse/handlers/e2e_keys.py | 2 +- synapse/handlers/events.py | 2 +- synapse/handlers/federation.py | 43 +- synapse/handlers/initial_sync.py | 2 +- synapse/handlers/message.py | 2 +- synapse/handlers/pagination.py | 2 +- synapse/handlers/presence.py | 4 +- synapse/handlers/sync.py | 2 +- synapse/handlers/typing.py | 2 +- synapse/http/client.py | 2 +- synapse/http/federation/matrix_federation_agent.py | 2 +- synapse/http/federation/srv_resolver.py | 2 +- synapse/http/matrixfederationclient.py | 2 +- synapse/http/request_metrics.py | 2 +- synapse/http/server.py | 2 +- synapse/http/site.py | 2 +- synapse/logging/__init__.py | 0 synapse/logging/context.py | 693 +++++++++++++++++++++ synapse/logging/formatter.py | 53 ++ synapse/logging/utils.py | 194 ++++++ synapse/metrics/background_process_metrics.py | 2 +- synapse/notifier.py | 4 +- synapse/push/mailer.py | 2 +- synapse/replication/tcp/protocol.py | 2 +- synapse/rest/client/transactions.py | 2 +- synapse/rest/media/v1/_base.py | 6 +- synapse/rest/media/v1/media_repository.py | 12 +- synapse/rest/media/v1/media_storage.py | 5 +- synapse/rest/media/v1/preview_url_resource.py | 2 +- synapse/rest/media/v1/storage_provider.py | 5 +- synapse/state/__init__.py | 2 +- synapse/storage/_base.py | 2 +- synapse/storage/events.py | 4 +- synapse/storage/events_worker.py | 6 +- synapse/storage/stream.py | 2 +- synapse/util/__init__.py | 12 +- synapse/util/async_helpers.py | 9 +- synapse/util/caches/descriptors.py | 11 +- synapse/util/caches/response_cache.py | 4 +- synapse/util/distributor.py | 2 +- synapse/util/file_consumer.py | 2 +- synapse/util/logcontext.py | 693 --------------------- synapse/util/logformatter.py | 53 -- synapse/util/logutils.py | 194 ------ synapse/util/metrics.py | 2 +- synapse/util/ratelimitutils.py | 2 +- synapse/util/retryutils.py | 4 +- tests/appservice/test_scheduler.py | 2 +- tests/crypto/test_keyring.py | 13 +- .../federation/test_matrix_federation_agent.py | 2 +- tests/http/federation/test_srv_resolver.py | 2 +- tests/http/test_fedclient.py | 2 +- tests/patch_inline_callbacks.py | 2 +- tests/push/test_http.py | 2 +- tests/rest/client/test_transactions.py | 2 +- tests/rest/media/v1/test_media_storage.py | 2 +- tests/test_federation.py | 2 +- tests/test_server.py | 2 +- tests/test_utils/logging_setup.py | 2 +- tests/unittest.py | 2 +- tests/util/caches/test_descriptors.py | 49 +- tests/util/test_async_utils.py | 7 +- tests/util/test_linearizer.py | 9 +- tests/util/test_logcontext.py | 24 +- tests/util/test_logformatter.py | 2 +- tests/utils.py | 2 +- 101 files changed, 1189 insertions(+), 1173 deletions(-) create mode 100644 changelog.d/5606.misc create mode 100644 synapse/logging/__init__.py create mode 100644 synapse/logging/context.py create mode 100644 synapse/logging/formatter.py create mode 100644 synapse/logging/utils.py delete mode 100644 synapse/util/logcontext.py delete mode 100644 synapse/util/logformatter.py delete mode 100644 synapse/util/logutils.py (limited to 'tests') diff --git a/changelog.d/5606.misc b/changelog.d/5606.misc new file mode 100644 index 0000000000..bb3c028167 --- /dev/null +++ b/changelog.d/5606.misc @@ -0,0 +1 @@ +Move logging code out of `synapse.util` and into `synapse.logging`. diff --git a/contrib/example_log_config.yaml b/contrib/example_log_config.yaml index 06592963da..3a76a7a33f 100644 --- a/contrib/example_log_config.yaml +++ b/contrib/example_log_config.yaml @@ -1,7 +1,7 @@ -# Example log_config file for synapse. To enable, point `log_config` to it in +# Example log_config file for synapse. To enable, point `log_config` to it in # `homeserver.yaml`, and restart synapse. # -# This configuration will produce similar results to the defaults within +# This configuration will produce similar results to the defaults within # synapse, but can be edited to give more flexibility. version: 1 @@ -12,7 +12,7 @@ formatters: filters: context: - (): synapse.util.logcontext.LoggingContextFilter + (): synapse.logging.context.LoggingContextFilter request: "" handlers: @@ -35,7 +35,7 @@ handlers: root: level: INFO handlers: [console] # to use file handler instead, switch to [file] - + loggers: synapse: level: INFO diff --git a/contrib/experiments/test_messaging.py b/contrib/experiments/test_messaging.py index c7e55d8aa7..5ef140ae48 100644 --- a/contrib/experiments/test_messaging.py +++ b/contrib/experiments/test_messaging.py @@ -36,7 +36,7 @@ from synapse.util import origin_from_ucid from synapse.app.homeserver import SynapseHomeServer -# from synapse.util.logutils import log_function +# from synapse.logging.utils import log_function from twisted.internet import reactor, defer from twisted.python import log diff --git a/contrib/systemd/log_config.yaml b/contrib/systemd/log_config.yaml index d85bdd1208..22f67a50ce 100644 --- a/contrib/systemd/log_config.yaml +++ b/contrib/systemd/log_config.yaml @@ -8,7 +8,7 @@ formatters: filters: context: - (): synapse.util.logcontext.LoggingContextFilter + (): synapse.logging.context.LoggingContextFilter request: "" handlers: diff --git a/debian/changelog b/debian/changelog index 91653e7243..bf5a59b335 100644 --- a/debian/changelog +++ b/debian/changelog @@ -3,6 +3,9 @@ matrix-synapse-py3 (1.0.0+nmu1) UNRELEASED; urgency=medium [ Silke Hofstra ] * Include systemd-python to allow logging to the systemd journal. + [ Amber Brown ] + * Update logging config defaults to match API changes in Synapse. + -- Silke Hofstra Wed, 29 May 2019 09:45:29 +0200 matrix-synapse-py3 (1.0.0) stable; urgency=medium diff --git a/debian/log.yaml b/debian/log.yaml index 206b65f1ac..95b655dd35 100644 --- a/debian/log.yaml +++ b/debian/log.yaml @@ -7,7 +7,7 @@ formatters: filters: context: - (): synapse.util.logcontext.LoggingContextFilter + (): synapse.logging.context.LoggingContextFilter request: "" handlers: diff --git a/docker/conf/log.config b/docker/conf/log.config index 895e45d20b..ea5ccfd68b 100644 --- a/docker/conf/log.config +++ b/docker/conf/log.config @@ -6,7 +6,7 @@ formatters: filters: context: - (): synapse.util.logcontext.LoggingContextFilter + (): synapse.logging.context.LoggingContextFilter request: "" handlers: diff --git a/docs/log_contexts.rst b/docs/log_contexts.rst index 27cde11cf7..f5cd5de8ab 100644 --- a/docs/log_contexts.rst +++ b/docs/log_contexts.rst @@ -1,4 +1,4 @@ -Log contexts +Log Contexts ============ .. contents:: @@ -12,7 +12,7 @@ record. Logcontexts are also used for CPU and database accounting, so that we can track which requests were responsible for high CPU use or database activity. -The ``synapse.util.logcontext`` module provides a facilities for managing the +The ``synapse.logging.context`` module provides a facilities for managing the current log context (as well as providing the ``LoggingContextFilter`` class). Deferreds make the whole thing complicated, so this document describes how it @@ -27,19 +27,19 @@ found them: .. code:: python - from synapse.util import logcontext # omitted from future snippets + from synapse.logging import context # omitted from future snippets def handle_request(request_id): - request_context = logcontext.LoggingContext() + request_context = context.LoggingContext() - calling_context = logcontext.LoggingContext.current_context() - logcontext.LoggingContext.set_current_context(request_context) + calling_context = context.LoggingContext.current_context() + context.LoggingContext.set_current_context(request_context) try: request_context.request = request_id do_request_handling() logger.debug("finished") finally: - logcontext.LoggingContext.set_current_context(calling_context) + context.LoggingContext.set_current_context(calling_context) def do_request_handling(): logger.debug("phew") # this will be logged against request_id @@ -51,7 +51,7 @@ written much more succinctly as: .. code:: python def handle_request(request_id): - with logcontext.LoggingContext() as request_context: + with context.LoggingContext() as request_context: request_context.request = request_id do_request_handling() logger.debug("finished") @@ -74,7 +74,7 @@ blocking operation, and returns a deferred: @defer.inlineCallbacks def handle_request(request_id): - with logcontext.LoggingContext() as request_context: + with context.LoggingContext() as request_context: request_context.request = request_id yield do_request_handling() logger.debug("finished") @@ -179,7 +179,7 @@ though, we need to make up a new Deferred, or we get a Deferred back from external code. We need to make it follow our rules. The easy way to do it is with a combination of ``defer.inlineCallbacks``, and -``logcontext.PreserveLoggingContext``. Suppose we want to implement ``sleep``, +``context.PreserveLoggingContext``. Suppose we want to implement ``sleep``, which returns a deferred which will run its callbacks after a given number of seconds. That might look like: @@ -204,13 +204,13 @@ That doesn't follow the rules, but we can fix it by wrapping it with This technique works equally for external functions which return deferreds, or deferreds we have made ourselves. -You can also use ``logcontext.make_deferred_yieldable``, which just does the +You can also use ``context.make_deferred_yieldable``, which just does the boilerplate for you, so the above could be written: .. code:: python def sleep(seconds): - return logcontext.make_deferred_yieldable(get_sleep_deferred(seconds)) + return context.make_deferred_yieldable(get_sleep_deferred(seconds)) Fire-and-forget @@ -279,7 +279,7 @@ Obviously that option means that the operations done in that might be fixed by setting a different logcontext via a ``with LoggingContext(...)`` in ``background_operation``). -The second option is to use ``logcontext.run_in_background``, which wraps a +The second option is to use ``context.run_in_background``, which wraps a function so that it doesn't reset the logcontext even when it returns an incomplete deferred, and adds a callback to the returned deferred to reset the logcontext. In other words, it turns a function that follows the Synapse rules @@ -293,7 +293,7 @@ It can be used like this: def do_request_handling(): yield foreground_operation() - logcontext.run_in_background(background_operation) + context.run_in_background(background_operation) # this will now be logged against the request context logger.debug("Request handling complete") @@ -332,7 +332,7 @@ gathered: result = yield defer.gatherResults([d1, d2]) In this case particularly, though, option two, of using -``logcontext.preserve_fn`` almost certainly makes more sense, so that +``context.preserve_fn`` almost certainly makes more sense, so that ``operation1`` and ``operation2`` are both logged against the original logcontext. This looks like: @@ -340,8 +340,8 @@ logcontext. This looks like: @defer.inlineCallbacks def do_request_handling(): - d1 = logcontext.preserve_fn(operation1)() - d2 = logcontext.preserve_fn(operation2)() + d1 = context.preserve_fn(operation1)() + d2 = context.preserve_fn(operation2)() with PreserveLoggingContext(): result = yield defer.gatherResults([d1, d2]) @@ -381,7 +381,7 @@ off the background process, and then leave the ``with`` block to wait for it: .. code:: python def handle_request(request_id): - with logcontext.LoggingContext() as request_context: + with context.LoggingContext() as request_context: request_context.request = request_id d = do_request_handling() @@ -414,7 +414,7 @@ runs its callbacks in the original logcontext, all is happy. The business of a Deferred which runs its callbacks in the original logcontext isn't hard to achieve — we have it today, in the shape of -``logcontext._PreservingContextDeferred``: +``context._PreservingContextDeferred``: .. code:: python diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 9b5c30f506..1ebb7ae539 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -27,7 +27,7 @@ from twisted.protocols.tls import TLSMemoryBIOFactory import synapse from synapse.app import check_bind_error from synapse.crypto import context_factory -from synapse.util import PreserveLoggingContext +from synapse.logging.context import PreserveLoggingContext from synapse.util.async_helpers import Linearizer from synapse.util.rlimit import change_resource_limit from synapse.util.versionstring import get_version_string diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index 9120bdb143..be44249ed6 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -26,6 +26,7 @@ from synapse.config._base import ConfigError from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging from synapse.http.site import SynapseSite +from synapse.logging.context import LoggingContext, run_in_background from synapse.metrics import RegistryProxy from synapse.metrics.resource import METRICS_PREFIX, MetricsResource from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore @@ -36,7 +37,6 @@ from synapse.replication.tcp.client import ReplicationClientHandler from synapse.server import HomeServer from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.logcontext import LoggingContext, run_in_background from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index 90bc79cdda..ff11beca82 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -27,6 +27,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging from synapse.http.server import JsonResource from synapse.http.site import SynapseSite +from synapse.logging.context import LoggingContext from synapse.metrics import RegistryProxy from synapse.metrics.resource import METRICS_PREFIX, MetricsResource from synapse.replication.slave.storage._base import BaseSlavedStore @@ -64,7 +65,6 @@ from synapse.rest.client.versions import VersionsRestServlet from synapse.server import HomeServer from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.logcontext import LoggingContext from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py index ff522e4499..cacad25eac 100644 --- a/synapse/app/event_creator.py +++ b/synapse/app/event_creator.py @@ -27,6 +27,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging from synapse.http.server import JsonResource from synapse.http.site import SynapseSite +from synapse.logging.context import LoggingContext from synapse.metrics import RegistryProxy from synapse.metrics.resource import METRICS_PREFIX, MetricsResource from synapse.replication.slave.storage._base import BaseSlavedStore @@ -59,7 +60,6 @@ from synapse.server import HomeServer from synapse.storage.engines import create_engine from synapse.storage.user_directory import UserDirectoryStore from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.logcontext import LoggingContext from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index 9421420930..11e80dbae0 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -28,6 +28,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging from synapse.federation.transport.server import TransportLayerServer from synapse.http.site import SynapseSite +from synapse.logging.context import LoggingContext from synapse.metrics import RegistryProxy from synapse.metrics.resource import METRICS_PREFIX, MetricsResource from synapse.replication.slave.storage._base import BaseSlavedStore @@ -48,7 +49,6 @@ from synapse.rest.key.v2 import KeyApiV2Resource from synapse.server import HomeServer from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.logcontext import LoggingContext from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 969be58d0b..97da7bdcbf 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -27,6 +27,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging from synapse.federation import send_queue from synapse.http.site import SynapseSite +from synapse.logging.context import LoggingContext, run_in_background from synapse.metrics import RegistryProxy from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.resource import METRICS_PREFIX, MetricsResource @@ -44,7 +45,6 @@ from synapse.storage.engines import create_engine from synapse.types import ReadReceipt from synapse.util.async_helpers import Linearizer from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.logcontext import LoggingContext, run_in_background from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index 2fd7d57ebf..417a10bbd2 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -29,6 +29,7 @@ from synapse.config.logger import setup_logging from synapse.http.server import JsonResource from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseSite +from synapse.logging.context import LoggingContext from synapse.metrics import RegistryProxy from synapse.metrics.resource import METRICS_PREFIX, MetricsResource from synapse.replication.slave.storage._base import BaseSlavedStore @@ -41,7 +42,6 @@ from synapse.rest.client.v2_alpha._base import client_patterns from synapse.server import HomeServer from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.logcontext import LoggingContext from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 49da105cf6..639b1429c0 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -54,6 +54,7 @@ from synapse.federation.transport.server import TransportLayerServer from synapse.http.additional_resource import AdditionalResource from synapse.http.server import RootRedirect from synapse.http.site import SynapseSite +from synapse.logging.context import LoggingContext from synapse.metrics import RegistryProxy from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.resource import METRICS_PREFIX, MetricsResource @@ -72,7 +73,6 @@ from synapse.storage.engines import IncorrectDatabaseSetup, create_engine from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.logcontext import LoggingContext from synapse.util.manhole import manhole from synapse.util.module_loader import load_module from synapse.util.rlimit import change_resource_limit diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index cf0e2036c3..f23b9b6eda 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -27,6 +27,7 @@ from synapse.config._base import ConfigError from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging from synapse.http.site import SynapseSite +from synapse.logging.context import LoggingContext from synapse.metrics import RegistryProxy from synapse.metrics.resource import METRICS_PREFIX, MetricsResource from synapse.replication.slave.storage._base import BaseSlavedStore @@ -40,7 +41,6 @@ from synapse.server import HomeServer from synapse.storage.engines import create_engine from synapse.storage.media_repository import MediaRepositoryStore from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.logcontext import LoggingContext from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index df29ea5ecb..4f929edf86 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -26,6 +26,7 @@ from synapse.config._base import ConfigError from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging from synapse.http.site import SynapseSite +from synapse.logging.context import LoggingContext, run_in_background from synapse.metrics import RegistryProxy from synapse.metrics.resource import METRICS_PREFIX, MetricsResource from synapse.replication.slave.storage._base import __func__ @@ -38,7 +39,6 @@ from synapse.server import HomeServer from synapse.storage import DataStore from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.logcontext import LoggingContext, run_in_background from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 858949910d..de4797fddc 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -31,6 +31,7 @@ from synapse.config.logger import setup_logging from synapse.handlers.presence import PresenceHandler, get_interested_parties from synapse.http.server import JsonResource from synapse.http.site import SynapseSite +from synapse.logging.context import LoggingContext, run_in_background from synapse.metrics import RegistryProxy from synapse.metrics.resource import METRICS_PREFIX, MetricsResource from synapse.replication.slave.storage._base import BaseSlavedStore, __func__ @@ -57,7 +58,6 @@ from synapse.server import HomeServer from synapse.storage.engines import create_engine from synapse.storage.presence import UserPresenceState from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.logcontext import LoggingContext, run_in_background from synapse.util.manhole import manhole from synapse.util.stringutils import random_string from synapse.util.versionstring import get_version_string diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index 2d9d2e1bbc..1177ddd72e 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -28,6 +28,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging from synapse.http.server import JsonResource from synapse.http.site import SynapseSite +from synapse.logging.context import LoggingContext, run_in_background from synapse.metrics import RegistryProxy from synapse.metrics.resource import METRICS_PREFIX, MetricsResource from synapse.replication.slave.storage._base import BaseSlavedStore @@ -46,7 +47,6 @@ from synapse.storage.engines import create_engine from synapse.storage.user_directory import UserDirectoryStore from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.logcontext import LoggingContext, run_in_background from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index b54bf5411f..e5b36494f5 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -53,8 +53,8 @@ import logging from twisted.internet import defer from synapse.appservice import ApplicationServiceState +from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util.logcontext import run_in_background logger = logging.getLogger(__name__) diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 931aec41c0..0f5554211c 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -24,7 +24,7 @@ from twisted.logger import STDLibLogObserver, globalLogBeginner import synapse from synapse.app import _base as appbase -from synapse.util.logcontext import LoggingContextFilter +from synapse.logging.context import LoggingContextFilter from synapse.util.versionstring import get_version_string from ._base import Config @@ -40,7 +40,7 @@ formatters: filters: context: - (): synapse.util.logcontext.LoggingContextFilter + (): synapse.logging.context.LoggingContextFilter request: "" handlers: diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 10c2eb7f0f..341c863152 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -44,15 +44,16 @@ from synapse.api.errors import ( RequestSendFailed, SynapseError, ) -from synapse.storage.keys import FetchKeyResult -from synapse.util import logcontext, unwrapFirstError -from synapse.util.async_helpers import yieldable_gather_results -from synapse.util.logcontext import ( +from synapse.logging.context import ( LoggingContext, PreserveLoggingContext, + make_deferred_yieldable, preserve_fn, run_in_background, ) +from synapse.storage.keys import FetchKeyResult +from synapse.util import unwrapFirstError +from synapse.util.async_helpers import yieldable_gather_results from synapse.util.metrics import Measure from synapse.util.retryutils import NotRetryingDestination @@ -140,7 +141,7 @@ class Keyring(object): """ req = VerifyJsonRequest(server_name, json_object, validity_time, request_name) requests = (req,) - return logcontext.make_deferred_yieldable(self._verify_objects(requests)[0]) + return make_deferred_yieldable(self._verify_objects(requests)[0]) def verify_json_objects_for_server(self, server_and_json): """Bulk verifies signatures of json objects, bulk fetching keys as @@ -557,7 +558,7 @@ class BaseV2KeyFetcher(object): signed_key_json_bytes = encode_canonical_json(signed_key_json) - yield logcontext.make_deferred_yieldable( + yield make_deferred_yieldable( defer.gatherResults( [ run_in_background( @@ -612,7 +613,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): defer.returnValue({}) - results = yield logcontext.make_deferred_yieldable( + results = yield make_deferred_yieldable( defer.gatherResults( [run_in_background(get_key, server) for server in self.key_servers], consumeErrors=True, diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index a96cdada3d..a9545e6c1b 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -19,7 +19,7 @@ from frozendict import frozendict from twisted.internet import defer -from synapse.util.logcontext import make_deferred_yieldable, run_in_background +from synapse.logging.context import make_deferred_yieldable, run_in_background class EventContext(object): diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 1e925b19e7..f7bb806ae7 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -27,8 +27,14 @@ from synapse.crypto.event_signing import check_event_content_hash from synapse.events import event_type_from_format_version from synapse.events.utils import prune_event from synapse.http.servlet import assert_params_in_dict +from synapse.logging.context import ( + LoggingContext, + PreserveLoggingContext, + make_deferred_yieldable, + preserve_fn, +) from synapse.types import get_domain_from_id -from synapse.util import logcontext, unwrapFirstError +from synapse.util import unwrapFirstError logger = logging.getLogger(__name__) @@ -73,7 +79,7 @@ class FederationBase(object): @defer.inlineCallbacks def handle_check_result(pdu, deferred): try: - res = yield logcontext.make_deferred_yieldable(deferred) + res = yield make_deferred_yieldable(deferred) except SynapseError: res = None @@ -102,10 +108,10 @@ class FederationBase(object): defer.returnValue(res) - handle = logcontext.preserve_fn(handle_check_result) + handle = preserve_fn(handle_check_result) deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)] - valid_pdus = yield logcontext.make_deferred_yieldable( + valid_pdus = yield make_deferred_yieldable( defer.gatherResults(deferreds2, consumeErrors=True) ).addErrback(unwrapFirstError) @@ -115,7 +121,7 @@ class FederationBase(object): defer.returnValue([p for p in valid_pdus if p]) def _check_sigs_and_hash(self, room_version, pdu): - return logcontext.make_deferred_yieldable( + return make_deferred_yieldable( self._check_sigs_and_hashes(room_version, [pdu])[0] ) @@ -133,14 +139,14 @@ class FederationBase(object): * returns a redacted version of the event (if the signature matched but the hash did not) * throws a SynapseError if the signature check failed. - The deferreds run their callbacks in the sentinel logcontext. + The deferreds run their callbacks in the sentinel """ deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus) - ctx = logcontext.LoggingContext.current_context() + ctx = LoggingContext.current_context() def callback(_, pdu): - with logcontext.PreserveLoggingContext(ctx): + with PreserveLoggingContext(ctx): if not check_event_content_hash(pdu): # let's try to distinguish between failures because the event was # redacted (which are somewhat expected) vs actual ball-tampering @@ -178,7 +184,7 @@ class FederationBase(object): def errback(failure, pdu): failure.trap(SynapseError) - with logcontext.PreserveLoggingContext(ctx): + with PreserveLoggingContext(ctx): logger.warn( "Signature check failed for %s: %s", pdu.event_id, diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 3883eb525e..3cb4b94420 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -39,10 +39,10 @@ from synapse.api.room_versions import ( ) from synapse.events import builder, room_version_to_event_format from synapse.federation.federation_base import FederationBase, event_from_pdu_json -from synapse.util import logcontext, unwrapFirstError +from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.logging.utils import log_function +from synapse.util import unwrapFirstError from synapse.util.caches.expiringcache import ExpiringCache -from synapse.util.logcontext import make_deferred_yieldable, run_in_background -from synapse.util.logutils import log_function from synapse.util.retryutils import NotRetryingDestination logger = logging.getLogger(__name__) @@ -207,7 +207,7 @@ class FederationClient(FederationBase): ] # FIXME: We should handle signature failures more gracefully. - pdus[:] = yield logcontext.make_deferred_yieldable( + pdus[:] = yield make_deferred_yieldable( defer.gatherResults( self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True ).addErrback(unwrapFirstError) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 2e0cebb638..8c0a18b120 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -42,6 +42,8 @@ from synapse.federation.federation_base import FederationBase, event_from_pdu_js from synapse.federation.persistence import TransactionActions from synapse.federation.units import Edu, Transaction from synapse.http.endpoint import parse_server_name +from synapse.logging.context import nested_logging_context +from synapse.logging.utils import log_function from synapse.replication.http.federation import ( ReplicationFederationSendEduRestServlet, ReplicationGetQueryRestServlet, @@ -50,8 +52,6 @@ from synapse.types import get_domain_from_id from synapse.util import glob_to_regex from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.caches.response_cache import ResponseCache -from synapse.util.logcontext import nested_logging_context -from synapse.util.logutils import log_function # when processing incoming transactions, we try to handle multiple rooms in # parallel, up to this limit. diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index 7535f79203..d086e04243 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -23,7 +23,7 @@ import logging from twisted.internet import defer -from synapse.util.logutils import log_function +from synapse.logging.utils import log_function logger = logging.getLogger(__name__) diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 766c5a37cd..d46f4aaeb1 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -26,6 +26,11 @@ from synapse.federation.sender.per_destination_queue import PerDestinationQueue from synapse.federation.sender.transaction_manager import TransactionManager from synapse.federation.units import Edu from synapse.handlers.presence import get_interested_remotes +from synapse.logging.context import ( + make_deferred_yieldable, + preserve_fn, + run_in_background, +) from synapse.metrics import ( LaterGauge, event_processing_loop_counter, @@ -33,7 +38,6 @@ from synapse.metrics import ( events_processed_counter, ) from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util import logcontext from synapse.util.metrics import measure_func logger = logging.getLogger(__name__) @@ -210,10 +214,10 @@ class FederationSender(object): for event in events: events_by_room.setdefault(event.room_id, []).append(event) - yield logcontext.make_deferred_yieldable( + yield make_deferred_yieldable( defer.gatherResults( [ - logcontext.run_in_background(handle_room_events, evs) + run_in_background(handle_room_events, evs) for evs in itervalues(events_by_room) ], consumeErrors=True, @@ -360,7 +364,7 @@ class FederationSender(object): for queue in queues: queue.flush_read_receipts_for_room(room_id) - @logcontext.preserve_fn # the caller should not yield on this + @preserve_fn # the caller should not yield on this @defer.inlineCallbacks def send_presence(self, states): """Send the new presence states to the appropriate destinations. diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index aecd142309..1aae9ec9e7 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -22,7 +22,7 @@ from twisted.internet import defer from synapse.api.constants import Membership from synapse.api.urls import FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX -from synapse.util.logutils import log_function +from synapse.logging.utils import log_function logger = logging.getLogger(__name__) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 955f0f4308..2efdcff7ef 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -36,8 +36,8 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string_from_args, ) +from synapse.logging.context import run_in_background from synapse.types import ThirdPartyInstanceID, get_domain_from_id -from synapse.util.logcontext import run_in_background from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.versionstring import get_version_string diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index e73757570c..f497711133 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -43,9 +43,9 @@ from signedjson.sign import sign_json from twisted.internet import defer from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError +from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import get_domain_from_id -from synapse.util.logcontext import run_in_background logger = logging.getLogger(__name__) diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index edb48054a0..1f1708ba7d 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -22,10 +22,10 @@ from email.mime.text import MIMEText from twisted.internet import defer from synapse.api.errors import StoreError +from synapse.logging.context import make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import UserID from synapse.util import stringutils -from synapse.util.logcontext import make_deferred_yieldable try: from synapse.push.mailer import load_jinja2_templates diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 5cc89d43f6..8f089f0e33 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -23,13 +23,13 @@ from twisted.internet import defer import synapse from synapse.api.constants import EventTypes +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics import ( event_processing_loop_counter, event_processing_loop_room_count, ) from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util import log_failure -from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.metrics import Measure logger = logging.getLogger(__name__) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index c8c1ed3246..ef5585aa99 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -36,9 +36,9 @@ from synapse.api.errors import ( SynapseError, ) from synapse.api.ratelimiting import Ratelimiter +from synapse.logging.context import defer_to_thread from synapse.module_api import ModuleApi from synapse.types import UserID -from synapse.util import logcontext from synapse.util.caches.expiringcache import ExpiringCache from ._base import BaseHandler @@ -987,7 +987,7 @@ class AuthHandler(BaseHandler): bcrypt.gensalt(self.bcrypt_rounds), ).decode("ascii") - return logcontext.defer_to_thread(self.hs.get_reactor(), _do_hash) + return defer_to_thread(self.hs.get_reactor(), _do_hash) def validate_hash(self, password, stored_hash): """Validates that self.hash(password) == stored_hash. @@ -1013,7 +1013,7 @@ class AuthHandler(BaseHandler): if not isinstance(stored_hash, bytes): stored_hash = stored_hash.encode("ascii") - return logcontext.defer_to_thread(self.hs.get_reactor(), _do_validate_hash) + return defer_to_thread(self.hs.get_reactor(), _do_validate_hash) else: return defer.succeed(False) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 807900fe52..55b4ab3a1a 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -23,8 +23,8 @@ from canonicaljson import encode_canonical_json, json from twisted.internet import defer from synapse.api.errors import CodeMessageException, FederationDeniedError, SynapseError +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.types import UserID, get_domain_from_id -from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.retryutils import NotRetryingDestination logger = logging.getLogger(__name__) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 5836d3c639..6a38328af3 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -21,8 +21,8 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, SynapseError from synapse.events import EventBase +from synapse.logging.utils import log_function from synapse.types import UserID -from synapse.util.logutils import log_function from synapse.visibility import filter_events_for_client from ._base import BaseHandler diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 02d397c498..57be968c67 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -45,6 +45,13 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.crypto.event_signing import compute_event_signature from synapse.event_auth import auth_types_for_event from synapse.events.validator import EventValidator +from synapse.logging.context import ( + make_deferred_yieldable, + nested_logging_context, + preserve_fn, + run_in_background, +) +from synapse.logging.utils import log_function from synapse.replication.http.federation import ( ReplicationCleanRoomRestServlet, ReplicationFederationSendEventsRestServlet, @@ -52,10 +59,9 @@ from synapse.replication.http.federation import ( from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet from synapse.state import StateResolutionStore, resolve_events_with_store from synapse.types import UserID, get_domain_from_id -from synapse.util import logcontext, unwrapFirstError +from synapse.util import unwrapFirstError from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_joined_room -from synapse.util.logutils import log_function from synapse.util.retryutils import NotRetryingDestination from synapse.visibility import filter_events_for_server @@ -338,7 +344,7 @@ class FederationHandler(BaseHandler): room_version = yield self.store.get_room_version(room_id) - with logcontext.nested_logging_context(p): + with nested_logging_context(p): # note that if any of the missing prevs share missing state or # auth events, the requests to fetch those events are deduped # by the get_pdu_cache in federation_client. @@ -532,7 +538,7 @@ class FederationHandler(BaseHandler): event_id, ev.event_id, ) - with logcontext.nested_logging_context(ev.event_id): + with nested_logging_context(ev.event_id): try: yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False) except FederationError as e: @@ -725,10 +731,10 @@ class FederationHandler(BaseHandler): missing_auth - failed_to_fetch, ) - results = yield logcontext.make_deferred_yieldable( + results = yield make_deferred_yieldable( defer.gatherResults( [ - logcontext.run_in_background( + run_in_background( self.federation_client.get_pdu, [dest], event_id, @@ -994,10 +1000,8 @@ class FederationHandler(BaseHandler): event_ids = list(extremities.keys()) logger.debug("calling resolve_state_groups in _maybe_backfill") - resolve = logcontext.preserve_fn( - self.state_handler.resolve_state_groups_for_events - ) - states = yield logcontext.make_deferred_yieldable( + resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events) + states = yield make_deferred_yieldable( defer.gatherResults( [resolve(room_id, [e]) for e in event_ids], consumeErrors=True ) @@ -1171,7 +1175,7 @@ class FederationHandler(BaseHandler): # lots of requests for missing prev_events which we do actually # have. Hence we fire off the deferred, but don't wait for it. - logcontext.run_in_background(self._handle_queued_pdus, room_queue) + run_in_background(self._handle_queued_pdus, room_queue) defer.returnValue(True) @@ -1191,7 +1195,7 @@ class FederationHandler(BaseHandler): p.event_id, p.room_id, ) - with logcontext.nested_logging_context(p.event_id): + with nested_logging_context(p.event_id): yield self.on_receive_pdu(origin, p, sent_to_us_directly=True) except Exception as e: logger.warn( @@ -1610,7 +1614,7 @@ class FederationHandler(BaseHandler): success = True finally: if not success: - logcontext.run_in_background( + run_in_background( self.store.remove_push_actions_from_staging, event.event_id ) @@ -1629,7 +1633,7 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def prep(ev_info): event = ev_info["event"] - with logcontext.nested_logging_context(suffix=event.event_id): + with nested_logging_context(suffix=event.event_id): res = yield self._prep_event( origin, event, @@ -1639,12 +1643,9 @@ class FederationHandler(BaseHandler): ) defer.returnValue(res) - contexts = yield logcontext.make_deferred_yieldable( + contexts = yield make_deferred_yieldable( defer.gatherResults( - [ - logcontext.run_in_background(prep, ev_info) - for ev_info in event_infos - ], + [run_in_background(prep, ev_info) for ev_info in event_infos], consumeErrors=True, ) ) @@ -2106,10 +2107,10 @@ class FederationHandler(BaseHandler): room_version = yield self.store.get_room_version(event.room_id) - different_events = yield logcontext.make_deferred_yieldable( + different_events = yield make_deferred_yieldable( defer.gatherResults( [ - logcontext.run_in_background( + run_in_background( self.store.get_event, d, allow_none=True, allow_rejected=False ) for d in different_auth diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index a1fe9d116f..54c966c8a6 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -21,12 +21,12 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, Codes, SynapseError from synapse.events.validator import EventValidator from synapse.handlers.presence import format_user_presence_state +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.streams.config import PaginationConfig from synapse.types import StreamToken, UserID from synapse.util import unwrapFirstError from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.snapshot_cache import SnapshotCache -from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.visibility import filter_events_for_client from ._base import BaseHandler diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 683da6bf32..eaeda7a5cb 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -34,13 +34,13 @@ from synapse.api.errors import ( from synapse.api.room_versions import RoomVersions from synapse.api.urls import ConsentURIBuilder from synapse.events.validator import EventValidator +from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.storage.state import StateFilter from synapse.types import RoomAlias, UserID, create_requester from synapse.util.async_helpers import Linearizer from synapse.util.frozenutils import frozendict_json_encoder -from synapse.util.logcontext import run_in_background from synapse.util.metrics import measure_func from synapse.visibility import filter_events_for_client diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 76ee97ddd3..20bcfed334 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -20,10 +20,10 @@ from twisted.python.failure import Failure from synapse.api.constants import EventTypes, Membership from synapse.api.errors import SynapseError +from synapse.logging.context import run_in_background from synapse.storage.state import StateFilter from synapse.types import RoomStreamToken from synapse.util.async_helpers import ReadWriteLock -from synapse.util.logcontext import run_in_background from synapse.util.stringutils import random_string from synapse.visibility import filter_events_for_client diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index c80dc2eba0..6f3537e435 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -34,14 +34,14 @@ from twisted.internet import defer import synapse.metrics from synapse.api.constants import EventTypes, Membership, PresenceState from synapse.api.errors import SynapseError +from synapse.logging.context import run_in_background +from synapse.logging.utils import log_function from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.presence import UserPresenceState from synapse.types import UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches.descriptors import cachedInlineCallbacks -from synapse.util.logcontext import run_in_background -from synapse.util.logutils import log_function from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index a3f550554f..cd1ac0a27a 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -25,6 +25,7 @@ from prometheus_client import Counter from twisted.internet import defer from synapse.api.constants import EventTypes, Membership +from synapse.logging.context import LoggingContext from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.roommember import MemberSummary from synapse.storage.state import StateFilter @@ -33,7 +34,6 @@ from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.lrucache import LruCache from synapse.util.caches.response_cache import ResponseCache -from synapse.util.logcontext import LoggingContext from synapse.util.metrics import Measure, measure_func from synapse.visibility import filter_events_for_client diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index f8062c8671..c3e0c8fc7e 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -19,9 +19,9 @@ from collections import namedtuple from twisted.internet import defer from synapse.api.errors import AuthError, SynapseError +from synapse.logging.context import run_in_background from synapse.types import UserID, get_domain_from_id from synapse.util.caches.stream_change_cache import StreamChangeCache -from synapse.util.logcontext import run_in_background from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer diff --git a/synapse/http/client.py b/synapse/http/client.py index 9bc7035c8d..45d5010952 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -45,9 +45,9 @@ from synapse.http import ( cancelled_to_request_timed_out_error, redact_uri, ) +from synapse.logging.context import make_deferred_yieldable from synapse.util.async_helpers import timeout_deferred from synapse.util.caches import CACHE_SIZE_FACTOR -from synapse.util.logcontext import make_deferred_yieldable logger = logging.getLogger(__name__) diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 414cde0777..054c321a20 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -30,9 +30,9 @@ from twisted.web.http_headers import Headers from twisted.web.iweb import IAgent from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list +from synapse.logging.context import make_deferred_yieldable from synapse.util import Clock from synapse.util.caches.ttlcache import TTLCache -from synapse.util.logcontext import make_deferred_yieldable from synapse.util.metrics import Measure # period to cache .well-known results for by default diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py index 1f22f78a75..ecc88f9b96 100644 --- a/synapse/http/federation/srv_resolver.py +++ b/synapse/http/federation/srv_resolver.py @@ -25,7 +25,7 @@ from twisted.internet.error import ConnectError from twisted.names import client, dns from twisted.names.error import DNSNameError, DomainError -from synapse.util.logcontext import make_deferred_yieldable +from synapse.logging.context import make_deferred_yieldable logger = logging.getLogger(__name__) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 5ef8bb60a3..dee3710f68 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -48,8 +48,8 @@ from synapse.api.errors import ( from synapse.http import QuieterFileBodyProducer from synapse.http.client import BlacklistingAgentWrapper, IPBlacklistingResolver from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent +from synapse.logging.context import make_deferred_yieldable from synapse.util.async_helpers import timeout_deferred -from synapse.util.logcontext import make_deferred_yieldable from synapse.util.metrics import Measure logger = logging.getLogger(__name__) diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py index 62045a918b..46af27c8f6 100644 --- a/synapse/http/request_metrics.py +++ b/synapse/http/request_metrics.py @@ -19,8 +19,8 @@ import threading from prometheus_client.core import Counter, Histogram +from synapse.logging.context import LoggingContext from synapse.metrics import LaterGauge -from synapse.util.logcontext import LoggingContext logger = logging.getLogger(__name__) diff --git a/synapse/http/server.py b/synapse/http/server.py index d993161a3e..72a3d67eb6 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -39,8 +39,8 @@ from synapse.api.errors import ( SynapseError, UnrecognizedRequestError, ) +from synapse.logging.context import preserve_fn from synapse.util.caches import intern_dict -from synapse.util.logcontext import preserve_fn logger = logging.getLogger(__name__) diff --git a/synapse/http/site.py b/synapse/http/site.py index 93f679ea48..df5274c177 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -19,7 +19,7 @@ from twisted.web.server import Request, Site from synapse.http import redact_uri from synapse.http.request_metrics import RequestMetrics, requests_counter -from synapse.util.logcontext import LoggingContext, PreserveLoggingContext +from synapse.logging.context import LoggingContext, PreserveLoggingContext logger = logging.getLogger(__name__) diff --git a/synapse/logging/__init__.py b/synapse/logging/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/synapse/logging/context.py b/synapse/logging/context.py new file mode 100644 index 0000000000..30dfa1d6b2 --- /dev/null +++ b/synapse/logging/context.py @@ -0,0 +1,693 @@ +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Thread-local-alike tracking of log contexts within synapse + +This module provides objects and utilities for tracking contexts through +synapse code, so that log lines can include a request identifier, and so that +CPU and database activity can be accounted for against the request that caused +them. + +See doc/log_contexts.rst for details on how this works. +""" + +import logging +import threading +import types + +from twisted.internet import defer, threads + +logger = logging.getLogger(__name__) + +try: + import resource + + # Python doesn't ship with a definition of RUSAGE_THREAD but it's defined + # to be 1 on linux so we hard code it. + RUSAGE_THREAD = 1 + + # If the system doesn't support RUSAGE_THREAD then this should throw an + # exception. + resource.getrusage(RUSAGE_THREAD) + + def get_thread_resource_usage(): + return resource.getrusage(RUSAGE_THREAD) + + +except Exception: + # If the system doesn't support resource.getrusage(RUSAGE_THREAD) then we + # won't track resource usage by returning None. + def get_thread_resource_usage(): + return None + + +# get an id for the current thread. +# +# threading.get_ident doesn't actually return an OS-level tid, and annoyingly, +# on Linux it actually returns the same value either side of a fork() call. However +# we only fork in one place, so it's not worth the hoop-jumping to get a real tid. +# +get_thread_id = threading.get_ident + + +class ContextResourceUsage(object): + """Object for tracking the resources used by a log context + + Attributes: + ru_utime (float): user CPU time (in seconds) + ru_stime (float): system CPU time (in seconds) + db_txn_count (int): number of database transactions done + db_sched_duration_sec (float): amount of time spent waiting for a + database connection + db_txn_duration_sec (float): amount of time spent doing database + transactions (excluding scheduling time) + evt_db_fetch_count (int): number of events requested from the database + """ + + __slots__ = [ + "ru_stime", + "ru_utime", + "db_txn_count", + "db_txn_duration_sec", + "db_sched_duration_sec", + "evt_db_fetch_count", + ] + + def __init__(self, copy_from=None): + """Create a new ContextResourceUsage + + Args: + copy_from (ContextResourceUsage|None): if not None, an object to + copy stats from + """ + if copy_from is None: + self.reset() + else: + self.ru_utime = copy_from.ru_utime + self.ru_stime = copy_from.ru_stime + self.db_txn_count = copy_from.db_txn_count + + self.db_txn_duration_sec = copy_from.db_txn_duration_sec + self.db_sched_duration_sec = copy_from.db_sched_duration_sec + self.evt_db_fetch_count = copy_from.evt_db_fetch_count + + def copy(self): + return ContextResourceUsage(copy_from=self) + + def reset(self): + self.ru_stime = 0.0 + self.ru_utime = 0.0 + self.db_txn_count = 0 + + self.db_txn_duration_sec = 0 + self.db_sched_duration_sec = 0 + self.evt_db_fetch_count = 0 + + def __repr__(self): + return ( + "" + ) % ( + self.ru_stime, + self.ru_utime, + self.db_txn_count, + self.db_txn_duration_sec, + self.db_sched_duration_sec, + self.evt_db_fetch_count, + ) + + def __iadd__(self, other): + """Add another ContextResourceUsage's stats to this one's. + + Args: + other (ContextResourceUsage): the other resource usage object + """ + self.ru_utime += other.ru_utime + self.ru_stime += other.ru_stime + self.db_txn_count += other.db_txn_count + self.db_txn_duration_sec += other.db_txn_duration_sec + self.db_sched_duration_sec += other.db_sched_duration_sec + self.evt_db_fetch_count += other.evt_db_fetch_count + return self + + def __isub__(self, other): + self.ru_utime -= other.ru_utime + self.ru_stime -= other.ru_stime + self.db_txn_count -= other.db_txn_count + self.db_txn_duration_sec -= other.db_txn_duration_sec + self.db_sched_duration_sec -= other.db_sched_duration_sec + self.evt_db_fetch_count -= other.evt_db_fetch_count + return self + + def __add__(self, other): + res = ContextResourceUsage(copy_from=self) + res += other + return res + + def __sub__(self, other): + res = ContextResourceUsage(copy_from=self) + res -= other + return res + + +class LoggingContext(object): + """Additional context for log formatting. Contexts are scoped within a + "with" block. + + If a parent is given when creating a new context, then: + - logging fields are copied from the parent to the new context on entry + - when the new context exits, the cpu usage stats are copied from the + child to the parent + + Args: + name (str): Name for the context for debugging. + parent_context (LoggingContext|None): The parent of the new context + """ + + __slots__ = [ + "previous_context", + "name", + "parent_context", + "_resource_usage", + "usage_start", + "main_thread", + "alive", + "request", + "tag", + ] + + thread_local = threading.local() + + class Sentinel(object): + """Sentinel to represent the root context""" + + __slots__ = [] + + def __str__(self): + return "sentinel" + + def copy_to(self, record): + pass + + def start(self): + pass + + def stop(self): + pass + + def add_database_transaction(self, duration_sec): + pass + + def add_database_scheduled(self, sched_sec): + pass + + def record_event_fetch(self, event_count): + pass + + def __nonzero__(self): + return False + + __bool__ = __nonzero__ # python3 + + sentinel = Sentinel() + + def __init__(self, name=None, parent_context=None, request=None): + self.previous_context = LoggingContext.current_context() + self.name = name + + # track the resources used by this context so far + self._resource_usage = ContextResourceUsage() + + # If alive has the thread resource usage when the logcontext last + # became active. + self.usage_start = None + + self.main_thread = get_thread_id() + self.request = None + self.tag = "" + self.alive = True + + self.parent_context = parent_context + + if self.parent_context is not None: + self.parent_context.copy_to(self) + + if request is not None: + # the request param overrides the request from the parent context + self.request = request + + def __str__(self): + if self.request: + return str(self.request) + return "%s@%x" % (self.name, id(self)) + + @classmethod + def current_context(cls): + """Get the current logging context from thread local storage + + Returns: + LoggingContext: the current logging context + """ + return getattr(cls.thread_local, "current_context", cls.sentinel) + + @classmethod + def set_current_context(cls, context): + """Set the current logging context in thread local storage + Args: + context(LoggingContext): The context to activate. + Returns: + The context that was previously active + """ + current = cls.current_context() + + if current is not context: + current.stop() + cls.thread_local.current_context = context + context.start() + return current + + def __enter__(self): + """Enters this logging context into thread local storage""" + old_context = self.set_current_context(self) + if self.previous_context != old_context: + logger.warn( + "Expected previous context %r, found %r", + self.previous_context, + old_context, + ) + self.alive = True + + return self + + def __exit__(self, type, value, traceback): + """Restore the logging context in thread local storage to the state it + was before this context was entered. + Returns: + None to avoid suppressing any exceptions that were thrown. + """ + current = self.set_current_context(self.previous_context) + if current is not self: + if current is self.sentinel: + logger.warning("Expected logging context %s was lost", self) + else: + logger.warning( + "Expected logging context %s but found %s", self, current + ) + self.previous_context = None + self.alive = False + + # if we have a parent, pass our CPU usage stats on + if self.parent_context is not None and hasattr( + self.parent_context, "_resource_usage" + ): + self.parent_context._resource_usage += self._resource_usage + + # reset them in case we get entered again + self._resource_usage.reset() + + def copy_to(self, record): + """Copy logging fields from this context to a log record or + another LoggingContext + """ + + # 'request' is the only field we currently use in the logger, so that's + # all we need to copy + record.request = self.request + + def start(self): + if get_thread_id() != self.main_thread: + logger.warning("Started logcontext %s on different thread", self) + return + + # If we haven't already started record the thread resource usage so + # far + if not self.usage_start: + self.usage_start = get_thread_resource_usage() + + def stop(self): + if get_thread_id() != self.main_thread: + logger.warning("Stopped logcontext %s on different thread", self) + return + + # When we stop, let's record the cpu used since we started + if not self.usage_start: + logger.warning("Called stop on logcontext %s without calling start", self) + return + + utime_delta, stime_delta = self._get_cputime() + self._resource_usage.ru_utime += utime_delta + self._resource_usage.ru_stime += stime_delta + + self.usage_start = None + + def get_resource_usage(self): + """Get resources used by this logcontext so far. + + Returns: + ContextResourceUsage: a *copy* of the object tracking resource + usage so far + """ + # we always return a copy, for consistency + res = self._resource_usage.copy() + + # If we are on the correct thread and we're currently running then we + # can include resource usage so far. + is_main_thread = get_thread_id() == self.main_thread + if self.alive and self.usage_start and is_main_thread: + utime_delta, stime_delta = self._get_cputime() + res.ru_utime += utime_delta + res.ru_stime += stime_delta + + return res + + def _get_cputime(self): + """Get the cpu usage time so far + + Returns: Tuple[float, float]: seconds in user mode, seconds in system mode + """ + current = get_thread_resource_usage() + + utime_delta = current.ru_utime - self.usage_start.ru_utime + stime_delta = current.ru_stime - self.usage_start.ru_stime + + # sanity check + if utime_delta < 0: + logger.error( + "utime went backwards! %f < %f", + current.ru_utime, + self.usage_start.ru_utime, + ) + utime_delta = 0 + + if stime_delta < 0: + logger.error( + "stime went backwards! %f < %f", + current.ru_stime, + self.usage_start.ru_stime, + ) + stime_delta = 0 + + return utime_delta, stime_delta + + def add_database_transaction(self, duration_sec): + if duration_sec < 0: + raise ValueError("DB txn time can only be non-negative") + self._resource_usage.db_txn_count += 1 + self._resource_usage.db_txn_duration_sec += duration_sec + + def add_database_scheduled(self, sched_sec): + """Record a use of the database pool + + Args: + sched_sec (float): number of seconds it took us to get a + connection + """ + if sched_sec < 0: + raise ValueError("DB scheduling time can only be non-negative") + self._resource_usage.db_sched_duration_sec += sched_sec + + def record_event_fetch(self, event_count): + """Record a number of events being fetched from the db + + Args: + event_count (int): number of events being fetched + """ + self._resource_usage.evt_db_fetch_count += event_count + + +class LoggingContextFilter(logging.Filter): + """Logging filter that adds values from the current logging context to each + record. + Args: + **defaults: Default values to avoid formatters complaining about + missing fields + """ + + def __init__(self, **defaults): + self.defaults = defaults + + def filter(self, record): + """Add each fields from the logging contexts to the record. + Returns: + True to include the record in the log output. + """ + context = LoggingContext.current_context() + for key, value in self.defaults.items(): + setattr(record, key, value) + + # context should never be None, but if it somehow ends up being, then + # we end up in a death spiral of infinite loops, so let's check, for + # robustness' sake. + if context is not None: + context.copy_to(record) + + return True + + +class PreserveLoggingContext(object): + """Captures the current logging context and restores it when the scope is + exited. Used to restore the context after a function using + @defer.inlineCallbacks is resumed by a callback from the reactor.""" + + __slots__ = ["current_context", "new_context", "has_parent"] + + def __init__(self, new_context=None): + if new_context is None: + new_context = LoggingContext.sentinel + self.new_context = new_context + + def __enter__(self): + """Captures the current logging context""" + self.current_context = LoggingContext.set_current_context(self.new_context) + + if self.current_context: + self.has_parent = self.current_context.previous_context is not None + if not self.current_context.alive: + logger.debug("Entering dead context: %s", self.current_context) + + def __exit__(self, type, value, traceback): + """Restores the current logging context""" + context = LoggingContext.set_current_context(self.current_context) + + if context != self.new_context: + if context is LoggingContext.sentinel: + logger.warning("Expected logging context %s was lost", self.new_context) + else: + logger.warning( + "Expected logging context %s but found %s", + self.new_context, + context, + ) + + if self.current_context is not LoggingContext.sentinel: + if not self.current_context.alive: + logger.debug("Restoring dead context: %s", self.current_context) + + +def nested_logging_context(suffix, parent_context=None): + """Creates a new logging context as a child of another. + + The nested logging context will have a 'request' made up of the parent context's + request, plus the given suffix. + + CPU/db usage stats will be added to the parent context's on exit. + + Normal usage looks like: + + with nested_logging_context(suffix): + # ... do stuff + + Args: + suffix (str): suffix to add to the parent context's 'request'. + parent_context (LoggingContext|None): parent context. Will use the current context + if None. + + Returns: + LoggingContext: new logging context. + """ + if parent_context is None: + parent_context = LoggingContext.current_context() + return LoggingContext( + parent_context=parent_context, request=parent_context.request + "-" + suffix + ) + + +def preserve_fn(f): + """Function decorator which wraps the function with run_in_background""" + + def g(*args, **kwargs): + return run_in_background(f, *args, **kwargs) + + return g + + +def run_in_background(f, *args, **kwargs): + """Calls a function, ensuring that the current context is restored after + return from the function, and that the sentinel context is set once the + deferred returned by the function completes. + + Useful for wrapping functions that return a deferred or coroutine, which you don't + yield or await on (for instance because you want to pass it to + deferred.gatherResults()). + + Note that if you completely discard the result, you should make sure that + `f` doesn't raise any deferred exceptions, otherwise a scary-looking + CRITICAL error about an unhandled error will be logged without much + indication about where it came from. + """ + current = LoggingContext.current_context() + try: + res = f(*args, **kwargs) + except: # noqa: E722 + # the assumption here is that the caller doesn't want to be disturbed + # by synchronous exceptions, so let's turn them into Failures. + return defer.fail() + + if isinstance(res, types.CoroutineType): + res = defer.ensureDeferred(res) + + if not isinstance(res, defer.Deferred): + return res + + if res.called and not res.paused: + # The function should have maintained the logcontext, so we can + # optimise out the messing about + return res + + # The function may have reset the context before returning, so + # we need to restore it now. + ctx = LoggingContext.set_current_context(current) + + # The original context will be restored when the deferred + # completes, but there is nothing waiting for it, so it will + # get leaked into the reactor or some other function which + # wasn't expecting it. We therefore need to reset the context + # here. + # + # (If this feels asymmetric, consider it this way: we are + # effectively forking a new thread of execution. We are + # probably currently within a ``with LoggingContext()`` block, + # which is supposed to have a single entry and exit point. But + # by spawning off another deferred, we are effectively + # adding a new exit point.) + res.addBoth(_set_context_cb, ctx) + return res + + +def make_deferred_yieldable(deferred): + """Given a deferred, make it follow the Synapse logcontext rules: + + If the deferred has completed (or is not actually a Deferred), essentially + does nothing (just returns another completed deferred with the + result/failure). + + If the deferred has not yet completed, resets the logcontext before + returning a deferred. Then, when the deferred completes, restores the + current logcontext before running callbacks/errbacks. + + (This is more-or-less the opposite operation to run_in_background.) + """ + if not isinstance(deferred, defer.Deferred): + return deferred + + if deferred.called and not deferred.paused: + # it looks like this deferred is ready to run any callbacks we give it + # immediately. We may as well optimise out the logcontext faffery. + return deferred + + # ok, we can't be sure that a yield won't block, so let's reset the + # logcontext, and add a callback to the deferred to restore it. + prev_context = LoggingContext.set_current_context(LoggingContext.sentinel) + deferred.addBoth(_set_context_cb, prev_context) + return deferred + + +def _set_context_cb(result, context): + """A callback function which just sets the logging context""" + LoggingContext.set_current_context(context) + return result + + +def defer_to_thread(reactor, f, *args, **kwargs): + """ + Calls the function `f` using a thread from the reactor's default threadpool and + returns the result as a Deferred. + + Creates a new logcontext for `f`, which is created as a child of the current + logcontext (so its CPU usage metrics will get attributed to the current + logcontext). `f` should preserve the logcontext it is given. + + The result deferred follows the Synapse logcontext rules: you should `yield` + on it. + + Args: + reactor (twisted.internet.base.ReactorBase): The reactor in whose main thread + the Deferred will be invoked, and whose threadpool we should use for the + function. + + Normally this will be hs.get_reactor(). + + f (callable): The function to call. + + args: positional arguments to pass to f. + + kwargs: keyword arguments to pass to f. + + Returns: + Deferred: A Deferred which fires a callback with the result of `f`, or an + errback if `f` throws an exception. + """ + return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs) + + +def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs): + """ + A wrapper for twisted.internet.threads.deferToThreadpool, which handles + logcontexts correctly. + + Calls the function `f` using a thread from the given threadpool and returns + the result as a Deferred. + + Creates a new logcontext for `f`, which is created as a child of the current + logcontext (so its CPU usage metrics will get attributed to the current + logcontext). `f` should preserve the logcontext it is given. + + The result deferred follows the Synapse logcontext rules: you should `yield` + on it. + + Args: + reactor (twisted.internet.base.ReactorBase): The reactor in whose main thread + the Deferred will be invoked. Normally this will be hs.get_reactor(). + + threadpool (twisted.python.threadpool.ThreadPool): The threadpool to use for + running `f`. Normally this will be hs.get_reactor().getThreadPool(). + + f (callable): The function to call. + + args: positional arguments to pass to f. + + kwargs: keyword arguments to pass to f. + + Returns: + Deferred: A Deferred which fires a callback with the result of `f`, or an + errback if `f` throws an exception. + """ + logcontext = LoggingContext.current_context() + + def g(): + with LoggingContext(parent_context=logcontext): + return f(*args, **kwargs) + + return make_deferred_yieldable(threads.deferToThreadPool(reactor, threadpool, g)) diff --git a/synapse/logging/formatter.py b/synapse/logging/formatter.py new file mode 100644 index 0000000000..fbf570c756 --- /dev/null +++ b/synapse/logging/formatter.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import traceback + +from six import StringIO + + +class LogFormatter(logging.Formatter): + """Log formatter which gives more detail for exceptions + + This is the same as the standard log formatter, except that when logging + exceptions [typically via log.foo("msg", exc_info=1)], it prints the + sequence that led up to the point at which the exception was caught. + (Normally only stack frames between the point the exception was raised and + where it was caught are logged). + """ + + def __init__(self, *args, **kwargs): + super(LogFormatter, self).__init__(*args, **kwargs) + + def formatException(self, ei): + sio = StringIO() + (typ, val, tb) = ei + + # log the stack above the exception capture point if possible, but + # check that we actually have an f_back attribute to work around + # https://twistedmatrix.com/trac/ticket/9305 + + if tb and hasattr(tb.tb_frame, "f_back"): + sio.write("Capture point (most recent call last):\n") + traceback.print_stack(tb.tb_frame.f_back, None, sio) + + traceback.print_exception(typ, val, tb, None, sio) + s = sio.getvalue() + sio.close() + if s[-1:] == "\n": + s = s[:-1] + return s diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py new file mode 100644 index 0000000000..7df0fa6087 --- /dev/null +++ b/synapse/logging/utils.py @@ -0,0 +1,194 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +import logging +import time +from functools import wraps +from inspect import getcallargs + +from six import PY3 + +_TIME_FUNC_ID = 0 + + +def _log_debug_as_f(f, msg, msg_args): + name = f.__module__ + logger = logging.getLogger(name) + + if logger.isEnabledFor(logging.DEBUG): + if PY3: + lineno = f.__code__.co_firstlineno + pathname = f.__code__.co_filename + else: + lineno = f.func_code.co_firstlineno + pathname = f.func_code.co_filename + + record = logging.LogRecord( + name=name, + level=logging.DEBUG, + pathname=pathname, + lineno=lineno, + msg=msg, + args=msg_args, + exc_info=None, + ) + + logger.handle(record) + + +def log_function(f): + """ Function decorator that logs every call to that function. + """ + func_name = f.__name__ + + @wraps(f) + def wrapped(*args, **kwargs): + name = f.__module__ + logger = logging.getLogger(name) + level = logging.DEBUG + + if logger.isEnabledFor(level): + bound_args = getcallargs(f, *args, **kwargs) + + def format(value): + r = str(value) + if len(r) > 50: + r = r[:50] + "..." + return r + + func_args = ["%s=%s" % (k, format(v)) for k, v in bound_args.items()] + + msg_args = {"func_name": func_name, "args": ", ".join(func_args)} + + _log_debug_as_f(f, "Invoked '%(func_name)s' with args: %(args)s", msg_args) + + return f(*args, **kwargs) + + wrapped.__name__ = func_name + return wrapped + + +def time_function(f): + func_name = f.__name__ + + @wraps(f) + def wrapped(*args, **kwargs): + global _TIME_FUNC_ID + id = _TIME_FUNC_ID + _TIME_FUNC_ID += 1 + + start = time.clock() + + try: + _log_debug_as_f(f, "[FUNC START] {%s-%d}", (func_name, id)) + + r = f(*args, **kwargs) + finally: + end = time.clock() + _log_debug_as_f( + f, "[FUNC END] {%s-%d} %.3f sec", (func_name, id, end - start) + ) + + return r + + return wrapped + + +def trace_function(f): + func_name = f.__name__ + linenum = f.func_code.co_firstlineno + pathname = f.func_code.co_filename + + @wraps(f) + def wrapped(*args, **kwargs): + name = f.__module__ + logger = logging.getLogger(name) + level = logging.DEBUG + + s = inspect.currentframe().f_back + + to_print = [ + "\t%s:%s %s. Args: args=%s, kwargs=%s" + % (pathname, linenum, func_name, args, kwargs) + ] + while s: + if True or s.f_globals["__name__"].startswith("synapse"): + filename, lineno, function, _, _ = inspect.getframeinfo(s) + args_string = inspect.formatargvalues(*inspect.getargvalues(s)) + + to_print.append( + "\t%s:%d %s. Args: %s" % (filename, lineno, function, args_string) + ) + + s = s.f_back + + msg = "\nTraceback for %s:\n" % (func_name,) + "\n".join(to_print) + + record = logging.LogRecord( + name=name, + level=level, + pathname=pathname, + lineno=lineno, + msg=msg, + args=None, + exc_info=None, + ) + + logger.handle(record) + + return f(*args, **kwargs) + + wrapped.__name__ = func_name + return wrapped + + +def get_previous_frames(): + s = inspect.currentframe().f_back.f_back + to_return = [] + while s: + if s.f_globals["__name__"].startswith("synapse"): + filename, lineno, function, _, _ = inspect.getframeinfo(s) + args_string = inspect.formatargvalues(*inspect.getargvalues(s)) + + to_return.append( + "{{ %s:%d %s - Args: %s }}" % (filename, lineno, function, args_string) + ) + + s = s.f_back + + return ", ".join(to_return) + + +def get_previous_frame(ignore=[]): + s = inspect.currentframe().f_back.f_back + + while s: + if s.f_globals["__name__"].startswith("synapse"): + if not any(s.f_globals["__name__"].startswith(ig) for ig in ignore): + filename, lineno, function, _, _ = inspect.getframeinfo(s) + args_string = inspect.formatargvalues(*inspect.getargvalues(s)) + + return "{{ %s:%d %s - Args: %s }}" % ( + filename, + lineno, + function, + args_string, + ) + + s = s.f_back + + return None diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 167e2c068a..edd6b42db3 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -22,7 +22,7 @@ from prometheus_client.core import REGISTRY, Counter, GaugeMetricFamily from twisted.internet import defer -from synapse.util.logcontext import LoggingContext, PreserveLoggingContext +from synapse.logging.context import LoggingContext, PreserveLoggingContext logger = logging.getLogger(__name__) diff --git a/synapse/notifier.py b/synapse/notifier.py index d398078eed..918ef64897 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -23,12 +23,12 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError from synapse.handlers.presence import format_user_presence_state +from synapse.logging.context import PreserveLoggingContext +from synapse.logging.utils import log_function from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import StreamToken from synapse.util.async_helpers import ObservableDeferred, timeout_deferred -from synapse.util.logcontext import PreserveLoggingContext -from synapse.util.logutils import log_function from synapse.util.metrics import Measure from synapse.visibility import filter_events_for_client diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 809199fe88..521c6e2cd7 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -29,6 +29,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.api.errors import StoreError +from synapse.logging.context import make_deferred_yieldable from synapse.push.presentable_names import ( calculate_room_name, descriptor_from_member_events, @@ -36,7 +37,6 @@ from synapse.push.presentable_names import ( ) from synapse.types import UserID from synapse.util.async_helpers import concurrently_execute -from synapse.util.logcontext import make_deferred_yieldable from synapse.visibility import filter_events_for_client logger = logging.getLogger(__name__) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 97efb835ad..5ffdf2675d 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -62,9 +62,9 @@ from twisted.internet import defer from twisted.protocols.basic import LineOnlyReceiver from twisted.python.failure import Failure +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.stringutils import random_string from .commands import ( diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py index 36404b797d..6da71dc46f 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py @@ -17,8 +17,8 @@ to ensure idempotency when performing PUTs using the REST API.""" import logging +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.util.async_helpers import ObservableDeferred -from synapse.util.logcontext import make_deferred_yieldable, run_in_background logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 3318638d3e..5fefee4dde 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -25,7 +25,7 @@ from twisted.protocols.basic import FileSender from synapse.api.errors import Codes, SynapseError, cs_error from synapse.http.server import finish_request, respond_with_json -from synapse.util import logcontext +from synapse.logging.context import make_deferred_yieldable from synapse.util.stringutils import is_ascii logger = logging.getLogger(__name__) @@ -75,9 +75,7 @@ def respond_with_file(request, media_type, file_path, file_size=None, upload_nam add_file_headers(request, media_type, file_size, upload_name) with open(file_path, "rb") as f: - yield logcontext.make_deferred_yieldable( - FileSender().beginFileTransfer(f, request) - ) + yield make_deferred_yieldable(FileSender().beginFileTransfer(f, request)) finish_request(request) else: diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index df3d985a38..65afffbb42 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -33,8 +33,8 @@ from synapse.api.errors import ( RequestSendFailed, SynapseError, ) +from synapse.logging.context import defer_to_thread from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util import logcontext from synapse.util.async_helpers import Linearizer from synapse.util.retryutils import NotRetryingDestination from synapse.util.stringutils import random_string @@ -463,7 +463,7 @@ class MediaRepository(object): ) thumbnailer = Thumbnailer(input_path) - t_byte_source = yield logcontext.defer_to_thread( + t_byte_source = yield defer_to_thread( self.hs.get_reactor(), self._generate_thumbnail, thumbnailer, @@ -511,7 +511,7 @@ class MediaRepository(object): ) thumbnailer = Thumbnailer(input_path) - t_byte_source = yield logcontext.defer_to_thread( + t_byte_source = yield defer_to_thread( self.hs.get_reactor(), self._generate_thumbnail, thumbnailer, @@ -596,7 +596,7 @@ class MediaRepository(object): return if thumbnailer.transpose_method is not None: - m_width, m_height = yield logcontext.defer_to_thread( + m_width, m_height = yield defer_to_thread( self.hs.get_reactor(), thumbnailer.transpose ) @@ -616,11 +616,11 @@ class MediaRepository(object): for (t_width, t_height, t_type), t_method in iteritems(thumbnails): # Generate the thumbnail if t_method == "crop": - t_byte_source = yield logcontext.defer_to_thread( + t_byte_source = yield defer_to_thread( self.hs.get_reactor(), thumbnailer.crop, t_width, t_height, t_type ) elif t_method == "scale": - t_byte_source = yield logcontext.defer_to_thread( + t_byte_source = yield defer_to_thread( self.hs.get_reactor(), thumbnailer.scale, t_width, t_height, t_type ) else: diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index eff86836fb..25e5ac2848 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -24,9 +24,8 @@ import six from twisted.internet import defer from twisted.protocols.basic import FileSender -from synapse.util import logcontext +from synapse.logging.context import defer_to_thread, make_deferred_yieldable from synapse.util.file_consumer import BackgroundFileConsumer -from synapse.util.logcontext import make_deferred_yieldable from ._base import Responder @@ -65,7 +64,7 @@ class MediaStorage(object): with self.store_into_file(file_info) as (f, fname, finish_cb): # Write to the main repository - yield logcontext.defer_to_thread( + yield defer_to_thread( self.hs.get_reactor(), _write_file_synchronously, source, f ) yield finish_cb() diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 053346fb86..5871737bfd 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -42,11 +42,11 @@ from synapse.http.server import ( wrap_json_request_handler, ) from synapse.http.servlet import parse_integer, parse_string +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.media.v1._base import get_filename_from_headers from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.expiringcache import ExpiringCache -from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.stringutils import random_string from ._base import FileInfo diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index 359b45ebfc..e8f559acc1 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -20,8 +20,7 @@ import shutil from twisted.internet import defer from synapse.config._base import Config -from synapse.util import logcontext -from synapse.util.logcontext import run_in_background +from synapse.logging.context import defer_to_thread, run_in_background from .media_storage import FileResponder @@ -125,7 +124,7 @@ class FileStorageProviderBackend(StorageProvider): if not os.path.exists(dirname): os.makedirs(dirname) - return logcontext.defer_to_thread( + return defer_to_thread( self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname ) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 1b454a56a1..9f708fa205 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -28,11 +28,11 @@ from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions from synapse.events.snapshot import EventContext +from synapse.logging.utils import log_function from synapse.state import v1, v2 from synapse.util.async_helpers import Linearizer from synapse.util.caches import get_cache_factor_for from synapse.util.caches.expiringcache import ExpiringCache -from synapse.util.logutils import log_function from synapse.util.metrics import Measure logger = logging.getLogger(__name__) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 29589853c6..2f940dbae6 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -30,12 +30,12 @@ from prometheus_client import Histogram from twisted.internet import defer from synapse.api.errors import StoreError +from synapse.logging.context import LoggingContext, PreserveLoggingContext from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.types import get_domain_from_id from synapse.util import batch_iter from synapse.util.caches.descriptors import Cache -from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.stringutils import exception_to_unicode # import a function which will return a monotonic time, in seconds diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 86f8485704..b486ca50eb 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -33,6 +33,8 @@ from synapse.api.constants import EventTypes from synapse.api.errors import SynapseError from synapse.events import EventBase # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401 +from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable +from synapse.logging.utils import log_function from synapse.metrics import BucketCollector from synapse.metrics.background_process_metrics import run_as_background_process from synapse.state import StateResolutionStore @@ -45,8 +47,6 @@ from synapse.util import batch_iter from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.frozenutils import frozendict_json_encoder -from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable -from synapse.util.logutils import log_function from synapse.util.metrics import Measure logger = logging.getLogger(__name__) diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index 6d680d405a..09db872511 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -29,14 +29,14 @@ from synapse.api.room_versions import EventFormatVersions from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401 from synapse.events.utils import prune_event -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import get_domain_from_id -from synapse.util.logcontext import ( +from synapse.logging.context import ( LoggingContext, PreserveLoggingContext, make_deferred_yieldable, run_in_background, ) +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.types import get_domain_from_id from synapse.util.metrics import Measure from ._base import SQLBaseStore diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index d9482a3843..386a9dbe14 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -41,12 +41,12 @@ from six.moves import range from twisted.internet import defer +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage._base import SQLBaseStore from synapse.storage.engines import PostgresEngine from synapse.storage.events_worker import EventsWorkerStore from synapse.types import RoomStreamToken from synapse.util.caches.stream_change_cache import StreamChangeCache -from synapse.util.logcontext import make_deferred_yieldable, run_in_background logger = logging.getLogger(__name__) diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 954e32fb2a..c6d2ce4404 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -21,10 +21,14 @@ import attr from twisted.internet import defer, task -from synapse.util.logcontext import PreserveLoggingContext +from synapse.logging import context, formatter logger = logging.getLogger(__name__) +# Compatibility alias, for existing logconfigs. +logcontext = context +logformatter = formatter + def unwrapFirstError(failure): # defer.gatherResults and DeferredLists wrap failures. @@ -46,7 +50,7 @@ class Clock(object): @defer.inlineCallbacks def sleep(self, seconds): d = defer.Deferred() - with PreserveLoggingContext(): + with context.PreserveLoggingContext(): self._reactor.callLater(seconds, d.callback, seconds) res = yield d defer.returnValue(res) @@ -91,10 +95,10 @@ class Clock(object): """ def wrapped_callback(*args, **kwargs): - with PreserveLoggingContext(): + with context.PreserveLoggingContext(): callback(*args, **kwargs) - with PreserveLoggingContext(): + with context.PreserveLoggingContext(): return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs) def cancel_call_later(self, timer, ignore_errs=False): diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 7757b8708a..58a6b8764f 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -23,13 +23,12 @@ from twisted.internet import defer from twisted.internet.defer import CancelledError from twisted.python import failure -from synapse.util import Clock, logcontext, unwrapFirstError - -from .logcontext import ( +from synapse.logging.context import ( PreserveLoggingContext, make_deferred_yieldable, run_in_background, ) +from synapse.util import Clock, unwrapFirstError logger = logging.getLogger(__name__) @@ -153,7 +152,7 @@ def concurrently_execute(func, args, limit): except StopIteration: pass - return logcontext.make_deferred_yieldable( + return make_deferred_yieldable( defer.gatherResults( [run_in_background(_concurrently_execute_inner) for _ in range(limit)], consumeErrors=True, @@ -174,7 +173,7 @@ def yieldable_gather_results(func, iter, *args, **kwargs): Deferred[list]: Resolved when all functions have been invoked, or errors if one of the function calls fails. """ - return logcontext.make_deferred_yieldable( + return make_deferred_yieldable( defer.gatherResults( [run_in_background(func, item, *args, **kwargs) for item in iter], consumeErrors=True, diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index d2f25063aa..675db2f448 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -24,7 +24,8 @@ from six import itervalues, string_types from twisted.internet import defer -from synapse.util import logcontext, unwrapFirstError +from synapse.logging.context import make_deferred_yieldable, preserve_fn +from synapse.util import unwrapFirstError from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches import get_cache_factor_for from synapse.util.caches.lrucache import LruCache @@ -388,7 +389,7 @@ class CacheDescriptor(_CacheDescriptorBase): except KeyError: ret = defer.maybeDeferred( - logcontext.preserve_fn(self.function_to_call), obj, *args, **kwargs + preserve_fn(self.function_to_call), obj, *args, **kwargs ) def onErr(f): @@ -408,7 +409,7 @@ class CacheDescriptor(_CacheDescriptorBase): observer = result_d.observe() if isinstance(observer, defer.Deferred): - return logcontext.make_deferred_yieldable(observer) + return make_deferred_yieldable(observer) else: return observer @@ -563,7 +564,7 @@ class CacheListDescriptor(_CacheDescriptorBase): cached_defers.append( defer.maybeDeferred( - logcontext.preserve_fn(self.function_to_call), **args_to_call + preserve_fn(self.function_to_call), **args_to_call ).addCallbacks(complete_all, errback) ) @@ -571,7 +572,7 @@ class CacheListDescriptor(_CacheDescriptorBase): d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks( lambda _: results, unwrapFirstError ) - return logcontext.make_deferred_yieldable(d) + return make_deferred_yieldable(d) else: return results diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index cbe54d45dd..d6908e169d 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -16,9 +16,9 @@ import logging from twisted.internet import defer +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches import register_cache -from synapse.util.logcontext import make_deferred_yieldable, run_in_background logger = logging.getLogger(__name__) @@ -78,7 +78,7 @@ class ResponseCache(object): *deferred* should run its callbacks in the sentinel logcontext (ie, you should wrap normal synapse deferreds with - logcontext.run_in_background). + synapse.logging.context.run_in_background). Can return either a new Deferred (which also doesn't follow the synapse logcontext rules), or, if *deferred* was already complete, the actual diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index 5a79db821c..45af8d3eeb 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -17,8 +17,8 @@ import logging from twisted.internet import defer +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util.logcontext import make_deferred_yieldable, run_in_background logger = logging.getLogger(__name__) diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py index 629ed44149..8b17d1c8b8 100644 --- a/synapse/util/file_consumer.py +++ b/synapse/util/file_consumer.py @@ -17,7 +17,7 @@ from six.moves import queue from twisted.internet import threads -from synapse.util.logcontext import make_deferred_yieldable, run_in_background +from synapse.logging.context import make_deferred_yieldable, run_in_background class BackgroundFileConsumer(object): diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py deleted file mode 100644 index 30dfa1d6b2..0000000000 --- a/synapse/util/logcontext.py +++ /dev/null @@ -1,693 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" Thread-local-alike tracking of log contexts within synapse - -This module provides objects and utilities for tracking contexts through -synapse code, so that log lines can include a request identifier, and so that -CPU and database activity can be accounted for against the request that caused -them. - -See doc/log_contexts.rst for details on how this works. -""" - -import logging -import threading -import types - -from twisted.internet import defer, threads - -logger = logging.getLogger(__name__) - -try: - import resource - - # Python doesn't ship with a definition of RUSAGE_THREAD but it's defined - # to be 1 on linux so we hard code it. - RUSAGE_THREAD = 1 - - # If the system doesn't support RUSAGE_THREAD then this should throw an - # exception. - resource.getrusage(RUSAGE_THREAD) - - def get_thread_resource_usage(): - return resource.getrusage(RUSAGE_THREAD) - - -except Exception: - # If the system doesn't support resource.getrusage(RUSAGE_THREAD) then we - # won't track resource usage by returning None. - def get_thread_resource_usage(): - return None - - -# get an id for the current thread. -# -# threading.get_ident doesn't actually return an OS-level tid, and annoyingly, -# on Linux it actually returns the same value either side of a fork() call. However -# we only fork in one place, so it's not worth the hoop-jumping to get a real tid. -# -get_thread_id = threading.get_ident - - -class ContextResourceUsage(object): - """Object for tracking the resources used by a log context - - Attributes: - ru_utime (float): user CPU time (in seconds) - ru_stime (float): system CPU time (in seconds) - db_txn_count (int): number of database transactions done - db_sched_duration_sec (float): amount of time spent waiting for a - database connection - db_txn_duration_sec (float): amount of time spent doing database - transactions (excluding scheduling time) - evt_db_fetch_count (int): number of events requested from the database - """ - - __slots__ = [ - "ru_stime", - "ru_utime", - "db_txn_count", - "db_txn_duration_sec", - "db_sched_duration_sec", - "evt_db_fetch_count", - ] - - def __init__(self, copy_from=None): - """Create a new ContextResourceUsage - - Args: - copy_from (ContextResourceUsage|None): if not None, an object to - copy stats from - """ - if copy_from is None: - self.reset() - else: - self.ru_utime = copy_from.ru_utime - self.ru_stime = copy_from.ru_stime - self.db_txn_count = copy_from.db_txn_count - - self.db_txn_duration_sec = copy_from.db_txn_duration_sec - self.db_sched_duration_sec = copy_from.db_sched_duration_sec - self.evt_db_fetch_count = copy_from.evt_db_fetch_count - - def copy(self): - return ContextResourceUsage(copy_from=self) - - def reset(self): - self.ru_stime = 0.0 - self.ru_utime = 0.0 - self.db_txn_count = 0 - - self.db_txn_duration_sec = 0 - self.db_sched_duration_sec = 0 - self.evt_db_fetch_count = 0 - - def __repr__(self): - return ( - "" - ) % ( - self.ru_stime, - self.ru_utime, - self.db_txn_count, - self.db_txn_duration_sec, - self.db_sched_duration_sec, - self.evt_db_fetch_count, - ) - - def __iadd__(self, other): - """Add another ContextResourceUsage's stats to this one's. - - Args: - other (ContextResourceUsage): the other resource usage object - """ - self.ru_utime += other.ru_utime - self.ru_stime += other.ru_stime - self.db_txn_count += other.db_txn_count - self.db_txn_duration_sec += other.db_txn_duration_sec - self.db_sched_duration_sec += other.db_sched_duration_sec - self.evt_db_fetch_count += other.evt_db_fetch_count - return self - - def __isub__(self, other): - self.ru_utime -= other.ru_utime - self.ru_stime -= other.ru_stime - self.db_txn_count -= other.db_txn_count - self.db_txn_duration_sec -= other.db_txn_duration_sec - self.db_sched_duration_sec -= other.db_sched_duration_sec - self.evt_db_fetch_count -= other.evt_db_fetch_count - return self - - def __add__(self, other): - res = ContextResourceUsage(copy_from=self) - res += other - return res - - def __sub__(self, other): - res = ContextResourceUsage(copy_from=self) - res -= other - return res - - -class LoggingContext(object): - """Additional context for log formatting. Contexts are scoped within a - "with" block. - - If a parent is given when creating a new context, then: - - logging fields are copied from the parent to the new context on entry - - when the new context exits, the cpu usage stats are copied from the - child to the parent - - Args: - name (str): Name for the context for debugging. - parent_context (LoggingContext|None): The parent of the new context - """ - - __slots__ = [ - "previous_context", - "name", - "parent_context", - "_resource_usage", - "usage_start", - "main_thread", - "alive", - "request", - "tag", - ] - - thread_local = threading.local() - - class Sentinel(object): - """Sentinel to represent the root context""" - - __slots__ = [] - - def __str__(self): - return "sentinel" - - def copy_to(self, record): - pass - - def start(self): - pass - - def stop(self): - pass - - def add_database_transaction(self, duration_sec): - pass - - def add_database_scheduled(self, sched_sec): - pass - - def record_event_fetch(self, event_count): - pass - - def __nonzero__(self): - return False - - __bool__ = __nonzero__ # python3 - - sentinel = Sentinel() - - def __init__(self, name=None, parent_context=None, request=None): - self.previous_context = LoggingContext.current_context() - self.name = name - - # track the resources used by this context so far - self._resource_usage = ContextResourceUsage() - - # If alive has the thread resource usage when the logcontext last - # became active. - self.usage_start = None - - self.main_thread = get_thread_id() - self.request = None - self.tag = "" - self.alive = True - - self.parent_context = parent_context - - if self.parent_context is not None: - self.parent_context.copy_to(self) - - if request is not None: - # the request param overrides the request from the parent context - self.request = request - - def __str__(self): - if self.request: - return str(self.request) - return "%s@%x" % (self.name, id(self)) - - @classmethod - def current_context(cls): - """Get the current logging context from thread local storage - - Returns: - LoggingContext: the current logging context - """ - return getattr(cls.thread_local, "current_context", cls.sentinel) - - @classmethod - def set_current_context(cls, context): - """Set the current logging context in thread local storage - Args: - context(LoggingContext): The context to activate. - Returns: - The context that was previously active - """ - current = cls.current_context() - - if current is not context: - current.stop() - cls.thread_local.current_context = context - context.start() - return current - - def __enter__(self): - """Enters this logging context into thread local storage""" - old_context = self.set_current_context(self) - if self.previous_context != old_context: - logger.warn( - "Expected previous context %r, found %r", - self.previous_context, - old_context, - ) - self.alive = True - - return self - - def __exit__(self, type, value, traceback): - """Restore the logging context in thread local storage to the state it - was before this context was entered. - Returns: - None to avoid suppressing any exceptions that were thrown. - """ - current = self.set_current_context(self.previous_context) - if current is not self: - if current is self.sentinel: - logger.warning("Expected logging context %s was lost", self) - else: - logger.warning( - "Expected logging context %s but found %s", self, current - ) - self.previous_context = None - self.alive = False - - # if we have a parent, pass our CPU usage stats on - if self.parent_context is not None and hasattr( - self.parent_context, "_resource_usage" - ): - self.parent_context._resource_usage += self._resource_usage - - # reset them in case we get entered again - self._resource_usage.reset() - - def copy_to(self, record): - """Copy logging fields from this context to a log record or - another LoggingContext - """ - - # 'request' is the only field we currently use in the logger, so that's - # all we need to copy - record.request = self.request - - def start(self): - if get_thread_id() != self.main_thread: - logger.warning("Started logcontext %s on different thread", self) - return - - # If we haven't already started record the thread resource usage so - # far - if not self.usage_start: - self.usage_start = get_thread_resource_usage() - - def stop(self): - if get_thread_id() != self.main_thread: - logger.warning("Stopped logcontext %s on different thread", self) - return - - # When we stop, let's record the cpu used since we started - if not self.usage_start: - logger.warning("Called stop on logcontext %s without calling start", self) - return - - utime_delta, stime_delta = self._get_cputime() - self._resource_usage.ru_utime += utime_delta - self._resource_usage.ru_stime += stime_delta - - self.usage_start = None - - def get_resource_usage(self): - """Get resources used by this logcontext so far. - - Returns: - ContextResourceUsage: a *copy* of the object tracking resource - usage so far - """ - # we always return a copy, for consistency - res = self._resource_usage.copy() - - # If we are on the correct thread and we're currently running then we - # can include resource usage so far. - is_main_thread = get_thread_id() == self.main_thread - if self.alive and self.usage_start and is_main_thread: - utime_delta, stime_delta = self._get_cputime() - res.ru_utime += utime_delta - res.ru_stime += stime_delta - - return res - - def _get_cputime(self): - """Get the cpu usage time so far - - Returns: Tuple[float, float]: seconds in user mode, seconds in system mode - """ - current = get_thread_resource_usage() - - utime_delta = current.ru_utime - self.usage_start.ru_utime - stime_delta = current.ru_stime - self.usage_start.ru_stime - - # sanity check - if utime_delta < 0: - logger.error( - "utime went backwards! %f < %f", - current.ru_utime, - self.usage_start.ru_utime, - ) - utime_delta = 0 - - if stime_delta < 0: - logger.error( - "stime went backwards! %f < %f", - current.ru_stime, - self.usage_start.ru_stime, - ) - stime_delta = 0 - - return utime_delta, stime_delta - - def add_database_transaction(self, duration_sec): - if duration_sec < 0: - raise ValueError("DB txn time can only be non-negative") - self._resource_usage.db_txn_count += 1 - self._resource_usage.db_txn_duration_sec += duration_sec - - def add_database_scheduled(self, sched_sec): - """Record a use of the database pool - - Args: - sched_sec (float): number of seconds it took us to get a - connection - """ - if sched_sec < 0: - raise ValueError("DB scheduling time can only be non-negative") - self._resource_usage.db_sched_duration_sec += sched_sec - - def record_event_fetch(self, event_count): - """Record a number of events being fetched from the db - - Args: - event_count (int): number of events being fetched - """ - self._resource_usage.evt_db_fetch_count += event_count - - -class LoggingContextFilter(logging.Filter): - """Logging filter that adds values from the current logging context to each - record. - Args: - **defaults: Default values to avoid formatters complaining about - missing fields - """ - - def __init__(self, **defaults): - self.defaults = defaults - - def filter(self, record): - """Add each fields from the logging contexts to the record. - Returns: - True to include the record in the log output. - """ - context = LoggingContext.current_context() - for key, value in self.defaults.items(): - setattr(record, key, value) - - # context should never be None, but if it somehow ends up being, then - # we end up in a death spiral of infinite loops, so let's check, for - # robustness' sake. - if context is not None: - context.copy_to(record) - - return True - - -class PreserveLoggingContext(object): - """Captures the current logging context and restores it when the scope is - exited. Used to restore the context after a function using - @defer.inlineCallbacks is resumed by a callback from the reactor.""" - - __slots__ = ["current_context", "new_context", "has_parent"] - - def __init__(self, new_context=None): - if new_context is None: - new_context = LoggingContext.sentinel - self.new_context = new_context - - def __enter__(self): - """Captures the current logging context""" - self.current_context = LoggingContext.set_current_context(self.new_context) - - if self.current_context: - self.has_parent = self.current_context.previous_context is not None - if not self.current_context.alive: - logger.debug("Entering dead context: %s", self.current_context) - - def __exit__(self, type, value, traceback): - """Restores the current logging context""" - context = LoggingContext.set_current_context(self.current_context) - - if context != self.new_context: - if context is LoggingContext.sentinel: - logger.warning("Expected logging context %s was lost", self.new_context) - else: - logger.warning( - "Expected logging context %s but found %s", - self.new_context, - context, - ) - - if self.current_context is not LoggingContext.sentinel: - if not self.current_context.alive: - logger.debug("Restoring dead context: %s", self.current_context) - - -def nested_logging_context(suffix, parent_context=None): - """Creates a new logging context as a child of another. - - The nested logging context will have a 'request' made up of the parent context's - request, plus the given suffix. - - CPU/db usage stats will be added to the parent context's on exit. - - Normal usage looks like: - - with nested_logging_context(suffix): - # ... do stuff - - Args: - suffix (str): suffix to add to the parent context's 'request'. - parent_context (LoggingContext|None): parent context. Will use the current context - if None. - - Returns: - LoggingContext: new logging context. - """ - if parent_context is None: - parent_context = LoggingContext.current_context() - return LoggingContext( - parent_context=parent_context, request=parent_context.request + "-" + suffix - ) - - -def preserve_fn(f): - """Function decorator which wraps the function with run_in_background""" - - def g(*args, **kwargs): - return run_in_background(f, *args, **kwargs) - - return g - - -def run_in_background(f, *args, **kwargs): - """Calls a function, ensuring that the current context is restored after - return from the function, and that the sentinel context is set once the - deferred returned by the function completes. - - Useful for wrapping functions that return a deferred or coroutine, which you don't - yield or await on (for instance because you want to pass it to - deferred.gatherResults()). - - Note that if you completely discard the result, you should make sure that - `f` doesn't raise any deferred exceptions, otherwise a scary-looking - CRITICAL error about an unhandled error will be logged without much - indication about where it came from. - """ - current = LoggingContext.current_context() - try: - res = f(*args, **kwargs) - except: # noqa: E722 - # the assumption here is that the caller doesn't want to be disturbed - # by synchronous exceptions, so let's turn them into Failures. - return defer.fail() - - if isinstance(res, types.CoroutineType): - res = defer.ensureDeferred(res) - - if not isinstance(res, defer.Deferred): - return res - - if res.called and not res.paused: - # The function should have maintained the logcontext, so we can - # optimise out the messing about - return res - - # The function may have reset the context before returning, so - # we need to restore it now. - ctx = LoggingContext.set_current_context(current) - - # The original context will be restored when the deferred - # completes, but there is nothing waiting for it, so it will - # get leaked into the reactor or some other function which - # wasn't expecting it. We therefore need to reset the context - # here. - # - # (If this feels asymmetric, consider it this way: we are - # effectively forking a new thread of execution. We are - # probably currently within a ``with LoggingContext()`` block, - # which is supposed to have a single entry and exit point. But - # by spawning off another deferred, we are effectively - # adding a new exit point.) - res.addBoth(_set_context_cb, ctx) - return res - - -def make_deferred_yieldable(deferred): - """Given a deferred, make it follow the Synapse logcontext rules: - - If the deferred has completed (or is not actually a Deferred), essentially - does nothing (just returns another completed deferred with the - result/failure). - - If the deferred has not yet completed, resets the logcontext before - returning a deferred. Then, when the deferred completes, restores the - current logcontext before running callbacks/errbacks. - - (This is more-or-less the opposite operation to run_in_background.) - """ - if not isinstance(deferred, defer.Deferred): - return deferred - - if deferred.called and not deferred.paused: - # it looks like this deferred is ready to run any callbacks we give it - # immediately. We may as well optimise out the logcontext faffery. - return deferred - - # ok, we can't be sure that a yield won't block, so let's reset the - # logcontext, and add a callback to the deferred to restore it. - prev_context = LoggingContext.set_current_context(LoggingContext.sentinel) - deferred.addBoth(_set_context_cb, prev_context) - return deferred - - -def _set_context_cb(result, context): - """A callback function which just sets the logging context""" - LoggingContext.set_current_context(context) - return result - - -def defer_to_thread(reactor, f, *args, **kwargs): - """ - Calls the function `f` using a thread from the reactor's default threadpool and - returns the result as a Deferred. - - Creates a new logcontext for `f`, which is created as a child of the current - logcontext (so its CPU usage metrics will get attributed to the current - logcontext). `f` should preserve the logcontext it is given. - - The result deferred follows the Synapse logcontext rules: you should `yield` - on it. - - Args: - reactor (twisted.internet.base.ReactorBase): The reactor in whose main thread - the Deferred will be invoked, and whose threadpool we should use for the - function. - - Normally this will be hs.get_reactor(). - - f (callable): The function to call. - - args: positional arguments to pass to f. - - kwargs: keyword arguments to pass to f. - - Returns: - Deferred: A Deferred which fires a callback with the result of `f`, or an - errback if `f` throws an exception. - """ - return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs) - - -def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs): - """ - A wrapper for twisted.internet.threads.deferToThreadpool, which handles - logcontexts correctly. - - Calls the function `f` using a thread from the given threadpool and returns - the result as a Deferred. - - Creates a new logcontext for `f`, which is created as a child of the current - logcontext (so its CPU usage metrics will get attributed to the current - logcontext). `f` should preserve the logcontext it is given. - - The result deferred follows the Synapse logcontext rules: you should `yield` - on it. - - Args: - reactor (twisted.internet.base.ReactorBase): The reactor in whose main thread - the Deferred will be invoked. Normally this will be hs.get_reactor(). - - threadpool (twisted.python.threadpool.ThreadPool): The threadpool to use for - running `f`. Normally this will be hs.get_reactor().getThreadPool(). - - f (callable): The function to call. - - args: positional arguments to pass to f. - - kwargs: keyword arguments to pass to f. - - Returns: - Deferred: A Deferred which fires a callback with the result of `f`, or an - errback if `f` throws an exception. - """ - logcontext = LoggingContext.current_context() - - def g(): - with LoggingContext(parent_context=logcontext): - return f(*args, **kwargs) - - return make_deferred_yieldable(threads.deferToThreadPool(reactor, threadpool, g)) diff --git a/synapse/util/logformatter.py b/synapse/util/logformatter.py deleted file mode 100644 index fbf570c756..0000000000 --- a/synapse/util/logformatter.py +++ /dev/null @@ -1,53 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2017 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import logging -import traceback - -from six import StringIO - - -class LogFormatter(logging.Formatter): - """Log formatter which gives more detail for exceptions - - This is the same as the standard log formatter, except that when logging - exceptions [typically via log.foo("msg", exc_info=1)], it prints the - sequence that led up to the point at which the exception was caught. - (Normally only stack frames between the point the exception was raised and - where it was caught are logged). - """ - - def __init__(self, *args, **kwargs): - super(LogFormatter, self).__init__(*args, **kwargs) - - def formatException(self, ei): - sio = StringIO() - (typ, val, tb) = ei - - # log the stack above the exception capture point if possible, but - # check that we actually have an f_back attribute to work around - # https://twistedmatrix.com/trac/ticket/9305 - - if tb and hasattr(tb.tb_frame, "f_back"): - sio.write("Capture point (most recent call last):\n") - traceback.print_stack(tb.tb_frame.f_back, None, sio) - - traceback.print_exception(typ, val, tb, None, sio) - s = sio.getvalue() - sio.close() - if s[-1:] == "\n": - s = s[:-1] - return s diff --git a/synapse/util/logutils.py b/synapse/util/logutils.py deleted file mode 100644 index 7df0fa6087..0000000000 --- a/synapse/util/logutils.py +++ /dev/null @@ -1,194 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import inspect -import logging -import time -from functools import wraps -from inspect import getcallargs - -from six import PY3 - -_TIME_FUNC_ID = 0 - - -def _log_debug_as_f(f, msg, msg_args): - name = f.__module__ - logger = logging.getLogger(name) - - if logger.isEnabledFor(logging.DEBUG): - if PY3: - lineno = f.__code__.co_firstlineno - pathname = f.__code__.co_filename - else: - lineno = f.func_code.co_firstlineno - pathname = f.func_code.co_filename - - record = logging.LogRecord( - name=name, - level=logging.DEBUG, - pathname=pathname, - lineno=lineno, - msg=msg, - args=msg_args, - exc_info=None, - ) - - logger.handle(record) - - -def log_function(f): - """ Function decorator that logs every call to that function. - """ - func_name = f.__name__ - - @wraps(f) - def wrapped(*args, **kwargs): - name = f.__module__ - logger = logging.getLogger(name) - level = logging.DEBUG - - if logger.isEnabledFor(level): - bound_args = getcallargs(f, *args, **kwargs) - - def format(value): - r = str(value) - if len(r) > 50: - r = r[:50] + "..." - return r - - func_args = ["%s=%s" % (k, format(v)) for k, v in bound_args.items()] - - msg_args = {"func_name": func_name, "args": ", ".join(func_args)} - - _log_debug_as_f(f, "Invoked '%(func_name)s' with args: %(args)s", msg_args) - - return f(*args, **kwargs) - - wrapped.__name__ = func_name - return wrapped - - -def time_function(f): - func_name = f.__name__ - - @wraps(f) - def wrapped(*args, **kwargs): - global _TIME_FUNC_ID - id = _TIME_FUNC_ID - _TIME_FUNC_ID += 1 - - start = time.clock() - - try: - _log_debug_as_f(f, "[FUNC START] {%s-%d}", (func_name, id)) - - r = f(*args, **kwargs) - finally: - end = time.clock() - _log_debug_as_f( - f, "[FUNC END] {%s-%d} %.3f sec", (func_name, id, end - start) - ) - - return r - - return wrapped - - -def trace_function(f): - func_name = f.__name__ - linenum = f.func_code.co_firstlineno - pathname = f.func_code.co_filename - - @wraps(f) - def wrapped(*args, **kwargs): - name = f.__module__ - logger = logging.getLogger(name) - level = logging.DEBUG - - s = inspect.currentframe().f_back - - to_print = [ - "\t%s:%s %s. Args: args=%s, kwargs=%s" - % (pathname, linenum, func_name, args, kwargs) - ] - while s: - if True or s.f_globals["__name__"].startswith("synapse"): - filename, lineno, function, _, _ = inspect.getframeinfo(s) - args_string = inspect.formatargvalues(*inspect.getargvalues(s)) - - to_print.append( - "\t%s:%d %s. Args: %s" % (filename, lineno, function, args_string) - ) - - s = s.f_back - - msg = "\nTraceback for %s:\n" % (func_name,) + "\n".join(to_print) - - record = logging.LogRecord( - name=name, - level=level, - pathname=pathname, - lineno=lineno, - msg=msg, - args=None, - exc_info=None, - ) - - logger.handle(record) - - return f(*args, **kwargs) - - wrapped.__name__ = func_name - return wrapped - - -def get_previous_frames(): - s = inspect.currentframe().f_back.f_back - to_return = [] - while s: - if s.f_globals["__name__"].startswith("synapse"): - filename, lineno, function, _, _ = inspect.getframeinfo(s) - args_string = inspect.formatargvalues(*inspect.getargvalues(s)) - - to_return.append( - "{{ %s:%d %s - Args: %s }}" % (filename, lineno, function, args_string) - ) - - s = s.f_back - - return ", ".join(to_return) - - -def get_previous_frame(ignore=[]): - s = inspect.currentframe().f_back.f_back - - while s: - if s.f_globals["__name__"].startswith("synapse"): - if not any(s.f_globals["__name__"].startswith(ig) for ig in ignore): - filename, lineno, function, _, _ = inspect.getframeinfo(s) - args_string = inspect.formatargvalues(*inspect.getargvalues(s)) - - return "{{ %s:%d %s - Args: %s }}" % ( - filename, - lineno, - function, - args_string, - ) - - s = s.f_back - - return None diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 01284d3cf8..c30b6de19c 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -20,8 +20,8 @@ from prometheus_client import Counter from twisted.internet import defer +from synapse.logging.context import LoggingContext from synapse.metrics import InFlightGauge -from synapse.util.logcontext import LoggingContext logger = logging.getLogger(__name__) diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 06defa8199..27bceac00e 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -20,7 +20,7 @@ import logging from twisted.internet import defer from synapse.api.errors import LimitExceededError -from synapse.util.logcontext import ( +from synapse.logging.context import ( PreserveLoggingContext, make_deferred_yieldable, run_in_background, diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index 1a77456498..d8d0ceae51 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -17,7 +17,7 @@ import random from twisted.internet import defer -import synapse.util.logcontext +import synapse.logging.context from synapse.api.errors import CodeMessageException logger = logging.getLogger(__name__) @@ -225,4 +225,4 @@ class RetryDestinationLimiter(object): logger.exception("Failed to store destination_retry_timings") # we deliberately do this in the background. - synapse.util.logcontext.run_in_background(store_retry_timings) + synapse.logging.context.run_in_background(store_retry_timings) diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index db9f86bdac..04b8c2c07c 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -22,7 +22,7 @@ from synapse.appservice.scheduler import ( _ServiceQueuer, _TransactionController, ) -from synapse.util.logcontext import make_deferred_yieldable +from synapse.logging.context import make_deferred_yieldable from tests import unittest diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 5a355f00cc..795703967d 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -30,9 +30,12 @@ from synapse.crypto.keyring import ( ServerKeyFetcher, StoreKeyFetcher, ) +from synapse.logging.context import ( + LoggingContext, + PreserveLoggingContext, + make_deferred_yieldable, +) from synapse.storage.keys import FetchKeyResult -from synapse.util import logcontext -from synapse.util.logcontext import LoggingContext from tests import unittest @@ -131,7 +134,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): @defer.inlineCallbacks def get_perspectives(**kwargs): self.assertEquals(LoggingContext.current_context().request, "11") - with logcontext.PreserveLoggingContext(): + with PreserveLoggingContext(): yield persp_deferred defer.returnValue(persp_resp) @@ -158,7 +161,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): self.assertFalse(res_deferreds[0].called) res_deferreds[0].addBoth(self.check_context, None) - yield logcontext.make_deferred_yieldable(res_deferreds[0]) + yield make_deferred_yieldable(res_deferreds[0]) # let verify_json_objects_for_server finish its work before we kill the # logcontext @@ -184,7 +187,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): [("server10", json1, 0, "test")] ) res_deferreds_2[0].addBoth(self.check_context, None) - yield logcontext.make_deferred_yieldable(res_deferreds_2[0]) + yield make_deferred_yieldable(res_deferreds_2[0]) # let verify_json_objects_for_server finish its work before we kill the # logcontext diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 417fda3ab2..a49f9b3224 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -36,8 +36,8 @@ from synapse.http.federation.matrix_federation_agent import ( _cache_period_from_headers, ) from synapse.http.federation.srv_resolver import Server +from synapse.logging.context import LoggingContext from synapse.util.caches.ttlcache import TTLCache -from synapse.util.logcontext import LoggingContext from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file from tests.server import FakeTransport, ThreadedMemoryReactorClock diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index cf6c6e95b5..65b51dc981 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -22,7 +22,7 @@ from twisted.internet.error import ConnectError from twisted.names import dns, error from synapse.http.federation.srv_resolver import SrvResolver -from synapse.util.logcontext import LoggingContext +from synapse.logging.context import LoggingContext from tests import unittest from tests.utils import MockClock diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index c4c0d9b968..b9d6d7ad1c 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -29,7 +29,7 @@ from synapse.http.matrixfederationclient import ( MatrixFederationHttpClient, MatrixFederationRequest, ) -from synapse.util.logcontext import LoggingContext +from synapse.logging.context import LoggingContext from tests.server import FakeTransport from tests.unittest import HomeserverTestCase diff --git a/tests/patch_inline_callbacks.py b/tests/patch_inline_callbacks.py index ee0add3455..220884311c 100644 --- a/tests/patch_inline_callbacks.py +++ b/tests/patch_inline_callbacks.py @@ -28,7 +28,7 @@ def do_patch(): Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit """ - from synapse.util.logcontext import LoggingContext + from synapse.logging.context import LoggingContext orig_inline_callbacks = defer.inlineCallbacks diff --git a/tests/push/test_http.py b/tests/push/test_http.py index 22c3f73ef3..8ce6bb62da 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -18,8 +18,8 @@ from mock import Mock from twisted.internet.defer import Deferred import synapse.rest.admin +from synapse.logging.context import make_deferred_yieldable from synapse.rest.client.v1 import login, room -from synapse.util.logcontext import make_deferred_yieldable from tests.unittest import HomeserverTestCase diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index 708dc26e61..a8adc9a61d 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -2,9 +2,9 @@ from mock import Mock, call from twisted.internet import defer, reactor +from synapse.logging.context import LoggingContext from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionCache from synapse.util import Clock -from synapse.util.logcontext import LoggingContext from tests import unittest from tests.utils import MockClock diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 39c9342423..bc662b61db 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -24,11 +24,11 @@ from six.moves.urllib import parse from twisted.internet.defer import Deferred +from synapse.logging.context import make_deferred_yieldable from synapse.rest.media.v1._base import FileInfo from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.rest.media.v1.media_storage import MediaStorage from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend -from synapse.util.logcontext import make_deferred_yieldable from tests import unittest diff --git a/tests/test_federation.py b/tests/test_federation.py index 6a8339b561..a73f18f88e 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -3,9 +3,9 @@ from mock import Mock from twisted.internet.defer import maybeDeferred, succeed from synapse.events import FrozenEvent +from synapse.logging.context import LoggingContext from synapse.types import Requester, UserID from synapse.util import Clock -from synapse.util.logcontext import LoggingContext from tests import unittest from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver diff --git a/tests/test_server.py b/tests/test_server.py index da29ae92ce..ba08483a4b 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -26,8 +26,8 @@ from twisted.web.server import NOT_DONE_YET from synapse.api.errors import Codes, SynapseError from synapse.http.server import JsonResource from synapse.http.site import SynapseSite, logger +from synapse.logging.context import make_deferred_yieldable from synapse.util import Clock -from synapse.util.logcontext import make_deferred_yieldable from tests import unittest from tests.server import ( diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py index 813f984199..2d96b0fa8d 100644 --- a/tests/test_utils/logging_setup.py +++ b/tests/test_utils/logging_setup.py @@ -17,7 +17,7 @@ import os import twisted.logger -from synapse.util.logcontext import LoggingContextFilter +from synapse.logging.context import LoggingContextFilter class ToTwistedHandler(logging.Handler): diff --git a/tests/unittest.py b/tests/unittest.py index d26804b5b5..a09e76c7c2 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -33,9 +33,9 @@ from synapse.api.constants import EventTypes from synapse.config.homeserver import HomeServerConfig from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest +from synapse.logging.context import LoggingContext from synapse.server import HomeServer from synapse.types import Requester, UserID, create_requester -from synapse.util.logcontext import LoggingContext from tests.server import get_clock, make_request, render, setup_test_homeserver from tests.test_utils.logging_setup import setup_logging diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 6f8f52537c..7807328e2f 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -21,7 +21,11 @@ import mock from twisted.internet import defer, reactor from synapse.api.errors import SynapseError -from synapse.util import logcontext +from synapse.logging.context import ( + LoggingContext, + PreserveLoggingContext, + make_deferred_yieldable, +) from synapse.util.caches import descriptors from tests import unittest @@ -32,7 +36,7 @@ logger = logging.getLogger(__name__) def run_on_reactor(): d = defer.Deferred() reactor.callLater(0, d.callback, 0) - return logcontext.make_deferred_yieldable(d) + return make_deferred_yieldable(d) class CacheTestCase(unittest.TestCase): @@ -153,7 +157,7 @@ class DescriptorTestCase(unittest.TestCase): def fn(self, arg1): @defer.inlineCallbacks def inner_fn(): - with logcontext.PreserveLoggingContext(): + with PreserveLoggingContext(): yield complete_lookup defer.returnValue(1) @@ -161,10 +165,10 @@ class DescriptorTestCase(unittest.TestCase): @defer.inlineCallbacks def do_lookup(): - with logcontext.LoggingContext() as c1: + with LoggingContext() as c1: c1.name = "c1" r = yield obj.fn(1) - self.assertEqual(logcontext.LoggingContext.current_context(), c1) + self.assertEqual(LoggingContext.current_context(), c1) defer.returnValue(r) def check_result(r): @@ -174,18 +178,12 @@ class DescriptorTestCase(unittest.TestCase): # set off a deferred which will do a cache lookup d1 = do_lookup() - self.assertEqual( - logcontext.LoggingContext.current_context(), - logcontext.LoggingContext.sentinel, - ) + self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel) d1.addCallback(check_result) # and another d2 = do_lookup() - self.assertEqual( - logcontext.LoggingContext.current_context(), - logcontext.LoggingContext.sentinel, - ) + self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel) d2.addCallback(check_result) # let the lookup complete @@ -210,29 +208,25 @@ class DescriptorTestCase(unittest.TestCase): @defer.inlineCallbacks def do_lookup(): - with logcontext.LoggingContext() as c1: + with LoggingContext() as c1: c1.name = "c1" try: d = obj.fn(1) self.assertEqual( - logcontext.LoggingContext.current_context(), - logcontext.LoggingContext.sentinel, + LoggingContext.current_context(), LoggingContext.sentinel ) yield d self.fail("No exception thrown") except SynapseError: pass - self.assertEqual(logcontext.LoggingContext.current_context(), c1) + self.assertEqual(LoggingContext.current_context(), c1) obj = Cls() # set off a deferred which will do a cache lookup d1 = do_lookup() - self.assertEqual( - logcontext.LoggingContext.current_context(), - logcontext.LoggingContext.sentinel, - ) + self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel) return d1 @@ -288,23 +282,20 @@ class CachedListDescriptorTestCase(unittest.TestCase): @descriptors.cachedList("fn", "args1", inlineCallbacks=True) def list_fn(self, args1, arg2): - assert logcontext.LoggingContext.current_context().request == "c1" + assert LoggingContext.current_context().request == "c1" # we want this to behave like an asynchronous function yield run_on_reactor() - assert logcontext.LoggingContext.current_context().request == "c1" + assert LoggingContext.current_context().request == "c1" defer.returnValue(self.mock(args1, arg2)) - with logcontext.LoggingContext() as c1: + with LoggingContext() as c1: c1.request = "c1" obj = Cls() obj.mock.return_value = {10: "fish", 20: "chips"} d1 = obj.list_fn([10, 20], 2) - self.assertEqual( - logcontext.LoggingContext.current_context(), - logcontext.LoggingContext.sentinel, - ) + self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel) r = yield d1 - self.assertEqual(logcontext.LoggingContext.current_context(), c1) + self.assertEqual(LoggingContext.current_context(), c1) obj.mock.assert_called_once_with([10, 20], 2) self.assertEqual(r, {10: "fish", 20: "chips"}) obj.mock.reset_mock() diff --git a/tests/util/test_async_utils.py b/tests/util/test_async_utils.py index bf85d3b8ec..f60918069a 100644 --- a/tests/util/test_async_utils.py +++ b/tests/util/test_async_utils.py @@ -16,9 +16,8 @@ from twisted.internet import defer from twisted.internet.defer import CancelledError, Deferred from twisted.internet.task import Clock -from synapse.util import logcontext +from synapse.logging.context import LoggingContext, PreserveLoggingContext from synapse.util.async_helpers import timeout_deferred -from synapse.util.logcontext import LoggingContext from tests.unittest import TestCase @@ -69,14 +68,14 @@ class TimeoutDeferredTest(TestCase): @defer.inlineCallbacks def blocking(): non_completing_d = Deferred() - with logcontext.PreserveLoggingContext(): + with PreserveLoggingContext(): try: yield non_completing_d except CancelledError: blocking_was_cancelled[0] = True raise - with logcontext.LoggingContext("one") as context_one: + with LoggingContext("one") as context_one: # the errbacks should be run in the test logcontext def errback(res, deferred_name): self.assertIs( diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py index ec7ba9719c..0ec8ef90ce 100644 --- a/tests/util/test_linearizer.py +++ b/tests/util/test_linearizer.py @@ -19,7 +19,8 @@ from six.moves import range from twisted.internet import defer, reactor from twisted.internet.defer import CancelledError -from synapse.util import Clock, logcontext +from synapse.logging.context import LoggingContext +from synapse.util import Clock from synapse.util.async_helpers import Linearizer from tests import unittest @@ -51,13 +52,13 @@ class LinearizerTestCase(unittest.TestCase): @defer.inlineCallbacks def func(i, sleep=False): - with logcontext.LoggingContext("func(%s)" % i) as lc: + with LoggingContext("func(%s)" % i) as lc: with (yield linearizer.queue("")): - self.assertEqual(logcontext.LoggingContext.current_context(), lc) + self.assertEqual(LoggingContext.current_context(), lc) if sleep: yield Clock(reactor).sleep(0) - self.assertEqual(logcontext.LoggingContext.current_context(), lc) + self.assertEqual(LoggingContext.current_context(), lc) func(0, sleep=True) for i in range(1, 100): diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index 8d69fbf111..8b8455c8b7 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -1,8 +1,14 @@ import twisted.python.failure from twisted.internet import defer, reactor -from synapse.util import Clock, logcontext -from synapse.util.logcontext import LoggingContext +from synapse.logging.context import ( + LoggingContext, + PreserveLoggingContext, + make_deferred_yieldable, + nested_logging_context, + run_in_background, +) +from synapse.util import Clock from .. import unittest @@ -43,7 +49,7 @@ class LoggingContextTestCase(unittest.TestCase): context_one.request = "one" # fire off function, but don't wait on it. - d2 = logcontext.run_in_background(function) + d2 = run_in_background(function) def cb(res): callback_completed[0] = True @@ -85,7 +91,7 @@ class LoggingContextTestCase(unittest.TestCase): def test_run_in_background_with_non_blocking_fn(self): @defer.inlineCallbacks def nonblocking_function(): - with logcontext.PreserveLoggingContext(): + with PreserveLoggingContext(): yield defer.succeed(None) return self._test_run_in_background(nonblocking_function) @@ -94,7 +100,7 @@ class LoggingContextTestCase(unittest.TestCase): # a function which returns a deferred which looks like it has been # called, but is actually paused def testfunc(): - return logcontext.make_deferred_yieldable(_chained_deferred_function()) + return make_deferred_yieldable(_chained_deferred_function()) return self._test_run_in_background(testfunc) @@ -128,7 +134,7 @@ class LoggingContextTestCase(unittest.TestCase): with LoggingContext() as context_one: context_one.request = "one" - d1 = logcontext.make_deferred_yieldable(blocking_function()) + d1 = make_deferred_yieldable(blocking_function()) # make sure that the context was reset by make_deferred_yieldable self.assertIs(LoggingContext.current_context(), sentinel_context) @@ -144,7 +150,7 @@ class LoggingContextTestCase(unittest.TestCase): with LoggingContext() as context_one: context_one.request = "one" - d1 = logcontext.make_deferred_yieldable(_chained_deferred_function()) + d1 = make_deferred_yieldable(_chained_deferred_function()) # make sure that the context was reset by make_deferred_yieldable self.assertIs(LoggingContext.current_context(), sentinel_context) @@ -161,7 +167,7 @@ class LoggingContextTestCase(unittest.TestCase): with LoggingContext() as context_one: context_one.request = "one" - d1 = logcontext.make_deferred_yieldable("bum") + d1 = make_deferred_yieldable("bum") self._check_test_key("one") r = yield d1 @@ -170,7 +176,7 @@ class LoggingContextTestCase(unittest.TestCase): def test_nested_logging_context(self): with LoggingContext(request="foo"): - nested_context = logcontext.nested_logging_context(suffix="bar") + nested_context = nested_logging_context(suffix="bar") self.assertEqual(nested_context.request, "foo-bar") diff --git a/tests/util/test_logformatter.py b/tests/util/test_logformatter.py index 297aebbfbe..0fb60caacb 100644 --- a/tests/util/test_logformatter.py +++ b/tests/util/test_logformatter.py @@ -14,7 +14,7 @@ # limitations under the License. import sys -from synapse.util.logformatter import LogFormatter +from synapse.logging.formatter import LogFormatter from tests import unittest diff --git a/tests/utils.py b/tests/utils.py index da43166f3a..d8e55b0801 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -34,6 +34,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.federation.transport import server as federation_server from synapse.http.server import HttpServer +from synapse.logging.context import LoggingContext from synapse.server import HomeServer from synapse.storage import DataStore from synapse.storage.engines import PostgresEngine, create_engine @@ -42,7 +43,6 @@ from synapse.storage.prepare_database import ( _setup_new_database, prepare_database, ) -from synapse.util.logcontext import LoggingContext from synapse.util.ratelimitutils import FederationRateLimiter # set this to True to run the tests against postgres instead of sqlite. -- cgit 1.5.1 From c061d4f237273f3400dc8e62aa7421f02caec3dd Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 4 Jul 2019 11:07:09 +0100 Subject: Fixup from review comments. --- synapse/handlers/admin.py | 39 ++++++++++++++++++++++----------------- tests/handlers/test_admin.py | 10 +++++----- 2 files changed, 27 insertions(+), 22 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 6c905e97a7..69d2c8c36f 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -99,7 +99,7 @@ class AdminHandler(BaseHandler): defer.returnValue(ret) @defer.inlineCallbacks - def exfiltrate_user_data(self, user_id, writer): + def export_user_data(self, user_id, writer): """Write all data we have on the user to the given writer. Args: @@ -107,7 +107,8 @@ class AdminHandler(BaseHandler): writer (ExfiltrationWriter) Returns: - defer.Deferred + defer.Deferred: Resolves when all data for a user has been written. + The returned value is that returned by `writer.finished()`. """ # Get all rooms the user is in or has been in rooms = yield self.store.get_rooms_for_user_where_membership_is( @@ -134,7 +135,7 @@ class AdminHandler(BaseHandler): forgotten = yield self.store.did_forget(user_id, room_id) if forgotten: - logger.info("[%s] User forgot room %d, ignoring", room_id) + logger.info("[%s] User forgot room %d, ignoring", user_id, room_id) continue if room_id not in rooms_user_has_been_in: @@ -172,9 +173,10 @@ class AdminHandler(BaseHandler): # dict[str, set[str]]. event_to_unseen_prevs = {} - # The reverse mapping to above, i.e. map from unseen event to parent - # events. dict[str, set[str]] - unseen_event_to_parents = {} + # The reverse mapping to above, i.e. map from unseen event to events + # that have the unseen event in their prev_events, i.e. the unseen + # events "children". dict[str, set[str]] + unseen_to_child_events = {} # We fetch events in the room the user could see by fetching *all* # events that we have and then filtering, this isn't the most @@ -200,14 +202,14 @@ class AdminHandler(BaseHandler): if unseen_events: event_to_unseen_prevs[event.event_id] = unseen_events for unseen in unseen_events: - unseen_event_to_parents.setdefault(unseen, set()).add( + unseen_to_child_events.setdefault(unseen, set()).add( event.event_id ) # Now check if this event is an unseen prev event, if so # then we remove this event from the appropriate dicts. - for event_id in unseen_event_to_parents.pop(event.event_id, []): - event_to_unseen_prevs.get(event_id, set()).discard( + for child_id in unseen_to_child_events.pop(event.event_id, []): + event_to_unseen_prevs.get(child_id, set()).discard( event.event_id ) @@ -233,7 +235,7 @@ class AdminHandler(BaseHandler): class ExfiltrationWriter(object): - """Interface used to specify how to write exfiltrated data. + """Interface used to specify how to write exported data. """ def write_events(self, room_id, events): @@ -254,7 +256,7 @@ class ExfiltrationWriter(object): Args: room_id (str) event_id (str) - state (list[FrozenEvent]) + state (dict[tuple[str, str], FrozenEvent]) """ pass @@ -264,13 +266,16 @@ class ExfiltrationWriter(object): Args: room_id (str) event (FrozenEvent) - state (list[dict]): A subset of the state at the invite, with a - subset of the event keys (type, state_key, content and sender) + state (dict[tuple[str, str], dict]): A subset of the state at the + invite, with a subset of the event keys (type, state_key + content and sender) """ def finished(self): - """Called when exfiltration is complete, and the return valus is passed - to the requester. + """Called when all data has succesfully been exported and written. + + This functions return value is passed to the caller of + `export_user_data`. """ pass @@ -281,7 +286,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): Returns the directory location on completion. Args: - user_id (str): The user whose data is being exfiltrated. + user_id (str): The user whose data is being exported. directory (str|None): The directory to write the data to. If None then will write to a temporary directory. """ @@ -293,7 +298,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): self.base_directory = directory else: self.base_directory = tempfile.mkdtemp( - prefix="synapse-exfiltrate__%s__" % (user_id,) + prefix="synapse-exported__%s__" % (user_id,) ) os.makedirs(self.base_directory, exist_ok=True) diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py index 5e7d2d3361..fc37c4328c 100644 --- a/tests/handlers/test_admin.py +++ b/tests/handlers/test_admin.py @@ -55,7 +55,7 @@ class ExfiltrateData(unittest.HomeserverTestCase): writer = Mock() - self.get_success(self.admin_handler.exfiltrate_user_data(self.user2, writer)) + self.get_success(self.admin_handler.export_user_data(self.user2, writer)) writer.write_events.assert_called() @@ -94,7 +94,7 @@ class ExfiltrateData(unittest.HomeserverTestCase): writer = Mock() - self.get_success(self.admin_handler.exfiltrate_user_data(self.user2, writer)) + self.get_success(self.admin_handler.export_user_data(self.user2, writer)) writer.write_events.assert_called() @@ -127,7 +127,7 @@ class ExfiltrateData(unittest.HomeserverTestCase): writer = Mock() - self.get_success(self.admin_handler.exfiltrate_user_data(self.user2, writer)) + self.get_success(self.admin_handler.export_user_data(self.user2, writer)) writer.write_events.assert_called() @@ -169,7 +169,7 @@ class ExfiltrateData(unittest.HomeserverTestCase): writer = Mock() - self.get_success(self.admin_handler.exfiltrate_user_data(self.user2, writer)) + self.get_success(self.admin_handler.export_user_data(self.user2, writer)) writer.write_events.assert_called_once() @@ -198,7 +198,7 @@ class ExfiltrateData(unittest.HomeserverTestCase): writer = Mock() - self.get_success(self.admin_handler.exfiltrate_user_data(self.user2, writer)) + self.get_success(self.admin_handler.export_user_data(self.user2, writer)) writer.write_events.assert_not_called() writer.write_state.assert_not_called() -- cgit 1.5.1 From 9481707a52a89022bc5d99c181ceb9e7ec9ec9e9 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 5 Jul 2019 11:10:19 +0100 Subject: Fixes to the federation rate limiter (#5621) - Put the default window_size back to 1000ms (broken by #5181) - Make the `rc_federation` config actually do something - fix an off-by-one error in the 'concurrent' limit - Avoid creating an unused `_PerHostRatelimiter` object for every single incoming request --- changelog.d/5621.bugfix | 1 + synapse/config/ratelimiting.py | 4 +- synapse/util/ratelimitutils.py | 16 +++---- tests/config/test_ratelimiting.py | 40 ++++++++++++++++ tests/util/test_ratelimitutils.py | 97 +++++++++++++++++++++++++++++++++++++++ tests/utils.py | 6 --- 6 files changed, 148 insertions(+), 16 deletions(-) create mode 100644 changelog.d/5621.bugfix create mode 100644 tests/config/test_ratelimiting.py create mode 100644 tests/util/test_ratelimitutils.py (limited to 'tests') diff --git a/changelog.d/5621.bugfix b/changelog.d/5621.bugfix new file mode 100644 index 0000000000..f1a2851f45 --- /dev/null +++ b/changelog.d/5621.bugfix @@ -0,0 +1 @@ +Various minor fixes to the federation request rate limiter. diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 8c587f3fd2..33f31cf213 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -23,7 +23,7 @@ class RateLimitConfig(object): class FederationRateLimitConfig(object): _items_and_default = { - "window_size": 10000, + "window_size": 1000, "sleep_limit": 10, "sleep_delay": 500, "reject_limit": 50, @@ -54,7 +54,7 @@ class RatelimitConfig(Config): # Load the new-style federation config, if it exists. Otherwise, fall # back to the old method. - if "federation_rc" in config: + if "rc_federation" in config: self.rc_federation = FederationRateLimitConfig(**config["rc_federation"]) else: self.rc_federation = FederationRateLimitConfig( diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 27bceac00e..5ca4521ce3 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -36,9 +36,11 @@ class FederationRateLimiter(object): clock (Clock) config (FederationRateLimitConfig) """ - self.clock = clock - self._config = config - self.ratelimiters = {} + + def new_limiter(): + return _PerHostRatelimiter(clock=clock, config=config) + + self.ratelimiters = collections.defaultdict(new_limiter) def ratelimit(self, host): """Used to ratelimit an incoming request from given host @@ -53,11 +55,9 @@ class FederationRateLimiter(object): host (str): Origin of incoming request. Returns: - _PerHostRatelimiter + context manager which returns a deferred. """ - return self.ratelimiters.setdefault( - host, _PerHostRatelimiter(clock=self.clock, config=self._config) - ).ratelimit() + return self.ratelimiters[host].ratelimit() class _PerHostRatelimiter(object): @@ -122,7 +122,7 @@ class _PerHostRatelimiter(object): self.request_times.append(time_now) def queue_request(): - if len(self.current_processing) > self.concurrent_requests: + if len(self.current_processing) >= self.concurrent_requests: queue_defer = defer.Deferred() self.ready_request_queue[request_id] = queue_defer logger.info( diff --git a/tests/config/test_ratelimiting.py b/tests/config/test_ratelimiting.py new file mode 100644 index 0000000000..13ab282384 --- /dev/null +++ b/tests/config/test_ratelimiting.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.config.homeserver import HomeServerConfig + +from tests.unittest import TestCase +from tests.utils import default_config + + +class RatelimitConfigTestCase(TestCase): + def test_parse_rc_federation(self): + config_dict = default_config("test") + config_dict["rc_federation"] = { + "window_size": 20000, + "sleep_limit": 693, + "sleep_delay": 252, + "reject_limit": 198, + "concurrent": 7, + } + + config = HomeServerConfig() + config.parse_config_dict(config_dict, "", "") + config_obj = config.rc_federation + + self.assertEqual(config_obj.window_size, 20000) + self.assertEqual(config_obj.sleep_limit, 693) + self.assertEqual(config_obj.sleep_delay, 252) + self.assertEqual(config_obj.reject_limit, 198) + self.assertEqual(config_obj.concurrent, 7) diff --git a/tests/util/test_ratelimitutils.py b/tests/util/test_ratelimitutils.py new file mode 100644 index 0000000000..4d1aee91d5 --- /dev/null +++ b/tests/util/test_ratelimitutils.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.config.homeserver import HomeServerConfig +from synapse.util.ratelimitutils import FederationRateLimiter + +from tests.server import get_clock +from tests.unittest import TestCase +from tests.utils import default_config + + +class FederationRateLimiterTestCase(TestCase): + def test_ratelimit(self): + """A simple test with the default values""" + reactor, clock = get_clock() + rc_config = build_rc_config() + ratelimiter = FederationRateLimiter(clock, rc_config) + + with ratelimiter.ratelimit("testhost") as d1: + # shouldn't block + self.successResultOf(d1) + + def test_concurrent_limit(self): + """Test what happens when we hit the concurrent limit""" + reactor, clock = get_clock() + rc_config = build_rc_config({"rc_federation": {"concurrent": 2}}) + ratelimiter = FederationRateLimiter(clock, rc_config) + + with ratelimiter.ratelimit("testhost") as d1: + # shouldn't block + self.successResultOf(d1) + + cm2 = ratelimiter.ratelimit("testhost") + d2 = cm2.__enter__() + # also shouldn't block + self.successResultOf(d2) + + cm3 = ratelimiter.ratelimit("testhost") + d3 = cm3.__enter__() + # this one should block, though ... + self.assertNoResult(d3) + + # ... until we complete an earlier request + cm2.__exit__(None, None, None) + self.successResultOf(d3) + + def test_sleep_limit(self): + """Test what happens when we hit the sleep limit""" + reactor, clock = get_clock() + rc_config = build_rc_config( + {"rc_federation": {"sleep_limit": 2, "sleep_delay": 500}} + ) + ratelimiter = FederationRateLimiter(clock, rc_config) + + with ratelimiter.ratelimit("testhost") as d1: + # shouldn't block + self.successResultOf(d1) + + with ratelimiter.ratelimit("testhost") as d2: + # nor this + self.successResultOf(d2) + + with ratelimiter.ratelimit("testhost") as d3: + # this one should block, though ... + self.assertNoResult(d3) + sleep_time = _await_resolution(reactor, d3) + self.assertAlmostEqual(sleep_time, 500, places=3) + + +def _await_resolution(reactor, d): + """advance the clock until the deferred completes. + + Returns the number of milliseconds it took to complete. + """ + start_time = reactor.seconds() + while not d.called: + reactor.advance(0.01) + return (reactor.seconds() - start_time) * 1000 + + +def build_rc_config(settings={}): + config_dict = default_config("test") + config_dict.update(settings) + config = HomeServerConfig() + config.parse_config_dict(config_dict, "", "") + return config.rc_federation diff --git a/tests/utils.py b/tests/utils.py index d8e55b0801..8a94ce0b47 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -152,12 +152,6 @@ def default_config(name, parse=False): "mau_stats_only": False, "mau_limits_reserved_threepids": [], "admin_contact": None, - "rc_federation": { - "reject_limit": 10, - "sleep_limit": 10, - "sleep_delay": 10, - "concurrent": 10, - }, "rc_message": {"per_second": 10000, "burst_count": 10000}, "rc_registration": {"per_second": 10000, "burst_count": 10000}, "rc_login": { -- cgit 1.5.1 From ad8b909ce98cce5e955b10a3cbe04f3e868ecb70 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 5 Jul 2019 17:20:02 +0100 Subject: Add origin_server_ts and sender fields to m.replace (#5613) Riot team would like some extra fields as part of m.replace, so here you go. Fixes: #5598 --- changelog.d/5613.feature | 1 + synapse/events/utils.py | 6 +++++- tests/rest/client/v2_alpha/test_relations.py | 24 ++++++++++++++++++------ 3 files changed, 24 insertions(+), 7 deletions(-) create mode 100644 changelog.d/5613.feature (limited to 'tests') diff --git a/changelog.d/5613.feature b/changelog.d/5613.feature new file mode 100644 index 0000000000..4b7bb2745c --- /dev/null +++ b/changelog.d/5613.feature @@ -0,0 +1 @@ +Add `sender` and `origin_server_ts` fields to `m.replace`. diff --git a/synapse/events/utils.py b/synapse/events/utils.py index f24f0c16f0..987de5cab7 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -392,7 +392,11 @@ class EventClientSerializer(object): serialized_event["content"].pop("m.relates_to", None) r = serialized_event["unsigned"].setdefault("m.relations", {}) - r[RelationTypes.REPLACE] = {"event_id": edit.event_id} + r[RelationTypes.REPLACE] = { + "event_id": edit.event_id, + "origin_server_ts": edit.origin_server_ts, + "sender": edit.sender, + } defer.returnValue(serialized_event) diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index 3deeed3a70..6bb7d92638 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -466,9 +466,15 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.json_body["content"], new_body) - self.assertEquals( - channel.json_body["unsigned"].get("m.relations"), - {RelationTypes.REPLACE: {"event_id": edit_event_id}}, + relations_dict = channel.json_body["unsigned"].get("m.relations") + self.assertIn(RelationTypes.REPLACE, relations_dict) + + m_replace_dict = relations_dict[RelationTypes.REPLACE] + for key in ["event_id", "sender", "origin_server_ts"]: + self.assertIn(key, m_replace_dict) + + self.assert_dict( + {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) def test_multi_edit(self): @@ -518,9 +524,15 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.json_body["content"], new_body) - self.assertEquals( - channel.json_body["unsigned"].get("m.relations"), - {RelationTypes.REPLACE: {"event_id": edit_event_id}}, + relations_dict = channel.json_body["unsigned"].get("m.relations") + self.assertIn(RelationTypes.REPLACE, relations_dict) + + m_replace_dict = relations_dict[RelationTypes.REPLACE] + for key in ["event_id", "sender", "origin_server_ts"]: + self.assertIn(key, m_replace_dict) + + self.assert_dict( + {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) def _send_relation( -- cgit 1.5.1 From 1af2fcd492c81075a9d75a5578d8d599af3cea0c Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 8 Jul 2019 14:52:26 +0100 Subject: Move get_or_create_user to test code (#5628) This is only used in tests, so... --- changelog.d/5628.misc | 1 + synapse/handlers/register.py | 51 ------------------------------- tests/handlers/test_register.py | 68 +++++++++++++++++++++++++++++++++++------ 3 files changed, 60 insertions(+), 60 deletions(-) create mode 100644 changelog.d/5628.misc (limited to 'tests') diff --git a/changelog.d/5628.misc b/changelog.d/5628.misc new file mode 100644 index 0000000000..fec8446793 --- /dev/null +++ b/changelog.d/5628.misc @@ -0,0 +1 @@ +Move RegistrationHandler.get_or_create_user to test code. diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index fd55eb1654..853020180b 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -505,57 +505,6 @@ class RegistrationHandler(BaseHandler): ) defer.returnValue(data) - @defer.inlineCallbacks - def get_or_create_user(self, requester, localpart, displayname, password_hash=None): - """Creates a new user if the user does not exist, - else revokes all previous access tokens and generates a new one. - - Args: - localpart : The local part of the user ID to register. If None, - one will be randomly generated. - Returns: - A tuple of (user_id, access_token). - Raises: - RegistrationError if there was a problem registering. - - NB this is only used in tests. TODO: move it to the test package! - """ - if localpart is None: - raise SynapseError(400, "Request must include user id") - yield self.auth.check_auth_blocking() - need_register = True - - try: - yield self.check_username(localpart) - except SynapseError as e: - if e.errcode == Codes.USER_IN_USE: - need_register = False - else: - raise - - user = UserID(localpart, self.hs.hostname) - user_id = user.to_string() - token = self.macaroon_gen.generate_access_token(user_id) - - if need_register: - yield self.register_with_store( - user_id=user_id, - token=token, - password_hash=password_hash, - create_profile_with_displayname=user.localpart, - ) - else: - yield self._auth_handler.delete_access_tokens_for_user(user_id) - yield self.store.add_access_token_to_user(user_id=user_id, token=token) - - if displayname is not None: - logger.info("setting user display name: %s -> %s", user_id, displayname) - yield self.profile_handler.set_displayname( - user, requester, displayname, by_admin=True - ) - - defer.returnValue((user_id, token)) - @defer.inlineCallbacks def _join_user_to_room(self, requester, room_identifier): room_id = None diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 4edce7af43..1c7ded7397 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -18,7 +18,7 @@ from mock import Mock from twisted.internet import defer from synapse.api.constants import UserTypes -from synapse.api.errors import ResourceLimitError, SynapseError +from synapse.api.errors import Codes, ResourceLimitError, SynapseError from synapse.handlers.register import RegistrationHandler from synapse.types import RoomAlias, UserID, create_requester @@ -67,7 +67,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): user_id = frank.to_string() requester = create_requester(user_id) result_user_id, result_token = self.get_success( - self.handler.get_or_create_user(requester, frank.localpart, "Frankie") + self.get_or_create_user(requester, frank.localpart, "Frankie") ) self.assertEquals(result_user_id, user_id) self.assertTrue(result_token is not None) @@ -87,7 +87,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): user_id = frank.to_string() requester = create_requester(user_id) result_user_id, result_token = self.get_success( - self.handler.get_or_create_user(requester, local_part, None) + self.get_or_create_user(requester, local_part, None) ) self.assertEquals(result_user_id, user_id) self.assertTrue(result_token is not None) @@ -95,9 +95,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_mau_limits_when_disabled(self): self.hs.config.limit_usage_by_mau = False # Ensure does not throw exception - self.get_success( - self.handler.get_or_create_user(self.requester, "a", "display_name") - ) + self.get_success(self.get_or_create_user(self.requester, "a", "display_name")) def test_get_or_create_user_mau_not_blocked(self): self.hs.config.limit_usage_by_mau = True @@ -105,7 +103,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): return_value=defer.succeed(self.hs.config.max_mau_value - 1) ) # Ensure does not throw exception - self.get_success(self.handler.get_or_create_user(self.requester, "c", "User")) + self.get_success(self.get_or_create_user(self.requester, "c", "User")) def test_get_or_create_user_mau_blocked(self): self.hs.config.limit_usage_by_mau = True @@ -113,7 +111,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): return_value=defer.succeed(self.lots_of_users) ) self.get_failure( - self.handler.get_or_create_user(self.requester, "b", "display_name"), + self.get_or_create_user(self.requester, "b", "display_name"), ResourceLimitError, ) @@ -121,7 +119,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): return_value=defer.succeed(self.hs.config.max_mau_value) ) self.get_failure( - self.handler.get_or_create_user(self.requester, "b", "display_name"), + self.get_or_create_user(self.requester, "b", "display_name"), ResourceLimitError, ) @@ -232,3 +230,55 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_invalid_user_id_length(self): invalid_user_id = "x" * 256 self.get_failure(self.handler.register(localpart=invalid_user_id), SynapseError) + + @defer.inlineCallbacks + def get_or_create_user(self, requester, localpart, displayname, password_hash=None): + """Creates a new user if the user does not exist, + else revokes all previous access tokens and generates a new one. + + XXX: this used to be in the main codebase, but was only used by this file, + so got moved here. TODO: get rid of it, probably + + Args: + localpart : The local part of the user ID to register. If None, + one will be randomly generated. + Returns: + A tuple of (user_id, access_token). + Raises: + RegistrationError if there was a problem registering. + """ + if localpart is None: + raise SynapseError(400, "Request must include user id") + yield self.hs.get_auth().check_auth_blocking() + need_register = True + + try: + yield self.handler.check_username(localpart) + except SynapseError as e: + if e.errcode == Codes.USER_IN_USE: + need_register = False + else: + raise + + user = UserID(localpart, self.hs.hostname) + user_id = user.to_string() + token = self.macaroon_generator.generate_access_token(user_id) + + if need_register: + yield self.handler.register_with_store( + user_id=user_id, + token=token, + password_hash=password_hash, + create_profile_with_displayname=user.localpart, + ) + else: + yield self.hs.get_auth_handler().delete_access_tokens_for_user(user_id) + yield self.store.add_access_token_to_user(user_id=user_id, token=token) + + if displayname is not None: + # logger.info("setting user display name: %s -> %s", user_id, displayname) + yield self.hs.get_profile_handler().set_displayname( + user, requester, displayname, by_admin=True + ) + + defer.returnValue((user_id, token)) -- cgit 1.5.1 From 5e01e9ac1914cff89d54350df5270c1a2b7ccc42 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 8 Jul 2019 17:41:16 +0100 Subject: Add test case --- tests/rest/client/v1/test_profile.py | 47 ++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) (limited to 'tests') diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index dff9b2f10c..a76dda9503 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -288,3 +288,50 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): # if the user isn't already in the room), because we only want to # make sure the user isn't in the room. pass + + +class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + profile.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["require_auth_for_profile_requests"] = True + self.hs = self.setup_test_homeserver(config=config) + + return self.hs + + def prepare(self, reactor, clock, hs): + # User requesting the profile. + self.requester = self.register_user("requester", "pass") + self.requester_tok = self.login("requester", "pass") + + def test_can_lookup_own_profile(self): + """Tests that a user can lookup their own profile without having to be in a room + if 'require_auth_for_profile_requests' is set to true in the server's config. + """ + request, channel = self.make_request( + "GET", "/profile/" + self.requester, access_token=self.requester_tok + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + request, channel = self.make_request( + "GET", + "/profile/" + self.requester + "/displayname", + access_token=self.requester_tok + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + request, channel = self.make_request( + "GET", + "/profile/" + self.requester + "/avatar_url", + access_token=self.requester_tok + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) -- cgit 1.5.1 From 73cb716b3c97f018efe00c6ca7a80b7c6d48c0e1 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 8 Jul 2019 17:44:20 +0100 Subject: Lint --- tests/rest/client/v1/test_profile.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index a76dda9503..140d8b3772 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -323,7 +323,7 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "/profile/" + self.requester + "/displayname", - access_token=self.requester_tok + access_token=self.requester_tok, ) self.render(request) self.assertEqual(channel.code, 200, channel.result) @@ -331,7 +331,7 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "/profile/" + self.requester + "/avatar_url", - access_token=self.requester_tok + access_token=self.requester_tok, ) self.render(request) self.assertEqual(channel.code, 200, channel.result) -- cgit 1.5.1 From 824707383bea618ea809a0f13cf72168b87184f9 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 8 Jul 2019 19:01:08 +0100 Subject: Remove access-token support from RegistrationHandler.register (#5641) Nothing uses this now, so we can remove the dead code, and clean up the API. Since we're changing the shape of the return value anyway, we take the opportunity to give the method a better name. --- changelog.d/5641.misc | 1 + synapse/handlers/register.py | 27 ++-------------- synapse/module_api/__init__.py | 10 ++---- synapse/replication/http/register.py | 6 ---- synapse/rest/admin/__init__.py | 3 +- synapse/rest/client/v1/login.py | 14 +++------ synapse/rest/client/v2_alpha/register.py | 11 +++---- tests/handlers/test_register.py | 53 +++++++++++++++++--------------- 8 files changed, 44 insertions(+), 81 deletions(-) create mode 100644 changelog.d/5641.misc (limited to 'tests') diff --git a/changelog.d/5641.misc b/changelog.d/5641.misc new file mode 100644 index 0000000000..1899bc963d --- /dev/null +++ b/changelog.d/5641.misc @@ -0,0 +1 @@ +Remove access-token support from RegistrationHandler.register, and rename it. diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index c72735067f..a3e553d5f5 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -138,11 +138,10 @@ class RegistrationHandler(BaseHandler): ) @defer.inlineCallbacks - def register( + def register_user( self, localpart=None, password=None, - generate_token=True, guest_access_token=None, make_guest=False, admin=False, @@ -160,11 +159,6 @@ class RegistrationHandler(BaseHandler): password (unicode) : The password to assign to this user so they can login again. This can be None which means they cannot login again via a password (e.g. the user is an application service user). - generate_token (bool): Whether a new access token should be - generated. Having this be True should be considered deprecated, - since it offers no means of associating a device_id with the - access_token. Instead you should call auth_handler.issue_access_token - after registration. user_type (str|None): type of user. One of the values from api.constants.UserTypes, or None for a normal user. default_display_name (unicode|None): if set, the new user's displayname @@ -172,7 +166,7 @@ class RegistrationHandler(BaseHandler): address (str|None): the IP address used to perform the registration. bind_emails (List[str]): list of emails to bind to this account. Returns: - A tuple of (user_id, access_token). + Deferred[str]: user_id Raises: RegistrationError if there was a problem registering. """ @@ -206,12 +200,8 @@ class RegistrationHandler(BaseHandler): elif default_display_name is None: default_display_name = localpart - token = None - if generate_token: - token = self.macaroon_gen.generate_access_token(user_id) yield self.register_with_store( user_id=user_id, - token=token, password_hash=password_hash, was_guest=was_guest, make_guest=make_guest, @@ -230,21 +220,17 @@ class RegistrationHandler(BaseHandler): else: # autogen a sequential user ID attempts = 0 - token = None user = None while not user: localpart = yield self._generate_user_id(attempts > 0) user = UserID(localpart, self.hs.hostname) user_id = user.to_string() yield self.check_user_id_not_appservice_exclusive(user_id) - if generate_token: - token = self.macaroon_gen.generate_access_token(user_id) if default_display_name is None: default_display_name = localpart try: yield self.register_with_store( user_id=user_id, - token=token, password_hash=password_hash, make_guest=make_guest, create_profile_with_displayname=default_display_name, @@ -254,7 +240,6 @@ class RegistrationHandler(BaseHandler): # if user id is taken, just generate another user = None user_id = None - token = None attempts += 1 if not self.hs.config.user_consent_at_registration: @@ -278,7 +263,7 @@ class RegistrationHandler(BaseHandler): # Bind email to new account yield self._register_email_threepid(user_id, threepid_dict, None, False) - defer.returnValue((user_id, token)) + defer.returnValue(user_id) @defer.inlineCallbacks def _auto_join_rooms(self, user_id): @@ -541,7 +526,6 @@ class RegistrationHandler(BaseHandler): def register_with_store( self, user_id, - token=None, password_hash=None, was_guest=False, make_guest=False, @@ -555,9 +539,6 @@ class RegistrationHandler(BaseHandler): Args: user_id (str): The desired user ID to register. - token (str): The desired access token to use for this user. If this - is not None, the given access token is associated with the user - id. password_hash (str|None): Optional. The password hash for this user. was_guest (bool): Optional. Whether this is a guest account being upgraded to a non-guest account. @@ -593,7 +574,6 @@ class RegistrationHandler(BaseHandler): if self.hs.config.worker_app: return self._register_client( user_id=user_id, - token=token, password_hash=password_hash, was_guest=was_guest, make_guest=make_guest, @@ -606,7 +586,6 @@ class RegistrationHandler(BaseHandler): else: return self.store.register( user_id=user_id, - token=token, password_hash=password_hash, was_guest=was_guest, make_guest=make_guest, diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index a0be2c5ca3..7bb020cb45 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -103,7 +103,6 @@ class ModuleApi(object): _, access_token = yield self.register_device(user_id) defer.returnValue((user_id, access_token)) - @defer.inlineCallbacks def register_user(self, localpart, displayname=None, emails=[]): """Registers a new user with given localpart and optional displayname, emails. @@ -115,15 +114,10 @@ class ModuleApi(object): Returns: Deferred[str]: user_id """ - user_id, _ = yield self.hs.get_registration_handler().register( - localpart=localpart, - default_display_name=displayname, - bind_emails=emails, - generate_token=False, + return self.hs.get_registration_handler().register_user( + localpart=localpart, default_display_name=displayname, bind_emails=emails ) - defer.returnValue(user_id) - def register_device(self, user_id, device_id=None, initial_display_name=None): """Register a device for a user and generate an access token. diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index f81a0f1b8f..2bf2173895 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -38,7 +38,6 @@ class ReplicationRegisterServlet(ReplicationEndpoint): @staticmethod def _serialize_payload( user_id, - token, password_hash, was_guest, make_guest, @@ -51,9 +50,6 @@ class ReplicationRegisterServlet(ReplicationEndpoint): """ Args: user_id (str): The desired user ID to register. - token (str): The desired access token to use for this user. If this - is not None, the given access token is associated with the user - id. password_hash (str|None): Optional. The password hash for this user. was_guest (bool): Optional. Whether this is a guest account being upgraded to a non-guest account. @@ -68,7 +64,6 @@ class ReplicationRegisterServlet(ReplicationEndpoint): address (str|None): the IP address used to perform the regitration. """ return { - "token": token, "password_hash": password_hash, "was_guest": was_guest, "make_guest": make_guest, @@ -85,7 +80,6 @@ class ReplicationRegisterServlet(ReplicationEndpoint): yield self.registration_handler.register_with_store( user_id=user_id, - token=content["token"], password_hash=content["password_hash"], was_guest=content["was_guest"], make_guest=content["make_guest"], diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 9843a902c6..6888ae5590 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -219,11 +219,10 @@ class UserRegisterServlet(RestServlet): register = RegisterRestServlet(self.hs) - (user_id, _) = yield register.registration_handler.register( + user_id = yield register.registration_handler.register_user( localpart=body["username"].lower(), password=body["password"], admin=bool(admin), - generate_token=False, user_type=user_type, ) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index b13043cc64..0d05945f0a 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -314,10 +314,8 @@ class LoginRestServlet(RestServlet): registered_user_id = yield self.auth_handler.check_user_exists(user_id) if not registered_user_id: - registered_user_id, _ = ( - yield self.registration_handler.register( - localpart=user, generate_token=False - ) + registered_user_id = yield self.registration_handler.register_user( + localpart=user ) result = yield self._register_device_with_callback( @@ -505,12 +503,8 @@ class SSOAuthHandler(object): user_id = UserID(localpart, self._hostname).to_string() registered_user_id = yield self._auth_handler.check_user_exists(user_id) if not registered_user_id: - registered_user_id, _ = ( - yield self._registration_handler.register( - localpart=localpart, - generate_token=False, - default_display_name=user_display_name, - ) + registered_user_id = yield self._registration_handler.register_user( + localpart=localpart, default_display_name=user_display_name ) login_token = self._macaroon_gen.generate_short_term_login_token( diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 5c120e4dd5..f327999e59 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -464,11 +464,10 @@ class RegisterRestServlet(RestServlet): Codes.THREEPID_IN_USE, ) - (registered_user_id, _) = yield self.registration_handler.register( + registered_user_id = yield self.registration_handler.register_user( localpart=desired_username, password=new_password, guest_access_token=guest_access_token, - generate_token=False, threepid=threepid, address=client_addr, ) @@ -542,8 +541,8 @@ class RegisterRestServlet(RestServlet): if not compare_digest(want_mac, got_mac): raise SynapseError(403, "HMAC incorrect") - (user_id, _) = yield self.registration_handler.register( - localpart=username, password=password, generate_token=False + user_id = yield self.registration_handler.register_user( + localpart=username, password=password ) result = yield self._create_registration_details(user_id, body) @@ -577,8 +576,8 @@ class RegisterRestServlet(RestServlet): def _do_guest_registration(self, params, address=None): if not self.hs.config.allow_guest_access: raise SynapseError(403, "Guest access is disabled") - user_id, _ = yield self.registration_handler.register( - generate_token=False, make_guest=True, address=address + user_id = yield self.registration_handler.register_user( + make_guest=True, address=address ) # we don't allow guests to specify their own device_id, because diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 1c7ded7397..8197f26d4f 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -129,21 +129,21 @@ class RegistrationTestCase(unittest.HomeserverTestCase): return_value=defer.succeed(self.lots_of_users) ) self.get_failure( - self.handler.register(localpart="local_part"), ResourceLimitError + self.handler.register_user(localpart="local_part"), ResourceLimitError ) self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) self.get_failure( - self.handler.register(localpart="local_part"), ResourceLimitError + self.handler.register_user(localpart="local_part"), ResourceLimitError ) def test_auto_create_auto_join_rooms(self): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - res = self.get_success(self.handler.register(localpart="jeff")) - rooms = self.get_success(self.store.get_rooms_for_user(res[0])) + user_id = self.get_success(self.handler.register_user(localpart="jeff")) + rooms = self.get_success(self.store.get_rooms_for_user(user_id)) directory_handler = self.hs.get_handlers().directory_handler room_alias = RoomAlias.from_string(room_alias_str) room_id = self.get_success(directory_handler.get_association(room_alias)) @@ -154,25 +154,25 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_auto_create_auto_join_rooms_with_no_rooms(self): self.hs.config.auto_join_rooms = [] frank = UserID.from_string("@frank:test") - res = self.get_success(self.handler.register(frank.localpart)) - self.assertEqual(res[0], frank.to_string()) - rooms = self.get_success(self.store.get_rooms_for_user(res[0])) + user_id = self.get_success(self.handler.register_user(frank.localpart)) + self.assertEqual(user_id, frank.to_string()) + rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) def test_auto_create_auto_join_where_room_is_another_domain(self): self.hs.config.auto_join_rooms = ["#room:another"] frank = UserID.from_string("@frank:test") - res = self.get_success(self.handler.register(frank.localpart)) - self.assertEqual(res[0], frank.to_string()) - rooms = self.get_success(self.store.get_rooms_for_user(res[0])) + user_id = self.get_success(self.handler.register_user(frank.localpart)) + self.assertEqual(user_id, frank.to_string()) + rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) def test_auto_create_auto_join_where_auto_create_is_false(self): self.hs.config.autocreate_auto_join_rooms = False room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - res = self.get_success(self.handler.register(localpart="jeff")) - rooms = self.get_success(self.store.get_rooms_for_user(res[0])) + user_id = self.get_success(self.handler.register_user(localpart="jeff")) + rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) def test_auto_create_auto_join_rooms_when_support_user_exists(self): @@ -180,8 +180,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.hs.config.auto_join_rooms = [room_alias_str] self.store.is_support_user = Mock(return_value=True) - res = self.get_success(self.handler.register(localpart="support")) - rooms = self.get_success(self.store.get_rooms_for_user(res[0])) + user_id = self.get_success(self.handler.register_user(localpart="support")) + rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) directory_handler = self.hs.get_handlers().directory_handler room_alias = RoomAlias.from_string(room_alias_str) @@ -209,27 +209,31 @@ class RegistrationTestCase(unittest.HomeserverTestCase): # When:- # * the user is registered and post consent actions are called - res = self.get_success(self.handler.register(localpart="jeff")) - self.get_success(self.handler.post_consent_actions(res[0])) + user_id = self.get_success(self.handler.register_user(localpart="jeff")) + self.get_success(self.handler.post_consent_actions(user_id)) # Then:- # * Ensure that they have not been joined to the room - rooms = self.get_success(self.store.get_rooms_for_user(res[0])) + rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) def test_register_support_user(self): - res = self.get_success( - self.handler.register(localpart="user", user_type=UserTypes.SUPPORT) + user_id = self.get_success( + self.handler.register_user(localpart="user", user_type=UserTypes.SUPPORT) ) - self.assertTrue(self.store.is_support_user(res[0])) + d = self.store.is_support_user(user_id) + self.assertTrue(self.get_success(d)) def test_register_not_support_user(self): - res = self.get_success(self.handler.register(localpart="user")) - self.assertFalse(self.store.is_support_user(res[0])) + user_id = self.get_success(self.handler.register_user(localpart="user")) + d = self.store.is_support_user(user_id) + self.assertFalse(self.get_success(d)) def test_invalid_user_id_length(self): invalid_user_id = "x" * 256 - self.get_failure(self.handler.register(localpart=invalid_user_id), SynapseError) + self.get_failure( + self.handler.register_user(localpart=invalid_user_id), SynapseError + ) @defer.inlineCallbacks def get_or_create_user(self, requester, localpart, displayname, password_hash=None): @@ -267,13 +271,12 @@ class RegistrationTestCase(unittest.HomeserverTestCase): if need_register: yield self.handler.register_with_store( user_id=user_id, - token=token, password_hash=password_hash, create_profile_with_displayname=user.localpart, ) else: yield self.hs.get_auth_handler().delete_access_tokens_for_user(user_id) - yield self.store.add_access_token_to_user(user_id=user_id, token=token) + yield self.store.add_access_token_to_user(user_id=user_id, token=token) if displayname is not None: # logger.info("setting user display name: %s -> %s", user_id, displayname) -- cgit 1.5.1 From d88421ab03e72a6c7f69ca38a57b4b6212f1bc82 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 9 Jul 2019 13:43:08 +0100 Subject: Include the original event in /relations (#5626) When asking for the relations of an event, include the original event in the response. This will mostly be used for efficiently showing edit history, but could be useful in other circumstances. --- changelog.d/5626.feature | 1 + synapse/rest/client/v2_alpha/relations.py | 8 +++++--- synapse/storage/relations.py | 2 +- tests/rest/client/v2_alpha/test_relations.py | 5 +++++ 4 files changed, 12 insertions(+), 4 deletions(-) create mode 100644 changelog.d/5626.feature (limited to 'tests') diff --git a/changelog.d/5626.feature b/changelog.d/5626.feature new file mode 100644 index 0000000000..5ef793b943 --- /dev/null +++ b/changelog.d/5626.feature @@ -0,0 +1 @@ +Include the original event when asking for its relations. diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index 8e362782cc..458afd135f 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -145,9 +145,9 @@ class RelationPaginationServlet(RestServlet): room_id, requester.user.to_string() ) - # This checks that a) the event exists and b) the user is allowed to - # view it. - yield self.event_handler.get_event(requester.user, room_id, parent_id) + # This gets the original event and checks that a) the event exists and + # b) the user is allowed to view it. + event = yield self.event_handler.get_event(requester.user, room_id, parent_id) limit = parse_integer(request, "limit", default=5) from_token = parse_string(request, "from") @@ -173,10 +173,12 @@ class RelationPaginationServlet(RestServlet): ) now = self.clock.time_msec() + original_event = yield self._event_serializer.serialize_event(event, now) events = yield self._event_serializer.serialize_events(events, now) return_value = result.to_dict() return_value["chunk"] = events + return_value["original_event"] = original_event defer.returnValue((200, return_value)) diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index 1b01934c19..9954bc094f 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -60,7 +60,7 @@ class PaginationChunk(object): class RelationPaginationToken(object): """Pagination token for relation pagination API. - As the results are order by topological ordering, we can use the + As the results are in topological order, we can use the `topological_ordering` and `stream_ordering` fields of the events at the boundaries of the chunk as pagination tokens. diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index 6bb7d92638..58c6951852 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -126,6 +126,11 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel.json_body["chunk"][0], ) + # We also expect to get the original event (the id of which is self.parent_id) + self.assertEquals( + channel.json_body["original_event"]["event_id"], self.parent_id + ) + # Make sure next_batch has something in it that looks like it could be a # valid token. self.assertIsInstance( -- cgit 1.5.1 From 953dbb79808c018fe34999a662f4c7cef8ea3721 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 10 Jul 2019 16:26:49 +0100 Subject: Remove access-token support from RegistrationStore.register (#5642) The 'token' param is no longer used anywhere except the tests, so let's kill that off too. --- changelog.d/5642.misc | 1 + synapse/handlers/register.py | 2 +- synapse/storage/registration.py | 24 ++++------------------- tests/api/test_auth.py | 2 +- tests/handlers/test_register.py | 6 +----- tests/handlers/test_user_directory.py | 16 +++++---------- tests/storage/test_client_ips.py | 4 +--- tests/storage/test_monthly_active_users.py | 23 +++++++++------------- tests/storage/test_registration.py | 31 +++++++----------------------- 9 files changed, 30 insertions(+), 79 deletions(-) create mode 100644 changelog.d/5642.misc (limited to 'tests') diff --git a/changelog.d/5642.misc b/changelog.d/5642.misc new file mode 100644 index 0000000000..e7f8e214a4 --- /dev/null +++ b/changelog.d/5642.misc @@ -0,0 +1 @@ +Remove access-token support from `RegistrationStore.register`, and rename it. diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index a3e553d5f5..420c5cb5bc 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -584,7 +584,7 @@ class RegistrationHandler(BaseHandler): address=address, ) else: - return self.store.register( + return self.store.register_user( user_id=user_id, password_hash=password_hash, was_guest=was_guest, diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index aea5b3276b..8e217c9408 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -698,10 +698,9 @@ class RegistrationStore( desc="add_access_token_to_user", ) - def register( + def register_user( self, user_id, - token=None, password_hash=None, was_guest=False, make_guest=False, @@ -714,9 +713,6 @@ class RegistrationStore( Args: user_id (str): The desired user ID to register. - token (str): The desired access token to use for this user. If this - is not None, the given access token is associated with the user - id. password_hash (str): Optional. The password hash for this user. was_guest (bool): Optional. Whether this is a guest account being upgraded to a non-guest account. @@ -733,10 +729,9 @@ class RegistrationStore( StoreError if the user_id could not be registered. """ return self.runInteraction( - "register", - self._register, + "register_user", + self._register_user, user_id, - token, password_hash, was_guest, make_guest, @@ -746,11 +741,10 @@ class RegistrationStore( user_type, ) - def _register( + def _register_user( self, txn, user_id, - token, password_hash, was_guest, make_guest, @@ -763,8 +757,6 @@ class RegistrationStore( now = int(self.clock.time()) - next_id = self._access_tokens_id_gen.get_next() - try: if was_guest: # Ensure that the guest user actually exists @@ -812,14 +804,6 @@ class RegistrationStore( if self._account_validity.enabled: self.set_expiration_date_for_user_txn(txn, user_id) - if token: - # it's possible for this to get a conflict, but only for a single user - # since tokens are namespaced based on their user ID - txn.execute( - "INSERT INTO access_tokens(id, user_id, token)" " VALUES (?,?,?)", - (next_id, user_id, token), - ) - if create_profile_with_displayname: # set a default displayname serverside to avoid ugly race # between auto-joins and clients trying to set displaynames diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index d4e75b5b2e..96b26f974b 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -325,7 +325,7 @@ class AuthTestCase(unittest.TestCase): unknown_threepid = {"medium": "email", "address": "unreserved@server.com"} self.hs.config.mau_limits_reserved_threepids = [threepid] - yield self.store.register(user_id="user1", token="123", password_hash=None) + yield self.store.register_user(user_id="user1", password_hash=None) with self.assertRaises(ResourceLimitError): yield self.auth.check_auth_blocking() diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 8197f26d4f..1b7e1dacee 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -77,11 +77,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): store = self.hs.get_datastore() frank = UserID.from_string("@frank:test") self.get_success( - store.register( - user_id=frank.to_string(), - token="jkv;g498752-43gj['eamb!-5", - password_hash=None, - ) + store.register_user(user_id=frank.to_string(), password_hash=None) ) local_part = frank.localpart user_id = frank.to_string() diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index b135486c48..c5e91a8c41 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -47,11 +47,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def test_handle_local_profile_change_with_support_user(self): support_user_id = "@support:test" self.get_success( - self.store.register( - user_id=support_user_id, - token="123", - password_hash=None, - user_type=UserTypes.SUPPORT, + self.store.register_user( + user_id=support_user_id, password_hash=None, user_type=UserTypes.SUPPORT ) ) @@ -73,11 +70,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def test_handle_user_deactivated_support_user(self): s_user_id = "@support:test" self.get_success( - self.store.register( - user_id=s_user_id, - token="123", - password_hash=None, - user_type=UserTypes.SUPPORT, + self.store.register_user( + user_id=s_user_id, password_hash=None, user_type=UserTypes.SUPPORT ) ) @@ -90,7 +84,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def test_handle_user_deactivated_regular_user(self): r_user_id = "@regular:test" self.get_success( - self.store.register(user_id=r_user_id, token="123", password_hash=None) + self.store.register_user(user_id=r_user_id, password_hash=None) ) self.store.remove_from_user_dir = Mock() self.get_success(self.handler.handle_user_deactivated(r_user_id)) diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 59c6f8c227..09305c3bf1 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -185,9 +185,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.hs.config.limit_usage_by_mau = True self.hs.config.max_mau_value = 50 user_id = "@user:server" - self.get_success( - self.store.register(user_id=user_id, token="123", password_hash=None) - ) + self.get_success(self.store.register_user(user_id=user_id, password_hash=None)) active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertFalse(active) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 0ce0b991f9..1494650d10 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -53,10 +53,10 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): # -1 because user3 is a support user and does not count user_num = len(threepids) - 1 - self.store.register(user_id=user1, token="123", password_hash=None) - self.store.register(user_id=user2, token="456", password_hash=None) - self.store.register( - user_id=user3, token="789", password_hash=None, user_type=UserTypes.SUPPORT + self.store.register_user(user_id=user1, password_hash=None) + self.store.register_user(user_id=user2, password_hash=None) + self.store.register_user( + user_id=user3, password_hash=None, user_type=UserTypes.SUPPORT ) self.pump() @@ -161,9 +161,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): def test_populate_monthly_users_is_guest(self): # Test that guest users are not added to mau list user_id = "@user_id:host" - self.store.register( - user_id=user_id, token="123", password_hash=None, make_guest=True - ) + self.store.register_user(user_id=user_id, password_hash=None, make_guest=True) self.store.upsert_monthly_active_user = Mock() self.store.populate_monthly_active_users(user_id) self.pump() @@ -216,8 +214,8 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.assertEquals(self.get_success(count), 0) # Test reserved registed users - self.store.register(user_id=user1, token="123", password_hash=None) - self.store.register(user_id=user2, token="456", password_hash=None) + self.store.register_user(user_id=user1, password_hash=None) + self.store.register_user(user_id=user2, password_hash=None) self.pump() now = int(self.hs.get_clock().time_msec()) @@ -232,11 +230,8 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.pump() self.assertEqual(self.get_success(count), 0) - self.store.register( - user_id=support_user_id, - token="123", - password_hash=None, - user_type=UserTypes.SUPPORT, + self.store.register_user( + user_id=support_user_id, password_hash=None, user_type=UserTypes.SUPPORT ) self.store.upsert_monthly_active_user(support_user_id) diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 625b651e91..9365c4622d 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -37,7 +37,7 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_register(self): - yield self.store.register(self.user_id, self.tokens[0], self.pwhash) + yield self.store.register_user(self.user_id, self.pwhash) self.assertEquals( { @@ -53,15 +53,9 @@ class RegistrationStoreTestCase(unittest.TestCase): (yield self.store.get_user_by_id(self.user_id)), ) - result = yield self.store.get_user_by_access_token(self.tokens[0]) - - self.assertDictContainsSubset({"name": self.user_id}, result) - - self.assertTrue("token_id" in result) - @defer.inlineCallbacks def test_add_tokens(self): - yield self.store.register(self.user_id, self.tokens[0], self.pwhash) + yield self.store.register_user(self.user_id, self.pwhash) yield self.store.add_access_token_to_user( self.user_id, self.tokens[1], self.device_id ) @@ -77,7 +71,8 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_user_delete_access_tokens(self): # add some tokens - yield self.store.register(self.user_id, self.tokens[0], self.pwhash) + yield self.store.register_user(self.user_id, self.pwhash) + yield self.store.add_access_token_to_user(self.user_id, self.tokens[0]) yield self.store.add_access_token_to_user( self.user_id, self.tokens[1], self.device_id ) @@ -108,24 +103,12 @@ class RegistrationStoreTestCase(unittest.TestCase): res = yield self.store.is_support_user(None) self.assertFalse(res) - yield self.store.register(user_id=TEST_USER, token="123", password_hash=None) + yield self.store.register_user(user_id=TEST_USER, password_hash=None) res = yield self.store.is_support_user(TEST_USER) self.assertFalse(res) - yield self.store.register( - user_id=SUPPORT_USER, - token="456", - password_hash=None, - user_type=UserTypes.SUPPORT, + yield self.store.register_user( + user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT ) res = yield self.store.is_support_user(SUPPORT_USER) self.assertTrue(res) - - -class TokenGenerator: - def __init__(self): - self._last_issued_token = 0 - - def generate(self, user_id): - self._last_issued_token += 1 - return "%s-%d" % (user_id, self._last_issued_token) -- cgit 1.5.1 From 1890cfcf82aa3e69530e97bf9ce783e390f22fbe Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 10 Jul 2019 19:10:07 +0100 Subject: Inline issue_access_token (#5659) this is only used in one place, so it's clearer if we inline it and reduce the API surface. Also, fixes a buglet where we would create an access token even if we were about to block the user (we would never return the AT, so the user could never use it, but it was still created and added to the db.) --- changelog.d/5659.misc | 1 + synapse/handlers/auth.py | 10 +++------- tests/api/test_auth.py | 2 +- 3 files changed, 5 insertions(+), 8 deletions(-) create mode 100644 changelog.d/5659.misc (limited to 'tests') diff --git a/changelog.d/5659.misc b/changelog.d/5659.misc new file mode 100644 index 0000000000..686001295c --- /dev/null +++ b/changelog.d/5659.misc @@ -0,0 +1 @@ +Inline issue_access_token. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index ef5585aa99..da312b188e 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -578,9 +578,11 @@ class AuthHandler(BaseHandler): StoreError if there was a problem storing the token. """ logger.info("Logging in user %s on device %s", user_id, device_id) - access_token = yield self.issue_access_token(user_id, device_id) yield self.auth.check_auth_blocking(user_id) + access_token = self.macaroon_gen.generate_access_token(user_id) + yield self.store.add_access_token_to_user(user_id, access_token, device_id) + # the device *should* have been registered before we got here; however, # it's possible we raced against a DELETE operation. The thing we # really don't want is active access_tokens without a record of the @@ -831,12 +833,6 @@ class AuthHandler(BaseHandler): defer.returnValue(None) defer.returnValue(user_id) - @defer.inlineCallbacks - def issue_access_token(self, user_id, device_id=None): - access_token = self.macaroon_gen.generate_access_token(user_id) - yield self.store.add_access_token_to_user(user_id, access_token, device_id) - defer.returnValue(access_token) - @defer.inlineCallbacks def validate_short_term_login_token_and_get_user_id(self, login_token): auth_api = self.hs.get_auth() diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 96b26f974b..ddf2b578b3 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -244,7 +244,7 @@ class AuthTestCase(unittest.TestCase): USER_ID = "@percy:matrix.org" self.store.add_access_token_to_user = Mock() - token = yield self.hs.handlers.auth_handler.issue_access_token( + token = yield self.hs.handlers.auth_handler.get_access_token_for_user_id( USER_ID, "DEVICE" ) self.store.add_access_token_to_user.assert_called_with(USER_ID, token, "DEVICE") -- cgit 1.5.1 From 0a4001eba1eb22fc7c39f257c8d5a326b1a489ad Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 11 Jul 2019 11:06:23 +0100 Subject: Clean up exception handling for access_tokens (#5656) First of all, let's get rid of `TOKEN_NOT_FOUND_HTTP_STATUS`. It was a hack we did at one point when it was possible to return either a 403 or a 401 if the creds were missing. We always return a 401 in these cases now (thankfully), so it's not needed. Let's also stop abusing `AuthError` for these cases. Honestly they have nothing that relates them to the other places that `AuthError` is used, other than the fact that they are loosely under the 'Auth' banner. It makes no sense for them to share exception classes. Instead, let's add a couple of new exception classes: `InvalidClientTokenError` and `MissingClientTokenError`, for the `M_UNKNOWN_TOKEN` and `M_MISSING_TOKEN` cases respectively - and an `InvalidClientCredentialsError` base class for the two of them. --- changelog.d/5656.misc | 1 + synapse/api/auth.py | 127 +++++++++++------------------------- synapse/api/errors.py | 33 +++++++++- synapse/rest/client/v1/directory.py | 10 ++- synapse/rest/client/v1/room.py | 9 ++- tests/api/test_auth.py | 31 +++++++-- 6 files changed, 111 insertions(+), 100 deletions(-) create mode 100644 changelog.d/5656.misc (limited to 'tests') diff --git a/changelog.d/5656.misc b/changelog.d/5656.misc new file mode 100644 index 0000000000..a8de20a7d0 --- /dev/null +++ b/changelog.d/5656.misc @@ -0,0 +1 @@ +Clean up exception handling around client access tokens. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 86f145649c..afc6400948 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -25,7 +25,13 @@ from twisted.internet import defer import synapse.types from synapse import event_auth from synapse.api.constants import EventTypes, JoinRules, Membership -from synapse.api.errors import AuthError, Codes, ResourceLimitError +from synapse.api.errors import ( + AuthError, + Codes, + InvalidClientTokenError, + MissingClientTokenError, + ResourceLimitError, +) from synapse.config.server import is_threepid_reserved from synapse.types import UserID from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache @@ -63,7 +69,6 @@ class Auth(object): self.clock = hs.get_clock() self.store = hs.get_datastore() self.state = hs.get_state_handler() - self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000) register_cache("cache", "token_cache", self.token_cache) @@ -189,18 +194,17 @@ class Auth(object): Returns: defer.Deferred: resolves to a ``synapse.types.Requester`` object Raises: - AuthError if no user by that token exists or the token is invalid. + InvalidClientCredentialsError if no user by that token exists or the token + is invalid. + AuthError if access is denied for the user in the access token """ - # Can optionally look elsewhere in the request (e.g. headers) try: ip_addr = self.hs.get_ip_from_request(request) user_agent = request.requestHeaders.getRawHeaders( b"User-Agent", default=[b""] )[0].decode("ascii", "surrogateescape") - access_token = self.get_access_token_from_request( - request, self.TOKEN_NOT_FOUND_HTTP_STATUS - ) + access_token = self.get_access_token_from_request(request) user_id, app_service = yield self._get_appservice_user_id(request) if user_id: @@ -264,18 +268,12 @@ class Auth(object): ) ) except KeyError: - raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, - "Missing access token.", - errcode=Codes.MISSING_TOKEN, - ) + raise MissingClientTokenError() @defer.inlineCallbacks def _get_appservice_user_id(self, request): app_service = self.store.get_app_service_by_token( - self.get_access_token_from_request( - request, self.TOKEN_NOT_FOUND_HTTP_STATUS - ) + self.get_access_token_from_request(request) ) if app_service is None: defer.returnValue((None, None)) @@ -313,7 +311,8 @@ class Auth(object): `token_id` (int|None): access token id. May be None if guest `device_id` (str|None): device corresponding to access token Raises: - AuthError if no user by that token exists or the token is invalid. + InvalidClientCredentialsError if no user by that token exists or the token + is invalid. """ if rights == "access": @@ -331,11 +330,7 @@ class Auth(object): if not guest: # non-guest access tokens must be in the database logger.warning("Unrecognised access token - not in store.") - raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, - "Unrecognised access token.", - errcode=Codes.UNKNOWN_TOKEN, - ) + raise InvalidClientTokenError() # Guest access tokens are not stored in the database (there can # only be one access token per guest, anyway). @@ -350,16 +345,10 @@ class Auth(object): # guest tokens. stored_user = yield self.store.get_user_by_id(user_id) if not stored_user: - raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, - "Unknown user_id %s" % user_id, - errcode=Codes.UNKNOWN_TOKEN, - ) + raise InvalidClientTokenError("Unknown user_id %s" % user_id) if not stored_user["is_guest"]: - raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, - "Guest access token used for regular user", - errcode=Codes.UNKNOWN_TOKEN, + raise InvalidClientTokenError( + "Guest access token used for regular user" ) ret = { "user": user, @@ -386,11 +375,7 @@ class Auth(object): ValueError, ) as e: logger.warning("Invalid macaroon in auth: %s %s", type(e), e) - raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, - "Invalid macaroon passed.", - errcode=Codes.UNKNOWN_TOKEN, - ) + raise InvalidClientTokenError("Invalid macaroon passed.") def _parse_and_validate_macaroon(self, token, rights="access"): """Takes a macaroon and tries to parse and validate it. This is cached @@ -430,11 +415,7 @@ class Auth(object): macaroon, rights, self.hs.config.expire_access_token, user_id=user_id ) except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): - raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, - "Invalid macaroon passed.", - errcode=Codes.UNKNOWN_TOKEN, - ) + raise InvalidClientTokenError("Invalid macaroon passed.") if not has_expiry and rights == "access": self.token_cache[token] = (user_id, guest) @@ -453,17 +434,14 @@ class Auth(object): (str) user id Raises: - AuthError if there is no user_id caveat in the macaroon + InvalidClientCredentialsError if there is no user_id caveat in the + macaroon """ user_prefix = "user_id = " for caveat in macaroon.caveats: if caveat.caveat_id.startswith(user_prefix): return caveat.caveat_id[len(user_prefix) :] - raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, - "No user caveat in macaroon", - errcode=Codes.UNKNOWN_TOKEN, - ) + raise InvalidClientTokenError("No user caveat in macaroon") def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id): """ @@ -531,22 +509,13 @@ class Auth(object): defer.returnValue(user_info) def get_appservice_by_req(self, request): - try: - token = self.get_access_token_from_request( - request, self.TOKEN_NOT_FOUND_HTTP_STATUS - ) - service = self.store.get_app_service_by_token(token) - if not service: - logger.warn("Unrecognised appservice access token.") - raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, - "Unrecognised access token.", - errcode=Codes.UNKNOWN_TOKEN, - ) - request.authenticated_entity = service.sender - return defer.succeed(service) - except KeyError: - raise AuthError(self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.") + token = self.get_access_token_from_request(request) + service = self.store.get_app_service_by_token(token) + if not service: + logger.warn("Unrecognised appservice access token.") + raise InvalidClientTokenError() + request.authenticated_entity = service.sender + return defer.succeed(service) def is_server_admin(self, user): """ Check if the given user is a local server admin. @@ -692,20 +661,16 @@ class Auth(object): return bool(query_params) or bool(auth_headers) @staticmethod - def get_access_token_from_request(request, token_not_found_http_status=401): + def get_access_token_from_request(request): """Extracts the access_token from the request. Args: request: The http request. - token_not_found_http_status(int): The HTTP status code to set in the - AuthError if the token isn't found. This is used in some of the - legacy APIs to change the status code to 403 from the default of - 401 since some of the old clients depended on auth errors returning - 403. Returns: unicode: The access_token Raises: - AuthError: If there isn't an access_token in the request. + MissingClientTokenError: If there isn't a single access_token in the + request """ auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") @@ -714,34 +679,20 @@ class Auth(object): # Try the get the access_token from a "Authorization: Bearer" # header if query_params is not None: - raise AuthError( - token_not_found_http_status, - "Mixing Authorization headers and access_token query parameters.", - errcode=Codes.MISSING_TOKEN, + raise MissingClientTokenError( + "Mixing Authorization headers and access_token query parameters." ) if len(auth_headers) > 1: - raise AuthError( - token_not_found_http_status, - "Too many Authorization headers.", - errcode=Codes.MISSING_TOKEN, - ) + raise MissingClientTokenError("Too many Authorization headers.") parts = auth_headers[0].split(b" ") if parts[0] == b"Bearer" and len(parts) == 2: return parts[1].decode("ascii") else: - raise AuthError( - token_not_found_http_status, - "Invalid Authorization header.", - errcode=Codes.MISSING_TOKEN, - ) + raise MissingClientTokenError("Invalid Authorization header.") else: # Try to get the access_token from the query params. if not query_params: - raise AuthError( - token_not_found_http_status, - "Missing access token.", - errcode=Codes.MISSING_TOKEN, - ) + raise MissingClientTokenError() return query_params[0].decode("ascii") diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 28b5c2af9b..41fd04cd54 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -210,7 +210,9 @@ class NotFoundError(SynapseError): class AuthError(SynapseError): - """An error raised when there was a problem authorising an event.""" + """An error raised when there was a problem authorising an event, and at various + other poorly-defined times. + """ def __init__(self, *args, **kwargs): if "errcode" not in kwargs: @@ -218,6 +220,35 @@ class AuthError(SynapseError): super(AuthError, self).__init__(*args, **kwargs) +class InvalidClientCredentialsError(SynapseError): + """An error raised when there was a problem with the authorisation credentials + in a client request. + + https://matrix.org/docs/spec/client_server/r0.5.0#using-access-tokens: + + When credentials are required but missing or invalid, the HTTP call will + return with a status of 401 and the error code, M_MISSING_TOKEN or + M_UNKNOWN_TOKEN respectively. + """ + + def __init__(self, msg, errcode): + super().__init__(code=401, msg=msg, errcode=errcode) + + +class MissingClientTokenError(InvalidClientCredentialsError): + """Raised when we couldn't find the access token in a request""" + + def __init__(self, msg="Missing access token"): + super().__init__(msg=msg, errcode="M_MISSING_TOKEN") + + +class InvalidClientTokenError(InvalidClientCredentialsError): + """Raised when we didn't understand the access token in a request""" + + def __init__(self, msg="Unrecognised access token"): + super().__init__(msg=msg, errcode="M_UNKNOWN_TOKEN") + + class ResourceLimitError(SynapseError): """ Any error raised when there is a problem with resource usage. diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index dd0d38ea5c..57542c2b4b 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -18,7 +18,13 @@ import logging from twisted.internet import defer -from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError +from synapse.api.errors import ( + AuthError, + Codes, + InvalidClientCredentialsError, + NotFoundError, + SynapseError, +) from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.rest.client.v2_alpha._base import client_patterns from synapse.types import RoomAlias @@ -97,7 +103,7 @@ class ClientDirectoryServer(RestServlet): room_alias.to_string(), ) defer.returnValue((200, {})) - except AuthError: + except InvalidClientCredentialsError: # fallback to default user behaviour if they aren't an AS pass diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index cca7e45ddb..7709c2d705 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -24,7 +24,12 @@ from canonicaljson import json from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import AuthError, Codes, SynapseError +from synapse.api.errors import ( + AuthError, + Codes, + InvalidClientCredentialsError, + SynapseError, +) from synapse.api.filtering import Filter from synapse.events.utils import format_event_for_client_v2 from synapse.http.servlet import ( @@ -307,7 +312,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): try: yield self.auth.get_user_by_req(request, allow_guest=True) - except AuthError as e: + except InvalidClientCredentialsError as e: # Option to allow servers to require auth when accessing # /publicRooms via CS API. This is especially helpful in private # federations. diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index ddf2b578b3..ee92ceeb60 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -21,7 +21,14 @@ from twisted.internet import defer import synapse.handlers.auth from synapse.api.auth import Auth -from synapse.api.errors import AuthError, Codes, ResourceLimitError +from synapse.api.errors import ( + AuthError, + Codes, + InvalidClientCredentialsError, + InvalidClientTokenError, + MissingClientTokenError, + ResourceLimitError, +) from synapse.types import UserID from tests import unittest @@ -70,7 +77,9 @@ class AuthTestCase(unittest.TestCase): request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() d = self.auth.get_user_by_req(request) - self.failureResultOf(d, AuthError) + f = self.failureResultOf(d, InvalidClientTokenError).value + self.assertEqual(f.code, 401) + self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") def test_get_user_by_req_user_missing_token(self): user_info = {"name": self.test_user, "token_id": "ditto"} @@ -79,7 +88,9 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.requestHeaders.getRawHeaders = mock_getRawHeaders() d = self.auth.get_user_by_req(request) - self.failureResultOf(d, AuthError) + f = self.failureResultOf(d, MissingClientTokenError).value + self.assertEqual(f.code, 401) + self.assertEqual(f.errcode, "M_MISSING_TOKEN") @defer.inlineCallbacks def test_get_user_by_req_appservice_valid_token(self): @@ -133,7 +144,9 @@ class AuthTestCase(unittest.TestCase): request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() d = self.auth.get_user_by_req(request) - self.failureResultOf(d, AuthError) + f = self.failureResultOf(d, InvalidClientTokenError).value + self.assertEqual(f.code, 401) + self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") def test_get_user_by_req_appservice_bad_token(self): self.store.get_app_service_by_token = Mock(return_value=None) @@ -143,7 +156,9 @@ class AuthTestCase(unittest.TestCase): request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() d = self.auth.get_user_by_req(request) - self.failureResultOf(d, AuthError) + f = self.failureResultOf(d, InvalidClientTokenError).value + self.assertEqual(f.code, 401) + self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") def test_get_user_by_req_appservice_missing_token(self): app_service = Mock(token="foobar", url="a_url", sender=self.test_user) @@ -153,7 +168,9 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.requestHeaders.getRawHeaders = mock_getRawHeaders() d = self.auth.get_user_by_req(request) - self.failureResultOf(d, AuthError) + f = self.failureResultOf(d, MissingClientTokenError).value + self.assertEqual(f.code, 401) + self.assertEqual(f.errcode, "M_MISSING_TOKEN") @defer.inlineCallbacks def test_get_user_by_req_appservice_valid_token_valid_user_id(self): @@ -280,7 +297,7 @@ class AuthTestCase(unittest.TestCase): request.args[b"access_token"] = [guest_tok.encode("ascii")] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - with self.assertRaises(AuthError) as cm: + with self.assertRaises(InvalidClientCredentialsError) as cm: yield self.auth.get_user_by_req(request, allow_guest=True) self.assertEqual(401, cm.exception.code) -- cgit 1.5.1 From 6bb0357c949dcc433e2316c395199644ef75e02c Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 12 Jul 2019 10:16:23 +0100 Subject: Add a mechanism for per-test configs (#5657) It's useful to be able to tweak the homeserver config to be used for each test. This PR adds a mechanism to do so. --- changelog.d/5657.misc | 1 + tests/unittest.py | 55 ++++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 55 insertions(+), 1 deletion(-) create mode 100644 changelog.d/5657.misc (limited to 'tests') diff --git a/changelog.d/5657.misc b/changelog.d/5657.misc new file mode 100644 index 0000000000..bdec9ae4c0 --- /dev/null +++ b/changelog.d/5657.misc @@ -0,0 +1 @@ +Add a mechanism for per-test homeserver configuration in the unit tests. diff --git a/tests/unittest.py b/tests/unittest.py index a09e76c7c2..0f0c2ad69d 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -157,6 +157,21 @@ class HomeserverTestCase(TestCase): """ A base TestCase that reduces boilerplate for HomeServer-using test cases. + Defines a setUp method which creates a mock reactor, and instantiates a homeserver + running on that reactor. + + There are various hooks for modifying the way that the homeserver is instantiated: + + * override make_homeserver, for example by making it pass different parameters into + setup_test_homeserver. + + * override default_config, to return a modified configuration dictionary for use + by setup_test_homeserver. + + * On a per-test basis, you can use the @override_config decorator to give a + dictionary containing additional configuration settings to be added to the basic + config dict. + Attributes: servlets (list[function]): List of servlet registration function. user_id (str): The user ID to assume if auth is hijacked. @@ -168,6 +183,13 @@ class HomeserverTestCase(TestCase): hijack_auth = True needs_threadpool = False + def __init__(self, methodName, *args, **kwargs): + super().__init__(methodName, *args, **kwargs) + + # see if we have any additional config for this test + method = getattr(self, methodName) + self._extra_config = getattr(method, "_extra_config", None) + def setUp(self): """ Set up the TestCase by calling the homeserver constructor, optionally @@ -276,7 +298,14 @@ class HomeserverTestCase(TestCase): Args: name (str): The homeserver name/domain. """ - return default_config(name) + config = default_config(name) + + # apply any additional config which was specified via the override_config + # decorator. + if self._extra_config is not None: + config.update(self._extra_config) + + return config def prepare(self, reactor, clock, homeserver): """ @@ -534,3 +563,27 @@ class HomeserverTestCase(TestCase): ) self.render(request) self.assertEqual(channel.code, 403, channel.result) + + +def override_config(extra_config): + """A decorator which can be applied to test functions to give additional HS config + + For use + + For example: + + class MyTestCase(HomeserverTestCase): + @override_config({"enable_registration": False, ...}) + def test_foo(self): + ... + + Args: + extra_config(dict): Additional config settings to be merged into the default + config dict before instantiating the test homeserver. + """ + + def decorator(func): + func._extra_config = extra_config + return func + + return decorator -- cgit 1.5.1 From 5f158ec039e4753959aad9b8d288b3d8cb4959a1 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 12 Jul 2019 17:26:02 +0100 Subject: Implement access token expiry (#5660) Record how long an access token is valid for, and raise a soft-logout once it expires. --- changelog.d/5660.feature | 1 + docs/sample_config.yaml | 11 +++ synapse/api/auth.py | 12 +++ synapse/api/errors.py | 8 +- synapse/config/registration.py | 16 +++ synapse/handlers/auth.py | 17 +++- synapse/handlers/register.py | 35 ++++--- synapse/storage/registration.py | 19 +++- .../schema/delta/55/access_token_expiry.sql | 18 ++++ tests/api/test_auth.py | 6 +- tests/handlers/test_auth.py | 20 +++- tests/handlers/test_register.py | 5 +- tests/rest/client/v1/test_login.py | 108 +++++++++++++++++++++ tests/storage/test_registration.py | 8 +- 14 files changed, 253 insertions(+), 31 deletions(-) create mode 100644 changelog.d/5660.feature create mode 100644 synapse/storage/schema/delta/55/access_token_expiry.sql (limited to 'tests') diff --git a/changelog.d/5660.feature b/changelog.d/5660.feature new file mode 100644 index 0000000000..82889fdaf1 --- /dev/null +++ b/changelog.d/5660.feature @@ -0,0 +1 @@ +Implement `session_lifetime` configuration option, after which access tokens will expire. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 0462f0a17a..663ff31622 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -786,6 +786,17 @@ uploads_path: "DATADIR/uploads" # renew_at: 1w # renew_email_subject: "Renew your %(app)s account" +# Time that a user's session remains valid for, after they log in. +# +# Note that this is not currently compatible with guest logins. +# +# Note also that this is calculated at login time: changes are not applied +# retrospectively to users who have already logged in. +# +# By default, this is infinite. +# +#session_lifetime: 24h + # The user must provide all of the below types of 3PID when registering. # #registrations_require_3pid: diff --git a/synapse/api/auth.py b/synapse/api/auth.py index afc6400948..d9e943c39c 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -319,6 +319,17 @@ class Auth(object): # first look in the database r = yield self._look_up_user_by_access_token(token) if r: + valid_until_ms = r["valid_until_ms"] + if ( + valid_until_ms is not None + and valid_until_ms < self.clock.time_msec() + ): + # there was a valid access token, but it has expired. + # soft-logout the user. + raise InvalidClientTokenError( + msg="Access token has expired", soft_logout=True + ) + defer.returnValue(r) # otherwise it needs to be a valid macaroon @@ -505,6 +516,7 @@ class Auth(object): "token_id": ret.get("token_id", None), "is_guest": False, "device_id": ret.get("device_id"), + "valid_until_ms": ret.get("valid_until_ms"), } defer.returnValue(user_info) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 41fd04cd54..a6e753c30c 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -245,8 +245,14 @@ class MissingClientTokenError(InvalidClientCredentialsError): class InvalidClientTokenError(InvalidClientCredentialsError): """Raised when we didn't understand the access token in a request""" - def __init__(self, msg="Unrecognised access token"): + def __init__(self, msg="Unrecognised access token", soft_logout=False): super().__init__(msg=msg, errcode="M_UNKNOWN_TOKEN") + self._soft_logout = soft_logout + + def error_dict(self): + d = super().error_dict() + d["soft_logout"] = self._soft_logout + return d class ResourceLimitError(SynapseError): diff --git a/synapse/config/registration.py b/synapse/config/registration.py index b895c4e9f4..34cb11468c 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -84,6 +84,11 @@ class RegistrationConfig(Config): "disable_msisdn_registration", False ) + session_lifetime = config.get("session_lifetime") + if session_lifetime is not None: + session_lifetime = self.parse_duration(session_lifetime) + self.session_lifetime = session_lifetime + def generate_config_section(self, generate_secrets=False, **kwargs): if generate_secrets: registration_shared_secret = 'registration_shared_secret: "%s"' % ( @@ -141,6 +146,17 @@ class RegistrationConfig(Config): # renew_at: 1w # renew_email_subject: "Renew your %%(app)s account" + # Time that a user's session remains valid for, after they log in. + # + # Note that this is not currently compatible with guest logins. + # + # Note also that this is calculated at login time: changes are not applied + # retrospectively to users who have already logged in. + # + # By default, this is infinite. + # + #session_lifetime: 24h + # The user must provide all of the below types of 3PID when registering. # #registrations_require_3pid: diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index da312b188e..b74a6e9c62 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -15,6 +15,7 @@ # limitations under the License. import logging +import time import unicodedata import attr @@ -558,7 +559,7 @@ class AuthHandler(BaseHandler): return self.sessions[session_id] @defer.inlineCallbacks - def get_access_token_for_user_id(self, user_id, device_id=None): + def get_access_token_for_user_id(self, user_id, device_id, valid_until_ms): """ Creates a new access token for the user with the given user ID. @@ -572,16 +573,26 @@ class AuthHandler(BaseHandler): device_id (str|None): the device ID to associate with the tokens. None to leave the tokens unassociated with a device (deprecated: we should always have a device ID) + valid_until_ms (int|None): when the token is valid until. None for + no expiry. Returns: The access token for the user's session. Raises: StoreError if there was a problem storing the token. """ - logger.info("Logging in user %s on device %s", user_id, device_id) + fmt_expiry = "" + if valid_until_ms is not None: + fmt_expiry = time.strftime( + " until %Y-%m-%d %H:%M:%S", time.localtime(valid_until_ms / 1000.0) + ) + logger.info("Logging in user %s on device %s%s", user_id, device_id, fmt_expiry) + yield self.auth.check_auth_blocking(user_id) access_token = self.macaroon_gen.generate_access_token(user_id) - yield self.store.add_access_token_to_user(user_id, access_token, device_id) + yield self.store.add_access_token_to_user( + user_id, access_token, device_id, valid_until_ms + ) # the device *should* have been registered before we got here; however, # it's possible we raced against a DELETE operation. The thing we diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 420c5cb5bc..bb7cfd71b9 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -84,6 +84,8 @@ class RegistrationHandler(BaseHandler): self.device_handler = hs.get_device_handler() self.pusher_pool = hs.get_pusherpool() + self.session_lifetime = hs.config.session_lifetime + @defer.inlineCallbacks def check_username(self, localpart, guest_access_token=None, assigned_user_id=None): if types.contains_invalid_mxid_characters(localpart): @@ -599,6 +601,8 @@ class RegistrationHandler(BaseHandler): def register_device(self, user_id, device_id, initial_display_name, is_guest=False): """Register a device for a user and generate an access token. + The access token will be limited by the homeserver's session_lifetime config. + Args: user_id (str): full canonical @user:id device_id (str|None): The device ID to check, or None to generate @@ -619,20 +623,29 @@ class RegistrationHandler(BaseHandler): is_guest=is_guest, ) defer.returnValue((r["device_id"], r["access_token"])) - else: - device_id = yield self.device_handler.check_device_registered( - user_id, device_id, initial_display_name - ) + + valid_until_ms = None + if self.session_lifetime is not None: if is_guest: - access_token = self.macaroon_gen.generate_access_token( - user_id, ["guest = true"] - ) - else: - access_token = yield self._auth_handler.get_access_token_for_user_id( - user_id, device_id=device_id + raise Exception( + "session_lifetime is not currently implemented for guest access" ) + valid_until_ms = self.clock.time_msec() + self.session_lifetime + + device_id = yield self.device_handler.check_device_registered( + user_id, device_id, initial_display_name + ) + if is_guest: + assert valid_until_ms is None + access_token = self.macaroon_gen.generate_access_token( + user_id, ["guest = true"] + ) + else: + access_token = yield self._auth_handler.get_access_token_for_user_id( + user_id, device_id=device_id, valid_until_ms=valid_until_ms + ) - defer.returnValue((device_id, access_token)) + defer.returnValue((device_id, access_token)) @defer.inlineCallbacks def post_registration_actions( diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 73580f1725..8b2c2a97ab 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -90,7 +90,8 @@ class RegistrationWorkerStore(SQLBaseStore): token (str): The access token of a user. Returns: defer.Deferred: None, if the token did not match, otherwise dict - including the keys `name`, `is_guest`, `device_id`, `token_id`. + including the keys `name`, `is_guest`, `device_id`, `token_id`, + `valid_until_ms`. """ return self.runInteraction( "get_user_by_access_token", self._query_for_auth, token @@ -284,7 +285,7 @@ class RegistrationWorkerStore(SQLBaseStore): def _query_for_auth(self, txn, token): sql = ( "SELECT users.name, users.is_guest, access_tokens.id as token_id," - " access_tokens.device_id" + " access_tokens.device_id, access_tokens.valid_until_ms" " FROM users" " INNER JOIN access_tokens on users.name = access_tokens.user_id" " WHERE token = ?" @@ -679,14 +680,16 @@ class RegistrationStore( defer.returnValue(batch_size) @defer.inlineCallbacks - def add_access_token_to_user(self, user_id, token, device_id=None): + def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms): """Adds an access token for the given user. Args: user_id (str): The user ID. token (str): The new access token to add. device_id (str): ID of the device to associate with the access - token + token + valid_until_ms (int|None): when the token is valid until. None for + no expiry. Raises: StoreError if there was a problem adding this. """ @@ -694,7 +697,13 @@ class RegistrationStore( yield self._simple_insert( "access_tokens", - {"id": next_id, "user_id": user_id, "token": token, "device_id": device_id}, + { + "id": next_id, + "user_id": user_id, + "token": token, + "device_id": device_id, + "valid_until_ms": valid_until_ms, + }, desc="add_access_token_to_user", ) diff --git a/synapse/storage/schema/delta/55/access_token_expiry.sql b/synapse/storage/schema/delta/55/access_token_expiry.sql new file mode 100644 index 0000000000..4590604bfd --- /dev/null +++ b/synapse/storage/schema/delta/55/access_token_expiry.sql @@ -0,0 +1,18 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- when this access token can be used until, in ms since the epoch. NULL means the token +-- never expires. +ALTER TABLE access_tokens ADD COLUMN valid_until_ms BIGINT; diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index ee92ceeb60..c0cb8ef296 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -262,9 +262,11 @@ class AuthTestCase(unittest.TestCase): self.store.add_access_token_to_user = Mock() token = yield self.hs.handlers.auth_handler.get_access_token_for_user_id( - USER_ID, "DEVICE" + USER_ID, "DEVICE", valid_until_ms=None + ) + self.store.add_access_token_to_user.assert_called_with( + USER_ID, token, "DEVICE", None ) - self.store.add_access_token_to_user.assert_called_with(USER_ID, token, "DEVICE") def get_user(tok): if token != tok: diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index b204a0700d..b03103d96f 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -117,7 +117,9 @@ class AuthTestCase(unittest.TestCase): def test_mau_limits_disabled(self): self.hs.config.limit_usage_by_mau = False # Ensure does not throw exception - yield self.auth_handler.get_access_token_for_user_id("user_a") + yield self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ) yield self.auth_handler.validate_short_term_login_token_and_get_user_id( self._get_macaroon().serialize() @@ -131,7 +133,9 @@ class AuthTestCase(unittest.TestCase): ) with self.assertRaises(ResourceLimitError): - yield self.auth_handler.get_access_token_for_user_id("user_a") + yield self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ) self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.large_number_of_users) @@ -150,7 +154,9 @@ class AuthTestCase(unittest.TestCase): return_value=defer.succeed(self.hs.config.max_mau_value) ) with self.assertRaises(ResourceLimitError): - yield self.auth_handler.get_access_token_for_user_id("user_a") + yield self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ) self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) @@ -166,7 +172,9 @@ class AuthTestCase(unittest.TestCase): self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) - yield self.auth_handler.get_access_token_for_user_id("user_a") + yield self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ) self.hs.get_datastore().user_last_seen_monthly_active = Mock( return_value=defer.succeed(self.hs.get_clock().time_msec()) ) @@ -185,7 +193,9 @@ class AuthTestCase(unittest.TestCase): return_value=defer.succeed(self.small_number_of_users) ) # Ensure does not raise exception - yield self.auth_handler.get_access_token_for_user_id("user_a") + yield self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ) self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.small_number_of_users) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 1b7e1dacee..90d0129374 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -272,7 +272,10 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ) else: yield self.hs.get_auth_handler().delete_access_tokens_for_user(user_id) - yield self.store.add_access_token_to_user(user_id=user_id, token=token) + + yield self.store.add_access_token_to_user( + user_id=user_id, token=token, device_id=None, valid_until_ms=None + ) if displayname is not None: # logger.info("setting user display name: %s -> %s", user_id, displayname) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 0397f91a9e..eae5411325 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -2,10 +2,14 @@ import json import synapse.rest.admin from synapse.rest.client.v1 import login +from synapse.rest.client.v2_alpha import devices +from synapse.rest.client.v2_alpha.account import WhoamiRestServlet from tests import unittest +from tests.unittest import override_config LOGIN_URL = b"/_matrix/client/r0/login" +TEST_URL = b"/_matrix/client/r0/account/whoami" class LoginRestServletTestCase(unittest.HomeserverTestCase): @@ -13,6 +17,8 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, + devices.register_servlets, + lambda hs, http_server: WhoamiRestServlet(hs).register(http_server), ] def make_homeserver(self, reactor, clock): @@ -144,3 +150,105 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEquals(channel.result["code"], b"403", channel.result) + + @override_config({"session_lifetime": "24h"}) + def test_soft_logout(self): + self.register_user("kermit", "monkey") + + # we shouldn't be able to make requests without an access token + request, channel = self.make_request(b"GET", TEST_URL) + self.render(request) + self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEquals(channel.json_body["errcode"], "M_MISSING_TOKEN") + + # log in as normal + params = { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": "kermit"}, + "password": "monkey", + } + request, channel = self.make_request(b"POST", LOGIN_URL, params) + self.render(request) + + self.assertEquals(channel.code, 200, channel.result) + access_token = channel.json_body["access_token"] + device_id = channel.json_body["device_id"] + + # we should now be able to make requests with the access token + request, channel = self.make_request( + b"GET", TEST_URL, access_token=access_token + ) + self.render(request) + self.assertEquals(channel.code, 200, channel.result) + + # time passes + self.reactor.advance(24 * 3600) + + # ... and we should be soft-logouted + request, channel = self.make_request( + b"GET", TEST_URL, access_token=access_token + ) + self.render(request) + self.assertEquals(channel.code, 401, channel.result) + self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEquals(channel.json_body["soft_logout"], True) + + # + # test behaviour after deleting the expired device + # + + # we now log in as a different device + access_token_2 = self.login("kermit", "monkey") + + # more requests with the expired token should still return a soft-logout + self.reactor.advance(3600) + request, channel = self.make_request( + b"GET", TEST_URL, access_token=access_token + ) + self.render(request) + self.assertEquals(channel.code, 401, channel.result) + self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEquals(channel.json_body["soft_logout"], True) + + # ... but if we delete that device, it will be a proper logout + self._delete_device(access_token_2, "kermit", "monkey", device_id) + + request, channel = self.make_request( + b"GET", TEST_URL, access_token=access_token + ) + self.render(request) + self.assertEquals(channel.code, 401, channel.result) + self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEquals(channel.json_body["soft_logout"], False) + + def _delete_device(self, access_token, user_id, password, device_id): + """Perform the UI-Auth to delete a device""" + request, channel = self.make_request( + b"DELETE", "devices/" + device_id, access_token=access_token + ) + self.render(request) + self.assertEquals(channel.code, 401, channel.result) + # check it's a UI-Auth fail + self.assertEqual( + set(channel.json_body.keys()), + {"flows", "params", "session"}, + channel.result, + ) + + auth = { + "type": "m.login.password", + # https://github.com/matrix-org/synapse/issues/5665 + # "identifier": {"type": "m.id.user", "user": user_id}, + "user": user_id, + "password": password, + "session": channel.json_body["session"], + } + + request, channel = self.make_request( + b"DELETE", + "devices/" + device_id, + access_token=access_token, + content={"auth": auth}, + ) + self.render(request) + self.assertEquals(channel.code, 200, channel.result) diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 9365c4622d..0253c4ac05 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -57,7 +57,7 @@ class RegistrationStoreTestCase(unittest.TestCase): def test_add_tokens(self): yield self.store.register_user(self.user_id, self.pwhash) yield self.store.add_access_token_to_user( - self.user_id, self.tokens[1], self.device_id + self.user_id, self.tokens[1], self.device_id, valid_until_ms=None ) result = yield self.store.get_user_by_access_token(self.tokens[1]) @@ -72,9 +72,11 @@ class RegistrationStoreTestCase(unittest.TestCase): def test_user_delete_access_tokens(self): # add some tokens yield self.store.register_user(self.user_id, self.pwhash) - yield self.store.add_access_token_to_user(self.user_id, self.tokens[0]) yield self.store.add_access_token_to_user( - self.user_id, self.tokens[1], self.device_id + self.user_id, self.tokens[0], device_id=None, valid_until_ms=None + ) + yield self.store.add_access_token_to_user( + self.user_id, self.tokens[1], self.device_id, valid_until_ms=None ) # now delete some -- cgit 1.5.1 From 2091c91fdec0c90360992b4dbbd71048b940bcb0 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 17 Jul 2019 17:34:13 +0100 Subject: More refactoring in `get_events_as_list` (#5707) We can now use `_get_events_from_cache_or_db` rather than going right back to the database, which means that (a) we can benefit from caching, and (b) it opens the way forward to more extensive checks on the original event. We now always require the original event to exist before we will serve up a redaction. --- changelog.d/5707.bugfix | 1 + synapse/storage/events_worker.py | 64 ++++++++------ tests/rest/client/test_redactions.py | 159 +++++++++++++++++++++++++++++++++++ tests/unittest.py | 1 + 4 files changed, 198 insertions(+), 27 deletions(-) create mode 100644 changelog.d/5707.bugfix create mode 100644 tests/rest/client/test_redactions.py (limited to 'tests') diff --git a/changelog.d/5707.bugfix b/changelog.d/5707.bugfix new file mode 100644 index 0000000000..aa3046c5e1 --- /dev/null +++ b/changelog.d/5707.bugfix @@ -0,0 +1 @@ +Fix some problems with authenticating redactions in recent room versions. diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index e2fc7bc047..1d969d70be 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -236,38 +236,48 @@ class EventsWorkerStore(SQLBaseStore): "allow_rejected=False" ) - # Starting in room version v3, some redactions need to be rechecked if we - # didn't have the redacted event at the time, so we recheck on read - # instead. + # We may not have had the original event when we received a redaction, so + # we have to recheck auth now. + if not allow_rejected and entry.event.type == EventTypes.Redaction: - if entry.event.internal_metadata.need_to_check_redaction(): - # XXX: we should probably use _get_events_from_cache_or_db here, - # so that we can benefit from caching. - orig_sender = yield self._simple_select_one_onecol( - table="events", - keyvalues={"event_id": entry.event.redacts}, - retcol="sender", - allow_none=True, + redacted_event_id = entry.event.redacts + event_map = yield self._get_events_from_cache_or_db([redacted_event_id]) + original_event_entry = event_map.get(redacted_event_id) + if not original_event_entry: + # we don't have the redacted event (or it was rejected). + # + # We assume that the redaction isn't authorized for now; if the + # redacted event later turns up, the redaction will be re-checked, + # and if it is found valid, the original will get redacted before it + # is served to the client. + logger.debug( + "Withholding redaction event %s since we don't (yet) have the " + "original %s", + event_id, + redacted_event_id, ) + continue - expected_domain = get_domain_from_id(entry.event.sender) - if ( - orig_sender - and get_domain_from_id(orig_sender) == expected_domain - ): - # This redaction event is allowed. Mark as not needing a - # recheck. - entry.event.internal_metadata.recheck_redaction = False - else: - # We either don't have the event that is being redacted (so we - # assume that the event isn't authorised for now), or the - # senders don't match (so it will never be authorised). Either - # way, we shouldn't return it. - # - # (If we later receive the event, then we will redact it anyway, - # since we have this redaction) + original_event = original_event_entry.event + + if entry.event.internal_metadata.need_to_check_redaction(): + original_domain = get_domain_from_id(original_event.sender) + redaction_domain = get_domain_from_id(entry.event.sender) + if original_domain != redaction_domain: + # the senders don't match, so this is forbidden + logger.info( + "Withholding redaction %s whose sender domain %s doesn't " + "match that of redacted event %s %s", + event_id, + redaction_domain, + redacted_event_id, + original_domain, + ) continue + # Update the cache to save doing the checks again. + entry.event.internal_metadata.recheck_redaction = False + if check_redacted and entry.redacted_event: event = entry.redacted_event else: diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py new file mode 100644 index 0000000000..7d5df95855 --- /dev/null +++ b/tests/rest/client/test_redactions.py @@ -0,0 +1,159 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.rest import admin +from synapse.rest.client.v1 import login, room +from synapse.rest.client.v2_alpha import sync + +from tests.unittest import HomeserverTestCase + + +class RedactionsTestCase(HomeserverTestCase): + """Tests that various redaction events are handled correctly""" + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + sync.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + # register a couple of users + self.mod_user_id = self.register_user("user1", "pass") + self.mod_access_token = self.login("user1", "pass") + self.other_user_id = self.register_user("otheruser", "pass") + self.other_access_token = self.login("otheruser", "pass") + + # Create a room + self.room_id = self.helper.create_room_as( + self.mod_user_id, tok=self.mod_access_token + ) + + # Invite the other user + self.helper.invite( + room=self.room_id, + src=self.mod_user_id, + tok=self.mod_access_token, + targ=self.other_user_id, + ) + # The other user joins + self.helper.join( + room=self.room_id, user=self.other_user_id, tok=self.other_access_token + ) + + def _redact_event(self, access_token, room_id, event_id, expect_code=200): + """Helper function to send a redaction event. + + Returns the json body. + """ + path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id) + + request, channel = self.make_request( + "POST", path, content={}, access_token=access_token + ) + self.render(request) + self.assertEqual(int(channel.result["code"]), expect_code) + return channel.json_body + + def _sync_room_timeline(self, access_token, room_id): + request, channel = self.make_request( + "GET", "sync", access_token=self.mod_access_token + ) + self.render(request) + self.assertEqual(channel.result["code"], b"200") + room_sync = channel.json_body["rooms"]["join"][room_id] + return room_sync["timeline"]["events"] + + def test_redact_event_as_moderator(self): + # as a regular user, send a message to redact + b = self.helper.send(room_id=self.room_id, tok=self.other_access_token) + msg_id = b["event_id"] + + # as the moderator, send a redaction + b = self._redact_event(self.mod_access_token, self.room_id, msg_id) + redaction_id = b["event_id"] + + # now sync + timeline = self._sync_room_timeline(self.mod_access_token, self.room_id) + + # the last event should be the redaction + self.assertEqual(timeline[-1]["event_id"], redaction_id) + self.assertEqual(timeline[-1]["redacts"], msg_id) + + # and the penultimate should be the redacted original + self.assertEqual(timeline[-2]["event_id"], msg_id) + self.assertEqual(timeline[-2]["unsigned"]["redacted_by"], redaction_id) + self.assertEqual(timeline[-2]["content"], {}) + + def test_redact_event_as_normal(self): + # as a regular user, send a message to redact + b = self.helper.send(room_id=self.room_id, tok=self.other_access_token) + normal_msg_id = b["event_id"] + + # also send one as the admin + b = self.helper.send(room_id=self.room_id, tok=self.mod_access_token) + admin_msg_id = b["event_id"] + + # as a normal, try to redact the admin's event + self._redact_event( + self.other_access_token, self.room_id, admin_msg_id, expect_code=403 + ) + + # now try to redact our own event + b = self._redact_event(self.other_access_token, self.room_id, normal_msg_id) + redaction_id = b["event_id"] + + # now sync + timeline = self._sync_room_timeline(self.other_access_token, self.room_id) + + # the last event should be the redaction of the normal event + self.assertEqual(timeline[-1]["event_id"], redaction_id) + self.assertEqual(timeline[-1]["redacts"], normal_msg_id) + + # the penultimate should be the unredacted one from the admin + self.assertEqual(timeline[-2]["event_id"], admin_msg_id) + self.assertNotIn("redacted_by", timeline[-2]["unsigned"]) + self.assertTrue(timeline[-2]["content"]["body"], {}) + + # and the antepenultimate should be the redacted normal + self.assertEqual(timeline[-3]["event_id"], normal_msg_id) + self.assertEqual(timeline[-3]["unsigned"]["redacted_by"], redaction_id) + self.assertEqual(timeline[-3]["content"], {}) + + def test_redact_nonexistent_event(self): + # control case: an existing event + b = self.helper.send(room_id=self.room_id, tok=self.other_access_token) + msg_id = b["event_id"] + b = self._redact_event(self.other_access_token, self.room_id, msg_id) + redaction_id = b["event_id"] + + # room moderators can send redactions for non-existent events + self._redact_event(self.mod_access_token, self.room_id, "$zzz") + + # ... but normals cannot + self._redact_event( + self.other_access_token, self.room_id, "$zzz", expect_code=404 + ) + + # when we sync, we should see only the valid redaction + timeline = self._sync_room_timeline(self.other_access_token, self.room_id) + self.assertEqual(timeline[-1]["event_id"], redaction_id) + self.assertEqual(timeline[-1]["redacts"], msg_id) + + # and the penultimate should be the redacted original + self.assertEqual(timeline[-2]["event_id"], msg_id) + self.assertEqual(timeline[-2]["unsigned"]["redacted_by"], redaction_id) + self.assertEqual(timeline[-2]["content"], {}) diff --git a/tests/unittest.py b/tests/unittest.py index cabe787cb4..f5fae21317 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -447,6 +447,7 @@ class HomeserverTestCase(TestCase): # Create the user request, channel = self.make_request("GET", "/_matrix/client/r0/admin/register") self.render(request) + self.assertEqual(channel.code, 200) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) -- cgit 1.5.1 From 9c70a02a9cddf36521c3d6ae6f72e3b46a5f4c2d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 17 Jul 2019 19:08:02 +0100 Subject: Ignore redactions of m.room.create events (#5701) --- changelog.d/5701.bugfix | 1 + synapse/api/auth.py | 15 --------------- synapse/handlers/message.py | 33 ++++++++++++++++++++++++--------- synapse/storage/events_worker.py | 12 ++++++++++++ tests/rest/client/test_redactions.py | 20 ++++++++++++++++++++ 5 files changed, 57 insertions(+), 24 deletions(-) create mode 100644 changelog.d/5701.bugfix (limited to 'tests') diff --git a/changelog.d/5701.bugfix b/changelog.d/5701.bugfix new file mode 100644 index 0000000000..fd2866e16a --- /dev/null +++ b/changelog.d/5701.bugfix @@ -0,0 +1 @@ +Ignore redactions of m.room.create events. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index d9e943c39c..7ce6540bdd 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -606,21 +606,6 @@ class Auth(object): defer.returnValue(auth_ids) - def check_redaction(self, room_version, event, auth_events): - """Check whether the event sender is allowed to redact the target event. - - Returns: - True if the the sender is allowed to redact the target event if the - target event was created by them. - False if the sender is allowed to redact the target event with no - further checks. - - Raises: - AuthError if the event sender is definitely not allowed to redact - the target event. - """ - return event_auth.check_redaction(room_version, event, auth_events) - @defer.inlineCallbacks def check_can_change_room_list(self, room_id, user): """Check if the user is allowed to edit the room's entry in the diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index eaeda7a5cb..6d7a987f13 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -23,6 +23,7 @@ from canonicaljson import encode_canonical_json, json from twisted.internet import defer from twisted.internet.defer import succeed +from synapse import event_auth from synapse.api.constants import EventTypes, Membership, RelationTypes from synapse.api.errors import ( AuthError, @@ -784,6 +785,20 @@ class EventCreationHandler(object): event.signatures.update(returned_invite.signatures) if event.type == EventTypes.Redaction: + original_event = yield self.store.get_event( + event.redacts, + check_redacted=False, + get_prev_content=False, + allow_rejected=False, + allow_none=True, + check_room_id=event.room_id, + ) + + # we can make some additional checks now if we have the original event. + if original_event: + if original_event.type == EventTypes.Create: + raise AuthError(403, "Redacting create events is not permitted") + prev_state_ids = yield context.get_prev_state_ids(self.store) auth_events_ids = yield self.auth.compute_auth_events( event, prev_state_ids, for_verification=True @@ -791,18 +806,18 @@ class EventCreationHandler(object): auth_events = yield self.store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in auth_events.values()} room_version = yield self.store.get_room_version(event.room_id) - if self.auth.check_redaction(room_version, event, auth_events=auth_events): - original_event = yield self.store.get_event( - event.redacts, - check_redacted=False, - get_prev_content=False, - allow_rejected=False, - allow_none=False, - ) + + if event_auth.check_redaction(room_version, event, auth_events=auth_events): + # this user doesn't have 'redact' rights, so we need to do some more + # checks on the original event. Let's start by checking the original + # event exists. + if not original_event: + raise NotFoundError("Could not find event %s" % (event.redacts,)) + if event.user_id != original_event.user_id: raise AuthError(403, "You don't have permission to redact events") - # We've already checked. + # all the checks are done. event.internal_metadata.recheck_redaction = False if event.type == EventTypes.Create: diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index 1d969d70be..858fc755a1 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -259,6 +259,14 @@ class EventsWorkerStore(SQLBaseStore): continue original_event = original_event_entry.event + if original_event.type == EventTypes.Create: + # we never serve redactions of Creates to clients. + logger.info( + "Withholding redaction %s of create event %s", + event_id, + redacted_event_id, + ) + continue if entry.event.internal_metadata.need_to_check_redaction(): original_domain = get_domain_from_id(original_event.sender) @@ -617,6 +625,10 @@ class EventsWorkerStore(SQLBaseStore): Deferred[EventBase|None]: if the event should be redacted, a pruned event object. Otherwise, None. """ + if original_ev.type == "m.room.create": + # we choose to ignore redactions of m.room.create events. + return None + redaction_map = yield self._get_events_from_cache_or_db(redactions) for redaction_id in redactions: diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py index 7d5df95855..fe66e397c4 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py @@ -157,3 +157,23 @@ class RedactionsTestCase(HomeserverTestCase): self.assertEqual(timeline[-2]["event_id"], msg_id) self.assertEqual(timeline[-2]["unsigned"]["redacted_by"], redaction_id) self.assertEqual(timeline[-2]["content"], {}) + + def test_redact_create_event(self): + # control case: an existing event + b = self.helper.send(room_id=self.room_id, tok=self.mod_access_token) + msg_id = b["event_id"] + self._redact_event(self.mod_access_token, self.room_id, msg_id) + + # sync the room, to get the id of the create event + timeline = self._sync_room_timeline(self.other_access_token, self.room_id) + create_event_id = timeline[0]["event_id"] + + # room moderators cannot send redactions for create events + self._redact_event( + self.mod_access_token, self.room_id, create_event_id, expect_code=403 + ) + + # and nor can normals + self._redact_event( + self.other_access_token, self.room_id, create_event_id, expect_code=403 + ) -- cgit 1.5.1 From b2a382efdb0cf7b68e194070b616f10732fa0f36 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Thu, 18 Jul 2019 14:41:42 +0100 Subject: Remove the ability to query relations when the original event was redacted. (#5629) Fixes #5594 Forbid viewing relations on an event once it has been redacted. --- changelog.d/5629.bugfix | 1 + synapse/events/__init__.py | 11 +++ synapse/events/utils.py | 16 +++- synapse/rest/client/v2_alpha/relations.py | 75 +++++++++-------- tests/rest/client/v2_alpha/test_relations.py | 116 ++++++++++++++++++++++++++- 5 files changed, 180 insertions(+), 39 deletions(-) create mode 100644 changelog.d/5629.bugfix (limited to 'tests') diff --git a/changelog.d/5629.bugfix b/changelog.d/5629.bugfix new file mode 100644 index 0000000000..672eabad40 --- /dev/null +++ b/changelog.d/5629.bugfix @@ -0,0 +1 @@ +Forbid viewing relations on an event once it has been redacted. diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index d3de70e671..88ed6d764f 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -104,6 +104,17 @@ class _EventInternalMetadata(object): """ return getattr(self, "proactively_send", True) + def is_redacted(self): + """Whether the event has been redacted. + + This is used for efficiently checking whether an event has been + marked as redacted without needing to make another database call. + + Returns: + bool + """ + return getattr(self, "redacted", False) + def _event_dict_property(key): # We want to be able to use hasattr with the event dict properties. diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 987de5cab7..9487a886f5 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -52,10 +52,15 @@ def prune_event(event): from . import event_type_from_format_version - return event_type_from_format_version(event.format_version)( + pruned_event = event_type_from_format_version(event.format_version)( pruned_event_dict, event.internal_metadata.get_dict() ) + # Mark the event as redacted + pruned_event.internal_metadata.redacted = True + + return pruned_event + def prune_event_dict(event_dict): """Redacts the event_dict in the same way as `prune_event`, except it @@ -360,9 +365,12 @@ class EventClientSerializer(object): event_id = event.event_id serialized_event = serialize_event(event, time_now, **kwargs) - # If MSC1849 is enabled then we need to look if thre are any relations - # we need to bundle in with the event - if self.experimental_msc1849_support_enabled and bundle_aggregations: + # If MSC1849 is enabled then we need to look if there are any relations + # we need to bundle in with the event. + # Do not bundle relations if the event has been redacted + if not event.internal_metadata.is_redacted() and ( + self.experimental_msc1849_support_enabled and bundle_aggregations + ): annotations = yield self.store.get_aggregation_groups_for_event(event_id) references = yield self.store.get_relations_for_event( event_id, RelationTypes.REFERENCE, direction="f" diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index 7ce485b471..6e52f6d284 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -34,6 +34,7 @@ from synapse.http.servlet import ( from synapse.rest.client.transactions import HttpTransactionCache from synapse.storage.relations import ( AggregationPaginationToken, + PaginationChunk, RelationPaginationToken, ) @@ -153,23 +154,28 @@ class RelationPaginationServlet(RestServlet): from_token = parse_string(request, "from") to_token = parse_string(request, "to") - if from_token: - from_token = RelationPaginationToken.from_string(from_token) - - if to_token: - to_token = RelationPaginationToken.from_string(to_token) - - result = yield self.store.get_relations_for_event( - event_id=parent_id, - relation_type=relation_type, - event_type=event_type, - limit=limit, - from_token=from_token, - to_token=to_token, - ) + if event.internal_metadata.is_redacted(): + # If the event is redacted, return an empty list of relations + pagination_chunk = PaginationChunk(chunk=[]) + else: + # Return the relations + if from_token: + from_token = RelationPaginationToken.from_string(from_token) + + if to_token: + to_token = RelationPaginationToken.from_string(to_token) + + pagination_chunk = yield self.store.get_relations_for_event( + event_id=parent_id, + relation_type=relation_type, + event_type=event_type, + limit=limit, + from_token=from_token, + to_token=to_token, + ) events = yield self.store.get_events_as_list( - [c["event_id"] for c in result.chunk] + [c["event_id"] for c in pagination_chunk.chunk] ) now = self.clock.time_msec() @@ -186,7 +192,7 @@ class RelationPaginationServlet(RestServlet): events, now, bundle_aggregations=False ) - return_value = result.to_dict() + return_value = pagination_chunk.to_dict() return_value["chunk"] = events return_value["original_event"] = original_event @@ -234,7 +240,7 @@ class RelationAggregationPaginationServlet(RestServlet): # This checks that a) the event exists and b) the user is allowed to # view it. - yield self.event_handler.get_event(requester.user, room_id, parent_id) + event = yield self.event_handler.get_event(requester.user, room_id, parent_id) if relation_type not in (RelationTypes.ANNOTATION, None): raise SynapseError(400, "Relation type must be 'annotation'") @@ -243,21 +249,26 @@ class RelationAggregationPaginationServlet(RestServlet): from_token = parse_string(request, "from") to_token = parse_string(request, "to") - if from_token: - from_token = AggregationPaginationToken.from_string(from_token) - - if to_token: - to_token = AggregationPaginationToken.from_string(to_token) - - res = yield self.store.get_aggregation_groups_for_event( - event_id=parent_id, - event_type=event_type, - limit=limit, - from_token=from_token, - to_token=to_token, - ) - - defer.returnValue((200, res.to_dict())) + if event.internal_metadata.is_redacted(): + # If the event is redacted, return an empty list of relations + pagination_chunk = PaginationChunk(chunk=[]) + else: + # Return the relations + if from_token: + from_token = AggregationPaginationToken.from_string(from_token) + + if to_token: + to_token = AggregationPaginationToken.from_string(to_token) + + pagination_chunk = yield self.store.get_aggregation_groups_for_event( + event_id=parent_id, + event_type=event_type, + limit=limit, + from_token=from_token, + to_token=to_token, + ) + + defer.returnValue((200, pagination_chunk.to_dict())) class RelationAggregationGroupPaginationServlet(RestServlet): diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index 58c6951852..c7e5859970 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -93,7 +93,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): def test_deny_double_react(self): """Test that we deny relations on membership events """ - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") self.assertEquals(200, channel.code, channel.json_body) channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") @@ -540,14 +540,122 @@ class RelationsTestCase(unittest.HomeserverTestCase): {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) + def test_relations_redaction_redacts_edits(self): + """Test that edits of an event are redacted when the original event + is redacted. + """ + # Send a new event + res = self.helper.send(self.room, body="Heyo!", tok=self.user_token) + original_event_id = res["event_id"] + + # Add a relation + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + parent_id=original_event_id, + content={ + "msgtype": "m.text", + "body": "Wibble", + "m.new_content": {"msgtype": "m.text", "body": "First edit"}, + }, + ) + self.assertEquals(200, channel.code, channel.json_body) + + # Check the relation is returned + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message" + % (self.room, original_event_id), + access_token=self.user_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertIn("chunk", channel.json_body) + self.assertEquals(len(channel.json_body["chunk"]), 1) + + # Redact the original event + request, channel = self.make_request( + "PUT", + "/rooms/%s/redact/%s/%s" + % (self.room, original_event_id, "test_relations_redaction_redacts_edits"), + access_token=self.user_token, + content="{}", + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + # Try to check for remaining m.replace relations + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message" + % (self.room, original_event_id), + access_token=self.user_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + # Check that no relations are returned + self.assertIn("chunk", channel.json_body) + self.assertEquals(channel.json_body["chunk"], []) + + def test_aggregations_redaction_prevents_access_to_aggregations(self): + """Test that annotations of an event are redacted when the original event + is redacted. + """ + # Send a new event + res = self.helper.send(self.room, body="Hello!", tok=self.user_token) + original_event_id = res["event_id"] + + # Add a relation + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", key="👍", parent_id=original_event_id + ) + self.assertEquals(200, channel.code, channel.json_body) + + # Redact the original + request, channel = self.make_request( + "PUT", + "/rooms/%s/redact/%s/%s" + % ( + self.room, + original_event_id, + "test_aggregations_redaction_prevents_access_to_aggregations", + ), + access_token=self.user_token, + content="{}", + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + # Check that aggregations returns zero + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/aggregations/%s/m.annotation/m.reaction" + % (self.room, original_event_id), + access_token=self.user_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertIn("chunk", channel.json_body) + self.assertEquals(channel.json_body["chunk"], []) + def _send_relation( - self, relation_type, event_type, key=None, content={}, access_token=None + self, + relation_type, + event_type, + key=None, + content={}, + access_token=None, + parent_id=None, ): """Helper function to send a relation pointing at `self.parent_id` Args: relation_type (str): One of `RelationTypes` event_type (str): The type of the event to create + parent_id (str): The event_id this relation relates to. If None, then self.parent_id key (str|None): The aggregation key used for m.annotation relation type. content(dict|None): The content of the created event. @@ -564,10 +672,12 @@ class RelationsTestCase(unittest.HomeserverTestCase): if key: query = "?key=" + six.moves.urllib.parse.quote_plus(key.encode("utf-8")) + original_id = parent_id if parent_id else self.parent_id + request, channel = self.make_request( "POST", "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s" - % (self.room, self.parent_id, relation_type, event_type, query), + % (self.room, original_id, relation_type, event_type, query), json.dumps(content).encode("utf-8"), access_token=access_token, ) -- cgit 1.5.1 From 7ad1d763566eb34bd32234811aa9901d8f3668aa Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Thu, 18 Jul 2019 23:57:15 +1000 Subject: Support Prometheus_client 0.4.0+ (#5636) --- UPGRADE.rst | 7 + changelog.d/5636.misc | 1 + docs/metrics-howto.rst | 102 ++++++++++++++ synapse/app/_base.py | 3 +- synapse/app/appservice.py | 3 +- synapse/app/client_reader.py | 3 +- synapse/app/event_creator.py | 3 +- synapse/app/federation_reader.py | 3 +- synapse/app/federation_sender.py | 3 +- synapse/app/frontend_proxy.py | 3 +- synapse/app/homeserver.py | 3 +- synapse/app/media_repository.py | 3 +- synapse/app/pusher.py | 3 +- synapse/app/synchrotron.py | 3 +- synapse/app/user_dir.py | 3 +- synapse/metrics/__init__.py | 17 +++ synapse/metrics/_exposition.py | 258 ++++++++++++++++++++++++++++++++++++ synapse/metrics/resource.py | 20 --- synapse/python_dependencies.py | 4 +- tests/storage/test_event_metrics.py | 4 +- 20 files changed, 399 insertions(+), 50 deletions(-) create mode 100644 changelog.d/5636.misc create mode 100644 synapse/metrics/_exposition.py delete mode 100644 synapse/metrics/resource.py (limited to 'tests') diff --git a/UPGRADE.rst b/UPGRADE.rst index 72064accf3..cf228c7c52 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -49,6 +49,13 @@ returned by the Client-Server API: # configured on port 443. curl -kv https:///_matrix/client/versions 2>&1 | grep "Server:" +Upgrading to v1.2.0 +=================== + +Some counter metrics have been renamed, with the old names deprecated. See +`the metrics documentation `_ +for details. + Upgrading to v1.1.0 =================== diff --git a/changelog.d/5636.misc b/changelog.d/5636.misc new file mode 100644 index 0000000000..3add990283 --- /dev/null +++ b/changelog.d/5636.misc @@ -0,0 +1 @@ +Some counter metrics exposed over Prometheus have been renamed, with the old names preserved for backwards compatibility and deprecated. See `docs/metrics-howto.rst` for details. \ No newline at end of file diff --git a/docs/metrics-howto.rst b/docs/metrics-howto.rst index 32b064e2da..973641f3dc 100644 --- a/docs/metrics-howto.rst +++ b/docs/metrics-howto.rst @@ -59,6 +59,108 @@ How to monitor Synapse metrics using Prometheus Restart Prometheus. +Renaming of metrics & deprecation of old names in 1.2 +----------------------------------------------------- + +Synapse 1.2 updates the Prometheus metrics to match the naming convention of the +upstream ``prometheus_client``. The old names are considered deprecated and will +be removed in a future version of Synapse. + ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| New Name | Old Name | ++=============================================================================+=======================================================================+ +| python_gc_objects_collected_total | python_gc_objects_collected | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| python_gc_objects_uncollectable_total | python_gc_objects_uncollectable | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| python_gc_collections_total | python_gc_collections | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| process_cpu_seconds_total | process_cpu_seconds | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_federation_client_sent_transactions_total | synapse_federation_client_sent_transactions | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_federation_client_events_processed_total | synapse_federation_client_events_processed | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_event_processing_loop_count_total | synapse_event_processing_loop_count | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_event_processing_loop_room_count_total | synapse_event_processing_loop_room_count | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_util_metrics_block_count_total | synapse_util_metrics_block_count | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_util_metrics_block_time_seconds_total | synapse_util_metrics_block_time_seconds | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_util_metrics_block_ru_utime_seconds_total | synapse_util_metrics_block_ru_utime_seconds | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_util_metrics_block_ru_stime_seconds_total | synapse_util_metrics_block_ru_stime_seconds | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_util_metrics_block_db_txn_count_total | synapse_util_metrics_block_db_txn_count | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_util_metrics_block_db_txn_duration_seconds_total | synapse_util_metrics_block_db_txn_duration_seconds | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_util_metrics_block_db_sched_duration_seconds_total | synapse_util_metrics_block_db_sched_duration_seconds | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_background_process_start_count_total | synapse_background_process_start_count | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_background_process_ru_utime_seconds_total | synapse_background_process_ru_utime_seconds | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_background_process_ru_stime_seconds_total | synapse_background_process_ru_stime_seconds | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_background_process_db_txn_count_total | synapse_background_process_db_txn_count | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_background_process_db_txn_duration_seconds_total | synapse_background_process_db_txn_duration_seconds | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_background_process_db_sched_duration_seconds_total | synapse_background_process_db_sched_duration_seconds | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_storage_events_persisted_events_total | synapse_storage_events_persisted_events | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_storage_events_persisted_events_sep_total | synapse_storage_events_persisted_events_sep | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_storage_events_state_delta_total | synapse_storage_events_state_delta | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_storage_events_state_delta_single_event_total | synapse_storage_events_state_delta_single_event | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_storage_events_state_delta_reuse_delta_total | synapse_storage_events_state_delta_reuse_delta | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_federation_server_received_pdus_total | synapse_federation_server_received_pdus | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_federation_server_received_edus_total | synapse_federation_server_received_edus | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_handler_presence_notified_presence_total | synapse_handler_presence_notified_presence | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_handler_presence_federation_presence_out_total | synapse_handler_presence_federation_presence_out | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_handler_presence_presence_updates_total | synapse_handler_presence_presence_updates | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_handler_presence_timers_fired_total | synapse_handler_presence_timers_fired | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_handler_presence_federation_presence_total | synapse_handler_presence_federation_presence | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_handler_presence_bump_active_time_total | synapse_handler_presence_bump_active_time | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_federation_client_sent_edus_total | synapse_federation_client_sent_edus | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_federation_client_sent_pdu_destinations_count_total | synapse_federation_client_sent_pdu_destinations:count | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_federation_client_sent_pdu_destinations_total | synapse_federation_client_sent_pdu_destinations:total | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_handlers_appservice_events_processed_total | synapse_handlers_appservice_events_processed | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_notifier_notified_events_total | synapse_notifier_notified_events | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter_total | synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_push_bulk_push_rule_evaluator_push_rules_state_size_counter_total | synapse_push_bulk_push_rule_evaluator_push_rules_state_size_counter | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_http_httppusher_http_pushes_processed_total | synapse_http_httppusher_http_pushes_processed | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_http_httppusher_http_pushes_failed_total | synapse_http_httppusher_http_pushes_failed | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_http_httppusher_badge_updates_processed_total | synapse_http_httppusher_badge_updates_processed | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ +| synapse_http_httppusher_badge_updates_failed_total | synapse_http_httppusher_badge_updates_failed | ++-----------------------------------------------------------------------------+-----------------------------------------------------------------------+ + + Removal of deprecated metrics & time based counters becoming histograms in 0.31.0 --------------------------------------------------------------------------------- diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 807f320b46..540dbd9236 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -149,8 +149,7 @@ def listen_metrics(bind_addresses, port): """ Start Prometheus metrics server. """ - from synapse.metrics import RegistryProxy - from prometheus_client import start_http_server + from synapse.metrics import RegistryProxy, start_http_server for host in bind_addresses: logger.info("Starting metrics listener on %s:%d", host, port) diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index be44249ed6..e01f3e5f3b 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -27,8 +27,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext, run_in_background -from synapse.metrics import RegistryProxy -from synapse.metrics.resource import METRICS_PREFIX, MetricsResource +from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.directory import DirectoryStore from synapse.replication.slave.storage.events import SlavedEventStore diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index ff11beca82..29bddc4823 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -28,8 +28,7 @@ from synapse.config.logger import setup_logging from synapse.http.server import JsonResource from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext -from synapse.metrics import RegistryProxy -from synapse.metrics.resource import METRICS_PREFIX, MetricsResource +from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py index cacad25eac..042cfd04af 100644 --- a/synapse/app/event_creator.py +++ b/synapse/app/event_creator.py @@ -28,8 +28,7 @@ from synapse.config.logger import setup_logging from synapse.http.server import JsonResource from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext -from synapse.metrics import RegistryProxy -from synapse.metrics.resource import METRICS_PREFIX, MetricsResource +from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index 11e80dbae0..76a97f8f32 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -29,8 +29,7 @@ from synapse.config.logger import setup_logging from synapse.federation.transport.server import TransportLayerServer from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext -from synapse.metrics import RegistryProxy -from synapse.metrics.resource import METRICS_PREFIX, MetricsResource +from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 97da7bdcbf..fec49d5092 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -28,9 +28,8 @@ from synapse.config.logger import setup_logging from synapse.federation import send_queue from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext, run_in_background -from synapse.metrics import RegistryProxy +from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.metrics.resource import METRICS_PREFIX, MetricsResource from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.slave.storage.events import SlavedEventStore diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index 417a10bbd2..1f1f1df78e 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -30,8 +30,7 @@ from synapse.http.server import JsonResource from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext -from synapse.metrics import RegistryProxy -from synapse.metrics.resource import METRICS_PREFIX, MetricsResource +from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.client_ips import SlavedClientIpStore diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 639b1429c0..0c075cb3f1 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -55,9 +55,8 @@ from synapse.http.additional_resource import AdditionalResource from synapse.http.server import RootRedirect from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext -from synapse.metrics import RegistryProxy +from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.metrics.resource import METRICS_PREFIX, MetricsResource from synapse.module_api import ModuleApi from synapse.python_dependencies import check_requirements from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index f23b9b6eda..d70780e9d5 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -28,8 +28,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext -from synapse.metrics import RegistryProxy -from synapse.metrics.resource import METRICS_PREFIX, MetricsResource +from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.client_ips import SlavedClientIpStore diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index 4f929edf86..070de7d0b0 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -27,8 +27,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext, run_in_background -from synapse.metrics import RegistryProxy -from synapse.metrics.resource import METRICS_PREFIX, MetricsResource +from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.replication.slave.storage._base import __func__ from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.events import SlavedEventStore diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index de4797fddc..315c030694 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -32,8 +32,7 @@ from synapse.handlers.presence import PresenceHandler, get_interested_parties from synapse.http.server import JsonResource from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext, run_in_background -from synapse.metrics import RegistryProxy -from synapse.metrics.resource import METRICS_PREFIX, MetricsResource +from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.replication.slave.storage._base import BaseSlavedStore, __func__ from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index 1177ddd72e..03ef21bd01 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -29,8 +29,7 @@ from synapse.config.logger import setup_logging from synapse.http.server import JsonResource from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext, run_in_background -from synapse.metrics import RegistryProxy -from synapse.metrics.resource import METRICS_PREFIX, MetricsResource +from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.client_ips import SlavedClientIpStore diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index eaf0aaa86e..488280b4a6 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -29,8 +29,16 @@ from prometheus_client.core import REGISTRY, GaugeMetricFamily, HistogramMetricF from twisted.internet import reactor +from synapse.metrics._exposition import ( + MetricsResource, + generate_latest, + start_http_server, +) + logger = logging.getLogger(__name__) +METRICS_PREFIX = "/_synapse/metrics" + running_on_pypy = platform.python_implementation() == "PyPy" all_metrics = [] all_collectors = [] @@ -470,3 +478,12 @@ try: gc.disable() except AttributeError: pass + +__all__ = [ + "MetricsResource", + "generate_latest", + "start_http_server", + "LaterGauge", + "InFlightGauge", + "BucketCollector", +] diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py new file mode 100644 index 0000000000..1933ecd3e3 --- /dev/null +++ b/synapse/metrics/_exposition.py @@ -0,0 +1,258 @@ +# -*- coding: utf-8 -*- +# Copyright 2015-2019 Prometheus Python Client Developers +# Copyright 2019 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code is based off `prometheus_client/exposition.py` from version 0.7.1. + +Due to the renaming of metrics in prometheus_client 0.4.0, this customised +vendoring of the code will emit both the old versions that Synapse dashboards +expect, and the newer "best practice" version of the up-to-date official client. +""" + +import math +import threading +from collections import namedtuple +from http.server import BaseHTTPRequestHandler, HTTPServer +from socketserver import ThreadingMixIn +from urllib.parse import parse_qs, urlparse + +from prometheus_client import REGISTRY + +from twisted.web.resource import Resource + +try: + from prometheus_client.samples import Sample +except ImportError: + Sample = namedtuple("Sample", ["name", "labels", "value", "timestamp", "exemplar"]) + + +CONTENT_TYPE_LATEST = str("text/plain; version=0.0.4; charset=utf-8") + + +INF = float("inf") +MINUS_INF = float("-inf") + + +def floatToGoString(d): + d = float(d) + if d == INF: + return "+Inf" + elif d == MINUS_INF: + return "-Inf" + elif math.isnan(d): + return "NaN" + else: + s = repr(d) + dot = s.find(".") + # Go switches to exponents sooner than Python. + # We only need to care about positive values for le/quantile. + if d > 0 and dot > 6: + mantissa = "{0}.{1}{2}".format(s[0], s[1:dot], s[dot + 1 :]).rstrip("0.") + return "{0}e+0{1}".format(mantissa, dot - 1) + return s + + +def sample_line(line, name): + if line.labels: + labelstr = "{{{0}}}".format( + ",".join( + [ + '{0}="{1}"'.format( + k, + v.replace("\\", r"\\").replace("\n", r"\n").replace('"', r"\""), + ) + for k, v in sorted(line.labels.items()) + ] + ) + ) + else: + labelstr = "" + timestamp = "" + if line.timestamp is not None: + # Convert to milliseconds. + timestamp = " {0:d}".format(int(float(line.timestamp) * 1000)) + return "{0}{1} {2}{3}\n".format( + name, labelstr, floatToGoString(line.value), timestamp + ) + + +def nameify_sample(sample): + """ + If we get a prometheus_client<0.4.0 sample as a tuple, transform it into a + namedtuple which has the names we expect. + """ + if not isinstance(sample, Sample): + sample = Sample(*sample, None, None) + + return sample + + +def generate_latest(registry, emit_help=False): + output = [] + + for metric in registry.collect(): + + if metric.name.startswith("__unused"): + continue + + if not metric.samples: + # No samples, don't bother. + continue + + mname = metric.name + mnewname = metric.name + mtype = metric.type + + # OpenMetrics -> Prometheus + if mtype == "counter": + mnewname = mnewname + "_total" + elif mtype == "info": + mtype = "gauge" + mnewname = mnewname + "_info" + elif mtype == "stateset": + mtype = "gauge" + elif mtype == "gaugehistogram": + mtype = "histogram" + elif mtype == "unknown": + mtype = "untyped" + + # Output in the old format for compatibility. + if emit_help: + output.append( + "# HELP {0} {1}\n".format( + mname, + metric.documentation.replace("\\", r"\\").replace("\n", r"\n"), + ) + ) + output.append("# TYPE {0} {1}\n".format(mname, mtype)) + for sample in map(nameify_sample, metric.samples): + # Get rid of the OpenMetrics specific samples + for suffix in ["_created", "_gsum", "_gcount"]: + if sample.name.endswith(suffix): + break + else: + newname = sample.name.replace(mnewname, mname) + if ":" in newname and newname.endswith("_total"): + newname = newname[: -len("_total")] + output.append(sample_line(sample, newname)) + + # Get rid of the weird colon things while we're at it + if mtype == "counter": + mnewname = mnewname.replace(":total", "") + mnewname = mnewname.replace(":", "_") + + if mname == mnewname: + continue + + # Also output in the new format, if it's different. + if emit_help: + output.append( + "# HELP {0} {1}\n".format( + mnewname, + metric.documentation.replace("\\", r"\\").replace("\n", r"\n"), + ) + ) + output.append("# TYPE {0} {1}\n".format(mnewname, mtype)) + for sample in map(nameify_sample, metric.samples): + # Get rid of the OpenMetrics specific samples + for suffix in ["_created", "_gsum", "_gcount"]: + if sample.name.endswith(suffix): + break + else: + output.append( + sample_line( + sample, sample.name.replace(":total", "").replace(":", "_") + ) + ) + + return "".join(output).encode("utf-8") + + +class MetricsHandler(BaseHTTPRequestHandler): + """HTTP handler that gives metrics from ``REGISTRY``.""" + + registry = REGISTRY + + def do_GET(self): + registry = self.registry + params = parse_qs(urlparse(self.path).query) + + if "help" in params: + emit_help = True + else: + emit_help = False + + try: + output = generate_latest(registry, emit_help=emit_help) + except Exception: + self.send_error(500, "error generating metric output") + raise + self.send_response(200) + self.send_header("Content-Type", CONTENT_TYPE_LATEST) + self.end_headers() + self.wfile.write(output) + + def log_message(self, format, *args): + """Log nothing.""" + + @classmethod + def factory(cls, registry): + """Returns a dynamic MetricsHandler class tied + to the passed registry. + """ + # This implementation relies on MetricsHandler.registry + # (defined above and defaulted to REGISTRY). + + # As we have unicode_literals, we need to create a str() + # object for type(). + cls_name = str(cls.__name__) + MyMetricsHandler = type(cls_name, (cls, object), {"registry": registry}) + return MyMetricsHandler + + +class _ThreadingSimpleServer(ThreadingMixIn, HTTPServer): + """Thread per request HTTP server.""" + + # Make worker threads "fire and forget". Beginning with Python 3.7 this + # prevents a memory leak because ``ThreadingMixIn`` starts to gather all + # non-daemon threads in a list in order to join on them at server close. + # Enabling daemon threads virtually makes ``_ThreadingSimpleServer`` the + # same as Python 3.7's ``ThreadingHTTPServer``. + daemon_threads = True + + +def start_http_server(port, addr="", registry=REGISTRY): + """Starts an HTTP server for prometheus metrics as a daemon thread""" + CustomMetricsHandler = MetricsHandler.factory(registry) + httpd = _ThreadingSimpleServer((addr, port), CustomMetricsHandler) + t = threading.Thread(target=httpd.serve_forever) + t.daemon = True + t.start() + + +class MetricsResource(Resource): + """ + Twisted ``Resource`` that serves prometheus metrics. + """ + + isLeaf = True + + def __init__(self, registry=REGISTRY): + self.registry = registry + + def render_GET(self, request): + request.setHeader(b"Content-Type", CONTENT_TYPE_LATEST.encode("ascii")) + return generate_latest(self.registry) diff --git a/synapse/metrics/resource.py b/synapse/metrics/resource.py deleted file mode 100644 index 9789359077..0000000000 --- a/synapse/metrics/resource.py +++ /dev/null @@ -1,20 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from prometheus_client.twisted import MetricsResource - -METRICS_PREFIX = "/_synapse/metrics" - -__all__ = ["MetricsResource", "METRICS_PREFIX"] diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index e7618057be..c6465c0386 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -65,9 +65,7 @@ REQUIREMENTS = [ "msgpack>=0.5.2", "phonenumbers>=8.2.0", "six>=1.10", - # prometheus_client 0.4.0 changed the format of counter metrics - # (cf https://github.com/matrix-org/synapse/issues/4001) - "prometheus_client>=0.0.18,<0.4.0", + "prometheus_client>=0.0.18,<0.8.0", # we use attr.s(slots), which arrived in 16.0.0 # Twisted 18.7.0 requires attrs>=17.4.0 "attrs>=17.4.0", diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py index d44359ff93..f26ff57a18 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py @@ -13,9 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from prometheus_client.exposition import generate_latest - -from synapse.metrics import REGISTRY +from synapse.metrics import REGISTRY, generate_latest from synapse.types import Requester, UserID from tests.unittest import HomeserverTestCase -- cgit 1.5.1 From c7095be9136825efc5bd85181b0395b833f96aee Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 19 Jul 2019 17:49:19 +0100 Subject: Refactor Keyring._start_key_lookups There's an awful lot of deferreds and dictionaries flying around here. The whole thing can be made much simpler and achieve the same effect. --- synapse/crypto/keyring.py | 86 ++++++++++++++++++-------------------------- tests/crypto/test_keyring.py | 29 --------------- 2 files changed, 35 insertions(+), 80 deletions(-) (limited to 'tests') diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 341c863152..efa72dc5fc 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -238,27 +238,9 @@ class Keyring(object): """ try: - # create a deferred for each server we're going to look up the keys - # for; we'll resolve them once we have completed our lookups. - # These will be passed into wait_for_previous_lookups to block - # any other lookups until we have finished. - # The deferreds are called with no logcontext. - server_to_deferred = { - rq.server_name: defer.Deferred() for rq in verify_requests - } - - # We want to wait for any previous lookups to complete before - # proceeding. - yield self.wait_for_previous_lookups(server_to_deferred) + ctx = LoggingContext.current_context() - # Actually start fetching keys. - self._get_server_verify_keys(verify_requests) - - # When we've finished fetching all the keys for a given server_name, - # resolve the deferred passed to `wait_for_previous_lookups` so that - # any lookups waiting will proceed. - # - # map from server name to a set of request ids + # map from server name to a set of outstanding request ids server_to_request_ids = {} for verify_request in verify_requests: @@ -266,40 +248,55 @@ class Keyring(object): request_id = id(verify_request) server_to_request_ids.setdefault(server_name, set()).add(request_id) - def remove_deferreds(res, verify_request): + # Wait for any previous lookups to complete before proceeding. + yield self.wait_for_previous_lookups(server_to_request_ids.keys()) + + # take out a lock on each of the servers by sticking a Deferred in + # key_downloads + for server_name in server_to_request_ids.keys(): + self.key_downloads[server_name] = defer.Deferred() + logger.debug("Got key lookup lock on %s", server_name) + + # When we've finished fetching all the keys for a given server_name, + # drop the lock by resolving the deferred in key_downloads. + def lookup_done(res, verify_request): server_name = verify_request.server_name - request_id = id(verify_request) - server_to_request_ids[server_name].discard(request_id) - if not server_to_request_ids[server_name]: - d = server_to_deferred.pop(server_name, None) - if d: - d.callback(None) + server_requests = server_to_request_ids[server_name] + server_requests.remove(id(verify_request)) + + # if there are no more requests for this server, we can drop the lock. + if not server_requests: + with PreserveLoggingContext(ctx): + logger.debug("Releasing key lookup lock on %s", server_name) + + d = self.key_downloads.pop(server_name) + d.callback(None) return res for verify_request in verify_requests: - verify_request.key_ready.addBoth(remove_deferreds, verify_request) + verify_request.key_ready.addBoth(lookup_done, verify_request) + + # Actually start fetching keys. + self._get_server_verify_keys(verify_requests) except Exception: logger.exception("Error starting key lookups") @defer.inlineCallbacks - def wait_for_previous_lookups(self, server_to_deferred): + def wait_for_previous_lookups(self, server_names): """Waits for any previous key lookups for the given servers to finish. Args: - server_to_deferred (dict[str, Deferred]): server_name to deferred which gets - resolved once we've finished looking up keys for that server. - The Deferreds should be regular twisted ones which call their - callbacks with no logcontext. - - Returns: a Deferred which resolves once all key lookups for the given - servers have completed. Follows the synapse rules of logcontext - preservation. + server_names (Iterable[str]): list of servers which we want to look up + + Returns: + Deferred[None]: resolves once all key lookups for the given servers have + completed. Follows the synapse rules of logcontext preservation. """ loop_count = 1 while True: wait_on = [ (server_name, self.key_downloads[server_name]) - for server_name in server_to_deferred.keys() + for server_name in server_names if server_name in self.key_downloads ] if not wait_on: @@ -314,19 +311,6 @@ class Keyring(object): loop_count += 1 - ctx = LoggingContext.current_context() - - def rm(r, server_name_): - with PreserveLoggingContext(ctx): - logger.debug("Releasing key lookup lock on %s", server_name_) - self.key_downloads.pop(server_name_, None) - return r - - for server_name, deferred in server_to_deferred.items(): - logger.debug("Got key lookup lock on %s", server_name) - self.key_downloads[server_name] = deferred - deferred.addBoth(rm, server_name) - def _get_server_verify_keys(self, verify_requests): """Tries to find at least one key for each verify request diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 795703967d..8d94a503d6 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -86,35 +86,6 @@ class KeyringTestCase(unittest.HomeserverTestCase): getattr(LoggingContext.current_context(), "request", None), expected ) - def test_wait_for_previous_lookups(self): - kr = keyring.Keyring(self.hs) - - lookup_1_deferred = defer.Deferred() - lookup_2_deferred = defer.Deferred() - - # we run the lookup in a logcontext so that the patched inlineCallbacks can check - # it is doing the right thing with logcontexts. - wait_1_deferred = run_in_context( - kr.wait_for_previous_lookups, {"server1": lookup_1_deferred} - ) - - # there were no previous lookups, so the deferred should be ready - self.successResultOf(wait_1_deferred) - - # set off another wait. It should block because the first lookup - # hasn't yet completed. - wait_2_deferred = run_in_context( - kr.wait_for_previous_lookups, {"server1": lookup_2_deferred} - ) - - self.assertFalse(wait_2_deferred.called) - - # let the first lookup complete (in the sentinel context) - lookup_1_deferred.callback(None) - - # now the second wait should complete. - self.successResultOf(wait_2_deferred) - def test_verify_json_objects_for_server_awaits_previous_requests(self): key1 = signedjson.key.generate_signing_key(1) -- cgit 1.5.1 From 4806651744616bf48abf408034ab9560e33f60ce Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Tue, 23 Jul 2019 23:00:55 +1000 Subject: Replace returnValue with return (#5736) --- changelog.d/5736.misc | 1 + docs/log_contexts.rst | 2 +- synapse/api/auth.py | 44 ++++---- synapse/api/filtering.py | 2 +- synapse/app/frontend_proxy.py | 8 +- synapse/app/homeserver.py | 2 +- synapse/appservice/__init__.py | 28 +++--- synapse/appservice/api.py | 38 +++---- synapse/appservice/scheduler.py | 4 +- synapse/crypto/keyring.py | 14 +-- synapse/events/builder.py | 16 ++- synapse/events/snapshot.py | 28 +++--- synapse/events/third_party_rules.py | 8 +- synapse/events/utils.py | 4 +- synapse/federation/federation_base.py | 6 +- synapse/federation/federation_client.py | 46 ++++----- synapse/federation/federation_server.py | 77 ++++++-------- synapse/federation/sender/per_destination_queue.py | 4 +- synapse/federation/sender/transaction_manager.py | 2 +- synapse/federation/transport/client.py | 30 +++--- synapse/groups/attestations.py | 2 +- synapse/groups/groups_server.py | 92 ++++++++--------- synapse/handlers/account_data.py | 4 +- synapse/handlers/account_validity.py | 6 +- synapse/handlers/acme.py | 2 +- synapse/handlers/admin.py | 10 +- synapse/handlers/appservice.py | 22 ++-- synapse/handlers/auth.py | 44 ++++---- synapse/handlers/deactivate_account.py | 2 +- synapse/handlers/device.py | 22 ++-- synapse/handlers/directory.py | 14 +-- synapse/handlers/e2e_keys.py | 10 +- synapse/handlers/e2e_room_keys.py | 8 +- synapse/handlers/events.py | 6 +- synapse/handlers/federation.py | 82 ++++++++------- synapse/handlers/groups_local.py | 26 ++--- synapse/handlers/identity.py | 18 ++-- synapse/handlers/initial_sync.py | 54 +++++----- synapse/handlers/message.py | 32 +++--- synapse/handlers/pagination.py | 14 ++- synapse/handlers/presence.py | 54 +++++----- synapse/handlers/profile.py | 18 ++-- synapse/handlers/receipts.py | 14 +-- synapse/handlers/register.py | 16 +-- synapse/handlers/room.py | 16 +-- synapse/handlers/room_list.py | 10 +- synapse/handlers/room_member.py | 42 ++++---- synapse/handlers/room_member_worker.py | 2 +- synapse/handlers/search.py | 14 ++- synapse/handlers/state_deltas.py | 8 +- synapse/handlers/stats.py | 6 +- synapse/handlers/sync.py | 112 ++++++++++----------- synapse/handlers/typing.py | 4 +- synapse/handlers/user_directory.py | 2 +- synapse/http/client.py | 28 +++--- synapse/http/federation/matrix_federation_agent.py | 46 ++++----- synapse/http/federation/srv_resolver.py | 8 +- synapse/http/matrixfederationclient.py | 16 +-- synapse/logging/opentracing.py | 6 +- synapse/module_api/__init__.py | 2 +- synapse/notifier.py | 18 ++-- synapse/push/bulk_push_rule_evaluator.py | 10 +- synapse/push/httppusher.py | 18 ++-- synapse/push/mailer.py | 84 +++++++--------- synapse/push/presentable_names.py | 25 +++-- synapse/push/push_tools.py | 4 +- synapse/push/pusherpool.py | 6 +- synapse/replication/http/_base.py | 2 +- synapse/replication/http/federation.py | 10 +- synapse/replication/http/login.py | 2 +- synapse/replication/http/membership.py | 4 +- synapse/replication/http/register.py | 4 +- synapse/replication/http/send_event.py | 4 +- synapse/replication/tcp/streams/_base.py | 12 +-- synapse/replication/tcp/streams/events.py | 2 +- synapse/rest/admin/__init__.py | 48 +++++---- synapse/rest/admin/server_notice_servlet.py | 2 +- synapse/rest/client/v1/directory.py | 18 ++-- synapse/rest/client/v1/events.py | 6 +- synapse/rest/client/v1/initial_sync.py | 2 +- synapse/rest/client/v1/login.py | 14 +-- synapse/rest/client/v1/logout.py | 4 +- synapse/rest/client/v1/presence.py | 4 +- synapse/rest/client/v1/profile.py | 14 +-- synapse/rest/client/v1/push_rule.py | 10 +- synapse/rest/client/v1/pusher.py | 8 +- synapse/rest/client/v1/room.py | 46 ++++----- synapse/rest/client/v1/voip.py | 22 ++-- synapse/rest/client/v2_alpha/account.py | 32 +++--- synapse/rest/client/v2_alpha/account_data.py | 8 +- synapse/rest/client/v2_alpha/account_validity.py | 4 +- synapse/rest/client/v2_alpha/auth.py | 4 +- synapse/rest/client/v2_alpha/capabilities.py | 2 +- synapse/rest/client/v2_alpha/devices.py | 10 +- synapse/rest/client/v2_alpha/filter.py | 4 +- synapse/rest/client/v2_alpha/groups.py | 64 ++++++------ synapse/rest/client/v2_alpha/keys.py | 8 +- synapse/rest/client/v2_alpha/notifications.py | 4 +- synapse/rest/client/v2_alpha/openid.py | 18 ++-- synapse/rest/client/v2_alpha/read_marker.py | 2 +- synapse/rest/client/v2_alpha/receipts.py | 2 +- synapse/rest/client/v2_alpha/register.py | 38 ++++--- synapse/rest/client/v2_alpha/relations.py | 8 +- synapse/rest/client/v2_alpha/report_event.py | 2 +- synapse/rest/client/v2_alpha/room_keys.py | 14 +-- .../client/v2_alpha/room_upgrade_rest_servlet.py | 2 +- synapse/rest/client/v2_alpha/sendtodevice.py | 2 +- synapse/rest/client/v2_alpha/sync.py | 48 ++++----- synapse/rest/client/v2_alpha/tags.py | 6 +- synapse/rest/client/v2_alpha/thirdparty.py | 10 +- synapse/rest/client/v2_alpha/user_directory.py | 4 +- synapse/rest/media/v1/media_repository.py | 18 ++-- synapse/rest/media/v1/media_storage.py | 12 +-- synapse/rest/media/v1/preview_url_resource.py | 34 +++---- .../resource_limits_server_notices.py | 2 +- synapse/server_notices/server_notices_manager.py | 6 +- synapse/state/__init__.py | 38 ++++--- synapse/state/v1.py | 8 +- synapse/state/v2.py | 26 ++--- synapse/storage/__init__.py | 2 +- synapse/storage/_base.py | 14 +-- synapse/storage/account_data.py | 14 ++- synapse/storage/appservice.py | 14 ++- synapse/storage/background_updates.py | 20 ++-- synapse/storage/client_ips.py | 26 +++-- synapse/storage/deviceinbox.py | 10 +- synapse/storage/devices.py | 32 +++--- synapse/storage/directory.py | 10 +- synapse/storage/e2e_room_keys.py | 6 +- synapse/storage/end_to_end_keys.py | 8 +- synapse/storage/event_federation.py | 12 +-- synapse/storage/event_push_actions.py | 16 +-- synapse/storage/events.py | 34 +++---- synapse/storage/events_bg_updates.py | 6 +- synapse/storage/events_worker.py | 20 ++-- synapse/storage/filtering.py | 4 +- synapse/storage/group_server.py | 38 ++++--- synapse/storage/monthly_active_users.py | 2 +- synapse/storage/presence.py | 6 +- synapse/storage/profile.py | 12 +-- synapse/storage/push_rule.py | 18 ++-- synapse/storage/pusher.py | 10 +- synapse/storage/receipts.py | 36 +++---- synapse/storage/registration.py | 34 +++---- synapse/storage/relations.py | 4 +- synapse/storage/room.py | 10 +- synapse/storage/roommember.py | 42 ++++---- synapse/storage/search.py | 56 +++++------ synapse/storage/signatures.py | 2 +- synapse/storage/state.py | 54 +++++----- synapse/storage/stats.py | 16 +-- synapse/storage/stream.py | 44 ++++---- synapse/storage/tags.py | 12 +-- synapse/storage/transactions.py | 4 +- synapse/storage/user_directory.py | 30 +++--- synapse/storage/user_erasure_store.py | 5 +- synapse/streams/events.py | 4 +- synapse/util/__init__.py | 2 +- synapse/util/async_helpers.py | 4 +- synapse/util/caches/descriptors.py | 2 +- synapse/util/caches/response_cache.py | 2 +- synapse/util/metrics.py | 2 +- synapse/util/retryutils.py | 16 ++- synapse/visibility.py | 8 +- tests/crypto/test_keyring.py | 6 +- tests/handlers/test_register.py | 2 +- .../federation/test_matrix_federation_agent.py | 4 +- tests/http/federation/test_srv_resolver.py | 2 +- tests/http/test_fedclient.py | 2 +- tests/rest/client/test_transactions.py | 2 +- tests/storage/test_background_update.py | 4 +- tests/storage/test_redaction.py | 4 +- tests/storage/test_roommember.py | 2 +- tests/storage/test_state.py | 2 +- tests/test_visibility.py | 6 +- tests/util/caches/test_descriptors.py | 8 +- tests/utils.py | 6 +- 177 files changed, 1360 insertions(+), 1514 deletions(-) create mode 100644 changelog.d/5736.misc mode change 100755 => 100644 synapse/app/homeserver.py (limited to 'tests') diff --git a/changelog.d/5736.misc b/changelog.d/5736.misc new file mode 100644 index 0000000000..5713b8b32d --- /dev/null +++ b/changelog.d/5736.misc @@ -0,0 +1 @@ +Replace uses of returnValue with plain return, as returnValue is not needed on Python 3. diff --git a/docs/log_contexts.rst b/docs/log_contexts.rst index f5cd5de8ab..4502cd9454 100644 --- a/docs/log_contexts.rst +++ b/docs/log_contexts.rst @@ -148,7 +148,7 @@ call any other functions. d = more_stuff() result = yield d # also fine, of course - defer.returnValue(result) + return result def nonInlineCallbacksFun(): logger.debug("just a wrapper really") diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 7ce6540bdd..351790cca4 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -128,7 +128,7 @@ class Auth(object): ) self._check_joined_room(member, user_id, room_id) - defer.returnValue(member) + return member @defer.inlineCallbacks def check_user_was_in_room(self, room_id, user_id): @@ -156,13 +156,13 @@ class Auth(object): if forgot: raise AuthError(403, "User %s not in room %s" % (user_id, room_id)) - defer.returnValue(member) + return member @defer.inlineCallbacks def check_host_in_room(self, room_id, host): with Measure(self.clock, "check_host_in_room"): latest_event_ids = yield self.store.is_host_joined(room_id, host) - defer.returnValue(latest_event_ids) + return latest_event_ids def _check_joined_room(self, member, user_id, room_id): if not member or member.membership != Membership.JOIN: @@ -219,9 +219,7 @@ class Auth(object): device_id="dummy-device", # stubbed ) - defer.returnValue( - synapse.types.create_requester(user_id, app_service=app_service) - ) + return synapse.types.create_requester(user_id, app_service=app_service) user_info = yield self.get_user_by_access_token(access_token, rights) user = user_info["user"] @@ -262,10 +260,8 @@ class Auth(object): request.authenticated_entity = user.to_string() - defer.returnValue( - synapse.types.create_requester( - user, token_id, is_guest, device_id, app_service=app_service - ) + return synapse.types.create_requester( + user, token_id, is_guest, device_id, app_service=app_service ) except KeyError: raise MissingClientTokenError() @@ -276,25 +272,25 @@ class Auth(object): self.get_access_token_from_request(request) ) if app_service is None: - defer.returnValue((None, None)) + return (None, None) if app_service.ip_range_whitelist: ip_address = IPAddress(self.hs.get_ip_from_request(request)) if ip_address not in app_service.ip_range_whitelist: - defer.returnValue((None, None)) + return (None, None) if b"user_id" not in request.args: - defer.returnValue((app_service.sender, app_service)) + return (app_service.sender, app_service) user_id = request.args[b"user_id"][0].decode("utf8") if app_service.sender == user_id: - defer.returnValue((app_service.sender, app_service)) + return (app_service.sender, app_service) if not app_service.is_interested_in_user(user_id): raise AuthError(403, "Application service cannot masquerade as this user.") if not (yield self.store.get_user_by_id(user_id)): raise AuthError(403, "Application service has not registered this user") - defer.returnValue((user_id, app_service)) + return (user_id, app_service) @defer.inlineCallbacks def get_user_by_access_token(self, token, rights="access"): @@ -330,7 +326,7 @@ class Auth(object): msg="Access token has expired", soft_logout=True ) - defer.returnValue(r) + return r # otherwise it needs to be a valid macaroon try: @@ -378,7 +374,7 @@ class Auth(object): } else: raise RuntimeError("Unknown rights setting %s", rights) - defer.returnValue(ret) + return ret except ( _InvalidMacaroonException, pymacaroons.exceptions.MacaroonException, @@ -506,7 +502,7 @@ class Auth(object): def _look_up_user_by_access_token(self, token): ret = yield self.store.get_user_by_access_token(token) if not ret: - defer.returnValue(None) + return None # we use ret.get() below because *lots* of unit tests stub out # get_user_by_access_token in a way where it only returns a couple of @@ -518,7 +514,7 @@ class Auth(object): "device_id": ret.get("device_id"), "valid_until_ms": ret.get("valid_until_ms"), } - defer.returnValue(user_info) + return user_info def get_appservice_by_req(self, request): token = self.get_access_token_from_request(request) @@ -543,7 +539,7 @@ class Auth(object): @defer.inlineCallbacks def compute_auth_events(self, event, current_state_ids, for_verification=False): if event.type == EventTypes.Create: - defer.returnValue([]) + return [] auth_ids = [] @@ -604,7 +600,7 @@ class Auth(object): if member_event.content["membership"] == Membership.JOIN: auth_ids.append(member_event.event_id) - defer.returnValue(auth_ids) + return auth_ids @defer.inlineCallbacks def check_can_change_room_list(self, room_id, user): @@ -618,7 +614,7 @@ class Auth(object): is_admin = yield self.is_server_admin(user) if is_admin: - defer.returnValue(True) + return True user_id = user.to_string() yield self.check_joined_room(room_id, user_id) @@ -712,7 +708,7 @@ class Auth(object): # * The user is a guest user, and has joined the room # else it will throw. member_event = yield self.check_user_was_in_room(room_id, user_id) - defer.returnValue((member_event.membership, member_event.event_id)) + return (member_event.membership, member_event.event_id) except AuthError: visibility = yield self.state.get_current_state( room_id, EventTypes.RoomHistoryVisibility, "" @@ -721,7 +717,7 @@ class Auth(object): visibility and visibility.content["history_visibility"] == "world_readable" ): - defer.returnValue((Membership.JOIN, None)) + return (Membership.JOIN, None) return raise AuthError( 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 9b3daca29b..9f06556bd2 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -132,7 +132,7 @@ class Filtering(object): @defer.inlineCallbacks def get_user_filter(self, user_localpart, filter_id): result = yield self.store.get_user_filter(user_localpart, filter_id) - defer.returnValue(FilterCollection(result)) + return FilterCollection(result) def add_user_filter(self, user_localpart, user_filter): self.check_valid_filter(user_filter) diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index 5b563c2778..e2822ca848 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -70,12 +70,12 @@ class PresenceStatusStubServlet(RestServlet): except HttpResponseException as e: raise e.to_synapse_error() - defer.returnValue((200, result)) + return (200, result) @defer.inlineCallbacks def on_PUT(self, request, user_id): yield self.auth.get_user_by_req(request) - defer.returnValue((200, {})) + return (200, {}) class KeyUploadServlet(RestServlet): @@ -126,11 +126,11 @@ class KeyUploadServlet(RestServlet): self.main_uri + request.uri.decode("ascii"), body, headers=headers ) - defer.returnValue((200, result)) + return (200, result) else: # Just interested in counts. result = yield self.store.count_e2e_one_time_keys(user_id, device_id) - defer.returnValue((200, {"one_time_key_counts": result})) + return (200, {"one_time_key_counts": result}) class FrontendProxySlavedStore( diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py old mode 100755 new mode 100644 index 34c3f5ee99..7d6b51b5bc --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -406,7 +406,7 @@ def setup(config_options): if provision: yield acme.provision_certificate() - defer.returnValue(provision) + return provision @defer.inlineCallbacks def reprovision_acme(): diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index b26a31dd54..33b3579425 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -175,21 +175,21 @@ class ApplicationService(object): @defer.inlineCallbacks def _matches_user(self, event, store): if not event: - defer.returnValue(False) + return False if self.is_interested_in_user(event.sender): - defer.returnValue(True) + return True # also check m.room.member state key if event.type == EventTypes.Member and self.is_interested_in_user( event.state_key ): - defer.returnValue(True) + return True if not store: - defer.returnValue(False) + return False does_match = yield self._matches_user_in_member_list(event.room_id, store) - defer.returnValue(does_match) + return does_match @cachedInlineCallbacks(num_args=1, cache_context=True) def _matches_user_in_member_list(self, room_id, store, cache_context): @@ -200,8 +200,8 @@ class ApplicationService(object): # check joined member events for user_id in member_list: if self.is_interested_in_user(user_id): - defer.returnValue(True) - defer.returnValue(False) + return True + return False def _matches_room_id(self, event): if hasattr(event, "room_id"): @@ -211,13 +211,13 @@ class ApplicationService(object): @defer.inlineCallbacks def _matches_aliases(self, event, store): if not store or not event: - defer.returnValue(False) + return False alias_list = yield store.get_aliases_for_room(event.room_id) for alias in alias_list: if self.is_interested_in_alias(alias): - defer.returnValue(True) - defer.returnValue(False) + return True + return False @defer.inlineCallbacks def is_interested(self, event, store=None): @@ -231,15 +231,15 @@ class ApplicationService(object): """ # Do cheap checks first if self._matches_room_id(event): - defer.returnValue(True) + return True if (yield self._matches_aliases(event, store)): - defer.returnValue(True) + return True if (yield self._matches_user(event, store)): - defer.returnValue(True) + return True - defer.returnValue(False) + return False def is_interested_in_user(self, user_id): return ( diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 571881775b..007ca75a94 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -97,40 +97,40 @@ class ApplicationServiceApi(SimpleHttpClient): @defer.inlineCallbacks def query_user(self, service, user_id): if service.url is None: - defer.returnValue(False) + return False uri = service.url + ("/users/%s" % urllib.parse.quote(user_id)) response = None try: response = yield self.get_json(uri, {"access_token": service.hs_token}) if response is not None: # just an empty json object - defer.returnValue(True) + return True except CodeMessageException as e: if e.code == 404: - defer.returnValue(False) + return False return logger.warning("query_user to %s received %s", uri, e.code) except Exception as ex: logger.warning("query_user to %s threw exception %s", uri, ex) - defer.returnValue(False) + return False @defer.inlineCallbacks def query_alias(self, service, alias): if service.url is None: - defer.returnValue(False) + return False uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias)) response = None try: response = yield self.get_json(uri, {"access_token": service.hs_token}) if response is not None: # just an empty json object - defer.returnValue(True) + return True except CodeMessageException as e: logger.warning("query_alias to %s received %s", uri, e.code) if e.code == 404: - defer.returnValue(False) + return False return except Exception as ex: logger.warning("query_alias to %s threw exception %s", uri, ex) - defer.returnValue(False) + return False @defer.inlineCallbacks def query_3pe(self, service, kind, protocol, fields): @@ -141,7 +141,7 @@ class ApplicationServiceApi(SimpleHttpClient): else: raise ValueError("Unrecognised 'kind' argument %r to query_3pe()", kind) if service.url is None: - defer.returnValue([]) + return [] uri = "%s%s/thirdparty/%s/%s" % ( service.url, @@ -155,7 +155,7 @@ class ApplicationServiceApi(SimpleHttpClient): logger.warning( "query_3pe to %s returned an invalid response %r", uri, response ) - defer.returnValue([]) + return [] ret = [] for r in response: @@ -166,14 +166,14 @@ class ApplicationServiceApi(SimpleHttpClient): "query_3pe to %s returned an invalid result %r", uri, r ) - defer.returnValue(ret) + return ret except Exception as ex: logger.warning("query_3pe to %s threw exception %s", uri, ex) - defer.returnValue([]) + return [] def get_3pe_protocol(self, service, protocol): if service.url is None: - defer.returnValue({}) + return {} @defer.inlineCallbacks def _get(): @@ -189,7 +189,7 @@ class ApplicationServiceApi(SimpleHttpClient): logger.warning( "query_3pe_protocol to %s did not return a" " valid result", uri ) - defer.returnValue(None) + return None for instance in info.get("instances", []): network_id = instance.get("network_id", None) @@ -198,10 +198,10 @@ class ApplicationServiceApi(SimpleHttpClient): service.id, network_id ).to_string() - defer.returnValue(info) + return info except Exception as ex: logger.warning("query_3pe_protocol to %s threw exception %s", uri, ex) - defer.returnValue(None) + return None key = (service.id, protocol) return self.protocol_meta_cache.wrap(key, _get) @@ -209,7 +209,7 @@ class ApplicationServiceApi(SimpleHttpClient): @defer.inlineCallbacks def push_bulk(self, service, events, txn_id=None): if service.url is None: - defer.returnValue(True) + return True events = self._serialize(events) @@ -229,14 +229,14 @@ class ApplicationServiceApi(SimpleHttpClient): ) sent_transactions_counter.labels(service.id).inc() sent_events_counter.labels(service.id).inc(len(events)) - defer.returnValue(True) + return True return except CodeMessageException as e: logger.warning("push_bulk to %s received %s", uri, e.code) except Exception as ex: logger.warning("push_bulk to %s threw exception %s", uri, ex) failed_transactions_counter.labels(service.id).inc() - defer.returnValue(False) + return False def _serialize(self, events): time_now = self.clock.time_msec() diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index e5b36494f5..42a350bff8 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -193,7 +193,7 @@ class _TransactionController(object): @defer.inlineCallbacks def _is_service_up(self, service): state = yield self.store.get_appservice_state(service) - defer.returnValue(state == ApplicationServiceState.UP or state is None) + return state == ApplicationServiceState.UP or state is None class _Recoverer(object): @@ -208,7 +208,7 @@ class _Recoverer(object): r.service.id, ) r.recover() - defer.returnValue(recoverers) + return recoverers def __init__(self, clock, store, as_api, service, callback): self.clock = clock diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index e8bb420ad1..6c3e885e72 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -462,7 +462,7 @@ class StoreKeyFetcher(KeyFetcher): keys = {} for (server_name, key_id), key in res.items(): keys.setdefault(server_name, {})[key_id] = key - defer.returnValue(keys) + return keys class BaseV2KeyFetcher(object): @@ -566,7 +566,7 @@ class BaseV2KeyFetcher(object): ).addErrback(unwrapFirstError) ) - defer.returnValue(verify_keys) + return verify_keys class PerspectivesKeyFetcher(BaseV2KeyFetcher): @@ -588,7 +588,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): result = yield self.get_server_verify_key_v2_indirect( keys_to_fetch, key_server ) - defer.returnValue(result) + return result except KeyLookupError as e: logger.warning( "Key lookup failed from %r: %s", key_server.server_name, e @@ -601,7 +601,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): str(e), ) - defer.returnValue({}) + return {} results = yield make_deferred_yieldable( defer.gatherResults( @@ -615,7 +615,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): for server_name, keys in result.items(): union_of_keys.setdefault(server_name, {}).update(keys) - defer.returnValue(union_of_keys) + return union_of_keys @defer.inlineCallbacks def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server): @@ -701,7 +701,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): perspective_name, time_now_ms, added_keys ) - defer.returnValue(keys) + return keys def _validate_perspectives_response(self, key_server, response): """Optionally check the signature on the result of a /key/query request @@ -843,7 +843,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher): ) keys.update(response_keys) - defer.returnValue(keys) + return keys @defer.inlineCallbacks diff --git a/synapse/events/builder.py b/synapse/events/builder.py index db011e0407..3997751337 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -144,15 +144,13 @@ class EventBuilder(object): if self._origin_server_ts is not None: event_dict["origin_server_ts"] = self._origin_server_ts - defer.returnValue( - create_local_event_from_event_dict( - clock=self._clock, - hostname=self._hostname, - signing_key=self._signing_key, - format_version=self.format_version, - event_dict=event_dict, - internal_metadata_dict=self.internal_metadata.get_dict(), - ) + return create_local_event_from_event_dict( + clock=self._clock, + hostname=self._hostname, + signing_key=self._signing_key, + format_version=self.format_version, + event_dict=event_dict, + internal_metadata_dict=self.internal_metadata.get_dict(), ) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index a9545e6c1b..acbcbeeced 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -133,19 +133,17 @@ class EventContext(object): else: prev_state_id = None - defer.returnValue( - { - "prev_state_id": prev_state_id, - "event_type": event.type, - "event_state_key": event.state_key if event.is_state() else None, - "state_group": self.state_group, - "rejected": self.rejected, - "prev_group": self.prev_group, - "delta_ids": _encode_state_dict(self.delta_ids), - "prev_state_events": self.prev_state_events, - "app_service_id": self.app_service.id if self.app_service else None, - } - ) + return { + "prev_state_id": prev_state_id, + "event_type": event.type, + "event_state_key": event.state_key if event.is_state() else None, + "state_group": self.state_group, + "rejected": self.rejected, + "prev_group": self.prev_group, + "delta_ids": _encode_state_dict(self.delta_ids), + "prev_state_events": self.prev_state_events, + "app_service_id": self.app_service.id if self.app_service else None, + } @staticmethod def deserialize(store, input): @@ -202,7 +200,7 @@ class EventContext(object): yield make_deferred_yieldable(self._fetching_state_deferred) - defer.returnValue(self._current_state_ids) + return self._current_state_ids @defer.inlineCallbacks def get_prev_state_ids(self, store): @@ -222,7 +220,7 @@ class EventContext(object): yield make_deferred_yieldable(self._fetching_state_deferred) - defer.returnValue(self._prev_state_ids) + return self._prev_state_ids def get_cached_current_state_ids(self): """Gets the current state IDs if we have them already cached. diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 8f5d95696b..714a9b1579 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -51,7 +51,7 @@ class ThirdPartyEventRules(object): defer.Deferred[bool]: True if the event should be allowed, False if not. """ if self.third_party_rules is None: - defer.returnValue(True) + return True prev_state_ids = yield context.get_prev_state_ids(self.store) @@ -61,7 +61,7 @@ class ThirdPartyEventRules(object): state_events[key] = yield self.store.get_event(event_id, allow_none=True) ret = yield self.third_party_rules.check_event_allowed(event, state_events) - defer.returnValue(ret) + return ret @defer.inlineCallbacks def on_create_room(self, requester, config, is_requester_admin): @@ -98,7 +98,7 @@ class ThirdPartyEventRules(object): """ if self.third_party_rules is None: - defer.returnValue(True) + return True state_ids = yield self.store.get_filtered_current_state_ids(room_id) room_state_events = yield self.store.get_events(state_ids.values()) @@ -110,4 +110,4 @@ class ThirdPartyEventRules(object): ret = yield self.third_party_rules.check_threepid_can_be_invited( medium, address, state_events ) - defer.returnValue(ret) + return ret diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 9487a886f5..07d1c5bcf0 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -360,7 +360,7 @@ class EventClientSerializer(object): """ # To handle the case of presence events and the like if not isinstance(event, EventBase): - defer.returnValue(event) + return event event_id = event.event_id serialized_event = serialize_event(event, time_now, **kwargs) @@ -406,7 +406,7 @@ class EventClientSerializer(object): "sender": edit.sender, } - defer.returnValue(serialized_event) + return serialized_event def serialize_events(self, events, time_now, **kwargs): """Serializes multiple events. diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index f7bb806ae7..5a1e23a145 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -106,7 +106,7 @@ class FederationBase(object): "Failed to find copy of %s with valid signature", pdu.event_id ) - defer.returnValue(res) + return res handle = preserve_fn(handle_check_result) deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)] @@ -116,9 +116,9 @@ class FederationBase(object): ).addErrback(unwrapFirstError) if include_none: - defer.returnValue(valid_pdus) + return valid_pdus else: - defer.returnValue([p for p in valid_pdus if p]) + return [p for p in valid_pdus if p] def _check_sigs_and_hash(self, room_version, pdu): return make_deferred_yieldable( diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 3cb4b94420..25ed1257f1 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -213,7 +213,7 @@ class FederationClient(FederationBase): ).addErrback(unwrapFirstError) ) - defer.returnValue(pdus) + return pdus @defer.inlineCallbacks @log_function @@ -245,7 +245,7 @@ class FederationClient(FederationBase): ev = self._get_pdu_cache.get(event_id) if ev: - defer.returnValue(ev) + return ev pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {}) @@ -307,7 +307,7 @@ class FederationClient(FederationBase): if signed_pdu: self._get_pdu_cache[event_id] = signed_pdu - defer.returnValue(signed_pdu) + return signed_pdu @defer.inlineCallbacks @log_function @@ -355,7 +355,7 @@ class FederationClient(FederationBase): auth_chain.sort(key=lambda e: e.depth) - defer.returnValue((pdus, auth_chain)) + return (pdus, auth_chain) except HttpResponseException as e: if e.code == 400 or e.code == 404: logger.info("Failed to use get_room_state_ids API, falling back") @@ -404,7 +404,7 @@ class FederationClient(FederationBase): signed_auth.sort(key=lambda e: e.depth) - defer.returnValue((signed_pdus, signed_auth)) + return (signed_pdus, signed_auth) @defer.inlineCallbacks def get_events_from_store_or_dest(self, destination, room_id, event_ids): @@ -429,7 +429,7 @@ class FederationClient(FederationBase): missing_events.discard(k) if not missing_events: - defer.returnValue((signed_events, failed_to_fetch)) + return (signed_events, failed_to_fetch) logger.debug( "Fetching unknown state/auth events %s for room %s", @@ -465,7 +465,7 @@ class FederationClient(FederationBase): # We removed all events we successfully fetched from `batch` failed_to_fetch.update(batch) - defer.returnValue((signed_events, failed_to_fetch)) + return (signed_events, failed_to_fetch) @defer.inlineCallbacks @log_function @@ -485,7 +485,7 @@ class FederationClient(FederationBase): signed_auth.sort(key=lambda e: e.depth) - defer.returnValue(signed_auth) + return signed_auth @defer.inlineCallbacks def _try_destination_list(self, description, destinations, callback): @@ -521,7 +521,7 @@ class FederationClient(FederationBase): try: res = yield callback(destination) - defer.returnValue(res) + return res except InvalidResponseError as e: logger.warn("Failed to %s via %s: %s", description, destination, e) except HttpResponseException as e: @@ -615,7 +615,7 @@ class FederationClient(FederationBase): event_dict=pdu_dict, ) - defer.returnValue((destination, ev, event_format)) + return (destination, ev, event_format) return self._try_destination_list( "make_" + membership, destinations, send_request @@ -728,13 +728,11 @@ class FederationClient(FederationBase): check_authchain_validity(signed_auth) - defer.returnValue( - { - "state": signed_state, - "auth_chain": signed_auth, - "origin": destination, - } - ) + return { + "state": signed_state, + "auth_chain": signed_auth, + "origin": destination, + } return self._try_destination_list("send_join", destinations, send_request) @@ -758,7 +756,7 @@ class FederationClient(FederationBase): # FIXME: We should handle signature failures more gracefully. - defer.returnValue(pdu) + return pdu @defer.inlineCallbacks def _do_send_invite(self, destination, pdu, room_version): @@ -786,7 +784,7 @@ class FederationClient(FederationBase): "invite_room_state": pdu.unsigned.get("invite_room_state", []), }, ) - defer.returnValue(content) + return content except HttpResponseException as e: if e.code in [400, 404]: err = e.to_synapse_error() @@ -821,7 +819,7 @@ class FederationClient(FederationBase): event_id=pdu.event_id, content=pdu.get_pdu_json(time_now), ) - defer.returnValue(content) + return content def send_leave(self, destinations, pdu): """Sends a leave event to one of a list of homeservers. @@ -856,7 +854,7 @@ class FederationClient(FederationBase): ) logger.debug("Got content: %s", content) - defer.returnValue(None) + return None return self._try_destination_list("send_leave", destinations, send_request) @@ -917,7 +915,7 @@ class FederationClient(FederationBase): "missing": content.get("missing", []), } - defer.returnValue(ret) + return ret @defer.inlineCallbacks def get_missing_events( @@ -974,7 +972,7 @@ class FederationClient(FederationBase): # get_missing_events signed_events = [] - defer.returnValue(signed_events) + return signed_events @defer.inlineCallbacks def forward_third_party_invite(self, destinations, room_id, event_dict): @@ -986,7 +984,7 @@ class FederationClient(FederationBase): yield self.transport_layer.exchange_third_party_invite( destination=destination, room_id=room_id, event_dict=event_dict ) - defer.returnValue(None) + return None except CodeMessageException: raise except Exception as e: diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 8c0a18b120..b4b9a05ca6 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -99,7 +99,7 @@ class FederationServer(FederationBase): res = self._transaction_from_pdus(pdus).get_dict() - defer.returnValue((200, res)) + return (200, res) @defer.inlineCallbacks @log_function @@ -126,7 +126,7 @@ class FederationServer(FederationBase): origin, transaction, request_time ) - defer.returnValue(result) + return result @defer.inlineCallbacks def _handle_incoming_transaction(self, origin, transaction, request_time): @@ -147,8 +147,7 @@ class FederationServer(FederationBase): "[%s] We've already responded to this request", transaction.transaction_id, ) - defer.returnValue(response) - return + return response logger.debug("[%s] Transaction is new", transaction.transaction_id) @@ -163,7 +162,7 @@ class FederationServer(FederationBase): yield self.transaction_actions.set_response( origin, transaction, 400, response ) - defer.returnValue((400, response)) + return (400, response) received_pdus_counter.inc(len(transaction.pdus)) @@ -265,7 +264,7 @@ class FederationServer(FederationBase): logger.debug("Returning: %s", str(response)) yield self.transaction_actions.set_response(origin, transaction, 200, response) - defer.returnValue((200, response)) + return (200, response) @defer.inlineCallbacks def received_edu(self, origin, edu_type, content): @@ -298,7 +297,7 @@ class FederationServer(FederationBase): event_id, ) - defer.returnValue((200, resp)) + return (200, resp) @defer.inlineCallbacks def on_state_ids_request(self, origin, room_id, event_id): @@ -315,9 +314,7 @@ class FederationServer(FederationBase): state_ids = yield self.handler.get_state_ids_for_pdu(room_id, event_id) auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids) - defer.returnValue( - (200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}) - ) + return (200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}) @defer.inlineCallbacks def _on_context_state_request_compute(self, room_id, event_id): @@ -336,12 +333,10 @@ class FederationServer(FederationBase): ) ) - defer.returnValue( - { - "pdus": [pdu.get_pdu_json() for pdu in pdus], - "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], - } - ) + return { + "pdus": [pdu.get_pdu_json() for pdu in pdus], + "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], + } @defer.inlineCallbacks @log_function @@ -349,15 +344,15 @@ class FederationServer(FederationBase): pdu = yield self.handler.get_persisted_pdu(origin, event_id) if pdu: - defer.returnValue((200, self._transaction_from_pdus([pdu]).get_dict())) + return (200, self._transaction_from_pdus([pdu]).get_dict()) else: - defer.returnValue((404, "")) + return (404, "") @defer.inlineCallbacks def on_query_request(self, query_type, args): received_queries_counter.labels(query_type).inc() resp = yield self.registry.on_query(query_type, args) - defer.returnValue((200, resp)) + return (200, resp) @defer.inlineCallbacks def on_make_join_request(self, origin, room_id, user_id, supported_versions): @@ -371,9 +366,7 @@ class FederationServer(FederationBase): pdu = yield self.handler.on_make_join_request(room_id, user_id) time_now = self._clock.time_msec() - defer.returnValue( - {"event": pdu.get_pdu_json(time_now), "room_version": room_version} - ) + return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} @defer.inlineCallbacks def on_invite_request(self, origin, content, room_version): @@ -391,7 +384,7 @@ class FederationServer(FederationBase): yield self.check_server_matches_acl(origin_host, pdu.room_id) ret_pdu = yield self.handler.on_invite_request(origin, pdu) time_now = self._clock.time_msec() - defer.returnValue({"event": ret_pdu.get_pdu_json(time_now)}) + return {"event": ret_pdu.get_pdu_json(time_now)} @defer.inlineCallbacks def on_send_join_request(self, origin, content, room_id): @@ -407,16 +400,14 @@ class FederationServer(FederationBase): logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) res_pdus = yield self.handler.on_send_join_request(origin, pdu) time_now = self._clock.time_msec() - defer.returnValue( - ( - 200, - { - "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]], - "auth_chain": [ - p.get_pdu_json(time_now) for p in res_pdus["auth_chain"] - ], - }, - ) + return ( + 200, + { + "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]], + "auth_chain": [ + p.get_pdu_json(time_now) for p in res_pdus["auth_chain"] + ], + }, ) @defer.inlineCallbacks @@ -428,9 +419,7 @@ class FederationServer(FederationBase): room_version = yield self.store.get_room_version(room_id) time_now = self._clock.time_msec() - defer.returnValue( - {"event": pdu.get_pdu_json(time_now), "room_version": room_version} - ) + return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} @defer.inlineCallbacks def on_send_leave_request(self, origin, content, room_id): @@ -445,7 +434,7 @@ class FederationServer(FederationBase): logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures) yield self.handler.on_send_leave_request(origin, pdu) - defer.returnValue((200, {})) + return (200, {}) @defer.inlineCallbacks def on_event_auth(self, origin, room_id, event_id): @@ -456,7 +445,7 @@ class FederationServer(FederationBase): time_now = self._clock.time_msec() auth_pdus = yield self.handler.on_event_auth(event_id) res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]} - defer.returnValue((200, res)) + return (200, res) @defer.inlineCallbacks def on_query_auth_request(self, origin, content, room_id, event_id): @@ -509,7 +498,7 @@ class FederationServer(FederationBase): "missing": ret.get("missing", []), } - defer.returnValue((200, send_content)) + return (200, send_content) @log_function def on_query_client_keys(self, origin, content): @@ -548,7 +537,7 @@ class FederationServer(FederationBase): ), ) - defer.returnValue({"one_time_keys": json_result}) + return {"one_time_keys": json_result} @defer.inlineCallbacks @log_function @@ -580,9 +569,7 @@ class FederationServer(FederationBase): time_now = self._clock.time_msec() - defer.returnValue( - {"events": [ev.get_pdu_json(time_now) for ev in missing_events]} - ) + return {"events": [ev.get_pdu_json(time_now) for ev in missing_events]} @log_function def on_openid_userinfo(self, token): @@ -676,14 +663,14 @@ class FederationServer(FederationBase): ret = yield self.handler.exchange_third_party_invite( sender_user_id, target_user_id, room_id, signed ) - defer.returnValue(ret) + return ret @defer.inlineCallbacks def on_exchange_third_party_invite_request(self, origin, room_id, event_dict): ret = yield self.handler.on_exchange_third_party_invite_request( origin, room_id, event_dict ) - defer.returnValue(ret) + return ret @defer.inlineCallbacks def check_server_matches_acl(self, server_name, room_id): diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 9aab12c0d3..fad980b893 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -374,7 +374,7 @@ class PerDestinationQueue(object): assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs" - defer.returnValue((edus, now_stream_id)) + return (edus, now_stream_id) @defer.inlineCallbacks def _get_to_device_message_edus(self, limit): @@ -393,4 +393,4 @@ class PerDestinationQueue(object): for content in contents ] - defer.returnValue((edus, stream_id)) + return (edus, stream_id) diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 0460a8c4ac..52706302f2 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -133,4 +133,4 @@ class TransactionManager(object): ) success = False - defer.returnValue(success) + return success diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 1aae9ec9e7..2a6709ff48 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -183,7 +183,7 @@ class TransportLayerClient(object): try_trailing_slash_on_400=True, ) - defer.returnValue(response) + return response @defer.inlineCallbacks @log_function @@ -201,7 +201,7 @@ class TransportLayerClient(object): ignore_backoff=ignore_backoff, ) - defer.returnValue(content) + return content @defer.inlineCallbacks @log_function @@ -259,7 +259,7 @@ class TransportLayerClient(object): ignore_backoff=ignore_backoff, ) - defer.returnValue(content) + return content @defer.inlineCallbacks @log_function @@ -270,7 +270,7 @@ class TransportLayerClient(object): destination=destination, path=path, data=content ) - defer.returnValue(response) + return response @defer.inlineCallbacks @log_function @@ -288,7 +288,7 @@ class TransportLayerClient(object): ignore_backoff=True, ) - defer.returnValue(response) + return response @defer.inlineCallbacks @log_function @@ -299,7 +299,7 @@ class TransportLayerClient(object): destination=destination, path=path, data=content, ignore_backoff=True ) - defer.returnValue(response) + return response @defer.inlineCallbacks @log_function @@ -310,7 +310,7 @@ class TransportLayerClient(object): destination=destination, path=path, data=content, ignore_backoff=True ) - defer.returnValue(response) + return response @defer.inlineCallbacks @log_function @@ -339,7 +339,7 @@ class TransportLayerClient(object): destination=remote_server, path=path, args=args, ignore_backoff=True ) - defer.returnValue(response) + return response @defer.inlineCallbacks @log_function @@ -350,7 +350,7 @@ class TransportLayerClient(object): destination=destination, path=path, data=event_dict ) - defer.returnValue(response) + return response @defer.inlineCallbacks @log_function @@ -359,7 +359,7 @@ class TransportLayerClient(object): content = yield self.client.get_json(destination=destination, path=path) - defer.returnValue(content) + return content @defer.inlineCallbacks @log_function @@ -370,7 +370,7 @@ class TransportLayerClient(object): destination=destination, path=path, data=content ) - defer.returnValue(content) + return content @defer.inlineCallbacks @log_function @@ -402,7 +402,7 @@ class TransportLayerClient(object): content = yield self.client.post_json( destination=destination, path=path, data=query_content, timeout=timeout ) - defer.returnValue(content) + return content @defer.inlineCallbacks @log_function @@ -426,7 +426,7 @@ class TransportLayerClient(object): content = yield self.client.get_json( destination=destination, path=path, timeout=timeout ) - defer.returnValue(content) + return content @defer.inlineCallbacks @log_function @@ -460,7 +460,7 @@ class TransportLayerClient(object): content = yield self.client.post_json( destination=destination, path=path, data=query_content, timeout=timeout ) - defer.returnValue(content) + return content @defer.inlineCallbacks @log_function @@ -488,7 +488,7 @@ class TransportLayerClient(object): timeout=timeout, ) - defer.returnValue(content) + return content @log_function def get_group_profile(self, destination, group_id, requester_user_id): diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index f497711133..dfd7ae041b 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -157,7 +157,7 @@ class GroupAttestionRenewer(object): yield self.store.update_remote_attestion(group_id, user_id, attestation) - defer.returnValue({}) + return {} def _start_renew_attestations(self): return run_as_background_process("renew_attestations", self._renew_attestations) diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 168c9e3f84..d50e691436 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -85,7 +85,7 @@ class GroupsServerHandler(object): if not is_admin: raise SynapseError(403, "User is not admin in group") - defer.returnValue(group) + return group @defer.inlineCallbacks def get_group_summary(self, group_id, requester_user_id): @@ -151,22 +151,20 @@ class GroupsServerHandler(object): group_id, requester_user_id ) - defer.returnValue( - { - "profile": profile, - "users_section": { - "users": users, - "roles": roles, - "total_user_count_estimate": 0, # TODO - }, - "rooms_section": { - "rooms": rooms, - "categories": categories, - "total_room_count_estimate": 0, # TODO - }, - "user": membership_info, - } - ) + return { + "profile": profile, + "users_section": { + "users": users, + "roles": roles, + "total_user_count_estimate": 0, # TODO + }, + "rooms_section": { + "rooms": rooms, + "categories": categories, + "total_room_count_estimate": 0, # TODO + }, + "user": membership_info, + } @defer.inlineCallbacks def update_group_summary_room( @@ -192,7 +190,7 @@ class GroupsServerHandler(object): is_public=is_public, ) - defer.returnValue({}) + return {} @defer.inlineCallbacks def delete_group_summary_room( @@ -208,7 +206,7 @@ class GroupsServerHandler(object): group_id=group_id, room_id=room_id, category_id=category_id ) - defer.returnValue({}) + return {} @defer.inlineCallbacks def set_group_join_policy(self, group_id, requester_user_id, content): @@ -228,7 +226,7 @@ class GroupsServerHandler(object): yield self.store.set_group_join_policy(group_id, join_policy=join_policy) - defer.returnValue({}) + return {} @defer.inlineCallbacks def get_group_categories(self, group_id, requester_user_id): @@ -237,7 +235,7 @@ class GroupsServerHandler(object): yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) categories = yield self.store.get_group_categories(group_id=group_id) - defer.returnValue({"categories": categories}) + return {"categories": categories} @defer.inlineCallbacks def get_group_category(self, group_id, requester_user_id, category_id): @@ -249,7 +247,7 @@ class GroupsServerHandler(object): group_id=group_id, category_id=category_id ) - defer.returnValue(res) + return res @defer.inlineCallbacks def update_group_category(self, group_id, requester_user_id, category_id, content): @@ -269,7 +267,7 @@ class GroupsServerHandler(object): profile=profile, ) - defer.returnValue({}) + return {} @defer.inlineCallbacks def delete_group_category(self, group_id, requester_user_id, category_id): @@ -283,7 +281,7 @@ class GroupsServerHandler(object): group_id=group_id, category_id=category_id ) - defer.returnValue({}) + return {} @defer.inlineCallbacks def get_group_roles(self, group_id, requester_user_id): @@ -292,7 +290,7 @@ class GroupsServerHandler(object): yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) roles = yield self.store.get_group_roles(group_id=group_id) - defer.returnValue({"roles": roles}) + return {"roles": roles} @defer.inlineCallbacks def get_group_role(self, group_id, requester_user_id, role_id): @@ -301,7 +299,7 @@ class GroupsServerHandler(object): yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) res = yield self.store.get_group_role(group_id=group_id, role_id=role_id) - defer.returnValue(res) + return res @defer.inlineCallbacks def update_group_role(self, group_id, requester_user_id, role_id, content): @@ -319,7 +317,7 @@ class GroupsServerHandler(object): group_id=group_id, role_id=role_id, is_public=is_public, profile=profile ) - defer.returnValue({}) + return {} @defer.inlineCallbacks def delete_group_role(self, group_id, requester_user_id, role_id): @@ -331,7 +329,7 @@ class GroupsServerHandler(object): yield self.store.remove_group_role(group_id=group_id, role_id=role_id) - defer.returnValue({}) + return {} @defer.inlineCallbacks def update_group_summary_user( @@ -355,7 +353,7 @@ class GroupsServerHandler(object): is_public=is_public, ) - defer.returnValue({}) + return {} @defer.inlineCallbacks def delete_group_summary_user(self, group_id, requester_user_id, user_id, role_id): @@ -369,7 +367,7 @@ class GroupsServerHandler(object): group_id=group_id, user_id=user_id, role_id=role_id ) - defer.returnValue({}) + return {} @defer.inlineCallbacks def get_group_profile(self, group_id, requester_user_id): @@ -391,7 +389,7 @@ class GroupsServerHandler(object): group_description = {key: group[key] for key in cols} group_description["is_openly_joinable"] = group["join_policy"] == "open" - defer.returnValue(group_description) + return group_description else: raise SynapseError(404, "Unknown group") @@ -461,9 +459,7 @@ class GroupsServerHandler(object): # TODO: If admin add lists of users whose attestations have timed out - defer.returnValue( - {"chunk": chunk, "total_user_count_estimate": len(user_results)} - ) + return {"chunk": chunk, "total_user_count_estimate": len(user_results)} @defer.inlineCallbacks def get_invited_users_in_group(self, group_id, requester_user_id): @@ -494,9 +490,7 @@ class GroupsServerHandler(object): logger.warn("Error getting profile for %s: %s", user_id, e) user_profiles.append(user_profile) - defer.returnValue( - {"chunk": user_profiles, "total_user_count_estimate": len(invited_users)} - ) + return {"chunk": user_profiles, "total_user_count_estimate": len(invited_users)} @defer.inlineCallbacks def get_rooms_in_group(self, group_id, requester_user_id): @@ -533,9 +527,7 @@ class GroupsServerHandler(object): chunk.sort(key=lambda e: -e["num_joined_members"]) - defer.returnValue( - {"chunk": chunk, "total_room_count_estimate": len(room_results)} - ) + return {"chunk": chunk, "total_room_count_estimate": len(room_results)} @defer.inlineCallbacks def add_room_to_group(self, group_id, requester_user_id, room_id, content): @@ -551,7 +543,7 @@ class GroupsServerHandler(object): yield self.store.add_room_to_group(group_id, room_id, is_public=is_public) - defer.returnValue({}) + return {} @defer.inlineCallbacks def update_room_in_group( @@ -574,7 +566,7 @@ class GroupsServerHandler(object): else: raise SynapseError(400, "Uknown config option") - defer.returnValue({}) + return {} @defer.inlineCallbacks def remove_room_from_group(self, group_id, requester_user_id, room_id): @@ -586,7 +578,7 @@ class GroupsServerHandler(object): yield self.store.remove_room_from_group(group_id, room_id) - defer.returnValue({}) + return {} @defer.inlineCallbacks def invite_to_group(self, group_id, user_id, requester_user_id, content): @@ -644,9 +636,9 @@ class GroupsServerHandler(object): ) elif res["state"] == "invite": yield self.store.add_group_invite(group_id, user_id) - defer.returnValue({"state": "invite"}) + return {"state": "invite"} elif res["state"] == "reject": - defer.returnValue({"state": "reject"}) + return {"state": "reject"} else: raise SynapseError(502, "Unknown state returned by HS") @@ -679,7 +671,7 @@ class GroupsServerHandler(object): remote_attestation=remote_attestation, ) - defer.returnValue(local_attestation) + return local_attestation @defer.inlineCallbacks def accept_invite(self, group_id, requester_user_id, content): @@ -699,7 +691,7 @@ class GroupsServerHandler(object): local_attestation = yield self._add_user(group_id, requester_user_id, content) - defer.returnValue({"state": "join", "attestation": local_attestation}) + return {"state": "join", "attestation": local_attestation} @defer.inlineCallbacks def join_group(self, group_id, requester_user_id, content): @@ -716,7 +708,7 @@ class GroupsServerHandler(object): local_attestation = yield self._add_user(group_id, requester_user_id, content) - defer.returnValue({"state": "join", "attestation": local_attestation}) + return {"state": "join", "attestation": local_attestation} @defer.inlineCallbacks def knock(self, group_id, requester_user_id, content): @@ -769,7 +761,7 @@ class GroupsServerHandler(object): if not self.hs.is_mine_id(user_id): yield self.store.maybe_delete_remote_profile_cache(user_id) - defer.returnValue({}) + return {} @defer.inlineCallbacks def create_group(self, group_id, requester_user_id, content): @@ -845,7 +837,7 @@ class GroupsServerHandler(object): avatar_url=user_profile.get("avatar_url"), ) - defer.returnValue({"group_id": group_id}) + return {"group_id": group_id} @defer.inlineCallbacks def delete_group(self, group_id, requester_user_id): diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index e62e6cab77..8acd9f9a83 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -51,8 +51,8 @@ class AccountDataEventSource(object): {"type": account_data_type, "content": content, "room_id": room_id} ) - defer.returnValue((results, current_stream_id)) + return (results, current_stream_id) @defer.inlineCallbacks def get_pagination_rows(self, user, config, key): - defer.returnValue(([], config.to_id)) + return ([], config.to_id) diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 1f1708ba7d..930204e2d0 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -193,7 +193,7 @@ class AccountValidityHandler(object): if threepid["medium"] == "email": addresses.append(threepid["address"]) - defer.returnValue(addresses) + return addresses @defer.inlineCallbacks def _get_renewal_token(self, user_id): @@ -214,7 +214,7 @@ class AccountValidityHandler(object): try: renewal_token = stringutils.random_string(32) yield self.store.set_renewal_token_for_user(user_id, renewal_token) - defer.returnValue(renewal_token) + return renewal_token except StoreError: attempts += 1 raise StoreError(500, "Couldn't generate a unique string as refresh string.") @@ -254,4 +254,4 @@ class AccountValidityHandler(object): user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent ) - defer.returnValue(expiration_ts) + return expiration_ts diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py index fbef2f3d38..46ac73106d 100644 --- a/synapse/handlers/acme.py +++ b/synapse/handlers/acme.py @@ -100,4 +100,4 @@ class AcmeHandler(object): logger.exception("Failed saving!") raise - defer.returnValue(True) + return True diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index e8a651e231..2f22f56ca4 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -49,7 +49,7 @@ class AdminHandler(BaseHandler): "devices": {"": {"sessions": [{"connections": connections}]}}, } - defer.returnValue(ret) + return ret @defer.inlineCallbacks def get_users(self): @@ -61,7 +61,7 @@ class AdminHandler(BaseHandler): """ ret = yield self.store.get_users() - defer.returnValue(ret) + return ret @defer.inlineCallbacks def get_users_paginate(self, order, start, limit): @@ -78,7 +78,7 @@ class AdminHandler(BaseHandler): """ ret = yield self.store.get_users_paginate(order, start, limit) - defer.returnValue(ret) + return ret @defer.inlineCallbacks def search_users(self, term): @@ -92,7 +92,7 @@ class AdminHandler(BaseHandler): """ ret = yield self.store.search_users(term) - defer.returnValue(ret) + return ret @defer.inlineCallbacks def export_user_data(self, user_id, writer): @@ -225,7 +225,7 @@ class AdminHandler(BaseHandler): state = yield self.store.get_state_for_event(event_id) writer.write_state(room_id, event_id, state) - defer.returnValue(writer.finished()) + return writer.finished() class ExfiltrationWriter(object): diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 8f089f0e33..d1a51df6f9 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -167,8 +167,8 @@ class ApplicationServicesHandler(object): for user_service in user_query_services: is_known_user = yield self.appservice_api.query_user(user_service, user_id) if is_known_user: - defer.returnValue(True) - defer.returnValue(False) + return True + return False @defer.inlineCallbacks def query_room_alias_exists(self, room_alias): @@ -192,7 +192,7 @@ class ApplicationServicesHandler(object): if is_known_alias: # the alias exists now so don't query more ASes. result = yield self.store.get_association_from_room_alias(room_alias) - defer.returnValue(result) + return result @defer.inlineCallbacks def query_3pe(self, kind, protocol, fields): @@ -215,7 +215,7 @@ class ApplicationServicesHandler(object): if success: ret.extend(result) - defer.returnValue(ret) + return ret @defer.inlineCallbacks def get_3pe_protocols(self, only_protocol=None): @@ -254,7 +254,7 @@ class ApplicationServicesHandler(object): for p in protocols.keys(): protocols[p] = _merge_instances(protocols[p]) - defer.returnValue(protocols) + return protocols @defer.inlineCallbacks def _get_services_for_event(self, event): @@ -276,7 +276,7 @@ class ApplicationServicesHandler(object): if (yield s.is_interested(event, self.store)): interested_list.append(s) - defer.returnValue(interested_list) + return interested_list def _get_services_for_user(self, user_id): services = self.store.get_app_services() @@ -293,23 +293,23 @@ class ApplicationServicesHandler(object): if not self.is_mine_id(user_id): # we don't know if they are unknown or not since it isn't one of our # users. We can't poke ASes. - defer.returnValue(False) + return False return user_info = yield self.store.get_user_by_id(user_id) if user_info: - defer.returnValue(False) + return False return # user not found; could be the AS though, so check. services = self.store.get_app_services() service_list = [s for s in services if s.sender == user_id] - defer.returnValue(len(service_list) == 0) + return len(service_list) == 0 @defer.inlineCallbacks def _check_user_exists(self, user_id): unknown_user = yield self._is_unknown_user(user_id) if unknown_user: exists = yield self.query_user_exists(user_id) - defer.returnValue(exists) - defer.returnValue(True) + return exists + return True diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index d4d6574975..05be5b7c48 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -155,7 +155,7 @@ class AuthHandler(BaseHandler): if user_id != requester.user.to_string(): raise AuthError(403, "Invalid auth") - defer.returnValue(params) + return params @defer.inlineCallbacks def check_auth(self, flows, clientdict, clientip, password_servlet=False): @@ -280,7 +280,7 @@ class AuthHandler(BaseHandler): creds, list(clientdict), ) - defer.returnValue((creds, clientdict, session["id"])) + return (creds, clientdict, session["id"]) ret = self._auth_dict_for_flows(flows, session) ret["completed"] = list(creds) @@ -307,8 +307,8 @@ class AuthHandler(BaseHandler): if result: creds[stagetype] = result self._save_session(sess) - defer.returnValue(True) - defer.returnValue(False) + return True + return False def get_session_id(self, clientdict): """ @@ -379,7 +379,7 @@ class AuthHandler(BaseHandler): res = yield checker( authdict, clientip=clientip, password_servlet=password_servlet ) - defer.returnValue(res) + return res # build a v1-login-style dict out of the authdict and fall back to the # v1 code @@ -389,7 +389,7 @@ class AuthHandler(BaseHandler): raise SynapseError(400, "", Codes.MISSING_PARAM) (canonical_id, callback) = yield self.validate_login(user_id, authdict) - defer.returnValue(canonical_id) + return canonical_id @defer.inlineCallbacks def _check_recaptcha(self, authdict, clientip, **kwargs): @@ -433,7 +433,7 @@ class AuthHandler(BaseHandler): resp_body.get("hostname"), ) if resp_body["success"]: - defer.returnValue(True) + return True raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) def _check_email_identity(self, authdict, **kwargs): @@ -502,7 +502,7 @@ class AuthHandler(BaseHandler): threepid["threepid_creds"] = authdict["threepid_creds"] - defer.returnValue(threepid) + return threepid def _get_params_recaptcha(self): return {"public_key": self.hs.config.recaptcha_public_key} @@ -606,7 +606,7 @@ class AuthHandler(BaseHandler): yield self.store.delete_access_token(access_token) raise StoreError(400, "Login raced against device deletion") - defer.returnValue(access_token) + return access_token @defer.inlineCallbacks def check_user_exists(self, user_id): @@ -629,8 +629,8 @@ class AuthHandler(BaseHandler): self.ratelimit_login_per_account(user_id) res = yield self._find_user_id_and_pwd_hash(user_id) if res is not None: - defer.returnValue(res[0]) - defer.returnValue(None) + return res[0] + return None @defer.inlineCallbacks def _find_user_id_and_pwd_hash(self, user_id): @@ -661,7 +661,7 @@ class AuthHandler(BaseHandler): user_id, user_infos.keys(), ) - defer.returnValue(result) + return result def get_supported_login_types(self): """Get a the login types supported for the /login API @@ -722,7 +722,7 @@ class AuthHandler(BaseHandler): known_login_type = True is_valid = yield provider.check_password(qualified_user_id, password) if is_valid: - defer.returnValue((qualified_user_id, None)) + return (qualified_user_id, None) if not hasattr(provider, "get_supported_login_types") or not hasattr( provider, "check_auth" @@ -756,7 +756,7 @@ class AuthHandler(BaseHandler): if result: if isinstance(result, str): result = (result, None) - defer.returnValue(result) + return result if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled: known_login_type = True @@ -766,7 +766,7 @@ class AuthHandler(BaseHandler): ) if canonical_user_id: - defer.returnValue((canonical_user_id, None)) + return (canonical_user_id, None) if not known_login_type: raise SynapseError(400, "Unknown login type %s" % login_type) @@ -814,9 +814,9 @@ class AuthHandler(BaseHandler): if isinstance(result, str): # If it's a str, set callback function to None result = (result, None) - defer.returnValue(result) + return result - defer.returnValue((None, None)) + return (None, None) @defer.inlineCallbacks def _check_local_password(self, user_id, password): @@ -838,7 +838,7 @@ class AuthHandler(BaseHandler): """ lookupres = yield self._find_user_id_and_pwd_hash(user_id) if not lookupres: - defer.returnValue(None) + return None (user_id, password_hash) = lookupres # If the password hash is None, the account has likely been deactivated @@ -850,8 +850,8 @@ class AuthHandler(BaseHandler): result = yield self.validate_hash(password, password_hash) if not result: logger.warn("Failed password login for user %s", user_id) - defer.returnValue(None) - defer.returnValue(user_id) + return None + return user_id @defer.inlineCallbacks def validate_short_term_login_token_and_get_user_id(self, login_token): @@ -865,7 +865,7 @@ class AuthHandler(BaseHandler): raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) self.ratelimit_login_per_account(user_id) yield self.auth.check_auth_blocking(user_id) - defer.returnValue(user_id) + return user_id @defer.inlineCallbacks def delete_access_token(self, access_token): @@ -976,7 +976,7 @@ class AuthHandler(BaseHandler): ) yield self.store.user_delete_threepid(user_id, medium, address) - defer.returnValue(result) + return result def _save_session(self, session): # TODO: Persistent storage diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index e8f9da6098..5f804d1f13 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -125,7 +125,7 @@ class DeactivateAccountHandler(BaseHandler): # Mark the user as deactivated. yield self.store.set_user_deactivated_status(user_id, True) - defer.returnValue(identity_server_supports_unbinding) + return identity_server_supports_unbinding def _start_user_parting(self): """ diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 99e8413092..d6ab337783 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -64,7 +64,7 @@ class DeviceWorkerHandler(BaseHandler): for device in devices: _update_device_from_client_ips(device, ips) - defer.returnValue(devices) + return devices @defer.inlineCallbacks def get_device(self, user_id, device_id): @@ -85,7 +85,7 @@ class DeviceWorkerHandler(BaseHandler): raise errors.NotFoundError ips = yield self.store.get_last_client_ip_by_device(user_id, device_id) _update_device_from_client_ips(device, ips) - defer.returnValue(device) + return device @measure_func("device.get_user_ids_changed") @defer.inlineCallbacks @@ -200,9 +200,7 @@ class DeviceWorkerHandler(BaseHandler): possibly_joined = [] possibly_left = [] - defer.returnValue( - {"changed": list(possibly_joined), "left": list(possibly_left)} - ) + return {"changed": list(possibly_joined), "left": list(possibly_left)} class DeviceHandler(DeviceWorkerHandler): @@ -250,7 +248,7 @@ class DeviceHandler(DeviceWorkerHandler): ) if new_device: yield self.notify_device_update(user_id, [device_id]) - defer.returnValue(device_id) + return device_id # if the device id is not specified, we'll autogen one, but loop a few # times in case of a clash. @@ -264,7 +262,7 @@ class DeviceHandler(DeviceWorkerHandler): ) if new_device: yield self.notify_device_update(user_id, [device_id]) - defer.returnValue(device_id) + return device_id attempts += 1 raise errors.StoreError(500, "Couldn't generate a device ID.") @@ -411,9 +409,7 @@ class DeviceHandler(DeviceWorkerHandler): @defer.inlineCallbacks def on_federation_query_user_devices(self, user_id): stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id) - defer.returnValue( - {"user_id": user_id, "stream_id": stream_id, "devices": devices} - ) + return {"user_id": user_id, "stream_id": stream_id, "devices": devices} @defer.inlineCallbacks def user_left_room(self, user, room_id): @@ -623,7 +619,7 @@ class DeviceListEduUpdater(object): for _, stream_id, prev_ids, _ in updates: if not prev_ids: # We always do a resync if there are no previous IDs - defer.returnValue(True) + return True for prev_id in prev_ids: if prev_id == extremity: @@ -633,8 +629,8 @@ class DeviceListEduUpdater(object): elif prev_id in stream_id_in_updates: continue else: - defer.returnValue(True) + return True stream_id_in_updates.add(stream_id) - defer.returnValue(False) + return False diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 42d5b3db30..0fd423197c 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -210,7 +210,7 @@ class DirectoryHandler(BaseHandler): except AuthError as e: logger.info("Failed to update alias events: %s", e) - defer.returnValue(room_id) + return room_id @defer.inlineCallbacks def delete_appservice_association(self, service, room_alias): @@ -229,7 +229,7 @@ class DirectoryHandler(BaseHandler): room_id = yield self.store.delete_room_alias(room_alias) - defer.returnValue(room_id) + return room_id @defer.inlineCallbacks def get_association(self, room_alias): @@ -277,7 +277,7 @@ class DirectoryHandler(BaseHandler): else: servers = list(servers) - defer.returnValue({"room_id": room_id, "servers": servers}) + return {"room_id": room_id, "servers": servers} return @defer.inlineCallbacks @@ -289,7 +289,7 @@ class DirectoryHandler(BaseHandler): result = yield self.get_association_from_room_alias(room_alias) if result is not None: - defer.returnValue({"room_id": result.room_id, "servers": result.servers}) + return {"room_id": result.room_id, "servers": result.servers} else: raise SynapseError( 404, @@ -342,7 +342,7 @@ class DirectoryHandler(BaseHandler): # Query AS to see if it exists as_handler = self.appservice_handler result = yield as_handler.query_room_alias_exists(room_alias) - defer.returnValue(result) + return result def can_modify_alias(self, alias, user_id=None): # Any application service "interested" in an alias they are regexing on @@ -369,10 +369,10 @@ class DirectoryHandler(BaseHandler): creator = yield self.store.get_room_alias_creator(alias.to_string()) if creator is not None and creator == user_id: - defer.returnValue(True) + return True is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id)) - defer.returnValue(is_admin) + return is_admin @defer.inlineCallbacks def edit_published_room_list(self, requester, room_id, visibility): diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index fdfe8611b6..1300b540e3 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -144,7 +144,7 @@ class E2eKeysHandler(object): ) ) - defer.returnValue({"device_keys": results, "failures": failures}) + return {"device_keys": results, "failures": failures} @defer.inlineCallbacks def query_local_devices(self, query): @@ -189,7 +189,7 @@ class E2eKeysHandler(object): r["unsigned"]["device_display_name"] = display_name result_dict[user_id][device_id] = r - defer.returnValue(result_dict) + return result_dict @defer.inlineCallbacks def on_federation_query_client_keys(self, query_body): @@ -197,7 +197,7 @@ class E2eKeysHandler(object): """ device_keys_query = query_body.get("device_keys", {}) res = yield self.query_local_devices(device_keys_query) - defer.returnValue({"device_keys": res}) + return {"device_keys": res} @defer.inlineCallbacks def claim_one_time_keys(self, query, timeout): @@ -259,7 +259,7 @@ class E2eKeysHandler(object): ), ) - defer.returnValue({"one_time_keys": json_result, "failures": failures}) + return {"one_time_keys": json_result, "failures": failures} @defer.inlineCallbacks def upload_keys_for_user(self, user_id, device_id, keys): @@ -297,7 +297,7 @@ class E2eKeysHandler(object): result = yield self.store.count_e2e_one_time_keys(user_id, device_id) - defer.returnValue({"one_time_key_counts": result}) + return {"one_time_key_counts": result} @defer.inlineCallbacks def _upload_one_time_keys_for_user( diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index ebd807bca6..41b871fc59 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -84,7 +84,7 @@ class E2eRoomKeysHandler(object): user_id, version, room_id, session_id ) - defer.returnValue(results) + return results @defer.inlineCallbacks def delete_room_keys(self, user_id, version, room_id=None, session_id=None): @@ -262,7 +262,7 @@ class E2eRoomKeysHandler(object): new_version = yield self.store.create_e2e_room_keys_version( user_id, version_info ) - defer.returnValue(new_version) + return new_version @defer.inlineCallbacks def get_version_info(self, user_id, version=None): @@ -292,7 +292,7 @@ class E2eRoomKeysHandler(object): raise NotFoundError("Unknown backup version") else: raise - defer.returnValue(res) + return res @defer.inlineCallbacks def delete_version(self, user_id, version=None): @@ -350,4 +350,4 @@ class E2eRoomKeysHandler(object): user_id, version, version_info ) - defer.returnValue({}) + return {} diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 6a38328af3..2f1f10a9af 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -143,7 +143,7 @@ class EventStreamHandler(BaseHandler): "end": tokens[1].to_string(), } - defer.returnValue(chunk) + return chunk class EventHandler(BaseHandler): @@ -166,7 +166,7 @@ class EventHandler(BaseHandler): event = yield self.store.get_event(event_id, check_room_id=room_id) if not event: - defer.returnValue(None) + return None return users = yield self.store.get_users_in_room(event.room_id) @@ -179,4 +179,4 @@ class EventHandler(BaseHandler): if not filtered: raise AuthError(403, "You don't have permission to access that event.") - defer.returnValue(event) + return event diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 57be968c67..2aa208a2b8 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -210,7 +210,7 @@ class FederationHandler(BaseHandler): event_id, origin, ) - defer.returnValue(None) + return None state = None auth_chain = [] @@ -676,7 +676,7 @@ class FederationHandler(BaseHandler): events = [e for e in events if e.event_id not in seen_events] if not events: - defer.returnValue([]) + return [] event_map = {e.event_id: e for e in events} @@ -838,7 +838,7 @@ class FederationHandler(BaseHandler): # TODO: We can probably do something more clever here. yield self._handle_new_event(dest, event, backfilled=True) - defer.returnValue(events) + return events @defer.inlineCallbacks def maybe_backfill(self, room_id, current_depth): @@ -894,7 +894,7 @@ class FederationHandler(BaseHandler): ) if not filtered_extremities: - defer.returnValue(False) + return False # Check if we reached a point where we should start backfilling. sorted_extremeties_tuple = sorted(extremities.items(), key=lambda e: -int(e[1])) @@ -965,7 +965,7 @@ class FederationHandler(BaseHandler): # If this succeeded then we probably already have the # appropriate stuff. # TODO: We can probably do something more intelligent here. - defer.returnValue(True) + return True except SynapseError as e: logger.info("Failed to backfill from %s because %s", dom, e) continue @@ -985,11 +985,11 @@ class FederationHandler(BaseHandler): logger.exception("Failed to backfill from %s because %s", dom, e) continue - defer.returnValue(False) + return False success = yield try_backfill(likely_domains) if success: - defer.returnValue(True) + return True # Huh, well *those* domains didn't work out. Lets try some domains # from the time. @@ -1031,11 +1031,11 @@ class FederationHandler(BaseHandler): [dom for dom, _ in likely_domains if dom not in tried_domains] ) if success: - defer.returnValue(True) + return True tried_domains.update(dom for dom, _ in likely_domains) - defer.returnValue(False) + return False def _sanity_check_event(self, ev): """ @@ -1082,7 +1082,7 @@ class FederationHandler(BaseHandler): pdu=event, ) - defer.returnValue(pdu) + return pdu @defer.inlineCallbacks def on_event_auth(self, event_id): @@ -1090,7 +1090,7 @@ class FederationHandler(BaseHandler): auth = yield self.store.get_auth_chain( [auth_id for auth_id in event.auth_event_ids()], include_given=True ) - defer.returnValue([e for e in auth]) + return [e for e in auth] @log_function @defer.inlineCallbacks @@ -1177,7 +1177,7 @@ class FederationHandler(BaseHandler): run_in_background(self._handle_queued_pdus, room_queue) - defer.returnValue(True) + return True @defer.inlineCallbacks def _handle_queued_pdus(self, room_queue): @@ -1247,7 +1247,7 @@ class FederationHandler(BaseHandler): room_version, event, context, do_sig_check=False ) - defer.returnValue(event) + return event @defer.inlineCallbacks @log_function @@ -1308,7 +1308,7 @@ class FederationHandler(BaseHandler): state = yield self.store.get_events(list(prev_state_ids.values())) - defer.returnValue({"state": list(state.values()), "auth_chain": auth_chain}) + return {"state": list(state.values()), "auth_chain": auth_chain} @defer.inlineCallbacks def on_invite_request(self, origin, pdu): @@ -1364,7 +1364,7 @@ class FederationHandler(BaseHandler): context = yield self.state_handler.compute_event_context(event) yield self.persist_events_and_notify([(event, context)]) - defer.returnValue(event) + return event @defer.inlineCallbacks def do_remotely_reject_invite(self, target_hosts, room_id, user_id): @@ -1389,7 +1389,7 @@ class FederationHandler(BaseHandler): context = yield self.state_handler.compute_event_context(event) yield self.persist_events_and_notify([(event, context)]) - defer.returnValue(event) + return event @defer.inlineCallbacks def _make_and_verify_event( @@ -1407,7 +1407,7 @@ class FederationHandler(BaseHandler): assert event.user_id == user_id assert event.state_key == user_id assert event.room_id == room_id - defer.returnValue((origin, event, format_ver)) + return (origin, event, format_ver) @defer.inlineCallbacks @log_function @@ -1451,7 +1451,7 @@ class FederationHandler(BaseHandler): logger.warn("Failed to create new leave %r because %s", event, e) raise e - defer.returnValue(event) + return event @defer.inlineCallbacks @log_function @@ -1484,7 +1484,7 @@ class FederationHandler(BaseHandler): event.signatures, ) - defer.returnValue(None) + return None @defer.inlineCallbacks def get_state_for_pdu(self, room_id, event_id): @@ -1512,9 +1512,9 @@ class FederationHandler(BaseHandler): del results[(event.type, event.state_key)] res = list(results.values()) - defer.returnValue(res) + return res else: - defer.returnValue([]) + return [] @defer.inlineCallbacks def get_state_ids_for_pdu(self, room_id, event_id): @@ -1539,9 +1539,9 @@ class FederationHandler(BaseHandler): else: results.pop((event.type, event.state_key), None) - defer.returnValue(list(results.values())) + return list(results.values()) else: - defer.returnValue([]) + return [] @defer.inlineCallbacks @log_function @@ -1554,7 +1554,7 @@ class FederationHandler(BaseHandler): events = yield filter_events_for_server(self.store, origin, events) - defer.returnValue(events) + return events @defer.inlineCallbacks @log_function @@ -1584,9 +1584,9 @@ class FederationHandler(BaseHandler): events = yield filter_events_for_server(self.store, origin, [event]) event = events[0] - defer.returnValue(event) + return event else: - defer.returnValue(None) + return None def get_min_depth_for_context(self, context): return self.store.get_min_depth(context) @@ -1618,7 +1618,7 @@ class FederationHandler(BaseHandler): self.store.remove_push_actions_from_staging, event.event_id ) - defer.returnValue(context) + return context @defer.inlineCallbacks def _handle_new_events(self, origin, event_infos, backfilled=False): @@ -1641,7 +1641,7 @@ class FederationHandler(BaseHandler): auth_events=ev_info.get("auth_events"), backfilled=backfilled, ) - defer.returnValue(res) + return res contexts = yield make_deferred_yieldable( defer.gatherResults( @@ -1800,7 +1800,7 @@ class FederationHandler(BaseHandler): if event.type == EventTypes.GuestAccess and not context.rejected: yield self.maybe_kick_guest_users(event) - defer.returnValue(context) + return context @defer.inlineCallbacks def _check_for_soft_fail(self, event, state, backfilled): @@ -1919,7 +1919,7 @@ class FederationHandler(BaseHandler): logger.debug("on_query_auth returning: %s", ret) - defer.returnValue(ret) + return ret @defer.inlineCallbacks def on_get_missing_events( @@ -1942,7 +1942,7 @@ class FederationHandler(BaseHandler): self.store, origin, missing_events ) - defer.returnValue(missing_events) + return missing_events @defer.inlineCallbacks @log_function @@ -2418,16 +2418,14 @@ class FederationHandler(BaseHandler): logger.debug("construct_auth_difference returning") - defer.returnValue( - { - "auth_chain": local_auth, - "rejects": { - e.event_id: {"reason": reason_map[e.event_id], "proof": None} - for e in base_remote_rejected - }, - "missing": [e.event_id for e in missing_locals], - } - ) + return { + "auth_chain": local_auth, + "rejects": { + e.event_id: {"reason": reason_map[e.event_id], "proof": None} + for e in base_remote_rejected + }, + "missing": [e.event_id for e in missing_locals], + } @defer.inlineCallbacks @log_function @@ -2575,7 +2573,7 @@ class FederationHandler(BaseHandler): builder=builder ) EventValidator().validate_new(event) - defer.returnValue((event, context)) + return (event, context) @defer.inlineCallbacks def _check_signature(self, event, context): diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index 7da63bb643..7b67c8ae0f 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -162,7 +162,7 @@ class GroupsLocalHandler(object): res.setdefault("user", {})["is_publicised"] = is_publicised - defer.returnValue(res) + return res @defer.inlineCallbacks def create_group(self, group_id, user_id, content): @@ -207,7 +207,7 @@ class GroupsLocalHandler(object): ) self.notifier.on_new_event("groups_key", token, users=[user_id]) - defer.returnValue(res) + return res @defer.inlineCallbacks def get_users_in_group(self, group_id, requester_user_id): @@ -217,7 +217,7 @@ class GroupsLocalHandler(object): res = yield self.groups_server_handler.get_users_in_group( group_id, requester_user_id ) - defer.returnValue(res) + return res group_server_name = get_domain_from_id(group_id) @@ -244,7 +244,7 @@ class GroupsLocalHandler(object): res["chunk"] = valid_entries - defer.returnValue(res) + return res @defer.inlineCallbacks def join_group(self, group_id, user_id, content): @@ -285,7 +285,7 @@ class GroupsLocalHandler(object): ) self.notifier.on_new_event("groups_key", token, users=[user_id]) - defer.returnValue({}) + return {} @defer.inlineCallbacks def accept_invite(self, group_id, user_id, content): @@ -326,7 +326,7 @@ class GroupsLocalHandler(object): ) self.notifier.on_new_event("groups_key", token, users=[user_id]) - defer.returnValue({}) + return {} @defer.inlineCallbacks def invite(self, group_id, user_id, requester_user_id, config): @@ -346,7 +346,7 @@ class GroupsLocalHandler(object): content, ) - defer.returnValue(res) + return res @defer.inlineCallbacks def on_invite(self, group_id, user_id, content): @@ -377,7 +377,7 @@ class GroupsLocalHandler(object): logger.warn("No profile for user %s: %s", user_id, e) user_profile = {} - defer.returnValue({"state": "invite", "user_profile": user_profile}) + return {"state": "invite", "user_profile": user_profile} @defer.inlineCallbacks def remove_user_from_group(self, group_id, user_id, requester_user_id, content): @@ -406,7 +406,7 @@ class GroupsLocalHandler(object): content, ) - defer.returnValue(res) + return res @defer.inlineCallbacks def user_removed_from_group(self, group_id, user_id, content): @@ -421,7 +421,7 @@ class GroupsLocalHandler(object): @defer.inlineCallbacks def get_joined_groups(self, user_id): group_ids = yield self.store.get_joined_groups(user_id) - defer.returnValue({"groups": group_ids}) + return {"groups": group_ids} @defer.inlineCallbacks def get_publicised_groups_for_user(self, user_id): @@ -433,14 +433,14 @@ class GroupsLocalHandler(object): for app_service in self.store.get_app_services(): result.extend(app_service.get_groups_for_user(user_id)) - defer.returnValue({"groups": result}) + return {"groups": result} else: bulk_result = yield self.transport_client.bulk_get_publicised_groups( get_domain_from_id(user_id), [user_id] ) result = bulk_result.get("users", {}).get(user_id) # TODO: Verify attestations - defer.returnValue({"groups": result}) + return {"groups": result} @defer.inlineCallbacks def bulk_get_publicised_groups(self, user_ids, proxy=True): @@ -475,4 +475,4 @@ class GroupsLocalHandler(object): for app_service in self.store.get_app_services(): results[uid].extend(app_service.get_groups_for_user(uid)) - defer.returnValue({"users": results}) + return {"users": results} diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 546d6169e9..d199521b58 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -82,7 +82,7 @@ class IdentityHandler(BaseHandler): "%s is not a trusted ID server: rejecting 3pid " + "credentials", id_server, ) - defer.returnValue(None) + return None try: data = yield self.http_client.get_json( @@ -95,8 +95,8 @@ class IdentityHandler(BaseHandler): raise e.to_synapse_error() if "medium" in data: - defer.returnValue(data) - defer.returnValue(None) + return data + return None @defer.inlineCallbacks def bind_threepid(self, creds, mxid): @@ -133,7 +133,7 @@ class IdentityHandler(BaseHandler): ) except CodeMessageException as e: data = json.loads(e.msg) # XXX WAT? - defer.returnValue(data) + return data @defer.inlineCallbacks def try_unbind_threepid(self, mxid, threepid): @@ -161,7 +161,7 @@ class IdentityHandler(BaseHandler): # We don't know where to unbind, so we don't have a choice but to return if not id_servers: - defer.returnValue(False) + return False changed = True for id_server in id_servers: @@ -169,7 +169,7 @@ class IdentityHandler(BaseHandler): mxid, threepid, id_server ) - defer.returnValue(changed) + return changed @defer.inlineCallbacks def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server): @@ -224,7 +224,7 @@ class IdentityHandler(BaseHandler): id_server=id_server, ) - defer.returnValue(changed) + return changed @defer.inlineCallbacks def requestEmailToken( @@ -250,7 +250,7 @@ class IdentityHandler(BaseHandler): % (id_server, "/_matrix/identity/api/v1/validate/email/requestToken"), params, ) - defer.returnValue(data) + return data except HttpResponseException as e: logger.info("Proxied requestToken failed: %r", e) raise e.to_synapse_error() @@ -278,7 +278,7 @@ class IdentityHandler(BaseHandler): % (id_server, "/_matrix/identity/api/v1/validate/msisdn/requestToken"), params, ) - defer.returnValue(data) + return data except HttpResponseException as e: logger.info("Proxied requestToken failed: %r", e) raise e.to_synapse_error() diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 54c966c8a6..42d6650ed9 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -250,7 +250,7 @@ class InitialSyncHandler(BaseHandler): "end": now_token.to_string(), } - defer.returnValue(ret) + return ret @defer.inlineCallbacks def room_initial_sync(self, requester, room_id, pagin_config=None): @@ -301,7 +301,7 @@ class InitialSyncHandler(BaseHandler): result["account_data"] = account_data_events - defer.returnValue(result) + return result @defer.inlineCallbacks def _room_initial_sync_parted( @@ -330,28 +330,24 @@ class InitialSyncHandler(BaseHandler): time_now = self.clock.time_msec() - defer.returnValue( - { - "membership": membership, - "room_id": room_id, - "messages": { - "chunk": ( - yield self._event_serializer.serialize_events( - messages, time_now - ) - ), - "start": start_token.to_string(), - "end": end_token.to_string(), - }, - "state": ( - yield self._event_serializer.serialize_events( - room_state.values(), time_now - ) + return { + "membership": membership, + "room_id": room_id, + "messages": { + "chunk": ( + yield self._event_serializer.serialize_events(messages, time_now) ), - "presence": [], - "receipts": [], - } - ) + "start": start_token.to_string(), + "end": end_token.to_string(), + }, + "state": ( + yield self._event_serializer.serialize_events( + room_state.values(), time_now + ) + ), + "presence": [], + "receipts": [], + } @defer.inlineCallbacks def _room_initial_sync_joined( @@ -384,13 +380,13 @@ class InitialSyncHandler(BaseHandler): def get_presence(): # If presence is disabled, return an empty list if not self.hs.config.use_presence: - defer.returnValue([]) + return [] states = yield presence_handler.get_states( [m.user_id for m in room_members], as_event=True ) - defer.returnValue(states) + return states @defer.inlineCallbacks def get_receipts(): @@ -399,7 +395,7 @@ class InitialSyncHandler(BaseHandler): ) if not receipts: receipts = [] - defer.returnValue(receipts) + return receipts presence, receipts, (messages, token) = yield make_deferred_yieldable( defer.gatherResults( @@ -442,7 +438,7 @@ class InitialSyncHandler(BaseHandler): if not is_peeking: ret["membership"] = membership - defer.returnValue(ret) + return ret @defer.inlineCallbacks def _check_in_room_or_world_readable(self, room_id, user_id): @@ -453,7 +449,7 @@ class InitialSyncHandler(BaseHandler): # * The user is a guest user, and has joined the room # else it will throw. member_event = yield self.auth.check_user_was_in_room(room_id, user_id) - defer.returnValue((member_event.membership, member_event.event_id)) + return (member_event.membership, member_event.event_id) return except AuthError: visibility = yield self.state_handler.get_current_state( @@ -463,7 +459,7 @@ class InitialSyncHandler(BaseHandler): visibility and visibility.content["history_visibility"] == "world_readable" ): - defer.returnValue((Membership.JOIN, None)) + return (Membership.JOIN, None) return raise AuthError( 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 6d7a987f13..8b27e23378 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -87,7 +87,7 @@ class MessageHandler(object): ) data = room_state[membership_event_id].get(key) - defer.returnValue(data) + return data @defer.inlineCallbacks def get_state_events( @@ -174,7 +174,7 @@ class MessageHandler(object): # events, as clients won't use them. bundle_aggregations=False, ) - defer.returnValue(events) + return events @defer.inlineCallbacks def get_joined_members(self, requester, room_id): @@ -213,15 +213,13 @@ class MessageHandler(object): # Loop fell through, AS has no interested users in room raise AuthError(403, "Appservice not in room") - defer.returnValue( - { - user_id: { - "avatar_url": profile.avatar_url, - "display_name": profile.display_name, - } - for user_id, profile in iteritems(users_with_profile) + return { + user_id: { + "avatar_url": profile.avatar_url, + "display_name": profile.display_name, } - ) + for user_id, profile in iteritems(users_with_profile) + } class EventCreationHandler(object): @@ -398,7 +396,7 @@ class EventCreationHandler(object): self.validator.validate_new(event) - defer.returnValue((event, context)) + return (event, context) def _is_exempt_from_privacy_policy(self, builder, requester): """"Determine if an event to be sent is exempt from having to consent @@ -425,9 +423,9 @@ class EventCreationHandler(object): @defer.inlineCallbacks def _is_server_notices_room(self, room_id): if self.config.server_notices_mxid is None: - defer.returnValue(False) + return False user_ids = yield self.store.get_users_in_room(room_id) - defer.returnValue(self.config.server_notices_mxid in user_ids) + return self.config.server_notices_mxid in user_ids @defer.inlineCallbacks def assert_accepted_privacy_policy(self, requester): @@ -507,7 +505,7 @@ class EventCreationHandler(object): event.event_id, prev_state.event_id, ) - defer.returnValue(prev_state) + return prev_state yield self.handle_new_client_event( requester=requester, event=event, context=context, ratelimit=ratelimit @@ -531,7 +529,7 @@ class EventCreationHandler(object): prev_content = encode_canonical_json(prev_event.content) next_content = encode_canonical_json(event.content) if prev_content == next_content: - defer.returnValue(prev_event) + return prev_event return @defer.inlineCallbacks @@ -563,7 +561,7 @@ class EventCreationHandler(object): yield self.send_nonmember_event( requester, event, context, ratelimit=ratelimit ) - defer.returnValue(event) + return event @measure_func("create_new_client_event") @defer.inlineCallbacks @@ -626,7 +624,7 @@ class EventCreationHandler(object): logger.debug("Created event %s", event.event_id) - defer.returnValue((event, context)) + return (event, context) @measure_func("handle_new_client_event") @defer.inlineCallbacks diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 20bcfed334..d83aab3f74 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -242,13 +242,11 @@ class PaginationHandler(object): ) if not events: - defer.returnValue( - { - "chunk": [], - "start": pagin_config.from_token.to_string(), - "end": next_token.to_string(), - } - ) + return { + "chunk": [], + "start": pagin_config.from_token.to_string(), + "end": next_token.to_string(), + } state = None if event_filter and event_filter.lazy_load_members() and len(events) > 0: @@ -286,4 +284,4 @@ class PaginationHandler(object): ) ) - defer.returnValue(chunk) + return chunk diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 6f3537e435..ea54d0b991 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -461,7 +461,7 @@ class PresenceHandler(object): if affect_presence: run_in_background(_end) - defer.returnValue(_user_syncing()) + return _user_syncing() def get_currently_syncing_users(self): """Get the set of user ids that are currently syncing on this HS. @@ -556,7 +556,7 @@ class PresenceHandler(object): """Get the current presence state for a user. """ res = yield self.current_state_for_users([user_id]) - defer.returnValue(res[user_id]) + return res[user_id] @defer.inlineCallbacks def current_state_for_users(self, user_ids): @@ -585,7 +585,7 @@ class PresenceHandler(object): states.update(new) self.user_to_current_state.update(new) - defer.returnValue(states) + return states @defer.inlineCallbacks def _persist_and_notify(self, states): @@ -681,7 +681,7 @@ class PresenceHandler(object): def get_state(self, target_user, as_event=False): results = yield self.get_states([target_user.to_string()], as_event=as_event) - defer.returnValue(results[0]) + return results[0] @defer.inlineCallbacks def get_states(self, target_user_ids, as_event=False): @@ -703,17 +703,15 @@ class PresenceHandler(object): now = self.clock.time_msec() if as_event: - defer.returnValue( - [ - { - "type": "m.presence", - "content": format_user_presence_state(state, now), - } - for state in updates - ] - ) + return [ + { + "type": "m.presence", + "content": format_user_presence_state(state, now), + } + for state in updates + ] else: - defer.returnValue(updates) + return updates @defer.inlineCallbacks def set_state(self, target_user, state, ignore_status_msg=False): @@ -757,9 +755,9 @@ class PresenceHandler(object): ) if observer_room_ids & observed_room_ids: - defer.returnValue(True) + return True - defer.returnValue(False) + return False @defer.inlineCallbacks def get_all_presence_updates(self, last_id, current_id): @@ -778,7 +776,7 @@ class PresenceHandler(object): # TODO(markjh): replicate the unpersisted changes. # This could use the in-memory stores for recent changes. rows = yield self.store.get_all_presence_updates(last_id, current_id) - defer.returnValue(rows) + return rows def notify_new_event(self): """Called when new events have happened. Handles users and servers @@ -1034,7 +1032,7 @@ class PresenceEventSource(object): # # Hence this guard where we just return nothing so that the sync # doesn't return. C.f. #5503. - defer.returnValue(([], max_token)) + return ([], max_token) presence = self.get_presence_handler() stream_change_cache = self.store.presence_stream_cache @@ -1068,17 +1066,11 @@ class PresenceEventSource(object): updates = yield presence.current_state_for_users(user_ids_changed) if include_offline: - defer.returnValue((list(updates.values()), max_token)) + return (list(updates.values()), max_token) else: - defer.returnValue( - ( - [ - s - for s in itervalues(updates) - if s.state != PresenceState.OFFLINE - ], - max_token, - ) + return ( + [s for s in itervalues(updates) if s.state != PresenceState.OFFLINE], + max_token, ) def get_current_key(self): @@ -1107,7 +1099,7 @@ class PresenceEventSource(object): ) users_interested_in.update(user_ids) - defer.returnValue(users_interested_in) + return users_interested_in def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now): @@ -1287,7 +1279,7 @@ def get_interested_parties(store, states): # Always notify self users_to_states.setdefault(state.user_id, []).append(state) - defer.returnValue((room_ids_to_states, users_to_states)) + return (room_ids_to_states, users_to_states) @defer.inlineCallbacks @@ -1321,4 +1313,4 @@ def get_interested_remotes(store, states, state_handler): host = get_domain_from_id(user_id) hosts_and_states.append(([host], states)) - defer.returnValue(hosts_and_states) + return hosts_and_states diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index a2388a7091..2cc237e6a5 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -73,7 +73,7 @@ class BaseProfileHandler(BaseHandler): raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) raise - defer.returnValue({"displayname": displayname, "avatar_url": avatar_url}) + return {"displayname": displayname, "avatar_url": avatar_url} else: try: result = yield self.federation.make_query( @@ -82,7 +82,7 @@ class BaseProfileHandler(BaseHandler): args={"user_id": user_id}, ignore_backoff=True, ) - defer.returnValue(result) + return result except RequestSendFailed as e: raise_from(SynapseError(502, "Failed to fetch profile"), e) except HttpResponseException as e: @@ -108,10 +108,10 @@ class BaseProfileHandler(BaseHandler): raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) raise - defer.returnValue({"displayname": displayname, "avatar_url": avatar_url}) + return {"displayname": displayname, "avatar_url": avatar_url} else: profile = yield self.store.get_from_remote_profile_cache(user_id) - defer.returnValue(profile or {}) + return profile or {} @defer.inlineCallbacks def get_displayname(self, target_user): @@ -125,7 +125,7 @@ class BaseProfileHandler(BaseHandler): raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) raise - defer.returnValue(displayname) + return displayname else: try: result = yield self.federation.make_query( @@ -139,7 +139,7 @@ class BaseProfileHandler(BaseHandler): except HttpResponseException as e: raise e.to_synapse_error() - defer.returnValue(result["displayname"]) + return result["displayname"] @defer.inlineCallbacks def set_displayname(self, target_user, requester, new_displayname, by_admin=False): @@ -186,7 +186,7 @@ class BaseProfileHandler(BaseHandler): if e.code == 404: raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) raise - defer.returnValue(avatar_url) + return avatar_url else: try: result = yield self.federation.make_query( @@ -200,7 +200,7 @@ class BaseProfileHandler(BaseHandler): except HttpResponseException as e: raise e.to_synapse_error() - defer.returnValue(result["avatar_url"]) + return result["avatar_url"] @defer.inlineCallbacks def set_avatar_url(self, target_user, requester, new_avatar_url, by_admin=False): @@ -251,7 +251,7 @@ class BaseProfileHandler(BaseHandler): raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) raise - defer.returnValue(response) + return response @defer.inlineCallbacks def _update_join_states(self, requester, target_user): diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index a85dd8cdee..218d60f0c3 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -84,7 +84,7 @@ class ReceiptsHandler(BaseHandler): if min_batch_id is None: # no new receipts - defer.returnValue(False) + return False affected_room_ids = list(set([r.room_id for r in receipts])) @@ -94,7 +94,7 @@ class ReceiptsHandler(BaseHandler): min_batch_id, max_batch_id, affected_room_ids ) - defer.returnValue(True) + return True @defer.inlineCallbacks def received_client_receipt(self, room_id, receipt_type, user_id, event_id): @@ -124,9 +124,9 @@ class ReceiptsHandler(BaseHandler): ) if not result: - defer.returnValue([]) + return [] - defer.returnValue(result) + return result class ReceiptEventSource(object): @@ -139,13 +139,13 @@ class ReceiptEventSource(object): to_key = yield self.get_current_key() if from_key == to_key: - defer.returnValue(([], to_key)) + return ([], to_key) events = yield self.store.get_linearized_receipts_for_rooms( room_ids, from_key=from_key, to_key=to_key ) - defer.returnValue((events, to_key)) + return (events, to_key) def get_current_key(self, direction="f"): return self.store.get_max_receipt_stream_id() @@ -164,4 +164,4 @@ class ReceiptEventSource(object): room_ids, from_key=from_key, to_key=to_key ) - defer.returnValue((events, to_key)) + return (events, to_key) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index bb7cfd71b9..4631fab94e 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -265,7 +265,7 @@ class RegistrationHandler(BaseHandler): # Bind email to new account yield self._register_email_threepid(user_id, threepid_dict, None, False) - defer.returnValue(user_id) + return user_id @defer.inlineCallbacks def _auto_join_rooms(self, user_id): @@ -360,7 +360,7 @@ class RegistrationHandler(BaseHandler): appservice_id=service_id, create_profile_with_displayname=user.localpart, ) - defer.returnValue(user_id) + return user_id @defer.inlineCallbacks def check_recaptcha(self, ip, private_key, challenge, response): @@ -461,7 +461,7 @@ class RegistrationHandler(BaseHandler): id = self._next_generated_user_id self._next_generated_user_id += 1 - defer.returnValue(str(id)) + return str(id) @defer.inlineCallbacks def _validate_captcha(self, ip_addr, private_key, challenge, response): @@ -481,7 +481,7 @@ class RegistrationHandler(BaseHandler): "error_url": "http://www.recaptcha.net/recaptcha/api/challenge?" + "error=%s" % lines[1], } - defer.returnValue(json) + return json @defer.inlineCallbacks def _submit_captcha(self, ip_addr, private_key, challenge, response): @@ -497,7 +497,7 @@ class RegistrationHandler(BaseHandler): "response": response, }, ) - defer.returnValue(data) + return data @defer.inlineCallbacks def _join_user_to_room(self, requester, room_identifier): @@ -622,7 +622,7 @@ class RegistrationHandler(BaseHandler): initial_display_name=initial_display_name, is_guest=is_guest, ) - defer.returnValue((r["device_id"], r["access_token"])) + return (r["device_id"], r["access_token"]) valid_until_ms = None if self.session_lifetime is not None: @@ -645,7 +645,7 @@ class RegistrationHandler(BaseHandler): user_id, device_id=device_id, valid_until_ms=valid_until_ms ) - defer.returnValue((device_id, access_token)) + return (device_id, access_token) @defer.inlineCallbacks def post_registration_actions( @@ -798,7 +798,7 @@ class RegistrationHandler(BaseHandler): if ex.errcode == Codes.MISSING_PARAM: # This will only happen if the ID server returns a malformed response logger.info("Can't add incomplete 3pid") - defer.returnValue(None) + return None raise yield self._auth_handler.add_threepid( diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index db3f8cb76b..5caa90c3b7 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -128,7 +128,7 @@ class RoomCreationHandler(BaseHandler): old_room_id, new_version, # args for _upgrade_room ) - defer.returnValue(ret) + return ret @defer.inlineCallbacks def _upgrade_room(self, requester, old_room_id, new_version): @@ -193,7 +193,7 @@ class RoomCreationHandler(BaseHandler): requester, old_room_id, new_room_id, old_room_state ) - defer.returnValue(new_room_id) + return new_room_id @defer.inlineCallbacks def _update_upgraded_room_pls( @@ -671,7 +671,7 @@ class RoomCreationHandler(BaseHandler): result["room_alias"] = room_alias.to_string() yield directory_handler.send_room_alias_update_event(requester, room_id) - defer.returnValue(result) + return result @defer.inlineCallbacks def _send_events_for_new_room( @@ -796,7 +796,7 @@ class RoomCreationHandler(BaseHandler): room_creator_user_id=creator_id, is_public=is_public, ) - defer.returnValue(gen_room_id) + return gen_room_id except StoreError: attempts += 1 raise StoreError(500, "Couldn't generate a room ID.") @@ -839,7 +839,7 @@ class RoomContextHandler(object): event_id, get_prev_content=True, allow_none=True ) if not event: - defer.returnValue(None) + return None return filtered = yield (filter_evts([event])) @@ -890,7 +890,7 @@ class RoomContextHandler(object): results["end"] = token.copy_and_replace("room_key", results["end"]).to_string() - defer.returnValue(results) + return results class RoomEventSource(object): @@ -941,7 +941,7 @@ class RoomEventSource(object): else: end_key = to_key - defer.returnValue((events, end_key)) + return (events, end_key) def get_current_key(self): return self.store.get_room_events_max_id() @@ -959,4 +959,4 @@ class RoomEventSource(object): limit=config.limit, ) - defer.returnValue((events, next_key)) + return (events, next_key) diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index aae696a7e8..e9094ad02b 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -325,7 +325,7 @@ class RoomListHandler(BaseHandler): current_limit=since_token.current_limit - 1, ).to_token() - defer.returnValue(results) + return results @defer.inlineCallbacks def _append_room_entry_to_chunk( @@ -420,7 +420,7 @@ class RoomListHandler(BaseHandler): if join_rules_event: join_rule = join_rules_event.content.get("join_rule", None) if not allow_private and join_rule and join_rule != JoinRules.PUBLIC: - defer.returnValue(None) + return None # Return whether this room is open to federation users or not create_event = current_state.get((EventTypes.Create, "")) @@ -469,7 +469,7 @@ class RoomListHandler(BaseHandler): if avatar_url: result["avatar_url"] = avatar_url - defer.returnValue(result) + return result @defer.inlineCallbacks def get_remote_public_room_list( @@ -482,7 +482,7 @@ class RoomListHandler(BaseHandler): third_party_instance_id=None, ): if not self.enable_room_list_search: - defer.returnValue({"chunk": [], "total_room_count_estimate": 0}) + return {"chunk": [], "total_room_count_estimate": 0} if search_filter: # We currently don't support searching across federation, so we have @@ -507,7 +507,7 @@ class RoomListHandler(BaseHandler): ] } - defer.returnValue(res) + return res def _get_remote_list_cached( self, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index e0196ef83e..baea08ddd0 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -191,7 +191,7 @@ class RoomMemberHandler(object): ) if duplicate is not None: # Discard the new event since this membership change is a no-op. - defer.returnValue(duplicate) + return duplicate yield self.event_creation_handler.handle_new_client_event( requester, event, context, extra_users=[target], ratelimit=ratelimit @@ -233,7 +233,7 @@ class RoomMemberHandler(object): if prev_member_event.membership == Membership.JOIN: yield self._user_left_room(target, room_id) - defer.returnValue(event) + return event @defer.inlineCallbacks def copy_room_tags_and_direct_to_room(self, old_room_id, new_room_id, user_id): @@ -303,7 +303,7 @@ class RoomMemberHandler(object): require_consent=require_consent, ) - defer.returnValue(result) + return result @defer.inlineCallbacks def _update_membership( @@ -423,7 +423,7 @@ class RoomMemberHandler(object): same_membership = old_membership == effective_membership_state same_sender = requester.user.to_string() == old_state.sender if same_sender and same_membership and same_content: - defer.returnValue(old_state) + return old_state if old_membership in ["ban", "leave"] and action == "kick": raise AuthError(403, "The target user is not in the room") @@ -473,7 +473,7 @@ class RoomMemberHandler(object): ret = yield self._remote_join( requester, remote_room_hosts, room_id, target, content ) - defer.returnValue(ret) + return ret elif effective_membership_state == Membership.LEAVE: if not is_host_in_room: @@ -495,7 +495,7 @@ class RoomMemberHandler(object): res = yield self._remote_reject_invite( requester, remote_room_hosts, room_id, target ) - defer.returnValue(res) + return res res = yield self._local_membership_update( requester=requester, @@ -508,7 +508,7 @@ class RoomMemberHandler(object): content=content, require_consent=require_consent, ) - defer.returnValue(res) + return res @defer.inlineCallbacks def send_membership_event( @@ -596,11 +596,11 @@ class RoomMemberHandler(object): """ guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""), None) if not guest_access_id: - defer.returnValue(False) + return False guest_access = yield self.store.get_event(guest_access_id) - defer.returnValue( + return ( guest_access and guest_access.content and "guest_access" in guest_access.content @@ -635,7 +635,7 @@ class RoomMemberHandler(object): servers.remove(room_alias.domain) servers.insert(0, room_alias.domain) - defer.returnValue((RoomID.from_string(room_id), servers)) + return (RoomID.from_string(room_id), servers) @defer.inlineCallbacks def _get_inviter(self, user_id, room_id): @@ -643,7 +643,7 @@ class RoomMemberHandler(object): user_id=user_id, room_id=room_id ) if invite: - defer.returnValue(UserID.from_string(invite.sender)) + return UserID.from_string(invite.sender) @defer.inlineCallbacks def do_3pid_invite( @@ -708,11 +708,11 @@ class RoomMemberHandler(object): if "signatures" not in data: raise AuthError(401, "No signatures on 3pid binding") yield self._verify_any_signature(data, id_server) - defer.returnValue(data["mxid"]) + return data["mxid"] except IOError as e: logger.warn("Error from identity server lookup: %s" % (e,)) - defer.returnValue(None) + return None @defer.inlineCallbacks def _verify_any_signature(self, data, server_hostname): @@ -904,7 +904,7 @@ class RoomMemberHandler(object): if not public_keys: public_keys.append(fallback_public_key) display_name = data["display_name"] - defer.returnValue((token, public_keys, fallback_public_key, display_name)) + return (token, public_keys, fallback_public_key, display_name) @defer.inlineCallbacks def _is_host_in_room(self, current_state_ids): @@ -913,7 +913,7 @@ class RoomMemberHandler(object): create_event_id = current_state_ids.get(("m.room.create", "")) if len(current_state_ids) == 1 and create_event_id: # We can only get here if we're in the process of creating the room - defer.returnValue(True) + return True for etype, state_key in current_state_ids: if etype != EventTypes.Member or not self.hs.is_mine_id(state_key): @@ -925,16 +925,16 @@ class RoomMemberHandler(object): continue if event.membership == Membership.JOIN: - defer.returnValue(True) + return True - defer.returnValue(False) + return False @defer.inlineCallbacks def _is_server_notice_room(self, room_id): if self._server_notices_mxid is None: - defer.returnValue(False) + return False user_ids = yield self.store.get_users_in_room(room_id) - defer.returnValue(self._server_notices_mxid in user_ids) + return self._server_notices_mxid in user_ids class RoomMemberMasterHandler(RoomMemberHandler): @@ -978,7 +978,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): ret = yield fed_handler.do_remotely_reject_invite( remote_room_hosts, room_id, target.to_string() ) - defer.returnValue(ret) + return ret except Exception as e: # if we were unable to reject the exception, just mark # it as rejected on our end and plough ahead. @@ -989,7 +989,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): logger.warn("Failed to reject invite: %s", e) yield self.store.locally_reject_invite(target.to_string(), room_id) - defer.returnValue({}) + return {} def _user_joined_room(self, target, room_id): """Implements RoomMemberHandler._user_joined_room diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index fc873a3ba6..75e96ae1a2 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -53,7 +53,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): yield self._user_joined_room(user, room_id) - defer.returnValue(ret) + return ret def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target): """Implements RoomMemberHandler._remote_reject_invite diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index ddc4430d03..cd5e90bacb 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -69,7 +69,7 @@ class SearchHandler(BaseHandler): # Scan through the old room for further predecessors room_id = predecessor["room_id"] - defer.returnValue(historical_room_ids) + return historical_room_ids @defer.inlineCallbacks def search(self, user, content, batch=None): @@ -186,13 +186,11 @@ class SearchHandler(BaseHandler): room_ids.intersection_update({batch_group_key}) if not room_ids: - defer.returnValue( - { - "search_categories": { - "room_events": {"results": [], "count": 0, "highlights": []} - } + return { + "search_categories": { + "room_events": {"results": [], "count": 0, "highlights": []} } - ) + } rank_map = {} # event_id -> rank of event allowed_events = [] @@ -455,4 +453,4 @@ class SearchHandler(BaseHandler): if global_next_batch: rooms_cat_res["next_batch"] = global_next_batch - defer.returnValue({"search_categories": {"room_events": rooms_cat_res}}) + return {"search_categories": {"room_events": rooms_cat_res}} diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py index 6b364befd5..f065970c40 100644 --- a/synapse/handlers/state_deltas.py +++ b/synapse/handlers/state_deltas.py @@ -48,7 +48,7 @@ class StateDeltasHandler(object): if not event and not prev_event: logger.debug("Neither event exists: %r %r", prev_event_id, event_id) - defer.returnValue(None) + return None prev_value = None value = None @@ -62,8 +62,8 @@ class StateDeltasHandler(object): logger.debug("prev_value: %r -> value: %r", prev_value, value) if value == public_value and prev_value != public_value: - defer.returnValue(True) + return True elif value != public_value and prev_value == public_value: - defer.returnValue(False) + return False else: - defer.returnValue(None) + return None diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index a0ee8db988..4449da6669 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -86,7 +86,7 @@ class StatsHandler(StateDeltasHandler): # If still None then the initial background update hasn't happened yet if self.pos is None: - defer.returnValue(None) + return None # Loop round handling deltas until we're up to date while True: @@ -328,6 +328,6 @@ class StatsHandler(StateDeltasHandler): == "world_readable" ) ): - defer.returnValue(True) + return True else: - defer.returnValue(False) + return False diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index cd1ac0a27a..4007284e5b 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -263,7 +263,7 @@ class SyncHandler(object): timeout, full_state, ) - defer.returnValue(res) + return res @defer.inlineCallbacks def _wait_for_sync_for_user(self, sync_config, since_token, timeout, full_state): @@ -303,7 +303,7 @@ class SyncHandler(object): lazy_loaded = "false" non_empty_sync_counter.labels(sync_type, lazy_loaded).inc() - defer.returnValue(result) + return result def current_sync_for_user(self, sync_config, since_token=None, full_state=False): """Get the sync for client needed to match what the server has now. @@ -317,7 +317,7 @@ class SyncHandler(object): user_id = user.to_string() rules = yield self.store.get_push_rules_for_user(user_id) rules = format_push_rules_for_user(user, rules) - defer.returnValue(rules) + return rules @defer.inlineCallbacks def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None): @@ -378,7 +378,7 @@ class SyncHandler(object): event_copy = {k: v for (k, v) in iteritems(event) if k != "room_id"} ephemeral_by_room.setdefault(room_id, []).append(event_copy) - defer.returnValue((now_token, ephemeral_by_room)) + return (now_token, ephemeral_by_room) @defer.inlineCallbacks def _load_filtered_recents( @@ -426,8 +426,8 @@ class SyncHandler(object): recents = [] if not limited or block_all_timeline: - defer.returnValue( - TimelineBatch(events=recents, prev_batch=now_token, limited=False) + return TimelineBatch( + events=recents, prev_batch=now_token, limited=False ) filtering_factor = 2 @@ -490,12 +490,10 @@ class SyncHandler(object): prev_batch_token = now_token.copy_and_replace("room_key", room_key) - defer.returnValue( - TimelineBatch( - events=recents, - prev_batch=prev_batch_token, - limited=limited or newly_joined_room, - ) + return TimelineBatch( + events=recents, + prev_batch=prev_batch_token, + limited=limited or newly_joined_room, ) @defer.inlineCallbacks @@ -517,7 +515,7 @@ class SyncHandler(object): if event.is_state(): state_ids = state_ids.copy() state_ids[(event.type, event.state_key)] = event.event_id - defer.returnValue(state_ids) + return state_ids @defer.inlineCallbacks def get_state_at(self, room_id, stream_position, state_filter=StateFilter.all()): @@ -549,7 +547,7 @@ class SyncHandler(object): else: # no events in this room - so presumably no state state = {} - defer.returnValue(state) + return state @defer.inlineCallbacks def compute_summary(self, room_id, sync_config, batch, state, now_token): @@ -579,7 +577,7 @@ class SyncHandler(object): ) if not last_events: - defer.returnValue(None) + return None return last_event = last_events[-1] @@ -611,14 +609,14 @@ class SyncHandler(object): if name_id: name = yield self.store.get_event(name_id, allow_none=True) if name and name.content.get("name"): - defer.returnValue(summary) + return summary if canonical_alias_id: canonical_alias = yield self.store.get_event( canonical_alias_id, allow_none=True ) if canonical_alias and canonical_alias.content.get("alias"): - defer.returnValue(summary) + return summary me = sync_config.user.to_string() @@ -652,7 +650,7 @@ class SyncHandler(object): summary["m.heroes"] = sorted([user_id for user_id in gone_user_ids])[0:5] if not sync_config.filter_collection.lazy_load_members(): - defer.returnValue(summary) + return summary # ensure we send membership events for heroes if needed cache_key = (sync_config.user.to_string(), sync_config.device_id) @@ -686,7 +684,7 @@ class SyncHandler(object): cache.set(s.state_key, s.event_id) state[(EventTypes.Member, s.state_key)] = s - defer.returnValue(summary) + return summary def get_lazy_loaded_members_cache(self, cache_key): cache = self.lazy_loaded_members_cache.get(cache_key) @@ -871,14 +869,12 @@ class SyncHandler(object): if state_ids: state = yield self.store.get_events(list(state_ids.values())) - defer.returnValue( - { - (e.type, e.state_key): e - for e in sync_config.filter_collection.filter_room_state( - list(state.values()) - ) - } - ) + return { + (e.type, e.state_key): e + for e in sync_config.filter_collection.filter_room_state( + list(state.values()) + ) + } @defer.inlineCallbacks def unread_notifs_for_room_id(self, room_id, sync_config): @@ -894,11 +890,11 @@ class SyncHandler(object): notifs = yield self.store.get_unread_event_push_actions_by_room_for_user( room_id, sync_config.user.to_string(), last_unread_event_id ) - defer.returnValue(notifs) + return notifs # There is no new information in this period, so your notification # count is whatever it was last time. - defer.returnValue(None) + return None @defer.inlineCallbacks def generate_sync_result(self, sync_config, since_token=None, full_state=False): @@ -989,19 +985,17 @@ class SyncHandler(object): "Sync result for newly joined room %s: %r", room_id, joined_room ) - defer.returnValue( - SyncResult( - presence=sync_result_builder.presence, - account_data=sync_result_builder.account_data, - joined=sync_result_builder.joined, - invited=sync_result_builder.invited, - archived=sync_result_builder.archived, - to_device=sync_result_builder.to_device, - device_lists=device_lists, - groups=sync_result_builder.groups, - device_one_time_keys_count=one_time_key_counts, - next_batch=sync_result_builder.now_token, - ) + return SyncResult( + presence=sync_result_builder.presence, + account_data=sync_result_builder.account_data, + joined=sync_result_builder.joined, + invited=sync_result_builder.invited, + archived=sync_result_builder.archived, + to_device=sync_result_builder.to_device, + device_lists=device_lists, + groups=sync_result_builder.groups, + device_one_time_keys_count=one_time_key_counts, + next_batch=sync_result_builder.now_token, ) @measure_func("_generate_sync_entry_for_groups") @@ -1124,11 +1118,9 @@ class SyncHandler(object): # Remove any users that we still share a room with. newly_left_users -= users_who_share_room - defer.returnValue( - DeviceLists(changed=users_that_have_changed, left=newly_left_users) - ) + return DeviceLists(changed=users_that_have_changed, left=newly_left_users) else: - defer.returnValue(DeviceLists(changed=[], left=[])) + return DeviceLists(changed=[], left=[]) @defer.inlineCallbacks def _generate_sync_entry_for_to_device(self, sync_result_builder): @@ -1225,7 +1217,7 @@ class SyncHandler(object): sync_result_builder.account_data = account_data_for_user - defer.returnValue(account_data_by_room) + return account_data_by_room @defer.inlineCallbacks def _generate_sync_entry_for_presence( @@ -1325,7 +1317,7 @@ class SyncHandler(object): ) if not tags_by_room: logger.debug("no-oping sync") - defer.returnValue(([], [], [], [])) + return ([], [], [], []) ignored_account_data = yield self.store.get_global_account_data_by_type_for_user( "m.ignored_user_list", user_id=user_id @@ -1388,13 +1380,11 @@ class SyncHandler(object): newly_left_users -= newly_joined_or_invited_users - defer.returnValue( - ( - newly_joined_rooms, - newly_joined_or_invited_users, - newly_left_rooms, - newly_left_users, - ) + return ( + newly_joined_rooms, + newly_joined_or_invited_users, + newly_left_rooms, + newly_left_users, ) @defer.inlineCallbacks @@ -1414,13 +1404,13 @@ class SyncHandler(object): ) if rooms_changed: - defer.returnValue(True) + return True stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream for room_id in sync_result_builder.joined_room_ids: if self.store.has_room_changed_since(room_id, stream_id): - defer.returnValue(True) - defer.returnValue(False) + return True + return False @defer.inlineCallbacks def _get_rooms_changed(self, sync_result_builder, ignored_users): @@ -1637,7 +1627,7 @@ class SyncHandler(object): ) room_entries.append(entry) - defer.returnValue((room_entries, invited, newly_joined_rooms, newly_left_rooms)) + return (room_entries, invited, newly_joined_rooms, newly_left_rooms) @defer.inlineCallbacks def _get_all_rooms(self, sync_result_builder, ignored_users): @@ -1711,7 +1701,7 @@ class SyncHandler(object): ) ) - defer.returnValue((room_entries, invited, [])) + return (room_entries, invited, []) @defer.inlineCallbacks def _generate_room_entry( @@ -1912,7 +1902,7 @@ class SyncHandler(object): joined_room_ids.add(room_id) joined_room_ids = frozenset(joined_room_ids) - defer.returnValue(joined_room_ids) + return joined_room_ids def _action_has_highlight(actions): diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index c3e0c8fc7e..6b661aa93d 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -140,7 +140,7 @@ class TypingHandler(object): if was_present: # No point sending another notification - defer.returnValue(None) + return None self._push_update(member=member, typing=True) @@ -173,7 +173,7 @@ class TypingHandler(object): def _stopped_typing(self, member): if member.user_id not in self._room_typing.get(member.room_id, set()): # No point - defer.returnValue(None) + return None self._member_typing_until.pop(member, None) self._member_last_federation_poke.pop(member, None) diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 5de9630950..e53669e40d 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -133,7 +133,7 @@ class UserDirectoryHandler(StateDeltasHandler): # If still None then the initial background update hasn't happened yet if self.pos is None: - defer.returnValue(None) + return None # Loop round handling deltas until we're up to date while True: diff --git a/synapse/http/client.py b/synapse/http/client.py index 45d5010952..0ac20ebefc 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -294,7 +294,7 @@ class SimpleHttpClient(object): logger.info( "Received response to %s %s: %s", method, redact_uri(uri), response.code ) - defer.returnValue(response) + return response except Exception as e: incoming_responses_counter.labels(method, "ERR").inc() logger.info( @@ -345,7 +345,7 @@ class SimpleHttpClient(object): body = yield make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: - defer.returnValue(json.loads(body)) + return json.loads(body) else: raise HttpResponseException(response.code, response.phrase, body) @@ -385,7 +385,7 @@ class SimpleHttpClient(object): body = yield make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: - defer.returnValue(json.loads(body)) + return json.loads(body) else: raise HttpResponseException(response.code, response.phrase, body) @@ -410,7 +410,7 @@ class SimpleHttpClient(object): ValueError: if the response was not JSON """ body = yield self.get_raw(uri, args, headers=headers) - defer.returnValue(json.loads(body)) + return json.loads(body) @defer.inlineCallbacks def put_json(self, uri, json_body, args={}, headers=None): @@ -453,7 +453,7 @@ class SimpleHttpClient(object): body = yield make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: - defer.returnValue(json.loads(body)) + return json.loads(body) else: raise HttpResponseException(response.code, response.phrase, body) @@ -488,7 +488,7 @@ class SimpleHttpClient(object): body = yield make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: - defer.returnValue(body) + return body else: raise HttpResponseException(response.code, response.phrase, body) @@ -545,13 +545,11 @@ class SimpleHttpClient(object): except Exception as e: raise_from(SynapseError(502, ("Failed to download remote body: %s" % e)), e) - defer.returnValue( - ( - length, - resp_headers, - response.request.absoluteURI.decode("ascii"), - response.code, - ) + return ( + length, + resp_headers, + response.request.absoluteURI.decode("ascii"), + response.code, ) @@ -627,10 +625,10 @@ class CaptchaServerHttpClient(SimpleHttpClient): try: body = yield make_deferred_yieldable(readBody(response)) - defer.returnValue(body) + return body except PartialDownloadError as e: # twisted dislikes google's response, no content length. - defer.returnValue(e.response) + return e.response def encode_urlencode_args(args): diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 054c321a20..c03ddb724f 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -177,7 +177,7 @@ class MatrixFederationAgent(object): res = yield make_deferred_yieldable( agent.request(method, uri, headers, bodyProducer) ) - defer.returnValue(res) + return res @defer.inlineCallbacks def _route_matrix_uri(self, parsed_uri, lookup_well_known=True): @@ -205,24 +205,20 @@ class MatrixFederationAgent(object): port = parsed_uri.port if port == -1: port = 8448 - defer.returnValue( - _RoutingResult( - host_header=parsed_uri.netloc, - tls_server_name=parsed_uri.host, - target_host=parsed_uri.host, - target_port=port, - ) + return _RoutingResult( + host_header=parsed_uri.netloc, + tls_server_name=parsed_uri.host, + target_host=parsed_uri.host, + target_port=port, ) if parsed_uri.port != -1: # there is an explicit port - defer.returnValue( - _RoutingResult( - host_header=parsed_uri.netloc, - tls_server_name=parsed_uri.host, - target_host=parsed_uri.host, - target_port=parsed_uri.port, - ) + return _RoutingResult( + host_header=parsed_uri.netloc, + tls_server_name=parsed_uri.host, + target_host=parsed_uri.host, + target_port=parsed_uri.port, ) if lookup_well_known: @@ -259,7 +255,7 @@ class MatrixFederationAgent(object): ) res = yield self._route_matrix_uri(new_uri, lookup_well_known=False) - defer.returnValue(res) + return res # try a SRV lookup service_name = b"_matrix._tcp.%s" % (parsed_uri.host,) @@ -283,13 +279,11 @@ class MatrixFederationAgent(object): parsed_uri.host.decode("ascii"), ) - defer.returnValue( - _RoutingResult( - host_header=parsed_uri.netloc, - tls_server_name=parsed_uri.host, - target_host=target_host, - target_port=port, - ) + return _RoutingResult( + host_header=parsed_uri.netloc, + tls_server_name=parsed_uri.host, + target_host=target_host, + target_port=port, ) @defer.inlineCallbacks @@ -314,7 +308,7 @@ class MatrixFederationAgent(object): if cache_period > 0: self._well_known_cache.set(server_name, result, cache_period) - defer.returnValue(result) + return result @defer.inlineCallbacks def _do_get_well_known(self, server_name): @@ -354,7 +348,7 @@ class MatrixFederationAgent(object): # after startup cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER) - defer.returnValue((None, cache_period)) + return (None, cache_period) result = parsed_body["m.server"].encode("ascii") @@ -369,7 +363,7 @@ class MatrixFederationAgent(object): else: cache_period = min(cache_period, WELL_KNOWN_MAX_CACHE_PERIOD) - defer.returnValue((result, cache_period)) + return (result, cache_period) @implementer(IStreamClientEndpoint) diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py index ecc88f9b96..b32188766d 100644 --- a/synapse/http/federation/srv_resolver.py +++ b/synapse/http/federation/srv_resolver.py @@ -120,7 +120,7 @@ class SrvResolver(object): if cache_entry: if all(s.expires > now for s in cache_entry): servers = list(cache_entry) - defer.returnValue(servers) + return servers try: answers, _, _ = yield make_deferred_yieldable( @@ -129,7 +129,7 @@ class SrvResolver(object): except DNSNameError: # TODO: cache this. We can get the SOA out of the exception, and use # the negative-TTL value. - defer.returnValue([]) + return [] except DomainError as e: # We failed to resolve the name (other than a NameError) # Try something in the cache, else rereaise @@ -138,7 +138,7 @@ class SrvResolver(object): logger.warn( "Failed to resolve %r, falling back to cache. %r", service_name, e ) - defer.returnValue(list(cache_entry)) + return list(cache_entry) else: raise e @@ -169,4 +169,4 @@ class SrvResolver(object): ) self._cache[service_name] = list(servers) - defer.returnValue(servers) + return servers diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index e60334547e..d07d356464 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -158,7 +158,7 @@ def _handle_json_response(reactor, timeout_sec, request, response): response.code, response.phrase.decode("ascii", errors="replace"), ) - defer.returnValue(body) + return body class MatrixFederationHttpClient(object): @@ -256,7 +256,7 @@ class MatrixFederationHttpClient(object): response = yield self._send_request(request, **send_request_args) - defer.returnValue(response) + return response @defer.inlineCallbacks def _send_request( @@ -520,7 +520,7 @@ class MatrixFederationHttpClient(object): _flatten_response_never_received(e), ) raise - defer.returnValue(response) + return response def build_auth_headers( self, destination, method, url_bytes, content=None, destination_is=None @@ -644,7 +644,7 @@ class MatrixFederationHttpClient(object): self.reactor, self.default_timeout, request, response ) - defer.returnValue(body) + return body @defer.inlineCallbacks def post_json( @@ -713,7 +713,7 @@ class MatrixFederationHttpClient(object): body = yield _handle_json_response( self.reactor, _sec_timeout, request, response ) - defer.returnValue(body) + return body @defer.inlineCallbacks def get_json( @@ -778,7 +778,7 @@ class MatrixFederationHttpClient(object): self.reactor, self.default_timeout, request, response ) - defer.returnValue(body) + return body @defer.inlineCallbacks def delete_json( @@ -836,7 +836,7 @@ class MatrixFederationHttpClient(object): body = yield _handle_json_response( self.reactor, self.default_timeout, request, response ) - defer.returnValue(body) + return body @defer.inlineCallbacks def get_file( @@ -902,7 +902,7 @@ class MatrixFederationHttpClient(object): response.phrase.decode("ascii", errors="replace"), length, ) - defer.returnValue((length, headers)) + return (length, headers) class _ReadBodyToFileProtocol(protocol.Protocol): diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 96a4714d82..fb338ca223 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -89,7 +89,7 @@ the function becomes the operation name for the span. # We start yield we_wait # we finish - defer.returnValue(something_usual_and_useful) + return something_usual_and_useful Operation names can be explicitly set for functions by using ``trace_using_operation_name`` and @@ -113,7 +113,7 @@ Operation names can be explicitly set for functions by using # We start yield we_wait # we finish - defer.returnValue(something_usual_and_useful) + return something_usual_and_useful Contexts and carriers --------------------- @@ -694,7 +694,7 @@ def trace_servlet(servlet_name, func): }, ): result = yield defer.maybeDeferred(func, request, *args, **kwargs) - defer.returnValue(result) + return result return _trace_servlet_inner diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 7bb020cb45..41147d4292 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -101,7 +101,7 @@ class ModuleApi(object): ) user_id = yield self.register_user(localpart, displayname, emails) _, access_token = yield self.register_device(user_id) - defer.returnValue((user_id, access_token)) + return (user_id, access_token) def register_user(self, localpart, displayname=None, emails=[]): """Registers a new user with given localpart and optional displayname, emails. diff --git a/synapse/notifier.py b/synapse/notifier.py index 918ef64897..bd80c801b6 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -365,7 +365,7 @@ class Notifier(object): current_token = user_stream.current_token result = yield callback(prev_token, current_token) - defer.returnValue(result) + return result @defer.inlineCallbacks def get_events_for( @@ -400,7 +400,7 @@ class Notifier(object): @defer.inlineCallbacks def check_for_updates(before_token, after_token): if not after_token.is_after(before_token): - defer.returnValue(EventStreamResult([], (from_token, from_token))) + return EventStreamResult([], (from_token, from_token)) events = [] end_token = from_token @@ -440,7 +440,7 @@ class Notifier(object): events.extend(new_events) end_token = end_token.copy_and_replace(keyname, new_key) - defer.returnValue(EventStreamResult(events, (from_token, end_token))) + return EventStreamResult(events, (from_token, end_token)) user_id_for_stream = user.to_string() if is_peeking: @@ -465,18 +465,18 @@ class Notifier(object): from_token=from_token, ) - defer.returnValue(result) + return result @defer.inlineCallbacks def _get_room_ids(self, user, explicit_room_id): joined_room_ids = yield self.store.get_rooms_for_user(user.to_string()) if explicit_room_id: if explicit_room_id in joined_room_ids: - defer.returnValue(([explicit_room_id], True)) + return ([explicit_room_id], True) if (yield self._is_world_readable(explicit_room_id)): - defer.returnValue(([explicit_room_id], False)) + return ([explicit_room_id], False) raise AuthError(403, "Non-joined access not allowed") - defer.returnValue((joined_room_ids, True)) + return (joined_room_ids, True) @defer.inlineCallbacks def _is_world_readable(self, room_id): @@ -484,9 +484,9 @@ class Notifier(object): room_id, EventTypes.RoomHistoryVisibility, "" ) if state and "history_visibility" in state.content: - defer.returnValue(state.content["history_visibility"] == "world_readable") + return state.content["history_visibility"] == "world_readable" else: - defer.returnValue(False) + return False @log_function def remove_expired_streams(self): diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index c8a5b381da..c831975635 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -95,7 +95,7 @@ class BulkPushRuleEvaluator(object): invited ) - defer.returnValue(rules_by_user) + return rules_by_user @cached() def _get_rules_for_room(self, room_id): @@ -134,7 +134,7 @@ class BulkPushRuleEvaluator(object): pl_event = auth_events.get(POWER_KEY) - defer.returnValue((pl_event.content if pl_event else {}, sender_level)) + return (pl_event.content if pl_event else {}, sender_level) @defer.inlineCallbacks def action_for_event_by_user(self, event, context): @@ -283,13 +283,13 @@ class RulesForRoom(object): if state_group and self.state_group == state_group: logger.debug("Using cached rules for %r", self.room_id) self.room_push_rule_cache_metrics.inc_hits() - defer.returnValue(self.rules_by_user) + return self.rules_by_user with (yield self.linearizer.queue(())): if state_group and self.state_group == state_group: logger.debug("Using cached rules for %r", self.room_id) self.room_push_rule_cache_metrics.inc_hits() - defer.returnValue(self.rules_by_user) + return self.rules_by_user self.room_push_rule_cache_metrics.inc_misses() @@ -366,7 +366,7 @@ class RulesForRoom(object): logger.debug( "Returning push rules for %r %r", self.room_id, ret_rules_by_user.keys() ) - defer.returnValue(ret_rules_by_user) + return ret_rules_by_user @defer.inlineCallbacks def _update_rules_with_member_event_ids( diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 4e7b6a5531..5b15b0dbe7 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -258,17 +258,17 @@ class HttpPusher(object): @defer.inlineCallbacks def _process_one(self, push_action): if "notify" not in push_action["actions"]: - defer.returnValue(True) + return True tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"]) badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) event = yield self.store.get_event(push_action["event_id"], allow_none=True) if event is None: - defer.returnValue(True) # It's been redacted + return True # It's been redacted rejected = yield self.dispatch_push(event, tweaks, badge) if rejected is False: - defer.returnValue(False) + return False if isinstance(rejected, list) or isinstance(rejected, tuple): for pk in rejected: @@ -282,7 +282,7 @@ class HttpPusher(object): else: logger.info("Pushkey %s was rejected: removing", pk) yield self.hs.remove_pusher(self.app_id, pk, self.user_id) - defer.returnValue(True) + return True @defer.inlineCallbacks def _build_notification_dict(self, event, tweaks, badge): @@ -302,7 +302,7 @@ class HttpPusher(object): ], } } - defer.returnValue(d) + return d ctx = yield push_tools.get_context_for_event( self.store, self.state_handler, event, self.user_id @@ -345,13 +345,13 @@ class HttpPusher(object): if "name" in ctx and len(ctx["name"]) > 0: d["notification"]["room_name"] = ctx["name"] - defer.returnValue(d) + return d @defer.inlineCallbacks def dispatch_push(self, event, tweaks, badge): notification_dict = yield self._build_notification_dict(event, tweaks, badge) if not notification_dict: - defer.returnValue([]) + return [] try: resp = yield self.http_client.post_json_get_json( self.url, notification_dict @@ -364,11 +364,11 @@ class HttpPusher(object): type(e), e, ) - defer.returnValue(False) + return False rejected = [] if "rejected" in resp: rejected = resp["rejected"] - defer.returnValue(rejected) + return rejected @defer.inlineCallbacks def _send_badge(self, badge): diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 521c6e2cd7..4245ce26f3 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -316,7 +316,7 @@ class Mailer(object): if not merge: room_vars["notifs"].append(notifvars) - defer.returnValue(room_vars) + return room_vars @defer.inlineCallbacks def get_notif_vars(self, notif, user_id, notif_event, room_state_ids): @@ -343,7 +343,7 @@ class Mailer(object): if messagevars is not None: ret["messages"].append(messagevars) - defer.returnValue(ret) + return ret @defer.inlineCallbacks def get_message_vars(self, notif, event, room_state_ids): @@ -379,7 +379,7 @@ class Mailer(object): if "body" in event.content: ret["body_text_plain"] = event.content["body"] - defer.returnValue(ret) + return ret def add_text_message_vars(self, messagevars, event): msgformat = event.content.get("format") @@ -428,19 +428,16 @@ class Mailer(object): inviter_name = name_from_member_event(inviter_member_event) if room_name is None: - defer.returnValue( - INVITE_FROM_PERSON - % {"person": inviter_name, "app": self.app_name} - ) + return INVITE_FROM_PERSON % { + "person": inviter_name, + "app": self.app_name, + } else: - defer.returnValue( - INVITE_FROM_PERSON_TO_ROOM - % { - "person": inviter_name, - "room": room_name, - "app": self.app_name, - } - ) + return INVITE_FROM_PERSON_TO_ROOM % { + "person": inviter_name, + "room": room_name, + "app": self.app_name, + } sender_name = None if len(notifs_by_room[room_id]) == 1: @@ -454,26 +451,21 @@ class Mailer(object): sender_name = name_from_member_event(state_event) if sender_name is not None and room_name is not None: - defer.returnValue( - MESSAGE_FROM_PERSON_IN_ROOM - % { - "person": sender_name, - "room": room_name, - "app": self.app_name, - } - ) + return MESSAGE_FROM_PERSON_IN_ROOM % { + "person": sender_name, + "room": room_name, + "app": self.app_name, + } elif sender_name is not None: - defer.returnValue( - MESSAGE_FROM_PERSON - % {"person": sender_name, "app": self.app_name} - ) + return MESSAGE_FROM_PERSON % { + "person": sender_name, + "app": self.app_name, + } else: # There's more than one notification for this room, so just # say there are several if room_name is not None: - defer.returnValue( - MESSAGES_IN_ROOM % {"room": room_name, "app": self.app_name} - ) + return MESSAGES_IN_ROOM % {"room": room_name, "app": self.app_name} else: # If the room doesn't have a name, say who the messages # are from explicitly to avoid, "messages in the Bob room" @@ -493,24 +485,19 @@ class Mailer(object): ] ) - defer.returnValue( - MESSAGES_FROM_PERSON - % { - "person": descriptor_from_member_events( - member_events.values() - ), - "app": self.app_name, - } - ) + return MESSAGES_FROM_PERSON % { + "person": descriptor_from_member_events(member_events.values()), + "app": self.app_name, + } else: # Stuff's happened in multiple different rooms # ...but we still refer to the 'reason' room which triggered the mail if reason["room_name"] is not None: - defer.returnValue( - MESSAGES_IN_ROOM_AND_OTHERS - % {"room": reason["room_name"], "app": self.app_name} - ) + return MESSAGES_IN_ROOM_AND_OTHERS % { + "room": reason["room_name"], + "app": self.app_name, + } else: # If the reason room doesn't have a name, say who the messages # are from explicitly to avoid, "messages in the Bob room" @@ -527,13 +514,10 @@ class Mailer(object): [room_state_ids[room_id][("m.room.member", s)] for s in sender_ids] ) - defer.returnValue( - MESSAGES_FROM_PERSON_AND_OTHERS - % { - "person": descriptor_from_member_events(member_events.values()), - "app": self.app_name, - } - ) + return MESSAGES_FROM_PERSON_AND_OTHERS % { + "person": descriptor_from_member_events(member_events.values()), + "app": self.app_name, + } def make_room_link(self, room_id): if self.hs.config.email_riot_base_url: diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py index 06056fbf4f..16a7e8e31d 100644 --- a/synapse/push/presentable_names.py +++ b/synapse/push/presentable_names.py @@ -55,7 +55,7 @@ def calculate_room_name( room_state_ids[("m.room.name", "")], allow_none=True ) if m_room_name and m_room_name.content and m_room_name.content["name"]: - defer.returnValue(m_room_name.content["name"]) + return m_room_name.content["name"] # does it have a canonical alias? if ("m.room.canonical_alias", "") in room_state_ids: @@ -68,7 +68,7 @@ def calculate_room_name( and canon_alias.content["alias"] and _looks_like_an_alias(canon_alias.content["alias"]) ): - defer.returnValue(canon_alias.content["alias"]) + return canon_alias.content["alias"] # at this point we're going to need to search the state by all state keys # for an event type, so rearrange the data structure @@ -82,10 +82,10 @@ def calculate_room_name( if alias_event and alias_event.content.get("aliases"): the_aliases = alias_event.content["aliases"] if len(the_aliases) > 0 and _looks_like_an_alias(the_aliases[0]): - defer.returnValue(the_aliases[0]) + return the_aliases[0] if not fallback_to_members: - defer.returnValue(None) + return None my_member_event = None if ("m.room.member", user_id) in room_state_ids: @@ -104,14 +104,13 @@ def calculate_room_name( ) if inviter_member_event: if fallback_to_single_member: - defer.returnValue( - "Invite from %s" - % (name_from_member_event(inviter_member_event),) + return "Invite from %s" % ( + name_from_member_event(inviter_member_event), ) else: return else: - defer.returnValue("Room Invite") + return "Room Invite" # we're going to have to generate a name based on who's in the room, # so find out who is in the room that isn't the user. @@ -154,17 +153,17 @@ def calculate_room_name( # return "Inviting %s" % ( # descriptor_from_member_events(third_party_invites) # ) - defer.returnValue("Inviting email address") + return "Inviting email address" else: - defer.returnValue(ALL_ALONE) + return ALL_ALONE else: - defer.returnValue(name_from_member_event(all_members[0])) + return name_from_member_event(all_members[0]) else: - defer.returnValue(ALL_ALONE) + return ALL_ALONE elif len(other_members) == 1 and not fallback_to_single_member: return else: - defer.returnValue(descriptor_from_member_events(other_members)) + return descriptor_from_member_events(other_members) def descriptor_from_member_events(member_events): diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index e37269cdb9..a54051a726 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -39,7 +39,7 @@ def get_badge_count(store, user_id): # return one badge count per conversation, as count per # message is so noisy as to be almost useless badge += 1 if notifs["notify_count"] else 0 - defer.returnValue(badge) + return badge @defer.inlineCallbacks @@ -61,4 +61,4 @@ def get_context_for_event(store, state_handler, ev, user_id): sender_state_event = yield store.get_event(sender_state_event_id) ctx["sender_display_name"] = name_from_member_event(sender_state_event) - defer.returnValue(ctx) + return ctx diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index df6f670740..08e840fdc2 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -123,7 +123,7 @@ class PusherPool: ) pusher = yield self.start_pusher_by_id(app_id, pushkey, user_id) - defer.returnValue(pusher) + return pusher @defer.inlineCallbacks def remove_pushers_by_app_id_and_pushkey_not_user( @@ -224,7 +224,7 @@ class PusherPool: if pusher_dict: pusher = yield self._start_pusher(pusher_dict) - defer.returnValue(pusher) + return pusher @defer.inlineCallbacks def _start_pushers(self): @@ -293,7 +293,7 @@ class PusherPool: p.on_started(have_notifs) - defer.returnValue(p) + return p @defer.inlineCallbacks def remove_pusher(self, app_id, pushkey, user_id): diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index fe482e279f..f5074b101a 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -185,7 +185,7 @@ class ReplicationEndpoint(object): except RequestSendFailed as e: raise_from(SynapseError(502, "Failed to talk to master"), e) - defer.returnValue(result) + return result return send_request diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 61eafbe708..fed4f08820 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -80,7 +80,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): payload = {"events": event_payloads, "backfilled": backfilled} - defer.returnValue(payload) + return payload @defer.inlineCallbacks def _handle_request(self, request): @@ -113,7 +113,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): event_and_contexts, backfilled ) - defer.returnValue((200, {})) + return (200, {}) class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): @@ -156,7 +156,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): result = yield self.registry.on_edu(edu_type, origin, edu_content) - defer.returnValue((200, result)) + return (200, result) class ReplicationGetQueryRestServlet(ReplicationEndpoint): @@ -204,7 +204,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): result = yield self.registry.on_query(query_type, args) - defer.returnValue((200, result)) + return (200, result) class ReplicationCleanRoomRestServlet(ReplicationEndpoint): @@ -238,7 +238,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint): def _handle_request(self, request, room_id): yield self.store.clean_room_for_join(room_id) - defer.returnValue((200, {})) + return (200, {}) def register_servlets(hs, http_server): diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index 7c1197e5dd..f17d3a2da4 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -64,7 +64,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): user_id, device_id, initial_display_name, is_guest ) - defer.returnValue((200, {"device_id": device_id, "access_token": access_token})) + return (200, {"device_id": device_id, "access_token": access_token}) def register_servlets(hs, http_server): diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 2d9cbbaefc..4217335d88 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -83,7 +83,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): remote_room_hosts, room_id, user_id, event_content ) - defer.returnValue((200, {})) + return (200, {}) class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): @@ -153,7 +153,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): yield self.store.locally_reject_invite(user_id, room_id) ret = {} - defer.returnValue((200, ret)) + return (200, ret) class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index 2bf2173895..3341320a87 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -90,7 +90,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): address=content["address"], ) - defer.returnValue((200, {})) + return (200, {}) class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): @@ -143,7 +143,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): bind_msisdn=bind_msisdn, ) - defer.returnValue((200, {})) + return (200, {}) def register_servlets(hs, http_server): diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index 034763fe99..eff7bd7305 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -85,7 +85,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): "extra_users": [u.to_string() for u in extra_users], } - defer.returnValue(payload) + return payload @defer.inlineCallbacks def _handle_request(self, request, event_id): @@ -117,7 +117,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) - defer.returnValue((200, {})) + return (200, {}) def register_servlets(hs, http_server): diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 7ef67a5a73..c10b85d2ff 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -158,7 +158,7 @@ class Stream(object): updates, current_token = yield self.get_updates_since(self.last_token) self.last_token = current_token - defer.returnValue((updates, current_token)) + return (updates, current_token) @defer.inlineCallbacks def get_updates_since(self, from_token): @@ -172,14 +172,14 @@ class Stream(object): sent over the replication steam. """ if from_token in ("NOW", "now"): - defer.returnValue(([], self.upto_token)) + return ([], self.upto_token) current_token = self.upto_token from_token = int(from_token) if from_token == current_token: - defer.returnValue(([], current_token)) + return ([], current_token) if self._LIMITED: rows = yield self.update_function( @@ -198,7 +198,7 @@ class Stream(object): if self._LIMITED and len(updates) >= MAX_EVENTS_BEHIND: raise Exception("stream %s has fallen behind" % (self.NAME)) - defer.returnValue((updates, current_token)) + return (updates, current_token) def current_token(self): """Gets the current token of the underlying streams. Should be provided @@ -297,7 +297,7 @@ class PushRulesStream(Stream): @defer.inlineCallbacks def update_function(self, from_token, to_token, limit): rows = yield self.store.get_all_push_rule_updates(from_token, to_token, limit) - defer.returnValue([(row[0], row[2]) for row in rows]) + return [(row[0], row[2]) for row in rows] class PushersStream(Stream): @@ -424,7 +424,7 @@ class AccountDataStream(Stream): for stream_id, user_id, account_data_type, content in global_results ) - defer.returnValue(results) + return results class GroupServerStream(Stream): diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 3d0694bb11..d97669c886 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -134,7 +134,7 @@ class EventsStream(Stream): all_updates = heapq.merge(event_updates, state_updates) - defer.returnValue(all_updates) + return all_updates @classmethod def parse_row(cls, row): diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 6888ae5590..0a7d9b81b2 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -84,7 +84,7 @@ class UsersRestServlet(RestServlet): ret = yield self.handlers.admin_handler.get_users() - defer.returnValue((200, ret)) + return (200, ret) class VersionServlet(RestServlet): @@ -227,7 +227,7 @@ class UserRegisterServlet(RestServlet): ) result = yield register._create_registration_details(user_id, body) - defer.returnValue((200, result)) + return (200, result) class WhoisRestServlet(RestServlet): @@ -252,7 +252,7 @@ class WhoisRestServlet(RestServlet): ret = yield self.handlers.admin_handler.get_whois(target_user) - defer.returnValue((200, ret)) + return (200, ret) class PurgeMediaCacheRestServlet(RestServlet): @@ -271,7 +271,7 @@ class PurgeMediaCacheRestServlet(RestServlet): ret = yield self.media_repository.delete_old_remote_media(before_ts) - defer.returnValue((200, ret)) + return (200, ret) class PurgeHistoryRestServlet(RestServlet): @@ -356,7 +356,7 @@ class PurgeHistoryRestServlet(RestServlet): room_id, token, delete_local_events=delete_local_events ) - defer.returnValue((200, {"purge_id": purge_id})) + return (200, {"purge_id": purge_id}) class PurgeHistoryStatusRestServlet(RestServlet): @@ -381,7 +381,7 @@ class PurgeHistoryStatusRestServlet(RestServlet): if purge_status is None: raise NotFoundError("purge id '%s' not found" % purge_id) - defer.returnValue((200, purge_status.asdict())) + return (200, purge_status.asdict()) class DeactivateAccountRestServlet(RestServlet): @@ -413,7 +413,7 @@ class DeactivateAccountRestServlet(RestServlet): else: id_server_unbind_result = "no-support" - defer.returnValue((200, {"id_server_unbind_result": id_server_unbind_result})) + return (200, {"id_server_unbind_result": id_server_unbind_result}) class ShutdownRoomRestServlet(RestServlet): @@ -531,16 +531,14 @@ class ShutdownRoomRestServlet(RestServlet): room_id, new_room_id, requester_user_id ) - defer.returnValue( - ( - 200, - { - "kicked_users": kicked_users, - "failed_to_kick_users": failed_to_kick_users, - "local_aliases": aliases_for_room, - "new_room_id": new_room_id, - }, - ) + return ( + 200, + { + "kicked_users": kicked_users, + "failed_to_kick_users": failed_to_kick_users, + "local_aliases": aliases_for_room, + "new_room_id": new_room_id, + }, ) @@ -564,7 +562,7 @@ class QuarantineMediaInRoom(RestServlet): room_id, requester.user.to_string() ) - defer.returnValue((200, {"num_quarantined": num_quarantined})) + return (200, {"num_quarantined": num_quarantined}) class ListMediaInRoom(RestServlet): @@ -585,7 +583,7 @@ class ListMediaInRoom(RestServlet): local_mxcs, remote_mxcs = yield self.store.get_media_mxcs_in_room(room_id) - defer.returnValue((200, {"local": local_mxcs, "remote": remote_mxcs})) + return (200, {"local": local_mxcs, "remote": remote_mxcs}) class ResetPasswordRestServlet(RestServlet): @@ -629,7 +627,7 @@ class ResetPasswordRestServlet(RestServlet): yield self._set_password_handler.set_password( target_user_id, new_password, requester ) - defer.returnValue((200, {})) + return (200, {}) class GetUsersPaginatedRestServlet(RestServlet): @@ -671,7 +669,7 @@ class GetUsersPaginatedRestServlet(RestServlet): logger.info("limit: %s, start: %s", limit, start) ret = yield self.handlers.admin_handler.get_users_paginate(order, start, limit) - defer.returnValue((200, ret)) + return (200, ret) @defer.inlineCallbacks def on_POST(self, request, target_user_id): @@ -699,7 +697,7 @@ class GetUsersPaginatedRestServlet(RestServlet): logger.info("limit: %s, start: %s", limit, start) ret = yield self.handlers.admin_handler.get_users_paginate(order, start, limit) - defer.returnValue((200, ret)) + return (200, ret) class SearchUsersRestServlet(RestServlet): @@ -742,7 +740,7 @@ class SearchUsersRestServlet(RestServlet): logger.info("term: %s ", term) ret = yield self.handlers.admin_handler.search_users(term) - defer.returnValue((200, ret)) + return (200, ret) class DeleteGroupAdminRestServlet(RestServlet): @@ -765,7 +763,7 @@ class DeleteGroupAdminRestServlet(RestServlet): raise SynapseError(400, "Can only delete local groups") yield self.group_server.delete_group(group_id, requester.user.to_string()) - defer.returnValue((200, {})) + return (200, {}) class AccountValidityRenewServlet(RestServlet): @@ -796,7 +794,7 @@ class AccountValidityRenewServlet(RestServlet): ) res = {"expiration_ts": expiration_ts} - defer.returnValue((200, res)) + return (200, res) ######################################################################################## diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py index ee66838a0d..90c0ee15dc 100644 --- a/synapse/rest/admin/server_notice_servlet.py +++ b/synapse/rest/admin/server_notice_servlet.py @@ -87,7 +87,7 @@ class SendServerNoticeServlet(RestServlet): event_content=body["content"], ) - defer.returnValue((200, {"event_id": event.event_id})) + return (200, {"event_id": event.event_id}) def on_PUT(self, request, txn_id): return self.txns.fetch_or_execute_request( diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 57542c2b4b..4284738021 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -54,7 +54,7 @@ class ClientDirectoryServer(RestServlet): dir_handler = self.handlers.directory_handler res = yield dir_handler.get_association(room_alias) - defer.returnValue((200, res)) + return (200, res) @defer.inlineCallbacks def on_PUT(self, request, room_alias): @@ -87,7 +87,7 @@ class ClientDirectoryServer(RestServlet): requester, room_alias, room_id, servers ) - defer.returnValue((200, {})) + return (200, {}) @defer.inlineCallbacks def on_DELETE(self, request, room_alias): @@ -102,7 +102,7 @@ class ClientDirectoryServer(RestServlet): service.url, room_alias.to_string(), ) - defer.returnValue((200, {})) + return (200, {}) except InvalidClientCredentialsError: # fallback to default user behaviour if they aren't an AS pass @@ -118,7 +118,7 @@ class ClientDirectoryServer(RestServlet): "User %s deleted alias %s", user.to_string(), room_alias.to_string() ) - defer.returnValue((200, {})) + return (200, {}) class ClientDirectoryListServer(RestServlet): @@ -136,9 +136,7 @@ class ClientDirectoryListServer(RestServlet): if room is None: raise NotFoundError("Unknown room") - defer.returnValue( - (200, {"visibility": "public" if room["is_public"] else "private"}) - ) + return (200, {"visibility": "public" if room["is_public"] else "private"}) @defer.inlineCallbacks def on_PUT(self, request, room_id): @@ -151,7 +149,7 @@ class ClientDirectoryListServer(RestServlet): requester, room_id, visibility ) - defer.returnValue((200, {})) + return (200, {}) @defer.inlineCallbacks def on_DELETE(self, request, room_id): @@ -161,7 +159,7 @@ class ClientDirectoryListServer(RestServlet): requester, room_id, "private" ) - defer.returnValue((200, {})) + return (200, {}) class ClientAppserviceDirectoryListServer(RestServlet): @@ -195,4 +193,4 @@ class ClientAppserviceDirectoryListServer(RestServlet): requester.app_service.id, network_id, room_id, visibility ) - defer.returnValue((200, {})) + return (200, {}) diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index d6de2b7360..53ebed2203 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -67,7 +67,7 @@ class EventStreamRestServlet(RestServlet): is_guest=is_guest, ) - defer.returnValue((200, chunk)) + return (200, chunk) def on_OPTIONS(self, request): return (200, {}) @@ -91,9 +91,9 @@ class EventRestServlet(RestServlet): time_now = self.clock.time_msec() if event: event = yield self._event_serializer.serialize_event(event, time_now) - defer.returnValue((200, event)) + return (200, event) else: - defer.returnValue((404, "Event not found.")) + return (404, "Event not found.") def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 0fe5f2d79b..70b8478e90 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -42,7 +42,7 @@ class InitialSyncRestServlet(RestServlet): include_archived=include_archived, ) - defer.returnValue((200, content)) + return (200, content) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 0d05945f0a..5762b9fd06 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -152,7 +152,7 @@ class LoginRestServlet(RestServlet): well_known_data = self._well_known_builder.get_well_known() if well_known_data: result["well_known"] = well_known_data - defer.returnValue((200, result)) + return (200, result) @defer.inlineCallbacks def _do_other_login(self, login_submission): @@ -212,7 +212,7 @@ class LoginRestServlet(RestServlet): result = yield self._register_device_with_callback( canonical_user_id, login_submission, callback_3pid ) - defer.returnValue(result) + return result # No password providers were able to handle this 3pid # Check local store @@ -241,7 +241,7 @@ class LoginRestServlet(RestServlet): result = yield self._register_device_with_callback( canonical_user_id, login_submission, callback ) - defer.returnValue(result) + return result @defer.inlineCallbacks def _register_device_with_callback(self, user_id, login_submission, callback=None): @@ -273,7 +273,7 @@ class LoginRestServlet(RestServlet): if callback is not None: yield callback(result) - defer.returnValue(result) + return result @defer.inlineCallbacks def do_token_login(self, login_submission): @@ -284,7 +284,7 @@ class LoginRestServlet(RestServlet): ) result = yield self._register_device_with_callback(user_id, login_submission) - defer.returnValue(result) + return result @defer.inlineCallbacks def do_jwt_login(self, login_submission): @@ -321,7 +321,7 @@ class LoginRestServlet(RestServlet): result = yield self._register_device_with_callback( registered_user_id, login_submission ) - defer.returnValue(result) + return result class BaseSSORedirectServlet(RestServlet): @@ -395,7 +395,7 @@ class CasTicketServlet(RestServlet): # even if that's being used old-http style to signal end-of-data body = pde.response result = yield self.handle_cas_response(request, body, client_redirect_url) - defer.returnValue(result) + return result def handle_cas_response(self, request, cas_response_body, client_redirect_url): user, attributes = self.parse_cas_response(cas_response_body) diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py index cd711be519..2769f3a189 100644 --- a/synapse/rest/client/v1/logout.py +++ b/synapse/rest/client/v1/logout.py @@ -49,7 +49,7 @@ class LogoutRestServlet(RestServlet): requester.user.to_string(), requester.device_id ) - defer.returnValue((200, {})) + return (200, {}) class LogoutAllRestServlet(RestServlet): @@ -75,7 +75,7 @@ class LogoutAllRestServlet(RestServlet): # .. and then delete any access tokens which weren't associated with # devices. yield self._auth_handler.delete_access_tokens_for_user(user_id) - defer.returnValue((200, {})) + return (200, {}) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index 3e87f0fdb3..1eb1068c98 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -56,7 +56,7 @@ class PresenceStatusRestServlet(RestServlet): state = yield self.presence_handler.get_state(target_user=user) state = format_user_presence_state(state, self.clock.time_msec()) - defer.returnValue((200, state)) + return (200, state) @defer.inlineCallbacks def on_PUT(self, request, user_id): @@ -88,7 +88,7 @@ class PresenceStatusRestServlet(RestServlet): if self.hs.config.use_presence: yield self.presence_handler.set_state(user, state) - defer.returnValue((200, {})) + return (200, {}) def on_OPTIONS(self, request): return (200, {}) diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index 4d8ab1f47e..2657ae45bb 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -48,7 +48,7 @@ class ProfileDisplaynameRestServlet(RestServlet): if displayname is not None: ret["displayname"] = displayname - defer.returnValue((200, ret)) + return (200, ret) @defer.inlineCallbacks def on_PUT(self, request, user_id): @@ -61,11 +61,11 @@ class ProfileDisplaynameRestServlet(RestServlet): try: new_name = content["displayname"] except Exception: - defer.returnValue((400, "Unable to parse name")) + return (400, "Unable to parse name") yield self.profile_handler.set_displayname(user, requester, new_name, is_admin) - defer.returnValue((200, {})) + return (200, {}) def on_OPTIONS(self, request, user_id): return (200, {}) @@ -98,7 +98,7 @@ class ProfileAvatarURLRestServlet(RestServlet): if avatar_url is not None: ret["avatar_url"] = avatar_url - defer.returnValue((200, ret)) + return (200, ret) @defer.inlineCallbacks def on_PUT(self, request, user_id): @@ -110,11 +110,11 @@ class ProfileAvatarURLRestServlet(RestServlet): try: new_name = content["avatar_url"] except Exception: - defer.returnValue((400, "Unable to parse name")) + return (400, "Unable to parse name") yield self.profile_handler.set_avatar_url(user, requester, new_name, is_admin) - defer.returnValue((200, {})) + return (200, {}) def on_OPTIONS(self, request, user_id): return (200, {}) @@ -150,7 +150,7 @@ class ProfileRestServlet(RestServlet): if avatar_url is not None: ret["avatar_url"] = avatar_url - defer.returnValue((200, ret)) + return (200, ret) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index e635efb420..c3ae8b98a8 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -69,7 +69,7 @@ class PushRuleRestServlet(RestServlet): if "attr" in spec: yield self.set_rule_attr(user_id, spec, content) self.notify_user(user_id) - defer.returnValue((200, {})) + return (200, {}) if spec["rule_id"].startswith("."): # Rule ids starting with '.' are reserved for server default rules. @@ -106,7 +106,7 @@ class PushRuleRestServlet(RestServlet): except RuleNotFoundException as e: raise SynapseError(400, str(e)) - defer.returnValue((200, {})) + return (200, {}) @defer.inlineCallbacks def on_DELETE(self, request, path): @@ -123,7 +123,7 @@ class PushRuleRestServlet(RestServlet): try: yield self.store.delete_push_rule(user_id, namespaced_rule_id) self.notify_user(user_id) - defer.returnValue((200, {})) + return (200, {}) except StoreError as e: if e.code == 404: raise NotFoundError() @@ -151,10 +151,10 @@ class PushRuleRestServlet(RestServlet): ) if path[0] == "": - defer.returnValue((200, rules)) + return (200, rules) elif path[0] == "global": result = _filter_ruleset_with_path(rules["global"], path[1:]) - defer.returnValue((200, result)) + return (200, result) else: raise UnrecognizedRequestError() diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index e9246018df..ebc3dec516 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -62,7 +62,7 @@ class PushersRestServlet(RestServlet): if k not in allowed_keys: del p[k] - defer.returnValue((200, {"pushers": pushers})) + return (200, {"pushers": pushers}) def on_OPTIONS(self, _): return 200, {} @@ -94,7 +94,7 @@ class PushersSetRestServlet(RestServlet): yield self.pusher_pool.remove_pusher( content["app_id"], content["pushkey"], user_id=user.to_string() ) - defer.returnValue((200, {})) + return (200, {}) assert_params_in_dict( content, @@ -143,7 +143,7 @@ class PushersSetRestServlet(RestServlet): self.notifier.on_new_replication_data() - defer.returnValue((200, {})) + return (200, {}) def on_OPTIONS(self, _): return 200, {} @@ -190,7 +190,7 @@ class PushersRemoveRestServlet(RestServlet): ) request.write(PushersRemoveRestServlet.SUCCESS_HTML) finish_request(request) - defer.returnValue(None) + return None def on_OPTIONS(self, _): return 200, {} diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 7709c2d705..012e7a44a6 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -85,7 +85,7 @@ class RoomCreateRestServlet(TransactionRestServlet): requester, self.get_room_config(request) ) - defer.returnValue((200, info)) + return (200, info) def get_room_config(self, request): user_supplied_config = parse_json_object_from_request(request) @@ -155,9 +155,9 @@ class RoomStateEventRestServlet(TransactionRestServlet): if format == "event": event = format_event_for_client_v2(data.get_dict()) - defer.returnValue((200, event)) + return (200, event) elif format == "content": - defer.returnValue((200, data.get_dict()["content"])) + return (200, data.get_dict()["content"]) @defer.inlineCallbacks def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): @@ -192,7 +192,7 @@ class RoomStateEventRestServlet(TransactionRestServlet): ret = {} if event: ret = {"event_id": event.event_id} - defer.returnValue((200, ret)) + return (200, ret) # TODO: Needs unit testing for generic events + feedback @@ -226,7 +226,7 @@ class RoomSendEventRestServlet(TransactionRestServlet): requester, event_dict, txn_id=txn_id ) - defer.returnValue((200, {"event_id": event.event_id})) + return (200, {"event_id": event.event_id}) def on_GET(self, request, room_id, event_type, txn_id): return (200, "Not implemented") @@ -289,7 +289,7 @@ class JoinRoomAliasServlet(TransactionRestServlet): third_party_signed=content.get("third_party_signed", None), ) - defer.returnValue((200, {"room_id": room_id})) + return (200, {"room_id": room_id}) def on_PUT(self, request, room_identifier, txn_id): return self.txns.fetch_or_execute_request( @@ -342,7 +342,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): limit=limit, since_token=since_token ) - defer.returnValue((200, data)) + return (200, data) @defer.inlineCallbacks def on_POST(self, request): @@ -387,7 +387,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): network_tuple=network_tuple, ) - defer.returnValue((200, data)) + return (200, data) # TODO: Needs unit testing @@ -438,7 +438,7 @@ class RoomMemberListRestServlet(RestServlet): continue chunk.append(event) - defer.returnValue((200, {"chunk": chunk})) + return (200, {"chunk": chunk}) # deprecated in favour of /members?membership=join? @@ -459,7 +459,7 @@ class JoinedRoomMemberListRestServlet(RestServlet): requester, room_id ) - defer.returnValue((200, {"joined": users_with_profile})) + return (200, {"joined": users_with_profile}) # TODO: Needs better unit testing @@ -492,7 +492,7 @@ class RoomMessageListRestServlet(RestServlet): event_filter=event_filter, ) - defer.returnValue((200, msgs)) + return (200, msgs) # TODO: Needs unit testing @@ -513,7 +513,7 @@ class RoomStateRestServlet(RestServlet): user_id=requester.user.to_string(), is_guest=requester.is_guest, ) - defer.returnValue((200, events)) + return (200, events) # TODO: Needs unit testing @@ -532,7 +532,7 @@ class RoomInitialSyncRestServlet(RestServlet): content = yield self.initial_sync_handler.room_initial_sync( room_id=room_id, requester=requester, pagin_config=pagination_config ) - defer.returnValue((200, content)) + return (200, content) class RoomEventServlet(RestServlet): @@ -555,9 +555,9 @@ class RoomEventServlet(RestServlet): time_now = self.clock.time_msec() if event: event = yield self._event_serializer.serialize_event(event, time_now) - defer.returnValue((200, event)) + return (200, event) else: - defer.returnValue((404, "Event not found.")) + return (404, "Event not found.") class RoomEventContextServlet(RestServlet): @@ -607,7 +607,7 @@ class RoomEventContextServlet(RestServlet): results["state"], time_now ) - defer.returnValue((200, results)) + return (200, results) class RoomForgetRestServlet(TransactionRestServlet): @@ -626,7 +626,7 @@ class RoomForgetRestServlet(TransactionRestServlet): yield self.room_member_handler.forget(user=requester.user, room_id=room_id) - defer.returnValue((200, {})) + return (200, {}) def on_PUT(self, request, room_id, txn_id): return self.txns.fetch_or_execute_request( @@ -676,7 +676,7 @@ class RoomMembershipRestServlet(TransactionRestServlet): requester, txn_id, ) - defer.returnValue((200, {})) + return (200, {}) return target = requester.user @@ -703,7 +703,7 @@ class RoomMembershipRestServlet(TransactionRestServlet): if membership_action == "join": return_value["room_id"] = room_id - defer.returnValue((200, return_value)) + return (200, return_value) def _has_3pid_invite_keys(self, content): for key in {"id_server", "medium", "address"}: @@ -745,7 +745,7 @@ class RoomRedactEventRestServlet(TransactionRestServlet): txn_id=txn_id, ) - defer.returnValue((200, {"event_id": event.event_id})) + return (200, {"event_id": event.event_id}) def on_PUT(self, request, room_id, event_id, txn_id): return self.txns.fetch_or_execute_request( @@ -790,7 +790,7 @@ class RoomTypingRestServlet(RestServlet): target_user=target_user, auth_user=requester.user, room_id=room_id ) - defer.returnValue((200, {})) + return (200, {}) class SearchRestServlet(RestServlet): @@ -812,7 +812,7 @@ class SearchRestServlet(RestServlet): requester.user, content, batch ) - defer.returnValue((200, results)) + return (200, results) class JoinedRoomsRestServlet(RestServlet): @@ -828,7 +828,7 @@ class JoinedRoomsRestServlet(RestServlet): requester = yield self.auth.get_user_by_req(request, allow_guest=True) room_ids = yield self.store.get_rooms_for_user(requester.user.to_string()) - defer.returnValue((200, {"joined_rooms": list(room_ids)})) + return (200, {"joined_rooms": list(room_ids)}) def register_txn_path(servlet, regex_string, http_server, with_get=False): diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index 41b3171ac8..497cddf8b8 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -60,18 +60,16 @@ class VoipRestServlet(RestServlet): password = turnPassword else: - defer.returnValue((200, {})) - - defer.returnValue( - ( - 200, - { - "username": username, - "password": password, - "ttl": userLifetime / 1000, - "uris": turnUris, - }, - ) + return (200, {}) + + return ( + 200, + { + "username": username, + "password": password, + "ttl": userLifetime / 1000, + "uris": turnUris, + }, ) def on_OPTIONS(self, request): diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index f143d8b85c..7ac456812a 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -117,7 +117,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # Wrap the session id in a JSON object ret = {"sid": sid} - defer.returnValue((200, ret)) + return (200, ret) @defer.inlineCallbacks def send_password_reset(self, email, client_secret, send_attempt, next_link=None): @@ -149,7 +149,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # Check that the send_attempt is higher than previous attempts if send_attempt <= last_send_attempt: # If not, just return a success without sending an email - defer.returnValue(session_id) + return session_id else: # An non-validated session does not exist yet. # Generate a session id @@ -185,7 +185,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): token_expires, ) - defer.returnValue(session_id) + return session_id class MsisdnPasswordRequestTokenRestServlet(RestServlet): @@ -221,7 +221,7 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet): raise SynapseError(400, "MSISDN not found", Codes.THREEPID_NOT_FOUND) ret = yield self.identity_handler.requestMsisdnToken(**body) - defer.returnValue((200, ret)) + return (200, ret) class PasswordResetSubmitTokenServlet(RestServlet): @@ -279,7 +279,7 @@ class PasswordResetSubmitTokenServlet(RestServlet): request.setResponseCode(302) request.setHeader("Location", next_link) finish_request(request) - defer.returnValue(None) + return None # Otherwise show the success template html = self.config.email_password_reset_success_html_content @@ -295,7 +295,7 @@ class PasswordResetSubmitTokenServlet(RestServlet): request.write(html.encode("utf-8")) finish_request(request) - defer.returnValue(None) + return None def load_jinja2_template(self, template_dir, template_filename, template_vars): """Loads a jinja2 template with variables to insert @@ -330,7 +330,7 @@ class PasswordResetSubmitTokenServlet(RestServlet): ) response_code = 200 if valid else 400 - defer.returnValue((response_code, {"success": valid})) + return (response_code, {"success": valid}) class PasswordRestServlet(RestServlet): @@ -399,7 +399,7 @@ class PasswordRestServlet(RestServlet): yield self._set_password_handler.set_password(user_id, new_password, requester) - defer.returnValue((200, {})) + return (200, {}) def on_OPTIONS(self, _): return 200, {} @@ -434,7 +434,7 @@ class DeactivateAccountRestServlet(RestServlet): yield self._deactivate_account_handler.deactivate_account( requester.user.to_string(), erase ) - defer.returnValue((200, {})) + return (200, {}) yield self.auth_handler.validate_user_via_ui_auth( requester, body, self.hs.get_ip_from_request(request) @@ -447,7 +447,7 @@ class DeactivateAccountRestServlet(RestServlet): else: id_server_unbind_result = "no-support" - defer.returnValue((200, {"id_server_unbind_result": id_server_unbind_result})) + return (200, {"id_server_unbind_result": id_server_unbind_result}) class EmailThreepidRequestTokenRestServlet(RestServlet): @@ -481,7 +481,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) ret = yield self.identity_handler.requestEmailToken(**body) - defer.returnValue((200, ret)) + return (200, ret) class MsisdnThreepidRequestTokenRestServlet(RestServlet): @@ -516,7 +516,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) ret = yield self.identity_handler.requestMsisdnToken(**body) - defer.returnValue((200, ret)) + return (200, ret) class ThreepidRestServlet(RestServlet): @@ -536,7 +536,7 @@ class ThreepidRestServlet(RestServlet): threepids = yield self.datastore.user_get_threepids(requester.user.to_string()) - defer.returnValue((200, {"threepids": threepids})) + return (200, {"threepids": threepids}) @defer.inlineCallbacks def on_POST(self, request): @@ -568,7 +568,7 @@ class ThreepidRestServlet(RestServlet): logger.debug("Binding threepid %s to %s", threepid, user_id) yield self.identity_handler.bind_threepid(threePidCreds, user_id) - defer.returnValue((200, {})) + return (200, {}) class ThreepidDeleteRestServlet(RestServlet): @@ -603,7 +603,7 @@ class ThreepidDeleteRestServlet(RestServlet): else: id_server_unbind_result = "no-support" - defer.returnValue((200, {"id_server_unbind_result": id_server_unbind_result})) + return (200, {"id_server_unbind_result": id_server_unbind_result}) class WhoamiRestServlet(RestServlet): @@ -617,7 +617,7 @@ class WhoamiRestServlet(RestServlet): def on_GET(self, request): requester = yield self.auth.get_user_by_req(request) - defer.returnValue((200, {"user_id": requester.user.to_string()})) + return (200, {"user_id": requester.user.to_string()}) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py index f155c26259..98f2f6f4b5 100644 --- a/synapse/rest/client/v2_alpha/account_data.py +++ b/synapse/rest/client/v2_alpha/account_data.py @@ -55,7 +55,7 @@ class AccountDataServlet(RestServlet): self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) - defer.returnValue((200, {})) + return (200, {}) @defer.inlineCallbacks def on_GET(self, request, user_id, account_data_type): @@ -70,7 +70,7 @@ class AccountDataServlet(RestServlet): if event is None: raise NotFoundError("Account data not found") - defer.returnValue((200, event)) + return (200, event) class RoomAccountDataServlet(RestServlet): @@ -112,7 +112,7 @@ class RoomAccountDataServlet(RestServlet): self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) - defer.returnValue((200, {})) + return (200, {}) @defer.inlineCallbacks def on_GET(self, request, user_id, room_id, account_data_type): @@ -127,7 +127,7 @@ class RoomAccountDataServlet(RestServlet): if event is None: raise NotFoundError("Room account data not found") - defer.returnValue((200, event)) + return (200, event) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py index d29c10b83d..133c61900a 100644 --- a/synapse/rest/client/v2_alpha/account_validity.py +++ b/synapse/rest/client/v2_alpha/account_validity.py @@ -58,7 +58,7 @@ class AccountValidityRenewServlet(RestServlet): ) request.write(AccountValidityRenewServlet.SUCCESS_HTML) finish_request(request) - defer.returnValue(None) + return None class AccountValiditySendMailServlet(RestServlet): @@ -87,7 +87,7 @@ class AccountValiditySendMailServlet(RestServlet): user_id = requester.user.to_string() yield self.account_activity_handler.send_renewal_email_to_user(user_id) - defer.returnValue((200, {})) + return (200, {}) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index bebc2951e7..f21aff39e5 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -207,7 +207,7 @@ class AuthRestServlet(RestServlet): request.write(html_bytes) finish_request(request) - defer.returnValue(None) + return None elif stagetype == LoginType.TERMS: if ("session" not in request.args or len(request.args["session"])) == 0: raise SynapseError(400, "No session supplied") @@ -239,7 +239,7 @@ class AuthRestServlet(RestServlet): request.write(html_bytes) finish_request(request) - defer.returnValue(None) + return None else: raise SynapseError(404, "Unknown auth stage type") diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py index fc7e2f4dd5..a4fa45fe11 100644 --- a/synapse/rest/client/v2_alpha/capabilities.py +++ b/synapse/rest/client/v2_alpha/capabilities.py @@ -58,7 +58,7 @@ class CapabilitiesRestServlet(RestServlet): "m.change_password": {"enabled": change_password}, } } - defer.returnValue((200, response)) + return (200, response) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index d279229d74..9adf76cc0c 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -48,7 +48,7 @@ class DevicesRestServlet(RestServlet): devices = yield self.device_handler.get_devices_by_user( requester.user.to_string() ) - defer.returnValue((200, {"devices": devices})) + return (200, {"devices": devices}) class DeleteDevicesRestServlet(RestServlet): @@ -91,7 +91,7 @@ class DeleteDevicesRestServlet(RestServlet): yield self.device_handler.delete_devices( requester.user.to_string(), body["devices"] ) - defer.returnValue((200, {})) + return (200, {}) class DeviceRestServlet(RestServlet): @@ -114,7 +114,7 @@ class DeviceRestServlet(RestServlet): device = yield self.device_handler.get_device( requester.user.to_string(), device_id ) - defer.returnValue((200, device)) + return (200, device) @interactive_auth_handler @defer.inlineCallbacks @@ -137,7 +137,7 @@ class DeviceRestServlet(RestServlet): ) yield self.device_handler.delete_device(requester.user.to_string(), device_id) - defer.returnValue((200, {})) + return (200, {}) @defer.inlineCallbacks def on_PUT(self, request, device_id): @@ -147,7 +147,7 @@ class DeviceRestServlet(RestServlet): yield self.device_handler.update_device( requester.user.to_string(), device_id, body ) - defer.returnValue((200, {})) + return (200, {}) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index 3f0adf4a21..22be0ee3c5 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -56,7 +56,7 @@ class GetFilterRestServlet(RestServlet): user_localpart=target_user.localpart, filter_id=filter_id ) - defer.returnValue((200, filter.get_filter_json())) + return (200, filter.get_filter_json()) except (KeyError, StoreError): raise SynapseError(400, "No such filter", errcode=Codes.NOT_FOUND) @@ -89,7 +89,7 @@ class CreateFilterRestServlet(RestServlet): user_localpart=target_user.localpart, user_filter=content ) - defer.returnValue((200, {"filter_id": str(filter_id)})) + return (200, {"filter_id": str(filter_id)}) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py index a312dd2593..e629c4256d 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py @@ -47,7 +47,7 @@ class GroupServlet(RestServlet): group_id, requester_user_id ) - defer.returnValue((200, group_description)) + return (200, group_description) @defer.inlineCallbacks def on_POST(self, request, group_id): @@ -59,7 +59,7 @@ class GroupServlet(RestServlet): group_id, requester_user_id, content ) - defer.returnValue((200, {})) + return (200, {}) class GroupSummaryServlet(RestServlet): @@ -83,7 +83,7 @@ class GroupSummaryServlet(RestServlet): group_id, requester_user_id ) - defer.returnValue((200, get_group_summary)) + return (200, get_group_summary) class GroupSummaryRoomsCatServlet(RestServlet): @@ -120,7 +120,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): content=content, ) - defer.returnValue((200, resp)) + return (200, resp) @defer.inlineCallbacks def on_DELETE(self, request, group_id, category_id, room_id): @@ -131,7 +131,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): group_id, requester_user_id, room_id=room_id, category_id=category_id ) - defer.returnValue((200, resp)) + return (200, resp) class GroupCategoryServlet(RestServlet): @@ -157,7 +157,7 @@ class GroupCategoryServlet(RestServlet): group_id, requester_user_id, category_id=category_id ) - defer.returnValue((200, category)) + return (200, category) @defer.inlineCallbacks def on_PUT(self, request, group_id, category_id): @@ -169,7 +169,7 @@ class GroupCategoryServlet(RestServlet): group_id, requester_user_id, category_id=category_id, content=content ) - defer.returnValue((200, resp)) + return (200, resp) @defer.inlineCallbacks def on_DELETE(self, request, group_id, category_id): @@ -180,7 +180,7 @@ class GroupCategoryServlet(RestServlet): group_id, requester_user_id, category_id=category_id ) - defer.returnValue((200, resp)) + return (200, resp) class GroupCategoriesServlet(RestServlet): @@ -204,7 +204,7 @@ class GroupCategoriesServlet(RestServlet): group_id, requester_user_id ) - defer.returnValue((200, category)) + return (200, category) class GroupRoleServlet(RestServlet): @@ -228,7 +228,7 @@ class GroupRoleServlet(RestServlet): group_id, requester_user_id, role_id=role_id ) - defer.returnValue((200, category)) + return (200, category) @defer.inlineCallbacks def on_PUT(self, request, group_id, role_id): @@ -240,7 +240,7 @@ class GroupRoleServlet(RestServlet): group_id, requester_user_id, role_id=role_id, content=content ) - defer.returnValue((200, resp)) + return (200, resp) @defer.inlineCallbacks def on_DELETE(self, request, group_id, role_id): @@ -251,7 +251,7 @@ class GroupRoleServlet(RestServlet): group_id, requester_user_id, role_id=role_id ) - defer.returnValue((200, resp)) + return (200, resp) class GroupRolesServlet(RestServlet): @@ -275,7 +275,7 @@ class GroupRolesServlet(RestServlet): group_id, requester_user_id ) - defer.returnValue((200, category)) + return (200, category) class GroupSummaryUsersRoleServlet(RestServlet): @@ -312,7 +312,7 @@ class GroupSummaryUsersRoleServlet(RestServlet): content=content, ) - defer.returnValue((200, resp)) + return (200, resp) @defer.inlineCallbacks def on_DELETE(self, request, group_id, role_id, user_id): @@ -323,7 +323,7 @@ class GroupSummaryUsersRoleServlet(RestServlet): group_id, requester_user_id, user_id=user_id, role_id=role_id ) - defer.returnValue((200, resp)) + return (200, resp) class GroupRoomServlet(RestServlet): @@ -347,7 +347,7 @@ class GroupRoomServlet(RestServlet): group_id, requester_user_id ) - defer.returnValue((200, result)) + return (200, result) class GroupUsersServlet(RestServlet): @@ -371,7 +371,7 @@ class GroupUsersServlet(RestServlet): group_id, requester_user_id ) - defer.returnValue((200, result)) + return (200, result) class GroupInvitedUsersServlet(RestServlet): @@ -395,7 +395,7 @@ class GroupInvitedUsersServlet(RestServlet): group_id, requester_user_id ) - defer.returnValue((200, result)) + return (200, result) class GroupSettingJoinPolicyServlet(RestServlet): @@ -420,7 +420,7 @@ class GroupSettingJoinPolicyServlet(RestServlet): group_id, requester_user_id, content ) - defer.returnValue((200, result)) + return (200, result) class GroupCreateServlet(RestServlet): @@ -450,7 +450,7 @@ class GroupCreateServlet(RestServlet): group_id, requester_user_id, content ) - defer.returnValue((200, result)) + return (200, result) class GroupAdminRoomsServlet(RestServlet): @@ -477,7 +477,7 @@ class GroupAdminRoomsServlet(RestServlet): group_id, requester_user_id, room_id, content ) - defer.returnValue((200, result)) + return (200, result) @defer.inlineCallbacks def on_DELETE(self, request, group_id, room_id): @@ -488,7 +488,7 @@ class GroupAdminRoomsServlet(RestServlet): group_id, requester_user_id, room_id ) - defer.returnValue((200, result)) + return (200, result) class GroupAdminRoomsConfigServlet(RestServlet): @@ -516,7 +516,7 @@ class GroupAdminRoomsConfigServlet(RestServlet): group_id, requester_user_id, room_id, config_key, content ) - defer.returnValue((200, result)) + return (200, result) class GroupAdminUsersInviteServlet(RestServlet): @@ -546,7 +546,7 @@ class GroupAdminUsersInviteServlet(RestServlet): group_id, user_id, requester_user_id, config ) - defer.returnValue((200, result)) + return (200, result) class GroupAdminUsersKickServlet(RestServlet): @@ -573,7 +573,7 @@ class GroupAdminUsersKickServlet(RestServlet): group_id, user_id, requester_user_id, content ) - defer.returnValue((200, result)) + return (200, result) class GroupSelfLeaveServlet(RestServlet): @@ -598,7 +598,7 @@ class GroupSelfLeaveServlet(RestServlet): group_id, requester_user_id, requester_user_id, content ) - defer.returnValue((200, result)) + return (200, result) class GroupSelfJoinServlet(RestServlet): @@ -623,7 +623,7 @@ class GroupSelfJoinServlet(RestServlet): group_id, requester_user_id, content ) - defer.returnValue((200, result)) + return (200, result) class GroupSelfAcceptInviteServlet(RestServlet): @@ -648,7 +648,7 @@ class GroupSelfAcceptInviteServlet(RestServlet): group_id, requester_user_id, content ) - defer.returnValue((200, result)) + return (200, result) class GroupSelfUpdatePublicityServlet(RestServlet): @@ -672,7 +672,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet): publicise = content["publicise"] yield self.store.update_group_publicity(group_id, requester_user_id, publicise) - defer.returnValue((200, {})) + return (200, {}) class PublicisedGroupsForUserServlet(RestServlet): @@ -694,7 +694,7 @@ class PublicisedGroupsForUserServlet(RestServlet): result = yield self.groups_handler.get_publicised_groups_for_user(user_id) - defer.returnValue((200, result)) + return (200, result) class PublicisedGroupsForUsersServlet(RestServlet): @@ -719,7 +719,7 @@ class PublicisedGroupsForUsersServlet(RestServlet): result = yield self.groups_handler.bulk_get_publicised_groups(user_ids) - defer.returnValue((200, result)) + return (200, result) class GroupsForUserServlet(RestServlet): @@ -741,7 +741,7 @@ class GroupsForUserServlet(RestServlet): result = yield self.groups_handler.get_joined_groups(requester_user_id) - defer.returnValue((200, result)) + return (200, result) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 45c9928b65..6008adec7c 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -95,7 +95,7 @@ class KeyUploadServlet(RestServlet): result = yield self.e2e_keys_handler.upload_keys_for_user( user_id, device_id, body ) - defer.returnValue((200, result)) + return (200, result) class KeyQueryServlet(RestServlet): @@ -149,7 +149,7 @@ class KeyQueryServlet(RestServlet): timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) result = yield self.e2e_keys_handler.query_devices(body, timeout) - defer.returnValue((200, result)) + return (200, result) class KeyChangesServlet(RestServlet): @@ -189,7 +189,7 @@ class KeyChangesServlet(RestServlet): results = yield self.device_handler.get_user_ids_changed(user_id, from_token) - defer.returnValue((200, results)) + return (200, results) class OneTimeKeyServlet(RestServlet): @@ -224,7 +224,7 @@ class OneTimeKeyServlet(RestServlet): timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) result = yield self.e2e_keys_handler.claim_one_time_keys(body, timeout) - defer.returnValue((200, result)) + return (200, result) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py index 728a52328f..d034863a3c 100644 --- a/synapse/rest/client/v2_alpha/notifications.py +++ b/synapse/rest/client/v2_alpha/notifications.py @@ -88,9 +88,7 @@ class NotificationsServlet(RestServlet): returned_push_actions.append(returned_pa) next_token = str(pa["stream_ordering"]) - defer.returnValue( - (200, {"notifications": returned_push_actions, "next_token": next_token}) - ) + return (200, {"notifications": returned_push_actions, "next_token": next_token}) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py index b1b5385b09..b4925c0f59 100644 --- a/synapse/rest/client/v2_alpha/openid.py +++ b/synapse/rest/client/v2_alpha/openid.py @@ -83,16 +83,14 @@ class IdTokenServlet(RestServlet): yield self.store.insert_open_id_token(token, ts_valid_until_ms, user_id) - defer.returnValue( - ( - 200, - { - "access_token": token, - "token_type": "Bearer", - "matrix_server_name": self.server_name, - "expires_in": self.EXPIRES_MS / 1000, - }, - ) + return ( + 200, + { + "access_token": token, + "token_type": "Bearer", + "matrix_server_name": self.server_name, + "expires_in": self.EXPIRES_MS / 1000, + }, ) diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py index e75664279b..d93d6a9f24 100644 --- a/synapse/rest/client/v2_alpha/read_marker.py +++ b/synapse/rest/client/v2_alpha/read_marker.py @@ -59,7 +59,7 @@ class ReadMarkerRestServlet(RestServlet): event_id=read_marker_event_id, ) - defer.returnValue((200, {})) + return (200, {}) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index 488905626a..98a97b7059 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -52,7 +52,7 @@ class ReceiptRestServlet(RestServlet): room_id, receipt_type, user_id=requester.user.to_string(), event_id=event_id ) - defer.returnValue((200, {})) + return (200, {}) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index f327999e59..05ea1459e3 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -95,7 +95,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) ret = yield self.identity_handler.requestEmailToken(**body) - defer.returnValue((200, ret)) + return (200, ret) class MsisdnRegisterRequestTokenRestServlet(RestServlet): @@ -138,7 +138,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): ) ret = yield self.identity_handler.requestMsisdnToken(**body) - defer.returnValue((200, ret)) + return (200, ret) class UsernameAvailabilityRestServlet(RestServlet): @@ -178,7 +178,7 @@ class UsernameAvailabilityRestServlet(RestServlet): yield self.registration_handler.check_username(username) - defer.returnValue((200, {"available": True})) + return (200, {"available": True}) class RegisterRestServlet(RestServlet): @@ -230,7 +230,7 @@ class RegisterRestServlet(RestServlet): if kind == b"guest": ret = yield self._do_guest_registration(body, address=client_addr) - defer.returnValue(ret) + return ret return elif kind != b"user": raise UnrecognizedRequestError( @@ -282,7 +282,7 @@ class RegisterRestServlet(RestServlet): result = yield self._do_appservice_registration( desired_username, access_token, body ) - defer.returnValue((200, result)) # we throw for non 200 responses + return (200, result) # we throw for non 200 responses return # for either shared secret or regular registration, downcase the @@ -301,7 +301,7 @@ class RegisterRestServlet(RestServlet): result = yield self._do_shared_secret_registration( desired_username, desired_password, body ) - defer.returnValue((200, result)) # we throw for non 200 responses + return (200, result) # we throw for non 200 responses return # == Normal User Registration == (everyone else) @@ -500,7 +500,7 @@ class RegisterRestServlet(RestServlet): bind_msisdn=params.get("bind_msisdn"), ) - defer.returnValue((200, return_dict)) + return (200, return_dict) def on_OPTIONS(self, _): return 200, {} @@ -510,7 +510,7 @@ class RegisterRestServlet(RestServlet): user_id = yield self.registration_handler.appservice_register( username, as_token ) - defer.returnValue((yield self._create_registration_details(user_id, body))) + return (yield self._create_registration_details(user_id, body)) @defer.inlineCallbacks def _do_shared_secret_registration(self, username, password, body): @@ -546,7 +546,7 @@ class RegisterRestServlet(RestServlet): ) result = yield self._create_registration_details(user_id, body) - defer.returnValue(result) + return result @defer.inlineCallbacks def _create_registration_details(self, user_id, params): @@ -570,7 +570,7 @@ class RegisterRestServlet(RestServlet): ) result.update({"access_token": access_token, "device_id": device_id}) - defer.returnValue(result) + return result @defer.inlineCallbacks def _do_guest_registration(self, params, address=None): @@ -588,16 +588,14 @@ class RegisterRestServlet(RestServlet): user_id, device_id, initial_display_name, is_guest=True ) - defer.returnValue( - ( - 200, - { - "user_id": user_id, - "device_id": device_id, - "access_token": access_token, - "home_server": self.hs.hostname, - }, - ) + return ( + 200, + { + "user_id": user_id, + "device_id": device_id, + "access_token": access_token, + "home_server": self.hs.hostname, + }, ) diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index 6e52f6d284..6fde3decdb 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -116,7 +116,7 @@ class RelationSendServlet(RestServlet): requester, event_dict=event_dict, txn_id=txn_id ) - defer.returnValue((200, {"event_id": event.event_id})) + return (200, {"event_id": event.event_id}) class RelationPaginationServlet(RestServlet): @@ -196,7 +196,7 @@ class RelationPaginationServlet(RestServlet): return_value["chunk"] = events return_value["original_event"] = original_event - defer.returnValue((200, return_value)) + return (200, return_value) class RelationAggregationPaginationServlet(RestServlet): @@ -268,7 +268,7 @@ class RelationAggregationPaginationServlet(RestServlet): to_token=to_token, ) - defer.returnValue((200, pagination_chunk.to_dict())) + return (200, pagination_chunk.to_dict()) class RelationAggregationGroupPaginationServlet(RestServlet): @@ -354,7 +354,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): return_value = result.to_dict() return_value["chunk"] = events - defer.returnValue((200, return_value)) + return (200, return_value) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py index e7578af804..3fdd4584a3 100644 --- a/synapse/rest/client/v2_alpha/report_event.py +++ b/synapse/rest/client/v2_alpha/report_event.py @@ -72,7 +72,7 @@ class ReportEventRestServlet(RestServlet): received_ts=self.clock.time_msec(), ) - defer.returnValue((200, {})) + return (200, {}) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py index 8d1b810565..10dec96208 100644 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -135,7 +135,7 @@ class RoomKeysServlet(RestServlet): body = {"rooms": {room_id: body}} yield self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) - defer.returnValue((200, {})) + return (200, {}) @defer.inlineCallbacks def on_GET(self, request, room_id, session_id): @@ -218,7 +218,7 @@ class RoomKeysServlet(RestServlet): else: room_keys = room_keys["rooms"][room_id] - defer.returnValue((200, room_keys)) + return (200, room_keys) @defer.inlineCallbacks def on_DELETE(self, request, room_id, session_id): @@ -242,7 +242,7 @@ class RoomKeysServlet(RestServlet): yield self.e2e_room_keys_handler.delete_room_keys( user_id, version, room_id, session_id ) - defer.returnValue((200, {})) + return (200, {}) class RoomKeysNewVersionServlet(RestServlet): @@ -293,7 +293,7 @@ class RoomKeysNewVersionServlet(RestServlet): info = parse_json_object_from_request(request) new_version = yield self.e2e_room_keys_handler.create_version(user_id, info) - defer.returnValue((200, {"version": new_version})) + return (200, {"version": new_version}) # we deliberately don't have a PUT /version, as these things really should # be immutable to avoid people footgunning @@ -338,7 +338,7 @@ class RoomKeysVersionServlet(RestServlet): except SynapseError as e: if e.code == 404: raise SynapseError(404, "No backup found", Codes.NOT_FOUND) - defer.returnValue((200, info)) + return (200, info) @defer.inlineCallbacks def on_DELETE(self, request, version): @@ -358,7 +358,7 @@ class RoomKeysVersionServlet(RestServlet): user_id = requester.user.to_string() yield self.e2e_room_keys_handler.delete_version(user_id, version) - defer.returnValue((200, {})) + return (200, {}) @defer.inlineCallbacks def on_PUT(self, request, version): @@ -392,7 +392,7 @@ class RoomKeysVersionServlet(RestServlet): ) yield self.e2e_room_keys_handler.update_version(user_id, version, info) - defer.returnValue((200, {})) + return (200, {}) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py index d7f7faa029..14ba61a63e 100644 --- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py +++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py @@ -80,7 +80,7 @@ class RoomUpgradeRestServlet(RestServlet): ret = {"replacement_room": new_room_id} - defer.returnValue((200, ret)) + return (200, ret) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py index 78075b8fc0..2613648d82 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -60,7 +60,7 @@ class SendToDeviceRestServlet(servlet.RestServlet): ) response = (200, {}) - defer.returnValue(response) + return response def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 02d56dee6c..7b32dd2212 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -174,7 +174,7 @@ class SyncRestServlet(RestServlet): time_now, sync_result, requester.access_token_id, filter ) - defer.returnValue((200, response_content)) + return (200, response_content) @defer.inlineCallbacks def encode_response(self, time_now, sync_result, access_token_id, filter): @@ -205,27 +205,23 @@ class SyncRestServlet(RestServlet): event_formatter, ) - defer.returnValue( - { - "account_data": {"events": sync_result.account_data}, - "to_device": {"events": sync_result.to_device}, - "device_lists": { - "changed": list(sync_result.device_lists.changed), - "left": list(sync_result.device_lists.left), - }, - "presence": SyncRestServlet.encode_presence( - sync_result.presence, time_now - ), - "rooms": {"join": joined, "invite": invited, "leave": archived}, - "groups": { - "join": sync_result.groups.join, - "invite": sync_result.groups.invite, - "leave": sync_result.groups.leave, - }, - "device_one_time_keys_count": sync_result.device_one_time_keys_count, - "next_batch": sync_result.next_batch.to_string(), - } - ) + return { + "account_data": {"events": sync_result.account_data}, + "to_device": {"events": sync_result.to_device}, + "device_lists": { + "changed": list(sync_result.device_lists.changed), + "left": list(sync_result.device_lists.left), + }, + "presence": SyncRestServlet.encode_presence(sync_result.presence, time_now), + "rooms": {"join": joined, "invite": invited, "leave": archived}, + "groups": { + "join": sync_result.groups.join, + "invite": sync_result.groups.invite, + "leave": sync_result.groups.leave, + }, + "device_one_time_keys_count": sync_result.device_one_time_keys_count, + "next_batch": sync_result.next_batch.to_string(), + } @staticmethod def encode_presence(events, time_now): @@ -273,7 +269,7 @@ class SyncRestServlet(RestServlet): event_formatter=event_formatter, ) - defer.returnValue(joined) + return joined @defer.inlineCallbacks def encode_invited(self, rooms, time_now, token_id, event_formatter): @@ -309,7 +305,7 @@ class SyncRestServlet(RestServlet): invited_state.append(invite) invited[room.room_id] = {"invite_state": {"events": invited_state}} - defer.returnValue(invited) + return invited @defer.inlineCallbacks def encode_archived(self, rooms, time_now, token_id, event_fields, event_formatter): @@ -342,7 +338,7 @@ class SyncRestServlet(RestServlet): event_formatter=event_formatter, ) - defer.returnValue(joined) + return joined @defer.inlineCallbacks def encode_room( @@ -414,7 +410,7 @@ class SyncRestServlet(RestServlet): result["unread_notifications"] = room.unread_notifications result["summary"] = room.summary - defer.returnValue(result) + return result def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index 07b6ede603..d173544355 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -45,7 +45,7 @@ class TagListServlet(RestServlet): tags = yield self.store.get_tags_for_room(user_id, room_id) - defer.returnValue((200, {"tags": tags})) + return (200, {"tags": tags}) class TagServlet(RestServlet): @@ -76,7 +76,7 @@ class TagServlet(RestServlet): self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) - defer.returnValue((200, {})) + return (200, {}) @defer.inlineCallbacks def on_DELETE(self, request, user_id, room_id, tag): @@ -88,7 +88,7 @@ class TagServlet(RestServlet): self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) - defer.returnValue((200, {})) + return (200, {}) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index 1e66662a05..158e686b01 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -40,7 +40,7 @@ class ThirdPartyProtocolsServlet(RestServlet): yield self.auth.get_user_by_req(request, allow_guest=True) protocols = yield self.appservice_handler.get_3pe_protocols() - defer.returnValue((200, protocols)) + return (200, protocols) class ThirdPartyProtocolServlet(RestServlet): @@ -60,9 +60,9 @@ class ThirdPartyProtocolServlet(RestServlet): only_protocol=protocol ) if protocol in protocols: - defer.returnValue((200, protocols[protocol])) + return (200, protocols[protocol]) else: - defer.returnValue((404, {"error": "Unknown protocol"})) + return (404, {"error": "Unknown protocol"}) class ThirdPartyUserServlet(RestServlet): @@ -85,7 +85,7 @@ class ThirdPartyUserServlet(RestServlet): ThirdPartyEntityKind.USER, protocol, fields ) - defer.returnValue((200, results)) + return (200, results) class ThirdPartyLocationServlet(RestServlet): @@ -108,7 +108,7 @@ class ThirdPartyLocationServlet(RestServlet): ThirdPartyEntityKind.LOCATION, protocol, fields ) - defer.returnValue((200, results)) + return (200, results) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py index e19fb6d583..7ab2b80e46 100644 --- a/synapse/rest/client/v2_alpha/user_directory.py +++ b/synapse/rest/client/v2_alpha/user_directory.py @@ -60,7 +60,7 @@ class UserDirectorySearchRestServlet(RestServlet): user_id = requester.user.to_string() if not self.hs.config.user_directory_search_enabled: - defer.returnValue((200, {"limited": False, "results": []})) + return (200, {"limited": False, "results": []}) body = parse_json_object_from_request(request) @@ -76,7 +76,7 @@ class UserDirectorySearchRestServlet(RestServlet): user_id, search_term, limit ) - defer.returnValue((200, results)) + return (200, results) def register_servlets(hs, http_server): diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 65afffbb42..92beefa176 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -171,7 +171,7 @@ class MediaRepository(object): yield self._generate_thumbnails(None, media_id, media_id, media_type) - defer.returnValue("mxc://%s/%s" % (self.server_name, media_id)) + return "mxc://%s/%s" % (self.server_name, media_id) @defer.inlineCallbacks def get_local_media(self, request, media_id, name): @@ -282,7 +282,7 @@ class MediaRepository(object): with responder: pass - defer.returnValue(media_info) + return media_info @defer.inlineCallbacks def _get_remote_media_impl(self, server_name, media_id): @@ -317,14 +317,14 @@ class MediaRepository(object): responder = yield self.media_storage.fetch_media(file_info) if responder: - defer.returnValue((responder, media_info)) + return (responder, media_info) # Failed to find the file anywhere, lets download it. media_info = yield self._download_remote_file(server_name, media_id, file_id) responder = yield self.media_storage.fetch_media(file_info) - defer.returnValue((responder, media_info)) + return (responder, media_info) @defer.inlineCallbacks def _download_remote_file(self, server_name, media_id, file_id): @@ -421,7 +421,7 @@ class MediaRepository(object): yield self._generate_thumbnails(server_name, media_id, file_id, media_type) - defer.returnValue(media_info) + return media_info def _get_thumbnail_requirements(self, media_type): return self.thumbnail_requirements.get(media_type, ()) @@ -500,7 +500,7 @@ class MediaRepository(object): media_id, t_width, t_height, t_type, t_method, t_len ) - defer.returnValue(output_path) + return output_path @defer.inlineCallbacks def generate_remote_exact_thumbnail( @@ -554,7 +554,7 @@ class MediaRepository(object): t_len, ) - defer.returnValue(output_path) + return output_path @defer.inlineCallbacks def _generate_thumbnails( @@ -667,7 +667,7 @@ class MediaRepository(object): media_id, t_width, t_height, t_type, t_method, t_len ) - defer.returnValue({"width": m_width, "height": m_height}) + return {"width": m_width, "height": m_height} @defer.inlineCallbacks def delete_old_remote_media(self, before_ts): @@ -704,7 +704,7 @@ class MediaRepository(object): yield self.store.delete_remote_media(origin, media_id) deleted += 1 - defer.returnValue({"deleted": deleted}) + return {"deleted": deleted} class MediaRepositoryResource(Resource): diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index 25e5ac2848..3b87717a5a 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -69,7 +69,7 @@ class MediaStorage(object): ) yield finish_cb() - defer.returnValue(fname) + return fname @contextlib.contextmanager def store_into_file(self, file_info): @@ -143,14 +143,14 @@ class MediaStorage(object): path = self._file_info_to_path(file_info) local_path = os.path.join(self.local_media_directory, path) if os.path.exists(local_path): - defer.returnValue(FileResponder(open(local_path, "rb"))) + return FileResponder(open(local_path, "rb")) for provider in self.storage_providers: res = yield provider.fetch(path, file_info) if res: - defer.returnValue(res) + return res - defer.returnValue(None) + return None @defer.inlineCallbacks def ensure_media_is_in_local_cache(self, file_info): @@ -166,7 +166,7 @@ class MediaStorage(object): path = self._file_info_to_path(file_info) local_path = os.path.join(self.local_media_directory, path) if os.path.exists(local_path): - defer.returnValue(local_path) + return local_path dirname = os.path.dirname(local_path) if not os.path.exists(dirname): @@ -181,7 +181,7 @@ class MediaStorage(object): ) yield res.write_to_consumer(consumer) yield consumer.wait() - defer.returnValue(local_path) + return local_path raise Exception("file could not be found") diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 5871737bfd..bd40891a7f 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -182,7 +182,7 @@ class PreviewUrlResource(DirectServeResource): og = cache_result["og"] if isinstance(og, six.text_type): og = og.encode("utf8") - defer.returnValue(og) + return og return media_info = yield self._download_url(url, user) @@ -284,7 +284,7 @@ class PreviewUrlResource(DirectServeResource): media_info["created_ts"], ) - defer.returnValue(jsonog) + return jsonog @defer.inlineCallbacks def _download_url(self, url, user): @@ -354,22 +354,20 @@ class PreviewUrlResource(DirectServeResource): # therefore not expire it. raise - defer.returnValue( - { - "media_type": media_type, - "media_length": length, - "download_name": download_name, - "created_ts": time_now_ms, - "filesystem_id": file_id, - "filename": fname, - "uri": uri, - "response_code": code, - # FIXME: we should calculate a proper expiration based on the - # Cache-Control and Expire headers. But for now, assume 1 hour. - "expires": 60 * 60 * 1000, - "etag": headers["ETag"][0] if "ETag" in headers else None, - } - ) + return { + "media_type": media_type, + "media_length": length, + "download_name": download_name, + "created_ts": time_now_ms, + "filesystem_id": file_id, + "filename": fname, + "uri": uri, + "response_code": code, + # FIXME: we should calculate a proper expiration based on the + # Cache-Control and Expire headers. But for now, assume 1 hour. + "expires": 60 * 60 * 1000, + "etag": headers["ETag"][0] if "ETag" in headers else None, + } def _start_expire_url_cache_data(self): return run_as_background_process( diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index f183743f31..729c097e6d 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -193,4 +193,4 @@ class ResourceLimitsServerNotices(object): if event_id in referenced_events: referenced_events.remove(event.event_id) - defer.returnValue((currently_blocked, referenced_events)) + return (currently_blocked, referenced_events) diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index 71e7e75320..2dac90578c 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -86,7 +86,7 @@ class ServerNoticesManager(object): res = yield self._event_creation_handler.create_and_send_nonmember_event( requester, event_dict, ratelimit=False ) - defer.returnValue(res) + return res @cachedInlineCallbacks() def get_notice_room_for_user(self, user_id): @@ -120,7 +120,7 @@ class ServerNoticesManager(object): # we found a room which our user shares with the system notice # user logger.info("Using room %s", room.room_id) - defer.returnValue(room.room_id) + return room.room_id # apparently no existing notice room: create a new one logger.info("Creating server notices room for %s", user_id) @@ -158,4 +158,4 @@ class ServerNoticesManager(object): self._notifier.on_new_event("account_data_key", max_id, users=[user_id]) logger.info("Created server notices room %s for %s", room_id, user_id) - defer.returnValue(room_id) + return room_id diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 9f708fa205..a0d34f16ea 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -135,7 +135,7 @@ class StateHandler(object): event = None if event_id: event = yield self.store.get_event(event_id, allow_none=True) - defer.returnValue(event) + return event return state_map = yield self.store.get_events( @@ -145,7 +145,7 @@ class StateHandler(object): key: state_map[e_id] for key, e_id in iteritems(state) if e_id in state_map } - defer.returnValue(state) + return state @defer.inlineCallbacks def get_current_state_ids(self, room_id, latest_event_ids=None): @@ -169,7 +169,7 @@ class StateHandler(object): ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids) state = ret.state - defer.returnValue(state) + return state @defer.inlineCallbacks def get_current_users_in_room(self, room_id, latest_event_ids=None): @@ -189,7 +189,7 @@ class StateHandler(object): logger.debug("calling resolve_state_groups from get_current_users_in_room") entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids) joined_users = yield self.store.get_joined_users_from_state(room_id, entry) - defer.returnValue(joined_users) + return joined_users @defer.inlineCallbacks def get_current_hosts_in_room(self, room_id, latest_event_ids=None): @@ -198,7 +198,7 @@ class StateHandler(object): logger.debug("calling resolve_state_groups from get_current_hosts_in_room") entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids) joined_hosts = yield self.store.get_joined_hosts(room_id, entry) - defer.returnValue(joined_hosts) + return joined_hosts @defer.inlineCallbacks def compute_event_context(self, event, old_state=None): @@ -241,7 +241,7 @@ class StateHandler(object): prev_state_ids=prev_state_ids, ) - defer.returnValue(context) + return context if old_state: # We already have the state, so we don't need to calculate it. @@ -275,7 +275,7 @@ class StateHandler(object): prev_state_ids=prev_state_ids, ) - defer.returnValue(context) + return context logger.debug("calling resolve_state_groups from compute_event_context") @@ -343,7 +343,7 @@ class StateHandler(object): delta_ids=delta_ids, ) - defer.returnValue(context) + return context @defer.inlineCallbacks def resolve_state_groups_for_events(self, room_id, event_ids): @@ -368,19 +368,17 @@ class StateHandler(object): state_groups_ids = yield self.store.get_state_groups_ids(room_id, event_ids) if len(state_groups_ids) == 0: - defer.returnValue(_StateCacheEntry(state={}, state_group=None)) + return _StateCacheEntry(state={}, state_group=None) elif len(state_groups_ids) == 1: name, state_list = list(state_groups_ids.items()).pop() prev_group, delta_ids = yield self.store.get_state_group_delta(name) - defer.returnValue( - _StateCacheEntry( - state=state_list, - state_group=name, - prev_group=prev_group, - delta_ids=delta_ids, - ) + return _StateCacheEntry( + state=state_list, + state_group=name, + prev_group=prev_group, + delta_ids=delta_ids, ) room_version = yield self.store.get_room_version(room_id) @@ -392,7 +390,7 @@ class StateHandler(object): None, state_res_store=StateResolutionStore(self.store), ) - defer.returnValue(result) + return result @defer.inlineCallbacks def resolve_events(self, room_version, state_sets, event): @@ -415,7 +413,7 @@ class StateHandler(object): new_state = {key: state_map[ev_id] for key, ev_id in iteritems(new_state)} - defer.returnValue(new_state) + return new_state class StateResolutionHandler(object): @@ -479,7 +477,7 @@ class StateResolutionHandler(object): if self._state_cache is not None: cache = self._state_cache.get(group_names, None) if cache: - defer.returnValue(cache) + return cache logger.info( "Resolving state for %s with %d groups", room_id, len(state_groups_ids) @@ -525,7 +523,7 @@ class StateResolutionHandler(object): if self._state_cache is not None: self._state_cache[group_names] = cache - defer.returnValue(cache) + return cache def _make_state_cache_entry(new_state, state_groups_ids): diff --git a/synapse/state/v1.py b/synapse/state/v1.py index 88acd4817e..a2f92d9ff9 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -55,7 +55,7 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory): a map from (type, state_key) to event_id. """ if len(state_sets) == 1: - defer.returnValue(state_sets[0]) + return state_sets[0] unconflicted_state, conflicted_state = _seperate(state_sets) @@ -97,10 +97,8 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory): state_map_new = yield state_map_factory(new_needed_events) state_map.update(state_map_new) - defer.returnValue( - _resolve_with_state( - unconflicted_state, conflicted_state, auth_events, state_map - ) + return _resolve_with_state( + unconflicted_state, conflicted_state, auth_events, state_map ) diff --git a/synapse/state/v2.py b/synapse/state/v2.py index db969e8997..b327c86f40 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -63,7 +63,7 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto unconflicted_state, conflicted_state = _seperate(state_sets) if not conflicted_state: - defer.returnValue(unconflicted_state) + return unconflicted_state logger.debug("%d conflicted state entries", len(conflicted_state)) logger.debug("Calculating auth chain difference") @@ -137,7 +137,7 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto logger.debug("done") - defer.returnValue(resolved_state) + return resolved_state @defer.inlineCallbacks @@ -168,18 +168,18 @@ def _get_power_level_for_sender(event_id, event_map, state_res_store): aev = yield _get_event(aid, event_map, state_res_store) if (aev.type, aev.state_key) == (EventTypes.Create, ""): if aev.content.get("creator") == event.sender: - defer.returnValue(100) + return 100 break - defer.returnValue(0) + return 0 level = pl.content.get("users", {}).get(event.sender) if level is None: level = pl.content.get("users_default", 0) if level is None: - defer.returnValue(0) + return 0 else: - defer.returnValue(int(level)) + return int(level) @defer.inlineCallbacks @@ -224,7 +224,7 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store): intersection = set(auth_sets[0]).intersection(*auth_sets[1:]) union = set().union(*auth_sets) - defer.returnValue(union - intersection) + return union - intersection def _seperate(state_sets): @@ -343,7 +343,7 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_ it = lexicographical_topological_sort(graph, key=_get_power_order) sorted_events = list(it) - defer.returnValue(sorted_events) + return sorted_events @defer.inlineCallbacks @@ -396,7 +396,7 @@ def _iterative_auth_checks( except AuthError: pass - defer.returnValue(resolved_state) + return resolved_state @defer.inlineCallbacks @@ -439,7 +439,7 @@ def _mainline_sort(event_ids, resolved_power_event_id, event_map, state_res_stor event_ids.sort(key=lambda ev_id: order_map[ev_id]) - defer.returnValue(event_ids) + return event_ids @defer.inlineCallbacks @@ -462,7 +462,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor while event: depth = mainline_map.get(event.event_id) if depth is not None: - defer.returnValue(depth) + return depth auth_events = event.auth_event_ids() event = None @@ -474,7 +474,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor break # Didn't find a power level auth event, so we just return 0 - defer.returnValue(0) + return 0 @defer.inlineCallbacks @@ -493,7 +493,7 @@ def _get_event(event_id, event_map, state_res_store): if event_id not in event_map: events = yield state_res_store.get_events([event_id], allow_rejected=True) event_map.update(events) - defer.returnValue(event_map[event_id]) + return event_map[event_id] def lexicographical_topological_sort(graph, key): diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 86a333a919..e7f6ea7286 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -498,7 +498,7 @@ class DataStore( ) count = yield self.runInteraction("get_users_paginate", self.get_user_count_txn) retval = {"users": users, "total": count} - defer.returnValue(retval) + return retval def search_users(self, term): """Function to search users list for one or more users with diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index a7c93efa46..489ce82fae 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -513,7 +513,7 @@ class SQLBaseStore(object): after_callback(*after_args, **after_kwargs) raise - defer.returnValue(result) + return result @defer.inlineCallbacks def runWithConnection(self, func, *args, **kwargs): @@ -553,7 +553,7 @@ class SQLBaseStore(object): with PreserveLoggingContext(): result = yield self._db_pool.runWithConnection(inner_func, *args, **kwargs) - defer.returnValue(result) + return result @staticmethod def cursor_to_dict(cursor): @@ -615,8 +615,8 @@ class SQLBaseStore(object): # a cursor after we receive an error from the db. if not or_ignore: raise - defer.returnValue(False) - defer.returnValue(True) + return False + return True @staticmethod def _simple_insert_txn(txn, table, values): @@ -708,7 +708,7 @@ class SQLBaseStore(object): insertion_values, lock=lock, ) - defer.returnValue(result) + return result except self.database_engine.module.IntegrityError as e: attempts += 1 if attempts >= 5: @@ -1121,7 +1121,7 @@ class SQLBaseStore(object): results = [] if not iterable: - defer.returnValue(results) + return results # iterables can not be sliced, so convert it to a list first it_list = list(iterable) @@ -1142,7 +1142,7 @@ class SQLBaseStore(object): results.extend(rows) - defer.returnValue(results) + return results @classmethod def _simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols): diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index 8394389073..9fa5b4f3d6 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -111,9 +111,9 @@ class AccountDataWorkerStore(SQLBaseStore): ) if result: - defer.returnValue(json.loads(result)) + return json.loads(result) else: - defer.returnValue(None) + return None @cached(num_args=2) def get_account_data_for_room(self, user_id, room_id): @@ -264,11 +264,9 @@ class AccountDataWorkerStore(SQLBaseStore): on_invalidate=cache_context.invalidate, ) if not ignored_account_data: - defer.returnValue(False) + return False - defer.returnValue( - ignored_user_id in ignored_account_data.get("ignored_users", {}) - ) + return ignored_user_id in ignored_account_data.get("ignored_users", {}) class AccountDataStore(AccountDataWorkerStore): @@ -332,7 +330,7 @@ class AccountDataStore(AccountDataWorkerStore): ) result = self._account_data_id_gen.get_current_token() - defer.returnValue(result) + return result @defer.inlineCallbacks def add_account_data_for_user(self, user_id, account_data_type, content): @@ -373,7 +371,7 @@ class AccountDataStore(AccountDataWorkerStore): ) result = self._account_data_id_gen.get_current_token() - defer.returnValue(result) + return result def _update_max_stream_id(self, next_id): """Update the max stream_id diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py index eb329ebd8b..05d9c05c3f 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/appservice.py @@ -145,7 +145,7 @@ class ApplicationServiceTransactionWorkerStore( for service in as_list: if service.id == res["as_id"]: services.append(service) - defer.returnValue(services) + return services @defer.inlineCallbacks def get_appservice_state(self, service): @@ -164,9 +164,9 @@ class ApplicationServiceTransactionWorkerStore( desc="get_appservice_state", ) if result: - defer.returnValue(result.get("state")) + return result.get("state") return - defer.returnValue(None) + return None def set_appservice_state(self, service, state): """Set the application service state. @@ -298,15 +298,13 @@ class ApplicationServiceTransactionWorkerStore( ) if not entry: - defer.returnValue(None) + return None event_ids = json.loads(entry["event_ids"]) events = yield self.get_events_as_list(event_ids) - defer.returnValue( - AppServiceTransaction(service=service, id=entry["txn_id"], events=events) - ) + return AppServiceTransaction(service=service, id=entry["txn_id"], events=events) def _get_last_txn(self, txn, service_id): txn.execute( @@ -360,7 +358,7 @@ class ApplicationServiceTransactionWorkerStore( events = yield self.get_events_as_list(event_ids) - defer.returnValue((upper_bound, events)) + return (upper_bound, events) class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore): diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 50f913a414..e5f0668f09 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -115,7 +115,7 @@ class BackgroundUpdateStore(SQLBaseStore): " Unscheduling background update task." ) self._all_done = True - defer.returnValue(None) + return None @defer.inlineCallbacks def has_completed_background_updates(self): @@ -127,11 +127,11 @@ class BackgroundUpdateStore(SQLBaseStore): # if we've previously determined that there is nothing left to do, that # is easy if self._all_done: - defer.returnValue(True) + return True # obviously, if we have things in our queue, we're not done. if self._background_update_queue: - defer.returnValue(False) + return False # otherwise, check if there are updates to be run. This is important, # as we may be running on a worker which doesn't perform the bg updates @@ -144,9 +144,9 @@ class BackgroundUpdateStore(SQLBaseStore): ) if not updates: self._all_done = True - defer.returnValue(True) + return True - defer.returnValue(False) + return False @defer.inlineCallbacks def do_next_background_update(self, desired_duration_ms): @@ -173,14 +173,14 @@ class BackgroundUpdateStore(SQLBaseStore): if not self._background_update_queue: # no work left to do - defer.returnValue(None) + return None # pop from the front, and add back to the back update_name = self._background_update_queue.pop(0) self._background_update_queue.append(update_name) res = yield self._do_background_update(update_name, desired_duration_ms) - defer.returnValue(res) + return res @defer.inlineCallbacks def _do_background_update(self, update_name, desired_duration_ms): @@ -231,7 +231,7 @@ class BackgroundUpdateStore(SQLBaseStore): performance.update(items_updated, duration_ms) - defer.returnValue(len(self._background_update_performance)) + return len(self._background_update_performance) def register_background_update_handler(self, update_name, update_handler): """Register a handler for doing a background update. @@ -266,7 +266,7 @@ class BackgroundUpdateStore(SQLBaseStore): @defer.inlineCallbacks def noop_update(progress, batch_size): yield self._end_background_update(update_name) - defer.returnValue(1) + return 1 self.register_background_update_handler(update_name, noop_update) @@ -370,7 +370,7 @@ class BackgroundUpdateStore(SQLBaseStore): logger.info("Adding index %s to %s", index_name, table) yield self.runWithConnection(runner) yield self._end_background_update(update_name) - defer.returnValue(1) + return 1 self.register_background_update_handler(update_name, updater) diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index bda68de5be..6db8c54077 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -104,7 +104,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): yield self.runWithConnection(f) yield self._end_background_update("user_ips_drop_nonunique_index") - defer.returnValue(1) + return 1 @defer.inlineCallbacks def _analyze_user_ip(self, progress, batch_size): @@ -121,7 +121,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): yield self._end_background_update("user_ips_analyze") - defer.returnValue(1) + return 1 @defer.inlineCallbacks def _remove_user_ip_dupes(self, progress, batch_size): @@ -291,7 +291,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): if last: yield self._end_background_update("user_ips_remove_dupes") - defer.returnValue(batch_size) + return batch_size @defer.inlineCallbacks def insert_client_ip( @@ -401,7 +401,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): "device_id": did, "last_seen": last_seen, } - defer.returnValue(ret) + return ret @classmethod def _get_last_client_ip_by_device_txn(cls, txn, user_id, device_id, retcols): @@ -461,14 +461,12 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): ((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"])) for row in rows ) - defer.returnValue( - list( - { - "access_token": access_token, - "ip": ip, - "user_agent": user_agent, - "last_seen": last_seen, - } - for (access_token, ip), (user_agent, last_seen) in iteritems(results) - ) + return list( + { + "access_token": access_token, + "ip": ip, + "user_agent": user_agent, + "last_seen": last_seen, + } + for (access_token, ip), (user_agent, last_seen) in iteritems(results) ) diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index 4ea0deea4f..79bb0ea46d 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -92,7 +92,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): user_id, last_deleted_stream_id ) if not has_changed: - defer.returnValue(0) + return 0 def delete_messages_for_device_txn(txn): sql = ( @@ -115,7 +115,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): last_deleted_stream_id, up_to_stream_id ) - defer.returnValue(count) + return count def get_new_device_msgs_for_remote( self, destination, last_stream_id, current_stream_id, limit @@ -263,7 +263,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore): destination, stream_id ) - defer.returnValue(self._device_inbox_id_gen.get_current_token()) + return self._device_inbox_id_gen.get_current_token() @defer.inlineCallbacks def add_messages_from_remote_to_device_inbox( @@ -312,7 +312,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore): for user_id in local_messages_by_user_then_device.keys(): self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id) - defer.returnValue(stream_id) + return stream_id def _add_messages_to_local_device_inbox_txn( self, txn, stream_id, messages_by_user_then_device @@ -426,4 +426,4 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore): yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID) - defer.returnValue(1) + return 1 diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index d2b113a4e7..8f72d92895 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -71,7 +71,7 @@ class DeviceWorkerStore(SQLBaseStore): desc="get_devices_by_user", ) - defer.returnValue({d["device_id"]: d for d in devices}) + return {d["device_id"]: d for d in devices} @defer.inlineCallbacks def get_devices_by_remote(self, destination, from_stream_id, limit): @@ -88,7 +88,7 @@ class DeviceWorkerStore(SQLBaseStore): destination, int(from_stream_id) ) if not has_changed: - defer.returnValue((now_stream_id, [])) + return (now_stream_id, []) # We retrieve n+1 devices from the list of outbound pokes where n is # our outbound device update limit. We then check if the very last @@ -111,7 +111,7 @@ class DeviceWorkerStore(SQLBaseStore): # Return an empty list if there are no updates if not updates: - defer.returnValue((now_stream_id, [])) + return (now_stream_id, []) # if we have exceeded the limit, we need to exclude any results with the # same stream_id as the last row. @@ -147,13 +147,13 @@ class DeviceWorkerStore(SQLBaseStore): # skip that stream_id and return an empty list, and continue with the next # stream_id next time. if not query_map: - defer.returnValue((stream_id_cutoff, [])) + return (stream_id_cutoff, []) results = yield self._get_device_update_edus_by_remote( destination, from_stream_id, query_map ) - defer.returnValue((now_stream_id, results)) + return (now_stream_id, results) def _get_devices_by_remote_txn( self, txn, destination, from_stream_id, now_stream_id, limit @@ -232,7 +232,7 @@ class DeviceWorkerStore(SQLBaseStore): results.append(result) - defer.returnValue(results) + return results def _get_last_device_update_for_remote_user( self, destination, user_id, from_stream_id @@ -330,7 +330,7 @@ class DeviceWorkerStore(SQLBaseStore): else: results[user_id] = yield self._get_cached_devices_for_user(user_id) - defer.returnValue((user_ids_not_in_cache, results)) + return (user_ids_not_in_cache, results) @cachedInlineCallbacks(num_args=2, tree=True) def _get_cached_user_device(self, user_id, device_id): @@ -340,7 +340,7 @@ class DeviceWorkerStore(SQLBaseStore): retcol="content", desc="_get_cached_user_device", ) - defer.returnValue(db_to_json(content)) + return db_to_json(content) @cachedInlineCallbacks() def _get_cached_devices_for_user(self, user_id): @@ -350,9 +350,9 @@ class DeviceWorkerStore(SQLBaseStore): retcols=("device_id", "content"), desc="_get_cached_devices_for_user", ) - defer.returnValue( - {device["device_id"]: db_to_json(device["content"]) for device in devices} - ) + return { + device["device_id"]: db_to_json(device["content"]) for device in devices + } def get_devices_with_keys_by_user(self, user_id): """Get all devices (with any device keys) for a user @@ -482,7 +482,7 @@ class DeviceWorkerStore(SQLBaseStore): results = {user_id: None for user_id in user_ids} results.update({row["user_id"]: row["stream_id"] for row in rows}) - defer.returnValue(results) + return results class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): @@ -543,7 +543,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): """ key = (user_id, device_id) if self.device_id_exists_cache.get(key, None): - defer.returnValue(False) + return False try: inserted = yield self._simple_insert( @@ -557,7 +557,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): or_ignore=True, ) self.device_id_exists_cache.prefill(key, True) - defer.returnValue(inserted) + return inserted except Exception as e: logger.error( "store_device with device_id=%s(%r) user_id=%s(%r)" @@ -780,7 +780,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): hosts, stream_id, ) - defer.returnValue(stream_id) + return stream_id def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id): now = self._clock.time_msec() @@ -889,4 +889,4 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): yield self.runWithConnection(f) yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES) - defer.returnValue(1) + return 1 diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index 201bbd430c..e966a73f3d 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -46,7 +46,7 @@ class DirectoryWorkerStore(SQLBaseStore): ) if not room_id: - defer.returnValue(None) + return None return servers = yield self._simple_select_onecol( @@ -57,10 +57,10 @@ class DirectoryWorkerStore(SQLBaseStore): ) if not servers: - defer.returnValue(None) + return None return - defer.returnValue(RoomAliasMapping(room_id, room_alias.to_string(), servers)) + return RoomAliasMapping(room_id, room_alias.to_string(), servers) def get_room_alias_creator(self, room_alias): return self._simple_select_one_onecol( @@ -125,7 +125,7 @@ class DirectoryStore(DirectoryWorkerStore): raise SynapseError( 409, "Room alias %s already exists" % room_alias.to_string() ) - defer.returnValue(ret) + return ret @defer.inlineCallbacks def delete_room_alias(self, room_alias): @@ -133,7 +133,7 @@ class DirectoryStore(DirectoryWorkerStore): "delete_room_alias", self._delete_room_alias_txn, room_alias ) - defer.returnValue(room_id) + return room_id def _delete_room_alias_txn(self, txn, room_alias): txn.execute( diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py index f40ef2ab64..99128f2df7 100644 --- a/synapse/storage/e2e_room_keys.py +++ b/synapse/storage/e2e_room_keys.py @@ -61,7 +61,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): row["session_data"] = json.loads(row["session_data"]) - defer.returnValue(row) + return row @defer.inlineCallbacks def set_e2e_room_key(self, user_id, version, room_id, session_id, room_key): @@ -118,7 +118,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): try: version = int(version) except ValueError: - defer.returnValue({"rooms": {}}) + return {"rooms": {}} keyvalues = {"user_id": user_id, "version": version} if room_id: @@ -151,7 +151,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): "session_data": json.loads(row["session_data"]), } - defer.returnValue(sessions) + return sessions @defer.inlineCallbacks def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 2fabb9e2cb..1e07474e70 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -41,7 +41,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): dict containing "key_json", "device_display_name". """ if not query_list: - defer.returnValue({}) + return {} results = yield self.runInteraction( "get_e2e_device_keys", @@ -55,7 +55,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): for device_id, device_info in iteritems(device_keys): device_info["keys"] = db_to_json(device_info.pop("key_json")) - defer.returnValue(results) + return results def _get_e2e_device_keys_txn( self, txn, query_list, include_all_devices=False, include_deleted_devices=False @@ -130,9 +130,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): desc="add_e2e_one_time_keys_check", ) - defer.returnValue( - {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows} - ) + return {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows} @defer.inlineCallbacks def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys): diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index cb4478342f..4f500d893e 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -131,9 +131,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas ) if not rows: - defer.returnValue(0) + return 0 else: - defer.returnValue(max(row["depth"] for row in rows)) + return max(row["depth"] for row in rows) def _get_oldest_events_in_room_txn(self, txn, room_id): return self._simple_select_onecol_txn( @@ -169,7 +169,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas # make sure that we don't completely ignore the older events. res = res[0:5] + random.sample(res[5:], 5) - defer.returnValue(res) + return res def get_latest_event_ids_and_hashes_in_room(self, room_id): """ @@ -411,7 +411,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas limit, ) events = yield self.get_events_as_list(ids) - defer.returnValue(events) + return events def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit): @@ -463,7 +463,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas desc="get_successor_events", ) - defer.returnValue([row["event_id"] for row in rows]) + return [row["event_id"] for row in rows] class EventFederationStore(EventFederationWorkerStore): @@ -654,4 +654,4 @@ class EventFederationStore(EventFederationWorkerStore): if not result: yield self._end_background_update(self.EVENT_AUTH_STATE_ONLY) - defer.returnValue(batch_size) + return batch_size diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index dcfb67e029..22025effbc 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -100,7 +100,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): user_id, last_read_event_id, ) - defer.returnValue(ret) + return ret def _get_unread_counts_by_receipt_txn( self, txn, room_id, user_id, last_read_event_id @@ -178,7 +178,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): return [r[0] for r in txn] ret = yield self.runInteraction("get_push_action_users_in_range", f) - defer.returnValue(ret) + return ret @defer.inlineCallbacks def get_unread_push_actions_for_user_in_range_for_http( @@ -279,7 +279,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): # Take only up to the limit. We have to stop at the limit because # one of the subqueries may have hit the limit. - defer.returnValue(notifs[:limit]) + return notifs[:limit] @defer.inlineCallbacks def get_unread_push_actions_for_user_in_range_for_email( @@ -380,7 +380,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): notifs.sort(key=lambda r: -(r["received_ts"] or 0)) # Now return the first `limit` - defer.returnValue(notifs[:limit]) + return notifs[:limit] def get_if_maybe_push_in_range_for_user(self, user_id, min_stream_ordering): """A fast check to see if there might be something to push for the @@ -477,7 +477,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): keyvalues={"event_id": event_id}, desc="remove_push_actions_from_staging", ) - defer.returnValue(res) + return res except Exception: # this method is called from an exception handler, so propagating # another exception here really isn't helpful - there's nothing @@ -732,7 +732,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): push_actions = yield self.runInteraction("get_push_actions_for_user", f) for pa in push_actions: pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) - defer.returnValue(push_actions) + return push_actions @defer.inlineCallbacks def get_time_of_last_push_action_before(self, stream_ordering): @@ -749,7 +749,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): return txn.fetchone() result = yield self.runInteraction("get_time_of_last_push_action_before", f) - defer.returnValue(result[0] if result else None) + return result[0] if result else None @defer.inlineCallbacks def get_latest_push_action_stream_ordering(self): @@ -758,7 +758,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): return txn.fetchone() result = yield self.runInteraction("get_latest_push_action_stream_ordering", f) - defer.returnValue(result[0] or 0) + return result[0] or 0 def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id): # Sad that we have to blow away the cache for the whole room here diff --git a/synapse/storage/events.py b/synapse/storage/events.py index b70457bfc6..88c0180116 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -223,7 +223,7 @@ def _retry_on_integrity_error(func): except self.database_engine.module.IntegrityError: logger.exception("IntegrityError, retrying.") res = yield func(self, *args, delete_existing=True, **kwargs) - defer.returnValue(res) + return res return f @@ -309,7 +309,7 @@ class EventsStore( max_persisted_id = yield self._stream_id_gen.get_current_token() - defer.returnValue(max_persisted_id) + return max_persisted_id @defer.inlineCallbacks @log_function @@ -334,7 +334,7 @@ class EventsStore( yield make_deferred_yieldable(deferred) max_persisted_id = yield self._stream_id_gen.get_current_token() - defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id)) + return (event.internal_metadata.stream_ordering, max_persisted_id) def _maybe_start_persisting(self, room_id): @defer.inlineCallbacks @@ -595,7 +595,7 @@ class EventsStore( stale = latest_event_ids & result stale_forward_extremities_counter.observe(len(stale)) - defer.returnValue(result) + return result @defer.inlineCallbacks def _get_events_which_are_prevs(self, event_ids): @@ -633,7 +633,7 @@ class EventsStore( "_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk ) - defer.returnValue(results) + return results @defer.inlineCallbacks def _get_prevs_before_rejected(self, event_ids): @@ -695,7 +695,7 @@ class EventsStore( "_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk ) - defer.returnValue(existing_prevs) + return existing_prevs @defer.inlineCallbacks def _get_new_state_after_events( @@ -796,7 +796,7 @@ class EventsStore( # If they old and new groups are the same then we don't need to do # anything. if old_state_groups == new_state_groups: - defer.returnValue((None, None)) + return (None, None) if len(new_state_groups) == 1 and len(old_state_groups) == 1: # If we're going from one state group to another, lets check if @@ -813,7 +813,7 @@ class EventsStore( # the current state in memory then lets also return that, # but it doesn't matter if we don't. new_state = state_groups_map.get(new_state_group) - defer.returnValue((new_state, delta_ids)) + return (new_state, delta_ids) # Now that we have calculated new_state_groups we need to get # their state IDs so we can resolve to a single state set. @@ -825,7 +825,7 @@ class EventsStore( if len(new_state_groups) == 1: # If there is only one state group, then we know what the current # state is. - defer.returnValue((state_groups_map[new_state_groups.pop()], None)) + return (state_groups_map[new_state_groups.pop()], None) # Ok, we need to defer to the state handler to resolve our state sets. @@ -854,7 +854,7 @@ class EventsStore( state_res_store=StateResolutionStore(self), ) - defer.returnValue((res.state, None)) + return (res.state, None) @defer.inlineCallbacks def _calculate_state_delta(self, room_id, current_state): @@ -877,7 +877,7 @@ class EventsStore( if ev_id != existing_state.get(key) } - defer.returnValue((to_delete, to_insert)) + return (to_delete, to_insert) @log_function def _persist_events_txn( @@ -1564,7 +1564,7 @@ class EventsStore( return count ret = yield self.runInteraction("count_messages", _count_messages) - defer.returnValue(ret) + return ret @defer.inlineCallbacks def count_daily_sent_messages(self): @@ -1585,7 +1585,7 @@ class EventsStore( return count ret = yield self.runInteraction("count_daily_sent_messages", _count_messages) - defer.returnValue(ret) + return ret @defer.inlineCallbacks def count_daily_active_rooms(self): @@ -1600,7 +1600,7 @@ class EventsStore( return count ret = yield self.runInteraction("count_daily_active_rooms", _count) - defer.returnValue(ret) + return ret def get_current_backfill_token(self): """The current minimum token that backfilled events have reached""" @@ -2183,7 +2183,7 @@ class EventsStore( """ to_1, so_1 = yield self._get_event_ordering(event_id1) to_2, so_2 = yield self._get_event_ordering(event_id2) - defer.returnValue((to_1, so_1) > (to_2, so_2)) + return (to_1, so_1) > (to_2, so_2) @cachedInlineCallbacks(max_entries=5000) def _get_event_ordering(self, event_id): @@ -2197,9 +2197,7 @@ class EventsStore( if not res: raise SynapseError(404, "Could not find event %s" % (event_id,)) - defer.returnValue( - (int(res["topological_ordering"]), int(res["stream_ordering"])) - ) + return (int(res["topological_ordering"]), int(res["stream_ordering"])) def get_all_updated_current_state_deltas(self, from_token, to_token, limit): def get_all_updated_current_state_deltas_txn(txn): diff --git a/synapse/storage/events_bg_updates.py b/synapse/storage/events_bg_updates.py index 1ce21d190c..6587f31e2b 100644 --- a/synapse/storage/events_bg_updates.py +++ b/synapse/storage/events_bg_updates.py @@ -135,7 +135,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): if not result: yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME) - defer.returnValue(result) + return result @defer.inlineCallbacks def _background_reindex_origin_server_ts(self, progress, batch_size): @@ -212,7 +212,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): if not result: yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME) - defer.returnValue(result) + return result @defer.inlineCallbacks def _cleanup_extremities_bg_update(self, progress, batch_size): @@ -396,4 +396,4 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): "_cleanup_extremities_bg_update_drop_table", _drop_table_txn ) - defer.returnValue(num_handled) + return num_handled diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index 858fc755a1..44441957db 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -157,7 +157,7 @@ class EventsWorkerStore(SQLBaseStore): if event is None and not allow_none: raise NotFoundError("Could not find event %s" % (event_id,)) - defer.returnValue(event) + return event @defer.inlineCallbacks def get_events( @@ -187,7 +187,7 @@ class EventsWorkerStore(SQLBaseStore): allow_rejected=allow_rejected, ) - defer.returnValue({e.event_id: e for e in events}) + return {e.event_id: e for e in events} @defer.inlineCallbacks def get_events_as_list( @@ -217,7 +217,7 @@ class EventsWorkerStore(SQLBaseStore): """ if not event_ids: - defer.returnValue([]) + return [] # there may be duplicates so we cast the list to a set event_entry_map = yield self._get_events_from_cache_or_db( @@ -305,7 +305,7 @@ class EventsWorkerStore(SQLBaseStore): event.unsigned["prev_content"] = prev.content event.unsigned["prev_sender"] = prev.sender - defer.returnValue(events) + return events @defer.inlineCallbacks def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): @@ -452,7 +452,7 @@ class EventsWorkerStore(SQLBaseStore): without having to create a new transaction for each request for events. """ if not events: - defer.returnValue({}) + return {} events_d = defer.Deferred() with self._event_fetch_lock: @@ -496,7 +496,7 @@ class EventsWorkerStore(SQLBaseStore): ) ) - defer.returnValue({e.event.event_id: e for e in res if e}) + return {e.event.event_id: e for e in res if e} def _fetch_event_rows(self, txn, event_ids): """Fetch event rows from the database @@ -609,7 +609,7 @@ class EventsWorkerStore(SQLBaseStore): self._get_event_cache.prefill((original_ev.event_id,), cache_entry) - defer.returnValue(cache_entry) + return cache_entry @defer.inlineCallbacks def _maybe_redact_event_row(self, original_ev, redactions): @@ -679,7 +679,7 @@ class EventsWorkerStore(SQLBaseStore): desc="have_events_in_timeline", ) - defer.returnValue(set(r["event_id"] for r in rows)) + return set(r["event_id"] for r in rows) @defer.inlineCallbacks def have_seen_events(self, event_ids): @@ -705,7 +705,7 @@ class EventsWorkerStore(SQLBaseStore): input_iterator = iter(event_ids) for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []): yield self.runInteraction("have_seen_events", have_seen_events_txn, chunk) - defer.returnValue(results) + return results def get_seen_events_with_rejections(self, event_ids): """Given a list of event ids, check if we rejected them. @@ -816,4 +816,4 @@ class EventsWorkerStore(SQLBaseStore): # it. complexity_v1 = round(state_events / 500, 2) - defer.returnValue({"v1": complexity_v1}) + return {"v1": complexity_v1} diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py index b195dc66a0..23b48f6cea 100644 --- a/synapse/storage/filtering.py +++ b/synapse/storage/filtering.py @@ -15,8 +15,6 @@ from canonicaljson import encode_canonical_json -from twisted.internet import defer - from synapse.api.errors import Codes, SynapseError from synapse.util.caches.descriptors import cachedInlineCallbacks @@ -41,7 +39,7 @@ class FilteringStore(SQLBaseStore): desc="get_user_filter", ) - defer.returnValue(db_to_json(def_json)) + return db_to_json(def_json) def add_user_filter(self, user_localpart, user_filter): def_json = encode_canonical_json(user_filter) diff --git a/synapse/storage/group_server.py b/synapse/storage/group_server.py index 73e6fc6de2..15b01c6958 100644 --- a/synapse/storage/group_server.py +++ b/synapse/storage/group_server.py @@ -307,15 +307,13 @@ class GroupServerStore(SQLBaseStore): desc="get_group_categories", ) - defer.returnValue( - { - row["category_id"]: { - "is_public": row["is_public"], - "profile": json.loads(row["profile"]), - } - for row in rows + return { + row["category_id"]: { + "is_public": row["is_public"], + "profile": json.loads(row["profile"]), } - ) + for row in rows + } @defer.inlineCallbacks def get_group_category(self, group_id, category_id): @@ -328,7 +326,7 @@ class GroupServerStore(SQLBaseStore): category["profile"] = json.loads(category["profile"]) - defer.returnValue(category) + return category def upsert_group_category(self, group_id, category_id, profile, is_public): """Add/update room category for group @@ -370,15 +368,13 @@ class GroupServerStore(SQLBaseStore): desc="get_group_roles", ) - defer.returnValue( - { - row["role_id"]: { - "is_public": row["is_public"], - "profile": json.loads(row["profile"]), - } - for row in rows + return { + row["role_id"]: { + "is_public": row["is_public"], + "profile": json.loads(row["profile"]), } - ) + for row in rows + } @defer.inlineCallbacks def get_group_role(self, group_id, role_id): @@ -391,7 +387,7 @@ class GroupServerStore(SQLBaseStore): role["profile"] = json.loads(role["profile"]) - defer.returnValue(role) + return role def upsert_group_role(self, group_id, role_id, profile, is_public): """Add/remove user role @@ -960,7 +956,7 @@ class GroupServerStore(SQLBaseStore): _register_user_group_membership_txn, next_id, ) - defer.returnValue(res) + return res @defer.inlineCallbacks def create_group( @@ -1057,9 +1053,9 @@ class GroupServerStore(SQLBaseStore): now = int(self._clock.time_msec()) if row and now < row["valid_until_ms"]: - defer.returnValue(json.loads(row["attestation_json"])) + return json.loads(row["attestation_json"]) - defer.returnValue(None) + return None def get_joined_groups(self, user_id): return self._simple_select_onecol( diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 081564360f..752e9788a2 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -173,7 +173,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): ) if user_id: count = count + 1 - defer.returnValue(count) + return count @defer.inlineCallbacks def upsert_monthly_active_user(self, user_id): diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 42ec8c6bb8..1a0f2d5768 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -90,9 +90,7 @@ class PresenceStore(SQLBaseStore): presence_states, ) - defer.returnValue( - (stream_orderings[-1], self._presence_id_gen.get_current_token()) - ) + return (stream_orderings[-1], self._presence_id_gen.get_current_token()) def _update_presence_txn(self, txn, stream_orderings, presence_states): for stream_id, state in zip(stream_orderings, presence_states): @@ -180,7 +178,7 @@ class PresenceStore(SQLBaseStore): for row in rows: row["currently_active"] = bool(row["currently_active"]) - defer.returnValue({row["user_id"]: UserPresenceState(**row) for row in rows}) + return {row["user_id"]: UserPresenceState(**row) for row in rows} def get_current_presence_token(self): return self._presence_id_gen.get_current_token() diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py index 0ff392bdb4..8a5d8e9b18 100644 --- a/synapse/storage/profile.py +++ b/synapse/storage/profile.py @@ -34,15 +34,13 @@ class ProfileWorkerStore(SQLBaseStore): except StoreError as e: if e.code == 404: # no match - defer.returnValue(ProfileInfo(None, None)) + return ProfileInfo(None, None) return else: raise - defer.returnValue( - ProfileInfo( - avatar_url=profile["avatar_url"], display_name=profile["displayname"] - ) + return ProfileInfo( + avatar_url=profile["avatar_url"], display_name=profile["displayname"] ) def get_profile_displayname(self, user_localpart): @@ -168,7 +166,7 @@ class ProfileStore(ProfileWorkerStore): ) if res: - defer.returnValue(True) + return True res = yield self._simple_select_one_onecol( table="group_invites", @@ -179,4 +177,4 @@ class ProfileStore(ProfileWorkerStore): ) if res: - defer.returnValue(True) + return True diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 98cec8c82b..a6517c4cf3 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -120,7 +120,7 @@ class PushRulesWorkerStore( rules = _load_rules(rows, enabled_map) - defer.returnValue(rules) + return rules @cachedInlineCallbacks(max_entries=5000) def get_push_rules_enabled_for_user(self, user_id): @@ -130,9 +130,7 @@ class PushRulesWorkerStore( retcols=("user_name", "rule_id", "enabled"), desc="get_push_rules_enabled_for_user", ) - defer.returnValue( - {r["rule_id"]: False if r["enabled"] == 0 else True for r in results} - ) + return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results} def have_push_rules_changed_for_user(self, user_id, last_id): if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): @@ -160,7 +158,7 @@ class PushRulesWorkerStore( ) def bulk_get_push_rules(self, user_ids): if not user_ids: - defer.returnValue({}) + return {} results = {user_id: [] for user_id in user_ids} @@ -182,7 +180,7 @@ class PushRulesWorkerStore( for user_id, rules in results.items(): results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {})) - defer.returnValue(results) + return results @defer.inlineCallbacks def move_push_rule_from_room_to_room(self, new_room_id, user_id, rule): @@ -253,7 +251,7 @@ class PushRulesWorkerStore( result = yield self._bulk_get_push_rules_for_room( event.room_id, state_group, current_state_ids, event=event ) - defer.returnValue(result) + return result @cachedInlineCallbacks(num_args=2, cache_context=True) def _bulk_get_push_rules_for_room( @@ -312,7 +310,7 @@ class PushRulesWorkerStore( rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None} - defer.returnValue(rules_by_user) + return rules_by_user @cachedList( cached_method_name="get_push_rules_enabled_for_user", @@ -322,7 +320,7 @@ class PushRulesWorkerStore( ) def bulk_get_push_rules_enabled(self, user_ids): if not user_ids: - defer.returnValue({}) + return {} results = {user_id: {} for user_id in user_ids} @@ -336,7 +334,7 @@ class PushRulesWorkerStore( for row in rows: enabled = bool(row["enabled"]) results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled - defer.returnValue(results) + return results class PushRuleStore(PushRulesWorkerStore): diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index cfe0a94330..be3d4d9ded 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -63,7 +63,7 @@ class PusherWorkerStore(SQLBaseStore): ret = yield self._simple_select_one_onecol( "pushers", {"user_name": user_id}, "id", allow_none=True ) - defer.returnValue(ret is not None) + return ret is not None def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey): return self.get_pushers_by({"app_id": app_id, "pushkey": pushkey}) @@ -95,7 +95,7 @@ class PusherWorkerStore(SQLBaseStore): ], desc="get_pushers_by", ) - defer.returnValue(self._decode_pushers_rows(ret)) + return self._decode_pushers_rows(ret) @defer.inlineCallbacks def get_all_pushers(self): @@ -106,7 +106,7 @@ class PusherWorkerStore(SQLBaseStore): return self._decode_pushers_rows(rows) rows = yield self.runInteraction("get_all_pushers", get_pushers) - defer.returnValue(rows) + return rows def get_all_updated_pushers(self, last_id, current_id, limit): if last_id == current_id: @@ -205,7 +205,7 @@ class PusherWorkerStore(SQLBaseStore): result = {user_id: False for user_id in user_ids} result.update({r["user_name"]: True for r in rows}) - defer.returnValue(result) + return result class PusherStore(PusherWorkerStore): @@ -343,7 +343,7 @@ class PusherStore(PusherWorkerStore): "throttle_ms": row["throttle_ms"], } - defer.returnValue(params_by_room) + return params_by_room @defer.inlineCallbacks def set_throttle_params(self, pusher_id, room_id, params): diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index b477da12b1..6aa6d98ebb 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -58,7 +58,7 @@ class ReceiptsWorkerStore(SQLBaseStore): @cachedInlineCallbacks() def get_users_with_read_receipts_in_room(self, room_id): receipts = yield self.get_receipts_for_room(room_id, "m.read") - defer.returnValue(set(r["user_id"] for r in receipts)) + return set(r["user_id"] for r in receipts) @cached(num_args=2) def get_receipts_for_room(self, room_id, receipt_type): @@ -92,7 +92,7 @@ class ReceiptsWorkerStore(SQLBaseStore): desc="get_receipts_for_user", ) - defer.returnValue({row["room_id"]: row["event_id"] for row in rows}) + return {row["room_id"]: row["event_id"] for row in rows} @defer.inlineCallbacks def get_receipts_for_user_with_orderings(self, user_id, receipt_type): @@ -110,16 +110,14 @@ class ReceiptsWorkerStore(SQLBaseStore): return txn.fetchall() rows = yield self.runInteraction("get_receipts_for_user_with_orderings", f) - defer.returnValue( - { - row[0]: { - "event_id": row[1], - "topological_ordering": row[2], - "stream_ordering": row[3], - } - for row in rows + return { + row[0]: { + "event_id": row[1], + "topological_ordering": row[2], + "stream_ordering": row[3], } - ) + for row in rows + } @defer.inlineCallbacks def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): @@ -147,7 +145,7 @@ class ReceiptsWorkerStore(SQLBaseStore): room_ids, to_key, from_key=from_key ) - defer.returnValue([ev for res in results.values() for ev in res]) + return [ev for res in results.values() for ev in res] def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): """Get receipts for a single room for sending to clients. @@ -197,7 +195,7 @@ class ReceiptsWorkerStore(SQLBaseStore): rows = yield self.runInteraction("get_linearized_receipts_for_room", f) if not rows: - defer.returnValue([]) + return [] content = {} for row in rows: @@ -205,9 +203,7 @@ class ReceiptsWorkerStore(SQLBaseStore): row["user_id"] ] = json.loads(row["data"]) - defer.returnValue( - [{"type": "m.receipt", "room_id": room_id, "content": content}] - ) + return [{"type": "m.receipt", "room_id": room_id, "content": content}] @cachedList( cached_method_name="_get_linearized_receipts_for_room", @@ -217,7 +213,7 @@ class ReceiptsWorkerStore(SQLBaseStore): ) def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): if not room_ids: - defer.returnValue({}) + return {} def f(txn): if from_key: @@ -264,7 +260,7 @@ class ReceiptsWorkerStore(SQLBaseStore): room_id: [results[room_id]] if room_id in results else [] for room_id in room_ids } - defer.returnValue(results) + return results def get_all_updated_receipts(self, last_id, current_id, limit=None): if last_id == current_id: @@ -468,7 +464,7 @@ class ReceiptsStore(ReceiptsWorkerStore): ) if event_ts is None: - defer.returnValue(None) + return None now = self._clock.time_msec() logger.debug( @@ -482,7 +478,7 @@ class ReceiptsStore(ReceiptsWorkerStore): max_persisted_id = self._receipts_id_gen.get_current_token() - defer.returnValue((stream_id, max_persisted_id)) + return (stream_id, max_persisted_id) def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data): return self.runInteraction( diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 8b2c2a97ab..999c10a308 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -75,12 +75,12 @@ class RegistrationWorkerStore(SQLBaseStore): info = yield self.get_user_by_id(user_id) if not info: - defer.returnValue(False) + return False now = self.clock.time_msec() trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000 is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms - defer.returnValue(is_trial) + return is_trial @cached() def get_user_by_access_token(self, token): @@ -115,7 +115,7 @@ class RegistrationWorkerStore(SQLBaseStore): allow_none=True, desc="get_expiration_ts_for_user", ) - defer.returnValue(res) + return res @defer.inlineCallbacks def set_account_validity_for_user( @@ -190,7 +190,7 @@ class RegistrationWorkerStore(SQLBaseStore): desc="get_user_from_renewal_token", ) - defer.returnValue(res) + return res @defer.inlineCallbacks def get_renewal_token_for_user(self, user_id): @@ -209,7 +209,7 @@ class RegistrationWorkerStore(SQLBaseStore): desc="get_renewal_token_for_user", ) - defer.returnValue(res) + return res @defer.inlineCallbacks def get_users_expiring_soon(self): @@ -237,7 +237,7 @@ class RegistrationWorkerStore(SQLBaseStore): self.config.account_validity.renew_at, ) - defer.returnValue(res) + return res @defer.inlineCallbacks def set_renewal_mail_status(self, user_id, email_sent): @@ -280,7 +280,7 @@ class RegistrationWorkerStore(SQLBaseStore): desc="is_server_admin", ) - defer.returnValue(res if res else False) + return res if res else False def _query_for_auth(self, txn, token): sql = ( @@ -311,7 +311,7 @@ class RegistrationWorkerStore(SQLBaseStore): res = yield self.runInteraction( "is_support_user", self.is_support_user_txn, user_id ) - defer.returnValue(res) + return res def is_support_user_txn(self, txn, user_id): res = self._simple_select_one_onecol_txn( @@ -349,7 +349,7 @@ class RegistrationWorkerStore(SQLBaseStore): return 0 ret = yield self.runInteraction("count_users", _count_users) - defer.returnValue(ret) + return ret def count_daily_user_type(self): """ @@ -395,7 +395,7 @@ class RegistrationWorkerStore(SQLBaseStore): return count ret = yield self.runInteraction("count_users", _count_users) - defer.returnValue(ret) + return ret @defer.inlineCallbacks def find_next_generated_user_id_localpart(self): @@ -425,7 +425,7 @@ class RegistrationWorkerStore(SQLBaseStore): if i not in found: return i - defer.returnValue( + return ( ( yield self.runInteraction( "find_next_generated_user_id", _find_next_generated_user_id @@ -447,7 +447,7 @@ class RegistrationWorkerStore(SQLBaseStore): user_id = yield self.runInteraction( "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address ) - defer.returnValue(user_id) + return user_id def get_user_id_by_threepid_txn(self, txn, medium, address): """Returns user id from threepid @@ -487,7 +487,7 @@ class RegistrationWorkerStore(SQLBaseStore): ["medium", "address", "validated_at", "added_at"], "user_get_threepids", ) - defer.returnValue(ret) + return ret def user_delete_threepid(self, user_id, medium, address): return self._simple_delete( @@ -677,7 +677,7 @@ class RegistrationStore( if end: yield self._end_background_update("users_set_deactivated_flag") - defer.returnValue(batch_size) + return batch_size @defer.inlineCallbacks def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms): @@ -957,7 +957,7 @@ class RegistrationStore( desc="is_guest", ) - defer.returnValue(res if res else False) + return res if res else False def add_user_pending_deactivation(self, user_id): """ @@ -1024,7 +1024,7 @@ class RegistrationStore( yield self._end_background_update("user_threepids_grandfather") - defer.returnValue(1) + return 1 def get_threepid_validation_session( self, medium, client_secret, address=None, sid=None, validated=True @@ -1337,4 +1337,4 @@ class RegistrationStore( ) # Convert the integer into a boolean. - defer.returnValue(res == 1) + return res == 1 diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index 9954bc094f..fcb5f2f23a 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -17,8 +17,6 @@ import logging import attr -from twisted.internet import defer - from synapse.api.constants import RelationTypes from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore @@ -363,7 +361,7 @@ class RelationsWorkerStore(SQLBaseStore): return edit_event = yield self.get_event(edit_id, allow_none=True) - defer.returnValue(edit_event) + return edit_event def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender): """Check if a user has already annotated an event with the same key diff --git a/synapse/storage/room.py b/synapse/storage/room.py index fe9d79d792..bc606292b8 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -193,14 +193,12 @@ class RoomWorkerStore(SQLBaseStore): ) if row: - defer.returnValue( - RatelimitOverride( - messages_per_second=row["messages_per_second"], - burst_count=row["burst_count"], - ) + return RatelimitOverride( + messages_per_second=row["messages_per_second"], + burst_count=row["burst_count"], ) else: - defer.returnValue(None) + return None class RoomStore(RoomWorkerStore, SearchStore): diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index b3c002b9eb..cb88e49b51 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -108,7 +108,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): room_id, on_invalidate=cache_context.invalidate ) hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids) - defer.returnValue(hosts) + return hosts @cached(max_entries=100000, iterable=True) def get_users_in_room(self, room_id): @@ -253,8 +253,8 @@ class RoomMemberWorkerStore(EventsWorkerStore): invites = yield self.get_invited_rooms_for_user(user_id) for invite in invites: if invite.room_id == room_id: - defer.returnValue(invite) - defer.returnValue(None) + return invite + return None def get_rooms_for_user_where_membership_is(self, user_id, membership_list): """ Get all the rooms for this user where the membership for this user @@ -347,11 +347,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): rooms = yield self.get_rooms_for_user_where_membership_is( user_id, membership_list=[Membership.JOIN] ) - defer.returnValue( - frozenset( - GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering) - for r in rooms - ) + return frozenset( + GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering) + for r in rooms ) @defer.inlineCallbacks @@ -361,7 +359,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): rooms = yield self.get_rooms_for_user_with_stream_ordering( user_id, on_invalidate=on_invalidate ) - defer.returnValue(frozenset(r.room_id for r in rooms)) + return frozenset(r.room_id for r in rooms) @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True) def get_users_who_share_room_with_user(self, user_id, cache_context): @@ -378,7 +376,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) user_who_share_room.update(user_ids) - defer.returnValue(user_who_share_room) + return user_who_share_room @defer.inlineCallbacks def get_joined_users_from_context(self, event, context): @@ -394,7 +392,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): result = yield self._get_joined_users_from_context( event.room_id, state_group, current_state_ids, event=event, context=context ) - defer.returnValue(result) + return result def get_joined_users_from_state(self, room_id, state_entry): state_group = state_entry.state_group @@ -508,7 +506,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): avatar_url=to_ascii(event.content.get("avatar_url", None)), ) - defer.returnValue(users_in_room) + return users_in_room @cachedInlineCallbacks(max_entries=10000) def is_host_joined(self, room_id, host): @@ -533,14 +531,14 @@ class RoomMemberWorkerStore(EventsWorkerStore): rows = yield self._execute("is_host_joined", None, sql, room_id, like_clause) if not rows: - defer.returnValue(False) + return False user_id = rows[0][0] if get_domain_from_id(user_id) != host: # This can only happen if the host name has something funky in it raise Exception("Invalid host name") - defer.returnValue(True) + return True @cachedInlineCallbacks() def was_host_joined(self, room_id, host): @@ -573,14 +571,14 @@ class RoomMemberWorkerStore(EventsWorkerStore): rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause) if not rows: - defer.returnValue(False) + return False user_id = rows[0][0] if get_domain_from_id(user_id) != host: # This can only happen if the host name has something funky in it raise Exception("Invalid host name") - defer.returnValue(True) + return True def get_joined_hosts(self, room_id, state_entry): state_group = state_entry.state_group @@ -607,7 +605,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): cache = self._get_joined_hosts_cache(room_id) joined_hosts = yield cache.get_destinations(state_entry) - defer.returnValue(joined_hosts) + return joined_hosts @cached(max_entries=10000) def _get_joined_hosts_cache(self, room_id): @@ -637,7 +635,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): return rows[0][0] count = yield self.runInteraction("did_forget_membership", f) - defer.returnValue(count == 0) + return count == 0 @defer.inlineCallbacks def get_rooms_user_has_been_in(self, user_id): @@ -847,7 +845,7 @@ class RoomMemberStore(RoomMemberWorkerStore): if not result: yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME) - defer.returnValue(result) + return result @defer.inlineCallbacks def _background_current_state_membership(self, progress, batch_size): @@ -905,7 +903,7 @@ class RoomMemberStore(RoomMemberWorkerStore): if finished: yield self._end_background_update(_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME) - defer.returnValue(row_count) + return row_count class _JoinedHostsCache(object): @@ -933,7 +931,7 @@ class _JoinedHostsCache(object): state_entry(synapse.state._StateCacheEntry) """ if state_entry.state_group == self.state_group: - defer.returnValue(frozenset(self.hosts_to_joined_users)) + return frozenset(self.hosts_to_joined_users) with (yield self.linearizer.queue(())): if state_entry.state_group == self.state_group: @@ -970,7 +968,7 @@ class _JoinedHostsCache(object): else: self.state_group = object() self._len = sum(len(v) for v in itervalues(self.hosts_to_joined_users)) - defer.returnValue(frozenset(self.hosts_to_joined_users)) + return frozenset(self.hosts_to_joined_users) def __len__(self): return self._len diff --git a/synapse/storage/search.py b/synapse/storage/search.py index f3b1cec933..df87ab6a6d 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -166,7 +166,7 @@ class SearchStore(BackgroundUpdateStore): if not result: yield self._end_background_update(self.EVENT_SEARCH_UPDATE_NAME) - defer.returnValue(result) + return result @defer.inlineCallbacks def _background_reindex_gin_search(self, progress, batch_size): @@ -209,7 +209,7 @@ class SearchStore(BackgroundUpdateStore): yield self.runWithConnection(create_index) yield self._end_background_update(self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME) - defer.returnValue(1) + return 1 @defer.inlineCallbacks def _background_reindex_search_order(self, progress, batch_size): @@ -287,7 +287,7 @@ class SearchStore(BackgroundUpdateStore): if not finished: yield self._end_background_update(self.EVENT_SEARCH_ORDER_UPDATE_NAME) - defer.returnValue(num_rows) + return num_rows def store_event_search_txn(self, txn, event, key, value): """Add event to the search table @@ -454,17 +454,15 @@ class SearchStore(BackgroundUpdateStore): count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) - defer.returnValue( - { - "results": [ - {"event": event_map[r["event_id"]], "rank": r["rank"]} - for r in results - if r["event_id"] in event_map - ], - "highlights": highlights, - "count": count, - } - ) + return { + "results": [ + {"event": event_map[r["event_id"]], "rank": r["rank"]} + for r in results + if r["event_id"] in event_map + ], + "highlights": highlights, + "count": count, + } @defer.inlineCallbacks def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None): @@ -599,22 +597,20 @@ class SearchStore(BackgroundUpdateStore): count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) - defer.returnValue( - { - "results": [ - { - "event": event_map[r["event_id"]], - "rank": r["rank"], - "pagination_token": "%s,%s" - % (r["origin_server_ts"], r["stream_ordering"]), - } - for r in results - if r["event_id"] in event_map - ], - "highlights": highlights, - "count": count, - } - ) + return { + "results": [ + { + "event": event_map[r["event_id"]], + "rank": r["rank"], + "pagination_token": "%s,%s" + % (r["origin_server_ts"], r["stream_ordering"]), + } + for r in results + if r["event_id"] in event_map + ], + "highlights": highlights, + "count": count, + } def _find_highlights_in_postgres(self, search_query, events): """Given a list of events and a search term, return a list of words diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py index 6bd81e84ad..fb83218f90 100644 --- a/synapse/storage/signatures.py +++ b/synapse/storage/signatures.py @@ -59,7 +59,7 @@ class SignatureWorkerStore(SQLBaseStore): for e_id, h in hashes.items() } - defer.returnValue(list(hashes.items())) + return list(hashes.items()) def _get_event_reference_hashes_txn(self, txn, event_id): """Get all the hashes for a given PDU. diff --git a/synapse/storage/state.py b/synapse/storage/state.py index a35289876d..1980a87108 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -422,7 +422,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # Retrieve the room's create event create_event = yield self.get_create_event_for_room(room_id) - defer.returnValue(create_event.content.get("room_version", "1")) + return create_event.content.get("room_version", "1") @defer.inlineCallbacks def get_room_predecessor(self, room_id): @@ -442,7 +442,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): create_event = yield self.get_create_event_for_room(room_id) # Return predecessor if present - defer.returnValue(create_event.content.get("predecessor", None)) + return create_event.content.get("predecessor", None) @defer.inlineCallbacks def get_create_event_for_room(self, room_id): @@ -466,7 +466,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # Retrieve the room's create event and return create_event = yield self.get_event(create_id) - defer.returnValue(create_event) + return create_event @cached(max_entries=100000, iterable=True) def get_current_state_ids(self, room_id): @@ -563,7 +563,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): if not event: return - defer.returnValue(event.content.get("canonical_alias")) + return event.content.get("canonical_alias") @cached(max_entries=10000, iterable=True) def get_state_group_delta(self, state_group): @@ -613,14 +613,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): dict of state_group_id -> (dict of (type, state_key) -> event id) """ if not event_ids: - defer.returnValue({}) + return {} event_to_groups = yield self._get_state_group_for_events(event_ids) groups = set(itervalues(event_to_groups)) group_to_state = yield self._get_state_for_groups(groups) - defer.returnValue(group_to_state) + return group_to_state @defer.inlineCallbacks def get_state_ids_for_group(self, state_group): @@ -634,7 +634,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """ group_to_state = yield self._get_state_for_groups((state_group,)) - defer.returnValue(group_to_state[state_group]) + return group_to_state[state_group] @defer.inlineCallbacks def get_state_groups(self, room_id, event_ids): @@ -645,7 +645,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): dict of state_group_id -> list of state events. """ if not event_ids: - defer.returnValue({}) + return {} group_to_ids = yield self.get_state_groups_ids(room_id, event_ids) @@ -658,16 +658,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): get_prev_content=False, ) - defer.returnValue( - { - group: [ - state_event_map[v] - for v in itervalues(event_id_map) - if v in state_event_map - ] - for group, event_id_map in iteritems(group_to_ids) - } - ) + return { + group: [ + state_event_map[v] + for v in itervalues(event_id_map) + if v in state_event_map + ] + for group, event_id_map in iteritems(group_to_ids) + } @defer.inlineCallbacks def _get_state_groups_from_groups(self, groups, state_filter): @@ -694,7 +692,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) results.update(res) - defer.returnValue(results) + return results def _get_state_groups_from_groups_txn( self, txn, groups, state_filter=StateFilter.all() @@ -829,7 +827,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): for event_id, group in iteritems(event_to_groups) } - defer.returnValue({event: event_to_state[event] for event in event_ids}) + return {event: event_to_state[event] for event in event_ids} @defer.inlineCallbacks def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()): @@ -855,7 +853,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): for event_id, group in iteritems(event_to_groups) } - defer.returnValue({event: event_to_state[event] for event in event_ids}) + return {event: event_to_state[event] for event in event_ids} @defer.inlineCallbacks def get_state_for_event(self, event_id, state_filter=StateFilter.all()): @@ -871,7 +869,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): A deferred dict from (type, state_key) -> state_event """ state_map = yield self.get_state_for_events([event_id], state_filter) - defer.returnValue(state_map[event_id]) + return state_map[event_id] @defer.inlineCallbacks def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()): @@ -887,7 +885,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): A deferred dict from (type, state_key) -> state_event """ state_map = yield self.get_state_ids_for_events([event_id], state_filter) - defer.returnValue(state_map[event_id]) + return state_map[event_id] @cached(max_entries=50000) def _get_state_group_for_event(self, event_id): @@ -917,7 +915,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): desc="_get_state_group_for_events", ) - defer.returnValue({row["event_id"]: row["state_group"] for row in rows}) + return {row["event_id"]: row["state_group"] for row in rows} def _get_state_for_group_using_cache(self, cache, group, state_filter): """Checks if group is in cache. See `_get_state_for_groups` @@ -997,7 +995,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): incomplete_groups = incomplete_groups_m | incomplete_groups_nm if not incomplete_groups: - defer.returnValue(state) + return state cache_sequence_nm = self._state_group_cache.sequence cache_sequence_m = self._state_group_members_cache.sequence @@ -1024,7 +1022,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # everything we need from the database anyway. state[group] = state_filter.filter_state(group_state_dict) - defer.returnValue(state) + return state def _get_state_for_groups_using_cache(self, groups, cache, state_filter): """Gets the state at each of a list of state groups, optionally @@ -1498,7 +1496,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME ) - defer.returnValue(result * BATCH_SIZE_SCALE_FACTOR) + return result * BATCH_SIZE_SCALE_FACTOR @defer.inlineCallbacks def _background_index_state(self, progress, batch_size): @@ -1528,4 +1526,4 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) - defer.returnValue(1) + return 1 diff --git a/synapse/storage/stats.py b/synapse/storage/stats.py index 1cec84ee2e..e893b05ee7 100644 --- a/synapse/storage/stats.py +++ b/synapse/storage/stats.py @@ -66,7 +66,7 @@ class StatsStore(StateDeltasStore): if not self.stats_enabled: yield self._end_background_update("populate_stats_createtables") - defer.returnValue(1) + return 1 # Get all the rooms that we want to process. def _make_staging_area(txn): @@ -120,7 +120,7 @@ class StatsStore(StateDeltasStore): self.get_earliest_token_for_room_stats.invalidate_all() yield self._end_background_update("populate_stats_createtables") - defer.returnValue(1) + return 1 @defer.inlineCallbacks def _populate_stats_cleanup(self, progress, batch_size): @@ -129,7 +129,7 @@ class StatsStore(StateDeltasStore): """ if not self.stats_enabled: yield self._end_background_update("populate_stats_cleanup") - defer.returnValue(1) + return 1 position = yield self._simple_select_one_onecol( TEMP_TABLE + "_position", None, "position" @@ -143,14 +143,14 @@ class StatsStore(StateDeltasStore): yield self.runInteraction("populate_stats_cleanup", _delete_staging_area) yield self._end_background_update("populate_stats_cleanup") - defer.returnValue(1) + return 1 @defer.inlineCallbacks def _populate_stats_process_rooms(self, progress, batch_size): if not self.stats_enabled: yield self._end_background_update("populate_stats_process_rooms") - defer.returnValue(1) + return 1 # If we don't have progress filed, delete everything. if not progress: @@ -186,7 +186,7 @@ class StatsStore(StateDeltasStore): # No more rooms -- complete the transaction. if not rooms_to_work_on: yield self._end_background_update("populate_stats_process_rooms") - defer.returnValue(1) + return 1 logger.info( "Processing the next %d rooms of %d remaining", @@ -303,9 +303,9 @@ class StatsStore(StateDeltasStore): if processed_event_count > batch_size: # Don't process any more rooms, we've hit our batch size. - defer.returnValue(processed_event_count) + return processed_event_count - defer.returnValue(processed_event_count) + return processed_event_count def delete_all_stats(self): """ diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index a0465484df..856c2ee8d8 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -300,7 +300,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) if not room_ids: - defer.returnValue({}) + return {} results = {} room_ids = list(room_ids) @@ -323,7 +323,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) results.update(dict(zip(rm_ids, res))) - defer.returnValue(results) + return results def get_rooms_that_changed(self, room_ids, from_key): """Given a list of rooms and a token, return rooms where there may have @@ -364,7 +364,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): the chunk of events returned. """ if from_key == to_key: - defer.returnValue(([], from_key)) + return ([], from_key) from_id = RoomStreamToken.parse_stream_token(from_key).stream to_id = RoomStreamToken.parse_stream_token(to_key).stream @@ -374,7 +374,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) if not has_changed: - defer.returnValue(([], from_key)) + return ([], from_key) def f(txn): sql = ( @@ -407,7 +407,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): # get. key = from_key - defer.returnValue((ret, key)) + return (ret, key) @defer.inlineCallbacks def get_membership_changes_for_user(self, user_id, from_key, to_key): @@ -415,14 +415,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): to_id = RoomStreamToken.parse_stream_token(to_key).stream if from_key == to_key: - defer.returnValue([]) + return [] if from_id: has_changed = self._membership_stream_cache.has_entity_changed( user_id, int(from_id) ) if not has_changed: - defer.returnValue([]) + return [] def f(txn): sql = ( @@ -447,7 +447,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): self._set_before_and_after(ret, rows, topo_order=False) - defer.returnValue(ret) + return ret @defer.inlineCallbacks def get_recent_events_for_room(self, room_id, limit, end_token): @@ -477,7 +477,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): self._set_before_and_after(events, rows) - defer.returnValue((events, token)) + return (events, token) @defer.inlineCallbacks def get_recent_event_ids_for_room(self, room_id, limit, end_token): @@ -496,7 +496,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): """ # Allow a zero limit here, and no-op. if limit == 0: - defer.returnValue(([], end_token)) + return ([], end_token) end_token = RoomStreamToken.parse(end_token) @@ -511,7 +511,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): # We want to return the results in ascending order. rows.reverse() - defer.returnValue((rows, token)) + return (rows, token) def get_room_event_after_stream_ordering(self, room_id, stream_ordering): """Gets details of the first event in a room at or after a stream ordering @@ -549,12 +549,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): """ token = yield self.get_room_max_stream_ordering() if room_id is None: - defer.returnValue("s%d" % (token,)) + return "s%d" % (token,) else: topo = yield self.runInteraction( "_get_max_topological_txn", self._get_max_topological_txn, room_id ) - defer.returnValue("t%d-%d" % (topo, token)) + return "t%d-%d" % (topo, token) def get_stream_token_for_event(self, event_id): """The stream token for an event @@ -674,14 +674,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): [e for e in results["after"]["event_ids"]], get_prev_content=True ) - defer.returnValue( - { - "events_before": events_before, - "events_after": events_after, - "start": results["before"]["token"], - "end": results["after"]["token"], - } - ) + return { + "events_before": events_before, + "events_after": events_after, + "start": results["before"]["token"], + "end": results["after"]["token"], + } def _get_events_around_txn( self, txn, room_id, event_id, before_limit, after_limit, event_filter @@ -785,7 +783,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): events = yield self.get_events_as_list(event_ids) - defer.returnValue((upper_bound, events)) + return (upper_bound, events) def get_federation_out_pos(self, typ): return self._simple_select_one_onecol( @@ -939,7 +937,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): self._set_before_and_after(events, rows) - defer.returnValue((events, token)) + return (events, token) class StreamStore(StreamWorkerStore): diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index e88f8ea35f..20dd6bd53d 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -66,7 +66,7 @@ class TagsWorkerStore(AccountDataWorkerStore): room_id string, tag string and content string. """ if last_id == current_id: - defer.returnValue([]) + return [] def get_all_updated_tags_txn(txn): sql = ( @@ -107,7 +107,7 @@ class TagsWorkerStore(AccountDataWorkerStore): ) results.extend(tags) - defer.returnValue(results) + return results @defer.inlineCallbacks def get_updated_tags(self, user_id, stream_id): @@ -135,7 +135,7 @@ class TagsWorkerStore(AccountDataWorkerStore): user_id, int(stream_id) ) if not changed: - defer.returnValue({}) + return {} room_ids = yield self.runInteraction("get_updated_tags", get_updated_tags_txn) @@ -145,7 +145,7 @@ class TagsWorkerStore(AccountDataWorkerStore): for room_id in room_ids: results[room_id] = tags_by_room.get(room_id, {}) - defer.returnValue(results) + return results def get_tags_for_room(self, user_id, room_id): """Get all the tags for the given room @@ -194,7 +194,7 @@ class TagsStore(TagsWorkerStore): self.get_tags_for_user.invalidate((user_id,)) result = self._account_data_id_gen.get_current_token() - defer.returnValue(result) + return result @defer.inlineCallbacks def remove_tag_from_room(self, user_id, room_id, tag): @@ -217,7 +217,7 @@ class TagsStore(TagsWorkerStore): self.get_tags_for_user.invalidate((user_id,)) result = self._account_data_id_gen.get_current_token() - defer.returnValue(result) + return result def _update_revision_txn(self, txn, user_id, room_id, next_id): """Update the latest revision of the tags for the given user and room. diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index c585cf6cf7..b3c3bf55bc 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -147,7 +147,7 @@ class TransactionStore(SQLBaseStore): result = self._destination_retry_cache.get(destination, SENTINEL) if result is not SENTINEL: - defer.returnValue(result) + return result result = yield self.runInteraction( "get_destination_retry_timings", @@ -158,7 +158,7 @@ class TransactionStore(SQLBaseStore): # We don't hugely care about race conditions between getting and # invalidating the cache, since we time out fairly quickly anyway. self._destination_retry_cache[destination] = result - defer.returnValue(result) + return result def _get_destination_retry_timings(self, txn, destination): result = self._simple_select_one_txn( diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py index 7fd16fe65e..b5188d9bee 100644 --- a/synapse/storage/user_directory.py +++ b/synapse/storage/user_directory.py @@ -109,7 +109,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): yield self._simple_insert(TEMP_TABLE + "_position", {"position": new_pos}) yield self._end_background_update("populate_user_directory_createtables") - defer.returnValue(1) + return 1 @defer.inlineCallbacks def _populate_user_directory_cleanup(self, progress, batch_size): @@ -131,7 +131,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): ) yield self._end_background_update("populate_user_directory_cleanup") - defer.returnValue(1) + return 1 @defer.inlineCallbacks def _populate_user_directory_process_rooms(self, progress, batch_size): @@ -177,7 +177,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): # No more rooms -- complete the transaction. if not rooms_to_work_on: yield self._end_background_update("populate_user_directory_process_rooms") - defer.returnValue(1) + return 1 logger.info( "Processing the next %d rooms of %d remaining" @@ -257,9 +257,9 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): if processed_event_count > batch_size: # Don't process any more rooms, we've hit our batch size. - defer.returnValue(processed_event_count) + return processed_event_count - defer.returnValue(processed_event_count) + return processed_event_count @defer.inlineCallbacks def _populate_user_directory_process_users(self, progress, batch_size): @@ -268,7 +268,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): """ if not self.hs.config.user_directory_search_all_users: yield self._end_background_update("populate_user_directory_process_users") - defer.returnValue(1) + return 1 def _get_next_batch(txn): sql = "SELECT user_id FROM %s LIMIT %s" % ( @@ -298,7 +298,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): # No more users -- complete the transaction. if not users_to_work_on: yield self._end_background_update("populate_user_directory_process_users") - defer.returnValue(1) + return 1 logger.info( "Processing the next %d users of %d remaining" @@ -322,7 +322,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): progress, ) - defer.returnValue(len(users_to_work_on)) + return len(users_to_work_on) @defer.inlineCallbacks def is_room_world_readable_or_publicly_joinable(self, room_id): @@ -344,16 +344,16 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): join_rule_ev = yield self.get_event(join_rules_id, allow_none=True) if join_rule_ev: if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC: - defer.returnValue(True) + return True hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, "")) if hist_vis_id: hist_vis_ev = yield self.get_event(hist_vis_id, allow_none=True) if hist_vis_ev: if hist_vis_ev.content.get("history_visibility") == "world_readable": - defer.returnValue(True) + return True - defer.returnValue(False) + return False def update_profile_in_user_dir(self, user_id, display_name, avatar_url): """ @@ -499,7 +499,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): user_ids = set(user_ids_share_pub) user_ids.update(user_ids_share_priv) - defer.returnValue(user_ids) + return user_ids def add_users_who_share_private_room(self, room_id, user_id_tuples): """Insert entries into the users_who_share_private_rooms table. The first @@ -609,7 +609,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): users = set(pub_rows) users.update(rows) - defer.returnValue(list(users)) + return list(users) @defer.inlineCallbacks def get_rooms_in_common_for_users(self, user_id, other_user_id): @@ -635,7 +635,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): "get_rooms_in_common_for_users", None, sql, user_id, other_user_id ) - defer.returnValue([room_id for room_id, in rows]) + return [room_id for room_id, in rows] def delete_all_from_user_dir(self): """Delete the entire user directory @@ -782,7 +782,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): limited = len(results) > limit - defer.returnValue({"limited": limited, "results": results}) + return {"limited": limited, "results": results} def _parse_query_sqlite(search_term): diff --git a/synapse/storage/user_erasure_store.py b/synapse/storage/user_erasure_store.py index 1815fdc0dd..05cabc2282 100644 --- a/synapse/storage/user_erasure_store.py +++ b/synapse/storage/user_erasure_store.py @@ -12,9 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import operator -from twisted.internet import defer +import operator from synapse.storage._base import SQLBaseStore from synapse.util.caches.descriptors import cached, cachedList @@ -67,7 +66,7 @@ class UserErasureWorkerStore(SQLBaseStore): erased_users = yield self.runInteraction("are_users_erased", _get_erased_users) res = dict((u, u in erased_users) for u in user_ids) - defer.returnValue(res) + return res class UserErasureStore(UserErasureWorkerStore): diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 488c49747a..b91fb2db7b 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -56,7 +56,7 @@ class EventSources(object): device_list_key=device_list_key, groups_key=groups_key, ) - defer.returnValue(token) + return token @defer.inlineCallbacks def get_current_token_for_pagination(self): @@ -80,4 +80,4 @@ class EventSources(object): device_list_key=0, groups_key=0, ) - defer.returnValue(token) + return token diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index f506b2a695..841625a991 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -49,7 +49,7 @@ class Clock(object): with context.PreserveLoggingContext(): self._reactor.callLater(seconds, d.callback, seconds) res = yield d - defer.returnValue(res) + return res def time(self): """Returns the current system time in seconds since epoch.""" diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 58a6b8764f..f1c46836b1 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -366,7 +366,7 @@ class ReadWriteLock(object): new_defer.callback(None) self.key_to_current_readers.get(key, set()).discard(new_defer) - defer.returnValue(_ctx_manager()) + return _ctx_manager() @defer.inlineCallbacks def write(self, key): @@ -396,7 +396,7 @@ class ReadWriteLock(object): if self.key_to_current_writer[key] == new_defer: self.key_to_current_writer.pop(key) - defer.returnValue(_ctx_manager()) + return _ctx_manager() def _cancelled_to_timed_out_error(value, timeout): diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 675db2f448..a1acacbde9 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -289,7 +289,7 @@ class CacheDescriptor(_CacheDescriptorBase): def foo(self, key, cache_context): r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate) r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate) - defer.returnValue(r1 + r2) + return r1 + r2 Args: num_args (int): number of positional arguments (excluding ``self`` and diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index d6908e169d..82d3eefe0e 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -121,7 +121,7 @@ class ResponseCache(object): @defer.inlineCallbacks def handle_request(request): # etc - defer.returnValue(result) + return result result = yield response_cache.wrap( key, diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index c30b6de19c..0910930c21 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -67,7 +67,7 @@ def measure_func(name): def measured_func(self, *args, **kwargs): with Measure(self.clock, name): r = yield func(self, *args, **kwargs) - defer.returnValue(r) + return r return measured_func diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index d8d0ceae51..0862b5ca5a 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -95,15 +95,13 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs) # maximum backoff even though it might only have been down briefly backoff_on_failure = not ignore_backoff - defer.returnValue( - RetryDestinationLimiter( - destination, - clock, - store, - retry_interval, - backoff_on_failure=backoff_on_failure, - **kwargs - ) + return RetryDestinationLimiter( + destination, + clock, + store, + retry_interval, + backoff_on_failure=backoff_on_failure, + **kwargs ) diff --git a/synapse/visibility.py b/synapse/visibility.py index 2a11c83596..bf0f1eebd8 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -208,7 +208,7 @@ def filter_events_for_client( filtered_events = filter(operator.truth, filtered_events) # we turn it into a list before returning it. - defer.returnValue(list(filtered_events)) + return list(filtered_events) @defer.inlineCallbacks @@ -317,11 +317,11 @@ def filter_events_for_server( elif redact: to_return.append(prune_event(e)) - defer.returnValue(to_return) + return to_return # If there are no erased users then we can just return the given list # of events without having to copy it. - defer.returnValue(events) + return events # Ok, so we're dealing with events that have non-trivial visibility # rules, so we need to also get the memberships of the room. @@ -384,4 +384,4 @@ def filter_events_for_server( elif redact: to_return.append(prune_event(e)) - defer.returnValue(to_return) + return to_return diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 8d94a503d6..c4f0bbd3dd 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -107,7 +107,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): self.assertEquals(LoggingContext.current_context().request, "11") with PreserveLoggingContext(): yield persp_deferred - defer.returnValue(persp_resp) + return persp_resp self.http_client.post_json.side_effect = get_perspectives @@ -554,7 +554,7 @@ def run_in_context(f, *args, **kwargs): # logs. ctx.request = "testctx" rv = yield f(*args, **kwargs) - defer.returnValue(rv) + return rv def _verify_json_for_server(kr, *args): @@ -565,6 +565,6 @@ def _verify_json_for_server(kr, *args): @defer.inlineCallbacks def v(): rv1 = yield kr.verify_json_for_server(*args) - defer.returnValue(rv1) + return rv1 return run_in_context(v) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 90d0129374..99dce45cfe 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -283,4 +283,4 @@ class RegistrationTestCase(unittest.HomeserverTestCase): user, requester, displayname, by_admin=True ) - defer.returnValue((user_id, token)) + return (user_id, token) diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index a49f9b3224..b906686b49 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -145,7 +145,7 @@ class MatrixFederationAgentTests(TestCase): try: fetch_res = yield fetch_d - defer.returnValue(fetch_res) + return fetch_res except Exception as e: logger.info("Fetch of %s failed: %s", uri.decode("ascii"), e) raise @@ -936,7 +936,7 @@ class MatrixFederationAgentTests(TestCase): except Exception as e: logger.warning("Error fetching well-known: %s", e) raise - defer.returnValue(result) + return result def test_well_known_cache(self): self.reactor.lookups["testserv"] = "1.2.3.4" diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index 65b51dc981..3b885ef64b 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -61,7 +61,7 @@ class SrvResolverTestCase(unittest.TestCase): # should have restored our context self.assertIs(LoggingContext.current_context(), ctx) - defer.returnValue(result) + return result test_d = do_lookup() self.assertNoResult(test_d) diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index b9d6d7ad1c..2b01f40a42 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -68,7 +68,7 @@ class FederationClientTests(HomeserverTestCase): try: fetch_res = yield fetch_d - defer.returnValue(fetch_res) + return fetch_res finally: check_logcontext(context) diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index a8adc9a61d..a3d7e3c046 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -46,7 +46,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase): @defer.inlineCallbacks def cb(): yield Clock(reactor).sleep(0) - defer.returnValue("yay") + return "yay" @defer.inlineCallbacks def test(): diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index fbb9302694..9fabe3fbc0 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -43,7 +43,7 @@ class BackgroundUpdateTestCase(unittest.TestCase): "test_update", progress, ) - defer.returnValue(count) + return count self.update_handler.side_effect = update @@ -60,7 +60,7 @@ class BackgroundUpdateTestCase(unittest.TestCase): @defer.inlineCallbacks def update(progress, count): yield self.store._end_background_update("test_update") - defer.returnValue(count) + return count self.update_handler.side_effect = update self.update_handler.reset_mock() diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 732a778fab..1cb471205b 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -69,7 +69,7 @@ class RedactionTestCase(unittest.TestCase): yield self.store.persist_event(event, context) - defer.returnValue(event) + return event @defer.inlineCallbacks def inject_message(self, room, user, body): @@ -92,7 +92,7 @@ class RedactionTestCase(unittest.TestCase): yield self.store.persist_event(event, context) - defer.returnValue(event) + return event @defer.inlineCallbacks def inject_redaction(self, room, event_id, user, reason): diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 73ed943f5a..c6e8196b91 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -67,7 +67,7 @@ class RoomMemberStoreTestCase(unittest.TestCase): yield self.store.persist_event(event, context) - defer.returnValue(event) + return event @defer.inlineCallbacks def test_one_member(self): diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 212a7ae765..5c2cf3c2db 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -65,7 +65,7 @@ class StateStoreTestCase(tests.unittest.TestCase): yield self.store.persist_event(event, context) - defer.returnValue(event) + return event def assertStateMapEqual(self, s1, s2): for t in s1: diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 118c3bd238..e0605dac2f 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -139,7 +139,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): builder ) yield self.hs.get_datastore().persist_event(event, context) - defer.returnValue(event) + return event @defer.inlineCallbacks def inject_room_member(self, user_id, membership="join", extra_content={}): @@ -161,7 +161,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): ) yield self.hs.get_datastore().persist_event(event, context) - defer.returnValue(event) + return event @defer.inlineCallbacks def inject_message(self, user_id, content=None): @@ -182,7 +182,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): ) yield self.hs.get_datastore().persist_event(event, context) - defer.returnValue(event) + return event @defer.inlineCallbacks def test_large_room(self): diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 7807328e2f..56320bbaf9 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -159,7 +159,7 @@ class DescriptorTestCase(unittest.TestCase): def inner_fn(): with PreserveLoggingContext(): yield complete_lookup - defer.returnValue(1) + return 1 return inner_fn() @@ -169,7 +169,7 @@ class DescriptorTestCase(unittest.TestCase): c1.name = "c1" r = yield obj.fn(1) self.assertEqual(LoggingContext.current_context(), c1) - defer.returnValue(r) + return r def check_result(r): self.assertEqual(r, 1) @@ -286,7 +286,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): # we want this to behave like an asynchronous function yield run_on_reactor() assert LoggingContext.current_context().request == "c1" - defer.returnValue(self.mock(args1, arg2)) + return self.mock(args1, arg2) with LoggingContext() as c1: c1.request = "c1" @@ -334,7 +334,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): def list_fn(self, args1, arg2): # we want this to behave like an asynchronous function yield run_on_reactor() - defer.returnValue(self.mock(args1, arg2)) + return self.mock(args1, arg2) obj = Cls() invalidate0 = mock.Mock() diff --git a/tests/utils.py b/tests/utils.py index 8a94ce0b47..425e3387db 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -361,7 +361,7 @@ def setup_test_homeserver( if fed: register_federation_servlets(hs, fed) - defer.returnValue(hs) + return hs def register_federation_servlets(hs, resource): @@ -465,9 +465,9 @@ class MockHttpResource(HttpServer): args = [urlparse.unquote(u) for u in matcher.groups()] (code, response) = yield func(mock_request, *args) - defer.returnValue((code, response)) + return (code, response) except CodeMessageException as e: - defer.returnValue((e.code, cs_error(e.msg, code=e.errcode))) + return (e.code, cs_error(e.msg, code=e.errcode)) raise KeyError("No event can handle %s" % path) -- cgit 1.5.1 From 73bbaf2bc6962df2c25443cbc70286318601af5a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 23 Jul 2019 16:55:45 +0100 Subject: Add unit test for current state membership bg update --- tests/storage/test_roommember.py | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 73ed943f5a..b04be921f4 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -20,7 +20,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions -from synapse.types import RoomID, UserID +from synapse.types import Requester, RoomID, UserID from tests import unittest from tests.utils import create_room, setup_test_homeserver @@ -84,3 +84,38 @@ class RoomMemberStoreTestCase(unittest.TestCase): ) ], ) + + +class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, homeserver): + self.store = homeserver.get_datastore() + self.room_creator = homeserver.get_room_creation_handler() + + def test_can_rerun_update(self): + # First make sure we have completed all updates. + while not self.get_success(self.store.has_completed_background_updates()): + self.get_success(self.store.do_next_background_update(100), by=0.1) + + # Now let's create a room, which will insert a membership + user = UserID("alice", "test") + requester = Requester(user, None, False, None, None) + self.get_success(self.room_creator.create_room(requester, {})) + + # Register the background update to run again. + self.get_success( + self.store._simple_insert( + table="background_updates", + values={ + "update_name": "current_state_events_membership", + "progress_json": "{}", + "depends_on": None, + }, + ) + ) + + # ... and tell the DataStore that it hasn't finished all updates yet + self.store._all_done = False + + # Now let's actually drive the updates to completion + while not self.get_success(self.store.has_completed_background_updates()): + self.get_success(self.store.do_next_background_update(100), by=0.1) -- cgit 1.5.1 From cf2972c818344214244961e6175f559a6b59123b Mon Sep 17 00:00:00 2001 From: Jorik Schellekens Date: Wed, 24 Jul 2019 13:07:35 +0100 Subject: Fix servlet metric names (#5734) * Fix servlet metric names Co-Authored-By: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> * Remove redundant check * Cover all return paths --- changelog.d/5734.bugfix | 1 + synapse/federation/transport/server.py | 4 ++- synapse/http/server.py | 41 ++++++++++++++++++++--------- synapse/http/servlet.py | 4 ++- synapse/replication/http/_base.py | 2 +- synapse/rest/admin/server_notice_servlet.py | 9 +++++-- synapse/rest/client/v1/room.py | 37 +++++++++++++++++++++----- synapse/rest/client/v2_alpha/relations.py | 2 ++ tests/test_server.py | 21 +++++++++++---- tests/utils.py | 2 +- 10 files changed, 92 insertions(+), 31 deletions(-) create mode 100644 changelog.d/5734.bugfix (limited to 'tests') diff --git a/changelog.d/5734.bugfix b/changelog.d/5734.bugfix new file mode 100644 index 0000000000..33aea5a94c --- /dev/null +++ b/changelog.d/5734.bugfix @@ -0,0 +1 @@ +Fix a regression introduced in v1.2.0rc1 which led to incorrect labels on some prometheus metrics. diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 663264dec4..ea4e1b6d0f 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -325,7 +325,9 @@ class BaseFederationServlet(object): if code is None: continue - server.register_paths(method, (pattern,), self._wrap(code)) + server.register_paths( + method, (pattern,), self._wrap(code), self.__class__.__name__ + ) class FederationSendServlet(BaseFederationServlet): diff --git a/synapse/http/server.py b/synapse/http/server.py index 72a3d67eb6..e6f351ba3b 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -245,7 +245,9 @@ class JsonResource(HttpServer, resource.Resource): isLeaf = True - _PathEntry = collections.namedtuple("_PathEntry", ["pattern", "callback"]) + _PathEntry = collections.namedtuple( + "_PathEntry", ["pattern", "callback", "servlet_classname"] + ) def __init__(self, hs, canonical_json=True): resource.Resource.__init__(self) @@ -255,12 +257,28 @@ class JsonResource(HttpServer, resource.Resource): self.path_regexs = {} self.hs = hs - def register_paths(self, method, path_patterns, callback): + def register_paths(self, method, path_patterns, callback, servlet_classname): + """ + Registers a request handler against a regular expression. Later request URLs are + checked against these regular expressions in order to identify an appropriate + handler for that request. + + Args: + method (str): GET, POST etc + + path_patterns (Iterable[str]): A list of regular expressions to which + the request URLs are compared. + + callback (function): The handler for the request. Usually a Servlet + + servlet_classname (str): The name of the handler to be used in prometheus + and opentracing logs. + """ method = method.encode("utf-8") # method is bytes on py3 for path_pattern in path_patterns: logger.debug("Registering for %s %s", method, path_pattern.pattern) self.path_regexs.setdefault(method, []).append( - self._PathEntry(path_pattern, callback) + self._PathEntry(path_pattern, callback, servlet_classname) ) def render(self, request): @@ -275,13 +293,9 @@ class JsonResource(HttpServer, resource.Resource): This checks if anyone has registered a callback for that method and path. """ - callback, group_dict = self._get_handler_for_request(request) + callback, servlet_classname, group_dict = self._get_handler_for_request(request) - servlet_instance = getattr(callback, "__self__", None) - if servlet_instance is not None: - servlet_classname = servlet_instance.__class__.__name__ - else: - servlet_classname = "%r" % callback + # Make sure we have a name for this handler in prometheus. request.request_metrics.name = servlet_classname # Now trigger the callback. If it returns a response, we send it @@ -311,7 +325,8 @@ class JsonResource(HttpServer, resource.Resource): request (twisted.web.http.Request): Returns: - Tuple[Callable, dict[unicode, unicode]]: callback method, and the + Tuple[Callable, str, dict[unicode, unicode]]: callback method, the + label to use for that method in prometheus metrics, and the dict mapping keys to path components as specified in the handler's path match regexp. @@ -320,7 +335,7 @@ class JsonResource(HttpServer, resource.Resource): None, or a tuple of (http code, response body). """ if request.method == b"OPTIONS": - return _options_handler, {} + return _options_handler, "options_request_handler", {} # Loop through all the registered callbacks to check if the method # and path regex match @@ -328,10 +343,10 @@ class JsonResource(HttpServer, resource.Resource): m = path_entry.pattern.match(request.path.decode("ascii")) if m: # We found a match! - return path_entry.callback, m.groupdict() + return path_entry.callback, path_entry.servlet_classname, m.groupdict() # Huh. No one wanted to handle that? Fiiiiiine. Send 400. - return _unrecognised_request_handler, {} + return _unrecognised_request_handler, "unrecognised_request_handler", {} def _send_response( self, request, code, response_json_object, response_code_message=None diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 889038ff25..f0ca7d9aba 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -290,11 +290,13 @@ class RestServlet(object): for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"): if hasattr(self, "on_%s" % (method,)): + servlet_classname = self.__class__.__name__ method_handler = getattr(self, "on_%s" % (method,)) http_server.register_paths( method, patterns, - trace_servlet(self.__class__.__name__, method_handler), + trace_servlet(servlet_classname, method_handler), + servlet_classname, ) else: diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index fe482e279f..43c89e36dd 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -205,7 +205,7 @@ class ReplicationEndpoint(object): args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args) pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args)) - http_server.register_paths(method, [pattern], handler) + http_server.register_paths(method, [pattern], handler, self.__class__.__name__) def _cached_handler(self, request, txn_id, **kwargs): """Called on new incoming requests when caching is enabled. Checks diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py index ee66838a0d..d9c71261f2 100644 --- a/synapse/rest/admin/server_notice_servlet.py +++ b/synapse/rest/admin/server_notice_servlet.py @@ -59,9 +59,14 @@ class SendServerNoticeServlet(RestServlet): def register(self, json_resource): PATTERN = "^/_synapse/admin/v1/send_server_notice" - json_resource.register_paths("POST", (re.compile(PATTERN + "$"),), self.on_POST) json_resource.register_paths( - "PUT", (re.compile(PATTERN + "/(?P[^/]*)$"),), self.on_PUT + "POST", (re.compile(PATTERN + "$"),), self.on_POST, self.__class__.__name__ + ) + json_resource.register_paths( + "PUT", + (re.compile(PATTERN + "/(?P[^/]*)$"),), + self.on_PUT, + self.__class__.__name__, ) @defer.inlineCallbacks diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 7709c2d705..6276e97f89 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -67,11 +67,17 @@ class RoomCreateRestServlet(TransactionRestServlet): register_txn_path(self, PATTERNS, http_server) # define CORS for all of /rooms in RoomCreateRestServlet for simplicity http_server.register_paths( - "OPTIONS", client_patterns("/rooms(?:/.*)?$", v1=True), self.on_OPTIONS + "OPTIONS", + client_patterns("/rooms(?:/.*)?$", v1=True), + self.on_OPTIONS, + self.__class__.__name__, ) # define CORS for /createRoom[/txnid] http_server.register_paths( - "OPTIONS", client_patterns("/createRoom(?:/.*)?$", v1=True), self.on_OPTIONS + "OPTIONS", + client_patterns("/createRoom(?:/.*)?$", v1=True), + self.on_OPTIONS, + self.__class__.__name__, ) def on_PUT(self, request, txn_id): @@ -116,16 +122,28 @@ class RoomStateEventRestServlet(TransactionRestServlet): ) http_server.register_paths( - "GET", client_patterns(state_key, v1=True), self.on_GET + "GET", + client_patterns(state_key, v1=True), + self.on_GET, + self.__class__.__name__, ) http_server.register_paths( - "PUT", client_patterns(state_key, v1=True), self.on_PUT + "PUT", + client_patterns(state_key, v1=True), + self.on_PUT, + self.__class__.__name__, ) http_server.register_paths( - "GET", client_patterns(no_state_key, v1=True), self.on_GET_no_state_key + "GET", + client_patterns(no_state_key, v1=True), + self.on_GET_no_state_key, + self.__class__.__name__, ) http_server.register_paths( - "PUT", client_patterns(no_state_key, v1=True), self.on_PUT_no_state_key + "PUT", + client_patterns(no_state_key, v1=True), + self.on_PUT_no_state_key, + self.__class__.__name__, ) def on_GET_no_state_key(self, request, room_id, event_type): @@ -845,18 +863,23 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False): with_get: True to also register respective GET paths for the PUTs. """ http_server.register_paths( - "POST", client_patterns(regex_string + "$", v1=True), servlet.on_POST + "POST", + client_patterns(regex_string + "$", v1=True), + servlet.on_POST, + servlet.__class__.__name__, ) http_server.register_paths( "PUT", client_patterns(regex_string + "/(?P[^/]*)$", v1=True), servlet.on_PUT, + servlet.__class__.__name__, ) if with_get: http_server.register_paths( "GET", client_patterns(regex_string + "/(?P[^/]*)$", v1=True), servlet.on_GET, + servlet.__class__.__name__, ) diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index 6e52f6d284..9e9a639055 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -72,11 +72,13 @@ class RelationSendServlet(RestServlet): "POST", client_patterns(self.PATTERN + "$", releases=()), self.on_PUT_or_POST, + self.__class__.__name__, ) http_server.register_paths( "PUT", client_patterns(self.PATTERN + "/(?P[^/]*)$", releases=()), self.on_PUT, + self.__class__.__name__, ) def on_PUT(self, request, *args, **kwargs): diff --git a/tests/test_server.py b/tests/test_server.py index ba08483a4b..2a7d407c98 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -61,7 +61,10 @@ class JsonResourceTests(unittest.TestCase): res = JsonResource(self.homeserver) res.register_paths( - "GET", [re.compile("^/_matrix/foo/(?P[^/]*)$")], _callback + "GET", + [re.compile("^/_matrix/foo/(?P[^/]*)$")], + _callback, + "test_servlet", ) request, channel = make_request( @@ -82,7 +85,9 @@ class JsonResourceTests(unittest.TestCase): raise Exception("boo") res = JsonResource(self.homeserver) - res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) + res.register_paths( + "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" + ) request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") render(request, res, self.reactor) @@ -105,7 +110,9 @@ class JsonResourceTests(unittest.TestCase): return make_deferred_yieldable(d) res = JsonResource(self.homeserver) - res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) + res.register_paths( + "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" + ) request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") render(request, res, self.reactor) @@ -122,7 +129,9 @@ class JsonResourceTests(unittest.TestCase): raise SynapseError(403, "Forbidden!!one!", Codes.FORBIDDEN) res = JsonResource(self.homeserver) - res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) + res.register_paths( + "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" + ) request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") render(request, res, self.reactor) @@ -143,7 +152,9 @@ class JsonResourceTests(unittest.TestCase): self.fail("shouldn't ever get here") res = JsonResource(self.homeserver) - res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) + res.register_paths( + "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" + ) request, channel = make_request(self.reactor, b"GET", b"/_matrix/foobar") render(request, res, self.reactor) diff --git a/tests/utils.py b/tests/utils.py index 8a94ce0b47..99a3deae21 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -471,7 +471,7 @@ class MockHttpResource(HttpServer): raise KeyError("No event can handle %s" % path) - def register_paths(self, method, path_patterns, callback): + def register_paths(self, method, path_patterns, callback, servlet_name): for path_pattern in path_patterns: self.callbacks.append((method, path_pattern, callback)) -- cgit 1.5.1 From 618bd1ee76a83bd29beb208e9b7097ffcd787099 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 25 Jul 2019 15:59:45 +0100 Subject: Fix some error cases in the caching layer. (#5749) There was some inconsistent behaviour in the caching layer around how exceptions were handled - particularly synchronously-thrown ones. This seems to be most easily handled by pushing the creation of ObservableDeferreds down from CacheDescriptor to the Cache. --- changelog.d/5749.misc | 1 + synapse/util/caches/descriptors.py | 74 +++++++++++++++------------- tests/util/caches/test_descriptors.py | 90 +++++++++++++++++++++++++++++++++-- 3 files changed, 130 insertions(+), 35 deletions(-) create mode 100644 changelog.d/5749.misc (limited to 'tests') diff --git a/changelog.d/5749.misc b/changelog.d/5749.misc new file mode 100644 index 0000000000..48dd61f461 --- /dev/null +++ b/changelog.d/5749.misc @@ -0,0 +1 @@ +Fix some error cases in the caching layer. diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 7e69cf55fb..43f66ec4be 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -19,8 +19,7 @@ import logging import threading from collections import namedtuple -import six -from six import itervalues, string_types +from six import itervalues from prometheus_client import Gauge @@ -32,7 +31,6 @@ from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches import get_cache_factor_for from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry -from synapse.util.stringutils import to_ascii from . import register_cache @@ -124,7 +122,7 @@ class Cache(object): update_metrics (bool): whether to update the cache hit rate metrics Returns: - Either a Deferred or the raw result + Either an ObservableDeferred or the raw result """ callbacks = [callback] if callback else [] val = self._pending_deferred_cache.get(key, _CacheSentinel) @@ -148,9 +146,14 @@ class Cache(object): return default def set(self, key, value, callback=None): + if not isinstance(value, defer.Deferred): + raise TypeError("not a Deferred") + callbacks = [callback] if callback else [] self.check_thread() - entry = CacheEntry(deferred=value, callbacks=callbacks) + observable = ObservableDeferred(value, consumeErrors=True) + observer = defer.maybeDeferred(observable.observe) + entry = CacheEntry(deferred=observable, callbacks=callbacks) existing_entry = self._pending_deferred_cache.pop(key, None) if existing_entry: @@ -158,20 +161,31 @@ class Cache(object): self._pending_deferred_cache[key] = entry - def shuffle(result): + def compare_and_pop(): + """Check if our entry is still the one in _pending_deferred_cache, and + if so, pop it. + + Returns true if the entries matched. + """ existing_entry = self._pending_deferred_cache.pop(key, None) if existing_entry is entry: + return True + + # oops, the _pending_deferred_cache has been updated since + # we started our query, so we are out of date. + # + # Better put back whatever we took out. (We do it this way + # round, rather than peeking into the _pending_deferred_cache + # and then removing on a match, to make the common case faster) + if existing_entry is not None: + self._pending_deferred_cache[key] = existing_entry + + return False + + def cb(result): + if compare_and_pop(): self.cache.set(key, result, entry.callbacks) else: - # oops, the _pending_deferred_cache has been updated since - # we started our query, so we are out of date. - # - # Better put back whatever we took out. (We do it this way - # round, rather than peeking into the _pending_deferred_cache - # and then removing on a match, to make the common case faster) - if existing_entry is not None: - self._pending_deferred_cache[key] = existing_entry - # we're not going to put this entry into the cache, so need # to make sure that the invalidation callbacks are called. # That was probably done when _pending_deferred_cache was @@ -179,9 +193,16 @@ class Cache(object): # `invalidate` being previously called, in which case it may # not have been. Either way, let's double-check now. entry.invalidate() - return result - entry.deferred.addCallback(shuffle) + def eb(_fail): + compare_and_pop() + entry.invalidate() + + # once the deferred completes, we can move the entry from the + # _pending_deferred_cache to the real cache. + # + observer.addCallbacks(cb, eb) + return observable def prefill(self, key, value, callback=None): callbacks = [callback] if callback else [] @@ -414,20 +435,10 @@ class CacheDescriptor(_CacheDescriptorBase): ret.addErrback(onErr) - # If our cache_key is a string on py2, try to convert to ascii - # to save a bit of space in large caches. Py3 does this - # internally automatically. - if six.PY2 and isinstance(cache_key, string_types): - cache_key = to_ascii(cache_key) - - result_d = ObservableDeferred(ret, consumeErrors=True) - cache.set(cache_key, result_d, callback=invalidate_callback) + result_d = cache.set(cache_key, ret, callback=invalidate_callback) observer = result_d.observe() - if isinstance(observer, defer.Deferred): - return make_deferred_yieldable(observer) - else: - return observer + return make_deferred_yieldable(observer) if self.num_args == 1: wrapped.invalidate = lambda key: cache.invalidate(key[0]) @@ -543,7 +554,7 @@ class CacheListDescriptor(_CacheDescriptorBase): missing.add(arg) if missing: - # we need an observable deferred for each entry in the list, + # we need a deferred for each entry in the list, # which we put in the cache. Each deferred resolves with the # relevant result for that key. deferreds_map = {} @@ -551,8 +562,7 @@ class CacheListDescriptor(_CacheDescriptorBase): deferred = defer.Deferred() deferreds_map[arg] = deferred key = arg_to_cache_key(arg) - observable = ObservableDeferred(deferred) - cache.set(key, observable, callback=invalidate_callback) + cache.set(key, deferred, callback=invalidate_callback) def complete_all(res): # the wrapped function has completed. It returns a diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 56320bbaf9..5713870f48 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -27,6 +27,7 @@ from synapse.logging.context import ( make_deferred_yieldable, ) from synapse.util.caches import descriptors +from synapse.util.caches.descriptors import cached from tests import unittest @@ -55,12 +56,15 @@ class CacheTestCase(unittest.TestCase): d2 = defer.Deferred() cache.set("key2", d2, partial(record_callback, 1)) - # lookup should return the deferreds - self.assertIs(cache.get("key1"), d1) - self.assertIs(cache.get("key2"), d2) + # lookup should return observable deferreds + self.assertFalse(cache.get("key1").has_called()) + self.assertFalse(cache.get("key2").has_called()) # let one of the lookups complete d2.callback("result2") + + # for now at least, the cache will return real results rather than an + # observabledeferred self.assertEqual(cache.get("key2"), "result2") # now do the invalidation @@ -146,6 +150,28 @@ class DescriptorTestCase(unittest.TestCase): self.assertEqual(r, "chips") obj.mock.assert_not_called() + def test_cache_with_sync_exception(self): + """If the wrapped function throws synchronously, things should continue to work + """ + + class Cls(object): + @cached() + def fn(self, arg1): + raise SynapseError(100, "mai spoon iz too big!!1") + + obj = Cls() + + # this should fail immediately + d = obj.fn(1) + self.failureResultOf(d, SynapseError) + + # ... leaving the cache empty + self.assertEqual(len(obj.fn.cache.cache), 0) + + # and a second call should result in a second exception + d = obj.fn(1) + self.failureResultOf(d, SynapseError) + def test_cache_logcontexts(self): """Check that logcontexts are set and restored correctly when using the cache.""" @@ -222,6 +248,9 @@ class DescriptorTestCase(unittest.TestCase): self.assertEqual(LoggingContext.current_context(), c1) + # the cache should now be empty + self.assertEqual(len(obj.fn.cache.cache), 0) + obj = Cls() # set off a deferred which will do a cache lookup @@ -268,6 +297,61 @@ class DescriptorTestCase(unittest.TestCase): self.assertEqual(r, "chips") obj.mock.assert_not_called() + def test_cache_iterable(self): + class Cls(object): + def __init__(self): + self.mock = mock.Mock() + + @descriptors.cached(iterable=True) + def fn(self, arg1, arg2): + return self.mock(arg1, arg2) + + obj = Cls() + + obj.mock.return_value = ["spam", "eggs"] + r = obj.fn(1, 2) + self.assertEqual(r, ["spam", "eggs"]) + obj.mock.assert_called_once_with(1, 2) + obj.mock.reset_mock() + + # a call with different params should call the mock again + obj.mock.return_value = ["chips"] + r = obj.fn(1, 3) + self.assertEqual(r, ["chips"]) + obj.mock.assert_called_once_with(1, 3) + obj.mock.reset_mock() + + # the two values should now be cached + self.assertEqual(len(obj.fn.cache.cache), 3) + + r = obj.fn(1, 2) + self.assertEqual(r, ["spam", "eggs"]) + r = obj.fn(1, 3) + self.assertEqual(r, ["chips"]) + obj.mock.assert_not_called() + + def test_cache_iterable_with_sync_exception(self): + """If the wrapped function throws synchronously, things should continue to work + """ + + class Cls(object): + @descriptors.cached(iterable=True) + def fn(self, arg1): + raise SynapseError(100, "mai spoon iz too big!!1") + + obj = Cls() + + # this should fail immediately + d = obj.fn(1) + self.failureResultOf(d, SynapseError) + + # ... leaving the cache empty + self.assertEqual(len(obj.fn.cache.cache), 0) + + # and a second call should result in a second exception + d = obj.fn(1) + self.failureResultOf(d, SynapseError) + class CachedListDescriptorTestCase(unittest.TestCase): @defer.inlineCallbacks -- cgit 1.5.1 From c659b9f94fff29adfb2abe4f6b345710b65e8741 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Thu, 25 Jul 2019 11:08:24 -0400 Subject: allow uploading keys for cross-signing --- synapse/api/errors.py | 1 + synapse/handlers/device.py | 17 ++ synapse/handlers/e2e_keys.py | 198 ++++++++++++++++++++++- synapse/handlers/sync.py | 7 +- synapse/rest/client/v2_alpha/keys.py | 46 +++++- synapse/storage/__init__.py | 5 +- synapse/storage/devices.py | 57 +++++++ synapse/storage/end_to_end_keys.py | 174 +++++++++++++++++++- synapse/storage/schema/delta/56/signing_keys.sql | 41 +++++ synapse/types.py | 24 +++ tests/handlers/test_e2e_keys.py | 63 ++++++++ 11 files changed, 621 insertions(+), 12 deletions(-) (limited to 'tests') diff --git a/synapse/api/errors.py b/synapse/api/errors.py index ad3e262041..be15921bc6 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -61,6 +61,7 @@ class Codes(object): INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION" WRONG_ROOM_KEYS_VERSION = "M_WRONG_ROOM_KEYS_VERSION" EXPIRED_ACCOUNT = "ORG_MATRIX_EXPIRED_ACCOUNT" + INVALID_SIGNATURE = "M_INVALID_SIGNATURE" class CodeMessageException(RuntimeError): diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 99e8413092..2a8fa9c818 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -408,6 +410,21 @@ class DeviceHandler(DeviceWorkerHandler): for host in hosts: self.federation_sender.send_device_messages(host) + @defer.inlineCallbacks + def notify_user_signature_update(self, from_user_id, user_ids): + """Notify a user that they have made new signatures of other users. + + Args: + from_user_id (str): the user who made the signature + user_ids (list[str]): the users IDs that have new signatures + """ + + position = yield self.store.add_user_signature_change_to_streams( + from_user_id, user_ids + ) + + self.notifier.on_new_event("device_list_key", position, users=[from_user_id]) + @defer.inlineCallbacks def on_federation_query_user_devices(self, user_id): stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index fdfe8611b6..6187f879ef 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,12 +20,17 @@ import logging from six import iteritems from canonicaljson import encode_canonical_json, json +from signedjson.sign import SignatureVerifyException, verify_signed_json from twisted.internet import defer -from synapse.api.errors import CodeMessageException, SynapseError +from synapse.api.errors import CodeMessageException, Codes, SynapseError from synapse.logging.context import make_deferred_yieldable, run_in_background -from synapse.types import UserID, get_domain_from_id +from synapse.types import ( + UserID, + get_domain_from_id, + get_verify_key_from_cross_signing_key, +) from synapse.util.retryutils import NotRetryingDestination logger = logging.getLogger(__name__) @@ -46,7 +52,7 @@ class E2eKeysHandler(object): ) @defer.inlineCallbacks - def query_devices(self, query_body, timeout): + def query_devices(self, query_body, timeout, from_user_id): """ Handle a device key query from a client { @@ -64,6 +70,11 @@ class E2eKeysHandler(object): } } } + + Args: + from_user_id (str): the user making the query. This is used when + adding cross-signing signatures to limit what signatures users + can see. """ device_keys_query = query_body.get("device_keys", {}) @@ -118,6 +129,11 @@ class E2eKeysHandler(object): r = remote_queries_not_in_cache.setdefault(domain, {}) r[user_id] = remote_queries[user_id] + # Get cached cross-signing keys + cross_signing_keys = yield self.query_cross_signing_keys( + device_keys_query, from_user_id + ) + # Now fetch any devices that we don't have in our cache @defer.inlineCallbacks def do_remote_query(destination): @@ -131,6 +147,14 @@ class E2eKeysHandler(object): if user_id in destination_query: results[user_id] = keys + for user_id, key in remote_result["master_keys"].items(): + if user_id in destination_query: + cross_signing_keys["master"][user_id] = key + + for user_id, key in remote_result["self_signing_keys"].items(): + if user_id in destination_query: + cross_signing_keys["self_signing"][user_id] = key + except Exception as e: failures[destination] = _exception_to_failure(e) @@ -144,7 +168,73 @@ class E2eKeysHandler(object): ) ) - defer.returnValue({"device_keys": results, "failures": failures}) + ret = {"device_keys": results, "failures": failures} + + for key, value in iteritems(cross_signing_keys): + ret[key + "_keys"] = value + + defer.returnValue(ret) + + @defer.inlineCallbacks + def query_cross_signing_keys(self, query, from_user_id): + """Get cross-signing keys for users + + Args: + query (Iterable[string]) an iterable of user IDs. A dict whose keys + are user IDs satisfies this, so the query format used for + query_devices can be used here. + from_user_id (str): the user making the query. This is used when + adding cross-signing signatures to limit what signatures users + can see. + + Returns: + defer.Deferred[dict[str, dict[str, dict]]]: map from + (master|self_signing|user_signing) -> user_id -> key + """ + master_keys = {} + self_signing_keys = {} + user_signing_keys = {} + + for user_id in query: + # XXX: consider changing the store functions to allow querying + # multiple users simultaneously. + try: + key = yield self.store.get_e2e_cross_signing_key( + user_id, "master", from_user_id + ) + if key: + master_keys[user_id] = key + except Exception as e: + logger.info("Error getting master key: %s", e) + + try: + key = yield self.store.get_e2e_cross_signing_key( + user_id, "self_signing", from_user_id + ) + if key: + self_signing_keys[user_id] = key + except Exception as e: + logger.info("Error getting self-signing key: %s", e) + + # users can see other users' master and self-signing keys, but can + # only see their own user-signing keys + if from_user_id == user_id: + try: + key = yield self.store.get_e2e_cross_signing_key( + user_id, "user_signing", from_user_id + ) + if key: + user_signing_keys[user_id] = key + except Exception as e: + logger.info("Error getting user-signing key: %s", e) + + defer.returnValue( + { + "master": master_keys, + "self_signing": self_signing_keys, + "user_signing": user_signing_keys, + } + ) @defer.inlineCallbacks def query_local_devices(self, query): @@ -342,6 +432,104 @@ class E2eKeysHandler(object): yield self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys) + @defer.inlineCallbacks + def upload_signing_keys_for_user(self, user_id, keys): + """Upload signing keys for cross-signing + + Args: + user_id (string): the user uploading the keys + keys (dict[string, dict]): the signing keys + """ + + # if a master key is uploaded, then check it. Otherwise, load the + # stored master key, to check signatures on other keys + if "master_key" in keys: + master_key = keys["master_key"] + + _check_cross_signing_key(master_key, user_id, "master") + else: + master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master") + + # if there is no master key, then we can't do anything, because all the + # other cross-signing keys need to be signed by the master key + if not master_key: + raise SynapseError(400, "No master key available", Codes.MISSING_PARAM) + + master_key_id, master_verify_key = get_verify_key_from_cross_signing_key( + master_key + ) + + # for the other cross-signing keys, make sure that they have valid + # signatures from the master key + if "self_signing_key" in keys: + self_signing_key = keys["self_signing_key"] + + _check_cross_signing_key( + self_signing_key, user_id, "self_signing", master_verify_key + ) + + if "user_signing_key" in keys: + user_signing_key = keys["user_signing_key"] + + _check_cross_signing_key( + user_signing_key, user_id, "user_signing", master_verify_key + ) + + # if everything checks out, then store the keys and send notifications + deviceids = [] + if "master_key" in keys: + yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key) + deviceids.append(master_verify_key.version) + if "self_signing_key" in keys: + yield self.store.set_e2e_cross_signing_key( + user_id, "self_signing", self_signing_key + ) + deviceids.append( + get_verify_key_from_cross_signing_key(self_signing_key)[1].version + ) + if "user_signing_key" in keys: + yield self.store.set_e2e_cross_signing_key( + user_id, "user_signing", user_signing_key + ) + # the signature stream matches the semantics that we want for + # user-signing key updates: only the user themselves is notified of + # their own user-signing key updates + yield self.device_handler.notify_user_signature_update(user_id, [user_id]) + + # master key and self-signing key updates match the semantics of device + # list updates: all users who share an encrypted room are notified + if len(deviceids): + yield self.device_handler.notify_device_update(user_id, deviceids) + + defer.returnValue({}) + + +def _check_cross_signing_key(key, user_id, key_type, signing_key=None): + """Check a cross-signing key uploaded by a user. Performs some basic sanity + checking, and ensures that it is signed, if a signature is required. + + Args: + key (dict): the key data to verify + user_id (str): the user whose key is being checked + key_type (str): the type of key that the key should be + signing_key (VerifyKey): (optional) the signing key that the key should + be signed with. If omitted, signatures will not be checked. + """ + if ( + key.get("user_id") != user_id + or key_type not in key.get("usage", []) + or len(key.get("keys", {})) != 1 + ): + raise SynapseError(400, ("Invalid %s key" % (key_type,)), Codes.INVALID_PARAM) + + if signing_key: + try: + verify_signed_json(key, user_id, signing_key) + except SignatureVerifyException: + raise SynapseError( + 400, ("Invalid signature on %s key" % key_type), Codes.INVALID_SIGNATURE + ) + def _exception_to_failure(e): if isinstance(e, CodeMessageException): diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index cd1ac0a27a..c1c28a5fa1 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2018, 2019 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -1116,6 +1116,11 @@ class SyncHandler(object): # weren't in the previous sync *or* they left and rejoined. users_that_have_changed.update(newly_joined_or_invited_users) + user_signatures_changed = yield self.store.get_users_whose_signatures_changed( + user_id, since_token.device_list_key + ) + users_that_have_changed.update(user_signatures_changed) + # Now find users that we no longer track for room_id in newly_left_rooms: left_users = yield self.state.get_current_users_in_room(room_id) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 45c9928b65..3eaf1fd8a4 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,7 +27,7 @@ from synapse.http.servlet import ( ) from synapse.types import StreamToken -from ._base import client_patterns +from ._base import client_patterns, interactive_auth_handler logger = logging.getLogger(__name__) @@ -145,10 +146,11 @@ class KeyQueryServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request): - yield self.auth.get_user_by_req(request, allow_guest=True) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + user_id = requester.user.to_string() timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) - result = yield self.e2e_keys_handler.query_devices(body, timeout) + result = yield self.e2e_keys_handler.query_devices(body, timeout, user_id) defer.returnValue((200, result)) @@ -227,8 +229,46 @@ class OneTimeKeyServlet(RestServlet): defer.returnValue((200, result)) +class SigningKeyUploadServlet(RestServlet): + """ + POST /keys/device_signing/upload HTTP/1.1 + Content-Type: application/json + + { + } + """ + + PATTERNS = client_patterns("/keys/device_signing/upload$", releases=()) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(SigningKeyUploadServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.e2e_keys_handler = hs.get_e2e_keys_handler() + self.auth_handler = hs.get_auth_handler() + + @interactive_auth_handler + @defer.inlineCallbacks + def on_POST(self, request): + requester = yield self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + body = parse_json_object_from_request(request) + + yield self.auth_handler.validate_user_via_ui_auth( + requester, body, self.hs.get_ip_from_request(request) + ) + + result = yield self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) + defer.returnValue((200, result)) + + def register_servlets(hs, http_server): KeyUploadServlet(hs).register(http_server) KeyQueryServlet(hs).register(http_server) KeyChangesServlet(hs).register(http_server) OneTimeKeyServlet(hs).register(http_server) + SigningKeyUploadServlet(hs).register(http_server) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 6b0ca80087..c20ba1001c 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2018,2019 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -207,6 +207,9 @@ class DataStore( self._device_list_stream_cache = StreamChangeCache( "DeviceListStreamChangeCache", device_list_max ) + self._user_signature_stream_cache = StreamChangeCache( + "UserSignatureStreamChangeCache", device_list_max + ) self._device_list_federation_stream_cache = StreamChangeCache( "DeviceListFederationStreamChangeCache", device_list_max ) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index b73401bc26..ed372e2fc4 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -302,6 +302,41 @@ class DeviceWorkerStore(SQLBaseStore): """ txn.execute(sql, (destination, stream_id)) + @defer.inlineCallbacks + def add_user_signature_change_to_streams(self, from_user_id, user_ids): + """Persist that a user has made new signatures + + Args: + from_user_id (str): the user who made the signatures + user_ids (list[str]): the users who were signed + """ + + with self._device_list_id_gen.get_next() as stream_id: + yield self.runInteraction( + "add_user_sig_change_to_streams", + self._add_user_signature_change_txn, + from_user_id, + user_ids, + stream_id, + ) + defer.returnValue(stream_id) + + def _add_user_signature_change_txn(self, txn, from_user_id, user_ids, stream_id): + txn.call_after( + self._user_signature_stream_cache.entity_has_changed, + from_user_id, + stream_id, + ) + self._simple_insert_txn( + txn, + "user_signature_stream", + values={ + "stream_id": stream_id, + "from_user_id": from_user_id, + "user_ids": json.dumps(user_ids), + }, + ) + def get_device_stream_token(self): return self._device_list_id_gen.get_current_token() @@ -440,6 +475,28 @@ class DeviceWorkerStore(SQLBaseStore): "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn ) + @defer.inlineCallbacks + def get_users_whose_signatures_changed(self, user_id, from_key): + """Get the users who have new cross-signing signatures made by `user_id` since + `from_key`. + + Args: + user_id (str): the user who made the signatures + from_key (str): The device lists stream token + """ + from_key = int(from_key) + if self._user_signature_stream_cache.has_entity_changed(user_id, from_key): + sql = """ + SELECT DISTINCT user_ids FROM user_signature_stream + WHERE from_user_id = ? AND stream_id > ? + """ + rows = yield self._execute( + "get_users_whose_signatures_changed", None, sql, user_id, from_key + ) + defer.returnValue(set(user for row in rows for user in json.loads(row[0]))) + else: + defer.returnValue(set()) + def get_all_device_list_changes_for_remotes(self, from_key, to_key): """Return a list of `(stream_id, user_id, destination)` which is the combined list of changes to devices, and which destinations need to be diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 2fabb9e2cb..bb5f7d94eb 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,9 +14,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import time + from six import iteritems -from canonicaljson import encode_canonical_json +from canonicaljson import encode_canonical_json, json from twisted.internet import defer @@ -85,11 +89,12 @@ class EndToEndKeyWorkerStore(SQLBaseStore): " k.key_json" " FROM devices d" " %s JOIN e2e_device_keys_json k USING (user_id, device_id)" - " WHERE %s" + " WHERE (%s) AND NOT COALESCE(d.hidden, ?)" ) % ( "LEFT" if include_all_devices else "INNER", " OR ".join("(" + q + ")" for q in query_clauses), ) + query_params.append(False) txn.execute(sql, query_params) rows = self.cursor_to_dict(txn) @@ -281,3 +286,168 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): return self.runInteraction( "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn ) + + def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key): + """Set a user's cross-signing key. + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + user_id (str): the user to set the signing key for + key_type (str): the type of key that is being set: either 'master' + for a master key, 'self_signing' for a self-signing key, or + 'user_signing' for a user-signing key + key (dict): the key data + """ + # the cross-signing keys need to occupy the same namespace as devices, + # since signatures are identified by device ID. So add an entry to the + # device table to make sure that we don't have a collision with device + # IDs + + # the 'key' dict will look something like: + # { + # "user_id": "@alice:example.com", + # "usage": ["self_signing"], + # "keys": { + # "ed25519:base64+self+signing+public+key": "base64+self+signing+public+key", + # }, + # "signatures": { + # "@alice:example.com": { + # "ed25519:base64+master+public+key": "base64+signature" + # } + # } + # } + # The "keys" property must only have one entry, which will be the public + # key, so we just grab the first value in there + pubkey = next(iter(key["keys"].values())) + self._simple_insert( + "devices", + values={ + "user_id": user_id, + "device_id": pubkey, + "display_name": key_type + " signing key", + "hidden": True, + }, + desc="store_master_key_device", + ) + + # and finally, store the key itself + self._simple_insert( + "e2e_cross_signing_keys", + values={ + "user_id": user_id, + "keytype": key_type, + "keydata": json.dumps(key), + "added_ts": time.time() * 1000, + }, + desc="store_master_key", + ) + + def set_e2e_cross_signing_key(self, user_id, key_type, key): + """Set a user's cross-signing key. + + Args: + user_id (str): the user to set the user-signing key for + key_type (str): the type of cross-signing key to set + key (dict): the key data + """ + return self.runInteraction( + "add_e2e_cross_signing_key", + self._set_e2e_cross_signing_key_txn, + user_id, + key_type, + key, + ) + + def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None): + """Returns a user's cross-signing key. + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + user_id (str): the user whose key is being requested + key_type (str): the type of key that is being set: either 'master' + for a master key, 'self_signing' for a self-signing key, or + 'user_signing' for a user-signing key + from_user_id (str): if specified, signatures made by this user on + the key will be included in the result + + Returns: + dict of the key data + """ + sql = ( + "SELECT keydata " + " FROM e2e_cross_signing_keys " + " WHERE user_id = ? AND keytype = ? ORDER BY added_ts DESC LIMIT 1" + ) + txn.execute(sql, (user_id, key_type)) + row = txn.fetchone() + if not row: + return None + key = json.loads(row[0]) + + device_id = None + for k in key["keys"].values(): + device_id = k + + if from_user_id is not None: + sql = ( + "SELECT key_id, signature " + " FROM e2e_cross_signing_signatures " + " WHERE user_id = ? " + " AND target_user_id = ? " + " AND target_device_id = ? " + ) + txn.execute(sql, (from_user_id, user_id, device_id)) + row = txn.fetchone() + if row: + key.setdefault("signatures", {}).setdefault(from_user_id, {})[ + row[0] + ] = row[1] + + return key + + def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None): + """Returns a user's cross-signing key. + + Args: + user_id (str): the user whose self-signing key is being requested + key_type (str): the type of cross-signing key to get + from_user_id (str): if specified, signatures made by this user on + the self-signing key will be included in the result + + Returns: + dict of the key data + """ + return self.runInteraction( + "get_e2e_cross_signing_key", + self._get_e2e_cross_signing_key_txn, + user_id, + key_type, + from_user_id, + ) + + def store_e2e_cross_signing_signatures(self, user_id, signatures): + """Stores cross-signing signatures. + + Args: + user_id (str): the user who made the signatures + signatures (iterable[(str, str, str, str)]): signatures to add - each + a tuple of (key_id, target_user_id, target_device_id, signature), + where key_id is the ID of the key (including the signature + algorithm) that made the signature, target_user_id and + target_device_id indicate the device being signed, and signature + is the signature of the device + """ + return self._simple_insert_many( + "e2e_cross_signing_signatures", + [ + { + "user_id": user_id, + "key_id": key_id, + "target_user_id": target_user_id, + "target_device_id": target_device_id, + "signature": signature, + } + for (key_id, target_user_id, target_device_id, signature) in signatures + ], + "add_e2e_signing_key", + ) diff --git a/synapse/storage/schema/delta/56/signing_keys.sql b/synapse/storage/schema/delta/56/signing_keys.sql index 51c96d3116..771740e970 100644 --- a/synapse/storage/schema/delta/56/signing_keys.sql +++ b/synapse/storage/schema/delta/56/signing_keys.sql @@ -13,6 +13,47 @@ * limitations under the License. */ +-- cross-signing keys +CREATE TABLE IF NOT EXISTS e2e_cross_signing_keys ( + user_id TEXT NOT NULL, + -- the type of cross-signing key (master, user_signing, or self_signing) + keytype TEXT NOT NULL, + -- the full key information, as a json-encoded dict + keydata TEXT NOT NULL, + -- time that the key was added + added_ts BIGINT NOT NULL +); + +CREATE UNIQUE INDEX e2e_cross_signing_keys_idx ON e2e_cross_signing_keys(user_id, keytype, added_ts); + +-- cross-signing signatures +CREATE TABLE IF NOT EXISTS e2e_cross_signing_signatures ( + -- user who did the signing + user_id TEXT NOT NULL, + -- key used to sign + key_id TEXT NOT NULL, + -- user who was signed + target_user_id TEXT NOT NULL, + -- device/key that was signed + target_device_id TEXT NOT NULL, + -- the actual signature + signature TEXT NOT NULL +); + +CREATE UNIQUE INDEX e2e_cross_signing_signatures_idx ON e2e_cross_signing_signatures(user_id, target_user_id, target_device_id); + +-- stream of user signature updates +CREATE TABLE IF NOT EXISTS user_signature_stream ( + -- uses the same stream ID as device list stream + stream_id BIGINT NOT NULL, + -- user who did the signing + from_user_id TEXT NOT NULL, + -- list of users who were signed, as a JSON array + user_ids TEXT NOT NULL +); + +CREATE UNIQUE INDEX user_signature_stream_idx ON user_signature_stream(stream_id); + -- device list needs to know which ones are "real" devices, and which ones are -- just used to avoid collisions ALTER TABLE devices ADD COLUMN hidden BOOLEAN NULLABLE; diff --git a/synapse/types.py b/synapse/types.py index 51eadb6ad4..7a80471a0c 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,6 +18,8 @@ import string from collections import namedtuple import attr +from signedjson.key import decode_verify_key_bytes +from unpaddedbase64 import decode_base64 from synapse.api.errors import SynapseError @@ -475,3 +478,24 @@ class ReadReceipt(object): user_id = attr.ib() event_ids = attr.ib() data = attr.ib() + + +def get_verify_key_from_cross_signing_key(key_info): + """Get the key ID and signedjson verify key from a cross-signing key dict + + Args: + key_info (dict): a cross-signing key dict, which must have a "keys" + property that has exactly one item in it + + Returns: + (str, VerifyKey): the key ID and verify key for the cross-signing key + """ + # make sure that exactly one key is provided + if "keys" not in key_info: + raise SynapseError(400, "Invalid key") + keys = key_info["keys"] + if len(keys) != 1: + raise SynapseError(400, "Invalid key") + # and return that one key + for key_id, key_data in keys.items(): + return (key_id, decode_verify_key_bytes(key_id, decode_base64(key_data))) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 8dccc6826e..9ae4cb6ea2 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -145,3 +147,64 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}}, }, ) + + @defer.inlineCallbacks + def test_replace_master_key(self): + """uploading a new signing key should make the old signing key unavailable""" + local_user = "@boris:" + self.hs.hostname + keys1 = { + "master_key": { + # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0 + "user_id": local_user, + "usage": ["master"], + "keys": { + "ed25519:nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk": "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk" + }, + } + } + yield self.handler.upload_signing_keys_for_user(local_user, keys1) + + keys2 = { + "master_key": { + # private key: 4TL4AjRYwDVwD3pqQzcor+ez/euOB1/q78aTJ+czDNs + "user_id": local_user, + "usage": ["master"], + "keys": { + "ed25519:Hq6gL+utB4ET+UvD5ci0kgAwsX6qP/zvf8v6OInU5iw": "Hq6gL+utB4ET+UvD5ci0kgAwsX6qP/zvf8v6OInU5iw" + }, + } + } + yield self.handler.upload_signing_keys_for_user(local_user, keys2) + + devices = yield self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user) + self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]}) + + @defer.inlineCallbacks + def test_self_signing_key_doesnt_show_up_as_device(self): + """signing keys should be hidden when fetching a user's devices""" + local_user = "@boris:" + self.hs.hostname + keys1 = { + "master_key": { + # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0 + "user_id": local_user, + "usage": ["master"], + "keys": { + "ed25519:nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk": "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk" + }, + } + } + yield self.handler.upload_signing_keys_for_user(local_user, keys1) + + res = None + try: + yield self.hs.get_device_handler().check_device_registered( + user_id=local_user, + device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", + initial_device_display_name="new display name", + ) + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 400) + + res = yield self.handler.query_local_devices({local_user: None}) + self.assertDictEqual(res, {local_user: {}}) -- cgit 1.5.1 From 1cad8d7b6f736d86bd53b7f5e8b8417c302fdbd1 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 26 Jul 2019 07:38:55 +0100 Subject: Convert RedactionTestCase to modern test style (#5768) --- changelog.d/5768.misc | 1 + tests/storage/test_redaction.py | 74 +++++++++++++++++++++-------------------- 2 files changed, 39 insertions(+), 36 deletions(-) create mode 100644 changelog.d/5768.misc (limited to 'tests') diff --git a/changelog.d/5768.misc b/changelog.d/5768.misc new file mode 100644 index 0000000000..7a9c88b4c2 --- /dev/null +++ b/changelog.d/5768.misc @@ -0,0 +1 @@ +Convert RedactionTestCase to modern test style. diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 1cb471205b..8488b6edc8 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,23 +17,21 @@ from mock import Mock -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions from synapse.types import RoomID, UserID from tests import unittest -from tests.utils import create_room, setup_test_homeserver +from tests.utils import create_room -class RedactionTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield setup_test_homeserver( - self.addCleanup, resource_for_federation=Mock(), http_client=None +class RedactionTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): + return self.setup_test_homeserver( + resource_for_federation=Mock(), http_client=None ) + def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() @@ -42,11 +41,12 @@ class RedactionTestCase(unittest.TestCase): self.room1 = RoomID.from_string("!abc123:test") - yield create_room(hs, self.room1.to_string(), self.u_alice.to_string()) + self.get_success( + create_room(hs, self.room1.to_string(), self.u_alice.to_string()) + ) self.depth = 1 - @defer.inlineCallbacks def inject_room_member( self, room, user, membership, replaces_state=None, extra_content={} ): @@ -63,15 +63,14 @@ class RedactionTestCase(unittest.TestCase): }, ) - event, context = yield self.event_creation_handler.create_new_client_event( - builder + event, context = self.get_success( + self.event_creation_handler.create_new_client_event(builder) ) - yield self.store.persist_event(event, context) + self.get_success(self.store.persist_event(event, context)) return event - @defer.inlineCallbacks def inject_message(self, room, user, body): self.depth += 1 @@ -86,15 +85,14 @@ class RedactionTestCase(unittest.TestCase): }, ) - event, context = yield self.event_creation_handler.create_new_client_event( - builder + event, context = self.get_success( + self.event_creation_handler.create_new_client_event(builder) ) - yield self.store.persist_event(event, context) + self.get_success(self.store.persist_event(event, context)) return event - @defer.inlineCallbacks def inject_redaction(self, room, event_id, user, reason): builder = self.event_builder_factory.for_room_version( RoomVersions.V1, @@ -108,20 +106,21 @@ class RedactionTestCase(unittest.TestCase): }, ) - event, context = yield self.event_creation_handler.create_new_client_event( - builder + event, context = self.get_success( + self.event_creation_handler.create_new_client_event(builder) ) - yield self.store.persist_event(event, context) + self.get_success(self.store.persist_event(event, context)) - @defer.inlineCallbacks def test_redact(self): - yield self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) + self.get_success( + self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) + ) - msg_event = yield self.inject_message(self.room1, self.u_alice, "t") + msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t")) # Check event has not been redacted: - event = yield self.store.get_event(msg_event.event_id) + event = self.get_success(self.store.get_event(msg_event.event_id)) self.assertObjectHasAttributes( { @@ -136,11 +135,11 @@ class RedactionTestCase(unittest.TestCase): # Redact event reason = "Because I said so" - yield self.inject_redaction( - self.room1, msg_event.event_id, self.u_alice, reason + self.get_success( + self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason) ) - event = yield self.store.get_event(msg_event.event_id) + event = self.get_success(self.store.get_event(msg_event.event_id)) self.assertEqual(msg_event.event_id, event.event_id) @@ -164,15 +163,18 @@ class RedactionTestCase(unittest.TestCase): event.unsigned["redacted_because"], ) - @defer.inlineCallbacks def test_redact_join(self): - yield self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) + self.get_success( + self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) + ) - msg_event = yield self.inject_room_member( - self.room1, self.u_bob, Membership.JOIN, extra_content={"blue": "red"} + msg_event = self.get_success( + self.inject_room_member( + self.room1, self.u_bob, Membership.JOIN, extra_content={"blue": "red"} + ) ) - event = yield self.store.get_event(msg_event.event_id) + event = self.get_success(self.store.get_event(msg_event.event_id)) self.assertObjectHasAttributes( { @@ -187,13 +189,13 @@ class RedactionTestCase(unittest.TestCase): # Redact event reason = "Because I said so" - yield self.inject_redaction( - self.room1, msg_event.event_id, self.u_alice, reason + self.get_success( + self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason) ) # Check redaction - event = yield self.store.get_event(msg_event.event_id) + event = self.get_success(self.store.get_event(msg_event.event_id)) self.assertTrue("redacted_because" in event.unsigned) -- cgit 1.5.1 From 865077f1d1f4866ab874c56b70abbd426fedfb97 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Tue, 30 Jul 2019 02:47:27 +1000 Subject: Room Complexity Client Implementation (#5783) --- changelog.d/5783.feature | 1 + docs/sample_config.yaml | 17 +++++++ synapse/config/server.py | 41 ++++++++++++++++ synapse/federation/federation_client.py | 36 ++++++++++++++ synapse/federation/transport/client.py | 31 +++++++++--- synapse/handlers/federation.py | 25 ++++++++++ synapse/handlers/room_member.py | 84 +++++++++++++++++++++++++++++++-- tests/federation/test_complexity.py | 77 ++++++++++++++++++++++++++++-- 8 files changed, 298 insertions(+), 14 deletions(-) create mode 100644 changelog.d/5783.feature (limited to 'tests') diff --git a/changelog.d/5783.feature b/changelog.d/5783.feature new file mode 100644 index 0000000000..18f5a3cb28 --- /dev/null +++ b/changelog.d/5783.feature @@ -0,0 +1 @@ +Synapse can now be configured to not join remote rooms of a given "complexity" (currently, state events) over federation. This option can be used to prevent adverse performance on resource-constrained homeservers. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 7edf15207a..b92959692d 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -278,6 +278,23 @@ listeners: # Used by phonehome stats to group together related servers. #server_context: context +# Resource-constrained Homeserver Settings +# +# If limit_remote_rooms.enabled is True, the room complexity will be +# checked before a user joins a new remote room. If it is above +# limit_remote_rooms.complexity, it will disallow joining or +# instantly leave. +# +# limit_remote_rooms.complexity_error can be set to customise the text +# displayed to the user when a room above the complexity threshold has +# its join cancelled. +# +# Uncomment the below lines to enable: +#limit_remote_rooms: +# enabled: True +# complexity: 1.0 +# complexity_error: "This room is too complex." + # Whether to require a user to be in the room to add an alias to it. # Defaults to 'true'. # diff --git a/synapse/config/server.py b/synapse/config/server.py index 00170f1393..15449695d1 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -18,6 +18,7 @@ import logging import os.path +import attr from netaddr import IPSet from synapse.api.room_versions import KNOWN_ROOM_VERSIONS @@ -38,6 +39,12 @@ DEFAULT_BIND_ADDRESSES = ["::", "0.0.0.0"] DEFAULT_ROOM_VERSION = "4" +ROOM_COMPLEXITY_TOO_GREAT = ( + "Your homeserver is unable to join rooms this large or complex. " + "Please speak to your server administrator, or upgrade your instance " + "to join this room." +) + class ServerConfig(Config): def read_config(self, config, **kwargs): @@ -247,6 +254,23 @@ class ServerConfig(Config): self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None)) + @attr.s + class LimitRemoteRoomsConfig(object): + enabled = attr.ib( + validator=attr.validators.instance_of(bool), default=False + ) + complexity = attr.ib( + validator=attr.validators.instance_of((int, float)), default=1.0 + ) + complexity_error = attr.ib( + validator=attr.validators.instance_of(str), + default=ROOM_COMPLEXITY_TOO_GREAT, + ) + + self.limit_remote_rooms = LimitRemoteRoomsConfig( + **config.get("limit_remote_rooms", {}) + ) + bind_port = config.get("bind_port") if bind_port: if config.get("no_tls", False): @@ -617,6 +641,23 @@ class ServerConfig(Config): # Used by phonehome stats to group together related servers. #server_context: context + # Resource-constrained Homeserver Settings + # + # If limit_remote_rooms.enabled is True, the room complexity will be + # checked before a user joins a new remote room. If it is above + # limit_remote_rooms.complexity, it will disallow joining or + # instantly leave. + # + # limit_remote_rooms.complexity_error can be set to customise the text + # displayed to the user when a room above the complexity threshold has + # its join cancelled. + # + # Uncomment the below lines to enable: + #limit_remote_rooms: + # enabled: True + # complexity: 1.0 + # complexity_error: "This room is too complex." + # Whether to require a user to be in the room to add an alias to it. # Defaults to 'true'. # diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 25ed1257f1..6e03ce21af 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -993,3 +993,39 @@ class FederationClient(FederationBase): ) raise RuntimeError("Failed to send to any server.") + + @defer.inlineCallbacks + def get_room_complexity(self, destination, room_id): + """ + Fetch the complexity of a remote room from another server. + + Args: + destination (str): The remote server + room_id (str): The room ID to ask about. + + Returns: + Deferred[dict] or Deferred[None]: Dict contains the complexity + metric versions, while None means we could not fetch the complexity. + """ + try: + complexity = yield self.transport_layer.get_room_complexity( + destination=destination, room_id=room_id + ) + defer.returnValue(complexity) + except CodeMessageException as e: + # We didn't manage to get it -- probably a 404. We are okay if other + # servers don't give it to us. + logger.debug( + "Failed to fetch room complexity via %s for %s, got a %d", + destination, + room_id, + e.code, + ) + except Exception: + logger.exception( + "Failed to fetch room complexity via %s for %s", destination, room_id + ) + + # If we don't manage to find it, return None. It's not an error if a + # server doesn't give it to us. + defer.returnValue(None) diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 2a6709ff48..0cea0d2a10 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -21,7 +21,11 @@ from six.moves import urllib from twisted.internet import defer from synapse.api.constants import Membership -from synapse.api.urls import FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX +from synapse.api.urls import ( + FEDERATION_UNSTABLE_PREFIX, + FEDERATION_V1_PREFIX, + FEDERATION_V2_PREFIX, +) from synapse.logging.utils import log_function logger = logging.getLogger(__name__) @@ -935,6 +939,23 @@ class TransportLayerClient(object): destination=destination, path=path, data=content, ignore_backoff=True ) + def get_room_complexity(self, destination, room_id): + """ + Args: + destination (str): The remote server + room_id (str): The room ID to ask about. + """ + path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/rooms/%s/complexity", room_id) + + return self.client.get_json(destination=destination, path=path) + + +def _create_path(federation_prefix, path, *args): + """ + Ensures that all args are url encoded. + """ + return federation_prefix + path % tuple(urllib.parse.quote(arg, "") for arg in args) + def _create_v1_path(path, *args): """Creates a path against V1 federation API from the path template and @@ -951,9 +972,7 @@ def _create_v1_path(path, *args): Returns: str """ - return FEDERATION_V1_PREFIX + path % tuple( - urllib.parse.quote(arg, "") for arg in args - ) + return _create_path(FEDERATION_V1_PREFIX, path, *args) def _create_v2_path(path, *args): @@ -971,6 +990,4 @@ def _create_v2_path(path, *args): Returns: str """ - return FEDERATION_V2_PREFIX + path % tuple( - urllib.parse.quote(arg, "") for arg in args - ) + return _create_path(FEDERATION_V2_PREFIX, path, *args) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 89b37dbc1c..10160bfe86 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -2796,3 +2796,28 @@ class FederationHandler(BaseHandler): ) else: return user_joined_room(self.distributor, user, room_id) + + @defer.inlineCallbacks + def get_room_complexity(self, remote_room_hosts, room_id): + """ + Fetch the complexity of a remote room over federation. + + Args: + remote_room_hosts (list[str]): The remote servers to ask. + room_id (str): The room ID to ask about. + + Returns: + Deferred[dict] or Deferred[None]: Dict contains the complexity + metric versions, while None means we could not fetch the complexity. + """ + + for host in remote_room_hosts: + res = yield self.federation_client.get_room_complexity(host, room_id) + + # We got a result, return it. + if res: + defer.returnValue(res) + + # We fell off the bottom, couldn't get the complexity from anyone. Oh + # well. + defer.returnValue(None) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index baea08ddd0..249a6d9c5d 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -26,8 +26,7 @@ from unpaddedbase64 import decode_base64 from twisted.internet import defer -import synapse.server -import synapse.types +from synapse import types from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, Codes, HttpResponseException, SynapseError from synapse.types import RoomID, UserID @@ -543,7 +542,7 @@ class RoomMemberHandler(object): ), "Sender (%s) must be same as requester (%s)" % (sender, requester.user) assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) else: - requester = synapse.types.create_requester(target_user) + requester = types.create_requester(target_user) prev_event = yield self.event_creation_handler.deduplicate_state_event( event, context @@ -945,6 +944,47 @@ class RoomMemberMasterHandler(RoomMemberHandler): self.distributor.declare("user_joined_room") self.distributor.declare("user_left_room") + @defer.inlineCallbacks + def _is_remote_room_too_complex(self, room_id, remote_room_hosts): + """ + Check if complexity of a remote room is too great. + + Args: + room_id (str) + remote_room_hosts (list[str]) + + Returns: bool of whether the complexity is too great, or None + if unable to be fetched + """ + max_complexity = self.hs.config.limit_remote_rooms.complexity + complexity = yield self.federation_handler.get_room_complexity( + remote_room_hosts, room_id + ) + + if complexity: + if complexity["v1"] > max_complexity: + return True + return False + return None + + @defer.inlineCallbacks + def _is_local_room_too_complex(self, room_id): + """ + Check if the complexity of a local room is too great. + + Args: + room_id (str) + + Returns: bool + """ + max_complexity = self.hs.config.limit_remote_rooms.complexity + complexity = yield self.store.get_room_complexity(room_id) + + if complexity["v1"] > max_complexity: + return True + + return False + @defer.inlineCallbacks def _remote_join(self, requester, remote_room_hosts, room_id, user, content): """Implements RoomMemberHandler._remote_join @@ -952,7 +992,6 @@ class RoomMemberMasterHandler(RoomMemberHandler): # filter ourselves out of remote_room_hosts: do_invite_join ignores it # and if it is the only entry we'd like to return a 404 rather than a # 500. - remote_room_hosts = [ host for host in remote_room_hosts if host != self.hs.hostname ] @@ -960,6 +999,18 @@ class RoomMemberMasterHandler(RoomMemberHandler): if len(remote_room_hosts) == 0: raise SynapseError(404, "No known servers") + if self.hs.config.limit_remote_rooms.enabled: + # Fetch the room complexity + too_complex = yield self._is_remote_room_too_complex( + room_id, remote_room_hosts + ) + if too_complex is True: + raise SynapseError( + code=400, + msg=self.hs.config.limit_remote_rooms.complexity_error, + errcode=Codes.RESOURCE_LIMIT_EXCEEDED, + ) + # We don't do an auth check if we are doing an invite # join dance for now, since we're kinda implicitly checking # that we are allowed to join when we decide whether or not we @@ -969,6 +1020,31 @@ class RoomMemberMasterHandler(RoomMemberHandler): ) yield self._user_joined_room(user, room_id) + # Check the room we just joined wasn't too large, if we didn't fetch the + # complexity of it before. + if self.hs.config.limit_remote_rooms.enabled: + if too_complex is False: + # We checked, and we're under the limit. + return + + # Check again, but with the local state events + too_complex = yield self._is_local_room_too_complex(room_id) + + if too_complex is False: + # We're under the limit. + return + + # The room is too large. Leave. + requester = types.create_requester(user, None, False, None) + yield self.update_membership( + requester=requester, target=user, room_id=room_id, action="leave" + ) + raise SynapseError( + code=400, + msg=self.hs.config.limit_remote_rooms.complexity_error, + errcode=Codes.RESOURCE_LIMIT_EXCEEDED, + ) + @defer.inlineCallbacks def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target): """Implements RoomMemberHandler._remote_reject_invite diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index a5b03005d7..51714a2b06 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -13,12 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from mock import Mock + from twisted.internet import defer +from synapse.api.errors import Codes, SynapseError from synapse.config.ratelimiting import FederationRateLimitConfig from synapse.federation.transport import server from synapse.rest import admin from synapse.rest.client.v1 import login, room +from synapse.types import UserID from synapse.util.ratelimitutils import FederationRateLimiter from tests import unittest @@ -33,9 +37,8 @@ class RoomComplexityTests(unittest.HomeserverTestCase): ] def default_config(self, name="test"): - config = super(RoomComplexityTests, self).default_config(name=name) - config["limit_large_remote_room_joins"] = True - config["limit_large_remote_room_complexity"] = 0.05 + config = super().default_config(name=name) + config["limit_remote_rooms"] = {"enabled": True, "complexity": 0.05} return config def prepare(self, reactor, clock, homeserver): @@ -88,3 +91,71 @@ class RoomComplexityTests(unittest.HomeserverTestCase): self.assertEquals(200, channel.code) complexity = channel.json_body["v1"] self.assertEqual(complexity, 1.23) + + def test_join_too_large(self): + + u1 = self.register_user("u1", "pass") + + handler = self.hs.get_room_member_handler() + fed_transport = self.hs.get_federation_transport_client() + + # Mock out some things, because we don't want to test the whole join + fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999})) + handler.federation_handler.do_invite_join = Mock(return_value=defer.succeed(1)) + + d = handler._remote_join( + None, + ["otherserver.example"], + "roomid", + UserID.from_string(u1), + {"membership": "join"}, + ) + + self.pump() + + # The request failed with a SynapseError saying the resource limit was + # exceeded. + f = self.get_failure(d, SynapseError) + self.assertEqual(f.value.code, 400, f.value) + self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + + def test_join_too_large_once_joined(self): + + u1 = self.register_user("u1", "pass") + u1_token = self.login("u1", "pass") + + # Ok, this might seem a bit weird -- I want to test that we actually + # leave the room, but I don't want to simulate two servers. So, we make + # a local room, which we say we're joining remotely, even if there's no + # remote, because we mock that out. Then, we'll leave the (actually + # local) room, which will be propagated over federation in a real + # scenario. + room_1 = self.helper.create_room_as(u1, tok=u1_token) + + handler = self.hs.get_room_member_handler() + fed_transport = self.hs.get_federation_transport_client() + + # Mock out some things, because we don't want to test the whole join + fed_transport.client.get_json = Mock(return_value=defer.succeed(None)) + handler.federation_handler.do_invite_join = Mock(return_value=defer.succeed(1)) + + # Artificially raise the complexity + self.hs.get_datastore().get_current_state_event_counts = lambda x: defer.succeed( + 600 + ) + + d = handler._remote_join( + None, + ["otherserver.example"], + room_1, + UserID.from_string(u1), + {"membership": "join"}, + ) + + self.pump() + + # The request failed with a SynapseError saying the resource limit was + # exceeded. + f = self.get_failure(d, SynapseError) + self.assertEqual(f.value.code, 400) + self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) -- cgit 1.5.1 From 4e97eb89e5ed4517e5967a49acf6db987bb96d51 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 24 Jul 2019 22:45:35 +0100 Subject: Handle loops in redaction events --- synapse/storage/events_worker.py | 96 +++++++++++++++------------------------- tests/storage/test_redaction.py | 70 +++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+), 60 deletions(-) (limited to 'tests') diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index e15e7d86fe..c6fa7f82fd 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -483,7 +483,8 @@ class EventsWorkerStore(SQLBaseStore): if events_to_fetch: logger.debug("Also fetching redaction events %s", events_to_fetch) - result_map = {} + # build a map from event_id to EventBase + event_map = {} for event_id, row in fetched_events.items(): if not row: continue @@ -494,14 +495,37 @@ class EventsWorkerStore(SQLBaseStore): if not allow_rejected and rejected_reason: continue - cache_entry = yield self._get_event_from_row( - row["internal_metadata"], - row["json"], - row["redactions"], - rejected_reason=row["rejected_reason"], - format_version=row["format_version"], + d = json.loads(row["json"]) + internal_metadata = json.loads(row["internal_metadata"]) + + format_version = row["format_version"] + if format_version is None: + # This means that we stored the event before we had the concept + # of a event format version, so it must be a V1 event. + format_version = EventFormatVersions.V1 + + original_ev = event_type_from_format_version(format_version)( + event_dict=d, + internal_metadata_dict=internal_metadata, + rejected_reason=rejected_reason, ) + event_map[event_id] = original_ev + + # finally, we can decide whether each one nededs redacting, and build + # the cache entries. + result_map = {} + for event_id, original_ev in event_map.items(): + redactions = fetched_events[event_id]["redactions"] + redacted_event = self._maybe_redact_event_row( + original_ev, redactions, event_map + ) + + cache_entry = _EventCacheEntry( + event=original_ev, redacted_event=redacted_event + ) + + self._get_event_cache.prefill((event_id,), cache_entry) result_map[event_id] = cache_entry return result_map @@ -615,50 +639,7 @@ class EventsWorkerStore(SQLBaseStore): return event_dict - @defer.inlineCallbacks - def _get_event_from_row( - self, internal_metadata, js, redactions, format_version, rejected_reason=None - ): - """Parse an event row which has been read from the database - - Args: - internal_metadata (str): json-encoded internal_metadata column - js (str): json-encoded event body from event_json - redactions (list[str]): a list of the events which claim to have redacted - this event, from the redactions table - format_version: (str): the 'format_version' column - rejected_reason (str|None): the reason this event was rejected, if any - - Returns: - _EventCacheEntry - """ - with Measure(self._clock, "_get_event_from_row"): - d = json.loads(js) - internal_metadata = json.loads(internal_metadata) - - if format_version is None: - # This means that we stored the event before we had the concept - # of a event format version, so it must be a V1 event. - format_version = EventFormatVersions.V1 - - original_ev = event_type_from_format_version(format_version)( - event_dict=d, - internal_metadata_dict=internal_metadata, - rejected_reason=rejected_reason, - ) - - redacted_event = yield self._maybe_redact_event_row(original_ev, redactions) - - cache_entry = _EventCacheEntry( - event=original_ev, redacted_event=redacted_event - ) - - self._get_event_cache.prefill((original_ev.event_id,), cache_entry) - - return cache_entry - - @defer.inlineCallbacks - def _maybe_redact_event_row(self, original_ev, redactions): + def _maybe_redact_event_row(self, original_ev, redactions, event_map): """Given an event object and a list of possible redacting event ids, determine whether to honour any of those redactions and if so return a redacted event. @@ -666,6 +647,8 @@ class EventsWorkerStore(SQLBaseStore): Args: original_ev (EventBase): redactions (iterable[str]): list of event ids of potential redaction events + event_map (dict[str, EventBase]): other events which have been fetched, in + which we can look up the redaaction events. Map from event id to event. Returns: Deferred[EventBase|None]: if the event should be redacted, a pruned @@ -675,15 +658,9 @@ class EventsWorkerStore(SQLBaseStore): # we choose to ignore redactions of m.room.create events. return None - if original_ev.type == "m.room.redaction": - # ... and redaction events - return None - - redaction_map = yield self._get_events_from_cache_or_db(redactions) - for redaction_id in redactions: - redaction_entry = redaction_map.get(redaction_id) - if not redaction_entry: + redaction_event = event_map.get(redaction_id) + if not redaction_event or redaction_event.rejected_reason: # we don't have the redaction event, or the redaction event was not # authorized. logger.debug( @@ -693,7 +670,6 @@ class EventsWorkerStore(SQLBaseStore): ) continue - redaction_event = redaction_entry.event if redaction_event.room_id != original_ev.room_id: logger.debug( "%s was redacted by %s but redaction was in a different room!", diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 8488b6edc8..d961b81d48 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -17,6 +17,8 @@ from mock import Mock +from twisted.internet import defer + from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions from synapse.types import RoomID, UserID @@ -216,3 +218,71 @@ class RedactionTestCase(unittest.HomeserverTestCase): }, event.unsigned["redacted_because"], ) + + def test_circular_redaction(self): + redaction_event_id1 = "$redaction1_id:test" + redaction_event_id2 = "$redaction2_id:test" + + class EventIdManglingBuilder: + def __init__(self, base_builder, event_id): + self._base_builder = base_builder + self._event_id = event_id + + @defer.inlineCallbacks + def build(self, prev_event_ids): + built_event = yield self._base_builder.build(prev_event_ids) + built_event.event_id = self._event_id + built_event._event_dict["event_id"] = self._event_id + return built_event + + @property + def room_id(self): + return self._base_builder.room_id + + event_1, context_1 = self.get_success( + self.event_creation_handler.create_new_client_event( + EventIdManglingBuilder( + self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.Redaction, + "sender": self.u_alice.to_string(), + "room_id": self.room1.to_string(), + "content": {"reason": "test"}, + "redacts": redaction_event_id2, + }, + ), + redaction_event_id1, + ) + ) + ) + + self.get_success(self.store.persist_event(event_1, context_1)) + + event_2, context_2 = self.get_success( + self.event_creation_handler.create_new_client_event( + EventIdManglingBuilder( + self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.Redaction, + "sender": self.u_alice.to_string(), + "room_id": self.room1.to_string(), + "content": {"reason": "test"}, + "redacts": redaction_event_id1, + }, + ), + redaction_event_id2, + ) + ) + ) + self.get_success(self.store.persist_event(event_2, context_2)) + + # fetch one of the redactions + fetched = self.get_success(self.store.get_event(redaction_event_id1)) + + # it should have been redacted + self.assertEqual(fetched.unsigned["redacted_by"], redaction_event_id2) + self.assertEqual( + fetched.unsigned["redacted_because"].event_id, redaction_event_id2 + ) -- cgit 1.5.1 From 8c97f6414cf322fc5b42a92ed0df2fb70bfab3fc Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 30 Jul 2019 08:25:02 +0100 Subject: Remove non-functional 'expire_access_token' setting (#5782) The `expire_access_token` didn't do what it sounded like it should do. What it actually did was make Synapse enforce the 'time' caveat on macaroons used as access tokens, but since our access token macaroons never contained such a caveat, it was always a no-op. (The code to add 'time' caveats was removed back in v0.18.5, in #1656) --- changelog.d/5782.removal | 1 + docs/sample_config.yaml | 4 ---- synapse/api/auth.py | 28 ++++------------------ synapse/config/key.py | 6 ----- synapse/handlers/auth.py | 2 +- tests/handlers/test_register.py | 2 +- .../test_resource_limits_server_notices.py | 2 +- tests/utils.py | 1 - 8 files changed, 9 insertions(+), 37 deletions(-) create mode 100644 changelog.d/5782.removal (limited to 'tests') diff --git a/changelog.d/5782.removal b/changelog.d/5782.removal new file mode 100644 index 0000000000..658bf923ab --- /dev/null +++ b/changelog.d/5782.removal @@ -0,0 +1 @@ +Remove non-functional 'expire_access_token' setting. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index b92959692d..08316597fa 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -942,10 +942,6 @@ uploads_path: "DATADIR/uploads" # # macaroon_secret_key: -# Used to enable access token expiration. -# -#expire_access_token: False - # a secret which is used to calculate HMACs for form values, to stop # falsification of values. Must be specified for the User Consent # forms to work. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 351790cca4..179644852a 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -410,21 +410,16 @@ class Auth(object): try: user_id = self.get_user_id_from_macaroon(macaroon) - has_expiry = False guest = False for caveat in macaroon.caveats: - if caveat.caveat_id.startswith("time "): - has_expiry = True - elif caveat.caveat_id == "guest = true": + if caveat.caveat_id == "guest = true": guest = True - self.validate_macaroon( - macaroon, rights, self.hs.config.expire_access_token, user_id=user_id - ) + self.validate_macaroon(macaroon, rights, user_id=user_id) except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): raise InvalidClientTokenError("Invalid macaroon passed.") - if not has_expiry and rights == "access": + if rights == "access": self.token_cache[token] = (user_id, guest) return user_id, guest @@ -450,7 +445,7 @@ class Auth(object): return caveat.caveat_id[len(user_prefix) :] raise InvalidClientTokenError("No user caveat in macaroon") - def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id): + def validate_macaroon(self, macaroon, type_string, user_id): """ validate that a Macaroon is understood by and was signed by this server. @@ -458,7 +453,6 @@ class Auth(object): macaroon(pymacaroons.Macaroon): The macaroon to validate type_string(str): The kind of token required (e.g. "access", "delete_pusher") - verify_expiry(bool): Whether to verify whether the macaroon has expired. user_id (str): The user_id required """ v = pymacaroons.Verifier() @@ -471,19 +465,7 @@ class Auth(object): v.satisfy_exact("type = " + type_string) v.satisfy_exact("user_id = %s" % user_id) v.satisfy_exact("guest = true") - - # verify_expiry should really always be True, but there exist access - # tokens in the wild which expire when they should not, so we can't - # enforce expiry yet (so we have to allow any caveat starting with - # 'time < ' in access tokens). - # - # On the other hand, short-term login tokens (as used by CAS login, for - # example) have an expiry time which we do want to enforce. - - if verify_expiry: - v.satisfy_general(self._verify_expiry) - else: - v.satisfy_general(lambda c: c.startswith("time < ")) + v.satisfy_general(self._verify_expiry) # access_tokens include a nonce for uniqueness: any value is acceptable v.satisfy_general(lambda c: c.startswith("nonce = ")) diff --git a/synapse/config/key.py b/synapse/config/key.py index 8fc74f9cdf..fe8386985c 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -116,8 +116,6 @@ class KeyConfig(Config): seed = bytes(self.signing_key[0]) self.macaroon_secret_key = hashlib.sha256(seed).digest() - self.expire_access_token = config.get("expire_access_token", False) - # a secret which is used to calculate HMACs for form values, to stop # falsification of values self.form_secret = config.get("form_secret", None) @@ -144,10 +142,6 @@ class KeyConfig(Config): # %(macaroon_secret_key)s - # Used to enable access token expiration. - # - #expire_access_token: False - # a secret which is used to calculate HMACs for form values, to stop # falsification of values. Must be specified for the User Consent # forms to work. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 05be5b7c48..0f3ebf7ef8 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -860,7 +860,7 @@ class AuthHandler(BaseHandler): try: macaroon = pymacaroons.Macaroon.deserialize(login_token) user_id = auth_api.get_user_id_from_macaroon(macaroon) - auth_api.validate_macaroon(macaroon, "login", True, user_id) + auth_api.validate_macaroon(macaroon, "login", user_id) except Exception: raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) self.ratelimit_login_per_account(user_id) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 99dce45cfe..0ad0a88165 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -44,7 +44,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): hs_config["max_mau_value"] = 50 hs_config["limit_usage_by_mau"] = True - hs = self.setup_test_homeserver(config=hs_config, expire_access_token=True) + hs = self.setup_test_homeserver(config=hs_config) return hs def prepare(self, reactor, clock, hs): diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 984feb623f..cdf89e3383 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -36,7 +36,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): "room_name": "Server Notices", } - hs = self.setup_test_homeserver(config=hs_config, expire_access_token=True) + hs = self.setup_test_homeserver(config=hs_config) return hs def prepare(self, reactor, clock, hs): diff --git a/tests/utils.py b/tests/utils.py index 6350646263..f1eb9a545c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -126,7 +126,6 @@ def default_config(name, parse=False): "enable_registration": True, "enable_registration_captcha": False, "macaroon_secret_key": "not even a little secret", - "expire_access_token": False, "trusted_third_party_id_servers": [], "room_invite_state_types": [], "password_providers": [], -- cgit 1.5.1 From a9bcae9f50a14dc020634c532732c1a86b696d92 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 30 Jul 2019 17:42:56 +0100 Subject: Share SSL options for well-known requests --- synapse/crypto/context_factory.py | 8 ++++++++ synapse/http/federation/matrix_federation_agent.py | 16 +++++----------- tests/http/federation/test_matrix_federation_agent.py | 12 ++++++------ 3 files changed, 19 insertions(+), 17 deletions(-) (limited to 'tests') diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index 4f48e8e88d..06e63a96b5 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -31,6 +31,7 @@ from twisted.internet.ssl import ( platformTrust, ) from twisted.python.failure import Failure +from twisted.web.iweb import IPolicyForHTTPS logger = logging.getLogger(__name__) @@ -74,6 +75,7 @@ class ServerContextFactory(ContextFactory): return self._context +@implementer(IPolicyForHTTPS) class ClientTLSOptionsFactory(object): """Factory for Twisted SSLClientConnectionCreators that are used to make connections to remote servers for federation. @@ -146,6 +148,12 @@ class ClientTLSOptionsFactory(object): f = Failure() tls_protocol.failVerification(f) + def creatorForNetloc(self, hostname, port): + """Implements the IPolicyForHTTPS interace so that this can be passed + directly to agents. + """ + return self.get_options(hostname) + @implementer(IOpenSSLClientConnectionCreator) class SSLClientConnectionCreator(object): diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index c03ddb724f..a0d5139839 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -64,10 +64,6 @@ class MatrixFederationAgent(object): tls_client_options_factory (ClientTLSOptionsFactory|None): factory to use for fetching client tls options, or none to disable TLS. - _well_known_tls_policy (IPolicyForHTTPS|None): - TLS policy to use for fetching .well-known files. None to use a default - (browser-like) implementation. - _srv_resolver (SrvResolver|None): SRVResolver impl to use for looking up SRV records. None to use a default implementation. @@ -81,7 +77,6 @@ class MatrixFederationAgent(object): self, reactor, tls_client_options_factory, - _well_known_tls_policy=None, _srv_resolver=None, _well_known_cache=well_known_cache, ): @@ -98,13 +93,12 @@ class MatrixFederationAgent(object): self._pool.maxPersistentPerHost = 5 self._pool.cachedConnectionTimeout = 2 * 60 - agent_args = {} - if _well_known_tls_policy is not None: - # the param is called 'contextFactory', but actually passing a - # contextfactory is deprecated, and it expects an IPolicyForHTTPS. - agent_args["contextFactory"] = _well_known_tls_policy _well_known_agent = RedirectAgent( - Agent(self._reactor, pool=self._pool, **agent_args) + Agent( + self._reactor, + pool=self._pool, + contextFactory=tls_client_options_factory, + ) ) self._well_known_agent = _well_known_agent diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index b906686b49..4255add097 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -75,7 +75,6 @@ class MatrixFederationAgentTests(TestCase): config_dict = default_config("test", parse=False) config_dict["federation_custom_ca_list"] = [get_test_ca_cert_file()] - # config_dict["trusted_key_servers"] = [] self._config = config = HomeServerConfig() config.parse_config_dict(config_dict, "", "") @@ -83,7 +82,6 @@ class MatrixFederationAgentTests(TestCase): self.agent = MatrixFederationAgent( reactor=self.reactor, tls_client_options_factory=ClientTLSOptionsFactory(config), - _well_known_tls_policy=TrustingTLSPolicyForHTTPS(), _srv_resolver=self.mock_resolver, _well_known_cache=self.well_known_cache, ) @@ -691,16 +689,18 @@ class MatrixFederationAgentTests(TestCase): not signed by a CA """ - # we use the same test server as the other tests, but use an agent - # with _well_known_tls_policy left to the default, which will not - # trust it (since the presented cert is signed by a test CA) + # we use the same test server as the other tests, but use an agent with + # the config left to the default, which will not trust it (since the + # presented cert is signed by a test CA) self.mock_resolver.resolve_service.side_effect = lambda _: [] self.reactor.lookups["testserv"] = "1.2.3.4" + config = default_config("test", parse=True) + agent = MatrixFederationAgent( reactor=self.reactor, - tls_client_options_factory=ClientTLSOptionsFactory(self._config), + tls_client_options_factory=ClientTLSOptionsFactory(config), _srv_resolver=self.mock_resolver, _well_known_cache=self.well_known_cache, ) -- cgit 1.5.1 From 8f15832950148280ed5a9cf28b74971abfd09e4f Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Wed, 31 Jul 2019 20:39:22 +1000 Subject: Remove DelayedCall debugging from test runs (#5787) --- changelog.d/5787.misc | 1 + tests/unittest.py | 6 ------ 2 files changed, 1 insertion(+), 6 deletions(-) create mode 100644 changelog.d/5787.misc (limited to 'tests') diff --git a/changelog.d/5787.misc b/changelog.d/5787.misc new file mode 100644 index 0000000000..ead0b04b62 --- /dev/null +++ b/changelog.d/5787.misc @@ -0,0 +1 @@ +Remove DelayedCall debugging from the test suite, as it is no longer required in the vast majority of Synapse's tests. diff --git a/tests/unittest.py b/tests/unittest.py index f5fae21317..561cebc223 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -23,8 +23,6 @@ from mock import Mock from canonicaljson import json -import twisted -import twisted.logger from twisted.internet.defer import Deferred, succeed from twisted.python.threadpool import ThreadPool from twisted.trial import unittest @@ -80,10 +78,6 @@ class TestCase(unittest.TestCase): @around(self) def setUp(orig): - # enable debugging of delayed calls - this means that we get a - # traceback when a unit test exits leaving things on the reactor. - twisted.internet.base.DelayedCall.debug = True - # if we're not starting in the sentinel logcontext, then to be honest # all future bets are off. if LoggingContext.current_context() is not LoggingContext.sentinel: -- cgit 1.5.1 From bc3550352830bef2c484dd5d3dd5ff1fce29a6fc Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 1 Aug 2019 12:00:08 +0200 Subject: Add tests --- tests/rest/client/v2_alpha/test_register.py | 37 +++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) (limited to 'tests') diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 89a3f95c0a..bb867150f4 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -323,6 +323,8 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): "renew_at": 172800000, # Time in ms for 2 days "renew_by_email_enabled": True, "renew_email_subject": "Renew your account", + "account_renewed_html_path": "account_renewed.html", + "invalid_token_html_path": "invalid_token.html", } # Email config. @@ -373,6 +375,19 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) + # Check that we're getting HTML back. + content_type = None + for header in channel.result.get("headers", []): + if header[0] == b"Content-Type": + content_type = header[1] + self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result) + + # Check that the HTML we're getting is the one we expect on a successful renewal. + expected_html = self.hs.config.account_validity.account_renewed_html_content + self.assertEqual( + channel.result["body"], expected_html.encode("utf8"), channel.result + ) + # Move 3 days forward. If the renewal failed, every authed request with # our access token should be denied from now, otherwise they should # succeed. @@ -381,6 +396,28 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) + def test_renewal_invalid_token(self): + # Hit the renewal endpoint with an invalid token and check that it behaves as + # expected, i.e. that it responds with 404 Not Found and the correct HTML. + url = "/_matrix/client/unstable/account_validity/renew?token=123" + request, channel = self.make_request(b"GET", url) + self.render(request) + self.assertEquals(channel.result["code"], b"404", channel.result) + + # Check that we're getting HTML back. + content_type = None + for header in channel.result.get("headers", []): + if header[0] == b"Content-Type": + content_type = header[1] + self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result) + + # Check that the HTML we're getting is the one we expect when using an + # invalid/unknown token. + expected_html = self.hs.config.account_validity.invalid_token_html_content + self.assertEqual( + channel.result["body"], expected_html.encode("utf8"), channel.result + ) + def test_manual_email_send(self): self.email_attempts = [] -- cgit 1.5.1 From 8c9adcc95dee892f90d6acbbe5c54acbf621720b Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Thu, 1 Aug 2019 22:09:05 -0400 Subject: fix formatting --- changelog.d/5769.feature | 2 +- tests/handlers/test_e2e_keys.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/changelog.d/5769.feature b/changelog.d/5769.feature index c34257cb8f..bf994ca327 100644 --- a/changelog.d/5769.feature +++ b/changelog.d/5769.feature @@ -1 +1 @@ -allow uploading of cross-signing keys \ No newline at end of file +Allow uploading of cross-signing keys. \ No newline at end of file diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 9ae4cb6ea2..a62c52eefa 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -176,7 +176,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase): } yield self.handler.upload_signing_keys_for_user(local_user, keys2) - devices = yield self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user) + devices = yield self.handler.query_devices( + {"device_keys": {local_user: []}}, 0, local_user + ) self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]}) @defer.inlineCallbacks -- cgit 1.5.1 From af9f1c07646031bf267aa20f2629d5b96c6603b6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 6 Aug 2019 16:27:46 +0100 Subject: Add a lower bound for TTL on well known results. It costs both us and the remote server for us to fetch the well known for every single request we send, so we add a minimum cache period. This is set to 5m so that we still honour the basic premise of "refetch frequently". --- synapse/http/federation/matrix_federation_agent.py | 4 ++++ tests/http/federation/test_matrix_federation_agent.py | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index a0d5139839..79e488c942 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -47,6 +47,9 @@ WELL_KNOWN_INVALID_CACHE_PERIOD = 1 * 3600 # cap for .well-known cache period WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600 +# lower bound for .well-known cache period +WELL_KNOWN_MIN_CACHE_PERIOD = 5 * 60 + logger = logging.getLogger(__name__) well_known_cache = TTLCache("well-known") @@ -356,6 +359,7 @@ class MatrixFederationAgent(object): cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER) else: cache_period = min(cache_period, WELL_KNOWN_MAX_CACHE_PERIOD) + cache_period = max(cache_period, WELL_KNOWN_MIN_CACHE_PERIOD) return (result, cache_period) diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 4255add097..5e709c0c17 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -953,7 +953,7 @@ class MatrixFederationAgentTests(TestCase): well_known_server = self._handle_well_known_connection( client_factory, expected_sni=b"testserv", - response_headers={b"Cache-Control": b"max-age=10"}, + response_headers={b"Cache-Control": b"max-age=1000"}, content=b'{ "m.server": "target-server" }', ) @@ -969,7 +969,7 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(r, b"target-server") # expire the cache - self.reactor.pump((10.0,)) + self.reactor.pump((1000.0,)) # now it should connect again fetch_d = self.do_get_well_known(b"testserv") -- cgit 1.5.1 From 107ad133fcf5b5a87e2aa1d53ab415aa39a01266 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 7 Aug 2019 15:36:38 +0100 Subject: Move well known lookup into a separate clas --- synapse/http/federation/matrix_federation_agent.py | 166 ++----------------- synapse/http/federation/well_known_resolver.py | 184 +++++++++++++++++++++ .../federation/test_matrix_federation_agent.py | 39 +++-- 3 files changed, 216 insertions(+), 173 deletions(-) create mode 100644 synapse/http/federation/well_known_resolver.py (limited to 'tests') diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 79e488c942..71a15f434d 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -12,10 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json + import logging -import random -import time import attr from netaddr import IPAddress @@ -24,34 +22,16 @@ from zope.interface import implementer from twisted.internet import defer from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet.interfaces import IStreamClientEndpoint -from twisted.web.client import URI, Agent, HTTPConnectionPool, RedirectAgent, readBody -from twisted.web.http import stringToDatetime +from twisted.web.client import URI, Agent, HTTPConnectionPool from twisted.web.http_headers import Headers from twisted.web.iweb import IAgent from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list +from synapse.http.federation.well_known_resolver import WellKnownResolver from synapse.logging.context import make_deferred_yieldable from synapse.util import Clock -from synapse.util.caches.ttlcache import TTLCache -from synapse.util.metrics import Measure - -# period to cache .well-known results for by default -WELL_KNOWN_DEFAULT_CACHE_PERIOD = 24 * 3600 - -# jitter to add to the .well-known default cache ttl -WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER = 10 * 60 - -# period to cache failure to fetch .well-known for -WELL_KNOWN_INVALID_CACHE_PERIOD = 1 * 3600 - -# cap for .well-known cache period -WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600 - -# lower bound for .well-known cache period -WELL_KNOWN_MIN_CACHE_PERIOD = 5 * 60 logger = logging.getLogger(__name__) -well_known_cache = TTLCache("well-known") @implementer(IAgent) @@ -81,7 +61,7 @@ class MatrixFederationAgent(object): reactor, tls_client_options_factory, _srv_resolver=None, - _well_known_cache=well_known_cache, + _well_known_cache=None, ): self._reactor = reactor self._clock = Clock(reactor) @@ -96,20 +76,15 @@ class MatrixFederationAgent(object): self._pool.maxPersistentPerHost = 5 self._pool.cachedConnectionTimeout = 2 * 60 - _well_known_agent = RedirectAgent( - Agent( + self._well_known_resolver = WellKnownResolver( + self._reactor, + agent=Agent( self._reactor, pool=self._pool, contextFactory=tls_client_options_factory, - ) + ), + well_known_cache=_well_known_cache, ) - self._well_known_agent = _well_known_agent - - # our cache of .well-known lookup results, mapping from server name - # to delegated name. The values can be: - # `bytes`: a valid server-name - # `None`: there is no (valid) .well-known here - self._well_known_cache = _well_known_cache @defer.inlineCallbacks def request(self, method, uri, headers=None, bodyProducer=None): @@ -220,7 +195,10 @@ class MatrixFederationAgent(object): if lookup_well_known: # try a .well-known lookup - well_known_server = yield self._get_well_known(parsed_uri.host) + well_known_result = yield self._well_known_resolver.get_well_known( + parsed_uri.host + ) + well_known_server = well_known_result.delegated_server if well_known_server: # if we found a .well-known, start again, but don't do another @@ -283,86 +261,6 @@ class MatrixFederationAgent(object): target_port=port, ) - @defer.inlineCallbacks - def _get_well_known(self, server_name): - """Attempt to fetch and parse a .well-known file for the given server - - Args: - server_name (bytes): name of the server, from the requested url - - Returns: - Deferred[bytes|None]: either the new server name, from the .well-known, or - None if there was no .well-known file. - """ - try: - result = self._well_known_cache[server_name] - except KeyError: - # TODO: should we linearise so that we don't end up doing two .well-known - # requests for the same server in parallel? - with Measure(self._clock, "get_well_known"): - result, cache_period = yield self._do_get_well_known(server_name) - - if cache_period > 0: - self._well_known_cache.set(server_name, result, cache_period) - - return result - - @defer.inlineCallbacks - def _do_get_well_known(self, server_name): - """Actually fetch and parse a .well-known, without checking the cache - - Args: - server_name (bytes): name of the server, from the requested url - - Returns: - Deferred[Tuple[bytes|None|object],int]: - result, cache period, where result is one of: - - the new server name from the .well-known (as a `bytes`) - - None if there was no .well-known file. - - INVALID_WELL_KNOWN if the .well-known was invalid - """ - uri = b"https://%s/.well-known/matrix/server" % (server_name,) - uri_str = uri.decode("ascii") - logger.info("Fetching %s", uri_str) - try: - response = yield make_deferred_yieldable( - self._well_known_agent.request(b"GET", uri) - ) - body = yield make_deferred_yieldable(readBody(response)) - if response.code != 200: - raise Exception("Non-200 response %s" % (response.code,)) - - parsed_body = json.loads(body.decode("utf-8")) - logger.info("Response from .well-known: %s", parsed_body) - if not isinstance(parsed_body, dict): - raise Exception("not a dict") - if "m.server" not in parsed_body: - raise Exception("Missing key 'm.server'") - except Exception as e: - logger.info("Error fetching %s: %s", uri_str, e) - - # add some randomness to the TTL to avoid a stampeding herd every hour - # after startup - cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD - cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER) - return (None, cache_period) - - result = parsed_body["m.server"].encode("ascii") - - cache_period = _cache_period_from_headers( - response.headers, time_now=self._reactor.seconds - ) - if cache_period is None: - cache_period = WELL_KNOWN_DEFAULT_CACHE_PERIOD - # add some randomness to the TTL to avoid a stampeding herd every 24 hours - # after startup - cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER) - else: - cache_period = min(cache_period, WELL_KNOWN_MAX_CACHE_PERIOD) - cache_period = max(cache_period, WELL_KNOWN_MIN_CACHE_PERIOD) - - return (result, cache_period) - @implementer(IStreamClientEndpoint) class LoggingHostnameEndpoint(object): @@ -378,44 +276,6 @@ class LoggingHostnameEndpoint(object): return self.ep.connect(protocol_factory) -def _cache_period_from_headers(headers, time_now=time.time): - cache_controls = _parse_cache_control(headers) - - if b"no-store" in cache_controls: - return 0 - - if b"max-age" in cache_controls: - try: - max_age = int(cache_controls[b"max-age"]) - return max_age - except ValueError: - pass - - expires = headers.getRawHeaders(b"expires") - if expires is not None: - try: - expires_date = stringToDatetime(expires[-1]) - return expires_date - time_now() - except ValueError: - # RFC7234 says 'A cache recipient MUST interpret invalid date formats, - # especially the value "0", as representing a time in the past (i.e., - # "already expired"). - return 0 - - return None - - -def _parse_cache_control(headers): - cache_controls = {} - for hdr in headers.getRawHeaders(b"cache-control", []): - for directive in hdr.split(b","): - splits = [x.strip() for x in directive.split(b"=", 1)] - k = splits[0].lower() - v = splits[1] if len(splits) > 1 else None - cache_controls[k] = v - return cache_controls - - @attr.s class _RoutingResult(object): """The result returned by `_route_matrix_uri`. diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py new file mode 100644 index 0000000000..bab4ab015e --- /dev/null +++ b/synapse/http/federation/well_known_resolver.py @@ -0,0 +1,184 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import random +import time + +import attr + +from twisted.internet import defer +from twisted.web.client import RedirectAgent, readBody +from twisted.web.http import stringToDatetime + +from synapse.logging.context import make_deferred_yieldable +from synapse.util import Clock +from synapse.util.caches.ttlcache import TTLCache +from synapse.util.metrics import Measure + +# period to cache .well-known results for by default +WELL_KNOWN_DEFAULT_CACHE_PERIOD = 24 * 3600 + +# jitter to add to the .well-known default cache ttl +WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER = 10 * 60 + +# period to cache failure to fetch .well-known for +WELL_KNOWN_INVALID_CACHE_PERIOD = 1 * 3600 + +# cap for .well-known cache period +WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600 + +# lower bound for .well-known cache period +WELL_KNOWN_MIN_CACHE_PERIOD = 5 * 60 + +logger = logging.getLogger(__name__) + + +@attr.s(slots=True, frozen=True) +class WellKnownLookupResult(object): + delegated_server = attr.ib() + + +class WellKnownResolver(object): + """Handles well-known lookups for matrix servers. + """ + + def __init__(self, reactor, agent, well_known_cache=None): + self._reactor = reactor + self._clock = Clock(reactor) + + if well_known_cache is None: + well_known_cache = TTLCache("well-known") + + self._well_known_cache = well_known_cache + self._well_known_agent = RedirectAgent(agent) + + @defer.inlineCallbacks + def get_well_known(self, server_name): + """Attempt to fetch and parse a .well-known file for the given server + + Args: + server_name (bytes): name of the server, from the requested url + + Returns: + Deferred[WellKnownLookupResult]: The result of the lookup + """ + try: + result = self._well_known_cache[server_name] + except KeyError: + # TODO: should we linearise so that we don't end up doing two .well-known + # requests for the same server in parallel? + with Measure(self._clock, "get_well_known"): + result, cache_period = yield self._do_get_well_known(server_name) + + if cache_period > 0: + self._well_known_cache.set(server_name, result, cache_period) + + return WellKnownLookupResult(delegated_server=result) + + @defer.inlineCallbacks + def _do_get_well_known(self, server_name): + """Actually fetch and parse a .well-known, without checking the cache + + Args: + server_name (bytes): name of the server, from the requested url + + Returns: + Deferred[Tuple[bytes|None|object],int]: + result, cache period, where result is one of: + - the new server name from the .well-known (as a `bytes`) + - None if there was no .well-known file. + - INVALID_WELL_KNOWN if the .well-known was invalid + """ + uri = b"https://%s/.well-known/matrix/server" % (server_name,) + uri_str = uri.decode("ascii") + logger.info("Fetching %s", uri_str) + try: + response = yield make_deferred_yieldable( + self._well_known_agent.request(b"GET", uri) + ) + body = yield make_deferred_yieldable(readBody(response)) + if response.code != 200: + raise Exception("Non-200 response %s" % (response.code,)) + + parsed_body = json.loads(body.decode("utf-8")) + logger.info("Response from .well-known: %s", parsed_body) + if not isinstance(parsed_body, dict): + raise Exception("not a dict") + if "m.server" not in parsed_body: + raise Exception("Missing key 'm.server'") + except Exception as e: + logger.info("Error fetching %s: %s", uri_str, e) + + # add some randomness to the TTL to avoid a stampeding herd every hour + # after startup + cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD + cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER) + return (None, cache_period) + + result = parsed_body["m.server"].encode("ascii") + + cache_period = _cache_period_from_headers( + response.headers, time_now=self._reactor.seconds + ) + if cache_period is None: + cache_period = WELL_KNOWN_DEFAULT_CACHE_PERIOD + # add some randomness to the TTL to avoid a stampeding herd every 24 hours + # after startup + cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER) + else: + cache_period = min(cache_period, WELL_KNOWN_MAX_CACHE_PERIOD) + cache_period = max(cache_period, WELL_KNOWN_MIN_CACHE_PERIOD) + + return (result, cache_period) + + +def _cache_period_from_headers(headers, time_now=time.time): + cache_controls = _parse_cache_control(headers) + + if b"no-store" in cache_controls: + return 0 + + if b"max-age" in cache_controls: + try: + max_age = int(cache_controls[b"max-age"]) + return max_age + except ValueError: + pass + + expires = headers.getRawHeaders(b"expires") + if expires is not None: + try: + expires_date = stringToDatetime(expires[-1]) + return expires_date - time_now() + except ValueError: + # RFC7234 says 'A cache recipient MUST interpret invalid date formats, + # especially the value "0", as representing a time in the past (i.e., + # "already expired"). + return 0 + + return None + + +def _parse_cache_control(headers): + cache_controls = {} + for hdr in headers.getRawHeaders(b"cache-control", []): + for directive in hdr.split(b","): + splits = [x.strip() for x in directive.split(b"=", 1)] + k = splits[0].lower() + v = splits[1] if len(splits) > 1 else None + cache_controls[k] = v + return cache_controls diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 5e709c0c17..1435baede2 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -25,17 +25,19 @@ from twisted.internet._sslverify import ClientTLSOptions, OpenSSLCertificateOpti from twisted.internet.protocol import Factory from twisted.protocols.tls import TLSMemoryBIOFactory from twisted.web._newclient import ResponseNeverReceived +from twisted.web.client import Agent from twisted.web.http import HTTPChannel from twisted.web.http_headers import Headers from twisted.web.iweb import IPolicyForHTTPS from synapse.config.homeserver import HomeServerConfig from synapse.crypto.context_factory import ClientTLSOptionsFactory -from synapse.http.federation.matrix_federation_agent import ( - MatrixFederationAgent, +from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent +from synapse.http.federation.srv_resolver import Server +from synapse.http.federation.well_known_resolver import ( + WellKnownResolver, _cache_period_from_headers, ) -from synapse.http.federation.srv_resolver import Server from synapse.logging.context import LoggingContext from synapse.util.caches.ttlcache import TTLCache @@ -79,9 +81,10 @@ class MatrixFederationAgentTests(TestCase): self._config = config = HomeServerConfig() config.parse_config_dict(config_dict, "", "") + self.tls_factory = ClientTLSOptionsFactory(config) self.agent = MatrixFederationAgent( reactor=self.reactor, - tls_client_options_factory=ClientTLSOptionsFactory(config), + tls_client_options_factory=self.tls_factory, _srv_resolver=self.mock_resolver, _well_known_cache=self.well_known_cache, ) @@ -928,20 +931,16 @@ class MatrixFederationAgentTests(TestCase): self.reactor.pump((0.1,)) self.successResultOf(test_d) - @defer.inlineCallbacks - def do_get_well_known(self, serv): - try: - result = yield self.agent._get_well_known(serv) - logger.info("Result from well-known fetch: %s", result) - except Exception as e: - logger.warning("Error fetching well-known: %s", e) - raise - return result - def test_well_known_cache(self): + well_known_resolver = WellKnownResolver( + self.reactor, + Agent(self.reactor, contextFactory=self.tls_factory), + well_known_cache=self.well_known_cache, + ) + self.reactor.lookups["testserv"] = "1.2.3.4" - fetch_d = self.do_get_well_known(b"testserv") + fetch_d = well_known_resolver.get_well_known(b"testserv") # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients @@ -958,21 +957,21 @@ class MatrixFederationAgentTests(TestCase): ) r = self.successResultOf(fetch_d) - self.assertEqual(r, b"target-server") + self.assertEqual(r.delegated_server, b"target-server") # close the tcp connection well_known_server.loseConnection() # repeat the request: it should hit the cache - fetch_d = self.do_get_well_known(b"testserv") + fetch_d = well_known_resolver.get_well_known(b"testserv") r = self.successResultOf(fetch_d) - self.assertEqual(r, b"target-server") + self.assertEqual(r.delegated_server, b"target-server") # expire the cache self.reactor.pump((1000.0,)) # now it should connect again - fetch_d = self.do_get_well_known(b"testserv") + fetch_d = well_known_resolver.get_well_known(b"testserv") self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) @@ -986,7 +985,7 @@ class MatrixFederationAgentTests(TestCase): ) r = self.successResultOf(fetch_d) - self.assertEqual(r, b"other-server") + self.assertEqual(r.delegated_server, b"other-server") class TestCachePeriodFromHeaders(TestCase): -- cgit 1.5.1 From 17e1e807264ce13774d9b343c96406795ea24c27 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 12 Aug 2019 15:39:14 +0100 Subject: Retry well-known lookup before expiry. This gives a bit of a grace period where we can attempt to refetch a remote `well-known`, while still using the cached result if that fails. Hopefully this will make the well-known resolution a bit more torelant of failures, rather than it immediately treating failures as "no result" and caching that for an hour. --- synapse/http/federation/well_known_resolver.py | 82 ++++++++++++++++------ synapse/util/caches/ttlcache.py | 8 ++- .../federation/test_matrix_federation_agent.py | 69 ++++++++++++++++++ tests/util/caches/test_ttlcache.py | 4 +- 4 files changed, 136 insertions(+), 27 deletions(-) (limited to 'tests') diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py index d2866ff67d..bb250c6922 100644 --- a/synapse/http/federation/well_known_resolver.py +++ b/synapse/http/federation/well_known_resolver.py @@ -44,6 +44,12 @@ WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600 # lower bound for .well-known cache period WELL_KNOWN_MIN_CACHE_PERIOD = 5 * 60 +# Attempt to refetch a cached well-known N% of the TTL before it expires. +# e.g. if set to 0.2 and we have a cached entry with a TTL of 5mins, then +# we'll start trying to refetch 1 minute before it expires. +WELL_KNOWN_GRACE_PERIOD_FACTOR = 0.2 + + logger = logging.getLogger(__name__) @@ -80,15 +86,38 @@ class WellKnownResolver(object): Deferred[WellKnownLookupResult]: The result of the lookup """ try: - result = self._well_known_cache[server_name] + prev_result, expiry, ttl = self._well_known_cache.get_with_expiry( + server_name + ) + + now = self._clock.time() + if now < expiry - WELL_KNOWN_GRACE_PERIOD_FACTOR * ttl: + return WellKnownLookupResult(delegated_server=prev_result) except KeyError: - # TODO: should we linearise so that we don't end up doing two .well-known - # requests for the same server in parallel? + prev_result = None + + # TODO: should we linearise so that we don't end up doing two .well-known + # requests for the same server in parallel? + try: with Measure(self._clock, "get_well_known"): result, cache_period = yield self._do_get_well_known(server_name) - if cache_period > 0: - self._well_known_cache.set(server_name, result, cache_period) + except _FetchWellKnownFailure as e: + if prev_result and e.temporary: + # This is a temporary failure and we have a still valid cached + # result, so lets return that. Hopefully the next time we ask + # the remote will be back up again. + return WellKnownLookupResult(delegated_server=prev_result) + + result = None + + # add some randomness to the TTL to avoid a stampeding herd every hour + # after startup + cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD + cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER) + + if cache_period > 0: + self._well_known_cache.set(server_name, result, cache_period) return WellKnownLookupResult(delegated_server=result) @@ -99,40 +128,42 @@ class WellKnownResolver(object): Args: server_name (bytes): name of the server, from the requested url + Raises: + _FetchWellKnownFailure if we fail to lookup a result + Returns: - Deferred[Tuple[bytes|None|object],int]: - result, cache period, where result is one of: - - the new server name from the .well-known (as a `bytes`) - - None if there was no .well-known file. - - INVALID_WELL_KNOWN if the .well-known was invalid + Deferred[Tuple[bytes,int]]: The lookup result and cache period. """ uri = b"https://%s/.well-known/matrix/server" % (server_name,) uri_str = uri.decode("ascii") logger.info("Fetching %s", uri_str) + + # We do this in two steps to differentiate between possibly transient + # errors (e.g. can't connect to host, 503 response) and more permenant + # errors (such as getting a 404 response). try: response = yield make_deferred_yieldable( self._well_known_agent.request(b"GET", uri) ) body = yield make_deferred_yieldable(readBody(response)) + + if 500 <= response.code < 600: + raise Exception("Non-200 response %s" % (response.code,)) + except Exception as e: + logger.info("Error fetching %s: %s", uri_str, e) + raise _FetchWellKnownFailure(temporary=True) + + try: if response.code != 200: raise Exception("Non-200 response %s" % (response.code,)) parsed_body = json.loads(body.decode("utf-8")) logger.info("Response from .well-known: %s", parsed_body) - if not isinstance(parsed_body, dict): - raise Exception("not a dict") - if "m.server" not in parsed_body: - raise Exception("Missing key 'm.server'") + + result = parsed_body["m.server"].encode("ascii") except Exception as e: logger.info("Error fetching %s: %s", uri_str, e) - - # add some randomness to the TTL to avoid a stampeding herd every hour - # after startup - cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD - cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER) - return (None, cache_period) - - result = parsed_body["m.server"].encode("ascii") + raise _FetchWellKnownFailure(temporary=False) cache_period = _cache_period_from_headers( response.headers, time_now=self._reactor.seconds @@ -185,3 +216,10 @@ def _parse_cache_control(headers): v = splits[1] if len(splits) > 1 else None cache_controls[k] = v return cache_controls + + +@attr.s() +class _FetchWellKnownFailure(Exception): + # True if we didn't get a non-5xx HTTP response, i.e. this may or may not be + # a temporary failure. + temporary = attr.ib() diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py index 2af8ca43b1..99646c7cf0 100644 --- a/synapse/util/caches/ttlcache.py +++ b/synapse/util/caches/ttlcache.py @@ -55,7 +55,7 @@ class TTLCache(object): if e != SENTINEL: self._expiry_list.remove(e) - entry = _CacheEntry(expiry_time=expiry, key=key, value=value) + entry = _CacheEntry(expiry_time=expiry, ttl=ttl, key=key, value=value) self._data[key] = entry self._expiry_list.add(entry) @@ -87,7 +87,8 @@ class TTLCache(object): key: key to look up Returns: - Tuple[Any, float]: the value from the cache, and the expiry time + Tuple[Any, float, float]: the value from the cache, the expiry time + and the TTL Raises: KeyError if the entry is not found @@ -99,7 +100,7 @@ class TTLCache(object): self._metrics.inc_misses() raise self._metrics.inc_hits() - return e.value, e.expiry_time + return e.value, e.expiry_time, e.ttl def pop(self, key, default=SENTINEL): """Remove a value from the cache @@ -158,5 +159,6 @@ class _CacheEntry(object): # expiry_time is the first attribute, so that entries are sorted by expiry. expiry_time = attr.ib() + ttl = attr.ib() key = attr.ib() value = attr.ib() diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 1435baede2..2c568788b3 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -987,6 +987,75 @@ class MatrixFederationAgentTests(TestCase): r = self.successResultOf(fetch_d) self.assertEqual(r.delegated_server, b"other-server") + def test_well_known_cache_with_temp_failure(self): + """Test that we refetch well-known before the cache expires, and that + it ignores transient errors. + """ + + well_known_resolver = WellKnownResolver( + self.reactor, + Agent(self.reactor, contextFactory=self.tls_factory), + well_known_cache=self.well_known_cache, + ) + + self.reactor.lookups["testserv"] = "1.2.3.4" + + fetch_d = well_known_resolver.get_well_known(b"testserv") + + # there should be an attempt to connect on port 443 for the .well-known + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 443) + + well_known_server = self._handle_well_known_connection( + client_factory, + expected_sni=b"testserv", + response_headers={b"Cache-Control": b"max-age=1000"}, + content=b'{ "m.server": "target-server" }', + ) + + r = self.successResultOf(fetch_d) + self.assertEqual(r.delegated_server, b"target-server") + + # close the tcp connection + well_known_server.loseConnection() + + # Get close to the cache expiry, this will cause the resolver to do + # another lookup. + self.reactor.pump((900.0,)) + + fetch_d = well_known_resolver.get_well_known(b"testserv") + clients = self.reactor.tcpClients + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + + # fonx the connection attempt, this will be treated as a temporary + # failure. + client_factory.clientConnectionFailed(None, Exception("nope")) + + # attemptdelay on the hostnameendpoint is 0.3, so takes that long before the + # .well-known request fails. + self.reactor.pump((0.4,)) + + # Resolver should return cached value, despite the lookup failing. + r = self.successResultOf(fetch_d) + self.assertEqual(r.delegated_server, b"target-server") + + # Expire the cache and repeat the request + self.reactor.pump((100.0,)) + + # Repated the request, this time it should fail if the lookup fails. + fetch_d = well_known_resolver.get_well_known(b"testserv") + + clients = self.reactor.tcpClients + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + client_factory.clientConnectionFailed(None, Exception("nope")) + self.reactor.pump((0.4,)) + + r = self.successResultOf(fetch_d) + self.assertEqual(r.delegated_server, None) + class TestCachePeriodFromHeaders(TestCase): def test_cache_control(self): diff --git a/tests/util/caches/test_ttlcache.py b/tests/util/caches/test_ttlcache.py index c94cbb662b..816795c136 100644 --- a/tests/util/caches/test_ttlcache.py +++ b/tests/util/caches/test_ttlcache.py @@ -36,7 +36,7 @@ class CacheTestCase(unittest.TestCase): self.assertTrue("one" in self.cache) self.assertEqual(self.cache.get("one"), "1") self.assertEqual(self.cache["one"], "1") - self.assertEqual(self.cache.get_with_expiry("one"), ("1", 110)) + self.assertEqual(self.cache.get_with_expiry("one"), ("1", 110, 10)) self.assertEqual(self.cache._metrics.hits, 3) self.assertEqual(self.cache._metrics.misses, 0) @@ -77,7 +77,7 @@ class CacheTestCase(unittest.TestCase): self.assertEqual(self.cache["two"], "2") self.assertEqual(self.cache["three"], "3") - self.assertEqual(self.cache.get_with_expiry("two"), ("2", 120)) + self.assertEqual(self.cache.get_with_expiry("two"), ("2", 120, 20)) self.assertEqual(self.cache._metrics.hits, 5) self.assertEqual(self.cache._metrics.misses, 0) -- cgit 1.5.1 From e6e136decca12648933f974e4151fb936ad9e1fa Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 13 Aug 2019 18:04:46 +0100 Subject: Retry well known on fail. If we have recently seen a valid well-known for a domain we want to retry on (non-final) errors a few times, to handle temporary blips in networking/etc. --- synapse/http/federation/matrix_federation_agent.py | 26 +++-- synapse/http/federation/well_known_resolver.py | 122 +++++++++++++++++---- .../federation/test_matrix_federation_agent.py | 79 +++++++------ 3 files changed, 160 insertions(+), 67 deletions(-) (limited to 'tests') diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 71a15f434d..64f62aaeec 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -51,9 +51,9 @@ class MatrixFederationAgent(object): SRVResolver impl to use for looking up SRV records. None to use a default implementation. - _well_known_cache (TTLCache|None): - TTLCache impl for storing cached well-known lookups. None to use a default - implementation. + _well_known_resolver (WellKnownResolver|None): + WellKnownResolver to use to perform well-known lookups. None to use a + default implementation. """ def __init__( @@ -61,7 +61,7 @@ class MatrixFederationAgent(object): reactor, tls_client_options_factory, _srv_resolver=None, - _well_known_cache=None, + _well_known_resolver=None, ): self._reactor = reactor self._clock = Clock(reactor) @@ -76,15 +76,17 @@ class MatrixFederationAgent(object): self._pool.maxPersistentPerHost = 5 self._pool.cachedConnectionTimeout = 2 * 60 - self._well_known_resolver = WellKnownResolver( - self._reactor, - agent=Agent( + if _well_known_resolver is None: + _well_known_resolver = WellKnownResolver( self._reactor, - pool=self._pool, - contextFactory=tls_client_options_factory, - ), - well_known_cache=_well_known_cache, - ) + agent=Agent( + self._reactor, + pool=self._pool, + contextFactory=tls_client_options_factory, + ), + ) + + self._well_known_resolver = _well_known_resolver @defer.inlineCallbacks def request(self, method, uri, headers=None, bodyProducer=None): diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py index bb250c6922..d59864e298 100644 --- a/synapse/http/federation/well_known_resolver.py +++ b/synapse/http/federation/well_known_resolver.py @@ -38,6 +38,13 @@ WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER = 10 * 60 # period to cache failure to fetch .well-known for WELL_KNOWN_INVALID_CACHE_PERIOD = 1 * 3600 +# period to cache failure to fetch .well-known if there has recently been a +# valid well-known for that domain. +WELL_KNOWN_DOWN_CACHE_PERIOD = 2 * 60 + +# period to remember there was a valid well-known after valid record expires +WELL_KNOWN_REMEMBER_DOMAIN_HAD_VALID = 2 * 3600 + # cap for .well-known cache period WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600 @@ -49,11 +56,16 @@ WELL_KNOWN_MIN_CACHE_PERIOD = 5 * 60 # we'll start trying to refetch 1 minute before it expires. WELL_KNOWN_GRACE_PERIOD_FACTOR = 0.2 +# Number of times we retry fetching a well-known for a domain we know recently +# had a valid entry. +WELL_KNOWN_RETRY_ATTEMPTS = 3 + logger = logging.getLogger(__name__) _well_known_cache = TTLCache("well-known") +_had_valid_well_known_cache = TTLCache("had-valid-well-known") @attr.s(slots=True, frozen=True) @@ -65,14 +77,20 @@ class WellKnownResolver(object): """Handles well-known lookups for matrix servers. """ - def __init__(self, reactor, agent, well_known_cache=None): + def __init__( + self, reactor, agent, well_known_cache=None, had_well_known_cache=None + ): self._reactor = reactor self._clock = Clock(reactor) if well_known_cache is None: well_known_cache = _well_known_cache + if had_well_known_cache is None: + had_well_known_cache = _had_valid_well_known_cache + self._well_known_cache = well_known_cache + self._had_valid_well_known_cache = had_well_known_cache self._well_known_agent = RedirectAgent(agent) @defer.inlineCallbacks @@ -100,7 +118,7 @@ class WellKnownResolver(object): # requests for the same server in parallel? try: with Measure(self._clock, "get_well_known"): - result, cache_period = yield self._do_get_well_known(server_name) + result, cache_period = yield self._fetch_well_known(server_name) except _FetchWellKnownFailure as e: if prev_result and e.temporary: @@ -111,10 +129,20 @@ class WellKnownResolver(object): result = None - # add some randomness to the TTL to avoid a stampeding herd every hour - # after startup - cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD - cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER) + if self._had_valid_well_known_cache.get(server_name, False): + # We have recently seen a valid well-known record for this + # server, so we cache the lack of well-known for a shorter time. + cache_period = WELL_KNOWN_DOWN_CACHE_PERIOD + cache_period += random.uniform( + 0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER + ) + else: + # add some randomness to the TTL to avoid a stampeding herd every hour + # after startup + cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD + cache_period += random.uniform( + 0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER + ) if cache_period > 0: self._well_known_cache.set(server_name, result, cache_period) @@ -122,7 +150,7 @@ class WellKnownResolver(object): return WellKnownLookupResult(delegated_server=result) @defer.inlineCallbacks - def _do_get_well_known(self, server_name): + def _fetch_well_known(self, server_name): """Actually fetch and parse a .well-known, without checking the cache Args: @@ -134,24 +162,17 @@ class WellKnownResolver(object): Returns: Deferred[Tuple[bytes,int]]: The lookup result and cache period. """ - uri = b"https://%s/.well-known/matrix/server" % (server_name,) - uri_str = uri.decode("ascii") - logger.info("Fetching %s", uri_str) + + had_valid_well_known = bool( + self._had_valid_well_known_cache.get(server_name, False) + ) # We do this in two steps to differentiate between possibly transient # errors (e.g. can't connect to host, 503 response) and more permenant # errors (such as getting a 404 response). - try: - response = yield make_deferred_yieldable( - self._well_known_agent.request(b"GET", uri) - ) - body = yield make_deferred_yieldable(readBody(response)) - - if 500 <= response.code < 600: - raise Exception("Non-200 response %s" % (response.code,)) - except Exception as e: - logger.info("Error fetching %s: %s", uri_str, e) - raise _FetchWellKnownFailure(temporary=True) + response, body = yield self._make_well_known_request( + server_name, retry=had_valid_well_known + ) try: if response.code != 200: @@ -161,8 +182,11 @@ class WellKnownResolver(object): logger.info("Response from .well-known: %s", parsed_body) result = parsed_body["m.server"].encode("ascii") + except defer.CancelledError: + # Bail if we've been cancelled + raise except Exception as e: - logger.info("Error fetching %s: %s", uri_str, e) + logger.info("Error parsing well-known for %s: %s", server_name, e) raise _FetchWellKnownFailure(temporary=False) cache_period = _cache_period_from_headers( @@ -177,8 +201,62 @@ class WellKnownResolver(object): cache_period = min(cache_period, WELL_KNOWN_MAX_CACHE_PERIOD) cache_period = max(cache_period, WELL_KNOWN_MIN_CACHE_PERIOD) + # We got a success, mark as such in the cache + self._had_valid_well_known_cache.set( + server_name, + bool(result), + cache_period + WELL_KNOWN_REMEMBER_DOMAIN_HAD_VALID, + ) + return (result, cache_period) + @defer.inlineCallbacks + def _make_well_known_request(self, server_name, retry): + """Make the well known request. + + This will retry the request if requested and it fails (with unable + to connect or receives a 5xx error). + + Args: + server_name (bytes) + retry (bool): Whether to retry the request if it fails. + + Returns: + Deferred[tuple[IResponse, bytes]] Returns the response object and + body. Response may be a non-200 response. + """ + uri = b"https://%s/.well-known/matrix/server" % (server_name,) + uri_str = uri.decode("ascii") + + i = 0 + while True: + i += 1 + + logger.info("Fetching %s", uri_str) + try: + response = yield make_deferred_yieldable( + self._well_known_agent.request(b"GET", uri) + ) + body = yield make_deferred_yieldable(readBody(response)) + + if 500 <= response.code < 600: + raise Exception("Non-200 response %s" % (response.code,)) + + return response, body + except defer.CancelledError: + # Bail if we've been cancelled + raise + except Exception as e: + logger.info("Retry: %s", retry) + if not retry or i >= WELL_KNOWN_RETRY_ATTEMPTS: + logger.info("Error fetching %s: %s", uri_str, e) + raise _FetchWellKnownFailure(temporary=True) + + logger.info("Error fetching %s: %s. Retrying", uri_str, e) + + # Sleep briefly in the hopes that they come back up + yield self._clock.sleep(0.5) + def _cache_period_from_headers(headers, time_now=time.time): cache_controls = _parse_cache_control(headers) diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 2c568788b3..4d3f31d18c 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -73,8 +73,6 @@ class MatrixFederationAgentTests(TestCase): self.mock_resolver = Mock() - self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds) - config_dict = default_config("test", parse=False) config_dict["federation_custom_ca_list"] = [get_test_ca_cert_file()] @@ -82,11 +80,21 @@ class MatrixFederationAgentTests(TestCase): config.parse_config_dict(config_dict, "", "") self.tls_factory = ClientTLSOptionsFactory(config) + + self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds) + self.had_well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds) + self.well_known_resolver = WellKnownResolver( + self.reactor, + Agent(self.reactor, contextFactory=self.tls_factory), + well_known_cache=self.well_known_cache, + had_well_known_cache=self.had_well_known_cache, + ) + self.agent = MatrixFederationAgent( reactor=self.reactor, tls_client_options_factory=self.tls_factory, _srv_resolver=self.mock_resolver, - _well_known_cache=self.well_known_cache, + _well_known_resolver=self.well_known_resolver, ) def _make_connection(self, client_factory, expected_sni): @@ -701,11 +709,18 @@ class MatrixFederationAgentTests(TestCase): config = default_config("test", parse=True) + # Build a new agent and WellKnownResolver with a different tls factory + tls_factory = ClientTLSOptionsFactory(config) agent = MatrixFederationAgent( reactor=self.reactor, - tls_client_options_factory=ClientTLSOptionsFactory(config), + tls_client_options_factory=tls_factory, _srv_resolver=self.mock_resolver, - _well_known_cache=self.well_known_cache, + _well_known_resolver=WellKnownResolver( + self.reactor, + Agent(self.reactor, contextFactory=tls_factory), + well_known_cache=self.well_known_cache, + had_well_known_cache=self.had_well_known_cache, + ), ) test_d = agent.request(b"GET", b"matrix://testserv/foo/bar") @@ -932,15 +947,9 @@ class MatrixFederationAgentTests(TestCase): self.successResultOf(test_d) def test_well_known_cache(self): - well_known_resolver = WellKnownResolver( - self.reactor, - Agent(self.reactor, contextFactory=self.tls_factory), - well_known_cache=self.well_known_cache, - ) - self.reactor.lookups["testserv"] = "1.2.3.4" - fetch_d = well_known_resolver.get_well_known(b"testserv") + fetch_d = self.well_known_resolver.get_well_known(b"testserv") # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients @@ -963,7 +972,7 @@ class MatrixFederationAgentTests(TestCase): well_known_server.loseConnection() # repeat the request: it should hit the cache - fetch_d = well_known_resolver.get_well_known(b"testserv") + fetch_d = self.well_known_resolver.get_well_known(b"testserv") r = self.successResultOf(fetch_d) self.assertEqual(r.delegated_server, b"target-server") @@ -971,7 +980,7 @@ class MatrixFederationAgentTests(TestCase): self.reactor.pump((1000.0,)) # now it should connect again - fetch_d = well_known_resolver.get_well_known(b"testserv") + fetch_d = self.well_known_resolver.get_well_known(b"testserv") self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) @@ -992,15 +1001,9 @@ class MatrixFederationAgentTests(TestCase): it ignores transient errors. """ - well_known_resolver = WellKnownResolver( - self.reactor, - Agent(self.reactor, contextFactory=self.tls_factory), - well_known_cache=self.well_known_cache, - ) - self.reactor.lookups["testserv"] = "1.2.3.4" - fetch_d = well_known_resolver.get_well_known(b"testserv") + fetch_d = self.well_known_resolver.get_well_known(b"testserv") # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients @@ -1026,27 +1029,37 @@ class MatrixFederationAgentTests(TestCase): # another lookup. self.reactor.pump((900.0,)) - fetch_d = well_known_resolver.get_well_known(b"testserv") - clients = self.reactor.tcpClients - (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + fetch_d = self.well_known_resolver.get_well_known(b"testserv") - # fonx the connection attempt, this will be treated as a temporary - # failure. - client_factory.clientConnectionFailed(None, Exception("nope")) + # The resolver may retry a few times, so fonx all requests that come along + attempts = 0 + while self.reactor.tcpClients: + clients = self.reactor.tcpClients + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) - # attemptdelay on the hostnameendpoint is 0.3, so takes that long before the - # .well-known request fails. - self.reactor.pump((0.4,)) + attempts += 1 + + # fonx the connection attempt, this will be treated as a temporary + # failure. + client_factory.clientConnectionFailed(None, Exception("nope")) + + # There's a few sleeps involved, so we have to pump the reactor a + # bit. + self.reactor.pump((1.0, 1.0)) + + # We expect to see more than one attempt as there was previously a valid + # well known. + self.assertGreater(attempts, 1) # Resolver should return cached value, despite the lookup failing. r = self.successResultOf(fetch_d) self.assertEqual(r.delegated_server, b"target-server") - # Expire the cache and repeat the request - self.reactor.pump((100.0,)) + # Expire both caches and repeat the request + self.reactor.pump((10000.0,)) # Repated the request, this time it should fail if the lookup fails. - fetch_d = well_known_resolver.get_well_known(b"testserv") + fetch_d = self.well_known_resolver.get_well_known(b"testserv") clients = self.reactor.tcpClients (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) -- cgit 1.5.1 From f299c5414c2dd300103b0e11e7114123d8eb58a1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 8 Aug 2019 15:30:04 +0100 Subject: Refactor MatrixFederationAgent to retry SRV. This refactors MatrixFederationAgent to move the SRV lookup into the endpoint code, this has two benefits: 1. Its easier to retry different host/ports in the same way as HostnameEndpoint. 2. We avoid SRV lookups if we have a free connection in the pool --- synapse/http/federation/matrix_federation_agent.py | 356 ++++++++++----------- synapse/http/federation/srv_resolver.py | 35 +- .../federation/test_matrix_federation_agent.py | 63 +++- tests/http/federation/test_srv_resolver.py | 8 +- 4 files changed, 268 insertions(+), 194 deletions(-) (limited to 'tests') diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 71a15f434d..c208185791 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -14,21 +14,21 @@ # limitations under the License. import logging +import urllib -import attr -from netaddr import IPAddress +from netaddr import AddrFormatError, IPAddress from zope.interface import implementer from twisted.internet import defer from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet.interfaces import IStreamClientEndpoint -from twisted.web.client import URI, Agent, HTTPConnectionPool +from twisted.web.client import Agent, HTTPConnectionPool from twisted.web.http_headers import Headers -from twisted.web.iweb import IAgent +from twisted.web.iweb import IAgent, IAgentEndpointFactory -from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list +from synapse.http.federation.srv_resolver import Server, SrvResolver from synapse.http.federation.well_known_resolver import WellKnownResolver -from synapse.logging.context import make_deferred_yieldable +from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.util import Clock logger = logging.getLogger(__name__) @@ -36,8 +36,9 @@ logger = logging.getLogger(__name__) @implementer(IAgent) class MatrixFederationAgent(object): - """An Agent-like thing which provides a `request` method which will look up a matrix - server and send an HTTP request to it. + """An Agent-like thing which provides a `request` method which correctly + handles resolving matrix server names when using matrix://. Handles standard + https URIs as normal. Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.) @@ -65,23 +66,25 @@ class MatrixFederationAgent(object): ): self._reactor = reactor self._clock = Clock(reactor) - - self._tls_client_options_factory = tls_client_options_factory - if _srv_resolver is None: - _srv_resolver = SrvResolver() - self._srv_resolver = _srv_resolver - self._pool = HTTPConnectionPool(reactor) self._pool.retryAutomatically = False self._pool.maxPersistentPerHost = 5 self._pool.cachedConnectionTimeout = 2 * 60 + self._agent = Agent.usingEndpointFactory( + self._reactor, + MatrixHostnameEndpointFactory( + reactor, tls_client_options_factory, _srv_resolver + ), + pool=self._pool, + ) + self._well_known_resolver = WellKnownResolver( self._reactor, agent=Agent( self._reactor, - pool=self._pool, contextFactory=tls_client_options_factory, + pool=self._pool, ), well_known_cache=_well_known_cache, ) @@ -91,19 +94,15 @@ class MatrixFederationAgent(object): """ Args: method (bytes): HTTP method: GET/POST/etc - uri (bytes): Absolute URI to be retrieved - headers (twisted.web.http_headers.Headers|None): HTTP headers to send with the request, or None to send no extra headers. - bodyProducer (twisted.web.iweb.IBodyProducer|None): An object which can generate bytes to make up the body of this request (for example, the properly encoded contents of a file for a file upload). Or None if the request is to have no body. - Returns: Deferred[twisted.web.iweb.IResponse]: fires when the header of the response has been received (regardless of the @@ -111,210 +110,195 @@ class MatrixFederationAgent(object): response from being received (including problems that prevent the request from being sent). """ - parsed_uri = URI.fromBytes(uri, defaultPort=-1) - res = yield self._route_matrix_uri(parsed_uri) + # We use urlparse as that will set `port` to None if there is no + # explicit port. + parsed_uri = urllib.parse.urlparse(uri) - # set up the TLS connection params + # If this is a matrix:// URI check if the server has delegated matrix + # traffic using well-known delegation. # - # XXX disabling TLS is really only supported here for the benefit of the - # unit tests. We should make the UTs cope with TLS rather than having to make - # the code support the unit tests. - if self._tls_client_options_factory is None: - tls_options = None - else: - tls_options = self._tls_client_options_factory.get_options( - res.tls_server_name.decode("ascii") + # We have to do this here and not in the endpoint as we need to rewrite + # the host header with the delegated server name. + delegated_server = None + if ( + parsed_uri.scheme == b"matrix" + and not _is_ip_literal(parsed_uri.hostname) + and not parsed_uri.port + ): + well_known_result = yield self._well_known_resolver.get_well_known( + parsed_uri.hostname + ) + delegated_server = well_known_result.delegated_server + + if delegated_server: + # Ok, the server has delegated matrix traffic to somewhere else, so + # lets rewrite the URL to replace the server with the delegated + # server name. + uri = urllib.parse.urlunparse( + ( + parsed_uri.scheme, + delegated_server, + parsed_uri.path, + parsed_uri.params, + parsed_uri.query, + parsed_uri.fragment, + ) ) + parsed_uri = urllib.parse.urlparse(uri) - # make sure that the Host header is set correctly + # We need to make sure the host header is set to the netloc of the + # server. if headers is None: headers = Headers() else: headers = headers.copy() if not headers.hasHeader(b"host"): - headers.addRawHeader(b"host", res.host_header) + headers.addRawHeader(b"host", parsed_uri.netloc) - class EndpointFactory(object): - @staticmethod - def endpointForURI(_uri): - ep = LoggingHostnameEndpoint( - self._reactor, res.target_host, res.target_port - ) - if tls_options is not None: - ep = wrapClientTLS(tls_options, ep) - return ep + with PreserveLoggingContext(): + res = yield self._agent.request(method, uri, headers, bodyProducer) - agent = Agent.usingEndpointFactory(self._reactor, EndpointFactory(), self._pool) - res = yield make_deferred_yieldable( - agent.request(method, uri, headers, bodyProducer) - ) return res - @defer.inlineCallbacks - def _route_matrix_uri(self, parsed_uri, lookup_well_known=True): - """Helper for `request`: determine the routing for a Matrix URI - Args: - parsed_uri (twisted.web.client.URI): uri to route. Note that it should be - parsed with URI.fromBytes(uri, defaultPort=-1) to set the `port` to -1 - if there is no explicit port given. +@implementer(IAgentEndpointFactory) +class MatrixHostnameEndpointFactory(object): + """Factory for MatrixHostnameEndpoint for parsing to an Agent. + """ - lookup_well_known (bool): True if we should look up the .well-known file if - there is no SRV record. + def __init__(self, reactor, tls_client_options_factory, srv_resolver): + self._reactor = reactor + self._tls_client_options_factory = tls_client_options_factory - Returns: - Deferred[_RoutingResult] - """ - # check for an IP literal - try: - ip_address = IPAddress(parsed_uri.host.decode("ascii")) - except Exception: - # not an IP address - ip_address = None - - if ip_address: - port = parsed_uri.port - if port == -1: - port = 8448 - return _RoutingResult( - host_header=parsed_uri.netloc, - tls_server_name=parsed_uri.host, - target_host=parsed_uri.host, - target_port=port, - ) + if srv_resolver is None: + srv_resolver = SrvResolver() - if parsed_uri.port != -1: - # there is an explicit port - return _RoutingResult( - host_header=parsed_uri.netloc, - tls_server_name=parsed_uri.host, - target_host=parsed_uri.host, - target_port=parsed_uri.port, - ) + self._srv_resolver = srv_resolver - if lookup_well_known: - # try a .well-known lookup - well_known_result = yield self._well_known_resolver.get_well_known( - parsed_uri.host - ) - well_known_server = well_known_result.delegated_server - - if well_known_server: - # if we found a .well-known, start again, but don't do another - # .well-known lookup. - - # parse the server name in the .well-known response into host/port. - # (This code is lifted from twisted.web.client.URI.fromBytes). - if b":" in well_known_server: - well_known_host, well_known_port = well_known_server.rsplit(b":", 1) - try: - well_known_port = int(well_known_port) - except ValueError: - # the part after the colon could not be parsed as an int - # - we assume it is an IPv6 literal with no port (the closing - # ']' stops it being parsed as an int) - well_known_host, well_known_port = well_known_server, -1 - else: - well_known_host, well_known_port = well_known_server, -1 - - new_uri = URI( - scheme=parsed_uri.scheme, - netloc=well_known_server, - host=well_known_host, - port=well_known_port, - path=parsed_uri.path, - params=parsed_uri.params, - query=parsed_uri.query, - fragment=parsed_uri.fragment, - ) + def endpointForURI(self, parsed_uri): + return MatrixHostnameEndpoint( + self._reactor, + self._tls_client_options_factory, + self._srv_resolver, + parsed_uri, + ) - res = yield self._route_matrix_uri(new_uri, lookup_well_known=False) - return res - - # try a SRV lookup - service_name = b"_matrix._tcp.%s" % (parsed_uri.host,) - server_list = yield self._srv_resolver.resolve_service(service_name) - - if not server_list: - target_host = parsed_uri.host - port = 8448 - logger.debug( - "No SRV record for %s, using %s:%i", - parsed_uri.host.decode("ascii"), - target_host.decode("ascii"), - port, - ) + +@implementer(IStreamClientEndpoint) +class MatrixHostnameEndpoint(object): + """An endpoint that resolves matrix:// URLs using Matrix server name + resolution (i.e. via SRV). Does not check for well-known delegation. + """ + + def __init__(self, reactor, tls_client_options_factory, srv_resolver, parsed_uri): + self._reactor = reactor + + # We reparse the URI so that defaultPort is -1 rather than 80 + self._parsed_uri = parsed_uri + + # set up the TLS connection params + # + # XXX disabling TLS is really only supported here for the benefit of the + # unit tests. We should make the UTs cope with TLS rather than having to make + # the code support the unit tests. + + if tls_client_options_factory is None: + self._tls_options = None else: - target_host, port = pick_server_from_list(server_list) - logger.debug( - "Picked %s:%i from SRV records for %s", - target_host.decode("ascii"), - port, - parsed_uri.host.decode("ascii"), + self._tls_options = tls_client_options_factory.get_options( + self._parsed_uri.host.decode("ascii") ) - return _RoutingResult( - host_header=parsed_uri.netloc, - tls_server_name=parsed_uri.host, - target_host=target_host, - target_port=port, - ) + self._srv_resolver = srv_resolver + @defer.inlineCallbacks + def connect(self, protocol_factory): + """Implements IStreamClientEndpoint interface + """ -@implementer(IStreamClientEndpoint) -class LoggingHostnameEndpoint(object): - """A wrapper for HostnameEndpint which logs when it connects""" + first_exception = None - def __init__(self, reactor, host, port, *args, **kwargs): - self.host = host - self.port = port - self.ep = HostnameEndpoint(reactor, host, port, *args, **kwargs) + server_list = yield self._resolve_server() - def connect(self, protocol_factory): - logger.info("Connecting to %s:%i", self.host.decode("ascii"), self.port) - return self.ep.connect(protocol_factory) + for server in server_list: + host = server.host + port = server.port + try: + logger.info("Connecting to %s:%i", host.decode("ascii"), port) + endpoint = HostnameEndpoint(self._reactor, host, port) + if self._tls_options: + endpoint = wrapClientTLS(self._tls_options, endpoint) + result = yield make_deferred_yieldable( + endpoint.connect(protocol_factory) + ) -@attr.s -class _RoutingResult(object): - """The result returned by `_route_matrix_uri`. + return result + except Exception as e: + logger.info( + "Failed to connect to %s:%i: %s", host.decode("ascii"), port, e + ) + if not first_exception: + first_exception = e - Contains the parameters needed to direct a federation connection to a particular - server. + # We return the first failure because that's probably the most interesting. + if first_exception: + raise first_exception - Where a SRV record points to several servers, this object contains a single server - chosen from the list. - """ + # This shouldn't happen as we should always have at least one host/port + # to try and if that doesn't work then we'll have an exception. + raise Exception("Failed to resolve server %r" % (self._parsed_uri.netloc,)) - host_header = attr.ib() - """ - The value we should assign to the Host header (host:port from the matrix - URI, or .well-known). + @defer.inlineCallbacks + def _resolve_server(self): + """Resolves the server name to a list of hosts and ports to attempt to + connect to. - :type: bytes - """ + Returns: + Deferred[list[Server]] + """ - tls_server_name = attr.ib() - """ - The server name we should set in the SNI (typically host, without port, from the - matrix URI or .well-known) + if self._parsed_uri.scheme != b"matrix": + return [Server(host=self._parsed_uri.host, port=self._parsed_uri.port)] - :type: bytes - """ + # Note: We don't do well-known lookup as that needs to have happened + # before now, due to needing to rewrite the Host header of the HTTP + # request. - target_host = attr.ib() - """ - The hostname (or IP literal) we should route the TCP connection to (the target of the - SRV record, or the hostname from the URL/.well-known) + parsed_uri = urllib.parse.urlparse(self._parsed_uri.toBytes()) - :type: bytes - """ + host = parsed_uri.hostname + port = parsed_uri.port - target_port = attr.ib() - """ - The port we should route the TCP connection to (the target of the SRV record, or - the port from the URL/.well-known, or 8448) + # If there is an explicit port or the host is an IP address we bypass + # SRV lookups and just use the given host/port. + if port or _is_ip_literal(host): + return [Server(host, port or 8448)] - :type: int + server_list = yield self._srv_resolver.resolve_service(b"_matrix._tcp." + host) + + if server_list: + return server_list + + # No SRV records, so we fallback to host and 8448 + return [Server(host, 8448)] + + +def _is_ip_literal(host): + """Test if the given host name is either an IPv4 or IPv6 literal. + + Args: + host (bytes) + + Returns: + bool """ + + host = host.decode("ascii") + + try: + IPAddress(host) + return True + except AddrFormatError: + return False diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py index b32188766d..bbda0a23f4 100644 --- a/synapse/http/federation/srv_resolver.py +++ b/synapse/http/federation/srv_resolver.py @@ -32,7 +32,7 @@ logger = logging.getLogger(__name__) SERVER_CACHE = {} -@attr.s +@attr.s(slots=True, frozen=True) class Server(object): """ Our record of an individual server which can be tried to reach a destination. @@ -83,6 +83,35 @@ def pick_server_from_list(server_list): raise RuntimeError("pick_server_from_list got to end of eligible server list.") +def _sort_server_list(server_list): + """Given a list of SRV records sort them into priority order and shuffle + each priority with the given weight. + """ + priority_map = {} + + for server in server_list: + priority_map.setdefault(server.priority, []).append(server) + + results = [] + for priority in sorted(priority_map): + servers = priority_map.pop(priority) + + while servers: + total_weight = sum(s.weight for s in servers) + target_weight = random.randint(0, total_weight) + + for s in servers: + target_weight -= s.weight + + if target_weight <= 0: + break + + results.append(s) + servers.remove(s) + + return results + + class SrvResolver(object): """Interface to the dns client to do SRV lookups, with result caching. @@ -120,7 +149,7 @@ class SrvResolver(object): if cache_entry: if all(s.expires > now for s in cache_entry): servers = list(cache_entry) - return servers + return _sort_server_list(servers) try: answers, _, _ = yield make_deferred_yieldable( @@ -169,4 +198,4 @@ class SrvResolver(object): ) self._cache[service_name] = list(servers) - return servers + return _sort_server_list(servers) diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 2c568788b3..f97c8a59f6 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -41,9 +41,9 @@ from synapse.http.federation.well_known_resolver import ( from synapse.logging.context import LoggingContext from synapse.util.caches.ttlcache import TTLCache +from tests import unittest from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file from tests.server import FakeTransport, ThreadedMemoryReactorClock -from tests.unittest import TestCase from tests.utils import default_config logger = logging.getLogger(__name__) @@ -67,7 +67,8 @@ def get_connection_factory(): return test_server_connection_factory -class MatrixFederationAgentTests(TestCase): +@unittest.DEBUG +class MatrixFederationAgentTests(unittest.TestCase): def setUp(self): self.reactor = ThreadedMemoryReactorClock() @@ -1056,8 +1057,64 @@ class MatrixFederationAgentTests(TestCase): r = self.successResultOf(fetch_d) self.assertEqual(r.delegated_server, None) + def test_srv_fallbacks(self): + """Test that other SRV results are tried if the first one fails. + """ + + self.mock_resolver.resolve_service.side_effect = lambda _: [ + Server(host=b"target.com", port=8443), + Server(host=b"target.com", port=8444), + ] + self.reactor.lookups["target.com"] = "1.2.3.4" + + test_d = self._make_get_request(b"matrix://testserv/foo/bar") + + # Nothing happened yet + self.assertNoResult(test_d) + + self.mock_resolver.resolve_service.assert_called_once_with( + b"_matrix._tcp.testserv" + ) + + # We should see an attempt to connect to the first server + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 8443) + + # Fonx the connection + client_factory.clientConnectionFailed(None, Exception("nope")) + + # There's a 300ms delay in HostnameEndpoint + self.reactor.pump((0.4,)) + + # Hasn't failed yet + self.assertNoResult(test_d) + + # We shouldnow see an attempt to connect to the second server + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 8444) + + # make a test server, and wire up the client + http_server = self._make_connection(client_factory, expected_sni=b"testserv") + + self.assertEqual(len(http_server.requests), 1) + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"]) + + # finish the request + request.finish() + self.reactor.pump((0.1,)) + self.successResultOf(test_d) + -class TestCachePeriodFromHeaders(TestCase): +class TestCachePeriodFromHeaders(unittest.TestCase): def test_cache_control(self): # uppercase self.assertEqual( diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index 3b885ef64b..df034ab237 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -83,8 +83,10 @@ class SrvResolverTestCase(unittest.TestCase): service_name = b"test_service.example.com" - entry = Mock(spec_set=["expires"]) + entry = Mock(spec_set=["expires", "priority", "weight"]) entry.expires = 0 + entry.priority = 0 + entry.weight = 0 cache = {service_name: [entry]} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) @@ -105,8 +107,10 @@ class SrvResolverTestCase(unittest.TestCase): service_name = b"test_service.example.com" - entry = Mock(spec_set=["expires"]) + entry = Mock(spec_set=["expires", "priority", "weight"]) entry.expires = 999999999 + entry.priority = 0 + entry.weight = 0 cache = {service_name: [entry]} resolver = SrvResolver( -- cgit 1.5.1 From 1dec31560e5712306e368a0adc6d9f84f924bdc9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 20 Aug 2019 11:46:00 +0100 Subject: Change jitter to be a factor rather than absolute value --- synapse/http/federation/well_known_resolver.py | 23 +++++++++++----------- .../federation/test_matrix_federation_agent.py | 4 ++-- 2 files changed, 14 insertions(+), 13 deletions(-) (limited to 'tests') diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py index c846003886..5e9b0befb0 100644 --- a/synapse/http/federation/well_known_resolver.py +++ b/synapse/http/federation/well_known_resolver.py @@ -32,8 +32,8 @@ from synapse.util.metrics import Measure # period to cache .well-known results for by default WELL_KNOWN_DEFAULT_CACHE_PERIOD = 24 * 3600 -# jitter to add to the .well-known default cache ttl -WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER = 10 * 60 +# jitter factor to add to the .well-known default cache ttls +WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER = 0.1 # period to cache failure to fetch .well-known for WELL_KNOWN_INVALID_CACHE_PERIOD = 1 * 3600 @@ -133,16 +133,14 @@ class WellKnownResolver(object): # We have recently seen a valid well-known record for this # server, so we cache the lack of well-known for a shorter time. cache_period = WELL_KNOWN_DOWN_CACHE_PERIOD - cache_period += random.uniform( - 0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER - ) else: - # add some randomness to the TTL to avoid a stampeding herd every hour - # after startup cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD - cache_period += random.uniform( - 0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER - ) + + # add some randomness to the TTL to avoid a stampeding herd + cache_period *= random.uniform( + 1 - WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER, + 1 + WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER, + ) if cache_period > 0: self._well_known_cache.set(server_name, result, cache_period) @@ -194,7 +192,10 @@ class WellKnownResolver(object): cache_period = WELL_KNOWN_DEFAULT_CACHE_PERIOD # add some randomness to the TTL to avoid a stampeding herd every 24 hours # after startup - cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER) + cache_period *= random.uniform( + 1 - WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER, + 1 + WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER, + ) else: cache_period = min(cache_period, WELL_KNOWN_MAX_CACHE_PERIOD) cache_period = max(cache_period, WELL_KNOWN_MIN_CACHE_PERIOD) diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 4d3f31d18c..c55aad8e11 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -551,7 +551,7 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(self.well_known_cache[b"testserv"], b"target-server") # check the cache expires - self.reactor.pump((25 * 3600,)) + self.reactor.pump((48 * 3600,)) self.well_known_cache.expire() self.assertNotIn(b"testserv", self.well_known_cache) @@ -639,7 +639,7 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(self.well_known_cache[b"testserv"], b"target-server") # check the cache expires - self.reactor.pump((25 * 3600,)) + self.reactor.pump((48 * 3600,)) self.well_known_cache.expire() self.assertNotIn(b"testserv", self.well_known_cache) -- cgit 1.5.1 From 7777d353bfffc840b79391da107e593338a1a2fe Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 20 Aug 2019 11:46:54 +0100 Subject: Remove test debugs --- tests/federation/test_federation_server.py | 1 - tests/http/federation/test_matrix_federation_agent.py | 1 - tests/test_visibility.py | 1 - 3 files changed, 3 deletions(-) (limited to 'tests') diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index af15f4cc5a..b08be451aa 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -20,7 +20,6 @@ from synapse.federation.federation_server import server_matches_acl_event from tests import unittest -@unittest.DEBUG class ServerACLsTestCase(unittest.TestCase): def test_blacklisted_server(self): e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]}) diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index f97c8a59f6..445a0e76ab 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -67,7 +67,6 @@ def get_connection_factory(): return test_server_connection_factory -@unittest.DEBUG class MatrixFederationAgentTests(unittest.TestCase): def setUp(self): self.reactor = ThreadedMemoryReactorClock() diff --git a/tests/test_visibility.py b/tests/test_visibility.py index e0605dac2f..18f1a0035d 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -74,7 +74,6 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id) self.assertEqual(filtered[i].content["a"], "b") - @tests.unittest.DEBUG @defer.inlineCallbacks def test_erased_user(self): # 4 message events, from erased and unerased users, with a membership -- cgit 1.5.1 From 501994582899ad9d790029b3d7c48ba32f5720a9 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 20 Aug 2019 11:20:10 +0100 Subject: Refactor the Appservice scheduler code Get rid of the labyrinthine `recoverer_fn` code, and clean up the startup code (it seemed to be previously inexplicably split between `ApplicationServiceScheduler.start` and `_Recoverer.start`). Add some docstrings too. --- changelog.d/5886.misc | 1 + synapse/appservice/scheduler.py | 110 ++++++++++++++++++++++--------------- tests/appservice/test_scheduler.py | 6 +- 3 files changed, 68 insertions(+), 49 deletions(-) create mode 100644 changelog.d/5886.misc (limited to 'tests') diff --git a/changelog.d/5886.misc b/changelog.d/5886.misc new file mode 100644 index 0000000000..22adba3d85 --- /dev/null +++ b/changelog.d/5886.misc @@ -0,0 +1 @@ +Refactor the Appservice scheduler code. diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 42a350bff8..03a14402c5 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -70,35 +70,37 @@ class ApplicationServiceScheduler(object): self.store = hs.get_datastore() self.as_api = hs.get_application_service_api() - def create_recoverer(service, callback): - return _Recoverer(self.clock, self.store, self.as_api, service, callback) - - self.txn_ctrl = _TransactionController( - self.clock, self.store, self.as_api, create_recoverer - ) + self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api) self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock) @defer.inlineCallbacks def start(self): logger.info("Starting appservice scheduler") + # check for any DOWN ASes and start recoverers for them. - recoverers = yield _Recoverer.start( - self.clock, self.store, self.as_api, self.txn_ctrl.on_recovered + services = yield self.store.get_appservices_by_state( + ApplicationServiceState.DOWN ) - self.txn_ctrl.add_recoverers(recoverers) + + for service in services: + self.txn_ctrl.start_recoverer(service) def submit_event_for_as(self, service, event): self.queuer.enqueue(service, event) class _ServiceQueuer(object): - """Queues events for the same application service together, sending - transactions as soon as possible. Once a transaction is sent successfully, - this schedules any other events in the queue to run. + """Queue of events waiting to be sent to appservices. + + Groups events into transactions per-appservice, and sends them on to the + TransactionController. Makes sure that we only have one transaction in flight per + appservice at a given time. """ def __init__(self, txn_ctrl, clock): self.queued_events = {} # dict of {service_id: [events]} + + # the appservices which currently have a transaction in flight self.requests_in_flight = set() self.txn_ctrl = txn_ctrl self.clock = clock @@ -136,13 +138,29 @@ class _ServiceQueuer(object): class _TransactionController(object): - def __init__(self, clock, store, as_api, recoverer_fn): + """Transaction manager. + + Builds AppServiceTransactions and runs their lifecycle. Also starts a Recoverer + if a transaction fails. + + (Note we have only have one of these in the homeserver.) + + Args: + clock (synapse.util.Clock): + store (synapse.storage.DataStore): + as_api (synapse.appservice.api.ApplicationServiceApi): + """ + + def __init__(self, clock, store, as_api): self.clock = clock self.store = store self.as_api = as_api - self.recoverer_fn = recoverer_fn - # keep track of how many recoverers there are - self.recoverers = [] + + # map from service id to recoverer instance + self.recoverers = {} + + # for UTs + self.RECOVERER_CLASS = _Recoverer @defer.inlineCallbacks def send(self, service, events): @@ -154,42 +172,45 @@ class _TransactionController(object): if sent: yield txn.complete(self.store) else: - run_in_background(self._start_recoverer, service) + run_in_background(self._on_txn_fail, service) except Exception: logger.exception("Error creating appservice transaction") - run_in_background(self._start_recoverer, service) + run_in_background(self._on_txn_fail, service) @defer.inlineCallbacks def on_recovered(self, recoverer): - self.recoverers.remove(recoverer) logger.info( "Successfully recovered application service AS ID %s", recoverer.service.id ) + self.recoverers.pop(recoverer.service.id) logger.info("Remaining active recoverers: %s", len(self.recoverers)) yield self.store.set_appservice_state( recoverer.service, ApplicationServiceState.UP ) - def add_recoverers(self, recoverers): - for r in recoverers: - self.recoverers.append(r) - if len(recoverers) > 0: - logger.info("New active recoverers: %s", len(self.recoverers)) - @defer.inlineCallbacks - def _start_recoverer(self, service): + def _on_txn_fail(self, service): try: yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN) - logger.info( - "Application service falling behind. Starting recoverer. AS ID %s", - service.id, - ) - recoverer = self.recoverer_fn(service, self.on_recovered) - self.add_recoverers([recoverer]) - recoverer.recover() + self.start_recoverer(service) except Exception: logger.exception("Error starting AS recoverer") + def start_recoverer(self, service): + """Start a Recoverer for the given service + + Args: + service (synapse.appservice.ApplicationService): + """ + logger.info("Starting recoverer for AS ID %s", service.id) + assert service.id not in self.recoverers + recoverer = self.RECOVERER_CLASS( + self.clock, self.store, self.as_api, service, self.on_recovered + ) + self.recoverers[service.id] = recoverer + recoverer.recover() + logger.info("Now %i active recoverers", len(self.recoverers)) + @defer.inlineCallbacks def _is_service_up(self, service): state = yield self.store.get_appservice_state(service) @@ -197,18 +218,17 @@ class _TransactionController(object): class _Recoverer(object): - @staticmethod - @defer.inlineCallbacks - def start(clock, store, as_api, callback): - services = yield store.get_appservices_by_state(ApplicationServiceState.DOWN) - recoverers = [_Recoverer(clock, store, as_api, s, callback) for s in services] - for r in recoverers: - logger.info( - "Starting recoverer for AS ID %s which was marked as " "DOWN", - r.service.id, - ) - r.recover() - return recoverers + """Manages retries and backoff for a DOWN appservice. + + We have one of these for each appservice which is currently considered DOWN. + + Args: + clock (synapse.util.Clock): + store (synapse.storage.DataStore): + as_api (synapse.appservice.api.ApplicationServiceApi): + service (synapse.appservice.ApplicationService): the service we are managing + callback (callable[_Recoverer]): called once the service recovers. + """ def __init__(self, clock, store, as_api, service, callback): self.clock = clock diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index 04b8c2c07c..52f89d3f83 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -37,11 +37,9 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.recoverer = Mock() self.recoverer_fn = Mock(return_value=self.recoverer) self.txnctrl = _TransactionController( - clock=self.clock, - store=self.store, - as_api=self.as_api, - recoverer_fn=self.recoverer_fn, + clock=self.clock, store=self.store, as_api=self.as_api ) + self.txnctrl.RECOVERER_CLASS = self.recoverer_fn def test_single_service_up_txn_sent(self): # Test: The AS is up and the txn is successfully sent. -- cgit 1.5.1 From 0fb5189072f61416e7f616aa9d90d8c981e6b8b3 Mon Sep 17 00:00:00 2001 From: Half-Shot Date: Fri, 23 Aug 2019 09:25:35 +0100 Subject: Fix registration test --- tests/storage/test_registration.py | 1 + 1 file changed, 1 insertion(+) (limited to 'tests') diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 0253c4ac05..4578cc3b60 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -49,6 +49,7 @@ class RegistrationStoreTestCase(unittest.TestCase): "consent_server_notice_sent": None, "appservice_id": None, "creation_ts": 1000, + "user_type": None, }, (yield self.store.get_user_by_id(self.user_id)), ) -- cgit 1.5.1 From 7dc398586c2156a456d9526ac0e42c1fec9f8143 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Wed, 28 Aug 2019 21:18:53 +1000 Subject: Implement a structured logging output system. (#5680) --- .buildkite/docker-compose.py35.pg95.yaml | 1 + .buildkite/docker-compose.py37.pg11.yaml | 1 + .buildkite/docker-compose.py37.pg95.yaml | 1 + .buildkite/pipeline.yml | 10 +- .gitignore | 5 +- changelog.d/5680.misc | 1 + docs/structured_logging.md | 83 +++++++ synapse/app/_base.py | 12 +- synapse/app/admin_cmd.py | 4 +- synapse/app/appservice.py | 4 +- synapse/app/client_reader.py | 4 +- synapse/app/event_creator.py | 4 +- synapse/app/federation_reader.py | 4 +- synapse/app/federation_sender.py | 4 +- synapse/app/frontend_proxy.py | 4 +- synapse/app/homeserver.py | 4 +- synapse/app/media_repository.py | 4 +- synapse/app/pusher.py | 4 +- synapse/app/synchrotron.py | 4 +- synapse/app/user_dir.py | 4 +- synapse/config/logger.py | 103 +++++---- synapse/handlers/federation.py | 5 +- synapse/logging/_structured.py | 374 +++++++++++++++++++++++++++++++ synapse/logging/_terse_json.py | 278 +++++++++++++++++++++++ synapse/logging/context.py | 14 +- synapse/python_dependencies.py | 6 +- tests/logging/__init__.py | 0 tests/logging/test_structured.py | 197 ++++++++++++++++ tests/logging/test_terse_json.py | 234 +++++++++++++++++++ tests/server.py | 27 ++- tox.ini | 10 + 31 files changed, 1328 insertions(+), 82 deletions(-) create mode 100644 changelog.d/5680.misc create mode 100644 docs/structured_logging.md create mode 100644 synapse/logging/_structured.py create mode 100644 synapse/logging/_terse_json.py create mode 100644 tests/logging/__init__.py create mode 100644 tests/logging/test_structured.py create mode 100644 tests/logging/test_terse_json.py (limited to 'tests') diff --git a/.buildkite/docker-compose.py35.pg95.yaml b/.buildkite/docker-compose.py35.pg95.yaml index 2f14387fbc..aaea33006b 100644 --- a/.buildkite/docker-compose.py35.pg95.yaml +++ b/.buildkite/docker-compose.py35.pg95.yaml @@ -6,6 +6,7 @@ services: image: postgres:9.5 environment: POSTGRES_PASSWORD: postgres + command: -c fsync=off testenv: image: python:3.5 diff --git a/.buildkite/docker-compose.py37.pg11.yaml b/.buildkite/docker-compose.py37.pg11.yaml index f3eec05ceb..1b32675e78 100644 --- a/.buildkite/docker-compose.py37.pg11.yaml +++ b/.buildkite/docker-compose.py37.pg11.yaml @@ -6,6 +6,7 @@ services: image: postgres:11 environment: POSTGRES_PASSWORD: postgres + command: -c fsync=off testenv: image: python:3.7 diff --git a/.buildkite/docker-compose.py37.pg95.yaml b/.buildkite/docker-compose.py37.pg95.yaml index 2a41db8eba..7679f6508d 100644 --- a/.buildkite/docker-compose.py37.pg95.yaml +++ b/.buildkite/docker-compose.py37.pg95.yaml @@ -6,6 +6,7 @@ services: image: postgres:9.5 environment: POSTGRES_PASSWORD: postgres + command: -c fsync=off testenv: image: python:3.7 diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index b75269a155..d9327227ed 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -45,8 +45,15 @@ steps: - docker#v3.0.1: image: "python:3.6" - - wait + - command: + - "python -m pip install tox" + - "tox -e mypy" + label: ":mypy: mypy" + plugins: + - docker#v3.0.1: + image: "python:3.5" + - wait - command: - "apt-get update && apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev zlib1g-dev" @@ -55,6 +62,7 @@ steps: label: ":python: 3.5 / SQLite / Old Deps" env: TRIAL_FLAGS: "-j 2" + LANG: "C.UTF-8" plugins: - docker#v3.0.1: image: "ubuntu:xenial" # We use xenail to get an old sqlite and python diff --git a/.gitignore b/.gitignore index f6168a8819..e53d4908d5 100644 --- a/.gitignore +++ b/.gitignore @@ -20,6 +20,7 @@ _trial_temp*/ /*.signing.key /env/ /homeserver*.yaml +/logs /media_store/ /uploads @@ -29,8 +30,9 @@ _trial_temp*/ /.vscode/ # build products -/.coverage* !/.coveragerc +/.coverage* +/.mypy_cache/ /.tox /build/ /coverage.* @@ -38,4 +40,3 @@ _trial_temp*/ /docs/build/ /htmlcov /pip-wheel-metadata/ - diff --git a/changelog.d/5680.misc b/changelog.d/5680.misc new file mode 100644 index 0000000000..46a403a188 --- /dev/null +++ b/changelog.d/5680.misc @@ -0,0 +1 @@ +Lay the groundwork for structured logging output. diff --git a/docs/structured_logging.md b/docs/structured_logging.md new file mode 100644 index 0000000000..decec9b8fa --- /dev/null +++ b/docs/structured_logging.md @@ -0,0 +1,83 @@ +# Structured Logging + +A structured logging system can be useful when your logs are destined for a machine to parse and process. By maintaining its machine-readable characteristics, it enables more efficient searching and aggregations when consumed by software such as the "ELK stack". + +Synapse's structured logging system is configured via the file that Synapse's `log_config` config option points to. The file must be YAML and contain `structured: true`. It must contain a list of "drains" (places where logs go to). + +A structured logging configuration looks similar to the following: + +```yaml +structured: true + +loggers: + synapse: + level: INFO + synapse.storage.SQL: + level: WARNING + +drains: + console: + type: console + location: stdout + file: + type: file_json + location: homeserver.log +``` + +The above logging config will set Synapse as 'INFO' logging level by default, with the SQL layer at 'WARNING', and will have two logging drains (to the console and to a file, stored as JSON). + +## Drain Types + +Drain types can be specified by the `type` key. + +### `console` + +Outputs human-readable logs to the console. + +Arguments: + +- `location`: Either `stdout` or `stderr`. + +### `console_json` + +Outputs machine-readable JSON logs to the console. + +Arguments: + +- `location`: Either `stdout` or `stderr`. + +### `console_json_terse` + +Outputs machine-readable JSON logs to the console, separated by newlines. This +format is not designed to be read and re-formatted into human-readable text, but +is optimal for a logging aggregation system. + +Arguments: + +- `location`: Either `stdout` or `stderr`. + +### `file` + +Outputs human-readable logs to a file. + +Arguments: + +- `location`: An absolute path to the file to log to. + +### `file_json` + +Outputs machine-readable logs to a file. + +Arguments: + +- `location`: An absolute path to the file to log to. + +### `network_json_terse` + +Delivers machine-readable JSON logs to a log aggregator over TCP. This is +compatible with LogStash's TCP input with the codec set to `json_lines`. + +Arguments: + +- `host`: Hostname or IP address of the log aggregator. +- `port`: Numerical port to contact on the host. \ No newline at end of file diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 69dcf3523f..c30fdeee9a 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -36,18 +36,20 @@ from synapse.util.versionstring import get_version_string logger = logging.getLogger(__name__) +# list of tuples of function, args list, kwargs dict _sighup_callbacks = [] -def register_sighup(func): +def register_sighup(func, *args, **kwargs): """ Register a function to be called when a SIGHUP occurs. Args: func (function): Function to be called when sent a SIGHUP signal. - Will be called with a single argument, the homeserver. + Will be called with a single default argument, the homeserver. + *args, **kwargs: args and kwargs to be passed to the target function. """ - _sighup_callbacks.append(func) + _sighup_callbacks.append((func, args, kwargs)) def start_worker_reactor(appname, config, run_command=reactor.run): @@ -248,8 +250,8 @@ def start(hs, listeners=None): # we're not using systemd. sdnotify(b"RELOADING=1") - for i in _sighup_callbacks: - i(hs) + for i, args, kwargs in _sighup_callbacks: + i(hs, *args, **kwargs) sdnotify(b"READY=1") diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 1fd52a5526..04751a6a5e 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -227,8 +227,6 @@ def start(config_options): config.start_pushers = False config.send_federation = False - setup_logging(config, use_worker_options=True) - synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts database_engine = create_engine(config.database_config) @@ -241,6 +239,8 @@ def start(config_options): database_engine=database_engine, ) + setup_logging(ss, config, use_worker_options=True) + ss.setup() # We use task.react as the basic run command as it correctly handles tearing diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index 54bb114dec..767b87d2db 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -141,8 +141,6 @@ def start(config_options): assert config.worker_app == "synapse.app.appservice" - setup_logging(config, use_worker_options=True) - events.USE_FROZEN_DICTS = config.use_frozen_dicts database_engine = create_engine(config.database_config) @@ -167,6 +165,8 @@ def start(config_options): database_engine=database_engine, ) + setup_logging(ps, config, use_worker_options=True) + ps.setup() reactor.addSystemEventTrigger( "before", "startup", _base.start, ps, config.worker_listeners diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index 721bb5b119..86193d35a8 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -179,8 +179,6 @@ def start(config_options): assert config.worker_app == "synapse.app.client_reader" - setup_logging(config, use_worker_options=True) - events.USE_FROZEN_DICTS = config.use_frozen_dicts database_engine = create_engine(config.database_config) @@ -193,6 +191,8 @@ def start(config_options): database_engine=database_engine, ) + setup_logging(ss, config, use_worker_options=True) + ss.setup() reactor.addSystemEventTrigger( "before", "startup", _base.start, ss, config.worker_listeners diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py index 473c8895d0..c67fe69a50 100644 --- a/synapse/app/event_creator.py +++ b/synapse/app/event_creator.py @@ -175,8 +175,6 @@ def start(config_options): assert config.worker_replication_http_port is not None - setup_logging(config, use_worker_options=True) - # This should only be done on the user directory worker or the master config.update_user_directory = False @@ -192,6 +190,8 @@ def start(config_options): database_engine=database_engine, ) + setup_logging(ss, config, use_worker_options=True) + ss.setup() reactor.addSystemEventTrigger( "before", "startup", _base.start, ss, config.worker_listeners diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index 5255d9e8cc..1ef027a88c 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -160,8 +160,6 @@ def start(config_options): assert config.worker_app == "synapse.app.federation_reader" - setup_logging(config, use_worker_options=True) - events.USE_FROZEN_DICTS = config.use_frozen_dicts database_engine = create_engine(config.database_config) @@ -174,6 +172,8 @@ def start(config_options): database_engine=database_engine, ) + setup_logging(ss, config, use_worker_options=True) + ss.setup() reactor.addSystemEventTrigger( "before", "startup", _base.start, ss, config.worker_listeners diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index c5a2880e69..04fbb407af 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -171,8 +171,6 @@ def start(config_options): assert config.worker_app == "synapse.app.federation_sender" - setup_logging(config, use_worker_options=True) - events.USE_FROZEN_DICTS = config.use_frozen_dicts database_engine = create_engine(config.database_config) @@ -197,6 +195,8 @@ def start(config_options): database_engine=database_engine, ) + setup_logging(ss, config, use_worker_options=True) + ss.setup() reactor.addSystemEventTrigger( "before", "startup", _base.start, ss, config.worker_listeners diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index e2822ca848..611d285421 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -232,8 +232,6 @@ def start(config_options): assert config.worker_main_http_uri is not None - setup_logging(config, use_worker_options=True) - events.USE_FROZEN_DICTS = config.use_frozen_dicts database_engine = create_engine(config.database_config) @@ -246,6 +244,8 @@ def start(config_options): database_engine=database_engine, ) + setup_logging(ss, config, use_worker_options=True) + ss.setup() reactor.addSystemEventTrigger( "before", "startup", _base.start, ss, config.worker_listeners diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 8233905844..04f1ed14f3 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -341,8 +341,6 @@ def setup(config_options): # generating config files and shouldn't try to continue. sys.exit(0) - synapse.config.logger.setup_logging(config, use_worker_options=False) - events.USE_FROZEN_DICTS = config.use_frozen_dicts database_engine = create_engine(config.database_config) @@ -356,6 +354,8 @@ def setup(config_options): database_engine=database_engine, ) + synapse.config.logger.setup_logging(hs, config, use_worker_options=False) + logger.info("Preparing database: %s...", config.database_config["name"]) try: diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index 3a168577c7..2ac783ffa3 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -155,8 +155,6 @@ def start(config_options): "Please add ``enable_media_repo: false`` to the main config\n" ) - setup_logging(config, use_worker_options=True) - events.USE_FROZEN_DICTS = config.use_frozen_dicts database_engine = create_engine(config.database_config) @@ -169,6 +167,8 @@ def start(config_options): database_engine=database_engine, ) + setup_logging(ss, config, use_worker_options=True) + ss.setup() reactor.addSystemEventTrigger( "before", "startup", _base.start, ss, config.worker_listeners diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index 692ffa2f04..d84732ee3c 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -184,8 +184,6 @@ def start(config_options): assert config.worker_app == "synapse.app.pusher" - setup_logging(config, use_worker_options=True) - events.USE_FROZEN_DICTS = config.use_frozen_dicts if config.start_pushers: @@ -210,6 +208,8 @@ def start(config_options): database_engine=database_engine, ) + setup_logging(ps, config, use_worker_options=True) + ps.setup() def start(): diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index a1c3b162f7..473026fce5 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -435,8 +435,6 @@ def start(config_options): assert config.worker_app == "synapse.app.synchrotron" - setup_logging(config, use_worker_options=True) - synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts database_engine = create_engine(config.database_config) @@ -450,6 +448,8 @@ def start(config_options): application_service_handler=SynchrotronApplicationService(), ) + setup_logging(ss, config, use_worker_options=True) + ss.setup() reactor.addSystemEventTrigger( "before", "startup", _base.start, ss, config.worker_listeners diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index cb29a1afab..e01afb39f2 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -197,8 +197,6 @@ def start(config_options): assert config.worker_app == "synapse.app.user_dir" - setup_logging(config, use_worker_options=True) - events.USE_FROZEN_DICTS = config.use_frozen_dicts database_engine = create_engine(config.database_config) @@ -223,6 +221,8 @@ def start(config_options): database_engine=database_engine, ) + setup_logging(ss, config, use_worker_options=True) + ss.setup() reactor.addSystemEventTrigger( "before", "startup", _base.start, ss, config.worker_listeners diff --git a/synapse/config/logger.py b/synapse/config/logger.py index d321d00b80..981df5a10c 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -25,6 +25,10 @@ from twisted.logger import STDLibLogObserver, globalLogBeginner import synapse from synapse.app import _base as appbase +from synapse.logging._structured import ( + reload_structured_logging, + setup_structured_logging, +) from synapse.logging.context import LoggingContextFilter from synapse.util.versionstring import get_version_string @@ -119,21 +123,10 @@ class LoggingConfig(Config): log_config_file.write(DEFAULT_LOG_CONFIG.substitute(log_file=log_file)) -def setup_logging(config, use_worker_options=False): - """ Set up python logging - - Args: - config (LoggingConfig | synapse.config.workers.WorkerConfig): - configuration data - - use_worker_options (bool): True to use the 'worker_log_config' option - instead of 'log_config'. - - register_sighup (func | None): Function to call to register a - sighup handler. +def _setup_stdlib_logging(config, log_config): + """ + Set up Python stdlib logging. """ - log_config = config.worker_log_config if use_worker_options else config.log_config - if log_config is None: log_format = ( "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s" @@ -151,35 +144,10 @@ def setup_logging(config, use_worker_options=False): handler.addFilter(LoggingContextFilter(request="")) logger.addHandler(handler) else: + logging.config.dictConfig(log_config) - def load_log_config(): - with open(log_config, "r") as f: - logging.config.dictConfig(yaml.safe_load(f)) - - def sighup(*args): - # it might be better to use a file watcher or something for this. - load_log_config() - logging.info("Reloaded log config from %s due to SIGHUP", log_config) - - load_log_config() - appbase.register_sighup(sighup) - - # make sure that the first thing we log is a thing we can grep backwards - # for - logging.warn("***** STARTING SERVER *****") - logging.warn("Server %s version %s", sys.argv[0], get_version_string(synapse)) - logging.info("Server hostname: %s", config.server_name) - - # It's critical to point twisted's internal logging somewhere, otherwise it - # stacks up and leaks kup to 64K object; - # see: https://twistedmatrix.com/trac/ticket/8164 - # - # Routing to the python logging framework could be a performance problem if - # the handlers blocked for a long time as python.logging is a blocking API - # see https://twistedmatrix.com/documents/current/core/howto/logger.html - # filed as https://github.com/matrix-org/synapse/issues/1727 - # - # However this may not be too much of a problem if we are just writing to a file. + # Route Twisted's native logging through to the standard library logging + # system. observer = STDLibLogObserver() def _log(event): @@ -201,3 +169,54 @@ def setup_logging(config, use_worker_options=False): ) if not config.no_redirect_stdio: print("Redirected stdout/stderr to logs") + + +def _reload_stdlib_logging(*args, log_config=None): + logger = logging.getLogger("") + + if not log_config: + logger.warn("Reloaded a blank config?") + + logging.config.dictConfig(log_config) + + +def setup_logging(hs, config, use_worker_options=False): + """ + Set up the logging subsystem. + + Args: + config (LoggingConfig | synapse.config.workers.WorkerConfig): + configuration data + + use_worker_options (bool): True to use the 'worker_log_config' option + instead of 'log_config'. + """ + log_config = config.worker_log_config if use_worker_options else config.log_config + + def read_config(*args, callback=None): + if log_config is None: + return None + + with open(log_config, "rb") as f: + log_config_body = yaml.safe_load(f.read()) + + if callback: + callback(log_config=log_config_body) + logging.info("Reloaded log config from %s due to SIGHUP", log_config) + + return log_config_body + + log_config_body = read_config() + + if log_config_body and log_config_body.get("structured") is True: + setup_structured_logging(hs, config, log_config_body) + appbase.register_sighup(read_config, callback=reload_structured_logging) + else: + _setup_stdlib_logging(config, log_config_body) + appbase.register_sighup(read_config, callback=_reload_stdlib_logging) + + # make sure that the first thing we log is a thing we can grep backwards + # for + logging.warn("***** STARTING SERVER *****") + logging.warn("Server %s version %s", sys.argv[0], get_version_string(synapse)) + logging.info("Server hostname: %s", config.server_name) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index c86903b98b..94306c94a9 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -326,8 +326,9 @@ class FederationHandler(BaseHandler): ours = yield self.store.get_state_groups_ids(room_id, seen) # state_maps is a list of mappings from (type, state_key) to event_id - # type: list[dict[tuple[str, str], str]] - state_maps = list(ours.values()) + state_maps = list( + ours.values() + ) # type: list[dict[tuple[str, str], str]] # we don't need this any more, let's delete it. del ours diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py new file mode 100644 index 0000000000..0367d6dfc4 --- /dev/null +++ b/synapse/logging/_structured.py @@ -0,0 +1,374 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os.path +import sys +import typing +import warnings + +import attr +from constantly import NamedConstant, Names, ValueConstant, Values +from zope.interface import implementer + +from twisted.logger import ( + FileLogObserver, + FilteringLogObserver, + ILogObserver, + LogBeginner, + Logger, + LogLevel, + LogLevelFilterPredicate, + LogPublisher, + eventAsText, + globalLogBeginner, + jsonFileLogObserver, +) + +from synapse.config._base import ConfigError +from synapse.logging._terse_json import ( + TerseJSONToConsoleLogObserver, + TerseJSONToTCPLogObserver, +) +from synapse.logging.context import LoggingContext + + +def stdlib_log_level_to_twisted(level: str) -> LogLevel: + """ + Convert a stdlib log level to Twisted's log level. + """ + lvl = level.lower().replace("warning", "warn") + return LogLevel.levelWithName(lvl) + + +@attr.s +@implementer(ILogObserver) +class LogContextObserver(object): + """ + An ILogObserver which adds Synapse-specific log context information. + + Attributes: + observer (ILogObserver): The target parent observer. + """ + + observer = attr.ib() + + def __call__(self, event: dict) -> None: + """ + Consume a log event and emit it to the parent observer after filtering + and adding log context information. + + Args: + event (dict) + """ + # Filter out some useless events that Twisted outputs + if "log_text" in event: + if event["log_text"].startswith("DNSDatagramProtocol starting on "): + return + + if event["log_text"].startswith("(UDP Port "): + return + + if event["log_text"].startswith("Timing out client") or event[ + "log_format" + ].startswith("Timing out client"): + return + + context = LoggingContext.current_context() + + # Copy the context information to the log event. + if context is not None: + context.copy_to_twisted_log_entry(event) + else: + # If there's no logging context, not even the root one, we might be + # starting up or it might be from non-Synapse code. Log it as if it + # came from the root logger. + event["request"] = None + event["scope"] = None + + self.observer(event) + + +class PythonStdlibToTwistedLogger(logging.Handler): + """ + Transform a Python stdlib log message into a Twisted one. + """ + + def __init__(self, observer, *args, **kwargs): + """ + Args: + observer (ILogObserver): A Twisted logging observer. + *args, **kwargs: Args/kwargs to be passed to logging.Handler. + """ + self.observer = observer + super().__init__(*args, **kwargs) + + def emit(self, record: logging.LogRecord) -> None: + """ + Emit a record to Twisted's observer. + + Args: + record (logging.LogRecord) + """ + + self.observer( + { + "log_time": record.created, + "log_text": record.getMessage(), + "log_format": "{log_text}", + "log_namespace": record.name, + "log_level": stdlib_log_level_to_twisted(record.levelname), + } + ) + + +def SynapseFileLogObserver(outFile: typing.io.TextIO) -> FileLogObserver: + """ + A log observer that formats events like the traditional log formatter and + sends them to `outFile`. + + Args: + outFile (file object): The file object to write to. + """ + + def formatEvent(_event: dict) -> str: + event = dict(_event) + event["log_level"] = event["log_level"].name.upper() + event["log_format"] = "- {log_namespace} - {log_level} - {request} - " + ( + event.get("log_format", "{log_text}") or "{log_text}" + ) + return eventAsText(event, includeSystem=False) + "\n" + + return FileLogObserver(outFile, formatEvent) + + +class DrainType(Names): + CONSOLE = NamedConstant() + CONSOLE_JSON = NamedConstant() + CONSOLE_JSON_TERSE = NamedConstant() + FILE = NamedConstant() + FILE_JSON = NamedConstant() + NETWORK_JSON_TERSE = NamedConstant() + + +class OutputPipeType(Values): + stdout = ValueConstant(sys.__stdout__) + stderr = ValueConstant(sys.__stderr__) + + +@attr.s +class DrainConfiguration(object): + name = attr.ib() + type = attr.ib() + location = attr.ib() + options = attr.ib(default=None) + + +@attr.s +class NetworkJSONTerseOptions(object): + maximum_buffer = attr.ib(type=int) + + +DEFAULT_LOGGERS = {"synapse": {"level": "INFO"}} + + +def parse_drain_configs( + drains: dict +) -> typing.Generator[DrainConfiguration, None, None]: + """ + Parse the drain configurations. + + Args: + drains (dict): A list of drain configurations. + + Yields: + DrainConfiguration instances. + + Raises: + ConfigError: If any of the drain configuration items are invalid. + """ + for name, config in drains.items(): + if "type" not in config: + raise ConfigError("Logging drains require a 'type' key.") + + try: + logging_type = DrainType.lookupByName(config["type"].upper()) + except ValueError: + raise ConfigError( + "%s is not a known logging drain type." % (config["type"],) + ) + + if logging_type in [ + DrainType.CONSOLE, + DrainType.CONSOLE_JSON, + DrainType.CONSOLE_JSON_TERSE, + ]: + location = config.get("location") + if location is None or location not in ["stdout", "stderr"]: + raise ConfigError( + ( + "The %s drain needs the 'location' key set to " + "either 'stdout' or 'stderr'." + ) + % (logging_type,) + ) + + pipe = OutputPipeType.lookupByName(location).value + + yield DrainConfiguration(name=name, type=logging_type, location=pipe) + + elif logging_type in [DrainType.FILE, DrainType.FILE_JSON]: + if "location" not in config: + raise ConfigError( + "The %s drain needs the 'location' key set." % (logging_type,) + ) + + location = config.get("location") + if os.path.abspath(location) != location: + raise ConfigError( + "File paths need to be absolute, '%s' is a relative path" + % (location,) + ) + yield DrainConfiguration(name=name, type=logging_type, location=location) + + elif logging_type in [DrainType.NETWORK_JSON_TERSE]: + host = config.get("host") + port = config.get("port") + maximum_buffer = config.get("maximum_buffer", 1000) + yield DrainConfiguration( + name=name, + type=logging_type, + location=(host, port), + options=NetworkJSONTerseOptions(maximum_buffer=maximum_buffer), + ) + + else: + raise ConfigError( + "The %s drain type is currently not implemented." + % (config["type"].upper(),) + ) + + +def setup_structured_logging( + hs, + config, + log_config: dict, + logBeginner: LogBeginner = globalLogBeginner, + redirect_stdlib_logging: bool = True, +) -> LogPublisher: + """ + Set up Twisted's structured logging system. + + Args: + hs: The homeserver to use. + config (HomeserverConfig): The configuration of the Synapse homeserver. + log_config (dict): The log configuration to use. + """ + if config.no_redirect_stdio: + raise ConfigError( + "no_redirect_stdio cannot be defined using structured logging." + ) + + logger = Logger() + + if "drains" not in log_config: + raise ConfigError("The logging configuration requires a list of drains.") + + observers = [] + + for observer in parse_drain_configs(log_config["drains"]): + # Pipe drains + if observer.type == DrainType.CONSOLE: + logger.debug( + "Starting up the {name} console logger drain", name=observer.name + ) + observers.append(SynapseFileLogObserver(observer.location)) + elif observer.type == DrainType.CONSOLE_JSON: + logger.debug( + "Starting up the {name} JSON console logger drain", name=observer.name + ) + observers.append(jsonFileLogObserver(observer.location)) + elif observer.type == DrainType.CONSOLE_JSON_TERSE: + logger.debug( + "Starting up the {name} terse JSON console logger drain", + name=observer.name, + ) + observers.append( + TerseJSONToConsoleLogObserver(observer.location, metadata={}) + ) + + # File drains + elif observer.type == DrainType.FILE: + logger.debug("Starting up the {name} file logger drain", name=observer.name) + log_file = open(observer.location, "at", buffering=1, encoding="utf8") + observers.append(SynapseFileLogObserver(log_file)) + elif observer.type == DrainType.FILE_JSON: + logger.debug( + "Starting up the {name} JSON file logger drain", name=observer.name + ) + log_file = open(observer.location, "at", buffering=1, encoding="utf8") + observers.append(jsonFileLogObserver(log_file)) + + elif observer.type == DrainType.NETWORK_JSON_TERSE: + metadata = {"server_name": hs.config.server_name} + log_observer = TerseJSONToTCPLogObserver( + hs=hs, + host=observer.location[0], + port=observer.location[1], + metadata=metadata, + maximum_buffer=observer.options.maximum_buffer, + ) + log_observer.start() + observers.append(log_observer) + else: + # We should never get here, but, just in case, throw an error. + raise ConfigError("%s drain type cannot be configured" % (observer.type,)) + + publisher = LogPublisher(*observers) + log_filter = LogLevelFilterPredicate() + + for namespace, namespace_config in log_config.get( + "loggers", DEFAULT_LOGGERS + ).items(): + # Set the log level for twisted.logger.Logger namespaces + log_filter.setLogLevelForNamespace( + namespace, + stdlib_log_level_to_twisted(namespace_config.get("level", "INFO")), + ) + + # Also set the log levels for the stdlib logger namespaces, to prevent + # them getting to PythonStdlibToTwistedLogger and having to be formatted + if "level" in namespace_config: + logging.getLogger(namespace).setLevel(namespace_config.get("level")) + + f = FilteringLogObserver(publisher, [log_filter]) + lco = LogContextObserver(f) + + if redirect_stdlib_logging: + stuff_into_twisted = PythonStdlibToTwistedLogger(lco) + stdliblogger = logging.getLogger() + stdliblogger.addHandler(stuff_into_twisted) + + # Always redirect standard I/O, otherwise other logging outputs might miss + # it. + logBeginner.beginLoggingTo([lco], redirectStandardIO=True) + + return publisher + + +def reload_structured_logging(*args, log_config=None) -> None: + warnings.warn( + "Currently the structured logging system can not be reloaded, doing nothing" + ) diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py new file mode 100644 index 0000000000..7f1e8f23fe --- /dev/null +++ b/synapse/logging/_terse_json.py @@ -0,0 +1,278 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Log formatters that output terse JSON. +""" + +import sys +from collections import deque +from ipaddress import IPv4Address, IPv6Address, ip_address +from math import floor +from typing.io import TextIO + +import attr +from simplejson import dumps + +from twisted.application.internet import ClientService +from twisted.internet.endpoints import ( + HostnameEndpoint, + TCP4ClientEndpoint, + TCP6ClientEndpoint, +) +from twisted.internet.protocol import Factory, Protocol +from twisted.logger import FileLogObserver, Logger +from twisted.python.failure import Failure + + +def flatten_event(event: dict, metadata: dict, include_time: bool = False): + """ + Flatten a Twisted logging event to an dictionary capable of being sent + as a log event to a logging aggregation system. + + The format is vastly simplified and is not designed to be a "human readable + string" in the sense that traditional logs are. Instead, the structure is + optimised for searchability and filtering, with human-understandable log + keys. + + Args: + event (dict): The Twisted logging event we are flattening. + metadata (dict): Additional data to include with each log message. This + can be information like the server name. Since the target log + consumer does not know who we are other than by host IP, this + allows us to forward through static information. + include_time (bool): Should we include the `time` key? If False, the + event time is stripped from the event. + """ + new_event = {} + + # If it's a failure, make the new event's log_failure be the traceback text. + if "log_failure" in event: + new_event["log_failure"] = event["log_failure"].getTraceback() + + # If it's a warning, copy over a string representation of the warning. + if "warning" in event: + new_event["warning"] = str(event["warning"]) + + # Stdlib logging events have "log_text" as their human-readable portion, + # Twisted ones have "log_format". For now, include the log_format, so that + # context only given in the log format (e.g. what is being logged) is + # available. + if "log_text" in event: + new_event["log"] = event["log_text"] + else: + new_event["log"] = event["log_format"] + + # We want to include the timestamp when forwarding over the network, but + # exclude it when we are writing to stdout. This is because the log ingester + # (e.g. logstash, fluentd) can add its own timestamp. + if include_time: + new_event["time"] = round(event["log_time"], 2) + + # Convert the log level to a textual representation. + new_event["level"] = event["log_level"].name.upper() + + # Ignore these keys, and do not transfer them over to the new log object. + # They are either useless (isError), transferred manually above (log_time, + # log_level, etc), or contain Python objects which are not useful for output + # (log_logger, log_source). + keys_to_delete = [ + "isError", + "log_failure", + "log_format", + "log_level", + "log_logger", + "log_source", + "log_system", + "log_time", + "log_text", + "observer", + "warning", + ] + + # If it's from the Twisted legacy logger (twisted.python.log), it adds some + # more keys we want to purge. + if event.get("log_namespace") == "log_legacy": + keys_to_delete.extend(["message", "system", "time"]) + + # Rather than modify the dictionary in place, construct a new one with only + # the content we want. The original event should be considered 'frozen'. + for key in event.keys(): + + if key in keys_to_delete: + continue + + if isinstance(event[key], (str, int, bool, float)) or event[key] is None: + # If it's a plain type, include it as is. + new_event[key] = event[key] + else: + # If it's not one of those basic types, write out a string + # representation. This should probably be a warning in development, + # so that we are sure we are only outputting useful data. + new_event[key] = str(event[key]) + + # Add the metadata information to the event (e.g. the server_name). + new_event.update(metadata) + + return new_event + + +def TerseJSONToConsoleLogObserver(outFile: TextIO, metadata: dict) -> FileLogObserver: + """ + A log observer that formats events to a flattened JSON representation. + + Args: + outFile: The file object to write to. + metadata: Metadata to be added to each log object. + """ + + def formatEvent(_event: dict) -> str: + flattened = flatten_event(_event, metadata) + return dumps(flattened, ensure_ascii=False, separators=(",", ":")) + "\n" + + return FileLogObserver(outFile, formatEvent) + + +@attr.s +class TerseJSONToTCPLogObserver(object): + """ + An IObserver that writes JSON logs to a TCP target. + + Args: + hs (HomeServer): The Homeserver that is being logged for. + host: The host of the logging target. + port: The logging target's port. + metadata: Metadata to be added to each log entry. + """ + + hs = attr.ib() + host = attr.ib(type=str) + port = attr.ib(type=int) + metadata = attr.ib(type=dict) + maximum_buffer = attr.ib(type=int) + _buffer = attr.ib(default=attr.Factory(deque), type=deque) + _writer = attr.ib(default=None) + _logger = attr.ib(default=attr.Factory(Logger)) + + def start(self) -> None: + + # Connect without DNS lookups if it's a direct IP. + try: + ip = ip_address(self.host) + if isinstance(ip, IPv4Address): + endpoint = TCP4ClientEndpoint( + self.hs.get_reactor(), self.host, self.port + ) + elif isinstance(ip, IPv6Address): + endpoint = TCP6ClientEndpoint( + self.hs.get_reactor(), self.host, self.port + ) + except ValueError: + endpoint = HostnameEndpoint(self.hs.get_reactor(), self.host, self.port) + + factory = Factory.forProtocol(Protocol) + self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor()) + self._service.startService() + + def _write_loop(self) -> None: + """ + Implement the write loop. + """ + if self._writer: + return + + self._writer = self._service.whenConnected() + + @self._writer.addBoth + def writer(r): + if isinstance(r, Failure): + r.printTraceback(file=sys.__stderr__) + self._writer = None + self.hs.get_reactor().callLater(1, self._write_loop) + return + + try: + for event in self._buffer: + r.transport.write( + dumps(event, ensure_ascii=False, separators=(",", ":")).encode( + "utf8" + ) + ) + r.transport.write(b"\n") + self._buffer.clear() + except Exception as e: + sys.__stderr__.write("Failed writing out logs with %s\n" % (str(e),)) + + self._writer = False + self.hs.get_reactor().callLater(1, self._write_loop) + + def _handle_pressure(self) -> None: + """ + Handle backpressure by shedding events. + + The buffer will, in this order, until the buffer is below the maximum: + - Shed DEBUG events + - Shed INFO events + - Shed the middle 50% of the events. + """ + if len(self._buffer) <= self.maximum_buffer: + return + + # Strip out DEBUGs + self._buffer = deque( + filter(lambda event: event["level"] != "DEBUG", self._buffer) + ) + + if len(self._buffer) <= self.maximum_buffer: + return + + # Strip out INFOs + self._buffer = deque( + filter(lambda event: event["level"] != "INFO", self._buffer) + ) + + if len(self._buffer) <= self.maximum_buffer: + return + + # Cut the middle entries out + buffer_split = floor(self.maximum_buffer / 2) + + old_buffer = self._buffer + self._buffer = deque() + + for i in range(buffer_split): + self._buffer.append(old_buffer.popleft()) + + end_buffer = [] + for i in range(buffer_split): + end_buffer.append(old_buffer.pop()) + + self._buffer.extend(reversed(end_buffer)) + + def __call__(self, event: dict) -> None: + flattened = flatten_event(event, self.metadata, include_time=True) + self._buffer.append(flattened) + + # Handle backpressure, if it exists. + try: + self._handle_pressure() + except Exception: + # If handling backpressure fails,clear the buffer and log the + # exception. + self._buffer.clear() + self._logger.failure("Failed clearing backpressure") + + # Try and write immediately. + self._write_loop() diff --git a/synapse/logging/context.py b/synapse/logging/context.py index b456c31f70..63379bfb93 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -25,6 +25,7 @@ See doc/log_contexts.rst for details on how this works. import logging import threading import types +from typing import Any, List from twisted.internet import defer, threads @@ -194,7 +195,7 @@ class LoggingContext(object): class Sentinel(object): """Sentinel to represent the root context""" - __slots__ = [] + __slots__ = [] # type: List[Any] def __str__(self): return "sentinel" @@ -202,6 +203,10 @@ class LoggingContext(object): def copy_to(self, record): pass + def copy_to_twisted_log_entry(self, record): + record["request"] = None + record["scope"] = None + def start(self): pass @@ -330,6 +335,13 @@ class LoggingContext(object): # we also track the current scope: record.scope = self.scope + def copy_to_twisted_log_entry(self, record): + """ + Copy logging fields from this context to a Twisted log record. + """ + record["request"] = self.request + record["scope"] = self.scope + def start(self): if get_thread_id() != self.main_thread: logger.warning("Started logcontext %s on different thread", self) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index c6465c0386..ec0ac547c1 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -47,9 +47,9 @@ REQUIREMENTS = [ "idna>=2.5", # validating SSL certs for IP addresses requires service_identity 18.1. "service_identity>=18.1.0", - # our logcontext handling relies on the ability to cancel inlineCallbacks - # (https://twistedmatrix.com/trac/ticket/4632) which landed in Twisted 18.7. - "Twisted>=18.7.0", + # Twisted 18.9 introduces some logger improvements that the structured + # logger utilises + "Twisted>=18.9.0", "treq>=15.1", # Twisted has required pyopenssl 16.0 since about Twisted 16.6. "pyopenssl>=16.0.0", diff --git a/tests/logging/__init__.py b/tests/logging/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/logging/test_structured.py b/tests/logging/test_structured.py new file mode 100644 index 0000000000..a786de0233 --- /dev/null +++ b/tests/logging/test_structured.py @@ -0,0 +1,197 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import os.path +import shutil +import sys +import textwrap + +from twisted.logger import Logger, eventAsText, eventsFromJSONLogFile + +from synapse.config.logger import setup_logging +from synapse.logging._structured import setup_structured_logging +from synapse.logging.context import LoggingContext + +from tests.unittest import DEBUG, HomeserverTestCase + + +class FakeBeginner(object): + def beginLoggingTo(self, observers, **kwargs): + self.observers = observers + + +class StructuredLoggingTestCase(HomeserverTestCase): + """ + Tests for Synapse's structured logging support. + """ + + def test_output_to_json_round_trip(self): + """ + Synapse logs can be outputted to JSON and then read back again. + """ + temp_dir = self.mktemp() + os.mkdir(temp_dir) + self.addCleanup(shutil.rmtree, temp_dir) + + json_log_file = os.path.abspath(os.path.join(temp_dir, "out.json")) + + log_config = { + "drains": {"jsonfile": {"type": "file_json", "location": json_log_file}} + } + + # Begin the logger with our config + beginner = FakeBeginner() + setup_structured_logging( + self.hs, self.hs.config, log_config, logBeginner=beginner + ) + + # Make a logger and send an event + logger = Logger( + namespace="tests.logging.test_structured", observer=beginner.observers[0] + ) + logger.info("Hello there, {name}!", name="wally") + + # Read the log file and check it has the event we sent + with open(json_log_file, "r") as f: + logged_events = list(eventsFromJSONLogFile(f)) + self.assertEqual(len(logged_events), 1) + + # The event pulled from the file should render fine + self.assertEqual( + eventAsText(logged_events[0], includeTimestamp=False), + "[tests.logging.test_structured#info] Hello there, wally!", + ) + + def test_output_to_text(self): + """ + Synapse logs can be outputted to text. + """ + temp_dir = self.mktemp() + os.mkdir(temp_dir) + self.addCleanup(shutil.rmtree, temp_dir) + + log_file = os.path.abspath(os.path.join(temp_dir, "out.log")) + + log_config = {"drains": {"file": {"type": "file", "location": log_file}}} + + # Begin the logger with our config + beginner = FakeBeginner() + setup_structured_logging( + self.hs, self.hs.config, log_config, logBeginner=beginner + ) + + # Make a logger and send an event + logger = Logger( + namespace="tests.logging.test_structured", observer=beginner.observers[0] + ) + logger.info("Hello there, {name}!", name="wally") + + # Read the log file and check it has the event we sent + with open(log_file, "r") as f: + logged_events = f.read().strip().split("\n") + self.assertEqual(len(logged_events), 1) + + # The event pulled from the file should render fine + self.assertTrue( + logged_events[0].endswith( + " - tests.logging.test_structured - INFO - None - Hello there, wally!" + ) + ) + + def test_collects_logcontext(self): + """ + Test that log outputs have the attached logging context. + """ + log_config = {"drains": {}} + + # Begin the logger with our config + beginner = FakeBeginner() + publisher = setup_structured_logging( + self.hs, self.hs.config, log_config, logBeginner=beginner + ) + + logs = [] + + publisher.addObserver(logs.append) + + # Make a logger and send an event + logger = Logger( + namespace="tests.logging.test_structured", observer=beginner.observers[0] + ) + + with LoggingContext("testcontext", request="somereq"): + logger.info("Hello there, {name}!", name="steve") + + self.assertEqual(len(logs), 1) + self.assertEqual(logs[0]["request"], "somereq") + + +class StructuredLoggingConfigurationFileTestCase(HomeserverTestCase): + def make_homeserver(self, reactor, clock): + + tempdir = self.mktemp() + os.mkdir(tempdir) + log_config_file = os.path.abspath(os.path.join(tempdir, "log.config.yaml")) + self.homeserver_log = os.path.abspath(os.path.join(tempdir, "homeserver.log")) + + config = self.default_config() + config["log_config"] = log_config_file + + with open(log_config_file, "w") as f: + f.write( + textwrap.dedent( + """\ + structured: true + + drains: + file: + type: file_json + location: %s + """ + % (self.homeserver_log,) + ) + ) + + self.addCleanup(self._sys_cleanup) + + return self.setup_test_homeserver(config=config) + + def _sys_cleanup(self): + sys.stdout = sys.__stdout__ + sys.stderr = sys.__stderr__ + + # Do not remove! We need the logging system to be set other than WARNING. + @DEBUG + def test_log_output(self): + """ + When a structured logging config is given, Synapse will use it. + """ + setup_logging(self.hs, self.hs.config) + + # Make a logger and send an event + logger = Logger(namespace="tests.logging.test_structured") + + with LoggingContext("testcontext", request="somereq"): + logger.info("Hello there, {name}!", name="steve") + + with open(self.homeserver_log, "r") as f: + logged_events = [ + eventAsText(x, includeTimestamp=False) for x in eventsFromJSONLogFile(f) + ] + + logs = "\n".join(logged_events) + self.assertTrue("***** STARTING SERVER *****" in logs) + self.assertTrue("Hello there, steve!" in logs) diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py new file mode 100644 index 0000000000..514282591d --- /dev/null +++ b/tests/logging/test_terse_json.py @@ -0,0 +1,234 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from collections import Counter + +from twisted.logger import Logger + +from synapse.logging._structured import setup_structured_logging + +from tests.server import connect_client +from tests.unittest import HomeserverTestCase + +from .test_structured import FakeBeginner + + +class TerseJSONTCPTestCase(HomeserverTestCase): + def test_log_output(self): + """ + The Terse JSON outputter delivers simplified structured logs over TCP. + """ + log_config = { + "drains": { + "tersejson": { + "type": "network_json_terse", + "host": "127.0.0.1", + "port": 8000, + } + } + } + + # Begin the logger with our config + beginner = FakeBeginner() + setup_structured_logging( + self.hs, self.hs.config, log_config, logBeginner=beginner + ) + + logger = Logger( + namespace="tests.logging.test_terse_json", observer=beginner.observers[0] + ) + logger.info("Hello there, {name}!", name="wally") + + # Trigger the connection + self.pump() + + _, server = connect_client(self.reactor, 0) + + # Trigger data being sent + self.pump() + + # One log message, with a single trailing newline + logs = server.data.decode("utf8").splitlines() + self.assertEqual(len(logs), 1) + self.assertEqual(server.data.count(b"\n"), 1) + + log = json.loads(logs[0]) + + # The terse logger should give us these keys. + expected_log_keys = [ + "log", + "time", + "level", + "log_namespace", + "request", + "scope", + "server_name", + "name", + ] + self.assertEqual(set(log.keys()), set(expected_log_keys)) + + # It contains the data we expect. + self.assertEqual(log["name"], "wally") + + def test_log_backpressure_debug(self): + """ + When backpressure is hit, DEBUG logs will be shed. + """ + log_config = { + "loggers": {"synapse": {"level": "DEBUG"}}, + "drains": { + "tersejson": { + "type": "network_json_terse", + "host": "127.0.0.1", + "port": 8000, + "maximum_buffer": 10, + } + }, + } + + # Begin the logger with our config + beginner = FakeBeginner() + setup_structured_logging( + self.hs, + self.hs.config, + log_config, + logBeginner=beginner, + redirect_stdlib_logging=False, + ) + + logger = Logger( + namespace="synapse.logging.test_terse_json", observer=beginner.observers[0] + ) + + # Send some debug messages + for i in range(0, 3): + logger.debug("debug %s" % (i,)) + + # Send a bunch of useful messages + for i in range(0, 7): + logger.info("test message %s" % (i,)) + + # The last debug message pushes it past the maximum buffer + logger.debug("too much debug") + + # Allow the reconnection + _, server = connect_client(self.reactor, 0) + self.pump() + + # Only the 7 infos made it through, the debugs were elided + logs = server.data.splitlines() + self.assertEqual(len(logs), 7) + + def test_log_backpressure_info(self): + """ + When backpressure is hit, DEBUG and INFO logs will be shed. + """ + log_config = { + "loggers": {"synapse": {"level": "DEBUG"}}, + "drains": { + "tersejson": { + "type": "network_json_terse", + "host": "127.0.0.1", + "port": 8000, + "maximum_buffer": 10, + } + }, + } + + # Begin the logger with our config + beginner = FakeBeginner() + setup_structured_logging( + self.hs, + self.hs.config, + log_config, + logBeginner=beginner, + redirect_stdlib_logging=False, + ) + + logger = Logger( + namespace="synapse.logging.test_terse_json", observer=beginner.observers[0] + ) + + # Send some debug messages + for i in range(0, 3): + logger.debug("debug %s" % (i,)) + + # Send a bunch of useful messages + for i in range(0, 10): + logger.warn("test warn %s" % (i,)) + + # Send a bunch of info messages + for i in range(0, 3): + logger.info("test message %s" % (i,)) + + # The last debug message pushes it past the maximum buffer + logger.debug("too much debug") + + # Allow the reconnection + client, server = connect_client(self.reactor, 0) + self.pump() + + # The 10 warnings made it through, the debugs and infos were elided + logs = list(map(json.loads, server.data.decode("utf8").splitlines())) + self.assertEqual(len(logs), 10) + + self.assertEqual(Counter([x["level"] for x in logs]), {"WARN": 10}) + + def test_log_backpressure_cut_middle(self): + """ + When backpressure is hit, and no more DEBUG and INFOs cannot be culled, + it will cut the middle messages out. + """ + log_config = { + "loggers": {"synapse": {"level": "DEBUG"}}, + "drains": { + "tersejson": { + "type": "network_json_terse", + "host": "127.0.0.1", + "port": 8000, + "maximum_buffer": 10, + } + }, + } + + # Begin the logger with our config + beginner = FakeBeginner() + setup_structured_logging( + self.hs, + self.hs.config, + log_config, + logBeginner=beginner, + redirect_stdlib_logging=False, + ) + + logger = Logger( + namespace="synapse.logging.test_terse_json", observer=beginner.observers[0] + ) + + # Send a bunch of useful messages + for i in range(0, 20): + logger.warn("test warn", num=i) + + # Allow the reconnection + client, server = connect_client(self.reactor, 0) + self.pump() + + # The first five and last five warnings made it through, the debugs and + # infos were elided + logs = list(map(json.loads, server.data.decode("utf8").splitlines())) + self.assertEqual(len(logs), 10) + self.assertEqual(Counter([x["level"] for x in logs]), {"WARN": 10}) + self.assertEqual([0, 1, 2, 3, 4, 15, 16, 17, 18, 19], [x["num"] for x in logs]) diff --git a/tests/server.py b/tests/server.py index e573c4e4c5..c8269619b1 100644 --- a/tests/server.py +++ b/tests/server.py @@ -11,9 +11,13 @@ from twisted.internet import address, threads, udp from twisted.internet._resolver import SimpleResolverComplexifier from twisted.internet.defer import Deferred, fail, succeed from twisted.internet.error import DNSLookupError -from twisted.internet.interfaces import IReactorPluggableNameResolver, IResolverSimple +from twisted.internet.interfaces import ( + IReactorPluggableNameResolver, + IReactorTCP, + IResolverSimple, +) from twisted.python.failure import Failure -from twisted.test.proto_helpers import MemoryReactorClock +from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock from twisted.web.http import unquote from twisted.web.http_headers import Headers @@ -465,3 +469,22 @@ class FakeTransport(object): self.buffer = self.buffer[len(to_write) :] if self.buffer and self.autoflush: self._reactor.callLater(0.0, self.flush) + + +def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol: + """ + Connect a client to a fake TCP transport. + + Args: + reactor + factory: The connecting factory to build. + """ + factory = reactor.tcpClients[client_id][2] + client = factory.buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, reactor)) + client.makeConnection(FakeTransport(server, reactor)) + + reactor.tcpClients.pop(client_id) + + return client, server diff --git a/tox.ini b/tox.ini index 09b4b8fc3c..f9a3b7e49a 100644 --- a/tox.ini +++ b/tox.ini @@ -146,3 +146,13 @@ commands = coverage combine coverage xml codecov -X gcov + +[testenv:mypy] +basepython = python3.5 +deps = + {[base]deps} + mypy +extras = all +commands = mypy --ignore-missing-imports \ + synapse/logging/_structured.py \ + synapse/logging/_terse_json.py \ No newline at end of file -- cgit 1.5.1 From 6d97843793d59bc5d307475a6a6185ff107e116b Mon Sep 17 00:00:00 2001 From: Jorik Schellekens Date: Wed, 28 Aug 2019 13:12:22 +0100 Subject: Config templating (#5900) Template config files * Imagine a system composed entirely of x, y, z etc and the basic operations.. Wait George, why XOR? Why not just neq? George: Eh, I didn't think of that.. Co-Authored-By: Erik Johnston --- changelog.d/5900.feature | 1 + docs/sample_config.yaml | 16 +++---- synapse/config/_base.py | 37 ++++++++++++++++ synapse/config/database.py | 27 +++++++---- synapse/config/server.py | 84 ++++++++++++++++++++++++++++------- synapse/config/tls.py | 50 ++++++++++++++++----- tests/config/test_database.py | 52 ++++++++++++++++++++++ tests/config/test_server.py | 101 +++++++++++++++++++++++++++++++++++++++++- tests/config/test_tls.py | 44 ++++++++++++++++++ 9 files changed, 366 insertions(+), 46 deletions(-) create mode 100644 changelog.d/5900.feature create mode 100644 tests/config/test_database.py (limited to 'tests') diff --git a/changelog.d/5900.feature b/changelog.d/5900.feature new file mode 100644 index 0000000000..b62d88a76b --- /dev/null +++ b/changelog.d/5900.feature @@ -0,0 +1 @@ +Add support for config templating. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index ae1cafc5f3..6da1167632 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -205,9 +205,9 @@ listeners: # - port: 8008 tls: false - bind_addresses: ['::1', '127.0.0.1'] type: http x_forwarded: true + bind_addresses: ['::1', '127.0.0.1'] resources: - names: [client, federation] @@ -392,10 +392,10 @@ listeners: # permission to listen on port 80. # acme: - # ACME support is disabled by default. Uncomment the following line - # (and tls_certificate_path and tls_private_key_path above) to enable it. + # ACME support is disabled by default. Set this to `true` and uncomment + # tls_certificate_path and tls_private_key_path above to enable it. # - #enabled: true + enabled: False # Endpoint to use to request certificates. If you only want to test, # use Let's Encrypt's staging url: @@ -406,17 +406,17 @@ acme: # Port number to listen on for the HTTP-01 challenge. Change this if # you are forwarding connections through Apache/Nginx/etc. # - #port: 80 + port: 80 # Local addresses to listen on for incoming connections. # Again, you may want to change this if you are forwarding connections # through Apache/Nginx/etc. # - #bind_addresses: ['::', '0.0.0.0'] + bind_addresses: ['::', '0.0.0.0'] # How many days remaining on a certificate before it is renewed. # - #reprovision_threshold: 30 + reprovision_threshold: 30 # The domain that the certificate should be for. Normally this # should be the same as your Matrix domain (i.e., 'server_name'), but, @@ -430,7 +430,7 @@ acme: # # If not set, defaults to your 'server_name'. # - #domain: matrix.example.com + domain: matrix.example.com # file to use for the account key. This will be generated if it doesn't # exist. diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 6ce5cd07fb..31f6530978 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -181,6 +181,11 @@ class Config(object): generate_secrets=False, report_stats=None, open_private_ports=False, + listeners=None, + database_conf=None, + tls_certificate_path=None, + tls_private_key_path=None, + acme_domain=None, ): """Build a default configuration file @@ -207,6 +212,33 @@ class Config(object): open_private_ports (bool): True to leave private ports (such as the non-TLS HTTP listener) open to the internet. + listeners (list(dict)|None): A list of descriptions of the listeners + synapse should start with each of which specifies a port (str), a list of + resources (list(str)), tls (bool) and type (str). For example: + [{ + "port": 8448, + "resources": [{"names": ["federation"]}], + "tls": True, + "type": "http", + }, + { + "port": 443, + "resources": [{"names": ["client"]}], + "tls": False, + "type": "http", + }], + + + database (str|None): The database type to configure, either `psycog2` + or `sqlite3`. + + tls_certificate_path (str|None): The path to the tls certificate. + + tls_private_key_path (str|None): The path to the tls private key. + + acme_domain (str|None): The domain acme will try to validate. If + specified acme will be enabled. + Returns: str: the yaml config file """ @@ -220,6 +252,11 @@ class Config(object): generate_secrets=generate_secrets, report_stats=report_stats, open_private_ports=open_private_ports, + listeners=listeners, + database_conf=database_conf, + tls_certificate_path=tls_certificate_path, + tls_private_key_path=tls_private_key_path, + acme_domain=acme_domain, ) ) diff --git a/synapse/config/database.py b/synapse/config/database.py index 746a6cd1f4..118aafbd4a 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from textwrap import indent + +import yaml from ._base import Config @@ -38,20 +41,28 @@ class DatabaseConfig(Config): self.set_databasepath(config.get("database_path")) - def generate_config_section(self, data_dir_path, **kwargs): - database_path = os.path.join(data_dir_path, "homeserver.db") - return ( - """\ - ## Database ## - - database: - # The database engine name + def generate_config_section(self, data_dir_path, database_conf, **kwargs): + if not database_conf: + database_path = os.path.join(data_dir_path, "homeserver.db") + database_conf = ( + """# The database engine name name: "sqlite3" # Arguments to pass to the engine args: # Path to the database database: "%(database_path)s" + """ + % locals() + ) + else: + database_conf = indent(yaml.dump(database_conf), " " * 10).lstrip() + + return ( + """\ + ## Database ## + database: + %(database_conf)s # Number of events to cache in memory. # #event_cache_size: 10K diff --git a/synapse/config/server.py b/synapse/config/server.py index 15449695d1..2abdef0971 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -17,8 +17,11 @@ import logging import os.path +import re +from textwrap import indent import attr +import yaml from netaddr import IPSet from synapse.api.room_versions import KNOWN_ROOM_VERSIONS @@ -352,7 +355,7 @@ class ServerConfig(Config): return any(l["tls"] for l in self.listeners) def generate_config_section( - self, server_name, data_dir_path, open_private_ports, **kwargs + self, server_name, data_dir_path, open_private_ports, listeners, **kwargs ): _, bind_port = parse_and_validate_server_name(server_name) if bind_port is not None: @@ -366,11 +369,68 @@ class ServerConfig(Config): # Bring DEFAULT_ROOM_VERSION into the local-scope for use in the # default config string default_room_version = DEFAULT_ROOM_VERSION + secure_listeners = [] + unsecure_listeners = [] + private_addresses = ["::1", "127.0.0.1"] + if listeners: + for listener in listeners: + if listener["tls"]: + secure_listeners.append(listener) + else: + # If we don't want open ports we need to bind the listeners + # to some address other than 0.0.0.0. Here we chose to use + # localhost. + # If the addresses are already bound we won't overwrite them + # however. + if not open_private_ports: + listener.setdefault("bind_addresses", private_addresses) + + unsecure_listeners.append(listener) + + secure_http_bindings = indent( + yaml.dump(secure_listeners), " " * 10 + ).lstrip() + + unsecure_http_bindings = indent( + yaml.dump(unsecure_listeners), " " * 10 + ).lstrip() + + if not unsecure_listeners: + unsecure_http_bindings = ( + """- port: %(unsecure_port)s + tls: false + type: http + x_forwarded: true""" + % locals() + ) + + if not open_private_ports: + unsecure_http_bindings += ( + "\n bind_addresses: ['::1', '127.0.0.1']" + ) + + unsecure_http_bindings += """ + + resources: + - names: [client, federation] + compress: false""" + + if listeners: + # comment out this block + unsecure_http_bindings = "#" + re.sub( + "\n {10}", + lambda match: match.group(0) + "#", + unsecure_http_bindings, + ) - unsecure_http_binding = "port: %i\n tls: false" % (unsecure_port,) - if not open_private_ports: - unsecure_http_binding += ( - "\n bind_addresses: ['::1', '127.0.0.1']" + if not secure_listeners: + secure_http_bindings = ( + """#- port: %(bind_port)s + # type: http + # tls: true + # resources: + # - names: [client, federation]""" + % locals() ) return ( @@ -556,11 +616,7 @@ class ServerConfig(Config): # will also need to give Synapse a TLS key and certificate: see the TLS section # below.) # - #- port: %(bind_port)s - # type: http - # tls: true - # resources: - # - names: [client, federation] + %(secure_http_bindings)s # Unsecure HTTP listener: for when matrix traffic passes through a reverse proxy # that unwraps TLS. @@ -568,13 +624,7 @@ class ServerConfig(Config): # If you plan to use a reverse proxy, please see # https://github.com/matrix-org/synapse/blob/master/docs/reverse_proxy.rst. # - - %(unsecure_http_binding)s - type: http - x_forwarded: true - - resources: - - names: [client, federation] - compress: false + %(unsecure_http_bindings)s # example additional_resources: # diff --git a/synapse/config/tls.py b/synapse/config/tls.py index ca508a224f..c0148aa95c 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -239,12 +239,38 @@ class TlsConfig(Config): self.tls_fingerprints.append({"sha256": sha256_fingerprint}) def generate_config_section( - self, config_dir_path, server_name, data_dir_path, **kwargs + self, + config_dir_path, + server_name, + data_dir_path, + tls_certificate_path, + tls_private_key_path, + acme_domain, + **kwargs ): + """If the acme_domain is specified acme will be enabled. + If the TLS paths are not specified the default will be certs in the + config directory""" + base_key_name = os.path.join(config_dir_path, server_name) - tls_certificate_path = base_key_name + ".tls.crt" - tls_private_key_path = base_key_name + ".tls.key" + if bool(tls_certificate_path) != bool(tls_private_key_path): + raise ConfigError( + "Please specify both a cert path and a key path or neither." + ) + + tls_enabled = ( + "" if tls_certificate_path and tls_private_key_path or acme_domain else "#" + ) + + if not tls_certificate_path: + tls_certificate_path = base_key_name + ".tls.crt" + if not tls_private_key_path: + tls_private_key_path = base_key_name + ".tls.key" + + acme_enabled = bool(acme_domain) + acme_domain = "matrix.example.com" + default_acme_account_file = os.path.join(data_dir_path, "acme_account.key") # this is to avoid the max line length. Sorrynotsorry @@ -269,11 +295,11 @@ class TlsConfig(Config): # instance, if using certbot, use `fullchain.pem` as your certificate, # not `cert.pem`). # - #tls_certificate_path: "%(tls_certificate_path)s" + %(tls_enabled)stls_certificate_path: "%(tls_certificate_path)s" # PEM-encoded private key for TLS # - #tls_private_key_path: "%(tls_private_key_path)s" + %(tls_enabled)stls_private_key_path: "%(tls_private_key_path)s" # Whether to verify TLS server certificates for outbound federation requests. # @@ -340,10 +366,10 @@ class TlsConfig(Config): # permission to listen on port 80. # acme: - # ACME support is disabled by default. Uncomment the following line - # (and tls_certificate_path and tls_private_key_path above) to enable it. + # ACME support is disabled by default. Set this to `true` and uncomment + # tls_certificate_path and tls_private_key_path above to enable it. # - #enabled: true + enabled: %(acme_enabled)s # Endpoint to use to request certificates. If you only want to test, # use Let's Encrypt's staging url: @@ -354,17 +380,17 @@ class TlsConfig(Config): # Port number to listen on for the HTTP-01 challenge. Change this if # you are forwarding connections through Apache/Nginx/etc. # - #port: 80 + port: 80 # Local addresses to listen on for incoming connections. # Again, you may want to change this if you are forwarding connections # through Apache/Nginx/etc. # - #bind_addresses: ['::', '0.0.0.0'] + bind_addresses: ['::', '0.0.0.0'] # How many days remaining on a certificate before it is renewed. # - #reprovision_threshold: 30 + reprovision_threshold: 30 # The domain that the certificate should be for. Normally this # should be the same as your Matrix domain (i.e., 'server_name'), but, @@ -378,7 +404,7 @@ class TlsConfig(Config): # # If not set, defaults to your 'server_name'. # - #domain: matrix.example.com + domain: %(acme_domain)s # file to use for the account key. This will be generated if it doesn't # exist. diff --git a/tests/config/test_database.py b/tests/config/test_database.py new file mode 100644 index 0000000000..151d3006ac --- /dev/null +++ b/tests/config/test_database.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import yaml + +from synapse.config.database import DatabaseConfig + +from tests import unittest + + +class DatabaseConfigTestCase(unittest.TestCase): + def test_database_configured_correctly_no_database_conf_param(self): + conf = yaml.safe_load( + DatabaseConfig().generate_config_section("/data_dir_path", None) + ) + + expected_database_conf = { + "name": "sqlite3", + "args": {"database": "/data_dir_path/homeserver.db"}, + } + + self.assertEqual(conf["database"], expected_database_conf) + + def test_database_configured_correctly_database_conf_param(self): + + database_conf = { + "name": "my super fast datastore", + "args": { + "user": "matrix", + "password": "synapse_database_password", + "host": "synapse_database_host", + "database": "matrix", + }, + } + + conf = yaml.safe_load( + DatabaseConfig().generate_config_section("/data_dir_path", database_conf) + ) + + self.assertEqual(conf["database"], database_conf) diff --git a/tests/config/test_server.py b/tests/config/test_server.py index 1ca5ea54ca..a10d017120 100644 --- a/tests/config/test_server.py +++ b/tests/config/test_server.py @@ -13,7 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.config.server import is_threepid_reserved +import yaml + +from synapse.config.server import ServerConfig, is_threepid_reserved from tests import unittest @@ -29,3 +31,100 @@ class ServerConfigTestCase(unittest.TestCase): self.assertTrue(is_threepid_reserved(config, user1)) self.assertFalse(is_threepid_reserved(config, user3)) self.assertFalse(is_threepid_reserved(config, user1_msisdn)) + + def test_unsecure_listener_no_listeners_open_private_ports_false(self): + conf = yaml.safe_load( + ServerConfig().generate_config_section( + "che.org", "/data_dir_path", False, None + ) + ) + + expected_listeners = [ + { + "port": 8008, + "tls": False, + "type": "http", + "x_forwarded": True, + "bind_addresses": ["::1", "127.0.0.1"], + "resources": [{"names": ["client", "federation"], "compress": False}], + } + ] + + self.assertEqual(conf["listeners"], expected_listeners) + + def test_unsecure_listener_no_listeners_open_private_ports_true(self): + conf = yaml.safe_load( + ServerConfig().generate_config_section( + "che.org", "/data_dir_path", True, None + ) + ) + + expected_listeners = [ + { + "port": 8008, + "tls": False, + "type": "http", + "x_forwarded": True, + "resources": [{"names": ["client", "federation"], "compress": False}], + } + ] + + self.assertEqual(conf["listeners"], expected_listeners) + + def test_listeners_set_correctly_open_private_ports_false(self): + listeners = [ + { + "port": 8448, + "resources": [{"names": ["federation"]}], + "tls": True, + "type": "http", + }, + { + "port": 443, + "resources": [{"names": ["client"]}], + "tls": False, + "type": "http", + }, + ] + + conf = yaml.safe_load( + ServerConfig().generate_config_section( + "this.one.listens", "/data_dir_path", True, listeners + ) + ) + + self.assertEqual(conf["listeners"], listeners) + + def test_listeners_set_correctly_open_private_ports_true(self): + listeners = [ + { + "port": 8448, + "resources": [{"names": ["federation"]}], + "tls": True, + "type": "http", + }, + { + "port": 443, + "resources": [{"names": ["client"]}], + "tls": False, + "type": "http", + }, + { + "port": 1243, + "resources": [{"names": ["client"]}], + "tls": False, + "type": "http", + "bind_addresses": ["this_one_is_bound"], + }, + ] + + expected_listeners = listeners.copy() + expected_listeners[1]["bind_addresses"] = ["::1", "127.0.0.1"] + + conf = yaml.safe_load( + ServerConfig().generate_config_section( + "this.one.listens", "/data_dir_path", True, listeners + ) + ) + + self.assertEqual(conf["listeners"], expected_listeners) diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py index 4f8a87a3df..8e0c4b9533 100644 --- a/tests/config/test_tls.py +++ b/tests/config/test_tls.py @@ -16,6 +16,8 @@ import os +import yaml + from OpenSSL import SSL from synapse.config.tls import ConfigError, TlsConfig @@ -191,3 +193,45 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg= self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1, 0) self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_1, 0) self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_2, 0) + + def test_acme_disabled_in_generated_config_no_acme_domain_provied(self): + """ + Checks acme is disabled by default. + """ + conf = TestConfig() + conf.read_config( + yaml.safe_load( + TestConfig().generate_config_section( + "/config_dir_path", + "my_super_secure_server", + "/data_dir_path", + "/tls_cert_path", + "tls_private_key", + None, # This is the acme_domain + ) + ), + "/config_dir_path", + ) + + self.assertFalse(conf.acme_enabled) + + def test_acme_enabled_in_generated_config_domain_provided(self): + """ + Checks acme is enabled if the acme_domain arg is set to some string. + """ + conf = TestConfig() + conf.read_config( + yaml.safe_load( + TestConfig().generate_config_section( + "/config_dir_path", + "my_super_secure_server", + "/data_dir_path", + "/tls_cert_path", + "tls_private_key", + "my_supe_secure_server", # This is the acme_domain + ) + ), + "/config_dir_path", + ) + + self.assertTrue(conf.acme_enabled) -- cgit 1.5.1 From a4bf72c30c5953b721a64eae89db186fa8735bb3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 29 Aug 2019 17:38:51 +0100 Subject: Censor redactions in DB after a month --- synapse/storage/events.py | 88 +++++++++++++++++++++- .../storage/schema/delta/56/redaction_censor.sql | 17 +++++ tests/storage/test_redaction.py | 71 +++++++++++++++++ 3 files changed, 175 insertions(+), 1 deletion(-) create mode 100644 synapse/storage/schema/delta/56/redaction_censor.sql (limited to 'tests') diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 5a95c36a8b..2970da6829 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -23,7 +23,7 @@ from functools import wraps from six import iteritems, text_type from six.moves import range -from canonicaljson import json +from canonicaljson import encode_canonical_json, json from prometheus_client import Counter, Histogram from twisted.internet import defer @@ -33,6 +33,7 @@ from synapse.api.constants import EventTypes from synapse.api.errors import SynapseError from synapse.events import EventBase # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401 +from synapse.events.utils import prune_event_dict from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.logging.utils import log_function from synapse.metrics import BucketCollector @@ -262,6 +263,13 @@ class EventsStore( hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000) + def _censor_redactions(): + return run_as_background_process( + "_censor_redactions", self._censor_redactions + ) + + hs.get_clock().looping_call(_censor_redactions, 10 * 60 * 1000) + @defer.inlineCallbacks def _read_forward_extremities(self): def fetch(txn): @@ -1548,6 +1556,84 @@ class EventsStore( (event.event_id, event.redacts), ) + @defer.inlineCallbacks + def _censor_redactions(self): + """Censors all redactions older than a month that haven't been censored. + + By censor we mean update the event_json table with the redacted event. + + Returns: + Deferred + """ + + if self.stream_ordering_month_ago is None: + return + + max_pos = self.stream_ordering_month_ago + + # We fetch all redactions that point to an event that we have that has + # a stream ordering from over a month ago, that we haven't yet censored + # in the DB. + sql = """ + SELECT er.event_id, redacts FROM redactions + INNER JOIN events AS er USING (event_id) + INNER JOIN events AS eb ON (er.room_id = eb.room_id AND redacts = eb.event_id) + WHERE NOT have_censored + AND ? <= er.stream_ordering AND er.stream_ordering <= ? + ORDER BY er.stream_ordering ASC + LIMIT ? + """ + + rows = yield self._execute( + "_censor_redactions_fetch", None, sql, -max_pos, max_pos, 100 + ) + + updates = [] + + for redaction_id, event_id in rows: + redaction_event = yield self.get_event(redaction_id, allow_none=True) + original_event = yield self.get_event( + event_id, allow_rejected=True, allow_none=True + ) + + # The SQL above ensures that we have both the redaction and + # original event, so if the `get_event` calls return None it + # means that the redaction wasn't allowed. Either way we know that + # the result won't change so we mark the fact that we've checked. + if ( + redaction_event + and original_event + and original_event.internal_metadata.is_redacted() + ): + # Redaction was allowed + pruned_json = encode_canonical_json( + prune_event_dict(original_event.get_dict()) + ) + else: + # Redaction wasn't allowed + pruned_json = None + + updates.append((redaction_id, event_id, pruned_json)) + + def _update_censor_txn(txn): + for redaction_id, event_id, pruned_json in updates: + if pruned_json: + self._simple_update_one_txn( + txn, + table="event_json", + keyvalues={"event_id": event_id}, + updatevalues={"json": pruned_json}, + ) + + self._simple_update_one_txn( + txn, + table="redactions", + keyvalues={"event_id": redaction_id}, + updatevalues={"have_censored": True}, + ) + + yield self.runInteraction("_update_censor_txn", _update_censor_txn) + @defer.inlineCallbacks def count_daily_messages(self): """ diff --git a/synapse/storage/schema/delta/56/redaction_censor.sql b/synapse/storage/schema/delta/56/redaction_censor.sql new file mode 100644 index 0000000000..fe51b02309 --- /dev/null +++ b/synapse/storage/schema/delta/56/redaction_censor.sql @@ -0,0 +1,17 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE redactions ADD COLUMN have_censored BOOL NOT NULL DEFAULT false; +CREATE INDEX redactions_have_censored ON redactions(event_id) WHERE not have_censored; diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index d961b81d48..0c9f3c7071 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -17,6 +17,8 @@ from mock import Mock +from canonicaljson import json + from twisted.internet import defer from synapse.api.constants import EventTypes, Membership @@ -286,3 +288,72 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.assertEqual( fetched.unsigned["redacted_because"].event_id, redaction_event_id2 ) + + def test_redact_censor(self): + """Test that a redacted event gets censored in the DB after a month + """ + + self.get_success( + self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) + ) + + msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t")) + + # Check event has not been redacted: + event = self.get_success(self.store.get_event(msg_event.event_id)) + + self.assertObjectHasAttributes( + { + "type": EventTypes.Message, + "user_id": self.u_alice.to_string(), + "content": {"body": "t", "msgtype": "message"}, + }, + event, + ) + + self.assertFalse("redacted_because" in event.unsigned) + + # Redact event + reason = "Because I said so" + self.get_success( + self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason) + ) + + event = self.get_success(self.store.get_event(msg_event.event_id)) + + self.assertTrue("redacted_because" in event.unsigned) + + self.assertObjectHasAttributes( + { + "type": EventTypes.Message, + "user_id": self.u_alice.to_string(), + "content": {}, + }, + event, + ) + + event_json = self.get_success( + self.store._simple_select_one_onecol( + table="event_json", + keyvalues={"event_id": msg_event.event_id}, + retcol="json", + ) + ) + + self.assert_dict( + {"content": {"body": "t", "msgtype": "message"}}, json.loads(event_json) + ) + + # Advance by 30 days + self.reactor.advance(60 * 60 * 24 * 31) + self.reactor.advance(60 * 60 * 2) + + event_json = self.get_success( + self.store._simple_select_one_onecol( + table="event_json", + keyvalues={"event_id": msg_event.event_id}, + retcol="json", + ) + ) + + self.assert_dict({"content": {}}, json.loads(event_json)) -- cgit 1.5.1 From 4548d1f87e3ff3dc24b0af8f944276137d3228e3 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 30 Aug 2019 16:28:26 +0100 Subject: Remove unnecessary parentheses around return statements (#5931) Python will return a tuple whether there are parentheses around the returned values or not. I'm just sick of my editor complaining about this all over the place :) --- changelog.d/5931.misc | 1 + synapse/api/auth.py | 14 ++--- synapse/app/frontend_proxy.py | 8 +-- synapse/crypto/event_signing.py | 4 +- synapse/federation/federation_client.py | 8 +-- synapse/federation/federation_server.py | 22 ++++---- synapse/handlers/account_data.py | 4 +- synapse/handlers/auth.py | 8 +-- synapse/handlers/federation.py | 2 +- synapse/handlers/initial_sync.py | 4 +- synapse/handlers/presence.py | 4 +- synapse/handlers/receipts.py | 2 +- synapse/handlers/register.py | 2 +- synapse/handlers/room_member.py | 2 +- synapse/handlers/sync.py | 8 +-- synapse/handlers/typing.py | 2 +- synapse/http/federation/well_known_resolver.py | 2 +- synapse/module_api/__init__.py | 2 +- synapse/notifier.py | 6 +- synapse/push/bulk_push_rule_evaluator.py | 2 +- synapse/replication/http/federation.py | 8 +-- synapse/replication/http/login.py | 2 +- synapse/replication/http/membership.py | 6 +- synapse/replication/http/register.py | 4 +- synapse/replication/http/send_event.py | 2 +- synapse/replication/tcp/streams/_base.py | 8 +-- synapse/rest/admin/__init__.py | 26 ++++----- synapse/rest/admin/media.py | 6 +- synapse/rest/admin/purge_room_servlet.py | 2 +- synapse/rest/admin/server_notice_servlet.py | 2 +- synapse/rest/admin/users.py | 4 +- synapse/rest/client/v1/directory.py | 16 +++--- synapse/rest/client/v1/events.py | 8 +-- synapse/rest/client/v1/initial_sync.py | 2 +- synapse/rest/client/v1/login.py | 6 +- synapse/rest/client/v1/logout.py | 8 +-- synapse/rest/client/v1/presence.py | 6 +- synapse/rest/client/v1/profile.py | 18 +++--- synapse/rest/client/v1/push_rule.py | 10 ++-- synapse/rest/client/v1/pusher.py | 6 +- synapse/rest/client/v1/room.py | 48 ++++++++-------- synapse/rest/client/v1/voip.py | 4 +- synapse/rest/client/v2_alpha/account.py | 24 ++++---- synapse/rest/client/v2_alpha/account_data.py | 8 +-- synapse/rest/client/v2_alpha/capabilities.py | 2 +- synapse/rest/client/v2_alpha/devices.py | 10 ++-- synapse/rest/client/v2_alpha/filter.py | 4 +- synapse/rest/client/v2_alpha/groups.py | 64 +++++++++++----------- synapse/rest/client/v2_alpha/keys.py | 8 +-- synapse/rest/client/v2_alpha/notifications.py | 2 +- synapse/rest/client/v2_alpha/read_marker.py | 2 +- synapse/rest/client/v2_alpha/receipts.py | 2 +- synapse/rest/client/v2_alpha/register.py | 10 ++-- synapse/rest/client/v2_alpha/relations.py | 8 +-- synapse/rest/client/v2_alpha/report_event.py | 2 +- synapse/rest/client/v2_alpha/room_keys.py | 14 ++--- .../client/v2_alpha/room_upgrade_rest_servlet.py | 2 +- synapse/rest/client/v2_alpha/sync.py | 2 +- synapse/rest/client/v2_alpha/tags.py | 6 +- synapse/rest/client/v2_alpha/thirdparty.py | 10 ++-- synapse/rest/client/v2_alpha/user_directory.py | 4 +- synapse/rest/media/v1/media_repository.py | 4 +- synapse/rest/media/v1/thumbnailer.py | 4 +- .../resource_limits_server_notices.py | 2 +- synapse/storage/account_data.py | 8 +-- synapse/storage/appservice.py | 2 +- synapse/storage/deviceinbox.py | 4 +- synapse/storage/devices.py | 10 ++-- synapse/storage/events.py | 10 ++-- synapse/storage/presence.py | 2 +- synapse/storage/pusher.py | 2 +- synapse/storage/receipts.py | 2 +- synapse/storage/stream.py | 12 ++-- synapse/storage/util/id_generators.py | 4 +- synapse/streams/config.py | 2 +- tests/handlers/test_register.py | 2 +- tests/rest/client/v2_alpha/test_register.py | 2 +- tests/server.py | 2 +- tests/test_server.py | 2 +- tests/test_state.py | 2 +- tests/utils.py | 2 +- 81 files changed, 287 insertions(+), 286 deletions(-) create mode 100644 changelog.d/5931.misc (limited to 'tests') diff --git a/changelog.d/5931.misc b/changelog.d/5931.misc new file mode 100644 index 0000000000..ac8e74f5b9 --- /dev/null +++ b/changelog.d/5931.misc @@ -0,0 +1 @@ +Remove unnecessary parentheses in return statements. \ No newline at end of file diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 7b3a5a8221..fd3cdf50b0 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -276,25 +276,25 @@ class Auth(object): self.get_access_token_from_request(request) ) if app_service is None: - return (None, None) + return None, None if app_service.ip_range_whitelist: ip_address = IPAddress(self.hs.get_ip_from_request(request)) if ip_address not in app_service.ip_range_whitelist: - return (None, None) + return None, None if b"user_id" not in request.args: - return (app_service.sender, app_service) + return app_service.sender, app_service user_id = request.args[b"user_id"][0].decode("utf8") if app_service.sender == user_id: - return (app_service.sender, app_service) + return app_service.sender, app_service if not app_service.is_interested_in_user(user_id): raise AuthError(403, "Application service cannot masquerade as this user.") if not (yield self.store.get_user_by_id(user_id)): raise AuthError(403, "Application service has not registered this user") - return (user_id, app_service) + return user_id, app_service @defer.inlineCallbacks def get_user_by_access_token(self, token, rights="access"): @@ -694,7 +694,7 @@ class Auth(object): # * The user is a guest user, and has joined the room # else it will throw. member_event = yield self.check_user_was_in_room(room_id, user_id) - return (member_event.membership, member_event.event_id) + return member_event.membership, member_event.event_id except AuthError: visibility = yield self.state.get_current_state( room_id, EventTypes.RoomHistoryVisibility, "" @@ -703,7 +703,7 @@ class Auth(object): visibility and visibility.content["history_visibility"] == "world_readable" ): - return (Membership.JOIN, None) + return Membership.JOIN, None return raise AuthError( 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index 611d285421..9504bfbc70 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -70,12 +70,12 @@ class PresenceStatusStubServlet(RestServlet): except HttpResponseException as e: raise e.to_synapse_error() - return (200, result) + return 200, result @defer.inlineCallbacks def on_PUT(self, request, user_id): yield self.auth.get_user_by_req(request) - return (200, {}) + return 200, {} class KeyUploadServlet(RestServlet): @@ -126,11 +126,11 @@ class KeyUploadServlet(RestServlet): self.main_uri + request.uri.decode("ascii"), body, headers=headers ) - return (200, result) + return 200, result else: # Just interested in counts. result = yield self.store.count_e2e_one_time_keys(user_id, device_id) - return (200, {"one_time_key_counts": result}) + return 200, {"one_time_key_counts": result} class FrontendProxySlavedStore( diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py index 41eabbe717..694fb2c816 100644 --- a/synapse/crypto/event_signing.py +++ b/synapse/crypto/event_signing.py @@ -83,7 +83,7 @@ def compute_content_hash(event_dict, hash_algorithm): event_json_bytes = encode_canonical_json(event_dict) hashed = hash_algorithm(event_json_bytes) - return (hashed.name, hashed.digest()) + return hashed.name, hashed.digest() def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256): @@ -106,7 +106,7 @@ def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256): event_dict.pop("unsigned", None) event_json_bytes = encode_canonical_json(event_dict) hashed = hash_algorithm(event_json_bytes) - return (hashed.name, hashed.digest()) + return hashed.name, hashed.digest() def compute_event_signature(event_dict, signature_name, signing_key): diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index bec3080895..6ee6216660 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -355,7 +355,7 @@ class FederationClient(FederationBase): auth_chain.sort(key=lambda e: e.depth) - return (pdus, auth_chain) + return pdus, auth_chain except HttpResponseException as e: if e.code == 400 or e.code == 404: logger.info("Failed to use get_room_state_ids API, falling back") @@ -404,7 +404,7 @@ class FederationClient(FederationBase): signed_auth.sort(key=lambda e: e.depth) - return (signed_pdus, signed_auth) + return signed_pdus, signed_auth @defer.inlineCallbacks def get_events_from_store_or_dest(self, destination, room_id, event_ids): @@ -429,7 +429,7 @@ class FederationClient(FederationBase): missing_events.discard(k) if not missing_events: - return (signed_events, failed_to_fetch) + return signed_events, failed_to_fetch logger.debug( "Fetching unknown state/auth events %s for room %s", @@ -465,7 +465,7 @@ class FederationClient(FederationBase): # We removed all events we successfully fetched from `batch` failed_to_fetch.update(batch) - return (signed_events, failed_to_fetch) + return signed_events, failed_to_fetch @defer.inlineCallbacks @log_function diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 05fd49f3c1..e5f0b90aec 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -100,7 +100,7 @@ class FederationServer(FederationBase): res = self._transaction_from_pdus(pdus).get_dict() - return (200, res) + return 200, res @defer.inlineCallbacks @log_function @@ -163,7 +163,7 @@ class FederationServer(FederationBase): yield self.transaction_actions.set_response( origin, transaction, 400, response ) - return (400, response) + return 400, response received_pdus_counter.inc(len(transaction.pdus)) @@ -265,7 +265,7 @@ class FederationServer(FederationBase): logger.debug("Returning: %s", str(response)) yield self.transaction_actions.set_response(origin, transaction, 200, response) - return (200, response) + return 200, response @defer.inlineCallbacks def received_edu(self, origin, edu_type, content): @@ -298,7 +298,7 @@ class FederationServer(FederationBase): event_id, ) - return (200, resp) + return 200, resp @defer.inlineCallbacks def on_state_ids_request(self, origin, room_id, event_id): @@ -315,7 +315,7 @@ class FederationServer(FederationBase): state_ids = yield self.handler.get_state_ids_for_pdu(room_id, event_id) auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids) - return (200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}) + return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} @defer.inlineCallbacks def _on_context_state_request_compute(self, room_id, event_id): @@ -345,15 +345,15 @@ class FederationServer(FederationBase): pdu = yield self.handler.get_persisted_pdu(origin, event_id) if pdu: - return (200, self._transaction_from_pdus([pdu]).get_dict()) + return 200, self._transaction_from_pdus([pdu]).get_dict() else: - return (404, "") + return 404, "" @defer.inlineCallbacks def on_query_request(self, query_type, args): received_queries_counter.labels(query_type).inc() resp = yield self.registry.on_query(query_type, args) - return (200, resp) + return 200, resp @defer.inlineCallbacks def on_make_join_request(self, origin, room_id, user_id, supported_versions): @@ -435,7 +435,7 @@ class FederationServer(FederationBase): logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures) yield self.handler.on_send_leave_request(origin, pdu) - return (200, {}) + return 200, {} @defer.inlineCallbacks def on_event_auth(self, origin, room_id, event_id): @@ -446,7 +446,7 @@ class FederationServer(FederationBase): time_now = self._clock.time_msec() auth_pdus = yield self.handler.on_event_auth(event_id) res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]} - return (200, res) + return 200, res @defer.inlineCallbacks def on_query_auth_request(self, origin, content, room_id, event_id): @@ -499,7 +499,7 @@ class FederationServer(FederationBase): "missing": ret.get("missing", []), } - return (200, send_content) + return 200, send_content @log_function def on_query_client_keys(self, origin, content): diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 8acd9f9a83..38bc67191c 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -51,8 +51,8 @@ class AccountDataEventSource(object): {"type": account_data_type, "content": content, "room_id": room_id} ) - return (results, current_stream_id) + return results, current_stream_id @defer.inlineCallbacks def get_pagination_rows(self, user, config, key): - return ([], config.to_id) + return [], config.to_id diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 0f3ebf7ef8..f844409d21 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -280,7 +280,7 @@ class AuthHandler(BaseHandler): creds, list(clientdict), ) - return (creds, clientdict, session["id"]) + return creds, clientdict, session["id"] ret = self._auth_dict_for_flows(flows, session) ret["completed"] = list(creds) @@ -722,7 +722,7 @@ class AuthHandler(BaseHandler): known_login_type = True is_valid = yield provider.check_password(qualified_user_id, password) if is_valid: - return (qualified_user_id, None) + return qualified_user_id, None if not hasattr(provider, "get_supported_login_types") or not hasattr( provider, "check_auth" @@ -766,7 +766,7 @@ class AuthHandler(BaseHandler): ) if canonical_user_id: - return (canonical_user_id, None) + return canonical_user_id, None if not known_login_type: raise SynapseError(400, "Unknown login type %s" % login_type) @@ -816,7 +816,7 @@ class AuthHandler(BaseHandler): result = (result, None) return result - return (None, None) + return None, None @defer.inlineCallbacks def _check_local_password(self, user_id, password): diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 94306c94a9..538b16efd6 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1428,7 +1428,7 @@ class FederationHandler(BaseHandler): assert event.user_id == user_id assert event.state_key == user_id assert event.room_id == room_id - return (origin, event, format_ver) + return origin, event, format_ver @defer.inlineCallbacks @log_function diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 42d6650ed9..595f75400b 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -449,7 +449,7 @@ class InitialSyncHandler(BaseHandler): # * The user is a guest user, and has joined the room # else it will throw. member_event = yield self.auth.check_user_was_in_room(room_id, user_id) - return (member_event.membership, member_event.event_id) + return member_event.membership, member_event.event_id return except AuthError: visibility = yield self.state_handler.get_current_state( @@ -459,7 +459,7 @@ class InitialSyncHandler(BaseHandler): visibility and visibility.content["history_visibility"] == "world_readable" ): - return (Membership.JOIN, None) + return Membership.JOIN, None return raise AuthError( 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 94a9ca0357..8377a0ddc2 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -1032,7 +1032,7 @@ class PresenceEventSource(object): # # Hence this guard where we just return nothing so that the sync # doesn't return. C.f. #5503. - return ([], max_token) + return [], max_token presence = self.get_presence_handler() stream_change_cache = self.store.presence_stream_cache @@ -1279,7 +1279,7 @@ def get_interested_parties(store, states): # Always notify self users_to_states.setdefault(state.user_id, []).append(state) - return (room_ids_to_states, users_to_states) + return room_ids_to_states, users_to_states @defer.inlineCallbacks diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 73973502a4..6854c751a6 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -148,7 +148,7 @@ class ReceiptEventSource(object): to_key = yield self.get_current_key() if from_key == to_key: - return ([], to_key) + return [], to_key events = yield self.store.get_linearized_receipts_for_rooms( room_ids, from_key=from_key, to_key=to_key diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 4631fab94e..be0425a33b 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -622,7 +622,7 @@ class RegistrationHandler(BaseHandler): initial_display_name=initial_display_name, is_guest=is_guest, ) - return (r["device_id"], r["access_token"]) + return r["device_id"], r["access_token"] valid_until_ms = None if self.session_lifetime is not None: diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 249a6d9c5d..f03a2bd540 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -903,7 +903,7 @@ class RoomMemberHandler(object): if not public_keys: public_keys.append(fallback_public_key) display_name = data["display_name"] - return (token, public_keys, fallback_public_key, display_name) + return token, public_keys, fallback_public_key, display_name @defer.inlineCallbacks def _is_host_in_room(self, current_state_ids): diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index ef7f2ca980..d582f8e494 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -378,7 +378,7 @@ class SyncHandler(object): event_copy = {k: v for (k, v) in iteritems(event) if k != "room_id"} ephemeral_by_room.setdefault(room_id, []).append(event_copy) - return (now_token, ephemeral_by_room) + return now_token, ephemeral_by_room @defer.inlineCallbacks def _load_filtered_recents( @@ -1332,7 +1332,7 @@ class SyncHandler(object): ) if not tags_by_room: logger.debug("no-oping sync") - return ([], [], [], []) + return [], [], [], [] ignored_account_data = yield self.store.get_global_account_data_by_type_for_user( "m.ignored_user_list", user_id=user_id @@ -1642,7 +1642,7 @@ class SyncHandler(object): ) room_entries.append(entry) - return (room_entries, invited, newly_joined_rooms, newly_left_rooms) + return room_entries, invited, newly_joined_rooms, newly_left_rooms @defer.inlineCallbacks def _get_all_rooms(self, sync_result_builder, ignored_users): @@ -1716,7 +1716,7 @@ class SyncHandler(object): ) ) - return (room_entries, invited, []) + return room_entries, invited, [] @defer.inlineCallbacks def _generate_room_entry( diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index f882330293..ca8ae9fb5b 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -319,4 +319,4 @@ class TypingNotificationEventSource(object): return self.get_typing_handler()._latest_room_serial def get_pagination_rows(self, user, pagination_config, key): - return ([], pagination_config.from_key) + return [], pagination_config.from_key diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py index 5e9b0befb0..7ddfad286d 100644 --- a/synapse/http/federation/well_known_resolver.py +++ b/synapse/http/federation/well_known_resolver.py @@ -207,7 +207,7 @@ class WellKnownResolver(object): cache_period + WELL_KNOWN_REMEMBER_DOMAIN_HAD_VALID, ) - return (result, cache_period) + return result, cache_period @defer.inlineCallbacks def _make_well_known_request(self, server_name, retry): diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 41147d4292..735b882363 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -101,7 +101,7 @@ class ModuleApi(object): ) user_id = yield self.register_user(localpart, displayname, emails) _, access_token = yield self.register_device(user_id) - return (user_id, access_token) + return user_id, access_token def register_user(self, localpart, displayname=None, emails=[]): """Registers a new user with given localpart and optional displayname, emails. diff --git a/synapse/notifier.py b/synapse/notifier.py index bd80c801b6..4e091314e6 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -472,11 +472,11 @@ class Notifier(object): joined_room_ids = yield self.store.get_rooms_for_user(user.to_string()) if explicit_room_id: if explicit_room_id in joined_room_ids: - return ([explicit_room_id], True) + return [explicit_room_id], True if (yield self._is_world_readable(explicit_room_id)): - return ([explicit_room_id], False) + return [explicit_room_id], False raise AuthError(403, "Non-joined access not allowed") - return (joined_room_ids, True) + return joined_room_ids, True @defer.inlineCallbacks def _is_world_readable(self, room_id): diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index c831975635..22491f3700 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -134,7 +134,7 @@ class BulkPushRuleEvaluator(object): pl_event = auth_events.get(POWER_KEY) - return (pl_event.content if pl_event else {}, sender_level) + return pl_event.content if pl_event else {}, sender_level @defer.inlineCallbacks def action_for_event_by_user(self, event, context): diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index fed4f08820..2f16955954 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -113,7 +113,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): event_and_contexts, backfilled ) - return (200, {}) + return 200, {} class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): @@ -156,7 +156,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): result = yield self.registry.on_edu(edu_type, origin, edu_content) - return (200, result) + return 200, result class ReplicationGetQueryRestServlet(ReplicationEndpoint): @@ -204,7 +204,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): result = yield self.registry.on_query(query_type, args) - return (200, result) + return 200, result class ReplicationCleanRoomRestServlet(ReplicationEndpoint): @@ -238,7 +238,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint): def _handle_request(self, request, room_id): yield self.store.clean_room_for_join(room_id) - return (200, {}) + return 200, {} def register_servlets(hs, http_server): diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index f17d3a2da4..786f5232b2 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -64,7 +64,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): user_id, device_id, initial_display_name, is_guest ) - return (200, {"device_id": device_id, "access_token": access_token}) + return 200, {"device_id": device_id, "access_token": access_token} def register_servlets(hs, http_server): diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 4217335d88..b9ce3477ad 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -83,7 +83,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): remote_room_hosts, room_id, user_id, event_content ) - return (200, {}) + return 200, {} class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): @@ -153,7 +153,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): yield self.store.locally_reject_invite(user_id, room_id) ret = {} - return (200, ret) + return 200, ret class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): @@ -202,7 +202,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): else: raise Exception("Unrecognized change: %r", change) - return (200, {}) + return 200, {} def register_servlets(hs, http_server): diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index 3341320a87..87fe2dd9b0 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -90,7 +90,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): address=content["address"], ) - return (200, {}) + return 200, {} class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): @@ -143,7 +143,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): bind_msisdn=bind_msisdn, ) - return (200, {}) + return 200, {} def register_servlets(hs, http_server): diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index eff7bd7305..adb9b2f7f4 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -117,7 +117,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) - return (200, {}) + return 200, {} def register_servlets(hs, http_server): diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index c10b85d2ff..f03111c259 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -158,7 +158,7 @@ class Stream(object): updates, current_token = yield self.get_updates_since(self.last_token) self.last_token = current_token - return (updates, current_token) + return updates, current_token @defer.inlineCallbacks def get_updates_since(self, from_token): @@ -172,14 +172,14 @@ class Stream(object): sent over the replication steam. """ if from_token in ("NOW", "now"): - return ([], self.upto_token) + return [], self.upto_token current_token = self.upto_token from_token = int(from_token) if from_token == current_token: - return ([], current_token) + return [], current_token if self._LIMITED: rows = yield self.update_function( @@ -198,7 +198,7 @@ class Stream(object): if self._LIMITED and len(updates) >= MAX_EVENTS_BEHIND: raise Exception("stream %s has fallen behind" % (self.NAME)) - return (updates, current_token) + return updates, current_token def current_token(self): """Gets the current token of the underlying streams. Should be provided diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index fa91cc8dee..b4761adaed 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -69,7 +69,7 @@ class UsersRestServlet(RestServlet): ret = yield self.handlers.admin_handler.get_users() - return (200, ret) + return 200, ret class VersionServlet(RestServlet): @@ -120,7 +120,7 @@ class UserRegisterServlet(RestServlet): nonce = self.hs.get_secrets().token_hex(64) self.nonces[nonce] = int(self.reactor.seconds()) - return (200, {"nonce": nonce}) + return 200, {"nonce": nonce} @defer.inlineCallbacks def on_POST(self, request): @@ -212,7 +212,7 @@ class UserRegisterServlet(RestServlet): ) result = yield register._create_registration_details(user_id, body) - return (200, result) + return 200, result class WhoisRestServlet(RestServlet): @@ -237,7 +237,7 @@ class WhoisRestServlet(RestServlet): ret = yield self.handlers.admin_handler.get_whois(target_user) - return (200, ret) + return 200, ret class PurgeHistoryRestServlet(RestServlet): @@ -322,7 +322,7 @@ class PurgeHistoryRestServlet(RestServlet): room_id, token, delete_local_events=delete_local_events ) - return (200, {"purge_id": purge_id}) + return 200, {"purge_id": purge_id} class PurgeHistoryStatusRestServlet(RestServlet): @@ -347,7 +347,7 @@ class PurgeHistoryStatusRestServlet(RestServlet): if purge_status is None: raise NotFoundError("purge id '%s' not found" % purge_id) - return (200, purge_status.asdict()) + return 200, purge_status.asdict() class DeactivateAccountRestServlet(RestServlet): @@ -379,7 +379,7 @@ class DeactivateAccountRestServlet(RestServlet): else: id_server_unbind_result = "no-support" - return (200, {"id_server_unbind_result": id_server_unbind_result}) + return 200, {"id_server_unbind_result": id_server_unbind_result} class ShutdownRoomRestServlet(RestServlet): @@ -549,7 +549,7 @@ class ResetPasswordRestServlet(RestServlet): yield self._set_password_handler.set_password( target_user_id, new_password, requester ) - return (200, {}) + return 200, {} class GetUsersPaginatedRestServlet(RestServlet): @@ -591,7 +591,7 @@ class GetUsersPaginatedRestServlet(RestServlet): logger.info("limit: %s, start: %s", limit, start) ret = yield self.handlers.admin_handler.get_users_paginate(order, start, limit) - return (200, ret) + return 200, ret @defer.inlineCallbacks def on_POST(self, request, target_user_id): @@ -619,7 +619,7 @@ class GetUsersPaginatedRestServlet(RestServlet): logger.info("limit: %s, start: %s", limit, start) ret = yield self.handlers.admin_handler.get_users_paginate(order, start, limit) - return (200, ret) + return 200, ret class SearchUsersRestServlet(RestServlet): @@ -662,7 +662,7 @@ class SearchUsersRestServlet(RestServlet): logger.info("term: %s ", term) ret = yield self.handlers.admin_handler.search_users(term) - return (200, ret) + return 200, ret class DeleteGroupAdminRestServlet(RestServlet): @@ -685,7 +685,7 @@ class DeleteGroupAdminRestServlet(RestServlet): raise SynapseError(400, "Can only delete local groups") yield self.group_server.delete_group(group_id, requester.user.to_string()) - return (200, {}) + return 200, {} class AccountValidityRenewServlet(RestServlet): @@ -716,7 +716,7 @@ class AccountValidityRenewServlet(RestServlet): ) res = {"expiration_ts": expiration_ts} - return (200, res) + return 200, res ######################################################################################## diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 824df919f2..f3f63f0be7 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -49,7 +49,7 @@ class QuarantineMediaInRoom(RestServlet): room_id, requester.user.to_string() ) - return (200, {"num_quarantined": num_quarantined}) + return 200, {"num_quarantined": num_quarantined} class ListMediaInRoom(RestServlet): @@ -70,7 +70,7 @@ class ListMediaInRoom(RestServlet): local_mxcs, remote_mxcs = yield self.store.get_media_mxcs_in_room(room_id) - return (200, {"local": local_mxcs, "remote": remote_mxcs}) + return 200, {"local": local_mxcs, "remote": remote_mxcs} class PurgeMediaCacheRestServlet(RestServlet): @@ -89,7 +89,7 @@ class PurgeMediaCacheRestServlet(RestServlet): ret = yield self.media_repository.delete_old_remote_media(before_ts) - return (200, ret) + return 200, ret def register_servlets_for_media_repo(hs, http_server): diff --git a/synapse/rest/admin/purge_room_servlet.py b/synapse/rest/admin/purge_room_servlet.py index 2922eb543e..f474066542 100644 --- a/synapse/rest/admin/purge_room_servlet.py +++ b/synapse/rest/admin/purge_room_servlet.py @@ -54,4 +54,4 @@ class PurgeRoomServlet(RestServlet): await self.pagination_handler.purge_room(body["room_id"]) - return (200, {}) + return 200, {} diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py index 656526fea5..ae2cbe2e0a 100644 --- a/synapse/rest/admin/server_notice_servlet.py +++ b/synapse/rest/admin/server_notice_servlet.py @@ -92,7 +92,7 @@ class SendServerNoticeServlet(RestServlet): event_content=body["content"], ) - return (200, {"event_id": event.event_id}) + return 200, {"event_id": event.event_id} def on_PUT(self, request, txn_id): return self.txns.fetch_or_execute_request( diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 5364117420..9720a3bab0 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -71,7 +71,7 @@ class UserAdminServlet(RestServlet): is_admin = yield self.handlers.admin_handler.get_user_server_admin(target_user) is_admin = bool(is_admin) - return (200, {"admin": is_admin}) + return 200, {"admin": is_admin} @defer.inlineCallbacks def on_PUT(self, request, user_id): @@ -97,4 +97,4 @@ class UserAdminServlet(RestServlet): target_user, set_admin_to ) - return (200, {}) + return 200, {} diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 4284738021..4ea3666874 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -54,7 +54,7 @@ class ClientDirectoryServer(RestServlet): dir_handler = self.handlers.directory_handler res = yield dir_handler.get_association(room_alias) - return (200, res) + return 200, res @defer.inlineCallbacks def on_PUT(self, request, room_alias): @@ -87,7 +87,7 @@ class ClientDirectoryServer(RestServlet): requester, room_alias, room_id, servers ) - return (200, {}) + return 200, {} @defer.inlineCallbacks def on_DELETE(self, request, room_alias): @@ -102,7 +102,7 @@ class ClientDirectoryServer(RestServlet): service.url, room_alias.to_string(), ) - return (200, {}) + return 200, {} except InvalidClientCredentialsError: # fallback to default user behaviour if they aren't an AS pass @@ -118,7 +118,7 @@ class ClientDirectoryServer(RestServlet): "User %s deleted alias %s", user.to_string(), room_alias.to_string() ) - return (200, {}) + return 200, {} class ClientDirectoryListServer(RestServlet): @@ -136,7 +136,7 @@ class ClientDirectoryListServer(RestServlet): if room is None: raise NotFoundError("Unknown room") - return (200, {"visibility": "public" if room["is_public"] else "private"}) + return 200, {"visibility": "public" if room["is_public"] else "private"} @defer.inlineCallbacks def on_PUT(self, request, room_id): @@ -149,7 +149,7 @@ class ClientDirectoryListServer(RestServlet): requester, room_id, visibility ) - return (200, {}) + return 200, {} @defer.inlineCallbacks def on_DELETE(self, request, room_id): @@ -159,7 +159,7 @@ class ClientDirectoryListServer(RestServlet): requester, room_id, "private" ) - return (200, {}) + return 200, {} class ClientAppserviceDirectoryListServer(RestServlet): @@ -193,4 +193,4 @@ class ClientAppserviceDirectoryListServer(RestServlet): requester.app_service.id, network_id, room_id, visibility ) - return (200, {}) + return 200, {} diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 53ebed2203..6651b4cf07 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -67,10 +67,10 @@ class EventStreamRestServlet(RestServlet): is_guest=is_guest, ) - return (200, chunk) + return 200, chunk def on_OPTIONS(self, request): - return (200, {}) + return 200, {} # TODO: Unit test gets, with and without auth, with different kinds of events. @@ -91,9 +91,9 @@ class EventRestServlet(RestServlet): time_now = self.clock.time_msec() if event: event = yield self._event_serializer.serialize_event(event, time_now) - return (200, event) + return 200, event else: - return (404, "Event not found.") + return 404, "Event not found." def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 70b8478e90..2da3cd7511 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -42,7 +42,7 @@ class InitialSyncRestServlet(RestServlet): include_archived=include_archived, ) - return (200, content) + return 200, content def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 5762b9fd06..25a1b67092 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -121,10 +121,10 @@ class LoginRestServlet(RestServlet): ({"type": t} for t in self.auth_handler.get_supported_login_types()) ) - return (200, {"flows": flows}) + return 200, {"flows": flows} def on_OPTIONS(self, request): - return (200, {}) + return 200, {} @defer.inlineCallbacks def on_POST(self, request): @@ -152,7 +152,7 @@ class LoginRestServlet(RestServlet): well_known_data = self._well_known_builder.get_well_known() if well_known_data: result["well_known"] = well_known_data - return (200, result) + return 200, result @defer.inlineCallbacks def _do_other_login(self, login_submission): diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py index 2769f3a189..4785a34d75 100644 --- a/synapse/rest/client/v1/logout.py +++ b/synapse/rest/client/v1/logout.py @@ -33,7 +33,7 @@ class LogoutRestServlet(RestServlet): self._device_handler = hs.get_device_handler() def on_OPTIONS(self, request): - return (200, {}) + return 200, {} @defer.inlineCallbacks def on_POST(self, request): @@ -49,7 +49,7 @@ class LogoutRestServlet(RestServlet): requester.user.to_string(), requester.device_id ) - return (200, {}) + return 200, {} class LogoutAllRestServlet(RestServlet): @@ -62,7 +62,7 @@ class LogoutAllRestServlet(RestServlet): self._device_handler = hs.get_device_handler() def on_OPTIONS(self, request): - return (200, {}) + return 200, {} @defer.inlineCallbacks def on_POST(self, request): @@ -75,7 +75,7 @@ class LogoutAllRestServlet(RestServlet): # .. and then delete any access tokens which weren't associated with # devices. yield self._auth_handler.delete_access_tokens_for_user(user_id) - return (200, {}) + return 200, {} def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index 1eb1068c98..0153525cef 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -56,7 +56,7 @@ class PresenceStatusRestServlet(RestServlet): state = yield self.presence_handler.get_state(target_user=user) state = format_user_presence_state(state, self.clock.time_msec()) - return (200, state) + return 200, state @defer.inlineCallbacks def on_PUT(self, request, user_id): @@ -88,10 +88,10 @@ class PresenceStatusRestServlet(RestServlet): if self.hs.config.use_presence: yield self.presence_handler.set_state(user, state) - return (200, {}) + return 200, {} def on_OPTIONS(self, request): - return (200, {}) + return 200, {} def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index 2657ae45bb..bbce2e2b71 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -48,7 +48,7 @@ class ProfileDisplaynameRestServlet(RestServlet): if displayname is not None: ret["displayname"] = displayname - return (200, ret) + return 200, ret @defer.inlineCallbacks def on_PUT(self, request, user_id): @@ -61,14 +61,14 @@ class ProfileDisplaynameRestServlet(RestServlet): try: new_name = content["displayname"] except Exception: - return (400, "Unable to parse name") + return 400, "Unable to parse name" yield self.profile_handler.set_displayname(user, requester, new_name, is_admin) - return (200, {}) + return 200, {} def on_OPTIONS(self, request, user_id): - return (200, {}) + return 200, {} class ProfileAvatarURLRestServlet(RestServlet): @@ -98,7 +98,7 @@ class ProfileAvatarURLRestServlet(RestServlet): if avatar_url is not None: ret["avatar_url"] = avatar_url - return (200, ret) + return 200, ret @defer.inlineCallbacks def on_PUT(self, request, user_id): @@ -110,14 +110,14 @@ class ProfileAvatarURLRestServlet(RestServlet): try: new_name = content["avatar_url"] except Exception: - return (400, "Unable to parse name") + return 400, "Unable to parse name" yield self.profile_handler.set_avatar_url(user, requester, new_name, is_admin) - return (200, {}) + return 200, {} def on_OPTIONS(self, request, user_id): - return (200, {}) + return 200, {} class ProfileRestServlet(RestServlet): @@ -150,7 +150,7 @@ class ProfileRestServlet(RestServlet): if avatar_url is not None: ret["avatar_url"] = avatar_url - return (200, ret) + return 200, ret def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index c3ae8b98a8..9f8c3d09e3 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -69,7 +69,7 @@ class PushRuleRestServlet(RestServlet): if "attr" in spec: yield self.set_rule_attr(user_id, spec, content) self.notify_user(user_id) - return (200, {}) + return 200, {} if spec["rule_id"].startswith("."): # Rule ids starting with '.' are reserved for server default rules. @@ -106,7 +106,7 @@ class PushRuleRestServlet(RestServlet): except RuleNotFoundException as e: raise SynapseError(400, str(e)) - return (200, {}) + return 200, {} @defer.inlineCallbacks def on_DELETE(self, request, path): @@ -123,7 +123,7 @@ class PushRuleRestServlet(RestServlet): try: yield self.store.delete_push_rule(user_id, namespaced_rule_id) self.notify_user(user_id) - return (200, {}) + return 200, {} except StoreError as e: if e.code == 404: raise NotFoundError() @@ -151,10 +151,10 @@ class PushRuleRestServlet(RestServlet): ) if path[0] == "": - return (200, rules) + return 200, rules elif path[0] == "global": result = _filter_ruleset_with_path(rules["global"], path[1:]) - return (200, result) + return 200, result else: raise UnrecognizedRequestError() diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index ebc3dec516..41660682d9 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -62,7 +62,7 @@ class PushersRestServlet(RestServlet): if k not in allowed_keys: del p[k] - return (200, {"pushers": pushers}) + return 200, {"pushers": pushers} def on_OPTIONS(self, _): return 200, {} @@ -94,7 +94,7 @@ class PushersSetRestServlet(RestServlet): yield self.pusher_pool.remove_pusher( content["app_id"], content["pushkey"], user_id=user.to_string() ) - return (200, {}) + return 200, {} assert_params_in_dict( content, @@ -143,7 +143,7 @@ class PushersSetRestServlet(RestServlet): self.notifier.on_new_replication_data() - return (200, {}) + return 200, {} def on_OPTIONS(self, _): return 200, {} diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 4b2344e696..f244e8f469 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -91,14 +91,14 @@ class RoomCreateRestServlet(TransactionRestServlet): requester, self.get_room_config(request) ) - return (200, info) + return 200, info def get_room_config(self, request): user_supplied_config = parse_json_object_from_request(request) return user_supplied_config def on_OPTIONS(self, request): - return (200, {}) + return 200, {} # TODO: Needs unit testing for generic events @@ -173,9 +173,9 @@ class RoomStateEventRestServlet(TransactionRestServlet): if format == "event": event = format_event_for_client_v2(data.get_dict()) - return (200, event) + return 200, event elif format == "content": - return (200, data.get_dict()["content"]) + return 200, data.get_dict()["content"] @defer.inlineCallbacks def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): @@ -210,7 +210,7 @@ class RoomStateEventRestServlet(TransactionRestServlet): ret = {} if event: ret = {"event_id": event.event_id} - return (200, ret) + return 200, ret # TODO: Needs unit testing for generic events + feedback @@ -244,10 +244,10 @@ class RoomSendEventRestServlet(TransactionRestServlet): requester, event_dict, txn_id=txn_id ) - return (200, {"event_id": event.event_id}) + return 200, {"event_id": event.event_id} def on_GET(self, request, room_id, event_type, txn_id): - return (200, "Not implemented") + return 200, "Not implemented" def on_PUT(self, request, room_id, event_type, txn_id): return self.txns.fetch_or_execute_request( @@ -307,7 +307,7 @@ class JoinRoomAliasServlet(TransactionRestServlet): third_party_signed=content.get("third_party_signed", None), ) - return (200, {"room_id": room_id}) + return 200, {"room_id": room_id} def on_PUT(self, request, room_identifier, txn_id): return self.txns.fetch_or_execute_request( @@ -360,7 +360,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): limit=limit, since_token=since_token ) - return (200, data) + return 200, data @defer.inlineCallbacks def on_POST(self, request): @@ -405,7 +405,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): network_tuple=network_tuple, ) - return (200, data) + return 200, data # TODO: Needs unit testing @@ -456,7 +456,7 @@ class RoomMemberListRestServlet(RestServlet): continue chunk.append(event) - return (200, {"chunk": chunk}) + return 200, {"chunk": chunk} # deprecated in favour of /members?membership=join? @@ -477,7 +477,7 @@ class JoinedRoomMemberListRestServlet(RestServlet): requester, room_id ) - return (200, {"joined": users_with_profile}) + return 200, {"joined": users_with_profile} # TODO: Needs better unit testing @@ -510,7 +510,7 @@ class RoomMessageListRestServlet(RestServlet): event_filter=event_filter, ) - return (200, msgs) + return 200, msgs # TODO: Needs unit testing @@ -531,7 +531,7 @@ class RoomStateRestServlet(RestServlet): user_id=requester.user.to_string(), is_guest=requester.is_guest, ) - return (200, events) + return 200, events # TODO: Needs unit testing @@ -550,7 +550,7 @@ class RoomInitialSyncRestServlet(RestServlet): content = yield self.initial_sync_handler.room_initial_sync( room_id=room_id, requester=requester, pagin_config=pagination_config ) - return (200, content) + return 200, content class RoomEventServlet(RestServlet): @@ -581,7 +581,7 @@ class RoomEventServlet(RestServlet): time_now = self.clock.time_msec() if event: event = yield self._event_serializer.serialize_event(event, time_now) - return (200, event) + return 200, event return SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) @@ -633,7 +633,7 @@ class RoomEventContextServlet(RestServlet): results["state"], time_now ) - return (200, results) + return 200, results class RoomForgetRestServlet(TransactionRestServlet): @@ -652,7 +652,7 @@ class RoomForgetRestServlet(TransactionRestServlet): yield self.room_member_handler.forget(user=requester.user, room_id=room_id) - return (200, {}) + return 200, {} def on_PUT(self, request, room_id, txn_id): return self.txns.fetch_or_execute_request( @@ -702,7 +702,7 @@ class RoomMembershipRestServlet(TransactionRestServlet): requester, txn_id, ) - return (200, {}) + return 200, {} return target = requester.user @@ -729,7 +729,7 @@ class RoomMembershipRestServlet(TransactionRestServlet): if membership_action == "join": return_value["room_id"] = room_id - return (200, return_value) + return 200, return_value def _has_3pid_invite_keys(self, content): for key in {"id_server", "medium", "address"}: @@ -771,7 +771,7 @@ class RoomRedactEventRestServlet(TransactionRestServlet): txn_id=txn_id, ) - return (200, {"event_id": event.event_id}) + return 200, {"event_id": event.event_id} def on_PUT(self, request, room_id, event_id, txn_id): return self.txns.fetch_or_execute_request( @@ -816,7 +816,7 @@ class RoomTypingRestServlet(RestServlet): target_user=target_user, auth_user=requester.user, room_id=room_id ) - return (200, {}) + return 200, {} class SearchRestServlet(RestServlet): @@ -838,7 +838,7 @@ class SearchRestServlet(RestServlet): requester.user, content, batch ) - return (200, results) + return 200, results class JoinedRoomsRestServlet(RestServlet): @@ -854,7 +854,7 @@ class JoinedRoomsRestServlet(RestServlet): requester = yield self.auth.get_user_by_req(request, allow_guest=True) room_ids = yield self.store.get_rooms_for_user(requester.user.to_string()) - return (200, {"joined_rooms": list(room_ids)}) + return 200, {"joined_rooms": list(room_ids)} def register_txn_path(servlet, regex_string, http_server, with_get=False): diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index 497cddf8b8..2afdbb89e5 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -60,7 +60,7 @@ class VoipRestServlet(RestServlet): password = turnPassword else: - return (200, {}) + return 200, {} return ( 200, @@ -73,7 +73,7 @@ class VoipRestServlet(RestServlet): ) def on_OPTIONS(self, request): - return (200, {}) + return 200, {} def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 934ed5d16d..0620a4d0cf 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -117,7 +117,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # Wrap the session id in a JSON object ret = {"sid": sid} - return (200, ret) + return 200, ret @defer.inlineCallbacks def send_password_reset(self, email, client_secret, send_attempt, next_link=None): @@ -221,7 +221,7 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet): raise SynapseError(400, "MSISDN not found", Codes.THREEPID_NOT_FOUND) ret = yield self.identity_handler.requestMsisdnToken(**body) - return (200, ret) + return 200, ret class PasswordResetSubmitTokenServlet(RestServlet): @@ -330,7 +330,7 @@ class PasswordResetSubmitTokenServlet(RestServlet): ) response_code = 200 if valid else 400 - return (response_code, {"success": valid}) + return response_code, {"success": valid} class PasswordRestServlet(RestServlet): @@ -399,7 +399,7 @@ class PasswordRestServlet(RestServlet): yield self._set_password_handler.set_password(user_id, new_password, requester) - return (200, {}) + return 200, {} def on_OPTIONS(self, _): return 200, {} @@ -434,7 +434,7 @@ class DeactivateAccountRestServlet(RestServlet): yield self._deactivate_account_handler.deactivate_account( requester.user.to_string(), erase ) - return (200, {}) + return 200, {} yield self.auth_handler.validate_user_via_ui_auth( requester, body, self.hs.get_ip_from_request(request) @@ -447,7 +447,7 @@ class DeactivateAccountRestServlet(RestServlet): else: id_server_unbind_result = "no-support" - return (200, {"id_server_unbind_result": id_server_unbind_result}) + return 200, {"id_server_unbind_result": id_server_unbind_result} class EmailThreepidRequestTokenRestServlet(RestServlet): @@ -481,7 +481,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) ret = yield self.identity_handler.requestEmailToken(**body) - return (200, ret) + return 200, ret class MsisdnThreepidRequestTokenRestServlet(RestServlet): @@ -516,7 +516,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) ret = yield self.identity_handler.requestMsisdnToken(**body) - return (200, ret) + return 200, ret class ThreepidRestServlet(RestServlet): @@ -536,7 +536,7 @@ class ThreepidRestServlet(RestServlet): threepids = yield self.datastore.user_get_threepids(requester.user.to_string()) - return (200, {"threepids": threepids}) + return 200, {"threepids": threepids} @defer.inlineCallbacks def on_POST(self, request): @@ -568,7 +568,7 @@ class ThreepidRestServlet(RestServlet): logger.debug("Binding threepid %s to %s", threepid, user_id) yield self.identity_handler.bind_threepid(threePidCreds, user_id) - return (200, {}) + return 200, {} class ThreepidDeleteRestServlet(RestServlet): @@ -603,7 +603,7 @@ class ThreepidDeleteRestServlet(RestServlet): else: id_server_unbind_result = "no-support" - return (200, {"id_server_unbind_result": id_server_unbind_result}) + return 200, {"id_server_unbind_result": id_server_unbind_result} class WhoamiRestServlet(RestServlet): @@ -617,7 +617,7 @@ class WhoamiRestServlet(RestServlet): def on_GET(self, request): requester = yield self.auth.get_user_by_req(request) - return (200, {"user_id": requester.user.to_string()}) + return 200, {"user_id": requester.user.to_string()} def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py index 98f2f6f4b5..f0db204ffa 100644 --- a/synapse/rest/client/v2_alpha/account_data.py +++ b/synapse/rest/client/v2_alpha/account_data.py @@ -55,7 +55,7 @@ class AccountDataServlet(RestServlet): self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) - return (200, {}) + return 200, {} @defer.inlineCallbacks def on_GET(self, request, user_id, account_data_type): @@ -70,7 +70,7 @@ class AccountDataServlet(RestServlet): if event is None: raise NotFoundError("Account data not found") - return (200, event) + return 200, event class RoomAccountDataServlet(RestServlet): @@ -112,7 +112,7 @@ class RoomAccountDataServlet(RestServlet): self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) - return (200, {}) + return 200, {} @defer.inlineCallbacks def on_GET(self, request, user_id, room_id, account_data_type): @@ -127,7 +127,7 @@ class RoomAccountDataServlet(RestServlet): if event is None: raise NotFoundError("Room account data not found") - return (200, event) + return 200, event def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py index a4fa45fe11..acd58af193 100644 --- a/synapse/rest/client/v2_alpha/capabilities.py +++ b/synapse/rest/client/v2_alpha/capabilities.py @@ -58,7 +58,7 @@ class CapabilitiesRestServlet(RestServlet): "m.change_password": {"enabled": change_password}, } } - return (200, response) + return 200, response def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 9adf76cc0c..26d0235208 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -48,7 +48,7 @@ class DevicesRestServlet(RestServlet): devices = yield self.device_handler.get_devices_by_user( requester.user.to_string() ) - return (200, {"devices": devices}) + return 200, {"devices": devices} class DeleteDevicesRestServlet(RestServlet): @@ -91,7 +91,7 @@ class DeleteDevicesRestServlet(RestServlet): yield self.device_handler.delete_devices( requester.user.to_string(), body["devices"] ) - return (200, {}) + return 200, {} class DeviceRestServlet(RestServlet): @@ -114,7 +114,7 @@ class DeviceRestServlet(RestServlet): device = yield self.device_handler.get_device( requester.user.to_string(), device_id ) - return (200, device) + return 200, device @interactive_auth_handler @defer.inlineCallbacks @@ -137,7 +137,7 @@ class DeviceRestServlet(RestServlet): ) yield self.device_handler.delete_device(requester.user.to_string(), device_id) - return (200, {}) + return 200, {} @defer.inlineCallbacks def on_PUT(self, request, device_id): @@ -147,7 +147,7 @@ class DeviceRestServlet(RestServlet): yield self.device_handler.update_device( requester.user.to_string(), device_id, body ) - return (200, {}) + return 200, {} def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index 22be0ee3c5..c6ddf24c8d 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -56,7 +56,7 @@ class GetFilterRestServlet(RestServlet): user_localpart=target_user.localpart, filter_id=filter_id ) - return (200, filter.get_filter_json()) + return 200, filter.get_filter_json() except (KeyError, StoreError): raise SynapseError(400, "No such filter", errcode=Codes.NOT_FOUND) @@ -89,7 +89,7 @@ class CreateFilterRestServlet(RestServlet): user_localpart=target_user.localpart, user_filter=content ) - return (200, {"filter_id": str(filter_id)}) + return 200, {"filter_id": str(filter_id)} def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py index e629c4256d..999a0fa80c 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py @@ -47,7 +47,7 @@ class GroupServlet(RestServlet): group_id, requester_user_id ) - return (200, group_description) + return 200, group_description @defer.inlineCallbacks def on_POST(self, request, group_id): @@ -59,7 +59,7 @@ class GroupServlet(RestServlet): group_id, requester_user_id, content ) - return (200, {}) + return 200, {} class GroupSummaryServlet(RestServlet): @@ -83,7 +83,7 @@ class GroupSummaryServlet(RestServlet): group_id, requester_user_id ) - return (200, get_group_summary) + return 200, get_group_summary class GroupSummaryRoomsCatServlet(RestServlet): @@ -120,7 +120,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): content=content, ) - return (200, resp) + return 200, resp @defer.inlineCallbacks def on_DELETE(self, request, group_id, category_id, room_id): @@ -131,7 +131,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): group_id, requester_user_id, room_id=room_id, category_id=category_id ) - return (200, resp) + return 200, resp class GroupCategoryServlet(RestServlet): @@ -157,7 +157,7 @@ class GroupCategoryServlet(RestServlet): group_id, requester_user_id, category_id=category_id ) - return (200, category) + return 200, category @defer.inlineCallbacks def on_PUT(self, request, group_id, category_id): @@ -169,7 +169,7 @@ class GroupCategoryServlet(RestServlet): group_id, requester_user_id, category_id=category_id, content=content ) - return (200, resp) + return 200, resp @defer.inlineCallbacks def on_DELETE(self, request, group_id, category_id): @@ -180,7 +180,7 @@ class GroupCategoryServlet(RestServlet): group_id, requester_user_id, category_id=category_id ) - return (200, resp) + return 200, resp class GroupCategoriesServlet(RestServlet): @@ -204,7 +204,7 @@ class GroupCategoriesServlet(RestServlet): group_id, requester_user_id ) - return (200, category) + return 200, category class GroupRoleServlet(RestServlet): @@ -228,7 +228,7 @@ class GroupRoleServlet(RestServlet): group_id, requester_user_id, role_id=role_id ) - return (200, category) + return 200, category @defer.inlineCallbacks def on_PUT(self, request, group_id, role_id): @@ -240,7 +240,7 @@ class GroupRoleServlet(RestServlet): group_id, requester_user_id, role_id=role_id, content=content ) - return (200, resp) + return 200, resp @defer.inlineCallbacks def on_DELETE(self, request, group_id, role_id): @@ -251,7 +251,7 @@ class GroupRoleServlet(RestServlet): group_id, requester_user_id, role_id=role_id ) - return (200, resp) + return 200, resp class GroupRolesServlet(RestServlet): @@ -275,7 +275,7 @@ class GroupRolesServlet(RestServlet): group_id, requester_user_id ) - return (200, category) + return 200, category class GroupSummaryUsersRoleServlet(RestServlet): @@ -312,7 +312,7 @@ class GroupSummaryUsersRoleServlet(RestServlet): content=content, ) - return (200, resp) + return 200, resp @defer.inlineCallbacks def on_DELETE(self, request, group_id, role_id, user_id): @@ -323,7 +323,7 @@ class GroupSummaryUsersRoleServlet(RestServlet): group_id, requester_user_id, user_id=user_id, role_id=role_id ) - return (200, resp) + return 200, resp class GroupRoomServlet(RestServlet): @@ -347,7 +347,7 @@ class GroupRoomServlet(RestServlet): group_id, requester_user_id ) - return (200, result) + return 200, result class GroupUsersServlet(RestServlet): @@ -371,7 +371,7 @@ class GroupUsersServlet(RestServlet): group_id, requester_user_id ) - return (200, result) + return 200, result class GroupInvitedUsersServlet(RestServlet): @@ -395,7 +395,7 @@ class GroupInvitedUsersServlet(RestServlet): group_id, requester_user_id ) - return (200, result) + return 200, result class GroupSettingJoinPolicyServlet(RestServlet): @@ -420,7 +420,7 @@ class GroupSettingJoinPolicyServlet(RestServlet): group_id, requester_user_id, content ) - return (200, result) + return 200, result class GroupCreateServlet(RestServlet): @@ -450,7 +450,7 @@ class GroupCreateServlet(RestServlet): group_id, requester_user_id, content ) - return (200, result) + return 200, result class GroupAdminRoomsServlet(RestServlet): @@ -477,7 +477,7 @@ class GroupAdminRoomsServlet(RestServlet): group_id, requester_user_id, room_id, content ) - return (200, result) + return 200, result @defer.inlineCallbacks def on_DELETE(self, request, group_id, room_id): @@ -488,7 +488,7 @@ class GroupAdminRoomsServlet(RestServlet): group_id, requester_user_id, room_id ) - return (200, result) + return 200, result class GroupAdminRoomsConfigServlet(RestServlet): @@ -516,7 +516,7 @@ class GroupAdminRoomsConfigServlet(RestServlet): group_id, requester_user_id, room_id, config_key, content ) - return (200, result) + return 200, result class GroupAdminUsersInviteServlet(RestServlet): @@ -546,7 +546,7 @@ class GroupAdminUsersInviteServlet(RestServlet): group_id, user_id, requester_user_id, config ) - return (200, result) + return 200, result class GroupAdminUsersKickServlet(RestServlet): @@ -573,7 +573,7 @@ class GroupAdminUsersKickServlet(RestServlet): group_id, user_id, requester_user_id, content ) - return (200, result) + return 200, result class GroupSelfLeaveServlet(RestServlet): @@ -598,7 +598,7 @@ class GroupSelfLeaveServlet(RestServlet): group_id, requester_user_id, requester_user_id, content ) - return (200, result) + return 200, result class GroupSelfJoinServlet(RestServlet): @@ -623,7 +623,7 @@ class GroupSelfJoinServlet(RestServlet): group_id, requester_user_id, content ) - return (200, result) + return 200, result class GroupSelfAcceptInviteServlet(RestServlet): @@ -648,7 +648,7 @@ class GroupSelfAcceptInviteServlet(RestServlet): group_id, requester_user_id, content ) - return (200, result) + return 200, result class GroupSelfUpdatePublicityServlet(RestServlet): @@ -672,7 +672,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet): publicise = content["publicise"] yield self.store.update_group_publicity(group_id, requester_user_id, publicise) - return (200, {}) + return 200, {} class PublicisedGroupsForUserServlet(RestServlet): @@ -694,7 +694,7 @@ class PublicisedGroupsForUserServlet(RestServlet): result = yield self.groups_handler.get_publicised_groups_for_user(user_id) - return (200, result) + return 200, result class PublicisedGroupsForUsersServlet(RestServlet): @@ -719,7 +719,7 @@ class PublicisedGroupsForUsersServlet(RestServlet): result = yield self.groups_handler.bulk_get_publicised_groups(user_ids) - return (200, result) + return 200, result class GroupsForUserServlet(RestServlet): @@ -741,7 +741,7 @@ class GroupsForUserServlet(RestServlet): result = yield self.groups_handler.get_joined_groups(requester_user_id) - return (200, result) + return 200, result def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index b218a3f334..64b6898eb8 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -105,7 +105,7 @@ class KeyUploadServlet(RestServlet): result = yield self.e2e_keys_handler.upload_keys_for_user( user_id, device_id, body ) - return (200, result) + return 200, result class KeyQueryServlet(RestServlet): @@ -159,7 +159,7 @@ class KeyQueryServlet(RestServlet): timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) result = yield self.e2e_keys_handler.query_devices(body, timeout) - return (200, result) + return 200, result class KeyChangesServlet(RestServlet): @@ -200,7 +200,7 @@ class KeyChangesServlet(RestServlet): results = yield self.device_handler.get_user_ids_changed(user_id, from_token) - return (200, results) + return 200, results class OneTimeKeyServlet(RestServlet): @@ -235,7 +235,7 @@ class OneTimeKeyServlet(RestServlet): timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) result = yield self.e2e_keys_handler.claim_one_time_keys(body, timeout) - return (200, result) + return 200, result def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py index d034863a3c..10c1ad5b07 100644 --- a/synapse/rest/client/v2_alpha/notifications.py +++ b/synapse/rest/client/v2_alpha/notifications.py @@ -88,7 +88,7 @@ class NotificationsServlet(RestServlet): returned_push_actions.append(returned_pa) next_token = str(pa["stream_ordering"]) - return (200, {"notifications": returned_push_actions, "next_token": next_token}) + return 200, {"notifications": returned_push_actions, "next_token": next_token} def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py index d93d6a9f24..b3bf8567e1 100644 --- a/synapse/rest/client/v2_alpha/read_marker.py +++ b/synapse/rest/client/v2_alpha/read_marker.py @@ -59,7 +59,7 @@ class ReadMarkerRestServlet(RestServlet): event_id=read_marker_event_id, ) - return (200, {}) + return 200, {} def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index 98a97b7059..0dab03d227 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -52,7 +52,7 @@ class ReceiptRestServlet(RestServlet): room_id, receipt_type, user_id=requester.user.to_string(), event_id=event_id ) - return (200, {}) + return 200, {} def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 9510a1e2b0..65f9fce2ff 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -94,7 +94,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) ret = yield self.identity_handler.requestEmailToken(**body) - return (200, ret) + return 200, ret class MsisdnRegisterRequestTokenRestServlet(RestServlet): @@ -137,7 +137,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): ) ret = yield self.identity_handler.requestMsisdnToken(**body) - return (200, ret) + return 200, ret class UsernameAvailabilityRestServlet(RestServlet): @@ -177,7 +177,7 @@ class UsernameAvailabilityRestServlet(RestServlet): yield self.registration_handler.check_username(username) - return (200, {"available": True}) + return 200, {"available": True} class RegisterRestServlet(RestServlet): @@ -279,7 +279,7 @@ class RegisterRestServlet(RestServlet): result = yield self._do_appservice_registration( desired_username, access_token, body ) - return (200, result) # we throw for non 200 responses + return 200, result # we throw for non 200 responses return # for regular registration, downcase the provided username before @@ -487,7 +487,7 @@ class RegisterRestServlet(RestServlet): bind_msisdn=params.get("bind_msisdn"), ) - return (200, return_dict) + return 200, return_dict def on_OPTIONS(self, _): return 200, {} diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index 1538b247e5..040b37c504 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -118,7 +118,7 @@ class RelationSendServlet(RestServlet): requester, event_dict=event_dict, txn_id=txn_id ) - return (200, {"event_id": event.event_id}) + return 200, {"event_id": event.event_id} class RelationPaginationServlet(RestServlet): @@ -198,7 +198,7 @@ class RelationPaginationServlet(RestServlet): return_value["chunk"] = events return_value["original_event"] = original_event - return (200, return_value) + return 200, return_value class RelationAggregationPaginationServlet(RestServlet): @@ -270,7 +270,7 @@ class RelationAggregationPaginationServlet(RestServlet): to_token=to_token, ) - return (200, pagination_chunk.to_dict()) + return 200, pagination_chunk.to_dict() class RelationAggregationGroupPaginationServlet(RestServlet): @@ -356,7 +356,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): return_value = result.to_dict() return_value["chunk"] = events - return (200, return_value) + return 200, return_value def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py index 3fdd4584a3..e7449864cd 100644 --- a/synapse/rest/client/v2_alpha/report_event.py +++ b/synapse/rest/client/v2_alpha/report_event.py @@ -72,7 +72,7 @@ class ReportEventRestServlet(RestServlet): received_ts=self.clock.time_msec(), ) - return (200, {}) + return 200, {} def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py index 10dec96208..df4f44cd36 100644 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -135,7 +135,7 @@ class RoomKeysServlet(RestServlet): body = {"rooms": {room_id: body}} yield self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) - return (200, {}) + return 200, {} @defer.inlineCallbacks def on_GET(self, request, room_id, session_id): @@ -218,7 +218,7 @@ class RoomKeysServlet(RestServlet): else: room_keys = room_keys["rooms"][room_id] - return (200, room_keys) + return 200, room_keys @defer.inlineCallbacks def on_DELETE(self, request, room_id, session_id): @@ -242,7 +242,7 @@ class RoomKeysServlet(RestServlet): yield self.e2e_room_keys_handler.delete_room_keys( user_id, version, room_id, session_id ) - return (200, {}) + return 200, {} class RoomKeysNewVersionServlet(RestServlet): @@ -293,7 +293,7 @@ class RoomKeysNewVersionServlet(RestServlet): info = parse_json_object_from_request(request) new_version = yield self.e2e_room_keys_handler.create_version(user_id, info) - return (200, {"version": new_version}) + return 200, {"version": new_version} # we deliberately don't have a PUT /version, as these things really should # be immutable to avoid people footgunning @@ -338,7 +338,7 @@ class RoomKeysVersionServlet(RestServlet): except SynapseError as e: if e.code == 404: raise SynapseError(404, "No backup found", Codes.NOT_FOUND) - return (200, info) + return 200, info @defer.inlineCallbacks def on_DELETE(self, request, version): @@ -358,7 +358,7 @@ class RoomKeysVersionServlet(RestServlet): user_id = requester.user.to_string() yield self.e2e_room_keys_handler.delete_version(user_id, version) - return (200, {}) + return 200, {} @defer.inlineCallbacks def on_PUT(self, request, version): @@ -392,7 +392,7 @@ class RoomKeysVersionServlet(RestServlet): ) yield self.e2e_room_keys_handler.update_version(user_id, version, info) - return (200, {}) + return 200, {} def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py index 14ba61a63e..d2c3316eb7 100644 --- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py +++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py @@ -80,7 +80,7 @@ class RoomUpgradeRestServlet(RestServlet): ret = {"replacement_room": new_room_id} - return (200, ret) + return 200, ret def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 7b32dd2212..c98c5a3802 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -174,7 +174,7 @@ class SyncRestServlet(RestServlet): time_now, sync_result, requester.access_token_id, filter ) - return (200, response_content) + return 200, response_content @defer.inlineCallbacks def encode_response(self, time_now, sync_result, access_token_id, filter): diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index d173544355..3b555669a0 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -45,7 +45,7 @@ class TagListServlet(RestServlet): tags = yield self.store.get_tags_for_room(user_id, room_id) - return (200, {"tags": tags}) + return 200, {"tags": tags} class TagServlet(RestServlet): @@ -76,7 +76,7 @@ class TagServlet(RestServlet): self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) - return (200, {}) + return 200, {} @defer.inlineCallbacks def on_DELETE(self, request, user_id, room_id, tag): @@ -88,7 +88,7 @@ class TagServlet(RestServlet): self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) - return (200, {}) + return 200, {} def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index 158e686b01..2e8d672471 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -40,7 +40,7 @@ class ThirdPartyProtocolsServlet(RestServlet): yield self.auth.get_user_by_req(request, allow_guest=True) protocols = yield self.appservice_handler.get_3pe_protocols() - return (200, protocols) + return 200, protocols class ThirdPartyProtocolServlet(RestServlet): @@ -60,9 +60,9 @@ class ThirdPartyProtocolServlet(RestServlet): only_protocol=protocol ) if protocol in protocols: - return (200, protocols[protocol]) + return 200, protocols[protocol] else: - return (404, {"error": "Unknown protocol"}) + return 404, {"error": "Unknown protocol"} class ThirdPartyUserServlet(RestServlet): @@ -85,7 +85,7 @@ class ThirdPartyUserServlet(RestServlet): ThirdPartyEntityKind.USER, protocol, fields ) - return (200, results) + return 200, results class ThirdPartyLocationServlet(RestServlet): @@ -108,7 +108,7 @@ class ThirdPartyLocationServlet(RestServlet): ThirdPartyEntityKind.LOCATION, protocol, fields ) - return (200, results) + return 200, results def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py index 7ab2b80e46..2863affbab 100644 --- a/synapse/rest/client/v2_alpha/user_directory.py +++ b/synapse/rest/client/v2_alpha/user_directory.py @@ -60,7 +60,7 @@ class UserDirectorySearchRestServlet(RestServlet): user_id = requester.user.to_string() if not self.hs.config.user_directory_search_enabled: - return (200, {"limited": False, "results": []}) + return 200, {"limited": False, "results": []} body = parse_json_object_from_request(request) @@ -76,7 +76,7 @@ class UserDirectorySearchRestServlet(RestServlet): user_id, search_term, limit ) - return (200, results) + return 200, results def register_servlets(hs, http_server): diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index cf5759e9a6..d4ea09260c 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -318,14 +318,14 @@ class MediaRepository(object): responder = yield self.media_storage.fetch_media(file_info) if responder: - return (responder, media_info) + return responder, media_info # Failed to find the file anywhere, lets download it. media_info = yield self._download_remote_file(server_name, media_id, file_id) responder = yield self.media_storage.fetch_media(file_info) - return (responder, media_info) + return responder, media_info @defer.inlineCallbacks def _download_remote_file(self, server_name, media_id, file_id): diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 90d8e6bffe..c995d7e043 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -78,9 +78,9 @@ class Thumbnailer(object): """ if max_width * self.height < max_height * self.width: - return (max_width, (max_width * self.height) // self.width) + return max_width, (max_width * self.height) // self.width else: - return ((max_height * self.width) // self.height, max_height) + return (max_height * self.width) // self.height, max_height def scale(self, width, height, output_type): """Rescales the image to the given dimensions. diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index 729c097e6d..81c4aff496 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -193,4 +193,4 @@ class ResourceLimitsServerNotices(object): if event_id in referenced_events: referenced_events.remove(event.event_id) - return (currently_blocked, referenced_events) + return currently_blocked, referenced_events diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index 9fa5b4f3d6..6afbfc0d74 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -90,7 +90,7 @@ class AccountDataWorkerStore(SQLBaseStore): room_data = by_room.setdefault(row["room_id"], {}) room_data[row["account_data_type"]] = json.loads(row["content"]) - return (global_account_data, by_room) + return global_account_data, by_room return self.runInteraction( "get_account_data_for_user", get_account_data_for_user_txn @@ -205,7 +205,7 @@ class AccountDataWorkerStore(SQLBaseStore): ) txn.execute(sql, (last_room_id, current_id, limit)) room_results = txn.fetchall() - return (global_results, room_results) + return global_results, room_results return self.runInteraction( "get_all_updated_account_data_txn", get_updated_account_data_txn @@ -244,13 +244,13 @@ class AccountDataWorkerStore(SQLBaseStore): room_account_data = account_data_by_room.setdefault(row[0], {}) room_account_data[row[1]] = json.loads(row[2]) - return (global_account_data, account_data_by_room) + return global_account_data, account_data_by_room changed = self._account_data_stream_cache.has_entity_changed( user_id, int(stream_id) ) if not changed: - return ({}, {}) + return {}, {} return self.runInteraction( "get_updated_account_data_for_user", get_updated_account_data_for_user_txn diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py index 05d9c05c3f..36657753cd 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/appservice.py @@ -358,7 +358,7 @@ class ApplicationServiceTransactionWorkerStore( events = yield self.get_events_as_list(event_ids) - return (upper_bound, events) + return upper_bound, events class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore): diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index 79bb0ea46d..4dca9de617 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -66,7 +66,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): messages.append(json.loads(row[1])) if len(messages) < limit: stream_pos = current_stream_id - return (messages, stream_pos) + return messages, stream_pos return self.runInteraction( "get_new_messages_for_device", get_new_messages_for_device_txn @@ -157,7 +157,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): messages.append(json.loads(row[1])) if len(messages) < limit: stream_pos = current_stream_id - return (messages, stream_pos) + return messages, stream_pos return self.runInteraction( "get_new_device_msgs_for_remote", diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index e11881161d..76542c512d 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -94,7 +94,7 @@ class DeviceWorkerStore(SQLBaseStore): destination, int(from_stream_id) ) if not has_changed: - return (now_stream_id, []) + return now_stream_id, [] # We retrieve n+1 devices from the list of outbound pokes where n is # our outbound device update limit. We then check if the very last @@ -117,7 +117,7 @@ class DeviceWorkerStore(SQLBaseStore): # Return an empty list if there are no updates if not updates: - return (now_stream_id, []) + return now_stream_id, [] # if we have exceeded the limit, we need to exclude any results with the # same stream_id as the last row. @@ -167,13 +167,13 @@ class DeviceWorkerStore(SQLBaseStore): # skip that stream_id and return an empty list, and continue with the next # stream_id next time. if not query_map: - return (stream_id_cutoff, []) + return stream_id_cutoff, [] results = yield self._get_device_update_edus_by_remote( destination, from_stream_id, query_map ) - return (now_stream_id, results) + return now_stream_id, results def _get_devices_by_remote_txn( self, txn, destination, from_stream_id, now_stream_id, limit @@ -352,7 +352,7 @@ class DeviceWorkerStore(SQLBaseStore): else: results[user_id] = yield self._get_cached_devices_for_user(user_id) - return (user_ids_not_in_cache, results) + return user_ids_not_in_cache, results @cachedInlineCallbacks(num_args=2, tree=True) def _get_cached_user_device(self, user_id, device_id): diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 5a95c36a8b..32050868ff 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -810,7 +810,7 @@ class EventsStore( # If they old and new groups are the same then we don't need to do # anything. if old_state_groups == new_state_groups: - return (None, None) + return None, None if len(new_state_groups) == 1 and len(old_state_groups) == 1: # If we're going from one state group to another, lets check if @@ -827,7 +827,7 @@ class EventsStore( # the current state in memory then lets also return that, # but it doesn't matter if we don't. new_state = state_groups_map.get(new_state_group) - return (new_state, delta_ids) + return new_state, delta_ids # Now that we have calculated new_state_groups we need to get # their state IDs so we can resolve to a single state set. @@ -839,7 +839,7 @@ class EventsStore( if len(new_state_groups) == 1: # If there is only one state group, then we know what the current # state is. - return (state_groups_map[new_state_groups.pop()], None) + return state_groups_map[new_state_groups.pop()], None # Ok, we need to defer to the state handler to resolve our state sets. @@ -868,7 +868,7 @@ class EventsStore( state_res_store=StateResolutionStore(self), ) - return (res.state, None) + return res.state, None @defer.inlineCallbacks def _calculate_state_delta(self, room_id, current_state): @@ -891,7 +891,7 @@ class EventsStore( if ev_id != existing_state.get(key) } - return (to_delete, to_insert) + return to_delete, to_insert @log_function def _persist_events_txn( diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 1a0f2d5768..5db6f2d84a 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -90,7 +90,7 @@ class PresenceStore(SQLBaseStore): presence_states, ) - return (stream_orderings[-1], self._presence_id_gen.get_current_token()) + return stream_orderings[-1], self._presence_id_gen.get_current_token() def _update_presence_txn(self, txn, stream_orderings, presence_states): for stream_id, state in zip(stream_orderings, presence_states): diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index b431d24b8a..3e0e834a62 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -133,7 +133,7 @@ class PusherWorkerStore(SQLBaseStore): txn.execute(sql, (last_id, current_id, limit)) deleted = txn.fetchall() - return (updated, deleted) + return updated, deleted return self.runInteraction( "get_all_updated_pushers", get_all_updated_pushers_txn diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 6aa6d98ebb..290ddb30e8 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -478,7 +478,7 @@ class ReceiptsStore(ReceiptsWorkerStore): max_persisted_id = self._receipts_id_gen.get_current_token() - return (stream_id, max_persisted_id) + return stream_id, max_persisted_id def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data): return self.runInteraction( diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 856c2ee8d8..490454f19a 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -364,7 +364,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): the chunk of events returned. """ if from_key == to_key: - return ([], from_key) + return [], from_key from_id = RoomStreamToken.parse_stream_token(from_key).stream to_id = RoomStreamToken.parse_stream_token(to_key).stream @@ -374,7 +374,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) if not has_changed: - return ([], from_key) + return [], from_key def f(txn): sql = ( @@ -407,7 +407,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): # get. key = from_key - return (ret, key) + return ret, key @defer.inlineCallbacks def get_membership_changes_for_user(self, user_id, from_key, to_key): @@ -496,7 +496,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): """ # Allow a zero limit here, and no-op. if limit == 0: - return ([], end_token) + return [], end_token end_token = RoomStreamToken.parse(end_token) @@ -511,7 +511,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): # We want to return the results in ascending order. rows.reverse() - return (rows, token) + return rows, token def get_room_event_after_stream_ordering(self, room_id, stream_ordering): """Gets details of the first event in a room at or after a stream ordering @@ -783,7 +783,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): events = yield self.get_events_as_list(event_ids) - return (upper_bound, events) + return upper_bound, events def get_federation_out_pos(self, typ): return self._simple_select_one_onecol( diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index f1c8d99419..cbb0a4810a 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -195,6 +195,6 @@ class ChainedIdGenerator(object): with self._lock: if self._unfinished_ids: stream_id, chained_id = self._unfinished_ids[0] - return (stream_id - 1, chained_id) + return stream_id - 1, chained_id - return (self._current_max, self.chained_generator.get_current_token()) + return self._current_max, self.chained_generator.get_current_token() diff --git a/synapse/streams/config.py b/synapse/streams/config.py index f7f5906a99..02994ab2a5 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -37,7 +37,7 @@ class SourcePaginationConfig(object): self.limit = min(int(limit), MAX_LIMIT) if limit is not None else None def __repr__(self): - return ("StreamConfig(from_key=%r, to_key=%r, direction=%r, limit=%r)") % ( + return "StreamConfig(from_key=%r, to_key=%r, direction=%r, limit=%r)" % ( self.from_key, self.to_key, self.direction, diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 0ad0a88165..e10296a5e4 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -283,4 +283,4 @@ class RegistrationTestCase(unittest.HomeserverTestCase): user, requester, displayname, by_admin=True ) - return (user_id, token) + return user_id, token diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index bb867150f4..ab4d7d70d0 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -472,7 +472,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): added_at=now, ) ) - return (user_id, tok) + return user_id, tok def test_manual_email_send_expired_account(self): user_id = self.register_user("kermit", "monkey") diff --git a/tests/server.py b/tests/server.py index c8269619b1..e397ebe8fa 100644 --- a/tests/server.py +++ b/tests/server.py @@ -338,7 +338,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): def get_clock(): clock = ThreadedMemoryReactorClock() hs_clock = Clock(clock) - return (clock, hs_clock) + return clock, hs_clock @attr.s(cmp=False) diff --git a/tests/test_server.py b/tests/test_server.py index 2a7d407c98..98fef21d55 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -57,7 +57,7 @@ class JsonResourceTests(unittest.TestCase): def _callback(request, **kwargs): got_kwargs.update(kwargs) - return (200, kwargs) + return 200, kwargs res = JsonResource(self.homeserver) res.register_paths( diff --git a/tests/test_state.py b/tests/test_state.py index 6d33566f47..610ec9fb46 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -106,7 +106,7 @@ class StateGroupStore(object): } def get_state_group_delta(self, name): - return (None, None) + return None, None def register_events(self, events): for e in events: diff --git a/tests/utils.py b/tests/utils.py index f1eb9a545c..46ef2959f2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -464,7 +464,7 @@ class MockHttpResource(HttpServer): args = [urlparse.unquote(u) for u in matcher.groups()] (code, response) = yield func(mock_request, *args) - return (code, response) + return code, response except CodeMessageException as e: return (e.code, cs_error(e.msg, code=e.errcode)) -- cgit 1.5.1 From 6e834e94fcc97811e4cc8185e86c6b9da06eb28e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 4 Sep 2019 13:04:27 +0100 Subject: Fix and refactor room and user stats (#5971) Previously the stats were not being correctly populated. --- changelog.d/5971.bugfix | 1 + docs/room_and_user_statistics.md | 62 ++ synapse/config/stats.py | 13 +- synapse/handlers/stats.py | 307 +++--- synapse/storage/events.py | 5 +- synapse/storage/registration.py | 12 + synapse/storage/roommember.py | 44 +- .../storage/schema/delta/56/stats_separated.sql | 152 +++ synapse/storage/stats.py | 1036 ++++++++++++++------ tests/handlers/test_stats.py | 643 +++++++++--- tests/rest/client/v1/utils.py | 8 +- 11 files changed, 1642 insertions(+), 641 deletions(-) create mode 100644 changelog.d/5971.bugfix create mode 100644 docs/room_and_user_statistics.md create mode 100644 synapse/storage/schema/delta/56/stats_separated.sql (limited to 'tests') diff --git a/changelog.d/5971.bugfix b/changelog.d/5971.bugfix new file mode 100644 index 0000000000..9ea095103b --- /dev/null +++ b/changelog.d/5971.bugfix @@ -0,0 +1 @@ +Fix room and user stats tracking. diff --git a/docs/room_and_user_statistics.md b/docs/room_and_user_statistics.md new file mode 100644 index 0000000000..e1facb38d4 --- /dev/null +++ b/docs/room_and_user_statistics.md @@ -0,0 +1,62 @@ +Room and User Statistics +======================== + +Synapse maintains room and user statistics (as well as a cache of room state), +in various tables. These can be used for administrative purposes but are also +used when generating the public room directory. + + +# Synapse Developer Documentation + +## High-Level Concepts + +### Definitions + +* **subject**: Something we are tracking stats about – currently a room or user. +* **current row**: An entry for a subject in the appropriate current statistics + table. Each subject can have only one. +* **historical row**: An entry for a subject in the appropriate historical + statistics table. Each subject can have any number of these. + +### Overview + +Stats are maintained as time series. There are two kinds of column: + +* absolute columns – where the value is correct for the time given by `end_ts` + in the stats row. (Imagine a line graph for these values) + * They can also be thought of as 'gauges' in Prometheus, if you are familiar. +* per-slice columns – where the value corresponds to how many of the occurrences + occurred within the time slice given by `(end_ts − bucket_size)…end_ts` + or `start_ts…end_ts`. (Imagine a histogram for these values) + +Stats are maintained in two tables (for each type): current and historical. + +Current stats correspond to the present values. Each subject can only have one +entry. + +Historical stats correspond to values in the past. Subjects may have multiple +entries. + +## Concepts around the management of stats + +### Current rows + +Current rows contain the most up-to-date statistics for a room. +They only contain absolute columns + +### Historical rows + +Historical rows can always be considered to be valid for the time slice and +end time specified. + +* historical rows will not exist for every time slice – they will be omitted + if there were no changes. In this case, the following assumptions can be + made to interpolate/recreate missing rows: + - absolute fields have the same values as in the preceding row + - per-slice fields are zero (`0`) +* historical rows will not be retained forever – rows older than a configurable + time will be purged. + +#### Purge + +The purging of historical rows is not yet implemented. diff --git a/synapse/config/stats.py b/synapse/config/stats.py index b518a3ed9c..b18ddbd1fa 100644 --- a/synapse/config/stats.py +++ b/synapse/config/stats.py @@ -27,19 +27,16 @@ class StatsConfig(Config): def read_config(self, config, **kwargs): self.stats_enabled = True - self.stats_bucket_size = 86400 + self.stats_bucket_size = 86400 * 1000 self.stats_retention = sys.maxsize stats_config = config.get("stats", None) if stats_config: self.stats_enabled = stats_config.get("enabled", self.stats_enabled) - self.stats_bucket_size = ( - self.parse_duration(stats_config.get("bucket_size", "1d")) / 1000 + self.stats_bucket_size = self.parse_duration( + stats_config.get("bucket_size", "1d") ) - self.stats_retention = ( - self.parse_duration( - stats_config.get("retention", "%ds" % (sys.maxsize,)) - ) - / 1000 + self.stats_retention = self.parse_duration( + stats_config.get("retention", "%ds" % (sys.maxsize,)) ) def generate_config_section(self, config_dir_path, server_name, **kwargs): diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 4449da6669..921735edb3 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -14,15 +14,14 @@ # limitations under the License. import logging +from collections import Counter from twisted.internet import defer -from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.api.constants import EventTypes, Membership from synapse.handlers.state_deltas import StateDeltasHandler from synapse.metrics import event_processing_positions from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import UserID -from synapse.util.metrics import Measure logger = logging.getLogger(__name__) @@ -62,11 +61,10 @@ class StatsHandler(StateDeltasHandler): def notify_new_event(self): """Called when there may be more deltas to process """ - if not self.hs.config.stats_enabled: + if not self.hs.config.stats_enabled or self._is_processing: return - if self._is_processing: - return + self._is_processing = True @defer.inlineCallbacks def process(): @@ -75,39 +73,72 @@ class StatsHandler(StateDeltasHandler): finally: self._is_processing = False - self._is_processing = True run_as_background_process("stats.notify_new_event", process) @defer.inlineCallbacks def _unsafe_process(self): # If self.pos is None then means we haven't fetched it from DB if self.pos is None: - self.pos = yield self.store.get_stats_stream_pos() - - # If still None then the initial background update hasn't happened yet - if self.pos is None: - return None + self.pos = yield self.store.get_stats_positions() # Loop round handling deltas until we're up to date + while True: - with Measure(self.clock, "stats_delta"): - deltas = yield self.store.get_current_state_deltas(self.pos) - if not deltas: - return + deltas = yield self.store.get_current_state_deltas(self.pos) + + if deltas: + logger.debug("Handling %d state deltas", len(deltas)) + room_deltas, user_deltas = yield self._handle_deltas(deltas) + + max_pos = deltas[-1]["stream_id"] + else: + room_deltas = {} + user_deltas = {} + max_pos = yield self.store.get_room_max_stream_ordering() - logger.info("Handling %d state deltas", len(deltas)) - yield self._handle_deltas(deltas) + # Then count deltas for total_events and total_event_bytes. + room_count, user_count = yield self.store.get_changes_room_total_events_and_bytes( + self.pos, max_pos + ) + + for room_id, fields in room_count.items(): + room_deltas.setdefault(room_id, {}).update(fields) + + for user_id, fields in user_count.items(): + user_deltas.setdefault(user_id, {}).update(fields) + + logger.debug("room_deltas: %s", room_deltas) + logger.debug("user_deltas: %s", user_deltas) - self.pos = deltas[-1]["stream_id"] - yield self.store.update_stats_stream_pos(self.pos) + # Always call this so that we update the stats position. + yield self.store.bulk_update_stats_delta( + self.clock.time_msec(), + updates={"room": room_deltas, "user": user_deltas}, + stream_id=max_pos, + ) + + event_processing_positions.labels("stats").set(max_pos) - event_processing_positions.labels("stats").set(self.pos) + if self.pos == max_pos: + break + + self.pos = max_pos @defer.inlineCallbacks def _handle_deltas(self, deltas): + """Called with the state deltas to process + + Returns: + Deferred[tuple[dict[str, Counter], dict[str, counter]]] + Resovles to two dicts, the room deltas and the user deltas, + mapping from room/user ID to changes in the various fields. """ - Called with the state deltas to process - """ + + room_to_stats_deltas = {} + user_to_stats_deltas = {} + + room_to_state_updates = {} + for delta in deltas: typ = delta["type"] state_key = delta["state_key"] @@ -115,11 +146,10 @@ class StatsHandler(StateDeltasHandler): event_id = delta["event_id"] stream_id = delta["stream_id"] prev_event_id = delta["prev_event_id"] - stream_pos = delta["stream_id"] - logger.debug("Handling: %r %r, %s", typ, state_key, event_id) + logger.debug("Handling: %r, %r %r, %s", room_id, typ, state_key, event_id) - token = yield self.store.get_earliest_token_for_room_stats(room_id) + token = yield self.store.get_earliest_token_for_stats("room", room_id) # If the earliest token to begin from is larger than our current # stream ID, skip processing this delta. @@ -131,203 +161,130 @@ class StatsHandler(StateDeltasHandler): continue if event_id is None and prev_event_id is None: - # Errr... + logger.error( + "event ID is None and so is the previous event ID. stream_id: %s", + stream_id, + ) continue event_content = {} + sender = None if event_id is not None: event = yield self.store.get_event(event_id, allow_none=True) if event: event_content = event.content or {} + sender = event.sender + + # All the values in this dict are deltas (RELATIVE changes) + room_stats_delta = room_to_stats_deltas.setdefault(room_id, Counter()) - # We use stream_pos here rather than fetch by event_id as event_id - # may be None - now = yield self.store.get_received_ts_by_stream_pos(stream_pos) + room_state = room_to_state_updates.setdefault(room_id, {}) - # quantise time to the nearest bucket - now = (now // 1000 // self.stats_bucket_size) * self.stats_bucket_size + if prev_event_id is None: + # this state event doesn't overwrite another, + # so it is a new effective/current state event + room_stats_delta["current_state_events"] += 1 if typ == EventTypes.Member: # we could use _get_key_change here but it's a bit inefficient # given we're not testing for a specific result; might as well # just grab the prev_membership and membership strings and # compare them. - prev_event_content = {} + # We take None rather than leave as a previous membership + # in the absence of a previous event because we do not want to + # reduce the leave count when a new-to-the-room user joins. + prev_membership = None if prev_event_id is not None: prev_event = yield self.store.get_event( prev_event_id, allow_none=True ) if prev_event: prev_event_content = prev_event.content + prev_membership = prev_event_content.get( + "membership", Membership.LEAVE + ) membership = event_content.get("membership", Membership.LEAVE) - prev_membership = prev_event_content.get("membership", Membership.LEAVE) - - if prev_membership == membership: - continue - if prev_membership == Membership.JOIN: - yield self.store.update_stats_delta( - now, "room", room_id, "joined_members", -1 - ) + if prev_membership is None: + logger.debug("No previous membership for this user.") + elif membership == prev_membership: + pass # noop + elif prev_membership == Membership.JOIN: + room_stats_delta["joined_members"] -= 1 elif prev_membership == Membership.INVITE: - yield self.store.update_stats_delta( - now, "room", room_id, "invited_members", -1 - ) + room_stats_delta["invited_members"] -= 1 elif prev_membership == Membership.LEAVE: - yield self.store.update_stats_delta( - now, "room", room_id, "left_members", -1 - ) + room_stats_delta["left_members"] -= 1 elif prev_membership == Membership.BAN: - yield self.store.update_stats_delta( - now, "room", room_id, "banned_members", -1 - ) + room_stats_delta["banned_members"] -= 1 else: - err = "%s is not a valid prev_membership" % (repr(prev_membership),) - logger.error(err) - raise ValueError(err) + raise ValueError( + "%r is not a valid prev_membership" % (prev_membership,) + ) + if membership == prev_membership: + pass # noop if membership == Membership.JOIN: - yield self.store.update_stats_delta( - now, "room", room_id, "joined_members", +1 - ) + room_stats_delta["joined_members"] += 1 elif membership == Membership.INVITE: - yield self.store.update_stats_delta( - now, "room", room_id, "invited_members", +1 - ) + room_stats_delta["invited_members"] += 1 + + if sender and self.is_mine_id(sender): + user_to_stats_deltas.setdefault(sender, Counter())[ + "invites_sent" + ] += 1 + elif membership == Membership.LEAVE: - yield self.store.update_stats_delta( - now, "room", room_id, "left_members", +1 - ) + room_stats_delta["left_members"] += 1 elif membership == Membership.BAN: - yield self.store.update_stats_delta( - now, "room", room_id, "banned_members", +1 - ) + room_stats_delta["banned_members"] += 1 else: - err = "%s is not a valid membership" % (repr(membership),) - logger.error(err) - raise ValueError(err) + raise ValueError("%r is not a valid membership" % (membership,)) user_id = state_key if self.is_mine_id(user_id): - # update user_stats as it's one of our users - public = yield self._is_public_room(room_id) - - if membership == Membership.LEAVE: - yield self.store.update_stats_delta( - now, - "user", - user_id, - "public_rooms" if public else "private_rooms", - -1, - ) - elif membership == Membership.JOIN: - yield self.store.update_stats_delta( - now, - "user", - user_id, - "public_rooms" if public else "private_rooms", - +1, - ) + # this accounts for transitions like leave → ban and so on. + has_changed_joinedness = (prev_membership == Membership.JOIN) != ( + membership == Membership.JOIN + ) - elif typ == EventTypes.Create: - # Newly created room. Add it with all blank portions. - yield self.store.update_room_state( - room_id, - { - "join_rules": None, - "history_visibility": None, - "encryption": None, - "name": None, - "topic": None, - "avatar": None, - "canonical_alias": None, - }, - ) + if has_changed_joinedness: + delta = +1 if membership == Membership.JOIN else -1 - elif typ == EventTypes.JoinRules: - yield self.store.update_room_state( - room_id, {"join_rules": event_content.get("join_rule")} - ) + user_to_stats_deltas.setdefault(user_id, Counter())[ + "joined_rooms" + ] += delta - is_public = yield self._get_key_change( - prev_event_id, event_id, "join_rule", JoinRules.PUBLIC - ) - if is_public is not None: - yield self.update_public_room_stats(now, room_id, is_public) + room_stats_delta["local_users_in_room"] += delta + elif typ == EventTypes.Create: + room_state["is_federatable"] = event_content.get("m.federate", True) + if sender and self.is_mine_id(sender): + user_to_stats_deltas.setdefault(sender, Counter())[ + "rooms_created" + ] += 1 + elif typ == EventTypes.JoinRules: + room_state["join_rules"] = event_content.get("join_rule") elif typ == EventTypes.RoomHistoryVisibility: - yield self.store.update_room_state( - room_id, - {"history_visibility": event_content.get("history_visibility")}, - ) - - is_public = yield self._get_key_change( - prev_event_id, event_id, "history_visibility", "world_readable" + room_state["history_visibility"] = event_content.get( + "history_visibility" ) - if is_public is not None: - yield self.update_public_room_stats(now, room_id, is_public) - elif typ == EventTypes.Encryption: - yield self.store.update_room_state( - room_id, {"encryption": event_content.get("algorithm")} - ) + room_state["encryption"] = event_content.get("algorithm") elif typ == EventTypes.Name: - yield self.store.update_room_state( - room_id, {"name": event_content.get("name")} - ) + room_state["name"] = event_content.get("name") elif typ == EventTypes.Topic: - yield self.store.update_room_state( - room_id, {"topic": event_content.get("topic")} - ) + room_state["topic"] = event_content.get("topic") elif typ == EventTypes.RoomAvatar: - yield self.store.update_room_state( - room_id, {"avatar": event_content.get("url")} - ) + room_state["avatar"] = event_content.get("url") elif typ == EventTypes.CanonicalAlias: - yield self.store.update_room_state( - room_id, {"canonical_alias": event_content.get("alias")} - ) + room_state["canonical_alias"] = event_content.get("alias") + elif typ == EventTypes.GuestAccess: + room_state["guest_access"] = event_content.get("guest_access") - @defer.inlineCallbacks - def update_public_room_stats(self, ts, room_id, is_public): - """ - Increment/decrement a user's number of public rooms when a room they are - in changes to/from public visibility. + for room_id, state in room_to_state_updates.items(): + yield self.store.update_room_state(room_id, state) - Args: - ts (int): Timestamp in seconds - room_id (str) - is_public (bool) - """ - # For now, blindly iterate over all local users in the room so that - # we can handle the whole problem of copying buckets over as needed - user_ids = yield self.store.get_users_in_room(room_id) - - for user_id in user_ids: - if self.hs.is_mine(UserID.from_string(user_id)): - yield self.store.update_stats_delta( - ts, "user", user_id, "public_rooms", +1 if is_public else -1 - ) - yield self.store.update_stats_delta( - ts, "user", user_id, "private_rooms", -1 if is_public else +1 - ) - - @defer.inlineCallbacks - def _is_public_room(self, room_id): - join_rules = yield self.state.get_current_state(room_id, EventTypes.JoinRules) - history_visibility = yield self.state.get_current_state( - room_id, EventTypes.RoomHistoryVisibility - ) - - if (join_rules and join_rules.content.get("join_rule") == JoinRules.PUBLIC) or ( - ( - history_visibility - and history_visibility.content.get("history_visibility") - == "world_readable" - ) - ): - return True - else: - return False + return room_to_stats_deltas, user_to_stats_deltas diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 32050868ff..1958afe1d7 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -2270,8 +2270,9 @@ class EventsStore( "room_aliases", "room_depth", "room_memberships", - "room_state", - "room_stats", + "room_stats_state", + "room_stats_current", + "room_stats_historical", "room_stats_earliest_token", "rooms", "stream_ordering_to_exterm", diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 3f50324253..2d3c7e2dc9 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -869,6 +869,17 @@ class RegistrationStore( (user_id_obj.localpart, create_profile_with_displayname), ) + if self.hs.config.stats_enabled: + # we create a new completed user statistics row + + # we don't strictly need current_token since this user really can't + # have any state deltas before now (as it is a new user), but still, + # we include it for completeness. + current_token = self._get_max_stream_id_in_current_state_deltas_txn(txn) + self._update_stats_delta_txn( + txn, now, "user", user_id, {}, complete_with_stream_id=current_token + ) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) txn.call_after(self.is_guest.invalidate, (user_id,)) @@ -1140,6 +1151,7 @@ class RegistrationStore( deferred str|None: A str representing a link to redirect the user to if there is one. """ + # Insert everything into a transaction in order to run atomically def validate_threepid_session_txn(txn): row = self._simple_select_one_txn( diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index eecb276465..f8b682ebd9 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -112,29 +112,31 @@ class RoomMemberWorkerStore(EventsWorkerStore): @cached(max_entries=100000, iterable=True) def get_users_in_room(self, room_id): - def f(txn): - # If we can assume current_state_events.membership is up to date - # then we can avoid a join, which is a Very Good Thing given how - # frequently this function gets called. - if self._current_state_events_membership_up_to_date: - sql = """ - SELECT state_key FROM current_state_events - WHERE type = 'm.room.member' AND room_id = ? AND membership = ? - """ - else: - sql = """ - SELECT state_key FROM room_memberships as m - INNER JOIN current_state_events as c - ON m.event_id = c.event_id - AND m.room_id = c.room_id - AND m.user_id = c.state_key - WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ? - """ + return self.runInteraction( + "get_users_in_room", self.get_users_in_room_txn, room_id + ) - txn.execute(sql, (room_id, Membership.JOIN)) - return [to_ascii(r[0]) for r in txn] + def get_users_in_room_txn(self, txn, room_id): + # If we can assume current_state_events.membership is up to date + # then we can avoid a join, which is a Very Good Thing given how + # frequently this function gets called. + if self._current_state_events_membership_up_to_date: + sql = """ + SELECT state_key FROM current_state_events + WHERE type = 'm.room.member' AND room_id = ? AND membership = ? + """ + else: + sql = """ + SELECT state_key FROM room_memberships as m + INNER JOIN current_state_events as c + ON m.event_id = c.event_id + AND m.room_id = c.room_id + AND m.user_id = c.state_key + WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ? + """ - return self.runInteraction("get_users_in_room", f) + txn.execute(sql, (room_id, Membership.JOIN)) + return [to_ascii(r[0]) for r in txn] @cached(max_entries=100000) def get_room_summary(self, room_id): diff --git a/synapse/storage/schema/delta/56/stats_separated.sql b/synapse/storage/schema/delta/56/stats_separated.sql new file mode 100644 index 0000000000..163529c071 --- /dev/null +++ b/synapse/storage/schema/delta/56/stats_separated.sql @@ -0,0 +1,152 @@ +/* Copyright 2018 New Vector Ltd + * Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +----- First clean up from previous versions of room stats. + +-- First remove old stats stuff +DROP TABLE IF EXISTS room_stats; +DROP TABLE IF EXISTS room_state; +DROP TABLE IF EXISTS room_stats_state; +DROP TABLE IF EXISTS user_stats; +DROP TABLE IF EXISTS room_stats_earliest_tokens; +DROP TABLE IF EXISTS _temp_populate_stats_position; +DROP TABLE IF EXISTS _temp_populate_stats_rooms; +DROP TABLE IF EXISTS stats_stream_pos; + +-- Unschedule old background updates if they're still scheduled +DELETE FROM background_updates WHERE update_name IN ( + 'populate_stats_createtables', + 'populate_stats_process_rooms', + 'populate_stats_process_users', + 'populate_stats_cleanup' +); + +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('populate_stats_process_rooms', '{}', ''); + +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('populate_stats_process_users', '{}', 'populate_stats_process_rooms'); + +----- Create tables for our version of room stats. + +-- single-row table to track position of incremental updates +DROP TABLE IF EXISTS stats_incremental_position; +CREATE TABLE stats_incremental_position ( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + stream_id BIGINT NOT NULL, + CHECK (Lock='X') +); + +-- insert a null row and make sure it is the only one. +INSERT INTO stats_incremental_position ( + stream_id +) SELECT COALESCE(MAX(stream_ordering), 0) from events; + +-- represents PRESENT room statistics for a room +-- only holds absolute fields +DROP TABLE IF EXISTS room_stats_current; +CREATE TABLE room_stats_current ( + room_id TEXT NOT NULL PRIMARY KEY, + + -- These are absolute counts + current_state_events INT NOT NULL, + joined_members INT NOT NULL, + invited_members INT NOT NULL, + left_members INT NOT NULL, + banned_members INT NOT NULL, + + local_users_in_room INT NOT NULL, + + -- The maximum delta stream position that this row takes into account. + completed_delta_stream_id BIGINT NOT NULL +); + + +-- represents HISTORICAL room statistics for a room +DROP TABLE IF EXISTS room_stats_historical; +CREATE TABLE room_stats_historical ( + room_id TEXT NOT NULL, + -- These stats cover the time from (end_ts - bucket_size)...end_ts (in ms). + -- Note that end_ts is quantised. + end_ts BIGINT NOT NULL, + bucket_size BIGINT NOT NULL, + + -- These stats are absolute counts + current_state_events BIGINT NOT NULL, + joined_members BIGINT NOT NULL, + invited_members BIGINT NOT NULL, + left_members BIGINT NOT NULL, + banned_members BIGINT NOT NULL, + local_users_in_room BIGINT NOT NULL, + + -- These stats are per time slice + total_events BIGINT NOT NULL, + total_event_bytes BIGINT NOT NULL, + + PRIMARY KEY (room_id, end_ts) +); + +-- We use this index to speed up deletion of ancient room stats. +CREATE INDEX room_stats_historical_end_ts ON room_stats_historical (end_ts); + +-- represents PRESENT statistics for a user +-- only holds absolute fields +DROP TABLE IF EXISTS user_stats_current; +CREATE TABLE user_stats_current ( + user_id TEXT NOT NULL PRIMARY KEY, + + joined_rooms BIGINT NOT NULL, + + -- The maximum delta stream position that this row takes into account. + completed_delta_stream_id BIGINT NOT NULL +); + +-- represents HISTORICAL statistics for a user +DROP TABLE IF EXISTS user_stats_historical; +CREATE TABLE user_stats_historical ( + user_id TEXT NOT NULL, + end_ts BIGINT NOT NULL, + bucket_size BIGINT NOT NULL, + + joined_rooms BIGINT NOT NULL, + + invites_sent BIGINT NOT NULL, + rooms_created BIGINT NOT NULL, + total_events BIGINT NOT NULL, + total_event_bytes BIGINT NOT NULL, + + PRIMARY KEY (user_id, end_ts) +); + +-- We use this index to speed up deletion of ancient user stats. +CREATE INDEX user_stats_historical_end_ts ON user_stats_historical (end_ts); + + +CREATE TABLE room_stats_state ( + room_id TEXT NOT NULL, + name TEXT, + canonical_alias TEXT, + join_rules TEXT, + history_visibility TEXT, + encryption TEXT, + avatar TEXT, + guest_access TEXT, + is_federatable BOOLEAN, + topic TEXT +); + +CREATE UNIQUE INDEX room_stats_state_room ON room_stats_state(room_id); diff --git a/synapse/storage/stats.py b/synapse/storage/stats.py index e13efed417..6560173c08 100644 --- a/synapse/storage/stats.py +++ b/synapse/storage/stats.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2018, 2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,17 +15,22 @@ # limitations under the License. import logging +from itertools import chain from twisted.internet import defer +from twisted.internet.defer import DeferredLock from synapse.api.constants import EventTypes, Membership -from synapse.storage.prepare_database import get_statements +from synapse.storage import PostgresEngine from synapse.storage.state_deltas import StateDeltasStore from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) # these fields track absolutes (e.g. total number of rooms on the server) +# You can think of these as Prometheus Gauges. +# You can draw these stats on a line graph. +# Example: number of users in a room ABSOLUTE_STATS_FIELDS = { "room": ( "current_state_events", @@ -32,14 +38,23 @@ ABSOLUTE_STATS_FIELDS = { "invited_members", "left_members", "banned_members", - "state_events", + "local_users_in_room", ), - "user": ("public_rooms", "private_rooms"), + "user": ("joined_rooms",), } -TYPE_TO_ROOM = {"room": ("room_stats", "room_id"), "user": ("user_stats", "user_id")} +# these fields are per-timeslice and so should be reset to 0 upon a new slice +# You can draw these stats on a histogram. +# Example: number of events sent locally during a time slice +PER_SLICE_FIELDS = { + "room": ("total_events", "total_event_bytes"), + "user": ("invites_sent", "rooms_created", "total_events", "total_event_bytes"), +} + +TYPE_TO_TABLE = {"room": ("room_stats", "room_id"), "user": ("user_stats", "user_id")} -TEMP_TABLE = "_temp_populate_stats" +# these are the tables (& ID columns) which contain our actual subjects +TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")} class StatsStore(StateDeltasStore): @@ -51,136 +66,102 @@ class StatsStore(StateDeltasStore): self.stats_enabled = hs.config.stats_enabled self.stats_bucket_size = hs.config.stats_bucket_size - self.register_background_update_handler( - "populate_stats_createtables", self._populate_stats_createtables - ) + self.stats_delta_processing_lock = DeferredLock() + self.register_background_update_handler( "populate_stats_process_rooms", self._populate_stats_process_rooms ) self.register_background_update_handler( - "populate_stats_cleanup", self._populate_stats_cleanup + "populate_stats_process_users", self._populate_stats_process_users ) + # we no longer need to perform clean-up, but we will give ourselves + # the potential to reintroduce it in the future – so documentation + # will still encourage the use of this no-op handler. + self.register_noop_background_update("populate_stats_cleanup") + self.register_noop_background_update("populate_stats_prepare") - @defer.inlineCallbacks - def _populate_stats_createtables(self, progress, batch_size): - - if not self.stats_enabled: - yield self._end_background_update("populate_stats_createtables") - return 1 - - # Get all the rooms that we want to process. - def _make_staging_area(txn): - # Create the temporary tables - stmts = get_statements( - """ - -- We just recreate the table, we'll be reinserting the - -- correct entries again later anyway. - DROP TABLE IF EXISTS {temp}_rooms; - - CREATE TABLE IF NOT EXISTS {temp}_rooms( - room_id TEXT NOT NULL, - events BIGINT NOT NULL - ); - - CREATE INDEX {temp}_rooms_events - ON {temp}_rooms(events); - CREATE INDEX {temp}_rooms_id - ON {temp}_rooms(room_id); - """.format( - temp=TEMP_TABLE - ).splitlines() - ) - - for statement in stmts: - txn.execute(statement) - - sql = ( - "CREATE TABLE IF NOT EXISTS " - + TEMP_TABLE - + "_position(position TEXT NOT NULL)" - ) - txn.execute(sql) - - # Get rooms we want to process from the database, only adding - # those that we haven't (i.e. those not in room_stats_earliest_token) - sql = """ - INSERT INTO %s_rooms (room_id, events) - SELECT c.room_id, count(*) FROM current_state_events AS c - LEFT JOIN room_stats_earliest_token AS t USING (room_id) - WHERE t.room_id IS NULL - GROUP BY c.room_id - """ % ( - TEMP_TABLE, - ) - txn.execute(sql) + def quantise_stats_time(self, ts): + """ + Quantises a timestamp to be a multiple of the bucket size. - new_pos = yield self.get_max_stream_id_in_current_state_deltas() - yield self.runInteraction("populate_stats_temp_build", _make_staging_area) - yield self._simple_insert(TEMP_TABLE + "_position", {"position": new_pos}) - self.get_earliest_token_for_room_stats.invalidate_all() + Args: + ts (int): the timestamp to quantise, in milliseconds since the Unix + Epoch - yield self._end_background_update("populate_stats_createtables") - return 1 + Returns: + int: a timestamp which + - is divisible by the bucket size; + - is no later than `ts`; and + - is the largest such timestamp. + """ + return (ts // self.stats_bucket_size) * self.stats_bucket_size @defer.inlineCallbacks - def _populate_stats_cleanup(self, progress, batch_size): + def _populate_stats_process_users(self, progress, batch_size): """ - Update the user directory stream position, then clean up the old tables. + This is a background update which regenerates statistics for users. """ if not self.stats_enabled: - yield self._end_background_update("populate_stats_cleanup") + yield self._end_background_update("populate_stats_process_users") return 1 - position = yield self._simple_select_one_onecol( - TEMP_TABLE + "_position", None, "position" + last_user_id = progress.get("last_user_id", "") + + def _get_next_batch(txn): + sql = """ + SELECT DISTINCT name FROM users + WHERE name > ? + ORDER BY name ASC + LIMIT ? + """ + txn.execute(sql, (last_user_id, batch_size)) + return [r for r, in txn] + + users_to_work_on = yield self.runInteraction( + "_populate_stats_process_users", _get_next_batch ) - yield self.update_stats_stream_pos(position) - def _delete_staging_area(txn): - txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms") - txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position") + # No more rooms -- complete the transaction. + if not users_to_work_on: + yield self._end_background_update("populate_stats_process_users") + return 1 - yield self.runInteraction("populate_stats_cleanup", _delete_staging_area) + for user_id in users_to_work_on: + yield self._calculate_and_set_initial_state_for_user(user_id) + progress["last_user_id"] = user_id - yield self._end_background_update("populate_stats_cleanup") - return 1 + yield self.runInteraction( + "populate_stats_process_users", + self._background_update_progress_txn, + "populate_stats_process_users", + progress, + ) + + return len(users_to_work_on) @defer.inlineCallbacks def _populate_stats_process_rooms(self, progress, batch_size): - + """ + This is a background update which regenerates statistics for rooms. + """ if not self.stats_enabled: yield self._end_background_update("populate_stats_process_rooms") return 1 - # If we don't have progress filed, delete everything. - if not progress: - yield self.delete_all_stats() + last_room_id = progress.get("last_room_id", "") def _get_next_batch(txn): - # Only fetch 250 rooms, so we don't fetch too many at once, even - # if those 250 rooms have less than batch_size state events. sql = """ - SELECT room_id, events FROM %s_rooms - ORDER BY events DESC - LIMIT 250 - """ % ( - TEMP_TABLE, - ) - txn.execute(sql) - rooms_to_work_on = txn.fetchall() - - if not rooms_to_work_on: - return None - - # Get how many are left to process, so we can give status on how - # far we are in processing - txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms") - progress["remaining"] = txn.fetchone()[0] - - return rooms_to_work_on + SELECT DISTINCT room_id FROM current_state_events + WHERE room_id > ? + ORDER BY room_id ASC + LIMIT ? + """ + txn.execute(sql, (last_room_id, batch_size)) + return [r for r, in txn] rooms_to_work_on = yield self.runInteraction( - "populate_stats_temp_read", _get_next_batch + "populate_stats_rooms_get_batch", _get_next_batch ) # No more rooms -- complete the transaction. @@ -188,154 +169,28 @@ class StatsStore(StateDeltasStore): yield self._end_background_update("populate_stats_process_rooms") return 1 - logger.info( - "Processing the next %d rooms of %d remaining", - len(rooms_to_work_on), - progress["remaining"], - ) - - # Number of state events we've processed by going through each room - processed_event_count = 0 - - for room_id, event_count in rooms_to_work_on: - - current_state_ids = yield self.get_current_state_ids(room_id) - - join_rules_id = current_state_ids.get((EventTypes.JoinRules, "")) - history_visibility_id = current_state_ids.get( - (EventTypes.RoomHistoryVisibility, "") - ) - encryption_id = current_state_ids.get((EventTypes.RoomEncryption, "")) - name_id = current_state_ids.get((EventTypes.Name, "")) - topic_id = current_state_ids.get((EventTypes.Topic, "")) - avatar_id = current_state_ids.get((EventTypes.RoomAvatar, "")) - canonical_alias_id = current_state_ids.get((EventTypes.CanonicalAlias, "")) - - event_ids = [ - join_rules_id, - history_visibility_id, - encryption_id, - name_id, - topic_id, - avatar_id, - canonical_alias_id, - ] - - state_events = yield self.get_events( - [ev for ev in event_ids if ev is not None] - ) - - def _get_or_none(event_id, arg): - event = state_events.get(event_id) - if event: - return event.content.get(arg) - return None - - yield self.update_room_state( - room_id, - { - "join_rules": _get_or_none(join_rules_id, "join_rule"), - "history_visibility": _get_or_none( - history_visibility_id, "history_visibility" - ), - "encryption": _get_or_none(encryption_id, "algorithm"), - "name": _get_or_none(name_id, "name"), - "topic": _get_or_none(topic_id, "topic"), - "avatar": _get_or_none(avatar_id, "url"), - "canonical_alias": _get_or_none(canonical_alias_id, "alias"), - }, - ) + for room_id in rooms_to_work_on: + yield self._calculate_and_set_initial_state_for_room(room_id) + progress["last_room_id"] = room_id - now = self.hs.get_reactor().seconds() - - # quantise time to the nearest bucket - now = (now // self.stats_bucket_size) * self.stats_bucket_size - - def _fetch_data(txn): - - # Get the current token of the room - current_token = self._get_max_stream_id_in_current_state_deltas_txn(txn) - - current_state_events = len(current_state_ids) - - membership_counts = self._get_user_counts_in_room_txn(txn, room_id) - - total_state_events = self._get_total_state_event_counts_txn( - txn, room_id - ) - - self._update_stats_txn( - txn, - "room", - room_id, - now, - { - "bucket_size": self.stats_bucket_size, - "current_state_events": current_state_events, - "joined_members": membership_counts.get(Membership.JOIN, 0), - "invited_members": membership_counts.get(Membership.INVITE, 0), - "left_members": membership_counts.get(Membership.LEAVE, 0), - "banned_members": membership_counts.get(Membership.BAN, 0), - "state_events": total_state_events, - }, - ) - self._simple_insert_txn( - txn, - "room_stats_earliest_token", - {"room_id": room_id, "token": current_token}, - ) - - # We've finished a room. Delete it from the table. - self._simple_delete_one_txn( - txn, TEMP_TABLE + "_rooms", {"room_id": room_id} - ) - - yield self.runInteraction("update_room_stats", _fetch_data) - - # Update the remaining counter. - progress["remaining"] -= 1 - yield self.runInteraction( - "populate_stats", - self._background_update_progress_txn, - "populate_stats_process_rooms", - progress, - ) - - processed_event_count += event_count - - if processed_event_count > batch_size: - # Don't process any more rooms, we've hit our batch size. - return processed_event_count + yield self.runInteraction( + "_populate_stats_process_rooms", + self._background_update_progress_txn, + "populate_stats_process_rooms", + progress, + ) - return processed_event_count + return len(rooms_to_work_on) - def delete_all_stats(self): + def get_stats_positions(self): """ - Delete all statistics records. + Returns the stats processor positions. """ - - def _delete_all_stats_txn(txn): - txn.execute("DELETE FROM room_state") - txn.execute("DELETE FROM room_stats") - txn.execute("DELETE FROM room_stats_earliest_token") - txn.execute("DELETE FROM user_stats") - - return self.runInteraction("delete_all_stats", _delete_all_stats_txn) - - def get_stats_stream_pos(self): return self._simple_select_one_onecol( - table="stats_stream_pos", + table="stats_incremental_position", keyvalues={}, retcol="stream_id", - desc="stats_stream_pos", - ) - - def update_stats_stream_pos(self, stream_id): - return self._simple_update_one( - table="stats_stream_pos", - keyvalues={}, - updatevalues={"stream_id": stream_id}, - desc="update_stats_stream_pos", + desc="stats_incremental_position", ) def update_room_state(self, room_id, fields): @@ -361,42 +216,87 @@ class StatsStore(StateDeltasStore): fields[col] = None return self._simple_upsert( - table="room_state", + table="room_stats_state", keyvalues={"room_id": room_id}, values=fields, desc="update_room_state", ) - def get_deltas_for_room(self, room_id, start, size=100): + def get_statistics_for_subject(self, stats_type, stats_id, start, size=100): """ - Get statistics deltas for a given room. + Get statistics for a given subject. Args: - room_id (str) + stats_type (str): The type of subject + stats_id (str): The ID of the subject (e.g. room_id or user_id) start (int): Pagination start. Number of entries, not timestamp. size (int): How many entries to return. Returns: Deferred[list[dict]], where the dict has the keys of - ABSOLUTE_STATS_FIELDS["room"] and "ts". + ABSOLUTE_STATS_FIELDS[stats_type], and "bucket_size" and "end_ts". """ - return self._simple_select_list_paginate( - "room_stats", - {"room_id": room_id}, - "ts", + return self.runInteraction( + "get_statistics_for_subject", + self._get_statistics_for_subject_txn, + stats_type, + stats_id, + start, + size, + ) + + def _get_statistics_for_subject_txn( + self, txn, stats_type, stats_id, start, size=100 + ): + """ + Transaction-bound version of L{get_statistics_for_subject}. + """ + + table, id_col = TYPE_TO_TABLE[stats_type] + selected_columns = list( + ABSOLUTE_STATS_FIELDS[stats_type] + PER_SLICE_FIELDS[stats_type] + ) + + slice_list = self._simple_select_list_paginate_txn( + txn, + table + "_historical", + {id_col: stats_id}, + "end_ts", start, size, - retcols=(list(ABSOLUTE_STATS_FIELDS["room"]) + ["ts"]), + retcols=selected_columns + ["bucket_size", "end_ts"], order_direction="DESC", ) - def get_all_room_state(self): - return self._simple_select_list( - "room_state", None, retcols=("name", "topic", "canonical_alias") + return slice_list + + def get_room_stats_state(self, room_id): + """ + Returns the current room_stats_state for a room. + + Args: + room_id (str): The ID of the room to return state for. + + Returns (dict): + Dictionary containing these keys: + "name", "topic", "canonical_alias", "avatar", "join_rules", + "history_visibility" + """ + return self._simple_select_one( + "room_stats_state", + {"room_id": room_id}, + retcols=( + "name", + "topic", + "canonical_alias", + "avatar", + "join_rules", + "history_visibility", + ), ) @cached() - def get_earliest_token_for_room_stats(self, room_id): + def get_earliest_token_for_stats(self, stats_type, id): """ Fetch the "earliest token". This is used by the room stats delta processor to ignore deltas that have been processed between the @@ -406,79 +306,571 @@ class StatsStore(StateDeltasStore): Returns: Deferred[int] """ + table, id_col = TYPE_TO_TABLE[stats_type] + return self._simple_select_one_onecol( - "room_stats_earliest_token", - {"room_id": room_id}, - retcol="token", + "%s_current" % (table,), + keyvalues={id_col: id}, + retcol="completed_delta_stream_id", allow_none=True, ) - def update_stats(self, stats_type, stats_id, ts, fields): - table, id_col = TYPE_TO_ROOM[stats_type] - return self._simple_upsert( - table=table, - keyvalues={id_col: stats_id, "ts": ts}, - values=fields, - desc="update_stats", + def bulk_update_stats_delta(self, ts, updates, stream_id): + """Bulk update stats tables for a given stream_id and updates the stats + incremental position. + + Args: + ts (int): Current timestamp in ms + updates(dict[str, dict[str, dict[str, Counter]]]): The updates to + commit as a mapping stats_type -> stats_id -> field -> delta. + stream_id (int): Current position. + + Returns: + Deferred + """ + + def _bulk_update_stats_delta_txn(txn): + for stats_type, stats_updates in updates.items(): + for stats_id, fields in stats_updates.items(): + self._update_stats_delta_txn( + txn, + ts=ts, + stats_type=stats_type, + stats_id=stats_id, + fields=fields, + complete_with_stream_id=stream_id, + ) + + self._simple_update_one_txn( + txn, + table="stats_incremental_position", + keyvalues={}, + updatevalues={"stream_id": stream_id}, + ) + + return self.runInteraction( + "bulk_update_stats_delta", _bulk_update_stats_delta_txn ) - def _update_stats_txn(self, txn, stats_type, stats_id, ts, fields): - table, id_col = TYPE_TO_ROOM[stats_type] - return self._simple_upsert_txn( - txn, table=table, keyvalues={id_col: stats_id, "ts": ts}, values=fields + def update_stats_delta( + self, + ts, + stats_type, + stats_id, + fields, + complete_with_stream_id, + absolute_field_overrides=None, + ): + """ + Updates the statistics for a subject, with a delta (difference/relative + change). + + Args: + ts (int): timestamp of the change + stats_type (str): "room" or "user" – the kind of subject + stats_id (str): the subject's ID (room ID or user ID) + fields (dict[str, int]): Deltas of stats values. + complete_with_stream_id (int, optional): + If supplied, converts an incomplete row into a complete row, + with the supplied stream_id marked as the stream_id where the + row was completed. + absolute_field_overrides (dict[str, int]): Current stats values + (i.e. not deltas) of absolute fields. + Does not work with per-slice fields. + """ + + return self.runInteraction( + "update_stats_delta", + self._update_stats_delta_txn, + ts, + stats_type, + stats_id, + fields, + complete_with_stream_id=complete_with_stream_id, + absolute_field_overrides=absolute_field_overrides, ) - def update_stats_delta(self, ts, stats_type, stats_id, field, value): - def _update_stats_delta(txn): - table, id_col = TYPE_TO_ROOM[stats_type] - - sql = ( - "SELECT * FROM %s" - " WHERE %s=? and ts=(" - " SELECT MAX(ts) FROM %s" - " WHERE %s=?" - ")" - ) % (table, id_col, table, id_col) - txn.execute(sql, (stats_id, stats_id)) - rows = self.cursor_to_dict(txn) - if len(rows) == 0: - # silently skip as we don't have anything to apply a delta to yet. - # this tries to minimise any race between the initial sync and - # subsequent deltas arriving. - return - - current_ts = ts - latest_ts = rows[0]["ts"] - if current_ts < latest_ts: - # This one is in the past, but we're just encountering it now. - # Mark it as part of the current bucket. - current_ts = latest_ts - elif ts != latest_ts: - # we have to copy our absolute counters over to the new entry. - values = { - key: rows[0][key] for key in ABSOLUTE_STATS_FIELDS[stats_type] - } - values[id_col] = stats_id - values["ts"] = ts - values["bucket_size"] = self.stats_bucket_size - - self._simple_insert_txn(txn, table=table, values=values) - - # actually update the new value - if stats_type in ABSOLUTE_STATS_FIELDS[stats_type]: - self._simple_update_txn( - txn, - table=table, - keyvalues={id_col: stats_id, "ts": current_ts}, - updatevalues={field: value}, + def _update_stats_delta_txn( + self, + txn, + ts, + stats_type, + stats_id, + fields, + complete_with_stream_id, + absolute_field_overrides=None, + ): + if absolute_field_overrides is None: + absolute_field_overrides = {} + + table, id_col = TYPE_TO_TABLE[stats_type] + + quantised_ts = self.quantise_stats_time(int(ts)) + end_ts = quantised_ts + self.stats_bucket_size + + # Lets be paranoid and check that all the given field names are known + abs_field_names = ABSOLUTE_STATS_FIELDS[stats_type] + slice_field_names = PER_SLICE_FIELDS[stats_type] + for field in chain(fields.keys(), absolute_field_overrides.keys()): + if field not in abs_field_names and field not in slice_field_names: + # guard against potential SQL injection dodginess + raise ValueError( + "%s is not a recognised field" + " for stats type %s" % (field, stats_type) ) + + # Per slice fields do not get added to the _current table + + # This calculates the deltas (`field = field + ?` values) + # for absolute fields, + # * defaulting to 0 if not specified + # (required for the INSERT part of upserting to work) + # * omitting overrides specified in `absolute_field_overrides` + deltas_of_absolute_fields = { + key: fields.get(key, 0) + for key in abs_field_names + if key not in absolute_field_overrides + } + + # Keep the delta stream ID field up to date + absolute_field_overrides = absolute_field_overrides.copy() + absolute_field_overrides["completed_delta_stream_id"] = complete_with_stream_id + + # first upsert the `_current` table + self._upsert_with_additive_relatives_txn( + txn=txn, + table=table + "_current", + keyvalues={id_col: stats_id}, + absolutes=absolute_field_overrides, + additive_relatives=deltas_of_absolute_fields, + ) + + per_slice_additive_relatives = { + key: fields.get(key, 0) for key in slice_field_names + } + self._upsert_copy_from_table_with_additive_relatives_txn( + txn=txn, + into_table=table + "_historical", + keyvalues={id_col: stats_id}, + extra_dst_insvalues={"bucket_size": self.stats_bucket_size}, + extra_dst_keyvalues={"end_ts": end_ts}, + additive_relatives=per_slice_additive_relatives, + src_table=table + "_current", + copy_columns=abs_field_names, + ) + + def _upsert_with_additive_relatives_txn( + self, txn, table, keyvalues, absolutes, additive_relatives + ): + """Used to update values in the stats tables. + + This is basically a slightly convoluted upsert that *adds* to any + existing rows. + + Args: + txn + table (str): Table name + keyvalues (dict[str, any]): Row-identifying key values + absolutes (dict[str, any]): Absolute (set) fields + additive_relatives (dict[str, int]): Fields that will be added onto + if existing row present. + """ + if self.database_engine.can_native_upsert: + absolute_updates = [ + "%(field)s = EXCLUDED.%(field)s" % {"field": field} + for field in absolutes.keys() + ] + + relative_updates = [ + "%(field)s = EXCLUDED.%(field)s + %(table)s.%(field)s" + % {"table": table, "field": field} + for field in additive_relatives.keys() + ] + + insert_cols = [] + qargs = [] + + for (key, val) in chain( + keyvalues.items(), absolutes.items(), additive_relatives.items() + ): + insert_cols.append(key) + qargs.append(val) + + sql = """ + INSERT INTO %(table)s (%(insert_cols_cs)s) + VALUES (%(insert_vals_qs)s) + ON CONFLICT (%(key_columns)s) DO UPDATE SET %(updates)s + """ % { + "table": table, + "insert_cols_cs": ", ".join(insert_cols), + "insert_vals_qs": ", ".join( + ["?"] * (len(keyvalues) + len(absolutes) + len(additive_relatives)) + ), + "key_columns": ", ".join(keyvalues), + "updates": ", ".join(chain(absolute_updates, relative_updates)), + } + + txn.execute(sql, qargs) + else: + self.database_engine.lock_table(txn, table) + retcols = list(chain(absolutes.keys(), additive_relatives.keys())) + current_row = self._simple_select_one_txn( + txn, table, keyvalues, retcols, allow_none=True + ) + if current_row is None: + merged_dict = {**keyvalues, **absolutes, **additive_relatives} + self._simple_insert_txn(txn, table, merged_dict) + else: + for (key, val) in additive_relatives.items(): + current_row[key] += val + current_row.update(absolutes) + self._simple_update_one_txn(txn, table, keyvalues, current_row) + + def _upsert_copy_from_table_with_additive_relatives_txn( + self, + txn, + into_table, + keyvalues, + extra_dst_keyvalues, + extra_dst_insvalues, + additive_relatives, + src_table, + copy_columns, + ): + """Updates the historic stats table with latest updates. + + This involves copying "absolute" fields from the `_current` table, and + adding relative fields to any existing values. + + Args: + txn: Transaction + into_table (str): The destination table to UPSERT the row into + keyvalues (dict[str, any]): Row-identifying key values + extra_dst_keyvalues (dict[str, any]): Additional keyvalues + for `into_table`. + extra_dst_insvalues (dict[str, any]): Additional values to insert + on new row creation for `into_table`. + additive_relatives (dict[str, any]): Fields that will be added onto + if existing row present. (Must be disjoint from copy_columns.) + src_table (str): The source table to copy from + copy_columns (iterable[str]): The list of columns to copy + """ + if self.database_engine.can_native_upsert: + ins_columns = chain( + keyvalues, + copy_columns, + additive_relatives, + extra_dst_keyvalues, + extra_dst_insvalues, + ) + sel_exprs = chain( + keyvalues, + copy_columns, + ( + "?" + for _ in chain( + additive_relatives, extra_dst_keyvalues, extra_dst_insvalues + ) + ), + ) + keyvalues_where = ("%s = ?" % f for f in keyvalues) + + sets_cc = ("%s = EXCLUDED.%s" % (f, f) for f in copy_columns) + sets_ar = ( + "%s = EXCLUDED.%s + %s.%s" % (f, f, into_table, f) + for f in additive_relatives + ) + + sql = """ + INSERT INTO %(into_table)s (%(ins_columns)s) + SELECT %(sel_exprs)s + FROM %(src_table)s + WHERE %(keyvalues_where)s + ON CONFLICT (%(keyvalues)s) + DO UPDATE SET %(sets)s + """ % { + "into_table": into_table, + "ins_columns": ", ".join(ins_columns), + "sel_exprs": ", ".join(sel_exprs), + "keyvalues_where": " AND ".join(keyvalues_where), + "src_table": src_table, + "keyvalues": ", ".join( + chain(keyvalues.keys(), extra_dst_keyvalues.keys()) + ), + "sets": ", ".join(chain(sets_cc, sets_ar)), + } + + qargs = list( + chain( + additive_relatives.values(), + extra_dst_keyvalues.values(), + extra_dst_insvalues.values(), + keyvalues.values(), + ) + ) + txn.execute(sql, qargs) + else: + self.database_engine.lock_table(txn, into_table) + src_row = self._simple_select_one_txn( + txn, src_table, keyvalues, copy_columns + ) + all_dest_keyvalues = {**keyvalues, **extra_dst_keyvalues} + dest_current_row = self._simple_select_one_txn( + txn, + into_table, + keyvalues=all_dest_keyvalues, + retcols=list(chain(additive_relatives.keys(), copy_columns)), + allow_none=True, + ) + + if dest_current_row is None: + merged_dict = { + **keyvalues, + **extra_dst_keyvalues, + **extra_dst_insvalues, + **src_row, + **additive_relatives, + } + self._simple_insert_txn(txn, into_table, merged_dict) else: - sql = ("UPDATE %s SET %s=%s+? WHERE %s=? AND ts=?") % ( - table, - field, - field, - id_col, + for (key, val) in additive_relatives.items(): + src_row[key] = dest_current_row[key] + val + self._simple_update_txn(txn, into_table, all_dest_keyvalues, src_row) + + def get_changes_room_total_events_and_bytes(self, min_pos, max_pos): + """Fetches the counts of events in the given range of stream IDs. + + Args: + min_pos (int) + max_pos (int) + + Returns: + Deferred[dict[str, dict[str, int]]]: Mapping of room ID to field + changes. + """ + + return self.runInteraction( + "stats_incremental_total_events_and_bytes", + self.get_changes_room_total_events_and_bytes_txn, + min_pos, + max_pos, + ) + + def get_changes_room_total_events_and_bytes_txn(self, txn, low_pos, high_pos): + """Gets the total_events and total_event_bytes counts for rooms and + senders, in a range of stream_orderings (including backfilled events). + + Args: + txn + low_pos (int): Low stream ordering + high_pos (int): High stream ordering + + Returns: + tuple[dict[str, dict[str, int]], dict[str, dict[str, int]]]: The + room and user deltas for total_events/total_event_bytes in the + format of `stats_id` -> fields + """ + + if low_pos >= high_pos: + # nothing to do here. + return {}, {} + + if isinstance(self.database_engine, PostgresEngine): + new_bytes_expression = "OCTET_LENGTH(json)" + else: + new_bytes_expression = "LENGTH(CAST(json AS BLOB))" + + sql = """ + SELECT events.room_id, COUNT(*) AS new_events, SUM(%s) AS new_bytes + FROM events INNER JOIN event_json USING (event_id) + WHERE (? < stream_ordering AND stream_ordering <= ?) + OR (? <= stream_ordering AND stream_ordering <= ?) + GROUP BY events.room_id + """ % ( + new_bytes_expression, + ) + + txn.execute(sql, (low_pos, high_pos, -high_pos, -low_pos)) + + room_deltas = { + room_id: {"total_events": new_events, "total_event_bytes": new_bytes} + for room_id, new_events, new_bytes in txn + } + + sql = """ + SELECT events.sender, COUNT(*) AS new_events, SUM(%s) AS new_bytes + FROM events INNER JOIN event_json USING (event_id) + WHERE (? < stream_ordering AND stream_ordering <= ?) + OR (? <= stream_ordering AND stream_ordering <= ?) + GROUP BY events.sender + """ % ( + new_bytes_expression, + ) + + txn.execute(sql, (low_pos, high_pos, -high_pos, -low_pos)) + + user_deltas = { + user_id: {"total_events": new_events, "total_event_bytes": new_bytes} + for user_id, new_events, new_bytes in txn + if self.hs.is_mine_id(user_id) + } + + return room_deltas, user_deltas + + @defer.inlineCallbacks + def _calculate_and_set_initial_state_for_room(self, room_id): + """Calculate and insert an entry into room_stats_current. + + Args: + room_id (str) + + Returns: + Deferred[tuple[dict, dict, int]]: A tuple of room state, membership + counts and stream position. + """ + + def _fetch_current_state_stats(txn): + pos = self.get_room_max_stream_ordering() + + rows = self._simple_select_many_txn( + txn, + table="current_state_events", + column="type", + iterable=[ + EventTypes.Create, + EventTypes.JoinRules, + EventTypes.RoomHistoryVisibility, + EventTypes.Encryption, + EventTypes.Name, + EventTypes.Topic, + EventTypes.RoomAvatar, + EventTypes.CanonicalAlias, + ], + keyvalues={"room_id": room_id, "state_key": ""}, + retcols=["event_id"], + ) + + event_ids = [row["event_id"] for row in rows] + + txn.execute( + """ + SELECT membership, count(*) FROM current_state_events + WHERE room_id = ? AND type = 'm.room.member' + GROUP BY membership + """, + (room_id,), + ) + membership_counts = {membership: cnt for membership, cnt in txn} + + txn.execute( + """ + SELECT COALESCE(count(*), 0) FROM current_state_events + WHERE room_id = ? + """, + (room_id,), + ) + + current_state_events_count, = txn.fetchone() + + users_in_room = self.get_users_in_room_txn(txn, room_id) + + return ( + event_ids, + membership_counts, + current_state_events_count, + users_in_room, + pos, + ) + + ( + event_ids, + membership_counts, + current_state_events_count, + users_in_room, + pos, + ) = yield self.runInteraction( + "get_initial_state_for_room", _fetch_current_state_stats + ) + + state_event_map = yield self.get_events(event_ids, get_prev_content=False) + + room_state = { + "join_rules": None, + "history_visibility": None, + "encryption": None, + "name": None, + "topic": None, + "avatar": None, + "canonical_alias": None, + "is_federatable": True, + } + + for event in state_event_map.values(): + if event.type == EventTypes.JoinRules: + room_state["join_rules"] = event.content.get("join_rule") + elif event.type == EventTypes.RoomHistoryVisibility: + room_state["history_visibility"] = event.content.get( + "history_visibility" ) - txn.execute(sql, (value, stats_id, current_ts)) + elif event.type == EventTypes.Encryption: + room_state["encryption"] = event.content.get("algorithm") + elif event.type == EventTypes.Name: + room_state["name"] = event.content.get("name") + elif event.type == EventTypes.Topic: + room_state["topic"] = event.content.get("topic") + elif event.type == EventTypes.RoomAvatar: + room_state["avatar"] = event.content.get("url") + elif event.type == EventTypes.CanonicalAlias: + room_state["canonical_alias"] = event.content.get("alias") + elif event.type == EventTypes.Create: + room_state["is_federatable"] = event.content.get("m.federate", True) + + yield self.update_room_state(room_id, room_state) + + local_users_in_room = [u for u in users_in_room if self.hs.is_mine_id(u)] + + yield self.update_stats_delta( + ts=self.clock.time_msec(), + stats_type="room", + stats_id=room_id, + fields={}, + complete_with_stream_id=pos, + absolute_field_overrides={ + "current_state_events": current_state_events_count, + "joined_members": membership_counts.get(Membership.JOIN, 0), + "invited_members": membership_counts.get(Membership.INVITE, 0), + "left_members": membership_counts.get(Membership.LEAVE, 0), + "banned_members": membership_counts.get(Membership.BAN, 0), + "local_users_in_room": len(local_users_in_room), + }, + ) + + @defer.inlineCallbacks + def _calculate_and_set_initial_state_for_user(self, user_id): + def _calculate_and_set_initial_state_for_user_txn(txn): + pos = self._get_max_stream_id_in_current_state_deltas_txn(txn) - return self.runInteraction("update_stats_delta", _update_stats_delta) + txn.execute( + """ + SELECT COUNT(distinct room_id) FROM current_state_events + WHERE type = 'm.room.member' AND state_key = ? + AND membership = 'join' + """, + (user_id,), + ) + count, = txn.fetchone() + return count, pos + + joined_rooms, pos = yield self.runInteraction( + "calculate_and_set_initial_state_for_user", + _calculate_and_set_initial_state_for_user_txn, + ) + + yield self.update_stats_delta( + ts=self.clock.time_msec(), + stats_type="user", + stats_id=user_id, + fields={}, + complete_with_stream_id=pos, + absolute_field_overrides={"joined_rooms": joined_rooms}, + ) diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index a8b858eb4f..7569b6fab5 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -13,16 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock - -from twisted.internet import defer - -from synapse.api.constants import EventTypes, Membership +from synapse import storage from synapse.rest import admin from synapse.rest.client.v1 import login, room from tests import unittest +# The expected number of state events in a fresh public room. +EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM = 5 +# The expected number of state events in a fresh private room. +EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM = 6 + class StatsRoomTests(unittest.HomeserverTestCase): @@ -33,7 +34,6 @@ class StatsRoomTests(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() self.handler = self.hs.get_stats_handler() @@ -47,7 +47,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.get_success( self.store._simple_insert( "background_updates", - {"update_name": "populate_stats_createtables", "progress_json": "{}"}, + {"update_name": "populate_stats_prepare", "progress_json": "{}"}, ) ) self.get_success( @@ -56,7 +56,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): { "update_name": "populate_stats_process_rooms", "progress_json": "{}", - "depends_on": "populate_stats_createtables", + "depends_on": "populate_stats_prepare", }, ) ) @@ -64,18 +64,58 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.store._simple_insert( "background_updates", { - "update_name": "populate_stats_cleanup", + "update_name": "populate_stats_process_users", "progress_json": "{}", "depends_on": "populate_stats_process_rooms", }, ) ) + self.get_success( + self.store._simple_insert( + "background_updates", + { + "update_name": "populate_stats_cleanup", + "progress_json": "{}", + "depends_on": "populate_stats_process_users", + }, + ) + ) + + def get_all_room_state(self): + return self.store._simple_select_list( + "room_stats_state", None, retcols=("name", "topic", "canonical_alias") + ) + + def _get_current_stats(self, stats_type, stat_id): + table, id_col = storage.stats.TYPE_TO_TABLE[stats_type] + + cols = list(storage.stats.ABSOLUTE_STATS_FIELDS[stats_type]) + list( + storage.stats.PER_SLICE_FIELDS[stats_type] + ) + + end_ts = self.store.quantise_stats_time(self.reactor.seconds() * 1000) + + return self.get_success( + self.store._simple_select_one( + table + "_historical", + {id_col: stat_id, end_ts: end_ts}, + cols, + allow_none=True, + ) + ) + + def _perform_background_initial_update(self): + # Do the initial population of the stats via the background update + self._add_background_updates() + + while not self.get_success(self.store.has_completed_background_updates()): + self.get_success(self.store.do_next_background_update(100), by=0.1) def test_initial_room(self): """ The background updates will build the table from scratch. """ - r = self.get_success(self.store.get_all_room_state()) + r = self.get_success(self.get_all_room_state()) self.assertEqual(len(r), 0) # Disable stats @@ -91,7 +131,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) # Stats disabled, shouldn't have done anything - r = self.get_success(self.store.get_all_room_state()) + r = self.get_success(self.get_all_room_state()) self.assertEqual(len(r), 0) # Enable stats @@ -104,7 +144,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): while not self.get_success(self.store.has_completed_background_updates()): self.get_success(self.store.do_next_background_update(100), by=0.1) - r = self.get_success(self.store.get_all_room_state()) + r = self.get_success(self.get_all_room_state()) self.assertEqual(len(r), 1) self.assertEqual(r[0]["topic"], "foo") @@ -114,6 +154,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): Ingestion via notify_new_event will ignore tokens that the background update have already processed. """ + self.reactor.advance(86401) self.hs.config.stats_enabled = False @@ -138,12 +179,18 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.hs.config.stats_enabled = True self.handler.stats_enabled = True self.store._all_done = False - self.get_success(self.store.update_stats_stream_pos(None)) + self.get_success( + self.store._simple_update_one( + table="stats_incremental_position", + keyvalues={}, + updatevalues={"stream_id": 0}, + ) + ) self.get_success( self.store._simple_insert( "background_updates", - {"update_name": "populate_stats_createtables", "progress_json": "{}"}, + {"update_name": "populate_stats_prepare", "progress_json": "{}"}, ) ) @@ -154,6 +201,8 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.helper.invite(room=room_1, src=u1, targ=u2, tok=u1_token) self.helper.join(room=room_1, user=u2, tok=u2_token) + # orig_delta_processor = self.store. + # Now do the initial ingestion. self.get_success( self.store._simple_insert( @@ -185,8 +234,15 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.helper.invite(room=room_1, src=u1, targ=u3, tok=u1_token) self.helper.join(room=room_1, user=u3, tok=u3_token) - # Get the deltas! There should be two -- day 1, and day 2. - r = self.get_success(self.store.get_deltas_for_room(room_1, 0)) + # self.handler.notify_new_event() + + # We need to let the delta processor advance… + self.pump(10 * 60) + + # Get the slices! There should be two -- day 1, and day 2. + r = self.get_success(self.store.get_statistics_for_subject("room", room_1, 0)) + + self.assertEqual(len(r), 2) # The oldest has 2 joined members self.assertEqual(r[-1]["joined_members"], 2) @@ -194,111 +250,476 @@ class StatsRoomTests(unittest.HomeserverTestCase): # The newest has 3 self.assertEqual(r[0]["joined_members"], 3) - def test_incorrect_state_transition(self): - """ - If the state transition is not one of (JOIN, INVITE, LEAVE, BAN) to - (JOIN, INVITE, LEAVE, BAN), an error is raised. - """ - events = { - "a1": {"membership": Membership.LEAVE}, - "a2": {"membership": "not a real thing"}, - } - - def get_event(event_id, allow_none=True): - m = Mock() - m.content = events[event_id] - d = defer.Deferred() - self.reactor.callLater(0.0, d.callback, m) - return d - - def get_received_ts(event_id): - return defer.succeed(1) - - self.store.get_received_ts = get_received_ts - self.store.get_event = get_event - - deltas = [ - { - "type": EventTypes.Member, - "state_key": "some_user", - "room_id": "room", - "event_id": "a1", - "prev_event_id": "a2", - "stream_id": 60, - } - ] - - f = self.get_failure(self.handler._handle_deltas(deltas), ValueError) + def test_create_user(self): + """ + When we create a user, it should have statistics already ready. + """ + + u1 = self.register_user("u1", "pass") + + u1stats = self._get_current_stats("user", u1) + + self.assertIsNotNone(u1stats) + + # not in any rooms by default + self.assertEqual(u1stats["joined_rooms"], 0) + + def test_create_room(self): + """ + When we create a room, it should have statistics already ready. + """ + + self._perform_background_initial_update() + + u1 = self.register_user("u1", "pass") + u1token = self.login("u1", "pass") + r1 = self.helper.create_room_as(u1, tok=u1token) + r1stats = self._get_current_stats("room", r1) + r2 = self.helper.create_room_as(u1, tok=u1token, is_public=False) + r2stats = self._get_current_stats("room", r2) + + self.assertIsNotNone(r1stats) + self.assertIsNotNone(r2stats) + + # contains the default things you'd expect in a fresh room self.assertEqual( - f.value.args[0], "'not a real thing' is not a valid prev_membership" - ) - - # And the other way... - deltas = [ - { - "type": EventTypes.Member, - "state_key": "some_user", - "room_id": "room", - "event_id": "a2", - "prev_event_id": "a1", - "stream_id": 100, - } - ] - - f = self.get_failure(self.handler._handle_deltas(deltas), ValueError) + r1stats["total_events"], + EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM, + "Wrong number of total_events in new room's stats!" + " You may need to update this if more state events are added to" + " the room creation process.", + ) self.assertEqual( - f.value.args[0], "'not a real thing' is not a valid membership" + r2stats["total_events"], + EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM, + "Wrong number of total_events in new room's stats!" + " You may need to update this if more state events are added to" + " the room creation process.", ) - def test_redacted_prev_event(self): + self.assertEqual( + r1stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM + ) + self.assertEqual( + r2stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM + ) + + self.assertEqual(r1stats["joined_members"], 1) + self.assertEqual(r1stats["invited_members"], 0) + self.assertEqual(r1stats["banned_members"], 0) + + self.assertEqual(r2stats["joined_members"], 1) + self.assertEqual(r2stats["invited_members"], 0) + self.assertEqual(r2stats["banned_members"], 0) + + def test_send_message_increments_total_events(self): """ - If the prev_event does not exist, then it is assumed to be a LEAVE. + When we send a message, it increments total_events. """ + + self._perform_background_initial_update() + u1 = self.register_user("u1", "pass") - u1_token = self.login("u1", "pass") + u1token = self.login("u1", "pass") + r1 = self.helper.create_room_as(u1, tok=u1token) + r1stats_ante = self._get_current_stats("room", r1) - room_1 = self.helper.create_room_as(u1, tok=u1_token) + self.helper.send(r1, "hiss", tok=u1token) - # Do the initial population of the user directory via the background update - self._add_background_updates() + r1stats_post = self._get_current_stats("room", r1) + + self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) + + def test_send_state_event_nonoverwriting(self): + """ + When we send a non-overwriting state event, it increments total_events AND current_state_events + """ + + self._perform_background_initial_update() + + u1 = self.register_user("u1", "pass") + u1token = self.login("u1", "pass") + r1 = self.helper.create_room_as(u1, tok=u1token) + + self.helper.send_state( + r1, "cat.hissing", {"value": True}, tok=u1token, state_key="tabby" + ) + + r1stats_ante = self._get_current_stats("room", r1) + + self.helper.send_state( + r1, "cat.hissing", {"value": False}, tok=u1token, state_key="moggy" + ) + + r1stats_post = self._get_current_stats("room", r1) + + self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) + self.assertEqual( + r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], + 1, + ) + + def test_send_state_event_overwriting(self): + """ + When we send an overwriting state event, it increments total_events ONLY + """ + + self._perform_background_initial_update() + + u1 = self.register_user("u1", "pass") + u1token = self.login("u1", "pass") + r1 = self.helper.create_room_as(u1, tok=u1token) + + self.helper.send_state( + r1, "cat.hissing", {"value": True}, tok=u1token, state_key="tabby" + ) + + r1stats_ante = self._get_current_stats("room", r1) + + self.helper.send_state( + r1, "cat.hissing", {"value": False}, tok=u1token, state_key="tabby" + ) + + r1stats_post = self._get_current_stats("room", r1) + + self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) + self.assertEqual( + r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], + 0, + ) + + def test_join_first_time(self): + """ + When a user joins a room for the first time, total_events, current_state_events and + joined_members should increase by exactly 1. + """ + + self._perform_background_initial_update() + + u1 = self.register_user("u1", "pass") + u1token = self.login("u1", "pass") + r1 = self.helper.create_room_as(u1, tok=u1token) + + u2 = self.register_user("u2", "pass") + u2token = self.login("u2", "pass") + + r1stats_ante = self._get_current_stats("room", r1) + + self.helper.join(r1, u2, tok=u2token) + + r1stats_post = self._get_current_stats("room", r1) + + self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) + self.assertEqual( + r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], + 1, + ) + self.assertEqual( + r1stats_post["joined_members"] - r1stats_ante["joined_members"], 1 + ) + + def test_join_after_leave(self): + """ + When a user joins a room after being previously left, total_events and + joined_members should increase by exactly 1. + current_state_events should not increase. + left_members should decrease by exactly 1. + """ + + self._perform_background_initial_update() + + u1 = self.register_user("u1", "pass") + u1token = self.login("u1", "pass") + r1 = self.helper.create_room_as(u1, tok=u1token) + + u2 = self.register_user("u2", "pass") + u2token = self.login("u2", "pass") + + self.helper.join(r1, u2, tok=u2token) + self.helper.leave(r1, u2, tok=u2token) + + r1stats_ante = self._get_current_stats("room", r1) + + self.helper.join(r1, u2, tok=u2token) + + r1stats_post = self._get_current_stats("room", r1) + + self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) + self.assertEqual( + r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], + 0, + ) + self.assertEqual( + r1stats_post["joined_members"] - r1stats_ante["joined_members"], +1 + ) + self.assertEqual( + r1stats_post["left_members"] - r1stats_ante["left_members"], -1 + ) + + def test_invited(self): + """ + When a user invites another user, current_state_events, total_events and + invited_members should increase by exactly 1. + """ + + self._perform_background_initial_update() + + u1 = self.register_user("u1", "pass") + u1token = self.login("u1", "pass") + r1 = self.helper.create_room_as(u1, tok=u1token) + + u2 = self.register_user("u2", "pass") + + r1stats_ante = self._get_current_stats("room", r1) + + self.helper.invite(r1, u1, u2, tok=u1token) + + r1stats_post = self._get_current_stats("room", r1) + + self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) + self.assertEqual( + r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], + 1, + ) + self.assertEqual( + r1stats_post["invited_members"] - r1stats_ante["invited_members"], +1 + ) + + def test_join_after_invite(self): + """ + When a user joins a room after being invited, total_events and + joined_members should increase by exactly 1. + current_state_events should not increase. + invited_members should decrease by exactly 1. + """ + + self._perform_background_initial_update() + + u1 = self.register_user("u1", "pass") + u1token = self.login("u1", "pass") + r1 = self.helper.create_room_as(u1, tok=u1token) + + u2 = self.register_user("u2", "pass") + u2token = self.login("u2", "pass") + + self.helper.invite(r1, u1, u2, tok=u1token) + + r1stats_ante = self._get_current_stats("room", r1) + + self.helper.join(r1, u2, tok=u2token) + + r1stats_post = self._get_current_stats("room", r1) + + self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) + self.assertEqual( + r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], + 0, + ) + self.assertEqual( + r1stats_post["joined_members"] - r1stats_ante["joined_members"], +1 + ) + self.assertEqual( + r1stats_post["invited_members"] - r1stats_ante["invited_members"], -1 + ) + + def test_left(self): + """ + When a user leaves a room after joining, total_events and + left_members should increase by exactly 1. + current_state_events should not increase. + joined_members should decrease by exactly 1. + """ + + self._perform_background_initial_update() + + u1 = self.register_user("u1", "pass") + u1token = self.login("u1", "pass") + r1 = self.helper.create_room_as(u1, tok=u1token) + + u2 = self.register_user("u2", "pass") + u2token = self.login("u2", "pass") + + self.helper.join(r1, u2, tok=u2token) + + r1stats_ante = self._get_current_stats("room", r1) + + self.helper.leave(r1, u2, tok=u2token) + + r1stats_post = self._get_current_stats("room", r1) + + self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) + self.assertEqual( + r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], + 0, + ) + self.assertEqual( + r1stats_post["left_members"] - r1stats_ante["left_members"], +1 + ) + self.assertEqual( + r1stats_post["joined_members"] - r1stats_ante["joined_members"], -1 + ) + + def test_banned(self): + """ + When a user is banned from a room after joining, total_events and + left_members should increase by exactly 1. + current_state_events should not increase. + banned_members should decrease by exactly 1. + """ + + self._perform_background_initial_update() + + u1 = self.register_user("u1", "pass") + u1token = self.login("u1", "pass") + r1 = self.helper.create_room_as(u1, tok=u1token) + + u2 = self.register_user("u2", "pass") + u2token = self.login("u2", "pass") + + self.helper.join(r1, u2, tok=u2token) + + r1stats_ante = self._get_current_stats("room", r1) + + self.helper.change_membership(r1, u1, u2, "ban", tok=u1token) + + r1stats_post = self._get_current_stats("room", r1) + + self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) + self.assertEqual( + r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], + 0, + ) + self.assertEqual( + r1stats_post["banned_members"] - r1stats_ante["banned_members"], +1 + ) + self.assertEqual( + r1stats_post["joined_members"] - r1stats_ante["joined_members"], -1 + ) + + def test_initial_background_update(self): + """ + Test that statistics can be generated by the initial background update + handler. + + This test also tests that stats rows are not created for new subjects + when stats are disabled. However, it may be desirable to change this + behaviour eventually to still keep current rows. + """ + + self.hs.config.stats_enabled = False + + u1 = self.register_user("u1", "pass") + u1token = self.login("u1", "pass") + r1 = self.helper.create_room_as(u1, tok=u1token) + + # test that these subjects, which were created during a time of disabled + # stats, do not have stats. + self.assertIsNone(self._get_current_stats("room", r1)) + self.assertIsNone(self._get_current_stats("user", u1)) + + self.hs.config.stats_enabled = True + + self._perform_background_initial_update() + + r1stats = self._get_current_stats("room", r1) + u1stats = self._get_current_stats("user", u1) + + self.assertEqual(r1stats["joined_members"], 1) + self.assertEqual( + r1stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM + ) + + self.assertEqual(u1stats["joined_rooms"], 1) + + def test_incomplete_stats(self): + """ + This tests that we track incomplete statistics. + + We first test that incomplete stats are incrementally generated, + following the preparation of a background regen. + + We then test that these incomplete rows are completed by the background + regen. + """ + + u1 = self.register_user("u1", "pass") + u1token = self.login("u1", "pass") + u2 = self.register_user("u2", "pass") + u2token = self.login("u2", "pass") + u3 = self.register_user("u3", "pass") + r1 = self.helper.create_room_as(u1, tok=u1token, is_public=False) + + # preparation stage of the initial background update + # Ugh, have to reset this flag + self.store._all_done = False + + self.get_success( + self.store._simple_delete( + "room_stats_current", {"1": 1}, "test_delete_stats" + ) + ) + self.get_success( + self.store._simple_delete( + "user_stats_current", {"1": 1}, "test_delete_stats" + ) + ) + + self.helper.invite(r1, u1, u2, tok=u1token) + self.helper.join(r1, u2, tok=u2token) + self.helper.invite(r1, u1, u3, tok=u1token) + self.helper.send(r1, "thou shalt yield", tok=u1token) + + # now do the background updates + + self.store._all_done = False + self.get_success( + self.store._simple_insert( + "background_updates", + { + "update_name": "populate_stats_process_rooms", + "progress_json": "{}", + "depends_on": "populate_stats_prepare", + }, + ) + ) + self.get_success( + self.store._simple_insert( + "background_updates", + { + "update_name": "populate_stats_process_users", + "progress_json": "{}", + "depends_on": "populate_stats_process_rooms", + }, + ) + ) + self.get_success( + self.store._simple_insert( + "background_updates", + { + "update_name": "populate_stats_cleanup", + "progress_json": "{}", + "depends_on": "populate_stats_process_users", + }, + ) + ) while not self.get_success(self.store.has_completed_background_updates()): self.get_success(self.store.do_next_background_update(100), by=0.1) - events = {"a1": None, "a2": {"membership": Membership.JOIN}} - - def get_event(event_id, allow_none=True): - if events.get(event_id): - m = Mock() - m.content = events[event_id] - else: - m = None - d = defer.Deferred() - self.reactor.callLater(0.0, d.callback, m) - return d - - def get_received_ts(event_id): - return defer.succeed(1) - - self.store.get_received_ts = get_received_ts - self.store.get_event = get_event - - deltas = [ - { - "type": EventTypes.Member, - "state_key": "some_user:test", - "room_id": room_1, - "event_id": "a2", - "prev_event_id": "a1", - "stream_id": 100, - } - ] - - # Handle our fake deltas, which has a user going from LEAVE -> JOIN. - self.get_success(self.handler._handle_deltas(deltas)) - - # One delta, with two joined members -- the room creator, and our fake - # user. - r = self.get_success(self.store.get_deltas_for_room(room_1, 0)) - self.assertEqual(len(r), 1) - self.assertEqual(r[0]["joined_members"], 2) + r1stats_complete = self._get_current_stats("room", r1) + u1stats_complete = self._get_current_stats("user", u1) + u2stats_complete = self._get_current_stats("user", u2) + + # now we make our assertions + + # check that _complete rows are complete and correct + self.assertEqual(r1stats_complete["joined_members"], 2) + self.assertEqual(r1stats_complete["invited_members"], 1) + + self.assertEqual( + r1stats_complete["current_state_events"], + 2 + EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM, + ) + + self.assertEqual(u1stats_complete["joined_rooms"], 1) + self.assertEqual(u2stats_complete["joined_rooms"], 1) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 9915367144..cdded88b7f 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -128,8 +128,12 @@ class RestHelper(object): return channel.json_body - def send_state(self, room_id, event_type, body, tok, expect_code=200): - path = "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, event_type) + def send_state(self, room_id, event_type, body, tok, expect_code=200, state_key=""): + path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % ( + room_id, + event_type, + state_key, + ) if tok: path = path + "?access_token=%s" % tok -- cgit 1.5.1 From ac4746ac4bb4d9371c5a25e94ecccd83effb8b9a Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Wed, 17 Jul 2019 22:11:31 -0400 Subject: allow uploading signatures of master key signed by devices --- synapse/handlers/e2e_keys.py | 232 ++++++++++++++++++++++------------- synapse/rest/client/v2_alpha/keys.py | 2 +- synapse/storage/end_to_end_keys.py | 2 +- tests/handlers/test_e2e_keys.py | 227 +++++++++++++++++++++++++++++++++- 4 files changed, 378 insertions(+), 85 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 9747b517ff..1148803c1e 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -20,7 +20,9 @@ import logging from six import iteritems from canonicaljson import encode_canonical_json, json +from signedjson.key import decode_verify_key_bytes from signedjson.sign import SignatureVerifyException, verify_signed_json +from unpaddedbase64 import decode_base64 from twisted.internet import defer @@ -619,8 +621,11 @@ class E2eKeysHandler(object): """ failures = {} - signature_list = [] # signatures to be stored - self_device_ids = [] # what devices have been updated, for notifying + # signatures to be stored. Each item will be a tuple of + # (signing_key_id, target_user_id, target_device_id, signature) + signature_list = [] + # what devices have been updated, for notifying + self_device_ids = [] # split between checking signatures for own user and signatures for # other users, since we verify them with different keys @@ -630,46 +635,107 @@ class E2eKeysHandler(object): self_device_ids = list(self_signatures.keys()) try: # get our self-signing key to verify the signatures - self_signing_key = yield self.store.get_e2e_cross_signing_key( - user_id, "self_signing" - ) - if self_signing_key is None: - raise SynapseError( - 404, - "No self-signing key found", - Codes.NOT_FOUND + self_signing_key, self_signing_key_id, self_signing_verify_key \ + = yield self._get_e2e_cross_signing_verify_key( + user_id, "self_signing" ) - self_signing_key_id, self_signing_verify_key \ - = get_verify_key_from_cross_signing_key(self_signing_key) + # get our master key, since it may be signed + master_key, master_key_id, master_verify_key \ + = yield self._get_e2e_cross_signing_verify_key( + user_id, "master" + ) - # fetch our stored devices so that we can compare with what was sent - user_devices = [] - for device in self_signatures.keys(): - user_devices.append((user_id, device)) - devices = yield self.store.get_e2e_device_keys(user_devices) + # fetch our stored devices. This is used to 1. verify + # signatures on the master key, and 2. to can compare with what + # was sent if the device was signed + devices = yield self.store.get_e2e_device_keys([(user_id, None)]) if user_id not in devices: raise SynapseError( - 404, - "No device key found", - Codes.NOT_FOUND + 404, "No device keys found", Codes.NOT_FOUND ) devices = devices[user_id] for device_id, device in self_signatures.items(): try: if ("signatures" not in device or - user_id not in device["signatures"] or - self_signing_key_id not in device["signatures"][user_id]): + user_id not in device["signatures"]): + # no signature was sent + raise SynapseError( + 400, "Invalid signature", Codes.INVALID_SIGNATURE + ) + + if device_id == master_verify_key.version: + # we have master key signed by devices: for each + # device that signed, check the signature. Since + # the "failures" property in the response only has + # granularity up to the signed device, either all + # of the signatures on the master key succeed, or + # all fail. So loop over the signatures and add + # them to a separate signature list. If everything + # works out, then add them all to the main + # signature list. (In practice, we're likely to + # only have only one signature anyways.) + master_key_signature_list = [] + for signing_key_id, signature in device["signatures"][user_id].items(): + alg, signing_device_id = signing_key_id.split(":", 1) + if (signing_device_id not in devices or + signing_key_id not in + devices[signing_device_id]["keys"]["keys"]): + # signed by an unknown device, or the + # device does not have the key + raise SynapseError( + 400, "Invalid signature", Codes.INVALID_SIGNATURE + ) + + sigs = device["signatures"] + del device["signatures"] + # use pop to avoid exception if key doesn't exist + device.pop("unsigned", None) + master_key.pop("signature", None) + master_key.pop("unsigned", None) + + if master_key != device: + raise SynapseError( + 400, "Key does not match" + ) + + # get the key and check the signature + pubkey = devices[signing_device_id]["keys"]["keys"][signing_key_id] + verify_key = decode_verify_key_bytes( + signing_key_id, decode_base64(pubkey) + ) + device["signatures"] = sigs + try: + verify_signed_json(device, user_id, verify_key) + except SignatureVerifyException: + raise SynapseError( + 400, "Invalid signature", Codes.INVALID_SIGNATURE + ) + + master_key_signature_list.append( + (signing_key_id, user_id, device_id, signature) + ) + + signature_list.extend(master_key_signature_list) + continue + + # at this point, we have a device that should be signed + # by the self-signing key + if self_signing_key_id not in device["signatures"][user_id]: # no signature was sent raise SynapseError( - 400, - "Invalid signature", - Codes.INVALID_SIGNATURE + 400, "Invalid signature", Codes.INVALID_SIGNATURE ) - stored_device = devices[device_id]["keys"] + stored_device = None + try: + stored_device = devices[device_id]["keys"] + except KeyError: + raise SynapseError( + 404, "Unknown device", Codes.NOT_FOUND + ) if self_signing_key_id in stored_device.get("signatures", {}) \ .get(user_id, {}): # we already have a signature on this device, so we @@ -698,69 +764,50 @@ class E2eKeysHandler(object): # if signatures isn't empty, then we have signatures for other # users. These signatures will be signed by the user signing key - # get our user-signing key to verify the signatures - user_signing_key = yield self.store.get_e2e_cross_signing_key( - user_id, "user_signing" - ) - if user_signing_key is None: - for user, devicemap in signatures.items(): - failures[user] = { - device: _exception_to_failure(SynapseError( - 404, - "No user-signing key found", - Codes.NOT_FOUND - )) - for device in devicemap.keys() - } - else: - user_signing_key_id, user_signing_verify_key \ - = get_verify_key_from_cross_signing_key(user_signing_key) + try: + # get our user-signing key to verify the signatures + user_signing_key, user_signing_key_id, user_signing_verify_key \ + = yield self._get_e2e_cross_signing_verify_key( + user_id, "user_signing" + ) for user, devicemap in signatures.items(): device_id = None try: # get the user's master key, to make sure it matches # what was sent - stored_key = yield self.store.get_e2e_cross_signing_key( - user, "master", user_id - ) - if stored_key is None: - logger.error( - "upload signature: no user key found for %s", user - ) - raise SynapseError( - 404, - "User's master key not found", - Codes.NOT_FOUND + stored_key, stored_key_id, _ \ + = yield self._get_e2e_cross_signing_verify_key( + user, "master", user_id ) # make sure that the user's master key is the one that # was signed (and no others) - device_id = get_verify_key_from_cross_signing_key(stored_key)[0] \ - .split(":", 1)[1] + device_id = stored_key_id.split(":", 1)[1] if device_id not in devicemap: + # set device to None so that the failure gets + # marked on all the signatures + device_id = None logger.error( "upload signature: wrong device: %s vs %s", device, devicemap ) raise SynapseError( - 404, - "Unknown device", - Codes.NOT_FOUND + 404, "Unknown device", Codes.NOT_FOUND ) - if len(devicemap) > 1: + key = devicemap[device_id] + del devicemap[device_id] + if len(devicemap) > 0: + # other devices were signed -- mark those as failures logger.error("upload signature: too many devices specified") + failure = _exception_to_failure(SynapseError( + 404, "Unknown device", Codes.NOT_FOUND + )) failures[user] = { - device: _exception_to_failure(SynapseError( - 404, - "Unknown device", - Codes.NOT_FOUND - )) + device: failure for device in devicemap.keys() } - key = devicemap[device_id] - if user_signing_key_id in stored_key.get("signatures", {}) \ .get(user_id, {}): # we already have the signature, so we can skip it @@ -770,25 +817,31 @@ class E2eKeysHandler(object): user_id, user_signing_verify_key, key, stored_key ) - signature = key["signatures"][user_id][user_signing_key_id] - signed_users.append(user) + signature = key["signatures"][user_id][user_signing_key_id] signature_list.append( (user_signing_key_id, user, device_id, signature) ) except SynapseError as e: + failure = _exception_to_failure(e) if device_id is None: failures[user] = { - device_id: _exception_to_failure(e) + device_id: failure for device_id in devicemap.keys() } else: - failures.setdefault(user, {})[device_id] \ - = _exception_to_failure(e) + failures.setdefault(user, {})[device_id] = failure + except SynapseError as e: + failure = _exception_to_failure(e) + for user, devicemap in signature.items(): + failures[user] = { + device_id: failure + for device_id in devicemap.keys() + } # store the signature, and send the appropriate notifications for sync logger.debug("upload signature failures: %r", failures) - yield self.store.store_e2e_device_signatures(user_id, signature_list) + yield self.store.store_e2e_cross_signing_signatures(user_id, signature_list) if len(self_device_ids): yield self.device_handler.notify_device_update(user_id, self_device_ids) @@ -797,6 +850,22 @@ class E2eKeysHandler(object): defer.returnValue({"failures": failures}) + @defer.inlineCallbacks + def _get_e2e_cross_signing_verify_key(self, user_id, key_type, from_user_id=None): + key = yield self.store.get_e2e_cross_signing_key( + user_id, key_type, from_user_id + ) + if key is None: + logger.error("no %s key found for %s", key_type, user_id) + raise SynapseError( + 404, + "No %s key found for %s" % (key_type, user_id), + Codes.NOT_FOUND + ) + key_id, verify_key = get_verify_key_from_cross_signing_key(key) + return key, key_id, verify_key + + def _check_cross_signing_key(key, user_id, key_type, signing_key=None): """Check a cross-signing key uploaded by a user. Performs some basic sanity checking, and ensures that it is signed, if a signature is required. @@ -851,21 +920,17 @@ def _check_device_signature(user_id, verify_key, signed_device, stored_device): # make sure that the device submitted matches what we have stored del signed_device["signatures"] - if "unsigned" in signed_device: - del signed_device["unsigned"] - if "signatures" in stored_device: - del stored_device["signatures"] - if "unsigned" in stored_device: - del stored_device["unsigned"] + # use pop to avoid exception if key doesn't exist + signed_device.pop("unsigned", None) + stored_device.pop("signatures", None) + stored_device.pop("unsigned", None) if signed_device != stored_device: logger.error( "upload signatures: key does not match %s vs %s", signed_device, stored_device ) raise SynapseError( - 400, - "Key does not match", - "M_MISMATCHED_KEY" + 400, "Key does not match", ) # check the signature @@ -887,6 +952,9 @@ def _check_device_signature(user_id, verify_key, signed_device, stored_device): def _exception_to_failure(e): + if isinstance(e, SynapseError): + return {"status": e.code, "errcode": e.errcode, "message": str(e)} + if isinstance(e, CodeMessageException): return {"status": e.code, "message": str(e)} diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 5c288d48b7..cb3c52cb8e 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -303,7 +303,7 @@ class SignaturesUploadServlet(RestServlet): } } """ - PATTERNS = client_v2_patterns("/keys/signatures/upload$") + PATTERNS = client_patterns("/keys/signatures/upload$") def __init__(self, hs): """ diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index fe786f3093..e68ce318af 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -132,7 +132,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): # get signatures on the device signature_sql = ( "SELECT * " - " FROM e2e_device_signatures " + " FROM e2e_cross_signing_signatures " " WHERE %s" ) % ( " OR ".join("(" + q + ")" for q in signature_query_clauses) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index a62c52eefa..b1d3a4cfae 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -17,9 +17,10 @@ import mock +import signedjson.key as key +import signedjson.sign as sign from twisted.internet import defer -import synapse.api.errors import synapse.handlers.e2e_keys import synapse.storage from synapse.api import errors @@ -210,3 +211,227 @@ class E2eKeysHandlerTestCase(unittest.TestCase): res = yield self.handler.query_local_devices({local_user: None}) self.assertDictEqual(res, {local_user: {}}) + + @defer.inlineCallbacks + def test_upload_signatures(self): + """should check signatures that are uploaded""" + # set up a user with cross-signing keys and a device. This user will + # try uploading signatures + local_user = "@boris:" + self.hs.hostname + device_id = "xyz" + # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA + device_pubkey = "NnHhnqiMFQkq969szYkooLaBAXW244ZOxgukCvm2ZeY" + device_key = { + "user_id": local_user, + "device_id": device_id, + "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"], + "keys": { + "curve25519:xyz": "curve25519+key", + "ed25519:xyz": device_pubkey + }, + "signatures": { + local_user: { + "ed25519:xyz": "something" + } + } + } + device_signing_key = key.decode_signing_key_base64( + "ed25519", + "xyz", + "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA" + ) + + yield self.handler.upload_keys_for_user( + local_user, device_id, {"device_keys": device_key} + ) + + # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0 + master_pubkey = "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk" + master_key = { + "user_id": local_user, + "usage": ["master"], + "keys": { + "ed25519:" + master_pubkey: master_pubkey + } + } + master_signing_key = key.decode_signing_key_base64( + "ed25519", master_pubkey, + "2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0" + ) + usersigning_pubkey = "Hq6gL+utB4ET+UvD5ci0kgAwsX6qP/zvf8v6OInU5iw" + usersigning_key = { + # private key: 4TL4AjRYwDVwD3pqQzcor+ez/euOB1/q78aTJ+czDNs + "user_id": local_user, + "usage": ["user_signing"], + "keys": { + "ed25519:" + usersigning_pubkey: usersigning_pubkey, + } + } + usersigning_signing_key = key.decode_signing_key_base64( + "ed25519", usersigning_pubkey, + "4TL4AjRYwDVwD3pqQzcor+ez/euOB1/q78aTJ+czDNs" + ) + sign.sign_json(usersigning_key, local_user, master_signing_key) + # private key: HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8 + selfsigning_pubkey = "EmkqvokUn8p+vQAGZitOk4PWjp7Ukp3txV2TbMPEiBQ" + selfsigning_key = { + "user_id": local_user, + "usage": ["self_signing"], + "keys": { + "ed25519:" + selfsigning_pubkey: selfsigning_pubkey, + } + } + selfsigning_signing_key = key.decode_signing_key_base64( + "ed25519", selfsigning_pubkey, + "HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8" + ) + sign.sign_json(selfsigning_key, local_user, master_signing_key) + cross_signing_keys = { + "master_key": master_key, + "user_signing_key": usersigning_key, + "self_signing_key": selfsigning_key, + } + yield self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys) + + # set up another user with a master key. This user will be signed by + # the first user + other_user = "@otherboris:" + self.hs.hostname + other_master_pubkey = "fHZ3NPiKxoLQm5OoZbKa99SYxprOjNs4TwJUKP+twCM" + other_master_key = { + # private key: oyw2ZUx0O4GifbfFYM0nQvj9CL0b8B7cyN4FprtK8OI + "user_id": other_user, + "usage": ["master"], + "keys": { + "ed25519:" + other_master_pubkey: other_master_pubkey + } + } + yield self.handler.upload_signing_keys_for_user(other_user, { + "master_key": other_master_key + }) + + # test various signature failures (see below) + ret = yield self.handler.upload_signatures_for_device_keys( + local_user, + { + local_user: { + # fails because the signature is invalid + # should fail with INVALID_SIGNATURE + device_id: { + "user_id": local_user, + "device_id": device_id, + "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"], + "keys": { + "curve25519:xyz": "curve25519+key", + # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA + "ed25519:xyz": device_pubkey + }, + "signatures": { + local_user: { + "ed25519:" + selfsigning_pubkey: "something", + } + } + }, + # fails because device is unknown + # should fail with NOT_FOUND + "unknown": { + "user_id": local_user, + "device_id": "unknown", + "signatures": { + local_user: { + "ed25519:" + selfsigning_pubkey: "something", + } + } + }, + # fails because the signature is invalid + # should fail with INVALID_SIGNATURE + master_pubkey: { + "user_id": local_user, + "usage": ["master"], + "keys": { + "ed25519:" + master_pubkey: master_pubkey + }, + "signatures": { + local_user: { + "ed25519:" + device_pubkey: "something", + } + } + } + }, + other_user: { + # fails because the device is not the user's master-signing key + # should fail with NOT_FOUND + "unknown": { + "user_id": other_user, + "device_id": "unknown", + "signatures": { + local_user: { + "ed25519:" + usersigning_pubkey: "something", + } + } + }, + other_master_pubkey: { + # fails because the key doesn't match what the server has + # should fail with UNKNOWN + "user_id": other_user, + "usage": ["master"], + "keys": { + "ed25519:" + other_master_pubkey: other_master_pubkey + }, + "something": "random", + "signatures": { + local_user: { + "ed25519:" + usersigning_pubkey: "something", + } + } + } + } + } + ) + + user_failures = ret["failures"][local_user] + self.assertEqual(user_failures[device_id]["errcode"], errors.Codes.INVALID_SIGNATURE) + self.assertEqual(user_failures[master_pubkey]["errcode"], errors.Codes.INVALID_SIGNATURE) + self.assertEqual(user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND) + + other_user_failures = ret["failures"][other_user] + self.assertEqual(other_user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND) + self.assertEqual(other_user_failures[other_master_pubkey]["errcode"], errors.Codes.UNKNOWN) + + # test successful signatures + del device_key["signatures"] + sign.sign_json(device_key, local_user, selfsigning_signing_key) + sign.sign_json(master_key, local_user, device_signing_key) + sign.sign_json(other_master_key, local_user, usersigning_signing_key) + ret = yield self.handler.upload_signatures_for_device_keys( + local_user, + { + local_user: { + device_id: device_key, + master_pubkey: master_key + }, + other_user: { + other_master_pubkey: other_master_key + } + } + ) + + self.assertEqual(ret["failures"], {}) + + # fetch the signed keys/devices and make sure that the signatures are there + ret = yield self.handler.query_devices( + {"device_keys": {local_user: [], other_user: []}}, + 0, local_user + ) + + self.assertEqual( + ret["device_keys"][local_user]["xyz"]["signatures"][local_user]["ed25519:" + selfsigning_pubkey], + device_key["signatures"][local_user]["ed25519:" + selfsigning_pubkey] + ) + self.assertEqual( + ret["master_keys"][local_user]["signatures"][local_user]["ed25519:" + device_id], + master_key["signatures"][local_user]["ed25519:" + device_id] + ) + self.assertEqual( + ret["master_keys"][other_user]["signatures"][local_user]["ed25519:" + usersigning_pubkey], + other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey] + ) -- cgit 1.5.1 From 7d6c70fc7ad08b94b8b577c537953a8d9b568562 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Mon, 22 Jul 2019 12:52:39 -0400 Subject: make black happy --- synapse/handlers/e2e_keys.py | 147 ++++++++++++++++------------------- synapse/rest/client/v2_alpha/keys.py | 1 + synapse/storage/end_to_end_keys.py | 24 +++--- tests/handlers/test_e2e_keys.py | 147 +++++++++++++++-------------------- 4 files changed, 141 insertions(+), 178 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 1148803c1e..74bceddc46 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -635,16 +635,14 @@ class E2eKeysHandler(object): self_device_ids = list(self_signatures.keys()) try: # get our self-signing key to verify the signatures - self_signing_key, self_signing_key_id, self_signing_verify_key \ - = yield self._get_e2e_cross_signing_verify_key( - user_id, "self_signing" - ) + self_signing_key, self_signing_key_id, self_signing_verify_key = yield self._get_e2e_cross_signing_verify_key( + user_id, "self_signing" + ) # get our master key, since it may be signed - master_key, master_key_id, master_verify_key \ - = yield self._get_e2e_cross_signing_verify_key( - user_id, "master" - ) + master_key, master_key_id, master_verify_key = yield self._get_e2e_cross_signing_verify_key( + user_id, "master" + ) # fetch our stored devices. This is used to 1. verify # signatures on the master key, and 2. to can compare with what @@ -652,15 +650,15 @@ class E2eKeysHandler(object): devices = yield self.store.get_e2e_device_keys([(user_id, None)]) if user_id not in devices: - raise SynapseError( - 404, "No device keys found", Codes.NOT_FOUND - ) + raise SynapseError(404, "No device keys found", Codes.NOT_FOUND) devices = devices[user_id] for device_id, device in self_signatures.items(): try: - if ("signatures" not in device or - user_id not in device["signatures"]): + if ( + "signatures" not in device + or user_id not in device["signatures"] + ): # no signature was sent raise SynapseError( 400, "Invalid signature", Codes.INVALID_SIGNATURE @@ -678,15 +676,21 @@ class E2eKeysHandler(object): # signature list. (In practice, we're likely to # only have only one signature anyways.) master_key_signature_list = [] - for signing_key_id, signature in device["signatures"][user_id].items(): + for signing_key_id, signature in device["signatures"][ + user_id + ].items(): alg, signing_device_id = signing_key_id.split(":", 1) - if (signing_device_id not in devices or - signing_key_id not in - devices[signing_device_id]["keys"]["keys"]): + if ( + signing_device_id not in devices + or signing_key_id + not in devices[signing_device_id]["keys"]["keys"] + ): # signed by an unknown device, or the # device does not have the key raise SynapseError( - 400, "Invalid signature", Codes.INVALID_SIGNATURE + 400, + "Invalid signature", + Codes.INVALID_SIGNATURE, ) sigs = device["signatures"] @@ -697,12 +701,12 @@ class E2eKeysHandler(object): master_key.pop("unsigned", None) if master_key != device: - raise SynapseError( - 400, "Key does not match" - ) + raise SynapseError(400, "Key does not match") # get the key and check the signature - pubkey = devices[signing_device_id]["keys"]["keys"][signing_key_id] + pubkey = devices[signing_device_id]["keys"]["keys"][ + signing_key_id + ] verify_key = decode_verify_key_bytes( signing_key_id, decode_base64(pubkey) ) @@ -711,7 +715,9 @@ class E2eKeysHandler(object): verify_signed_json(device, user_id, verify_key) except SignatureVerifyException: raise SynapseError( - 400, "Invalid signature", Codes.INVALID_SIGNATURE + 400, + "Invalid signature", + Codes.INVALID_SIGNATURE, ) master_key_signature_list.append( @@ -733,11 +739,10 @@ class E2eKeysHandler(object): try: stored_device = devices[device_id]["keys"] except KeyError: - raise SynapseError( - 404, "Unknown device", Codes.NOT_FOUND - ) - if self_signing_key_id in stored_device.get("signatures", {}) \ - .get(user_id, {}): + raise SynapseError(404, "Unknown device", Codes.NOT_FOUND) + if self_signing_key_id in stored_device.get( + "signatures", {} + ).get(user_id, {}): # we already have a signature on this device, so we # can skip it, since it should be exactly the same continue @@ -751,8 +756,9 @@ class E2eKeysHandler(object): (self_signing_key_id, user_id, device_id, signature) ) except SynapseError as e: - failures.setdefault(user_id, {})[device_id] \ - = _exception_to_failure(e) + failures.setdefault(user_id, {})[ + device_id + ] = _exception_to_failure(e) except SynapseError as e: failures[user_id] = { device: _exception_to_failure(e) @@ -766,20 +772,18 @@ class E2eKeysHandler(object): try: # get our user-signing key to verify the signatures - user_signing_key, user_signing_key_id, user_signing_verify_key \ - = yield self._get_e2e_cross_signing_verify_key( - user_id, "user_signing" - ) + user_signing_key, user_signing_key_id, user_signing_verify_key = yield self._get_e2e_cross_signing_verify_key( + user_id, "user_signing" + ) for user, devicemap in signatures.items(): device_id = None try: # get the user's master key, to make sure it matches # what was sent - stored_key, stored_key_id, _ \ - = yield self._get_e2e_cross_signing_verify_key( - user, "master", user_id - ) + stored_key, stored_key_id, _ = yield self._get_e2e_cross_signing_verify_key( + user, "master", user_id + ) # make sure that the user's master key is the one that # was signed (and no others) @@ -790,26 +794,25 @@ class E2eKeysHandler(object): device_id = None logger.error( "upload signature: wrong device: %s vs %s", - device, devicemap - ) - raise SynapseError( - 404, "Unknown device", Codes.NOT_FOUND + device, + devicemap, ) + raise SynapseError(404, "Unknown device", Codes.NOT_FOUND) key = devicemap[device_id] del devicemap[device_id] if len(devicemap) > 0: # other devices were signed -- mark those as failures logger.error("upload signature: too many devices specified") - failure = _exception_to_failure(SynapseError( - 404, "Unknown device", Codes.NOT_FOUND - )) + failure = _exception_to_failure( + SynapseError(404, "Unknown device", Codes.NOT_FOUND) + ) failures[user] = { - device: failure - for device in devicemap.keys() + device: failure for device in devicemap.keys() } - if user_signing_key_id in stored_key.get("signatures", {}) \ - .get(user_id, {}): + if user_signing_key_id in stored_key.get("signatures", {}).get( + user_id, {} + ): # we already have the signature, so we can skip it continue @@ -826,8 +829,7 @@ class E2eKeysHandler(object): failure = _exception_to_failure(e) if device_id is None: failures[user] = { - device_id: failure - for device_id in devicemap.keys() + device_id: failure for device_id in devicemap.keys() } else: failures.setdefault(user, {})[device_id] = failure @@ -835,8 +837,7 @@ class E2eKeysHandler(object): failure = _exception_to_failure(e) for user, devicemap in signature.items(): failures[user] = { - device_id: failure - for device_id in devicemap.keys() + device_id: failure for device_id in devicemap.keys() } # store the signature, and send the appropriate notifications for sync @@ -846,7 +847,9 @@ class E2eKeysHandler(object): if len(self_device_ids): yield self.device_handler.notify_device_update(user_id, self_device_ids) if len(signed_users): - yield self.device_handler.notify_user_signature_update(user_id, signed_users) + yield self.device_handler.notify_user_signature_update( + user_id, signed_users + ) defer.returnValue({"failures": failures}) @@ -858,9 +861,7 @@ class E2eKeysHandler(object): if key is None: logger.error("no %s key found for %s", key_type, user_id) raise SynapseError( - 404, - "No %s key found for %s" % (key_type, user_id), - Codes.NOT_FOUND + 404, "No %s key found for %s" % (key_type, user_id), Codes.NOT_FOUND ) key_id, verify_key = get_verify_key_from_cross_signing_key(key) return key, key_id, verify_key @@ -907,14 +908,13 @@ def _check_device_signature(user_id, verify_key, signed_device, stored_device): key_id = "%s:%s" % (verify_key.alg, verify_key.version) # make sure the device is signed - if ("signatures" not in signed_device or user_id not in signed_device["signatures"] - or key_id not in signed_device["signatures"][user_id]): + if ( + "signatures" not in signed_device + or user_id not in signed_device["signatures"] + or key_id not in signed_device["signatures"][user_id] + ): logger.error("upload signature: user not found in signatures") - raise SynapseError( - 400, - "Invalid signature", - Codes.INVALID_SIGNATURE - ) + raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE) signature = signed_device["signatures"][user_id][key_id] @@ -927,28 +927,19 @@ def _check_device_signature(user_id, verify_key, signed_device, stored_device): if signed_device != stored_device: logger.error( "upload signatures: key does not match %s vs %s", - signed_device, stored_device - ) - raise SynapseError( - 400, "Key does not match", + signed_device, + stored_device, ) + raise SynapseError(400, "Key does not match") # check the signature - signed_device["signatures"] = { - user_id: { - key_id: signature - } - } + signed_device["signatures"] = {user_id: {key_id: signature}} try: verify_signed_json(signed_device, user_id, verify_key) except SignatureVerifyException: logger.error("invalid signature on key") - raise SynapseError( - 400, - "Invalid signature", - Codes.INVALID_SIGNATURE - ) + raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE) def _exception_to_failure(e): diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index cb3c52cb8e..a205281830 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -303,6 +303,7 @@ class SignaturesUploadServlet(RestServlet): } } """ + PATTERNS = client_patterns("/keys/signatures/upload$") def __init__(self, hs): diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index e68ce318af..258e8dcb47 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -62,9 +62,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore): # add cross-signing signatures to the keys if "signatures" in device_info: for sig_user_id, sigs in device_info["signatures"].items(): - device_info["keys"].setdefault("signatures", {}) \ - .setdefault(sig_user_id, {}) \ - .update(sigs) + device_info["keys"].setdefault("signatures", {}).setdefault( + sig_user_id, {} + ).update(sigs) return results @@ -131,12 +131,8 @@ class EndToEndKeyWorkerStore(SQLBaseStore): # get signatures on the device signature_sql = ( - "SELECT * " - " FROM e2e_cross_signing_signatures " - " WHERE %s" - ) % ( - " OR ".join("(" + q + ")" for q in signature_query_clauses) - ) + "SELECT * " " FROM e2e_cross_signing_signatures " " WHERE %s" + ) % (" OR ".join("(" + q + ")" for q in signature_query_clauses)) txn.execute(signature_sql, signature_query_params) rows = self.cursor_to_dict(txn) @@ -144,12 +140,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore): for row in rows: target_user_id = row["target_user_id"] target_device_id = row["target_device_id"] - if target_user_id in result \ - and target_device_id in result[target_user_id]: - result[target_user_id][target_device_id] \ - .setdefault("signatures", {}) \ - .setdefault(row["user_id"], {})[row["key_id"]] \ - = row["signature"] + if target_user_id in result and target_device_id in result[target_user_id]: + result[target_user_id][target_device_id].setdefault( + "signatures", {} + ).setdefault(row["user_id"], {})[row["key_id"]] = row["signature"] log_kv(result) return result diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index b1d3a4cfae..8c0ee3f7d3 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -225,20 +225,11 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "user_id": local_user, "device_id": device_id, "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"], - "keys": { - "curve25519:xyz": "curve25519+key", - "ed25519:xyz": device_pubkey - }, - "signatures": { - local_user: { - "ed25519:xyz": "something" - } - } + "keys": {"curve25519:xyz": "curve25519+key", "ed25519:xyz": device_pubkey}, + "signatures": {local_user: {"ed25519:xyz": "something"}}, } device_signing_key = key.decode_signing_key_base64( - "ed25519", - "xyz", - "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA" + "ed25519", "xyz", "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA" ) yield self.handler.upload_keys_for_user( @@ -250,26 +241,20 @@ class E2eKeysHandlerTestCase(unittest.TestCase): master_key = { "user_id": local_user, "usage": ["master"], - "keys": { - "ed25519:" + master_pubkey: master_pubkey - } + "keys": {"ed25519:" + master_pubkey: master_pubkey}, } master_signing_key = key.decode_signing_key_base64( - "ed25519", master_pubkey, - "2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0" + "ed25519", master_pubkey, "2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0" ) usersigning_pubkey = "Hq6gL+utB4ET+UvD5ci0kgAwsX6qP/zvf8v6OInU5iw" usersigning_key = { # private key: 4TL4AjRYwDVwD3pqQzcor+ez/euOB1/q78aTJ+czDNs "user_id": local_user, "usage": ["user_signing"], - "keys": { - "ed25519:" + usersigning_pubkey: usersigning_pubkey, - } + "keys": {"ed25519:" + usersigning_pubkey: usersigning_pubkey}, } usersigning_signing_key = key.decode_signing_key_base64( - "ed25519", usersigning_pubkey, - "4TL4AjRYwDVwD3pqQzcor+ez/euOB1/q78aTJ+czDNs" + "ed25519", usersigning_pubkey, "4TL4AjRYwDVwD3pqQzcor+ez/euOB1/q78aTJ+czDNs" ) sign.sign_json(usersigning_key, local_user, master_signing_key) # private key: HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8 @@ -277,13 +262,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase): selfsigning_key = { "user_id": local_user, "usage": ["self_signing"], - "keys": { - "ed25519:" + selfsigning_pubkey: selfsigning_pubkey, - } + "keys": {"ed25519:" + selfsigning_pubkey: selfsigning_pubkey}, } selfsigning_signing_key = key.decode_signing_key_base64( - "ed25519", selfsigning_pubkey, - "HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8" + "ed25519", selfsigning_pubkey, "HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8" ) sign.sign_json(selfsigning_key, local_user, master_signing_key) cross_signing_keys = { @@ -301,13 +283,11 @@ class E2eKeysHandlerTestCase(unittest.TestCase): # private key: oyw2ZUx0O4GifbfFYM0nQvj9CL0b8B7cyN4FprtK8OI "user_id": other_user, "usage": ["master"], - "keys": { - "ed25519:" + other_master_pubkey: other_master_pubkey - } + "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey}, } - yield self.handler.upload_signing_keys_for_user(other_user, { - "master_key": other_master_key - }) + yield self.handler.upload_signing_keys_for_user( + other_user, {"master_key": other_master_key} + ) # test various signature failures (see below) ret = yield self.handler.upload_signatures_for_device_keys( @@ -319,17 +299,18 @@ class E2eKeysHandlerTestCase(unittest.TestCase): device_id: { "user_id": local_user, "device_id": device_id, - "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"], + "algorithms": [ + "m.olm.curve25519-aes-sha256", + "m.megolm.v1.aes-sha", + ], "keys": { "curve25519:xyz": "curve25519+key", # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA - "ed25519:xyz": device_pubkey + "ed25519:xyz": device_pubkey, }, "signatures": { - local_user: { - "ed25519:" + selfsigning_pubkey: "something", - } - } + local_user: {"ed25519:" + selfsigning_pubkey: "something"} + }, }, # fails because device is unknown # should fail with NOT_FOUND @@ -337,25 +318,19 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "user_id": local_user, "device_id": "unknown", "signatures": { - local_user: { - "ed25519:" + selfsigning_pubkey: "something", - } - } + local_user: {"ed25519:" + selfsigning_pubkey: "something"} + }, }, # fails because the signature is invalid # should fail with INVALID_SIGNATURE master_pubkey: { "user_id": local_user, "usage": ["master"], - "keys": { - "ed25519:" + master_pubkey: master_pubkey - }, + "keys": {"ed25519:" + master_pubkey: master_pubkey}, "signatures": { - local_user: { - "ed25519:" + device_pubkey: "something", - } - } - } + local_user: {"ed25519:" + device_pubkey: "something"} + }, + }, }, other_user: { # fails because the device is not the user's master-signing key @@ -364,38 +339,40 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "user_id": other_user, "device_id": "unknown", "signatures": { - local_user: { - "ed25519:" + usersigning_pubkey: "something", - } - } + local_user: {"ed25519:" + usersigning_pubkey: "something"} + }, }, other_master_pubkey: { # fails because the key doesn't match what the server has # should fail with UNKNOWN "user_id": other_user, "usage": ["master"], - "keys": { - "ed25519:" + other_master_pubkey: other_master_pubkey - }, + "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey}, "something": "random", "signatures": { - local_user: { - "ed25519:" + usersigning_pubkey: "something", - } - } - } - } - } + local_user: {"ed25519:" + usersigning_pubkey: "something"} + }, + }, + }, + }, ) user_failures = ret["failures"][local_user] - self.assertEqual(user_failures[device_id]["errcode"], errors.Codes.INVALID_SIGNATURE) - self.assertEqual(user_failures[master_pubkey]["errcode"], errors.Codes.INVALID_SIGNATURE) + self.assertEqual( + user_failures[device_id]["errcode"], errors.Codes.INVALID_SIGNATURE + ) + self.assertEqual( + user_failures[master_pubkey]["errcode"], errors.Codes.INVALID_SIGNATURE + ) self.assertEqual(user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND) other_user_failures = ret["failures"][other_user] - self.assertEqual(other_user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND) - self.assertEqual(other_user_failures[other_master_pubkey]["errcode"], errors.Codes.UNKNOWN) + self.assertEqual( + other_user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND + ) + self.assertEqual( + other_user_failures[other_master_pubkey]["errcode"], errors.Codes.UNKNOWN + ) # test successful signatures del device_key["signatures"] @@ -405,33 +382,33 @@ class E2eKeysHandlerTestCase(unittest.TestCase): ret = yield self.handler.upload_signatures_for_device_keys( local_user, { - local_user: { - device_id: device_key, - master_pubkey: master_key - }, - other_user: { - other_master_pubkey: other_master_key - } - } + local_user: {device_id: device_key, master_pubkey: master_key}, + other_user: {other_master_pubkey: other_master_key}, + }, ) self.assertEqual(ret["failures"], {}) # fetch the signed keys/devices and make sure that the signatures are there ret = yield self.handler.query_devices( - {"device_keys": {local_user: [], other_user: []}}, - 0, local_user + {"device_keys": {local_user: [], other_user: []}}, 0, local_user ) self.assertEqual( - ret["device_keys"][local_user]["xyz"]["signatures"][local_user]["ed25519:" + selfsigning_pubkey], - device_key["signatures"][local_user]["ed25519:" + selfsigning_pubkey] + ret["device_keys"][local_user]["xyz"]["signatures"][local_user][ + "ed25519:" + selfsigning_pubkey + ], + device_key["signatures"][local_user]["ed25519:" + selfsigning_pubkey], ) self.assertEqual( - ret["master_keys"][local_user]["signatures"][local_user]["ed25519:" + device_id], - master_key["signatures"][local_user]["ed25519:" + device_id] + ret["master_keys"][local_user]["signatures"][local_user][ + "ed25519:" + device_id + ], + master_key["signatures"][local_user]["ed25519:" + device_id], ) self.assertEqual( - ret["master_keys"][other_user]["signatures"][local_user]["ed25519:" + usersigning_pubkey], - other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey] + ret["master_keys"][other_user]["signatures"][local_user][ + "ed25519:" + usersigning_pubkey + ], + other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey], ) -- cgit 1.5.1 From 9061b4198af4b30bb99d98aab7ad227f8ed636f8 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Mon, 22 Jul 2019 12:58:04 -0400 Subject: make isort happy --- tests/handlers/test_e2e_keys.py | 1 + 1 file changed, 1 insertion(+) (limited to 'tests') diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 8c0ee3f7d3..c900451e03 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -19,6 +19,7 @@ import mock import signedjson.key as key import signedjson.sign as sign + from twisted.internet import defer import synapse.handlers.e2e_keys -- cgit 1.5.1 From 5914fd09c725342d03f702a50ec1da6290e946a9 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Mon, 22 Jul 2019 13:01:10 -0400 Subject: add test --- tests/handlers/test_e2e_keys.py | 88 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) (limited to 'tests') diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index c900451e03..316dd6259d 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -183,6 +183,94 @@ class E2eKeysHandlerTestCase(unittest.TestCase): ) self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]}) + @defer.inlineCallbacks + def test_reupload_signatures(self): + """re-uploading a signature should not fail""" + local_user = "@boris:" + self.hs.hostname + keys1 = { + "master_key": { + # private key: HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8 + "user_id": local_user, + "usage": ["master"], + "keys": { + "ed25519:EmkqvokUn8p+vQAGZitOk4PWjp7Ukp3txV2TbMPEiBQ": "EmkqvokUn8p+vQAGZitOk4PWjp7Ukp3txV2TbMPEiBQ" + }, + }, + "self_signing_key": { + # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0 + "user_id": local_user, + "usage": ["self_signing"], + "keys": { + "ed25519:nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk": "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk" + }, + }, + } + master_signing_key = key.decode_signing_key_base64( + "ed25519", + "EmkqvokUn8p+vQAGZitOk4PWjp7Ukp3txV2TbMPEiBQ", + "HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8", + ) + sign.sign_json(keys1["self_signing_key"], local_user, master_signing_key) + signing_key = key.decode_signing_key_base64( + "ed25519", + "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", + "2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0", + ) + yield self.handler.upload_signing_keys_for_user(local_user, keys1) + + # upload two device keys, which will be signed later by the self-signing key + device_key_1 = { + "user_id": local_user, + "device_id": "abc", + "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"], + "keys": { + "ed25519:abc": "base64+ed25519+key", + "curve25519:abc": "base64+curve25519+key", + }, + "signatures": {local_user: {"ed25519:abc": "base64+signature"}}, + } + device_key_2 = { + "user_id": local_user, + "device_id": "def", + "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"], + "keys": { + "ed25519:def": "base64+ed25519+key", + "curve25519:def": "base64+curve25519+key", + }, + "signatures": {local_user: {"ed25519:def": "base64+signature"}}, + } + + yield self.handler.upload_keys_for_user( + local_user, "abc", {"device_keys": device_key_1} + ) + yield self.handler.upload_keys_for_user( + local_user, "def", {"device_keys": device_key_2} + ) + + # sign the first device key and upload it + del device_key_1["signatures"] + sign.sign_json(device_key_1, local_user, signing_key) + yield self.handler.upload_signatures_for_device_keys( + local_user, {local_user: {"abc": device_key_1}} + ) + + # sign the second device key and upload both device keys. The server + # should ignore the first device key since it already has a valid + # signature for it + del device_key_2["signatures"] + sign.sign_json(device_key_2, local_user, signing_key) + yield self.handler.upload_signatures_for_device_keys( + local_user, {local_user: {"abc": device_key_1, "def": device_key_2}} + ) + + device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature" + device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature" + devices = yield self.handler.query_devices({"device_keys": {local_user: []}}, 0) + del devices["device_keys"][local_user]["abc"]["unsigned"] + del devices["device_keys"][local_user]["def"]["unsigned"] + self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1) + self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2) + @defer.inlineCallbacks def test_self_signing_key_doesnt_show_up_as_device(self): """signing keys should be hidden when fetching a user's devices""" -- cgit 1.5.1 From 3ff0422d2dbfa668df365da99a4b7caeea85528d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 5 Sep 2019 17:16:03 +0100 Subject: Make redaction retention period configurable --- docs/sample_config.yaml | 5 +++++ synapse/config/server.py | 15 +++++++++++++++ synapse/storage/events.py | 6 ++++-- tests/storage/test_redaction.py | 4 +++- 4 files changed, 27 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 43969bbb70..e23b80d2b8 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -306,6 +306,11 @@ listeners: # #allow_per_room_profiles: false +# How long to keep redacted events in unredacted form in the database. +# By default redactions are kept indefinitely. +# +#redaction_retention_period: 30d + ## TLS ## diff --git a/synapse/config/server.py b/synapse/config/server.py index 2abdef0971..8efab924d4 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -162,6 +162,16 @@ class ServerConfig(Config): self.mau_trial_days = config.get("mau_trial_days", 0) + # How long to keep redacted events in the database in unredacted form + # before redacting them. + redaction_retention_period = config.get("redaction_retention_period") + if redaction_retention_period: + self.redaction_retention_period = self.parse_duration( + redaction_retention_period + ) + else: + self.redaction_retention_period = None + # Options to disable HS self.hs_disabled = config.get("hs_disabled", False) self.hs_disabled_message = config.get("hs_disabled_message", "") @@ -718,6 +728,11 @@ class ServerConfig(Config): # Defaults to 'true'. # #allow_per_room_profiles: false + + # How long to keep redacted events in unredacted form in the database. + # By default redactions are kept indefinitely. + # + #redaction_retention_period: 30d """ % locals() ) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 2970da6829..d0d1781c90 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1566,10 +1566,12 @@ class EventsStore( Deferred """ - if self.stream_ordering_month_ago is None: + if not self.hs.config.redaction_retention_period: return - max_pos = self.stream_ordering_month_ago + max_pos = yield self.find_first_stream_ordering_after_ts( + self._clock.time_msec() - self.hs.config.redaction_retention_period + ) # We fetch all redactions that point to an event that we have that has # a stream ordering from over a month ago, that we haven't yet censored diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 0c9f3c7071..f0e86d41a8 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -344,7 +344,9 @@ class RedactionTestCase(unittest.HomeserverTestCase): {"content": {"body": "t", "msgtype": "message"}}, json.loads(event_json) ) - # Advance by 30 days + # Advance by 30 days, then advance again to ensure that the looping call + # for updating the stream position gets called and then the looping call + # for the censoring gets called. self.reactor.advance(60 * 60 * 24 * 31) self.reactor.advance(60 * 60 * 2) -- cgit 1.5.1 From ad9b64b4969537ac339469152eaa437bcf4b6609 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 5 Sep 2019 17:17:47 +0100 Subject: Fix test --- tests/storage/test_redaction.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index f0e86d41a8..deecfad9fb 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -31,8 +31,10 @@ from tests.utils import create_room class RedactionTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): + config = self.default_config() + config["redaction_retention_period"] = "30d" return self.setup_test_homeserver( - resource_for_federation=Mock(), http_client=None + resource_for_federation=Mock(), http_client=None, config=config ) def prepare(self, reactor, clock, hs): -- cgit 1.5.1 From e47af0f086839c5d22a0de87a32a49386abef8df Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Thu, 5 Sep 2019 17:03:14 -0400 Subject: fix test --- tests/handlers/test_e2e_keys.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 316dd6259d..7a59ec5085 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -265,7 +265,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase): device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature" device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature" - devices = yield self.handler.query_devices({"device_keys": {local_user: []}}, 0) + devices = yield self.handler.query_devices( + {"device_keys": {local_user: []}}, 0, 0 + ) del devices["device_keys"][local_user]["abc"]["unsigned"] del devices["device_keys"][local_user]["def"]["unsigned"] self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1) -- cgit 1.5.1 From 55d5b3af8863167432017f23cd8a04a0c14c9d23 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Sat, 7 Sep 2019 01:45:51 +1000 Subject: Servers-known-about statistic (#5981) --- changelog.d/5981.feature | 1 + docs/sample_config.yaml | 10 ++++ synapse/config/metrics.py | 31 ++++++++++ synapse/storage/roommember.py | 59 ++++++++++++++++++ tests/config/test_generate.py | 25 ++++---- tests/config/test_load.py | 34 +++++++---- tests/storage/test_roommember.py | 126 +++++++++++++++++++++++++++------------ 7 files changed, 226 insertions(+), 60 deletions(-) create mode 100644 changelog.d/5981.feature (limited to 'tests') diff --git a/changelog.d/5981.feature b/changelog.d/5981.feature new file mode 100644 index 0000000000..e39514273d --- /dev/null +++ b/changelog.d/5981.feature @@ -0,0 +1 @@ +Setting metrics_flags.known_servers to True in the configuration will publish the synapse_federation_known_servers metric over Prometheus. This represents the total number of servers your server knows about (i.e. is in rooms with), including itself. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 186cdbedd2..93c0edd8ce 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -958,6 +958,16 @@ account_threepid_delegates: #sentry: # dsn: "..." +# Flags to enable Prometheus metrics which are not suitable to be +# enabled by default, either for performance reasons or limited use. +# +metrics_flags: + # Publish synapse_federation_known_servers, a g auge of the number of + # servers this homeserver knows about, including itself. May cause + # performance problems on large homeservers. + # + #known_servers: true + # Whether or not to report anonymized homeserver usage statistics. # report_stats: true|false diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index 3698441963..653b990e67 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import attr + from ._base import Config, ConfigError MISSING_SENTRY = """Missing sentry-sdk library. This is required to enable sentry @@ -20,6 +23,18 @@ MISSING_SENTRY = """Missing sentry-sdk library. This is required to enable sentr """ +@attr.s +class MetricsFlags(object): + known_servers = attr.ib(default=False, validator=attr.validators.instance_of(bool)) + + @classmethod + def all_off(cls): + """ + Instantiate the flags with all options set to off. + """ + return cls(**{x.name: False for x in attr.fields(cls)}) + + class MetricsConfig(Config): def read_config(self, config, **kwargs): self.enable_metrics = config.get("enable_metrics", False) @@ -27,6 +42,12 @@ class MetricsConfig(Config): self.metrics_port = config.get("metrics_port") self.metrics_bind_host = config.get("metrics_bind_host", "127.0.0.1") + if self.enable_metrics: + _metrics_config = config.get("metrics_flags") or {} + self.metrics_flags = MetricsFlags(**_metrics_config) + else: + self.metrics_flags = MetricsFlags.all_off() + self.sentry_enabled = "sentry" in config if self.sentry_enabled: try: @@ -58,6 +79,16 @@ class MetricsConfig(Config): #sentry: # dsn: "..." + # Flags to enable Prometheus metrics which are not suitable to be + # enabled by default, either for performance reasons or limited use. + # + metrics_flags: + # Publish synapse_federation_known_servers, a g auge of the number of + # servers this homeserver knows about, including itself. May cause + # performance problems on large homeservers. + # + #known_servers: true + # Whether or not to report anonymized homeserver usage statistics. """ diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index f8b682ebd9..4df8ebdacd 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -24,8 +24,10 @@ from canonicaljson import json from twisted.internet import defer from synapse.api.constants import EventTypes, Membership +from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import LoggingTransaction +from synapse.storage.engines import Sqlite3Engine from synapse.storage.events_worker import EventsWorkerStore from synapse.types import get_domain_from_id from synapse.util.async_helpers import Linearizer @@ -74,6 +76,63 @@ class RoomMemberWorkerStore(EventsWorkerStore): self._check_safe_current_state_events_membership_updated_txn(txn) txn.close() + if self.hs.config.metrics_flags.known_servers: + self._known_servers_count = 1 + self.hs.get_clock().looping_call( + run_as_background_process, + 60 * 1000, + "_count_known_servers", + self._count_known_servers, + ) + self.hs.get_clock().call_later( + 1000, + run_as_background_process, + "_count_known_servers", + self._count_known_servers, + ) + LaterGauge( + "synapse_federation_known_servers", + "", + [], + lambda: self._known_servers_count, + ) + + @defer.inlineCallbacks + def _count_known_servers(self): + """ + Count the servers that this server knows about. + + The statistic is stored on the class for the + `synapse_federation_known_servers` LaterGauge to collect. + """ + + def _transact(txn): + if isinstance(self.database_engine, Sqlite3Engine): + query = """ + SELECT COUNT(DISTINCT substr(out.user_id, pos+1)) + FROM ( + SELECT rm.user_id as user_id, instr(rm.user_id, ':') + AS pos FROM room_memberships as rm + INNER JOIN current_state_events as c ON rm.event_id = c.event_id + WHERE c.type = 'm.room.member' + ) as out + """ + else: + query = """ + SELECT COUNT(DISTINCT split_part(state_key, ':', 2)) + FROM current_state_events + WHERE type = 'm.room.member' AND membership = 'join'; + """ + txn.execute(query) + return list(txn)[0][0] + + count = yield self.runInteraction("get_known_servers", _transact) + + # We always know about ourselves, even if we have nothing in + # room_memberships (for example, the server is new). + self._known_servers_count = max([count, 1]) + return self._known_servers_count + def _check_safe_current_state_events_membership_updated_txn(self, txn): """Checks if it is safe to assume the new current_state_events membership column is up to date diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py index 5017cbce85..2684e662de 100644 --- a/tests/config/test_generate.py +++ b/tests/config/test_generate.py @@ -17,6 +17,8 @@ import os.path import re import shutil import tempfile +from contextlib import redirect_stdout +from io import StringIO from synapse.config.homeserver import HomeServerConfig @@ -32,17 +34,18 @@ class ConfigGenerationTestCase(unittest.TestCase): shutil.rmtree(self.dir) def test_generate_config_generates_files(self): - HomeServerConfig.load_or_generate_config( - "", - [ - "--generate-config", - "-c", - self.file, - "--report-stats=yes", - "-H", - "lemurs.win", - ], - ) + with redirect_stdout(StringIO()): + HomeServerConfig.load_or_generate_config( + "", + [ + "--generate-config", + "-c", + self.file, + "--report-stats=yes", + "-H", + "lemurs.win", + ], + ) self.assertSetEqual( set(["homeserver.yaml", "lemurs.win.log.config", "lemurs.win.signing.key"]), diff --git a/tests/config/test_load.py b/tests/config/test_load.py index 6bfc1970ad..b3e557bd6a 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -15,6 +15,8 @@ import os.path import shutil import tempfile +from contextlib import redirect_stdout +from io import StringIO import yaml @@ -26,7 +28,6 @@ from tests import unittest class ConfigLoadingTestCase(unittest.TestCase): def setUp(self): self.dir = tempfile.mkdtemp() - print(self.dir) self.file = os.path.join(self.dir, "homeserver.yaml") def tearDown(self): @@ -94,18 +95,27 @@ class ConfigLoadingTestCase(unittest.TestCase): ) self.assertTrue(config.enable_registration) + def test_stats_enabled(self): + self.generate_config_and_remove_lines_containing("enable_metrics") + self.add_lines_to_config(["enable_metrics: true"]) + + # The default Metrics Flags are off by default. + config = HomeServerConfig.load_config("", ["-c", self.file]) + self.assertFalse(config.metrics_flags.known_servers) + def generate_config(self): - HomeServerConfig.load_or_generate_config( - "", - [ - "--generate-config", - "-c", - self.file, - "--report-stats=yes", - "-H", - "lemurs.win", - ], - ) + with redirect_stdout(StringIO()): + HomeServerConfig.load_or_generate_config( + "", + [ + "--generate-config", + "-c", + self.file, + "--report-stats=yes", + "-H", + "lemurs.win", + ], + ) def generate_config_and_remove_lines_containing(self, needle): self.generate_config() diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 64cb294c37..447a3c6ffb 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,78 +14,129 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from mock import Mock - -from twisted.internet import defer +from unittest.mock import Mock from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions -from synapse.types import Requester, RoomID, UserID +from synapse.rest.admin import register_servlets_for_client_rest_resource +from synapse.rest.client.v1 import login, room +from synapse.types import Requester, UserID from tests import unittest -from tests.utils import create_room, setup_test_homeserver -class RoomMemberStoreTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield setup_test_homeserver( - self.addCleanup, resource_for_federation=Mock(), http_client=None +class RoomMemberStoreTestCase(unittest.HomeserverTestCase): + + servlets = [ + login.register_servlets, + register_servlets_for_client_rest_resource, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver( + resource_for_federation=Mock(), http_client=None ) + return hs + + def prepare(self, reactor, clock, hs): + # We can't test the RoomMemberStore on its own without the other event # storage logic self.store = hs.get_datastore() self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() - self.u_alice = UserID.from_string("@alice:test") - self.u_bob = UserID.from_string("@bob:test") + self.u_alice = self.register_user("alice", "pass") + self.t_alice = self.login("alice", "pass") + self.u_bob = self.register_user("bob", "pass") # User elsewhere on another host self.u_charlie = UserID.from_string("@charlie:elsewhere") - self.room = RoomID.from_string("!abc123:test") - - yield create_room(hs, self.room.to_string(), self.u_alice.to_string()) - - @defer.inlineCallbacks def inject_room_member(self, room, user, membership, replaces_state=None): builder = self.event_builder_factory.for_room_version( RoomVersions.V1, { "type": EventTypes.Member, - "sender": user.to_string(), - "state_key": user.to_string(), - "room_id": room.to_string(), + "sender": user, + "state_key": user, + "room_id": room, "content": {"membership": membership}, }, ) - event, context = yield self.event_creation_handler.create_new_client_event( - builder + event, context = self.get_success( + self.event_creation_handler.create_new_client_event(builder) ) - yield self.store.persist_event(event, context) + self.get_success(self.store.persist_event(event, context)) return event - @defer.inlineCallbacks def test_one_member(self): - yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN) - - self.assertEquals( - [self.room.to_string()], - [ - m.room_id - for m in ( - yield self.store.get_rooms_for_user_where_membership_is( - self.u_alice.to_string(), [Membership.JOIN] - ) - ) - ], + + # Alice creates the room, and is automatically joined + self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) + + rooms_for_user = self.get_success( + self.store.get_rooms_for_user_where_membership_is( + self.u_alice, [Membership.JOIN] + ) ) + self.assertEquals([self.room], [m.room_id for m in rooms_for_user]) + + def test_count_known_servers(self): + """ + _count_known_servers will calculate how many servers are in a room. + """ + self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) + self.inject_room_member(self.room, self.u_bob, Membership.JOIN) + self.inject_room_member(self.room, self.u_charlie.to_string(), Membership.JOIN) + + servers = self.get_success(self.store._count_known_servers()) + self.assertEqual(servers, 2) + + def test_count_known_servers_stat_counter_disabled(self): + """ + If enabled, the metrics for how many servers are known will be counted. + """ + self.assertTrue("_known_servers_count" not in self.store.__dict__.keys()) + + self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) + self.inject_room_member(self.room, self.u_bob, Membership.JOIN) + self.inject_room_member(self.room, self.u_charlie.to_string(), Membership.JOIN) + + self.pump(20) + + self.assertTrue("_known_servers_count" not in self.store.__dict__.keys()) + + @unittest.override_config( + {"enable_metrics": True, "metrics_flags": {"known_servers": True}} + ) + def test_count_known_servers_stat_counter_enabled(self): + """ + If enabled, the metrics for how many servers are known will be counted. + """ + # Initialises to 1 -- itself + self.assertEqual(self.store._known_servers_count, 1) + + self.pump(20) + + # No rooms have been joined, so technically the SQL returns 0, but it + # will still say it knows about itself. + self.assertEqual(self.store._known_servers_count, 1) + + self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) + self.inject_room_member(self.room, self.u_bob, Membership.JOIN) + self.inject_room_member(self.room, self.u_charlie.to_string(), Membership.JOIN) + + self.pump(20) + + # It now knows about Charlie's server. + self.assertEqual(self.store._known_servers_count, 2) + class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): -- cgit 1.5.1 From ab729e31cfca4d1a958937bb576010271b9c8044 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Fri, 6 Sep 2019 17:52:37 -0400 Subject: use something that's the right type for user_id --- tests/handlers/test_e2e_keys.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 7a59ec5085..854eb6c024 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -266,7 +266,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature" device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature" devices = yield self.handler.query_devices( - {"device_keys": {local_user: []}}, 0, 0 + {"device_keys": {local_user: []}}, 0, local_user ) del devices["device_keys"][local_user]["abc"]["unsigned"] del devices["device_keys"][local_user]["def"]["unsigned"] -- cgit 1.5.1 From be618e055178f4aa9865ab426182218312bed07f Mon Sep 17 00:00:00 2001 From: Jason Robinson Date: Mon, 9 Sep 2019 14:43:51 +0300 Subject: Only count real users when checking for auto-creation of auto-join room Previously if the first registered user was a "support" or "bot" user, when the first real user registers, the auto-join rooms were not created. Fix to exclude non-real (ie users with a special user type) users when counting how many users there are to determine whether we should auto-create a room. Signed-off-by: Jason Robinson --- changelog.d/6004.bugfix | 1 + synapse/handlers/register.py | 12 ++++-------- synapse/storage/registration.py | 39 +++++++++++++++++++++++++++++++++++++++ tests/handlers/test_register.py | 29 +++++++++++++++++++++++++++-- 4 files changed, 71 insertions(+), 10 deletions(-) create mode 100644 changelog.d/6004.bugfix (limited to 'tests') diff --git a/changelog.d/6004.bugfix b/changelog.d/6004.bugfix new file mode 100644 index 0000000000..45c179c8fd --- /dev/null +++ b/changelog.d/6004.bugfix @@ -0,0 +1 @@ +Only count real users when checking for auto-creation of auto-join room. diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 975da57ffd..06bd03b77c 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -275,16 +275,12 @@ class RegistrationHandler(BaseHandler): fake_requester = create_requester(user_id) # try to create the room if we're the first real user on the server. Note - # that an auto-generated support user is not a real user and will never be + # that an auto-generated support or bot user is not a real user and will never be # the user to create the room should_auto_create_rooms = False - is_support = yield self.store.is_support_user(user_id) - # There is an edge case where the first user is the support user, then - # the room is never created, though this seems unlikely and - # recoverable from given the support user being involved in the first - # place. - if self.hs.config.autocreate_auto_join_rooms and not is_support: - count = yield self.store.count_all_users() + is_real_user = yield self.store.is_real_user(user_id) + if self.hs.config.autocreate_auto_join_rooms and is_real_user: + count = yield self.store.count_real_users() should_auto_create_rooms = count == 1 for r in self.hs.config.auto_join_rooms: logger.info("Auto-joining %s to %s", user_id, r) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 5138792a5f..b054d86ae0 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -322,6 +322,21 @@ class RegistrationWorkerStore(SQLBaseStore): return None + @cachedInlineCallbacks() + def is_real_user(self, user_id): + """Determines if the user is a real user, ie does not have a 'user_type'. + + Args: + user_id (str): user id to test + + Returns: + Deferred[bool]: True if user 'user_type' is null or empty string + """ + res = yield self.runInteraction( + "is_real_user", self.is_real_user_txn, user_id + ) + return res + @cachedInlineCallbacks() def is_support_user(self, user_id): """Determines if the user is of type UserTypes.SUPPORT @@ -337,6 +352,16 @@ class RegistrationWorkerStore(SQLBaseStore): ) return res + def is_real_user_txn(self, txn, user_id): + res = self._simple_select_one_onecol_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + retcol="user_type", + allow_none=True, + ) + return True if res is None or res == "" else False + def is_support_user_txn(self, txn, user_id): res = self._simple_select_one_onecol_txn( txn=txn, @@ -421,6 +446,20 @@ class RegistrationWorkerStore(SQLBaseStore): ret = yield self.runInteraction("count_users", _count_users) return ret + @defer.inlineCallbacks + def count_real_users(self): + """Counts all users without a special user_type registered on the homeserver.""" + + def _count_users(txn): + txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null or user_type = ''") + rows = self.cursor_to_dict(txn) + if rows: + return rows[0]["users"] + return 0 + + ret = yield self.runInteraction("count_real_users", _count_users) + return ret + @defer.inlineCallbacks def find_next_generated_user_id_localpart(self): """ diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index e10296a5e4..1e9ba3a201 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -171,11 +171,11 @@ class RegistrationTestCase(unittest.HomeserverTestCase): rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) - def test_auto_create_auto_join_rooms_when_support_user_exists(self): + def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - self.store.is_support_user = Mock(return_value=True) + self.store.is_real_user = Mock(return_value=False) user_id = self.get_success(self.handler.register_user(localpart="support")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) @@ -183,6 +183,31 @@ class RegistrationTestCase(unittest.HomeserverTestCase): room_alias = RoomAlias.from_string(room_alias_str) self.get_failure(directory_handler.get_association(room_alias), SynapseError) + def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self): + room_alias_str = "#room:test" + self.hs.config.auto_join_rooms = [room_alias_str] + + self.store.count_real_users = Mock(return_value=1) + self.store.is_real_user = Mock(return_value=True) + user_id = self.get_success(self.handler.register_user(localpart="real")) + rooms = self.get_success(self.store.get_rooms_for_user(user_id)) + directory_handler = self.hs.get_handlers().directory_handler + room_alias = RoomAlias.from_string(room_alias_str) + room_id = self.get_success(directory_handler.get_association(room_alias)) + + self.assertTrue(room_id["room_id"] in rooms) + self.assertEqual(len(rooms), 1) + + def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(self): + room_alias_str = "#room:test" + self.hs.config.auto_join_rooms = [room_alias_str] + + self.store.count_real_users = Mock(return_value=2) + self.store.is_real_user = Mock(return_value=True) + user_id = self.get_success(self.handler.register_user(localpart="real")) + rooms = self.get_success(self.store.get_rooms_for_user(user_id)) + self.assertEqual(len(rooms), 0) + def test_auto_create_auto_join_where_no_consent(self): """Test to ensure that the first user is not auto-joined to a room if they have not given general consent. -- cgit 1.5.1 From aeb9b2179eaa4b468bec937570d3ac7de7ccaaea Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Tue, 10 Sep 2019 00:14:58 +1000 Subject: Add a build info metric to Prometheus (#6005) --- changelog.d/6005.feature | 1 + synapse/metrics/__init__.py | 12 ++++++++++++ tests/test_metrics.py | 22 ++++++++++++++++++++-- 3 files changed, 33 insertions(+), 2 deletions(-) create mode 100644 changelog.d/6005.feature (limited to 'tests') diff --git a/changelog.d/6005.feature b/changelog.d/6005.feature new file mode 100644 index 0000000000..ed6491d3e4 --- /dev/null +++ b/changelog.d/6005.feature @@ -0,0 +1 @@ +The new Prometheus metric `synapse_build_info` exposes the Python version, OS version, and Synapse version of the running server. diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index 488280b4a6..b5c9595cb9 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -29,11 +29,13 @@ from prometheus_client.core import REGISTRY, GaugeMetricFamily, HistogramMetricF from twisted.internet import reactor +import synapse from synapse.metrics._exposition import ( MetricsResource, generate_latest, start_http_server, ) +from synapse.util.versionstring import get_version_string logger = logging.getLogger(__name__) @@ -385,6 +387,16 @@ event_processing_last_ts = Gauge("synapse_event_processing_last_ts", "", ["name" # finished being processed. event_processing_lag = Gauge("synapse_event_processing_lag", "", ["name"]) +# Build info of the running server. +build_info = Gauge( + "synapse_build_info", "Build information", ["pythonversion", "version", "osversion"] +) +build_info.labels( + " ".join([platform.python_implementation(), platform.python_version()]), + get_version_string(synapse), + " ".join([platform.system(), platform.release()]), +).set(1) + last_ticked = time.time() diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 2edbae5c6d..270f853d60 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2018 New Vector Ltd +# Copyright 2019 Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,8 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from synapse.metrics import InFlightGauge +from synapse.metrics import REGISTRY, InFlightGauge, generate_latest from tests import unittest @@ -111,3 +111,21 @@ class TestMauLimit(unittest.TestCase): } return results + + +class BuildInfoTests(unittest.TestCase): + def test_get_build(self): + """ + The synapse_build_info metric reports the OS version, Python version, + and Synapse version. + """ + items = list( + filter( + lambda x: b"synapse_build_info{" in x, + generate_latest(REGISTRY).split(b"\n"), + ) + ) + self.assertEqual(len(items), 1) + self.assertTrue(b"osversion=" in items[0]) + self.assertTrue(b"pythonversion=" in items[0]) + self.assertTrue(b"version=" in items[0]) -- cgit 1.5.1 From caa9d6fed719a8a80eb4a998d32f09577d04f927 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 11 Sep 2019 11:16:23 +0100 Subject: Add test for admin redaction ratelimiting. --- tests/rest/client/test_redactions.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) (limited to 'tests') diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py index fe66e397c4..1b1e991c42 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py @@ -30,6 +30,14 @@ class RedactionsTestCase(HomeserverTestCase): sync.register_servlets, ] + def make_homeserver(self, reactor, clock): + config = self.default_config() + + config["rc_message"] = {"per_second": 0.2, "burst_count": 10} + config["rc_admin_redaction"] = {"per_second": 1, "burst_count": 100} + + return self.setup_test_homeserver(config=config) + def prepare(self, reactor, clock, hs): # register a couple of users self.mod_user_id = self.register_user("user1", "pass") @@ -177,3 +185,20 @@ class RedactionsTestCase(HomeserverTestCase): self._redact_event( self.other_access_token, self.room_id, create_event_id, expect_code=403 ) + + def test_redact_event_as_moderator_ratelimit(self): + """Tests that the correct ratelimiting is applied to redactions + """ + + message_ids = [] + # as a regular user, send messages to redact + for _ in range(20): + b = self.helper.send(room_id=self.room_id, tok=self.other_access_token) + message_ids.append(b["event_id"]) + self.reactor.advance(10) # To get around ratelimits + + # as the moderator, send a bunch of redactions redaction + for msg_id in message_ids: + # These should all succeed, even though this would be denied by + # standard message ratelimiter + self._redact_event(self.mod_access_token, self.room_id, msg_id) -- cgit 1.5.1 From 57dd41a45b4df5d736e2f30d40926b60f367b500 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 11 Sep 2019 13:54:50 +0100 Subject: Fix comments Co-Authored-By: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> --- synapse/config/ratelimiting.py | 2 +- synapse/handlers/_base.py | 2 +- synapse/handlers/message.py | 2 +- tests/rest/client/test_redactions.py | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) (limited to 'tests') diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index b4df6612d6..587e2862b7 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -110,7 +110,7 @@ class RatelimitConfig(Config): # attempts for this account. # - one for ratelimiting redactions by room admins. If this is not explicitly # set then it uses the same ratelimiting as per rc_message. This is useful - # to allow room admins to quickly deal with abuse quickly. + # to allow room admins to deal with abuse quickly. # # The defaults are as shown below. # diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 853b72d8e7..d15c6282fb 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -105,7 +105,7 @@ class BaseHandler(object): if is_admin_redaction and self.hs.config.rc_admin_redaction: # If we have separate config for admin redactions we use a separate - # ratelimiter. + # ratelimiter allowed, time_allowed = self.admin_redaction_ratelimiter.can_do_action( user_id, time_now, diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index f975909416..1f8272784e 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -731,7 +731,7 @@ class EventCreationHandler(object): if ratelimit: # We check if this is a room admin redacting an event so that we # can apply different ratelimiting. We do this by simply checking - # its not a self-redaction (to avoid having to look up whether the + # it's not a self-redaction (to avoid having to look up whether the # user is actually admin or not). is_admin_redaction = False if event.type == EventTypes.Redaction: diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py index 1b1e991c42..d2bcf256fa 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py @@ -197,8 +197,8 @@ class RedactionsTestCase(HomeserverTestCase): message_ids.append(b["event_id"]) self.reactor.advance(10) # To get around ratelimits - # as the moderator, send a bunch of redactions redaction + # as the moderator, send a bunch of redactions for msg_id in message_ids: # These should all succeed, even though this would be denied by - # standard message ratelimiter + # the standard message ratelimiter self._redact_event(self.mod_access_token, self.room_id, msg_id) -- cgit 1.5.1 From 6d847d8ce69f2cb849633265aaeb4a9df4ff713d Mon Sep 17 00:00:00 2001 From: Jason Robinson Date: Wed, 11 Sep 2019 20:22:18 +0300 Subject: Ensure support users can be registered even if MAU limit is reached This allows support users to be created even on MAU limits via the admin API. Support users are excluded from MAU after creation, so it makes sense to exclude them in creation - except if the whole host is in disabled state. Signed-off-by: Jason Robinson --- changelog.d/6020.bugfix | 1 + synapse/api/auth.py | 11 +++++++++-- tests/api/test_auth.py | 18 ++++++++++++++++++ 3 files changed, 28 insertions(+), 2 deletions(-) create mode 100644 changelog.d/6020.bugfix (limited to 'tests') diff --git a/changelog.d/6020.bugfix b/changelog.d/6020.bugfix new file mode 100644 index 0000000000..58a7deba9d --- /dev/null +++ b/changelog.d/6020.bugfix @@ -0,0 +1 @@ +Ensure support users can be registered even if MAU limit is reached. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index ddc195bc32..9e445cd808 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -25,7 +25,7 @@ from twisted.internet import defer import synapse.logging.opentracing as opentracing import synapse.types from synapse import event_auth -from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.api.constants import EventTypes, JoinRules, Membership, UserTypes from synapse.api.errors import ( AuthError, Codes, @@ -709,7 +709,7 @@ class Auth(object): ) @defer.inlineCallbacks - def check_auth_blocking(self, user_id=None, threepid=None): + def check_auth_blocking(self, user_id=None, threepid=None, user_type=None): """Checks if the user should be rejected for some external reason, such as monthly active user limiting or global disable flag @@ -722,6 +722,9 @@ class Auth(object): with a MAU blocked server, normally they would be rejected but their threepid is on the reserved list. user_id and threepid should never be set at the same time. + + user_type(str|None): If present, is used to decide whether to check against + certain blocking reasons like MAU. """ # Never fail an auth check for the server notices users or support user @@ -759,6 +762,10 @@ class Auth(object): self.hs.config.mau_limits_reserved_threepids, threepid ): return + elif user_type == UserTypes.SUPPORT: + # If the user does not exist yet and is of type "support", + # allow registration. Support users are excluded from MAU checks. + return # Else if there is no room in the MAU bucket, bail current_mau = yield self.store.get_monthly_active_count() if current_mau >= self.hs.config.max_mau_value: diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index c0cb8ef296..6121efcfa9 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -21,6 +21,7 @@ from twisted.internet import defer import synapse.handlers.auth from synapse.api.auth import Auth +from synapse.api.constants import UserTypes from synapse.api.errors import ( AuthError, Codes, @@ -335,6 +336,23 @@ class AuthTestCase(unittest.TestCase): ) yield self.auth.check_auth_blocking() + @defer.inlineCallbacks + def test_blocking_mau__depending_on_user_type(self): + self.hs.config.max_mau_value = 50 + self.hs.config.limit_usage_by_mau = True + + self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100)) + # Support users allowed + yield self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT) + self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100)) + # Bots not allowed + with self.assertRaises(ResourceLimitError): + yield self.auth.check_auth_blocking(user_type=UserTypes.BOT) + self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100)) + # Real users not allowed + with self.assertRaises(ResourceLimitError): + yield self.auth.check_auth_blocking() + @defer.inlineCallbacks def test_reserved_threepid(self): self.hs.config.limit_usage_by_mau = True -- cgit 1.5.1 From b617864cd9f81109e818bc5ae95bee317d917b72 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Fri, 13 Sep 2019 02:29:55 +1000 Subject: Fix for structured logging tests stomping on logs (#6023) --- MANIFEST.in | 12 +++++---- changelog.d/6023.misc | 1 + mypy.ini | 54 ++++++++++++++++++++++++++++++++++++++++ synapse/config/logger.py | 33 ++++++++++++++++++------ synapse/logging/_structured.py | 8 +++--- synapse/logging/_terse_json.py | 8 +++--- synapse/logging/opentracing.py | 4 +-- synapse/metrics/__init__.py | 5 ++-- synapse/metrics/_exposition.py | 4 ++- synapse/python_dependencies.py | 7 +++--- tests/logging/test_structured.py | 25 ++++++++++++++++--- tests/logging/test_terse_json.py | 4 +-- tox.ini | 30 +++++++++++++++++----- 13 files changed, 154 insertions(+), 41 deletions(-) create mode 100644 changelog.d/6023.misc create mode 100644 mypy.ini (limited to 'tests') diff --git a/MANIFEST.in b/MANIFEST.in index 919cd8a1cd..9c2902b8d2 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -38,14 +38,16 @@ exclude sytest-blacklist include pyproject.toml recursive-include changelog.d * -prune .github -prune demo/etc -prune docker +prune .buildkite prune .circleci +prune .codecov.yml prune .coveragerc +prune .github prune debian -prune .codecov.yml -prune .buildkite +prune demo/etc +prune docker +prune mypy.ini +prune stubs exclude jenkins* recursive-exclude jenkins *.sh diff --git a/changelog.d/6023.misc b/changelog.d/6023.misc new file mode 100644 index 0000000000..d80410c22c --- /dev/null +++ b/changelog.d/6023.misc @@ -0,0 +1 @@ +Fix the structured logging tests stomping on the global log configuration for subsequent tests. diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000000..8788574ee3 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,54 @@ +[mypy] +namespace_packages=True +plugins=mypy_zope:plugin +follow_imports=skip +mypy_path=stubs + +[mypy-synapse.config.homeserver] +# this is a mess because of the metaclass shenanigans +ignore_errors = True + +[mypy-zope] +ignore_missing_imports = True + +[mypy-constantly] +ignore_missing_imports = True + +[mypy-twisted.*] +ignore_missing_imports = True + +[mypy-treq.*] +ignore_missing_imports = True + +[mypy-hyperlink] +ignore_missing_imports = True + +[mypy-h11] +ignore_missing_imports = True + +[mypy-opentracing] +ignore_missing_imports = True + +[mypy-OpenSSL] +ignore_missing_imports = True + +[mypy-netaddr] +ignore_missing_imports = True + +[mypy-saml2.*] +ignore_missing_imports = True + +[mypy-unpaddedbase64] +ignore_missing_imports = True + +[mypy-canonicaljson] +ignore_missing_imports = True + +[mypy-jaeger_client] +ignore_missing_imports = True + +[mypy-jsonschema] +ignore_missing_imports = True + +[mypy-signedjson.*] +ignore_missing_imports = True diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 2704c18720..767ecfdf09 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -21,7 +21,12 @@ from string import Template import yaml -from twisted.logger import STDLibLogObserver, globalLogBeginner +from twisted.logger import ( + ILogObserver, + LogBeginner, + STDLibLogObserver, + globalLogBeginner, +) import synapse from synapse.app import _base as appbase @@ -124,7 +129,7 @@ class LoggingConfig(Config): log_config_file.write(DEFAULT_LOG_CONFIG.substitute(log_file=log_file)) -def _setup_stdlib_logging(config, log_config): +def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner): """ Set up Python stdlib logging. """ @@ -165,12 +170,12 @@ def _setup_stdlib_logging(config, log_config): return observer(event) - globalLogBeginner.beginLoggingTo( - [_log], redirectStandardIO=not config.no_redirect_stdio - ) + logBeginner.beginLoggingTo([_log], redirectStandardIO=not config.no_redirect_stdio) if not config.no_redirect_stdio: print("Redirected stdout/stderr to logs") + return observer + def _reload_stdlib_logging(*args, log_config=None): logger = logging.getLogger("") @@ -181,7 +186,9 @@ def _reload_stdlib_logging(*args, log_config=None): logging.config.dictConfig(log_config) -def setup_logging(hs, config, use_worker_options=False): +def setup_logging( + hs, config, use_worker_options=False, logBeginner: LogBeginner = globalLogBeginner +) -> ILogObserver: """ Set up the logging subsystem. @@ -191,6 +198,12 @@ def setup_logging(hs, config, use_worker_options=False): use_worker_options (bool): True to use the 'worker_log_config' option instead of 'log_config'. + + logBeginner: The Twisted logBeginner to use. + + Returns: + The "root" Twisted Logger observer, suitable for sending logs to from a + Logger instance. """ log_config = config.worker_log_config if use_worker_options else config.log_config @@ -210,10 +223,12 @@ def setup_logging(hs, config, use_worker_options=False): log_config_body = read_config() if log_config_body and log_config_body.get("structured") is True: - setup_structured_logging(hs, config, log_config_body) + logger = setup_structured_logging( + hs, config, log_config_body, logBeginner=logBeginner + ) appbase.register_sighup(read_config, callback=reload_structured_logging) else: - _setup_stdlib_logging(config, log_config_body) + logger = _setup_stdlib_logging(config, log_config_body, logBeginner=logBeginner) appbase.register_sighup(read_config, callback=_reload_stdlib_logging) # make sure that the first thing we log is a thing we can grep backwards @@ -221,3 +236,5 @@ def setup_logging(hs, config, use_worker_options=False): logging.warn("***** STARTING SERVER *****") logging.warn("Server %s version %s", sys.argv[0], get_version_string(synapse)) logging.info("Server hostname: %s", config.server_name) + + return logger diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py index 0367d6dfc4..3220e985a9 100644 --- a/synapse/logging/_structured.py +++ b/synapse/logging/_structured.py @@ -18,6 +18,7 @@ import os.path import sys import typing import warnings +from typing import List import attr from constantly import NamedConstant, Names, ValueConstant, Values @@ -33,7 +34,6 @@ from twisted.logger import ( LogLevelFilterPredicate, LogPublisher, eventAsText, - globalLogBeginner, jsonFileLogObserver, ) @@ -134,7 +134,7 @@ class PythonStdlibToTwistedLogger(logging.Handler): ) -def SynapseFileLogObserver(outFile: typing.io.TextIO) -> FileLogObserver: +def SynapseFileLogObserver(outFile: typing.IO[str]) -> FileLogObserver: """ A log observer that formats events like the traditional log formatter and sends them to `outFile`. @@ -265,7 +265,7 @@ def setup_structured_logging( hs, config, log_config: dict, - logBeginner: LogBeginner = globalLogBeginner, + logBeginner: LogBeginner, redirect_stdlib_logging: bool = True, ) -> LogPublisher: """ @@ -286,7 +286,7 @@ def setup_structured_logging( if "drains" not in log_config: raise ConfigError("The logging configuration requires a list of drains.") - observers = [] + observers = [] # type: List[ILogObserver] for observer in parse_drain_configs(log_config["drains"]): # Pipe drains diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py index 7f1e8f23fe..0ebbde06f2 100644 --- a/synapse/logging/_terse_json.py +++ b/synapse/logging/_terse_json.py @@ -21,10 +21,11 @@ import sys from collections import deque from ipaddress import IPv4Address, IPv6Address, ip_address from math import floor -from typing.io import TextIO +from typing import IO import attr from simplejson import dumps +from zope.interface import implementer from twisted.application.internet import ClientService from twisted.internet.endpoints import ( @@ -33,7 +34,7 @@ from twisted.internet.endpoints import ( TCP6ClientEndpoint, ) from twisted.internet.protocol import Factory, Protocol -from twisted.logger import FileLogObserver, Logger +from twisted.logger import FileLogObserver, ILogObserver, Logger from twisted.python.failure import Failure @@ -129,7 +130,7 @@ def flatten_event(event: dict, metadata: dict, include_time: bool = False): return new_event -def TerseJSONToConsoleLogObserver(outFile: TextIO, metadata: dict) -> FileLogObserver: +def TerseJSONToConsoleLogObserver(outFile: IO[str], metadata: dict) -> FileLogObserver: """ A log observer that formats events to a flattened JSON representation. @@ -146,6 +147,7 @@ def TerseJSONToConsoleLogObserver(outFile: TextIO, metadata: dict) -> FileLogObs @attr.s +@implementer(ILogObserver) class TerseJSONToTCPLogObserver(object): """ An IObserver that writes JSON logs to a TCP target. diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 7246253018..308a27213b 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -223,8 +223,8 @@ try: from jaeger_client import Config as JaegerConfig from synapse.logging.scopecontextmanager import LogContextScopeManager except ImportError: - JaegerConfig = None - LogContextScopeManager = None + JaegerConfig = None # type: ignore + LogContextScopeManager = None # type: ignore logger = logging.getLogger(__name__) diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index b5c9595cb9..bec3b13397 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -20,6 +20,7 @@ import os import platform import threading import time +from typing import Dict, Union import six @@ -42,9 +43,7 @@ logger = logging.getLogger(__name__) METRICS_PREFIX = "/_synapse/metrics" running_on_pypy = platform.python_implementation() == "PyPy" -all_metrics = [] -all_collectors = [] -all_gauges = {} +all_gauges = {} # type: Dict[str, Union[LaterGauge, InFlightGauge, BucketCollector]] HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat") diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py index 1933ecd3e3..74d9c3ecd3 100644 --- a/synapse/metrics/_exposition.py +++ b/synapse/metrics/_exposition.py @@ -36,7 +36,9 @@ from twisted.web.resource import Resource try: from prometheus_client.samples import Sample except ImportError: - Sample = namedtuple("Sample", ["name", "labels", "value", "timestamp", "exemplar"]) + Sample = namedtuple( + "Sample", ["name", "labels", "value", "timestamp", "exemplar"] + ) # type: ignore CONTENT_TYPE_LATEST = str("text/plain; version=0.0.4; charset=utf-8") diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 07345e916a..0bd563edc7 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -15,6 +15,7 @@ # limitations under the License. import logging +from typing import Set from pkg_resources import ( DistributionNotFound, @@ -97,7 +98,7 @@ CONDITIONAL_REQUIREMENTS = { "jwt": ["pyjwt>=1.6.4"], } -ALL_OPTIONAL_REQUIREMENTS = set() +ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str] for name, optional_deps in CONDITIONAL_REQUIREMENTS.items(): # Exclude systemd as it's a system-based requirement. @@ -174,8 +175,8 @@ def check_requirements(for_feature=None): pass if deps_needed: - for e in errors: - logging.error(e) + for err in errors: + logging.error(err) raise DependencyException(deps_needed) diff --git a/tests/logging/test_structured.py b/tests/logging/test_structured.py index a786de0233..451d05c0f0 100644 --- a/tests/logging/test_structured.py +++ b/tests/logging/test_structured.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os import os.path import shutil @@ -33,7 +34,20 @@ class FakeBeginner(object): self.observers = observers -class StructuredLoggingTestCase(HomeserverTestCase): +class StructuredLoggingTestBase(object): + """ + Test base that registers a cleanup handler to reset the stdlib log handler + to 'unset'. + """ + + def prepare(self, reactor, clock, hs): + def _cleanup(): + logging.getLogger("synapse").setLevel(logging.NOTSET) + + self.addCleanup(_cleanup) + + +class StructuredLoggingTestCase(StructuredLoggingTestBase, HomeserverTestCase): """ Tests for Synapse's structured logging support. """ @@ -139,7 +153,9 @@ class StructuredLoggingTestCase(HomeserverTestCase): self.assertEqual(logs[0]["request"], "somereq") -class StructuredLoggingConfigurationFileTestCase(HomeserverTestCase): +class StructuredLoggingConfigurationFileTestCase( + StructuredLoggingTestBase, HomeserverTestCase +): def make_homeserver(self, reactor, clock): tempdir = self.mktemp() @@ -179,10 +195,11 @@ class StructuredLoggingConfigurationFileTestCase(HomeserverTestCase): """ When a structured logging config is given, Synapse will use it. """ - setup_logging(self.hs, self.hs.config) + beginner = FakeBeginner() + publisher = setup_logging(self.hs, self.hs.config, logBeginner=beginner) # Make a logger and send an event - logger = Logger(namespace="tests.logging.test_structured") + logger = Logger(namespace="tests.logging.test_structured", observer=publisher) with LoggingContext("testcontext", request="somereq"): logger.info("Hello there, {name}!", name="steve") diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py index 514282591d..4cf81f7128 100644 --- a/tests/logging/test_terse_json.py +++ b/tests/logging/test_terse_json.py @@ -23,10 +23,10 @@ from synapse.logging._structured import setup_structured_logging from tests.server import connect_client from tests.unittest import HomeserverTestCase -from .test_structured import FakeBeginner +from .test_structured import FakeBeginner, StructuredLoggingTestBase -class TerseJSONTCPTestCase(HomeserverTestCase): +class TerseJSONTCPTestCase(StructuredLoggingTestBase, HomeserverTestCase): def test_log_output(self): """ The Terse JSON outputter delivers simplified structured logs over TCP. diff --git a/tox.ini b/tox.ini index 7cb40847b5..1bce10a4ce 100644 --- a/tox.ini +++ b/tox.ini @@ -2,6 +2,7 @@ envlist = packaging, py35, py36, py37, check_codestyle, check_isort [base] +basepython = python3.7 deps = mock python-subunit @@ -137,18 +138,35 @@ commands = {toxinidir}/scripts-dev/generate_sample_config --check skip_install = True deps = coverage -whitelist_externals = - bash commands= coverage combine coverage report +[testenv:cov-erase] +skip_install = True +deps = + coverage +commands= + coverage erase + +[testenv:cov-html] +skip_install = True +deps = + coverage +commands= + coverage html + [testenv:mypy] -basepython = python3.5 +basepython = python3.7 +skip_install = True deps = {[base]deps} mypy + mypy-zope + typeshed +env = + MYPYPATH = stubs/ extras = all -commands = mypy --ignore-missing-imports \ - synapse/logging/_structured.py \ - synapse/logging/_terse_json.py +commands = mypy --show-traceback \ + synapse/logging/ \ + synapse/config/ -- cgit 1.5.1 From 850dcfd2d3a1d689042fb38c8a16b652244068c2 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Sat, 14 Sep 2019 04:58:38 +1000 Subject: Fix well-known lookups with the federation certificate whitelist (#5997) --- changelog.d/5996.bugfix | 1 + synapse/config/tls.py | 9 ++++- synapse/crypto/context_factory.py | 26 +++++++------- synapse/http/federation/matrix_federation_agent.py | 2 +- tests/config/test_tls.py | 40 ++++++++++++++++++++++ 5 files changed, 63 insertions(+), 15 deletions(-) create mode 100644 changelog.d/5996.bugfix (limited to 'tests') diff --git a/changelog.d/5996.bugfix b/changelog.d/5996.bugfix new file mode 100644 index 0000000000..05e31faaa2 --- /dev/null +++ b/changelog.d/5996.bugfix @@ -0,0 +1 @@ +federation_certificate_verification_whitelist now will not cause TypeErrors to be raised (a regression in 1.3). Additionally, it now supports internationalised domain names in their non-canonical representation. diff --git a/synapse/config/tls.py b/synapse/config/tls.py index c0148aa95c..fc47ba3e9a 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -110,8 +110,15 @@ class TlsConfig(Config): # Support globs (*) in whitelist values self.federation_certificate_verification_whitelist = [] for entry in fed_whitelist_entries: + try: + entry_regex = glob_to_regex(entry.encode("ascii").decode("ascii")) + except UnicodeEncodeError: + raise ConfigError( + "IDNA domain names are not allowed in the " + "federation_certificate_verification_whitelist: %s" % (entry,) + ) + # Convert globs to regex - entry_regex = glob_to_regex(entry) self.federation_certificate_verification_whitelist.append(entry_regex) # List of custom certificate authorities for federation traffic validation diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index 06e63a96b5..e93f0b3705 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -15,7 +15,6 @@ import logging -import idna from service_identity import VerificationError from service_identity.pyopenssl import verify_hostname, verify_ip_address from zope.interface import implementer @@ -114,14 +113,20 @@ class ClientTLSOptionsFactory(object): self._no_verify_ssl_context = self._no_verify_ssl.getContext() self._no_verify_ssl_context.set_info_callback(self._context_info_cb) - def get_options(self, host): + def get_options(self, host: bytes): + + # IPolicyForHTTPS.get_options takes bytes, but we want to compare + # against the str whitelist. The hostnames in the whitelist are already + # IDNA-encoded like the hosts will be here. + ascii_host = host.decode("ascii") + # Check if certificate verification has been enabled should_verify = self._config.federation_verify_certificates # Check if we've disabled certificate verification for this host if should_verify: for regex in self._config.federation_certificate_verification_whitelist: - if regex.match(host): + if regex.match(ascii_host): should_verify = False break @@ -162,7 +167,7 @@ class SSLClientConnectionCreator(object): Replaces twisted.internet.ssl.ClientTLSOptions """ - def __init__(self, hostname, ctx, verify_certs): + def __init__(self, hostname: bytes, ctx, verify_certs: bool): self._ctx = ctx self._verifier = ConnectionVerifier(hostname, verify_certs) @@ -190,21 +195,16 @@ class ConnectionVerifier(object): # This code is based on twisted.internet.ssl.ClientTLSOptions. - def __init__(self, hostname, verify_certs): + def __init__(self, hostname: bytes, verify_certs): self._verify_certs = verify_certs - if isIPAddress(hostname) or isIPv6Address(hostname): - self._hostnameBytes = hostname.encode("ascii") + _decoded = hostname.decode("ascii") + if isIPAddress(_decoded) or isIPv6Address(_decoded): self._is_ip_address = True else: - # twisted's ClientTLSOptions falls back to the stdlib impl here if - # idna is not installed, but points out that lacks support for - # IDNA2008 (http://bugs.python.org/issue17305). - # - # We can rely on having idna. - self._hostnameBytes = idna.encode(hostname) self._is_ip_address = False + self._hostnameBytes = hostname self._hostnameASCII = self._hostnameBytes.decode("ascii") def verify_context_info_cb(self, ssl_connection, where): diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index feae7de5be..647d26dc56 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -217,7 +217,7 @@ class MatrixHostnameEndpoint(object): self._tls_options = None else: self._tls_options = tls_client_options_factory.get_options( - self._parsed_uri.host.decode("ascii") + self._parsed_uri.host ) self._srv_resolver = srv_resolver diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py index 8e0c4b9533..b02780772a 100644 --- a/tests/config/test_tls.py +++ b/tests/config/test_tls.py @@ -16,6 +16,7 @@ import os +import idna import yaml from OpenSSL import SSL @@ -235,3 +236,42 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg= ) self.assertTrue(conf.acme_enabled) + + def test_whitelist_idna_failure(self): + """ + The federation certificate whitelist will not allow IDNA domain names. + """ + config = { + "federation_certificate_verification_whitelist": [ + "example.com", + "*.ドメイン.テスト", + ] + } + t = TestConfig() + e = self.assertRaises( + ConfigError, t.read_config, config, config_dir_path="", data_dir_path="" + ) + self.assertIn("IDNA domain names", str(e)) + + def test_whitelist_idna_result(self): + """ + The federation certificate whitelist will match on IDNA encoded names. + """ + config = { + "federation_certificate_verification_whitelist": [ + "example.com", + "*.xn--eckwd4c7c.xn--zckzah", + ] + } + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + + cf = ClientTLSOptionsFactory(t) + + # Not in the whitelist + opts = cf.get_options(b"notexample.com") + self.assertTrue(opts._verifier._verify_certs) + + # Caught by the wildcard + opts = cf.get_options(idna.encode("テスト.ドメイン.テスト")) + self.assertFalse(opts._verifier._verify_certs) -- cgit 1.5.1 From 1e19ce00bff8d67168d39201cdf9424f7b2f22f6 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 17 Sep 2019 11:41:54 +0100 Subject: Add 'failure_ts' column to 'destinations' table (#6016) Track the time that a server started failing at, for general analysis purposes. --- changelog.d/6016.misc | 1 + .../schema/delta/56/destinations_failure_ts.sql | 25 ++++ synapse/storage/transactions.py | 23 ++-- synapse/util/retryutils.py | 16 ++- tests/handlers/test_typing.py | 7 +- tests/storage/test_transactions.py | 8 +- tests/util/test_retryutils.py | 127 +++++++++++++++++++++ 7 files changed, 195 insertions(+), 12 deletions(-) create mode 100644 changelog.d/6016.misc create mode 100644 synapse/storage/schema/delta/56/destinations_failure_ts.sql create mode 100644 tests/util/test_retryutils.py (limited to 'tests') diff --git a/changelog.d/6016.misc b/changelog.d/6016.misc new file mode 100644 index 0000000000..91cf164714 --- /dev/null +++ b/changelog.d/6016.misc @@ -0,0 +1 @@ +Add a 'failure_ts' column to the 'destinations' database table. diff --git a/synapse/storage/schema/delta/56/destinations_failure_ts.sql b/synapse/storage/schema/delta/56/destinations_failure_ts.sql new file mode 100644 index 0000000000..f00889290b --- /dev/null +++ b/synapse/storage/schema/delta/56/destinations_failure_ts.sql @@ -0,0 +1,25 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Record the timestamp when a given server started failing + */ +ALTER TABLE destinations ADD failure_ts BIGINT; + +/* as a rough approximation, we assume that the server started failing at + * retry_interval before the last retry + */ +UPDATE destinations SET failure_ts = retry_last_ts - retry_interval + WHERE retry_last_ts > 0; diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index d81ace0ece..289c117396 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -165,7 +165,7 @@ class TransactionStore(SQLBaseStore): txn, table="destinations", keyvalues={"destination": destination}, - retcols=("destination", "retry_last_ts", "retry_interval"), + retcols=("destination", "failure_ts", "retry_last_ts", "retry_interval"), allow_none=True, ) @@ -174,12 +174,15 @@ class TransactionStore(SQLBaseStore): else: return None - def set_destination_retry_timings(self, destination, retry_last_ts, retry_interval): + def set_destination_retry_timings( + self, destination, failure_ts, retry_last_ts, retry_interval + ): """Sets the current retry timings for a given destination. Both timings should be zero if retrying is no longer occuring. Args: destination (str) + failure_ts (int|None) - when the server started failing (ms since epoch) retry_last_ts (int) - time of last retry attempt in unix epoch ms retry_interval (int) - how long until next retry in ms """ @@ -189,12 +192,13 @@ class TransactionStore(SQLBaseStore): "set_destination_retry_timings", self._set_destination_retry_timings, destination, + failure_ts, retry_last_ts, retry_interval, ) def _set_destination_retry_timings( - self, txn, destination, retry_last_ts, retry_interval + self, txn, destination, failure_ts, retry_last_ts, retry_interval ): if self.database_engine.can_native_upsert: @@ -202,9 +206,12 @@ class TransactionStore(SQLBaseStore): # resetting it) or greater than the existing retry interval. sql = """ - INSERT INTO destinations (destination, retry_last_ts, retry_interval) - VALUES (?, ?, ?) + INSERT INTO destinations ( + destination, failure_ts, retry_last_ts, retry_interval + ) + VALUES (?, ?, ?, ?) ON CONFLICT (destination) DO UPDATE SET + failure_ts = EXCLUDED.failure_ts, retry_last_ts = EXCLUDED.retry_last_ts, retry_interval = EXCLUDED.retry_interval WHERE @@ -212,7 +219,7 @@ class TransactionStore(SQLBaseStore): OR destinations.retry_interval < EXCLUDED.retry_interval """ - txn.execute(sql, (destination, retry_last_ts, retry_interval)) + txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval)) return @@ -225,7 +232,7 @@ class TransactionStore(SQLBaseStore): txn, table="destinations", keyvalues={"destination": destination}, - retcols=("retry_last_ts", "retry_interval"), + retcols=("failure_ts", "retry_last_ts", "retry_interval"), allow_none=True, ) @@ -235,6 +242,7 @@ class TransactionStore(SQLBaseStore): table="destinations", values={ "destination": destination, + "failure_ts": failure_ts, "retry_last_ts": retry_last_ts, "retry_interval": retry_interval, }, @@ -245,6 +253,7 @@ class TransactionStore(SQLBaseStore): "destinations", keyvalues={"destination": destination}, updatevalues={ + "failure_ts": failure_ts, "retry_last_ts": retry_last_ts, "retry_interval": retry_interval, }, diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index b740913b58..a5f2fbef5c 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -80,11 +80,13 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs) # We aren't ready to retry that destination. raise """ + failure_ts = None retry_last_ts, retry_interval = (0, 0) retry_timings = yield store.get_destination_retry_timings(destination) if retry_timings: + failure_ts = retry_timings["failure_ts"] retry_last_ts, retry_interval = ( retry_timings["retry_last_ts"], retry_timings["retry_interval"], @@ -108,6 +110,7 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs) destination, clock, store, + failure_ts, retry_interval, backoff_on_failure=backoff_on_failure, **kwargs @@ -120,6 +123,7 @@ class RetryDestinationLimiter(object): destination, clock, store, + failure_ts, retry_interval, backoff_on_404=False, backoff_on_failure=True, @@ -133,6 +137,8 @@ class RetryDestinationLimiter(object): destination (str) clock (Clock) store (DataStore) + failure_ts (int|None): when this destination started failing (in ms since + the epoch), or zero if the last request was successful retry_interval (int): The next retry interval taken from the database in milliseconds, or zero if the last request was successful. @@ -145,6 +151,7 @@ class RetryDestinationLimiter(object): self.store = store self.destination = destination + self.failure_ts = failure_ts self.retry_interval = retry_interval self.backoff_on_404 = backoff_on_404 self.backoff_on_failure = backoff_on_failure @@ -186,6 +193,7 @@ class RetryDestinationLimiter(object): logger.debug( "Connection to %s was successful; clearing backoff", self.destination ) + self.failure_ts = None retry_last_ts = 0 self.retry_interval = 0 elif not self.backoff_on_failure: @@ -211,11 +219,17 @@ class RetryDestinationLimiter(object): ) retry_last_ts = int(self.clock.time_msec()) + if self.failure_ts is None: + self.failure_ts = retry_last_ts + @defer.inlineCallbacks def store_retry_timings(): try: yield self.store.set_destination_retry_timings( - self.destination, retry_last_ts, self.retry_interval + self.destination, + self.failure_ts, + retry_last_ts, + self.retry_interval, ) except Exception: logger.exception("Failed to store destination_retry_timings") diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 5d5e324df2..1f2ef5d01f 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -99,7 +99,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.event_source = hs.get_event_sources().sources["typing"] self.datastore = hs.get_datastore() - retry_timings_res = {"destination": "", "retry_last_ts": 0, "retry_interval": 0} + retry_timings_res = { + "destination": "", + "retry_last_ts": 0, + "retry_interval": 0, + "failure_ts": None, + } self.datastore.get_destination_retry_timings.return_value = defer.succeed( retry_timings_res ) diff --git a/tests/storage/test_transactions.py b/tests/storage/test_transactions.py index 14169afa96..a771d5af29 100644 --- a/tests/storage/test_transactions.py +++ b/tests/storage/test_transactions.py @@ -29,17 +29,19 @@ class TransactionStoreTestCase(HomeserverTestCase): r = self.get_success(d) self.assertIsNone(r) - d = self.store.set_destination_retry_timings("example.com", 50, 100) + d = self.store.set_destination_retry_timings("example.com", 1000, 50, 100) self.get_success(d) d = self.store.get_destination_retry_timings("example.com") r = self.get_success(d) - self.assert_dict({"retry_last_ts": 50, "retry_interval": 100}, r) + self.assert_dict( + {"retry_last_ts": 50, "retry_interval": 100, "failure_ts": 1000}, r + ) def test_initial_set_transactions(self): """Tests that we can successfully set the destination retries (there was a bug around invalidating the cache that broke this) """ - d = self.store.set_destination_retry_timings("example.com", 50, 100) + d = self.store.set_destination_retry_timings("example.com", 1000, 50, 100) self.get_success(d) diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py new file mode 100644 index 0000000000..9e348694ad --- /dev/null +++ b/tests/util/test_retryutils.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.util.retryutils import ( + MIN_RETRY_INTERVAL, + RETRY_MULTIPLIER, + NotRetryingDestination, + get_retry_limiter, +) + +from tests.unittest import HomeserverTestCase + + +class RetryLimiterTestCase(HomeserverTestCase): + def test_new_destination(self): + """A happy-path case with a new destination and a successful operation""" + store = self.hs.get_datastore() + d = get_retry_limiter("test_dest", self.clock, store) + self.pump() + limiter = self.successResultOf(d) + + # advance the clock a bit before making the request + self.pump(1) + + with limiter: + pass + + d = store.get_destination_retry_timings("test_dest") + self.pump() + new_timings = self.successResultOf(d) + self.assertIsNone(new_timings) + + def test_limiter(self): + """General test case which walks through the process of a failing request""" + store = self.hs.get_datastore() + + d = get_retry_limiter("test_dest", self.clock, store) + self.pump() + limiter = self.successResultOf(d) + + self.pump(1) + try: + with limiter: + self.pump(1) + failure_ts = self.clock.time_msec() + raise AssertionError("argh") + except AssertionError: + pass + + # wait for the update to land + self.pump() + + d = store.get_destination_retry_timings("test_dest") + self.pump() + new_timings = self.successResultOf(d) + self.assertEqual(new_timings["failure_ts"], failure_ts) + self.assertEqual(new_timings["retry_last_ts"], failure_ts) + self.assertEqual(new_timings["retry_interval"], MIN_RETRY_INTERVAL) + + # now if we try again we should get a failure + d = get_retry_limiter("test_dest", self.clock, store) + self.pump() + self.failureResultOf(d, NotRetryingDestination) + + # + # advance the clock and try again + # + + self.pump(MIN_RETRY_INTERVAL) + d = get_retry_limiter("test_dest", self.clock, store) + self.pump() + limiter = self.successResultOf(d) + + self.pump(1) + try: + with limiter: + self.pump(1) + retry_ts = self.clock.time_msec() + raise AssertionError("argh") + except AssertionError: + pass + + # wait for the update to land + self.pump() + + d = store.get_destination_retry_timings("test_dest") + self.pump() + new_timings = self.successResultOf(d) + self.assertEqual(new_timings["failure_ts"], failure_ts) + self.assertEqual(new_timings["retry_last_ts"], retry_ts) + self.assertGreaterEqual( + new_timings["retry_interval"], MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 0.5 + ) + self.assertLessEqual( + new_timings["retry_interval"], MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0 + ) + + # + # one more go, with success + # + self.pump(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0) + d = get_retry_limiter("test_dest", self.clock, store) + self.pump() + limiter = self.successResultOf(d) + + self.pump(1) + with limiter: + self.pump(1) + + # wait for the update to land + self.pump() + + d = store.get_destination_retry_timings("test_dest") + self.pump() + new_timings = self.successResultOf(d) + self.assertIsNone(new_timings) -- cgit 1.5.1 From 51d28272e20d799b2e35a8a14b3c1d9d5f555d10 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 23 Sep 2019 16:00:18 +0100 Subject: Query devices table for last seen info. This is a) simpler than querying user_ips directly and b) means we can purge older entries from user_ips without losing the required info. The storage functions now no longer return the access_token, since it was unused. --- synapse/storage/client_ips.py | 57 ++++++---------------------------------- tests/storage/test_client_ips.py | 1 - 2 files changed, 8 insertions(+), 50 deletions(-) (limited to 'tests') diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index 8839562269..a4e6d9dbe7 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -392,19 +392,14 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): keys giving the column names """ - res = yield self.runInteraction( - "get_last_client_ip_by_device", - self._get_last_client_ip_by_device_txn, - user_id, - device_id, - retcols=( - "user_id", - "access_token", - "ip", - "user_agent", - "device_id", - "last_seen", - ), + keyvalues = {"user_id": user_id} + if device_id: + keyvalues["device_id"] = device_id + + res = yield self._simple_select_list( + table="devices", + keyvalues=keyvalues, + retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), ) ret = {(d["user_id"], d["device_id"]): d for d in res} @@ -423,42 +418,6 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): } return ret - @classmethod - def _get_last_client_ip_by_device_txn(cls, txn, user_id, device_id, retcols): - where_clauses = [] - bindings = [] - if device_id is None: - where_clauses.append("user_id = ?") - bindings.extend((user_id,)) - else: - where_clauses.append("(user_id = ? AND device_id = ?)") - bindings.extend((user_id, device_id)) - - if not where_clauses: - return [] - - inner_select = ( - "SELECT MAX(last_seen) mls, user_id, device_id FROM user_ips " - "WHERE %(where)s " - "GROUP BY user_id, device_id" - ) % {"where": " OR ".join(where_clauses)} - - sql = ( - "SELECT %(retcols)s FROM user_ips " - "JOIN (%(inner_select)s) ips ON" - " user_ips.last_seen = ips.mls AND" - " user_ips.user_id = ips.user_id AND" - " (user_ips.device_id = ips.device_id OR" - " (user_ips.device_id IS NULL AND ips.device_id IS NULL)" - " )" - ) % { - "retcols": ",".join("user_ips." + c for c in retcols), - "inner_select": inner_select, - } - - txn.execute(sql, bindings) - return cls.cursor_to_dict(txn) - @defer.inlineCallbacks def get_user_ip_and_agents(self, user): user_id = user.to_string() diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 09305c3bf1..6ac4654085 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -55,7 +55,6 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): { "user_id": user_id, "device_id": "device_id", - "access_token": "access_token", "ip": "ip", "user_agent": "user_agent", "last_seen": 12345678000, -- cgit 1.5.1 From acb62a7cc6973618397a868289b5881f1c3c1ec3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 23 Sep 2019 16:50:31 +0100 Subject: Test background update --- tests/storage/test_client_ips.py | 79 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) (limited to 'tests') diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 6ac4654085..76fe65b59e 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -200,6 +200,85 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertTrue(active) + def test_devices_last_seen_bg_update(self): + # First make sure we have completed all updates. + while not self.get_success(self.store.has_completed_background_updates()): + self.get_success(self.store.do_next_background_update(100), by=0.1) + + # Insert a user IP + user_id = "@user:id" + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id" + ) + ) + + # Force persisting to disk + self.reactor.advance(200) + + # But clear the associated entry in devices table + self.get_success( + self.store._simple_update( + table="devices", + keyvalues={"user_id": user_id, "device_id": "device_id"}, + updatevalues={"last_seen": None, "ip": None, "user_agent": None}, + desc="test_devices_last_seen_bg_update", + ) + ) + + # We should now get nulls when querying + result = self.get_success( + self.store.get_last_client_ip_by_device(user_id, "device_id") + ) + + r = result[(user_id, "device_id")] + self.assertDictContainsSubset( + { + "user_id": user_id, + "device_id": "device_id", + "ip": None, + "user_agent": None, + "last_seen": None, + }, + r, + ) + + # Register the background update to run again. + self.get_success( + self.store._simple_insert( + table="background_updates", + values={ + "update_name": "devices_last_seen", + "progress_json": "{}", + "depends_on": None, + }, + ) + ) + + # ... and tell the DataStore that it hasn't finished all updates yet + self.store._all_done = False + + # Now let's actually drive the updates to completion + while not self.get_success(self.store.has_completed_background_updates()): + self.get_success(self.store.do_next_background_update(100), by=0.1) + + # We should now get the correct result again + result = self.get_success( + self.store.get_last_client_ip_by_device(user_id, "device_id") + ) + + r = result[(user_id, "device_id")] + self.assertDictContainsSubset( + { + "user_id": user_id, + "device_id": "device_id", + "ip": "ip", + "user_agent": "user_agent", + "last_seen": 0, + }, + r, + ) + class ClientIpAuthTestCase(unittest.HomeserverTestCase): -- cgit 1.5.1 From faac453f08046ddf00b39b90ba255f774b75c253 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 24 Sep 2019 15:51:42 +0100 Subject: Test that pruning of old user IPs works --- tests/storage/test_client_ips.py | 71 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) (limited to 'tests') diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 76fe65b59e..afac5dec7f 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -279,6 +279,77 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): r, ) + def test_old_user_ips_pruned(self): + # First make sure we have completed all updates. + while not self.get_success(self.store.has_completed_background_updates()): + self.get_success(self.store.do_next_background_update(100), by=0.1) + + # Insert a user IP + user_id = "@user:id" + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id" + ) + ) + + # Force persisting to disk + self.reactor.advance(200) + + # We should see that in the DB + result = self.get_success( + self.store._simple_select_list( + table="user_ips", + keyvalues={"user_id": user_id}, + retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], + desc="get_user_ip_and_agents", + ) + ) + + self.assertEqual( + result, + [ + { + "access_token": "access_token", + "ip": "ip", + "user_agent": "user_agent", + "device_id": "device_id", + "last_seen": 0, + } + ], + ) + + # Now advance by a couple of months + self.reactor.advance(60 * 24 * 60 * 60) + + # We should get no results. + result = self.get_success( + self.store._simple_select_list( + table="user_ips", + keyvalues={"user_id": user_id}, + retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], + desc="get_user_ip_and_agents", + ) + ) + + self.assertEqual(result, []) + + # But we should still get the correct values for the device + result = self.get_success( + self.store.get_last_client_ip_by_device(user_id, "device_id") + ) + + r = result[(user_id, "device_id")] + self.assertDictContainsSubset( + { + "user_id": user_id, + "device_id": "device_id", + "ip": "ip", + "user_agent": "user_agent", + "last_seen": 0, + }, + r, + ) + class ClientIpAuthTestCase(unittest.HomeserverTestCase): -- cgit 1.5.1 From 8004d6ca2faf0f2f843fcdcaf225d7bcab847503 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 25 Sep 2019 11:32:05 +0100 Subject: Refactor code for calculating registration flows (#6106) because, frankly, it looked like it was written by an axe-murderer. This should be a non-functional change, except that where `m.login.dummy` was previously advertised *before* `m.login.terms`, it will now be advertised afterwards. AFAICT that should have no effect, and will be more consistent with the flows that involve passing a 3pid. --- changelog.d/6106.misc | 1 + synapse/rest/client/v2_alpha/register.py | 124 ++++++++++++++-------------- tests/rest/client/v2_alpha/test_register.py | 79 +++++++++++++++--- tests/test_terms_auth.py | 24 ++++-- 4 files changed, 145 insertions(+), 83 deletions(-) create mode 100644 changelog.d/6106.misc (limited to 'tests') diff --git a/changelog.d/6106.misc b/changelog.d/6106.misc new file mode 100644 index 0000000000..d732091779 --- /dev/null +++ b/changelog.d/6106.misc @@ -0,0 +1 @@ +Refactor code for calculating registration flows. diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 135a70808f..e3f3d9126f 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -16,6 +16,7 @@ import hmac import logging +from typing import List, Union from six import string_types @@ -31,8 +32,11 @@ from synapse.api.errors import ( ThreepidValidationError, UnrecognizedRequestError, ) +from synapse.config.captcha import CaptchaConfig +from synapse.config.consent_config import ConsentConfig from synapse.config.emailconfig import ThreepidBehaviour from synapse.config.ratelimiting import FederationRateLimitConfig +from synapse.config.registration import RegistrationConfig from synapse.config.server import is_threepid_reserved from synapse.http.server import finish_request from synapse.http.servlet import ( @@ -371,6 +375,8 @@ class RegisterRestServlet(RestServlet): self.ratelimiter = hs.get_registration_ratelimiter() self.clock = hs.get_clock() + self._registration_flows = _calculate_registration_flows(hs.config) + @interactive_auth_handler @defer.inlineCallbacks def on_POST(self, request): @@ -491,69 +497,8 @@ class RegisterRestServlet(RestServlet): assigned_user_id=registered_user_id, ) - # FIXME: need a better error than "no auth flow found" for scenarios - # where we required 3PID for registration but the user didn't give one - require_email = "email" in self.hs.config.registrations_require_3pid - require_msisdn = "msisdn" in self.hs.config.registrations_require_3pid - - show_msisdn = True - if self.hs.config.disable_msisdn_registration: - show_msisdn = False - require_msisdn = False - - flows = [] - if self.hs.config.enable_registration_captcha: - # only support 3PIDless registration if no 3PIDs are required - if not require_email and not require_msisdn: - # Also add a dummy flow here, otherwise if a client completes - # recaptcha first we'll assume they were going for this flow - # and complete the request, when they could have been trying to - # complete one of the flows with email/msisdn auth. - flows.extend([[LoginType.RECAPTCHA, LoginType.DUMMY]]) - # only support the email-only flow if we don't require MSISDN 3PIDs - if not require_msisdn: - flows.extend([[LoginType.RECAPTCHA, LoginType.EMAIL_IDENTITY]]) - - if show_msisdn: - # only support the MSISDN-only flow if we don't require email 3PIDs - if not require_email: - flows.extend([[LoginType.RECAPTCHA, LoginType.MSISDN]]) - # always let users provide both MSISDN & email - flows.extend( - [[LoginType.RECAPTCHA, LoginType.MSISDN, LoginType.EMAIL_IDENTITY]] - ) - else: - # only support 3PIDless registration if no 3PIDs are required - if not require_email and not require_msisdn: - flows.extend([[LoginType.DUMMY]]) - # only support the email-only flow if we don't require MSISDN 3PIDs - if not require_msisdn: - flows.extend([[LoginType.EMAIL_IDENTITY]]) - - if show_msisdn: - # only support the MSISDN-only flow if we don't require email 3PIDs - if not require_email or require_msisdn: - flows.extend([[LoginType.MSISDN]]) - # always let users provide both MSISDN & email - flows.extend([[LoginType.MSISDN, LoginType.EMAIL_IDENTITY]]) - - # Append m.login.terms to all flows if we're requiring consent - if self.hs.config.user_consent_at_registration: - new_flows = [] - for flow in flows: - inserted = False - # m.login.terms should go near the end but before msisdn or email auth - for i, stage in enumerate(flow): - if stage == LoginType.EMAIL_IDENTITY or stage == LoginType.MSISDN: - flow.insert(i, LoginType.TERMS) - inserted = True - break - if not inserted: - flow.append(LoginType.TERMS) - flows.extend(new_flows) - auth_result, params, session_id = yield self.auth_handler.check_auth( - flows, body, self.hs.get_ip_from_request(request) + self._registration_flows, body, self.hs.get_ip_from_request(request) ) # Check that we're not trying to register a denied 3pid. @@ -716,6 +661,61 @@ class RegisterRestServlet(RestServlet): ) +def _calculate_registration_flows( + # technically `config` has to provide *all* of these interfaces, not just one + config: Union[RegistrationConfig, ConsentConfig, CaptchaConfig], +) -> List[List[str]]: + """Get a suitable flows list for registration + + Args: + config: server configuration + + Returns: a list of supported flows + """ + # FIXME: need a better error than "no auth flow found" for scenarios + # where we required 3PID for registration but the user didn't give one + require_email = "email" in config.registrations_require_3pid + require_msisdn = "msisdn" in config.registrations_require_3pid + + show_msisdn = True + if config.disable_msisdn_registration: + show_msisdn = False + require_msisdn = False + + flows = [] + + # only support 3PIDless registration if no 3PIDs are required + if not require_email and not require_msisdn: + # Add a dummy step here, otherwise if a client completes + # recaptcha first we'll assume they were going for this flow + # and complete the request, when they could have been trying to + # complete one of the flows with email/msisdn auth. + flows.append([LoginType.DUMMY]) + + # only support the email-only flow if we don't require MSISDN 3PIDs + if not require_msisdn: + flows.append([LoginType.EMAIL_IDENTITY]) + + # only support the MSISDN-only flow if we don't require email 3PIDs + if show_msisdn and not require_email: + flows.append([LoginType.MSISDN]) + + if show_msisdn: + flows.append([LoginType.MSISDN, LoginType.EMAIL_IDENTITY]) + + # Prepend m.login.terms to all flows if we're requiring consent + if config.user_consent_at_registration: + for flow in flows: + flow.insert(0, LoginType.TERMS) + + # Prepend recaptcha to all flows if we're requiring captcha + if config.enable_registration_captcha: + for flow in flows: + flow.insert(0, LoginType.RECAPTCHA) + + return flows + + def register_servlets(hs, http_server): EmailRegisterRequestTokenRestServlet(hs).register(http_server) MsisdnRegisterRequestTokenRestServlet(hs).register(http_server) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index ab4d7d70d0..bc2dc47973 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -34,19 +34,12 @@ from tests import unittest class RegisterRestServletTestCase(unittest.HomeserverTestCase): servlets = [register.register_servlets] + url = b"/_matrix/client/r0/register" - def make_homeserver(self, reactor, clock): - - self.url = b"/_matrix/client/r0/register" - - self.hs = self.setup_test_homeserver() - self.hs.config.enable_registration = True - self.hs.config.registrations_require_3pid = [] - self.hs.config.auto_join_rooms = [] - self.hs.config.enable_registration_captcha = False - self.hs.config.allow_guest_access = True - - return self.hs + def default_config(self, name="test"): + config = super().default_config(name) + config["allow_guest_access"] = True + return config def test_POST_appservice_registration_valid(self): user_id = "@as_user_kermit:test" @@ -199,6 +192,68 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) + def test_advertised_flows(self): + request, channel = self.make_request(b"POST", self.url, b"{}") + self.render(request) + self.assertEquals(channel.result["code"], b"401", channel.result) + flows = channel.json_body["flows"] + + # with the stock config, we expect all four combinations of 3pid + self.assertCountEqual( + [ + ["m.login.dummy"], + ["m.login.email.identity"], + ["m.login.msisdn"], + ["m.login.msisdn", "m.login.email.identity"], + ], + (f["stages"] for f in flows), + ) + + @unittest.override_config( + { + "enable_registration_captcha": True, + "user_consent": { + "version": "1", + "template_dir": "/", + "require_at_registration": True, + }, + } + ) + def test_advertised_flows_captcha_and_terms(self): + request, channel = self.make_request(b"POST", self.url, b"{}") + self.render(request) + self.assertEquals(channel.result["code"], b"401", channel.result) + flows = channel.json_body["flows"] + + self.assertCountEqual( + [ + ["m.login.recaptcha", "m.login.terms", "m.login.dummy"], + ["m.login.recaptcha", "m.login.terms", "m.login.email.identity"], + ["m.login.recaptcha", "m.login.terms", "m.login.msisdn"], + [ + "m.login.recaptcha", + "m.login.terms", + "m.login.msisdn", + "m.login.email.identity", + ], + ], + (f["stages"] for f in flows), + ) + + @unittest.override_config( + {"registrations_require_3pid": ["email"], "disable_msisdn_registration": True} + ) + def test_advertised_flows_no_msisdn_email_required(self): + request, channel = self.make_request(b"POST", self.url, b"{}") + self.render(request) + self.assertEquals(channel.result["code"], b"401", channel.result) + flows = channel.json_body["flows"] + + # with the stock config, we expect all four combinations of 3pid + self.assertCountEqual( + [["m.login.email.identity"]], (f["stages"] for f in flows) + ) + class AccountValidityTestCase(unittest.HomeserverTestCase): diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py index 52739fbabc..5ec5d2b358 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py @@ -28,6 +28,21 @@ from tests import unittest class TermsTestCase(unittest.HomeserverTestCase): servlets = [register_servlets] + def default_config(self, name="test"): + config = super().default_config(name) + config.update( + { + "public_baseurl": "https://example.org/", + "user_consent": { + "version": "1.0", + "policy_name": "My Cool Privacy Policy", + "template_dir": "/", + "require_at_registration": True, + }, + } + ) + return config + def prepare(self, reactor, clock, hs): self.clock = MemoryReactorClock() self.hs_clock = Clock(self.clock) @@ -35,17 +50,8 @@ class TermsTestCase(unittest.HomeserverTestCase): self.registration_handler = Mock() self.auth_handler = Mock() self.device_handler = Mock() - hs.config.enable_registration = True - hs.config.registrations_require_3pid = [] - hs.config.auto_join_rooms = [] - hs.config.enable_registration_captcha = False def test_ui_auth(self): - self.hs.config.user_consent_at_registration = True - self.hs.config.user_consent_policy_name = "My Cool Privacy Policy" - self.hs.config.public_baseurl = "https://example.org/" - self.hs.config.user_consent_version = "1.0" - # Do a UI auth request request, channel = self.make_request(b"POST", self.url, b"{}") self.render(request) -- cgit 1.5.1 From 2cd98812ba338eefe83fee4ae2390d00f5499de9 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 25 Sep 2019 11:33:03 +0100 Subject: Refactor the user-interactive auth handling (#6105) Pull the checkers out to their own classes, rather than having them lost in a massive 1000-line class which does everything. This is also preparation for some more intelligent advertising of flows, as per #6100 --- changelog.d/6105.misc | 1 + synapse/handlers/auth.py | 141 ++------------------- synapse/handlers/ui_auth/__init__.py | 22 ++++ synapse/handlers/ui_auth/checkers.py | 216 ++++++++++++++++++++++++++++++++ tests/rest/client/v2_alpha/test_auth.py | 26 ++-- 5 files changed, 265 insertions(+), 141 deletions(-) create mode 100644 changelog.d/6105.misc create mode 100644 synapse/handlers/ui_auth/__init__.py create mode 100644 synapse/handlers/ui_auth/checkers.py (limited to 'tests') diff --git a/changelog.d/6105.misc b/changelog.d/6105.misc new file mode 100644 index 0000000000..2e838a35c6 --- /dev/null +++ b/changelog.d/6105.misc @@ -0,0 +1 @@ +Refactor the user-interactive auth handling. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 374372b69e..f920c2f6c1 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -21,10 +21,8 @@ import unicodedata import attr import bcrypt import pymacaroons -from canonicaljson import json from twisted.internet import defer -from twisted.web.client import PartialDownloadError import synapse.util.stringutils as stringutils from synapse.api.constants import LoginType @@ -38,7 +36,8 @@ from synapse.api.errors import ( UserDeactivatedError, ) from synapse.api.ratelimiting import Ratelimiter -from synapse.config.emailconfig import ThreepidBehaviour +from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS +from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker from synapse.logging.context import defer_to_thread from synapse.module_api import ModuleApi from synapse.types import UserID @@ -58,13 +57,12 @@ class AuthHandler(BaseHandler): hs (synapse.server.HomeServer): """ super(AuthHandler, self).__init__(hs) - self.checkers = { - LoginType.RECAPTCHA: self._check_recaptcha, - LoginType.EMAIL_IDENTITY: self._check_email_identity, - LoginType.MSISDN: self._check_msisdn, - LoginType.DUMMY: self._check_dummy_auth, - LoginType.TERMS: self._check_terms_auth, - } + + self.checkers = {} # type: dict[str, UserInteractiveAuthChecker] + for auth_checker_class in INTERACTIVE_AUTH_CHECKERS: + inst = auth_checker_class(hs) + self.checkers[inst.AUTH_TYPE] = inst + self.bcrypt_rounds = hs.config.bcrypt_rounds # This is not a cache per se, but a store of all current sessions that @@ -292,7 +290,7 @@ class AuthHandler(BaseHandler): sess["creds"] = {} creds = sess["creds"] - result = yield self.checkers[stagetype](authdict, clientip) + result = yield self.checkers[stagetype].check_auth(authdict, clientip) if result: creds[stagetype] = result self._save_session(sess) @@ -363,7 +361,7 @@ class AuthHandler(BaseHandler): login_type = authdict["type"] checker = self.checkers.get(login_type) if checker is not None: - res = yield checker(authdict, clientip=clientip) + res = yield checker.check_auth(authdict, clientip=clientip) return res # build a v1-login-style dict out of the authdict and fall back to the @@ -376,125 +374,6 @@ class AuthHandler(BaseHandler): (canonical_id, callback) = yield self.validate_login(user_id, authdict) return canonical_id - @defer.inlineCallbacks - def _check_recaptcha(self, authdict, clientip, **kwargs): - try: - user_response = authdict["response"] - except KeyError: - # Client tried to provide captcha but didn't give the parameter: - # bad request. - raise LoginError( - 400, "Captcha response is required", errcode=Codes.CAPTCHA_NEEDED - ) - - logger.info( - "Submitting recaptcha response %s with remoteip %s", user_response, clientip - ) - - # TODO: get this from the homeserver rather than creating a new one for - # each request - try: - client = self.hs.get_simple_http_client() - resp_body = yield client.post_urlencoded_get_json( - self.hs.config.recaptcha_siteverify_api, - args={ - "secret": self.hs.config.recaptcha_private_key, - "response": user_response, - "remoteip": clientip, - }, - ) - except PartialDownloadError as pde: - # Twisted is silly - data = pde.response - resp_body = json.loads(data) - - if "success" in resp_body: - # Note that we do NOT check the hostname here: we explicitly - # intend the CAPTCHA to be presented by whatever client the - # user is using, we just care that they have completed a CAPTCHA. - logger.info( - "%s reCAPTCHA from hostname %s", - "Successful" if resp_body["success"] else "Failed", - resp_body.get("hostname"), - ) - if resp_body["success"]: - return True - raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) - - def _check_email_identity(self, authdict, **kwargs): - return self._check_threepid("email", authdict, **kwargs) - - def _check_msisdn(self, authdict, **kwargs): - return self._check_threepid("msisdn", authdict) - - def _check_dummy_auth(self, authdict, **kwargs): - return defer.succeed(True) - - def _check_terms_auth(self, authdict, **kwargs): - return defer.succeed(True) - - @defer.inlineCallbacks - def _check_threepid(self, medium, authdict, **kwargs): - if "threepid_creds" not in authdict: - raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM) - - threepid_creds = authdict["threepid_creds"] - - identity_handler = self.hs.get_handlers().identity_handler - - logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,)) - if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - if medium == "email": - threepid = yield identity_handler.threepid_from_creds( - self.hs.config.account_threepid_delegate_email, threepid_creds - ) - elif medium == "msisdn": - threepid = yield identity_handler.threepid_from_creds( - self.hs.config.account_threepid_delegate_msisdn, threepid_creds - ) - else: - raise SynapseError(400, "Unrecognized threepid medium: %s" % (medium,)) - elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - row = yield self.store.get_threepid_validation_session( - medium, - threepid_creds["client_secret"], - sid=threepid_creds["sid"], - validated=True, - ) - - threepid = ( - { - "medium": row["medium"], - "address": row["address"], - "validated_at": row["validated_at"], - } - if row - else None - ) - - if row: - # Valid threepid returned, delete from the db - yield self.store.delete_threepid_session(threepid_creds["sid"]) - else: - raise SynapseError( - 400, "Password resets are not enabled on this homeserver" - ) - - if not threepid: - raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) - - if threepid["medium"] != medium: - raise LoginError( - 401, - "Expecting threepid of type '%s', got '%s'" - % (medium, threepid["medium"]), - errcode=Codes.UNAUTHORIZED, - ) - - threepid["threepid_creds"] = authdict["threepid_creds"] - - return threepid - def _get_params_recaptcha(self): return {"public_key": self.hs.config.recaptcha_public_key} diff --git a/synapse/handlers/ui_auth/__init__.py b/synapse/handlers/ui_auth/__init__.py new file mode 100644 index 0000000000..824f37f8f8 --- /dev/null +++ b/synapse/handlers/ui_auth/__init__.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module implements user-interactive auth verification. + +TODO: move more stuff out of AuthHandler in here. + +""" + +from synapse.handlers.ui_auth.checkers import INTERACTIVE_AUTH_CHECKERS # noqa: F401 diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py new file mode 100644 index 0000000000..fd633b7b0e --- /dev/null +++ b/synapse/handlers/ui_auth/checkers.py @@ -0,0 +1,216 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from canonicaljson import json + +from twisted.internet import defer +from twisted.web.client import PartialDownloadError + +from synapse.api.constants import LoginType +from synapse.api.errors import Codes, LoginError, SynapseError +from synapse.config.emailconfig import ThreepidBehaviour + +logger = logging.getLogger(__name__) + + +class UserInteractiveAuthChecker: + """Abstract base class for an interactive auth checker""" + + def __init__(self, hs): + pass + + def check_auth(self, authdict, clientip): + """Given the authentication dict from the client, attempt to check this step + + Args: + authdict (dict): authentication dictionary from the client + clientip (str): The IP address of the client. + + Raises: + SynapseError if authentication failed + + Returns: + Deferred: the result of authentication (to pass back to the client?) + """ + raise NotImplementedError() + + +class DummyAuthChecker(UserInteractiveAuthChecker): + AUTH_TYPE = LoginType.DUMMY + + def check_auth(self, authdict, clientip): + return defer.succeed(True) + + +class TermsAuthChecker(UserInteractiveAuthChecker): + AUTH_TYPE = LoginType.TERMS + + def check_auth(self, authdict, clientip): + return defer.succeed(True) + + +class RecaptchaAuthChecker(UserInteractiveAuthChecker): + AUTH_TYPE = LoginType.RECAPTCHA + + def __init__(self, hs): + super().__init__(hs) + self._http_client = hs.get_simple_http_client() + self._url = hs.config.recaptcha_siteverify_api + self._secret = hs.config.recaptcha_private_key + + @defer.inlineCallbacks + def check_auth(self, authdict, clientip): + try: + user_response = authdict["response"] + except KeyError: + # Client tried to provide captcha but didn't give the parameter: + # bad request. + raise LoginError( + 400, "Captcha response is required", errcode=Codes.CAPTCHA_NEEDED + ) + + logger.info( + "Submitting recaptcha response %s with remoteip %s", user_response, clientip + ) + + # TODO: get this from the homeserver rather than creating a new one for + # each request + try: + resp_body = yield self._http_client.post_urlencoded_get_json( + self._url, + args={ + "secret": self._secret, + "response": user_response, + "remoteip": clientip, + }, + ) + except PartialDownloadError as pde: + # Twisted is silly + data = pde.response + resp_body = json.loads(data) + + if "success" in resp_body: + # Note that we do NOT check the hostname here: we explicitly + # intend the CAPTCHA to be presented by whatever client the + # user is using, we just care that they have completed a CAPTCHA. + logger.info( + "%s reCAPTCHA from hostname %s", + "Successful" if resp_body["success"] else "Failed", + resp_body.get("hostname"), + ) + if resp_body["success"]: + return True + raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) + + +class _BaseThreepidAuthChecker: + def __init__(self, hs): + self.hs = hs + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def _check_threepid(self, medium, authdict): + if "threepid_creds" not in authdict: + raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM) + + threepid_creds = authdict["threepid_creds"] + + identity_handler = self.hs.get_handlers().identity_handler + + logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,)) + if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + if medium == "email": + threepid = yield identity_handler.threepid_from_creds( + self.hs.config.account_threepid_delegate_email, threepid_creds + ) + elif medium == "msisdn": + threepid = yield identity_handler.threepid_from_creds( + self.hs.config.account_threepid_delegate_msisdn, threepid_creds + ) + else: + raise SynapseError(400, "Unrecognized threepid medium: %s" % (medium,)) + elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + row = yield self.store.get_threepid_validation_session( + medium, + threepid_creds["client_secret"], + sid=threepid_creds["sid"], + validated=True, + ) + + threepid = ( + { + "medium": row["medium"], + "address": row["address"], + "validated_at": row["validated_at"], + } + if row + else None + ) + + if row: + # Valid threepid returned, delete from the db + yield self.store.delete_threepid_session(threepid_creds["sid"]) + else: + raise SynapseError( + 400, "Password resets are not enabled on this homeserver" + ) + + if not threepid: + raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) + + if threepid["medium"] != medium: + raise LoginError( + 401, + "Expecting threepid of type '%s', got '%s'" + % (medium, threepid["medium"]), + errcode=Codes.UNAUTHORIZED, + ) + + threepid["threepid_creds"] = authdict["threepid_creds"] + + return threepid + + +class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): + AUTH_TYPE = LoginType.EMAIL_IDENTITY + + def __init__(self, hs): + UserInteractiveAuthChecker.__init__(self, hs) + _BaseThreepidAuthChecker.__init__(self, hs) + + def check_auth(self, authdict, clientip): + return self._check_threepid("email", authdict) + + +class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): + AUTH_TYPE = LoginType.MSISDN + + def __init__(self, hs): + UserInteractiveAuthChecker.__init__(self, hs) + _BaseThreepidAuthChecker.__init__(self, hs) + + def check_auth(self, authdict, clientip): + return self._check_threepid("msisdn", authdict) + + +INTERACTIVE_AUTH_CHECKERS = [ + DummyAuthChecker, + TermsAuthChecker, + RecaptchaAuthChecker, + EmailIdentityAuthChecker, + MsisdnAuthChecker, +] +"""A list of UserInteractiveAuthChecker classes""" diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index b9ef46e8fb..b6df1396ad 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -18,11 +18,22 @@ from twisted.internet.defer import succeed import synapse.rest.admin from synapse.api.constants import LoginType +from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker from synapse.rest.client.v2_alpha import auth, register from tests import unittest +class DummyRecaptchaChecker(UserInteractiveAuthChecker): + def __init__(self, hs): + super().__init__(hs) + self.recaptcha_attempts = [] + + def check_auth(self, authdict, clientip): + self.recaptcha_attempts.append((authdict, clientip)) + return succeed(True) + + class FallbackAuthTests(unittest.HomeserverTestCase): servlets = [ @@ -44,15 +55,9 @@ class FallbackAuthTests(unittest.HomeserverTestCase): return hs def prepare(self, reactor, clock, hs): + self.recaptcha_checker = DummyRecaptchaChecker(hs) auth_handler = hs.get_auth_handler() - - self.recaptcha_attempts = [] - - def _recaptcha(authdict, clientip): - self.recaptcha_attempts.append((authdict, clientip)) - return succeed(True) - - auth_handler.checkers[LoginType.RECAPTCHA] = _recaptcha + auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker @unittest.INFO def test_fallback_captcha(self): @@ -89,8 +94,9 @@ class FallbackAuthTests(unittest.HomeserverTestCase): self.assertEqual(request.code, 200) # The recaptcha handler is called with the response given - self.assertEqual(len(self.recaptcha_attempts), 1) - self.assertEqual(self.recaptcha_attempts[0][0]["response"], "a") + attempts = self.recaptcha_checker.recaptcha_attempts + self.assertEqual(len(attempts), 1) + self.assertEqual(attempts[0][0]["response"], "a") # also complete the dummy auth request, channel = self.make_request( -- cgit 1.5.1 From 990928abde4f3ccd7d43e6214abd7d36434953a9 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 25 Sep 2019 12:10:26 +0100 Subject: Stop advertising unsupported flows for registration (#6107) If email or msisdn verification aren't supported, let's stop advertising them for registration. Fixes #6100. --- changelog.d/6107.bugfix | 1 + synapse/handlers/auth.py | 11 +++++++++- synapse/handlers/ui_auth/checkers.py | 26 +++++++++++++++++++++++ synapse/rest/client/v2_alpha/register.py | 32 ++++++++++++++++++++++++++--- tests/rest/client/v2_alpha/test_register.py | 29 +++++++++++++++----------- 5 files changed, 83 insertions(+), 16 deletions(-) create mode 100644 changelog.d/6107.bugfix (limited to 'tests') diff --git a/changelog.d/6107.bugfix b/changelog.d/6107.bugfix new file mode 100644 index 0000000000..d4b9516ac7 --- /dev/null +++ b/changelog.d/6107.bugfix @@ -0,0 +1 @@ +Ensure that servers which are not configured to support email address verification do not offer it in the registration flows. \ No newline at end of file diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index f920c2f6c1..333eb30625 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -61,7 +61,8 @@ class AuthHandler(BaseHandler): self.checkers = {} # type: dict[str, UserInteractiveAuthChecker] for auth_checker_class in INTERACTIVE_AUTH_CHECKERS: inst = auth_checker_class(hs) - self.checkers[inst.AUTH_TYPE] = inst + if inst.is_enabled(): + self.checkers[inst.AUTH_TYPE] = inst self.bcrypt_rounds = hs.config.bcrypt_rounds @@ -156,6 +157,14 @@ class AuthHandler(BaseHandler): return params + def get_enabled_auth_types(self): + """Return the enabled user-interactive authentication types + + Returns the UI-Auth types which are supported by the homeserver's current + config. + """ + return self.checkers.keys() + @defer.inlineCallbacks def check_auth(self, flows, clientdict, clientip): """ diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index fd633b7b0e..ee69223243 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -32,6 +32,13 @@ class UserInteractiveAuthChecker: def __init__(self, hs): pass + def is_enabled(self): + """Check if the configuration of the homeserver allows this checker to work + + Returns: + bool: True if this login type is enabled. + """ + def check_auth(self, authdict, clientip): """Given the authentication dict from the client, attempt to check this step @@ -51,6 +58,9 @@ class UserInteractiveAuthChecker: class DummyAuthChecker(UserInteractiveAuthChecker): AUTH_TYPE = LoginType.DUMMY + def is_enabled(self): + return True + def check_auth(self, authdict, clientip): return defer.succeed(True) @@ -58,6 +68,9 @@ class DummyAuthChecker(UserInteractiveAuthChecker): class TermsAuthChecker(UserInteractiveAuthChecker): AUTH_TYPE = LoginType.TERMS + def is_enabled(self): + return True + def check_auth(self, authdict, clientip): return defer.succeed(True) @@ -67,10 +80,14 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker): def __init__(self, hs): super().__init__(hs) + self._enabled = bool(hs.config.recaptcha_private_key) self._http_client = hs.get_simple_http_client() self._url = hs.config.recaptcha_siteverify_api self._secret = hs.config.recaptcha_private_key + def is_enabled(self): + return self._enabled + @defer.inlineCallbacks def check_auth(self, authdict, clientip): try: @@ -191,6 +208,12 @@ class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChec UserInteractiveAuthChecker.__init__(self, hs) _BaseThreepidAuthChecker.__init__(self, hs) + def is_enabled(self): + return self.hs.config.threepid_behaviour_email in ( + ThreepidBehaviour.REMOTE, + ThreepidBehaviour.LOCAL, + ) + def check_auth(self, authdict, clientip): return self._check_threepid("email", authdict) @@ -202,6 +225,9 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): UserInteractiveAuthChecker.__init__(self, hs) _BaseThreepidAuthChecker.__init__(self, hs) + def is_enabled(self): + return bool(self.hs.config.account_threepid_delegate_msisdn) + def check_auth(self, authdict, clientip): return self._check_threepid("msisdn", authdict) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index e3f3d9126f..4f24a124a6 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -32,12 +32,14 @@ from synapse.api.errors import ( ThreepidValidationError, UnrecognizedRequestError, ) +from synapse.config import ConfigError from synapse.config.captcha import CaptchaConfig from synapse.config.consent_config import ConsentConfig from synapse.config.emailconfig import ThreepidBehaviour from synapse.config.ratelimiting import FederationRateLimitConfig from synapse.config.registration import RegistrationConfig from synapse.config.server import is_threepid_reserved +from synapse.handlers.auth import AuthHandler from synapse.http.server import finish_request from synapse.http.servlet import ( RestServlet, @@ -375,7 +377,9 @@ class RegisterRestServlet(RestServlet): self.ratelimiter = hs.get_registration_ratelimiter() self.clock = hs.get_clock() - self._registration_flows = _calculate_registration_flows(hs.config) + self._registration_flows = _calculate_registration_flows( + hs.config, self.auth_handler + ) @interactive_auth_handler @defer.inlineCallbacks @@ -664,11 +668,13 @@ class RegisterRestServlet(RestServlet): def _calculate_registration_flows( # technically `config` has to provide *all* of these interfaces, not just one config: Union[RegistrationConfig, ConsentConfig, CaptchaConfig], + auth_handler: AuthHandler, ) -> List[List[str]]: """Get a suitable flows list for registration Args: config: server configuration + auth_handler: authorization handler Returns: a list of supported flows """ @@ -678,10 +684,29 @@ def _calculate_registration_flows( require_msisdn = "msisdn" in config.registrations_require_3pid show_msisdn = True + show_email = True + if config.disable_msisdn_registration: show_msisdn = False require_msisdn = False + enabled_auth_types = auth_handler.get_enabled_auth_types() + if LoginType.EMAIL_IDENTITY not in enabled_auth_types: + show_email = False + if require_email: + raise ConfigError( + "Configuration requires email address at registration, but email " + "validation is not configured" + ) + + if LoginType.MSISDN not in enabled_auth_types: + show_msisdn = False + if require_msisdn: + raise ConfigError( + "Configuration requires msisdn at registration, but msisdn " + "validation is not configured" + ) + flows = [] # only support 3PIDless registration if no 3PIDs are required @@ -693,14 +718,15 @@ def _calculate_registration_flows( flows.append([LoginType.DUMMY]) # only support the email-only flow if we don't require MSISDN 3PIDs - if not require_msisdn: + if show_email and not require_msisdn: flows.append([LoginType.EMAIL_IDENTITY]) # only support the MSISDN-only flow if we don't require email 3PIDs if show_msisdn and not require_email: flows.append([LoginType.MSISDN]) - if show_msisdn: + if show_email and show_msisdn: + # always let users provide both MSISDN & email flows.append([LoginType.MSISDN, LoginType.EMAIL_IDENTITY]) # Prepend m.login.terms to all flows if we're requiring consent diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index bc2dc47973..dab87e5edf 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -198,16 +198,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] - # with the stock config, we expect all four combinations of 3pid - self.assertCountEqual( - [ - ["m.login.dummy"], - ["m.login.email.identity"], - ["m.login.msisdn"], - ["m.login.msisdn", "m.login.email.identity"], - ], - (f["stages"] for f in flows), - ) + # with the stock config, we only expect the dummy flow + self.assertCountEqual([["m.login.dummy"]], (f["stages"] for f in flows)) @unittest.override_config( { @@ -217,9 +209,13 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "template_dir": "/", "require_at_registration": True, }, + "account_threepid_delegates": { + "email": "https://id_server", + "msisdn": "https://id_server", + }, } ) - def test_advertised_flows_captcha_and_terms(self): + def test_advertised_flows_captcha_and_terms_and_3pids(self): request, channel = self.make_request(b"POST", self.url, b"{}") self.render(request) self.assertEquals(channel.result["code"], b"401", channel.result) @@ -241,7 +237,16 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) @unittest.override_config( - {"registrations_require_3pid": ["email"], "disable_msisdn_registration": True} + { + "public_baseurl": "https://test_server", + "registrations_require_3pid": ["email"], + "disable_msisdn_registration": True, + "email": { + "smtp_host": "mail_server", + "smtp_port": 2525, + "notif_from": "sender@host", + }, + } ) def test_advertised_flows_no_msisdn_email_required(self): request, channel = self.make_request(b"POST", self.url, b"{}") -- cgit 1.5.1 From 034db2ba2115d935ce62b641b4051e477a454eac Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Thu, 26 Sep 2019 11:47:53 +0100 Subject: Fix dummy event insertion consent bug (#6053) Fixes #5905 --- changelog.d/6053.bugfix | 1 + synapse/handlers/message.py | 99 ++++++++++++++++------ synapse/storage/event_federation.py | 18 +++- tests/storage/test_cleanup_extrems.py | 147 +++++++++++++++++++++++++++++++-- tests/storage/test_event_federation.py | 40 +++++++++ 5 files changed, 266 insertions(+), 39 deletions(-) create mode 100644 changelog.d/6053.bugfix (limited to 'tests') diff --git a/changelog.d/6053.bugfix b/changelog.d/6053.bugfix new file mode 100644 index 0000000000..6311157bf6 --- /dev/null +++ b/changelog.d/6053.bugfix @@ -0,0 +1 @@ +Prevent exceptions being logged when extremity-cleanup events fail due to lack of user consent to the terms of service. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 1f8272784e..0f8cce8ffe 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -222,6 +222,13 @@ class MessageHandler(object): } +# The duration (in ms) after which rooms should be removed +# `_rooms_to_exclude_from_dummy_event_insertion` (with the effect that we will try +# to generate a dummy event for them once more) +# +_DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY = 7 * 24 * 60 * 60 * 1000 + + class EventCreationHandler(object): def __init__(self, hs): self.hs = hs @@ -258,6 +265,13 @@ class EventCreationHandler(object): self.config.block_events_without_consent_error ) + # Rooms which should be excluded from dummy insertion. (For instance, + # those without local users who can send events into the room). + # + # map from room id to time-of-last-attempt. + # + self._rooms_to_exclude_from_dummy_event_insertion = {} # type: dict[str, int] + # we need to construct a ConsentURIBuilder here, as it checks that the necessary # config options, but *only* if we have a configuration for which we are # going to need it. @@ -888,9 +902,11 @@ class EventCreationHandler(object): """Background task to send dummy events into rooms that have a large number of extremities """ - + self._expire_rooms_to_exclude_from_dummy_event_insertion() room_ids = yield self.store.get_rooms_with_many_extremities( - min_count=10, limit=5 + min_count=10, + limit=5, + room_id_filter=self._rooms_to_exclude_from_dummy_event_insertion.keys(), ) for room_id in room_ids: @@ -904,32 +920,61 @@ class EventCreationHandler(object): members = yield self.state.get_current_users_in_room( room_id, latest_event_ids=latest_event_ids ) + dummy_event_sent = False + for user_id in members: + if not self.hs.is_mine_id(user_id): + continue + requester = create_requester(user_id) + try: + event, context = yield self.create_event( + requester, + { + "type": "org.matrix.dummy_event", + "content": {}, + "room_id": room_id, + "sender": user_id, + }, + prev_events_and_hashes=prev_events_and_hashes, + ) - user_id = None - for member in members: - if self.hs.is_mine_id(member): - user_id = member - break - - if not user_id: - # We don't have a joined user. - # TODO: We should do something here to stop the room from - # appearing next time. - continue + event.internal_metadata.proactively_send = False - requester = create_requester(user_id) + yield self.send_nonmember_event( + requester, event, context, ratelimit=False + ) + dummy_event_sent = True + break + except ConsentNotGivenError: + logger.info( + "Failed to send dummy event into room %s for user %s due to " + "lack of consent. Will try another user" % (room_id, user_id) + ) + except AuthError: + logger.info( + "Failed to send dummy event into room %s for user %s due to " + "lack of power. Will try another user" % (room_id, user_id) + ) - event, context = yield self.create_event( - requester, - { - "type": "org.matrix.dummy_event", - "content": {}, - "room_id": room_id, - "sender": user_id, - }, - prev_events_and_hashes=prev_events_and_hashes, + if not dummy_event_sent: + # Did not find a valid user in the room, so remove from future attempts + # Exclusion is time limited, so the room will be rechecked in the future + # dependent on _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY + logger.info( + "Failed to send dummy event into room %s. Will exclude it from " + "future attempts until cache expires" % (room_id,) + ) + now = self.clock.time_msec() + self._rooms_to_exclude_from_dummy_event_insertion[room_id] = now + + def _expire_rooms_to_exclude_from_dummy_event_insertion(self): + expire_before = self.clock.time_msec() - _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY + to_expire = set() + for room_id, time in self._rooms_to_exclude_from_dummy_event_insertion.items(): + if time < expire_before: + to_expire.add(room_id) + for room_id in to_expire: + logger.debug( + "Expiring room id %s from dummy event insertion exclusion cache", + room_id, ) - - event.internal_metadata.proactively_send = False - - yield self.send_nonmember_event(requester, event, context, ratelimit=False) + del self._rooms_to_exclude_from_dummy_event_insertion[room_id] diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 4f500d893e..f5e8c39262 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import itertools import logging import random @@ -190,12 +191,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas room_id, ) - def get_rooms_with_many_extremities(self, min_count, limit): + def get_rooms_with_many_extremities(self, min_count, limit, room_id_filter): """Get the top rooms with at least N extremities. Args: min_count (int): The minimum number of extremities limit (int): The maximum number of rooms to return. + room_id_filter (iterable[str]): room_ids to exclude from the results Returns: Deferred[list]: At most `limit` room IDs that have at least @@ -203,15 +205,25 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas """ def _get_rooms_with_many_extremities_txn(txn): + where_clause = "1=1" + if room_id_filter: + where_clause = "room_id NOT IN (%s)" % ( + ",".join("?" for _ in room_id_filter), + ) + sql = """ SELECT room_id FROM event_forward_extremities + WHERE %s GROUP BY room_id HAVING count(*) > ? ORDER BY count(*) DESC LIMIT ? - """ + """ % ( + where_clause, + ) - txn.execute(sql, (min_count, limit)) + query_args = list(itertools.chain(room_id_filter, [min_count, limit])) + txn.execute(sql, query_args) return [room_id for room_id, in txn] return self.runInteraction( diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index e9e2d5337c..34f9c72709 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -14,7 +14,13 @@ # limitations under the License. import os.path +from unittest.mock import patch +from mock import Mock + +import synapse.rest.admin +from synapse.api.constants import EventTypes +from synapse.rest.client.v1 import login, room from synapse.storage import prepare_database from synapse.types import Requester, UserID @@ -225,6 +231,14 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): class CleanupExtremDummyEventsTestCase(HomeserverTestCase): + CONSENT_VERSION = "1" + EXTREMITIES_COUNT = 50 + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + room.register_servlets, + ] + def make_homeserver(self, reactor, clock): config = self.default_config() config["cleanup_extremities_with_dummy_events"] = True @@ -233,33 +247,148 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): def prepare(self, reactor, clock, homeserver): self.store = homeserver.get_datastore() self.room_creator = homeserver.get_room_creation_handler() + self.event_creator_handler = homeserver.get_event_creation_handler() # Create a test user and room - self.user = UserID("alice", "test") + self.user = UserID.from_string(self.register_user("user1", "password")) + self.token1 = self.login("user1", "password") self.requester = Requester(self.user, None, False, None, None) info = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] + self.event_creator = homeserver.get_event_creation_handler() + homeserver.config.user_consent_version = self.CONSENT_VERSION def test_send_dummy_event(self): - # Create a bushy graph with 50 extremities. - - event_id_start = self.create_and_send_event(self.room_id, self.user) + self._create_extremity_rich_graph() - for _ in range(50): - self.create_and_send_event( - self.room_id, self.user, prev_event_ids=[event_id_start] - ) + # Pump the reactor repeatedly so that the background updates have a + # chance to run. + self.pump(10 * 60) latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(len(latest_event_ids), 50) + self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids)) + @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=0) + def test_send_dummy_events_when_insufficient_power(self): + self._create_extremity_rich_graph() + # Criple power levels + self.helper.send_state( + self.room_id, + EventTypes.PowerLevels, + body={"users": {str(self.user): -1}}, + tok=self.token1, + ) # Pump the reactor repeatedly so that the background updates have a # chance to run. self.pump(10 * 60) + latest_event_ids = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) + ) + # Check that the room has not been pruned + self.assertTrue(len(latest_event_ids) > 10) + + # New user with regular levels + user2 = self.register_user("user2", "password") + token2 = self.login("user2", "password") + self.helper.join(self.room_id, user2, tok=token2) + self.pump(10 * 60) + + latest_event_ids = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) + ) + self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids)) + + @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=0) + def test_send_dummy_event_without_consent(self): + self._create_extremity_rich_graph() + self._enable_consent_checking() + + # Pump the reactor repeatedly so that the background updates have a + # chance to run. Attempt to add dummy event with user that has not consented + # Check that dummy event send fails. + self.pump(10 * 60) + latest_event_ids = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) + ) + self.assertTrue(len(latest_event_ids) == self.EXTREMITIES_COUNT) + + # Create new user, and add consent + user2 = self.register_user("user2", "password") + token2 = self.login("user2", "password") + self.get_success( + self.store.user_set_consent_version(user2, self.CONSENT_VERSION) + ) + self.helper.join(self.room_id, user2, tok=token2) + + # Background updates should now cause a dummy event to be added to the graph + self.pump(10 * 60) + latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids)) + + @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=250) + def test_expiry_logic(self): + """Simple test to ensure that _expire_rooms_to_exclude_from_dummy_event_insertion() + expires old entries correctly. + """ + self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[ + "1" + ] = 100000 + self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[ + "2" + ] = 200000 + self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[ + "3" + ] = 300000 + self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion() + # All entries within time frame + self.assertEqual( + len( + self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion + ), + 3, + ) + # Oldest room to expire + self.pump(1) + self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion() + self.assertEqual( + len( + self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion + ), + 2, + ) + # All rooms to expire + self.pump(2) + self.assertEqual( + len( + self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion + ), + 0, + ) + + def _create_extremity_rich_graph(self): + """Helper method to create bushy graph on demand""" + + event_id_start = self.create_and_send_event(self.room_id, self.user) + + for _ in range(self.EXTREMITIES_COUNT): + self.create_and_send_event( + self.room_id, self.user, prev_event_ids=[event_id_start] + ) + + latest_event_ids = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) + ) + self.assertEqual(len(latest_event_ids), 50) + + def _enable_consent_checking(self): + """Helper method to enable consent checking""" + self.event_creator._block_events_without_consent_error = "No consent from user" + consent_uri_builder = Mock() + consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" + self.event_creator._consent_uri_builder = consent_uri_builder diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 86c7ac350d..b58386994e 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -75,3 +75,43 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): el = r[i] depth = el[2] self.assertLessEqual(5, depth) + + @defer.inlineCallbacks + def test_get_rooms_with_many_extremities(self): + room1 = "#room1" + room2 = "#room2" + room3 = "#room3" + + def insert_event(txn, i, room_id): + event_id = "$event_%i:local" % i + txn.execute( + ( + "INSERT INTO event_forward_extremities (room_id, event_id) " + "VALUES (?, ?)" + ), + (room_id, event_id), + ) + + for i in range(0, 20): + yield self.store.runInteraction("insert", insert_event, i, room1) + yield self.store.runInteraction("insert", insert_event, i, room2) + yield self.store.runInteraction("insert", insert_event, i, room3) + + # Test simple case + r = yield self.store.get_rooms_with_many_extremities(5, 5, []) + self.assertEqual(len(r), 3) + + # Does filter work? + + r = yield self.store.get_rooms_with_many_extremities(5, 5, [room1]) + self.assertTrue(room2 in r) + self.assertTrue(room3 in r) + self.assertEqual(len(r), 2) + + r = yield self.store.get_rooms_with_many_extremities(5, 5, [room1, room2]) + self.assertEqual(r, [room3]) + + # Does filter and limit work? + + r = yield self.store.get_rooms_with_many_extremities(5, 1, [room1]) + self.assertTrue(r == [room2] or r == [room3]) -- cgit 1.5.1 From 132279a46fd76de8f767bc6977192900c450fec9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 27 Sep 2019 15:11:14 +0100 Subject: Patch inlinecallbacks for log contexts --- synapse/__init__.py | 6 +++ synapse/handlers/room_member.py | 4 +- synapse/storage/_base.py | 7 ++-- synapse/storage/push_rule.py | 2 +- tests/patch_inline_callbacks.py | 86 +++++++++++++++++++++++++++++++++++++++-- 5 files changed, 95 insertions(+), 10 deletions(-) (limited to 'tests') diff --git a/synapse/__init__.py b/synapse/__init__.py index ddfe9ec542..4401fd52f0 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -17,8 +17,11 @@ """ This is a reference implementation of a Matrix home server. """ +import os import sys +from tests.patch_inline_callbacks import do_patch + # Check that we're not running on an unsupported Python version. if sys.version_info < (3, 5): print("Synapse requires Python 3.5 or above.") @@ -36,3 +39,6 @@ except ImportError: pass __version__ = "1.4.0rc1" + +if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): + do_patch() diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 8abdb1b6e6..19e44b5460 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -213,11 +213,11 @@ class RoomMemberHandler(object): if predecessor: # It is an upgraded room. Copy over old tags - self.copy_room_tags_and_direct_to_room( + yield self.copy_room_tags_and_direct_to_room( predecessor["room_id"], room_id, user_id ) # Move over old push rules - self.store.move_push_rules_from_room_to_room_for_user( + yield self.store.move_push_rules_from_room_to_room_for_user( predecessor["room_id"], room_id, user_id ) elif event.membership == Membership.LEAVE: diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index abe16334ec..06cc14fcd1 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -30,7 +30,7 @@ from prometheus_client import Histogram from twisted.internet import defer from synapse.api.errors import StoreError -from synapse.logging.context import LoggingContext, PreserveLoggingContext +from synapse.logging.context import LoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.types import get_domain_from_id @@ -550,8 +550,9 @@ class SQLBaseStore(object): return func(conn, *args, **kwargs) - with PreserveLoggingContext(): - result = yield self._db_pool.runWithConnection(inner_func, *args, **kwargs) + result = yield make_deferred_yieldable( + self._db_pool.runWithConnection(inner_func, *args, **kwargs) + ) return result diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index a6517c4cf3..1f878e6575 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -235,7 +235,7 @@ class PushRulesWorkerStore( (c.get("key") == "room_id" and c.get("pattern") == old_room_id) for c in conditions ): - self.move_push_rule_from_room_to_room(new_room_id, user_id, rule) + yield self.move_push_rule_from_room_to_room(new_room_id, user_id, rule) @defer.inlineCallbacks def bulk_get_push_rules_for_room(self, event, context): diff --git a/tests/patch_inline_callbacks.py b/tests/patch_inline_callbacks.py index 220884311c..5ef0aff0c3 100644 --- a/tests/patch_inline_callbacks.py +++ b/tests/patch_inline_callbacks.py @@ -15,6 +15,7 @@ from __future__ import print_function +import inspect import functools import sys @@ -33,17 +34,19 @@ def do_patch(): orig_inline_callbacks = defer.inlineCallbacks def new_inline_callbacks(f): - - orig = orig_inline_callbacks(f) - @functools.wraps(f) def wrapped(*args, **kwargs): start_context = LoggingContext.current_context() + changes = [] + orig = orig_inline_callbacks(_check_yield_points(f, changes, start_context)) try: res = orig(*args, **kwargs) except Exception: if LoggingContext.current_context() != start_context: + for err in changes: + print(err, file=sys.stderr) + err = "%s changed context from %s to %s on exception" % ( f, start_context, @@ -55,7 +58,10 @@ def do_patch(): if not isinstance(res, Deferred) or res.called: if LoggingContext.current_context() != start_context: - err = "%s changed context from %s to %s" % ( + for err in changes: + print(err, file=sys.stderr) + + err = "Completed %s changed context from %s to %s" % ( f, start_context, LoggingContext.current_context(), @@ -76,6 +82,8 @@ def do_patch(): def check_ctx(r): if LoggingContext.current_context() != start_context: + for err in changes: + print(err, file=sys.stderr) err = "%s completion of %s changed context from %s to %s" % ( "Failure" if isinstance(r, Failure) else "Success", f, @@ -92,3 +100,73 @@ def do_patch(): return wrapped defer.inlineCallbacks = new_inline_callbacks + + +def _check_yield_points(f, changes, start_context): + from synapse.logging.context import LoggingContext + + @functools.wraps(f) + def check_yield_points_inner(*args, **kwargs): + gen = f(*args, **kwargs) + + last_yield_line_no = 1 + result = None + while True: + try: + isFailure = isinstance(result, Failure) + if isFailure: + d = result.throwExceptionIntoGenerator(gen) + else: + d = gen.send(result) + except (StopIteration, defer._DefGen_Return) as e: + if LoggingContext.current_context() != start_context: + # This happens when the context is lost sometime *after* the + # final yield and returning. E.g. we forgot to yield on a + # function that returns a deferred. + err = ( + "%s returned and changed context from %s to %s, in %s between %d and end of func" + % ( + f.__qualname__, + start_context, + LoggingContext.current_context(), + f.__code__.co_filename, + last_yield_line_no, + ) + ) + changes.append(err) + # print(err, file=sys.stderr) + # raise Exception(err) + return getattr(e, "value", None) + + try: + result = yield d + except Exception as e: + result = Failure(e) + + frame = gen.gi_frame + if frame.f_code.co_name == "check_yield_points_inner": + frame = inspect.getgeneratorlocals(gen)["gen"].gi_frame + + if LoggingContext.current_context() != start_context: + # This happens because the context is lost sometime *after* the + # previous yield and *after* the current yield. E.g. the + # deferred we waited on didn't follow the rules, or we forgot to + # yield on a function between the two yield points. + err = ( + "%s changed context from %s to %s, happened between lines %d and %d in %s" + % ( + frame.f_code.co_name, + start_context, + LoggingContext.current_context(), + last_yield_line_no, + frame.f_lineno, + frame.f_code.co_filename, + ) + ) + changes.append(err) + # print(err, file=sys.stderr) + # raise Exception(err) + + last_yield_line_no = frame.f_lineno + + return check_yield_points_inner -- cgit 1.5.1 From e94ff67903c3370fc5bc8b6c336433057e38ff05 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 27 Sep 2019 15:14:02 +0100 Subject: Add test to validate the change --- tests/rest/client/v2_alpha/test_account.py | 70 ++++++++++++++++++++++++------ 1 file changed, 57 insertions(+), 13 deletions(-) (limited to 'tests') diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 920de41de4..69c33dfd8a 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -23,8 +23,8 @@ from email.parser import Parser import pkg_resources import synapse.rest.admin -from synapse.api.constants import LoginType -from synapse.rest.client.v1 import login +from synapse.api.constants import LoginType, Membership +from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import account, register from tests import unittest @@ -244,16 +244,69 @@ class DeactivateTestCase(unittest.HomeserverTestCase): synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, account.register_servlets, + room.register_servlets, ] def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver() - return hs + self.hs = self.setup_test_homeserver() + return self.hs def test_deactivate_account(self): user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test") + self.deactivate(user_id, tok) + + store = self.hs.get_datastore() + + # Check that the user has been marked as deactivated. + self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id))) + + # Check that this access token has been invalidated. + request, channel = self.make_request("GET", "account/whoami") + self.render(request) + self.assertEqual(request.code, 401) + + @unittest.INFO + def test_pending_invites(self): + """Tests that deactivating a user rejects every pending invite for them.""" + store = self.hs.get_datastore() + + inviter_id = self.register_user("inviter", "test") + inviter_tok = self.login("inviter", "test") + + invitee_id = self.register_user("invitee", "test") + invitee_tok = self.login("invitee", "test") + + # Make @inviter:test invite @invitee:test in a new room. + room_id = self.helper.create_room_as(inviter_id, tok=inviter_tok) + self.helper.invite( + room=room_id, + src=inviter_id, + targ=invitee_id, + tok=inviter_tok, + ) + + # Make sure the invite is here. + pending_invites = self.get_success(store.get_invited_rooms_for_user(invitee_id)) + self.assertEqual(len(pending_invites), 1, pending_invites) + self.assertEqual(pending_invites[0].room_id, room_id, pending_invites) + + # Deactivate @invitee:test. + self.deactivate(invitee_id, invitee_tok) + + # Check that the invite isn't there anymore. + pending_invites = self.get_success(store.get_invited_rooms_for_user(invitee_id)) + self.assertEqual(len(pending_invites), 0, pending_invites) + + # Check that the membership of @invitee:test in the room is now "leave". + memberships = self.get_success( + store.get_rooms_for_user_where_membership_is(invitee_id, [Membership.LEAVE]) + ) + self.assertEqual(len(memberships), 1, memberships) + self.assertEqual(memberships[0].room_id, room_id, memberships) + + def deactivate(self, user_id, tok): request_data = json.dumps( { "auth": { @@ -270,12 +323,3 @@ class DeactivateTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEqual(request.code, 200) - store = self.hs.get_datastore() - - # Check that the user has been marked as deactivated. - self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id))) - - # Check that this access token has been invalidated. - request, channel = self.make_request("GET", "account/whoami") - self.render(request) - self.assertEqual(request.code, 401) -- cgit 1.5.1 From 873fe7883cf0d7cf5346a9a55d40967a35848e33 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 27 Sep 2019 15:21:03 +0100 Subject: Lint --- synapse/handlers/deactivate_account.py | 4 +--- tests/rest/client/v2_alpha/test_account.py | 8 +------- 2 files changed, 2 insertions(+), 10 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 763fea3a24..148d1424ca 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -156,9 +156,7 @@ class DeactivateAccountHandler(BaseHandler): require_consent=False, ) logger.info( - "Rejected invite for user %r in room %r", - user_id, - room.room_id, + "Rejected invite for user %r in room %r", user_id, room.room_id ) except Exception: logger.exception( diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 69c33dfd8a..434b730faf 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -280,12 +280,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase): # Make @inviter:test invite @invitee:test in a new room. room_id = self.helper.create_room_as(inviter_id, tok=inviter_tok) - self.helper.invite( - room=room_id, - src=inviter_id, - targ=invitee_id, - tok=inviter_tok, - ) + self.helper.invite(room=room_id, src=inviter_id, targ=invitee_id, tok=inviter_tok) # Make sure the invite is here. pending_invites = self.get_success(store.get_invited_rooms_for_user(invitee_id)) @@ -322,4 +317,3 @@ class DeactivateTestCase(unittest.HomeserverTestCase): ) self.render(request) self.assertEqual(request.code, 200) - -- cgit 1.5.1 From fbb8ff3088abab48bd5815a1acaeb9243ada7431 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 27 Sep 2019 15:23:07 +0100 Subject: ok --- tests/rest/client/v2_alpha/test_account.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 434b730faf..0f51895b81 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -280,7 +280,9 @@ class DeactivateTestCase(unittest.HomeserverTestCase): # Make @inviter:test invite @invitee:test in a new room. room_id = self.helper.create_room_as(inviter_id, tok=inviter_tok) - self.helper.invite(room=room_id, src=inviter_id, targ=invitee_id, tok=inviter_tok) + self.helper.invite( + room=room_id, src=inviter_id, targ=invitee_id, tok=inviter_tok + ) # Make sure the invite is here. pending_invites = self.get_success(store.get_invited_rooms_for_user(invitee_id)) -- cgit 1.5.1 From 6374ca40c2ff3d2eaa41af5ebf4f8324522ecb84 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 27 Sep 2019 15:58:14 +0100 Subject: Update --- tests/patch_inline_callbacks.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) (limited to 'tests') diff --git a/tests/patch_inline_callbacks.py b/tests/patch_inline_callbacks.py index 5ef0aff0c3..a35a1d3305 100644 --- a/tests/patch_inline_callbacks.py +++ b/tests/patch_inline_callbacks.py @@ -15,7 +15,6 @@ from __future__ import print_function -import inspect import functools import sys @@ -32,6 +31,8 @@ def do_patch(): from synapse.logging.context import LoggingContext orig_inline_callbacks = defer.inlineCallbacks + if hasattr(orig_inline_callbacks, "patched_by_synapse"): + return def new_inline_callbacks(f): @functools.wraps(f) @@ -100,13 +101,20 @@ def do_patch(): return wrapped defer.inlineCallbacks = new_inline_callbacks + new_inline_callbacks.patched_by_synapse = True def _check_yield_points(f, changes, start_context): + """Wraps a generator that is about to passed to defer.inlineCallbacks + checking that after every yield the log contexts are correct. + """ + from synapse.logging.context import LoggingContext @functools.wraps(f) def check_yield_points_inner(*args, **kwargs): + expected_context = start_context + gen = f(*args, **kwargs) last_yield_line_no = 1 @@ -119,12 +127,13 @@ def _check_yield_points(f, changes, start_context): else: d = gen.send(result) except (StopIteration, defer._DefGen_Return) as e: - if LoggingContext.current_context() != start_context: + if LoggingContext.current_context() != expected_context: # This happens when the context is lost sometime *after* the # final yield and returning. E.g. we forgot to yield on a # function that returns a deferred. err = ( - "%s returned and changed context from %s to %s, in %s between %d and end of func" + "Function %r returned and changed context from %s to %s," + " in %s between %d and end of func" % ( f.__qualname__, start_context, @@ -134,7 +143,6 @@ def _check_yield_points(f, changes, start_context): ) ) changes.append(err) - # print(err, file=sys.stderr) # raise Exception(err) return getattr(e, "value", None) @@ -144,10 +152,8 @@ def _check_yield_points(f, changes, start_context): result = Failure(e) frame = gen.gi_frame - if frame.f_code.co_name == "check_yield_points_inner": - frame = inspect.getgeneratorlocals(gen)["gen"].gi_frame - if LoggingContext.current_context() != start_context: + if LoggingContext.current_context() != expected_context: # This happens because the context is lost sometime *after* the # previous yield and *after* the current yield. E.g. the # deferred we waited on didn't follow the rules, or we forgot to @@ -164,9 +170,10 @@ def _check_yield_points(f, changes, start_context): ) ) changes.append(err) - # print(err, file=sys.stderr) # raise Exception(err) + expected_context = LoggingContext.current_context() + last_yield_line_no = frame.f_lineno return check_yield_points_inner -- cgit 1.5.1 From ce7a3e7e27ba4215ed86112d2967b0acca4cfb77 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 2 Oct 2019 10:14:01 +0100 Subject: Fix fetching censored redactions from DB Fetching a censored redactions caused an exception due to the code expecting redactions to have a `redact` key, which redacted redactions don't have. --- synapse/storage/events_worker.py | 14 ++++++++++++++ tests/storage/test_redaction.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) (limited to 'tests') diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index c6fa7f82fd..57ce0304e9 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -238,6 +238,20 @@ class EventsWorkerStore(SQLBaseStore): # we have to recheck auth now. if not allow_rejected and entry.event.type == EventTypes.Redaction: + if not hasattr(entry.event, "redacts"): + # A redacted redaction doesn't have a `redacts` key, in + # which case lets just withhold the event. + # + # Note: Most of the time if the redactions has been + # redacted we still have the un-redacted event in the DB + # and so we'll still see the `redacts` key. However, this + # isn't always true e.g. if we have censored the event. + logger.debug( + "Withholding redaction event %s as we don't have redacts key", + event_id, + ) + continue + redacted_event_id = entry.event.redacts event_map = yield self._get_events_from_cache_or_db([redacted_event_id]) original_event_entry = event_map.get(redacted_event_id) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index deecfad9fb..427d3c49c5 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -118,6 +118,8 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.get_success(self.store.persist_event(event, context)) + return event + def test_redact(self): self.get_success( self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) @@ -361,3 +363,37 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) self.assert_dict({"content": {}}, json.loads(event_json)) + + def test_redact_redaction(self): + """Tests that we can redact a redaction and can fetch it again. + """ + + self.get_success( + self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) + ) + + msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t")) + + first_redact_event = self.get_success( + self.inject_redaction( + self.room1, msg_event.event_id, self.u_alice, "Redacting message" + ) + ) + + self.get_success( + self.inject_redaction( + self.room1, + first_redact_event.event_id, + self.u_alice, + "Redacting redaction", + ) + ) + + # Now lets jump to the future where we have censored the redaction event + # in the DB. + self.reactor.advance(60 * 60 * 24 * 31) + + # We just want to check that fetching the event doesn't raise an exception. + self.get_success( + self.store.get_event(first_redact_event.event_id, allow_none=True) + ) -- cgit 1.5.1 From f44f1d2e8374b7250a8a68cf3a49e6d1ac63b0fb Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 2 Oct 2019 10:36:27 +0100 Subject: Fix errors storing large retry intervals. We have set the max retry interval to a value larger than a postgres or sqlite int can hold, which caused exceptions when updating the destinations table. To fix postgres we need to change the column to a bigint, and for sqlite we lower the max interval to 2**62 (which is still incredibly long). --- .../56/destinations_retry_interval_type.sql.postgres | 18 ++++++++++++++++++ synapse/util/retryutils.py | 2 +- tests/storage/test_transactions.py | 11 +++++++++++ 3 files changed, 30 insertions(+), 1 deletion(-) create mode 100644 synapse/storage/schema/delta/56/destinations_retry_interval_type.sql.postgres (limited to 'tests') diff --git a/synapse/storage/schema/delta/56/destinations_retry_interval_type.sql.postgres b/synapse/storage/schema/delta/56/destinations_retry_interval_type.sql.postgres new file mode 100644 index 0000000000..b9bbb18a91 --- /dev/null +++ b/synapse/storage/schema/delta/56/destinations_retry_interval_type.sql.postgres @@ -0,0 +1,18 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- We want to store large retry intervals so we upgrade the column from INT +-- to BIGINT. We don't need to do this on SQLite. +ALTER TABLE destinations ALTER retry_interval SET DATA TYPE BIGINT; diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index a5f2fbef5c..af69587196 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -29,7 +29,7 @@ MIN_RETRY_INTERVAL = 10 * 60 * 1000 RETRY_MULTIPLIER = 5 # a cap on the backoff. (Essentially none) -MAX_RETRY_INTERVAL = 2 ** 63 +MAX_RETRY_INTERVAL = 2 ** 62 class NotRetryingDestination(Exception): diff --git a/tests/storage/test_transactions.py b/tests/storage/test_transactions.py index a771d5af29..8e817e2c7f 100644 --- a/tests/storage/test_transactions.py +++ b/tests/storage/test_transactions.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.util.retryutils import MAX_RETRY_INTERVAL + from tests.unittest import HomeserverTestCase @@ -45,3 +47,12 @@ class TransactionStoreTestCase(HomeserverTestCase): """ d = self.store.set_destination_retry_timings("example.com", 1000, 50, 100) self.get_success(d) + + def test_large_destination_retry(self): + d = self.store.set_destination_retry_timings( + "example.com", MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL + ) + self.get_success(d) + + d = self.store.get_destination_retry_timings("example.com") + self.get_success(d) -- cgit 1.5.1 From a5166e4d5febc0e03ba9da9db99127a797a0bc4d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 2 Oct 2019 14:08:35 +0100 Subject: Land improved room list based on room stats (#6019) Use room_stats and room_state for room directory search --- changelog.d/6019.misc | 1 + synapse/federation/transport/server.py | 8 + synapse/handlers/room_list.py | 323 ++++++--------------- synapse/rest/client/v1/room.py | 8 + synapse/storage/room.py | 228 ++++++++++----- .../schema/delta/56/public_room_list_idx.sql | 16 + tests/handlers/test_roomlist.py | 39 --- 7 files changed, 273 insertions(+), 350 deletions(-) create mode 100644 changelog.d/6019.misc create mode 100644 synapse/storage/schema/delta/56/public_room_list_idx.sql delete mode 100644 tests/handlers/test_roomlist.py (limited to 'tests') diff --git a/changelog.d/6019.misc b/changelog.d/6019.misc new file mode 100644 index 0000000000..dfee73c28f --- /dev/null +++ b/changelog.d/6019.misc @@ -0,0 +1 @@ +Improve performance of the public room list directory. diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 7f8a16e355..0f16f21c2d 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -765,6 +765,10 @@ class PublicRoomList(BaseFederationServlet): else: network_tuple = ThirdPartyInstanceID(None, None) + if limit == 0: + # zero is a special value which corresponds to no limit. + limit = None + data = await maybeDeferred( self.handler.get_local_public_room_list, limit, @@ -800,6 +804,10 @@ class PublicRoomList(BaseFederationServlet): if search_filter is None: logger.warning("Nonefilter") + if limit == 0: + # zero is a special value which corresponds to no limit. + limit = None + data = await self.handler.get_local_public_room_list( limit=limit, since_token=since_token, diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index a7e55f00e5..4e1cc5460f 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -16,8 +16,7 @@ import logging from collections import namedtuple -from six import PY3, iteritems -from six.moves import range +from six import iteritems import msgpack from unpaddedbase64 import decode_base64, encode_base64 @@ -27,7 +26,6 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, JoinRules from synapse.api.errors import Codes, HttpResponseException from synapse.types import ThirdPartyInstanceID -from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.response_cache import ResponseCache @@ -37,7 +35,6 @@ logger = logging.getLogger(__name__) REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000 - # This is used to indicate we should only return rooms published to the main list. EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None) @@ -72,6 +69,8 @@ class RoomListHandler(BaseHandler): This can be (None, None) to indicate the main list, or a particular appservice and network id to use an appservice specific one. Setting to None returns all public rooms across all lists. + from_federation (bool): true iff the request comes from the federation + API """ if not self.enable_room_list_search: return defer.succeed({"chunk": [], "total_room_count_estimate": 0}) @@ -133,239 +132,109 @@ class RoomListHandler(BaseHandler): from_federation (bool): Whether this request originated from a federating server or a client. Used for room filtering. timeout (int|None): Amount of seconds to wait for a response before - timing out. + timing out. TODO """ - if since_token and since_token != "END": - since_token = RoomListNextBatch.from_token(since_token) - else: - since_token = None - rooms_to_order_value = {} - rooms_to_num_joined = {} + # Pagination tokens work by storing the room ID sent in the last batch, + # plus the direction (forwards or backwards). Next batch tokens always + # go forwards, prev batch tokens always go backwards. - newly_visible = [] - newly_unpublished = [] if since_token: - stream_token = since_token.stream_ordering - current_public_id = yield self.store.get_current_public_room_stream_id() - public_room_stream_id = since_token.public_room_stream_id - newly_visible, newly_unpublished = yield self.store.get_public_room_changes( - public_room_stream_id, current_public_id, network_tuple=network_tuple - ) - else: - stream_token = yield self.store.get_room_max_stream_ordering() - public_room_stream_id = yield self.store.get_current_public_room_stream_id() - - room_ids = yield self.store.get_public_room_ids_at_stream_id( - public_room_stream_id, network_tuple=network_tuple - ) - - # We want to return rooms in a particular order: the number of joined - # users. We then arbitrarily use the room_id as a tie breaker. - - @defer.inlineCallbacks - def get_order_for_room(room_id): - # Most of the rooms won't have changed between the since token and - # now (especially if the since token is "now"). So, we can ask what - # the current users are in a room (that will hit a cache) and then - # check if the room has changed since the since token. (We have to - # do it in that order to avoid races). - # If things have changed then fall back to getting the current state - # at the since token. - joined_users = yield self.store.get_users_in_room(room_id) - if self.store.has_room_changed_since(room_id, stream_token): - latest_event_ids = yield self.store.get_forward_extremeties_for_room( - room_id, stream_token - ) - - if not latest_event_ids: - return + batch_token = RoomListNextBatch.from_token(since_token) - joined_users = yield self.state_handler.get_current_users_in_room( - room_id, latest_event_ids - ) - - num_joined_users = len(joined_users) - rooms_to_num_joined[room_id] = num_joined_users + last_room_id = batch_token.last_room_id + forwards = batch_token.direction_is_forward + else: + batch_token = None - if num_joined_users == 0: - return + last_room_id = None + forwards = True - # We want larger rooms to be first, hence negating num_joined_users - rooms_to_order_value[room_id] = (-num_joined_users, room_id) + # we request one more than wanted to see if there are more pages to come + probing_limit = limit + 1 if limit is not None else None - logger.info( - "Getting ordering for %i rooms since %s", len(room_ids), stream_token + results = yield self.store.get_largest_public_rooms( + network_tuple, + search_filter, + probing_limit, + last_room_id=last_room_id, + forwards=forwards, + ignore_non_federatable=from_federation, ) - yield concurrently_execute(get_order_for_room, room_ids, 10) - sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1]) - sorted_rooms = [room_id for room_id, _ in sorted_entries] + def build_room_entry(room): + entry = { + "room_id": room["room_id"], + "name": room["name"], + "topic": room["topic"], + "canonical_alias": room["canonical_alias"], + "num_joined_members": room["joined_members"], + "avatar_url": room["avatar"], + "world_readable": room["history_visibility"] == "world_readable", + "guest_can_join": room["guest_access"] == "can_join", + } - # `sorted_rooms` should now be a list of all public room ids that is - # stable across pagination. Therefore, we can use indices into this - # list as our pagination tokens. + # Filter out Nones – rather omit the field altogether + return {k: v for k, v in entry.items() if v is not None} - # Filter out rooms that we don't want to return - rooms_to_scan = [ - r - for r in sorted_rooms - if r not in newly_unpublished and rooms_to_num_joined[r] > 0 - ] + results = [build_room_entry(r) for r in results] - total_room_count = len(rooms_to_scan) + response = {} + num_results = len(results) + if limit is not None: + more_to_come = num_results == probing_limit - if since_token: - # Filter out rooms we've already returned previously - # `since_token.current_limit` is the index of the last room we - # sent down, so we exclude it and everything before/after it. - if since_token.direction_is_forward: - rooms_to_scan = rooms_to_scan[since_token.current_limit + 1 :] + # Depending on direction we trim either the front or back. + if forwards: + results = results[:limit] else: - rooms_to_scan = rooms_to_scan[: since_token.current_limit] - rooms_to_scan.reverse() - - logger.info("After sorting and filtering, %i rooms remain", len(rooms_to_scan)) - - # _append_room_entry_to_chunk will append to chunk but will stop if - # len(chunk) > limit - # - # Normally we will generate enough results on the first iteration here, - # but if there is a search filter, _append_room_entry_to_chunk may - # filter some results out, in which case we loop again. - # - # We don't want to scan over the entire range either as that - # would potentially waste a lot of work. - # - # XXX if there is no limit, we may end up DoSing the server with - # calls to get_current_state_ids for every single room on the - # server. Surely we should cap this somehow? - # - if limit: - step = limit + 1 + results = results[-limit:] else: - # step cannot be zero - step = len(rooms_to_scan) if len(rooms_to_scan) != 0 else 1 - - chunk = [] - for i in range(0, len(rooms_to_scan), step): - if timeout and self.clock.time() > timeout: - raise Exception("Timed out searching room directory") - - batch = rooms_to_scan[i : i + step] - logger.info("Processing %i rooms for result", len(batch)) - yield concurrently_execute( - lambda r: self._append_room_entry_to_chunk( - r, - rooms_to_num_joined[r], - chunk, - limit, - search_filter, - from_federation=from_federation, - ), - batch, - 5, - ) - logger.info("Now %i rooms in result", len(chunk)) - if len(chunk) >= limit + 1: - break - - chunk.sort(key=lambda e: (-e["num_joined_members"], e["room_id"])) - - # Work out the new limit of the batch for pagination, or None if we - # know there are no more results that would be returned. - # i.e., [since_token.current_limit..new_limit] is the batch of rooms - # we've returned (or the reverse if we paginated backwards) - # We tried to pull out limit + 1 rooms above, so if we have <= limit - # then we know there are no more results to return - new_limit = None - if chunk and (not limit or len(chunk) > limit): - - if not since_token or since_token.direction_is_forward: - if limit: - chunk = chunk[:limit] - last_room_id = chunk[-1]["room_id"] + more_to_come = False + + if num_results > 0: + final_room_id = results[-1]["room_id"] + initial_room_id = results[0]["room_id"] + + if forwards: + if batch_token: + # If there was a token given then we assume that there + # must be previous results. + response["prev_batch"] = RoomListNextBatch( + last_room_id=initial_room_id, direction_is_forward=False + ).to_token() + + if more_to_come: + response["next_batch"] = RoomListNextBatch( + last_room_id=final_room_id, direction_is_forward=True + ).to_token() else: - if limit: - chunk = chunk[-limit:] - last_room_id = chunk[0]["room_id"] - - new_limit = sorted_rooms.index(last_room_id) - - results = {"chunk": chunk, "total_room_count_estimate": total_room_count} - - if since_token: - results["new_rooms"] = bool(newly_visible) - - if not since_token or since_token.direction_is_forward: - if new_limit is not None: - results["next_batch"] = RoomListNextBatch( - stream_ordering=stream_token, - public_room_stream_id=public_room_stream_id, - current_limit=new_limit, - direction_is_forward=True, - ).to_token() - - if since_token: - results["prev_batch"] = since_token.copy_and_replace( - direction_is_forward=False, - current_limit=since_token.current_limit + 1, - ).to_token() - else: - if new_limit is not None: - results["prev_batch"] = RoomListNextBatch( - stream_ordering=stream_token, - public_room_stream_id=public_room_stream_id, - current_limit=new_limit, - direction_is_forward=False, - ).to_token() - - if since_token: - results["next_batch"] = since_token.copy_and_replace( - direction_is_forward=True, - current_limit=since_token.current_limit - 1, - ).to_token() - - return results - - @defer.inlineCallbacks - def _append_room_entry_to_chunk( - self, - room_id, - num_joined_users, - chunk, - limit, - search_filter, - from_federation=False, - ): - """Generate the entry for a room in the public room list and append it - to the `chunk` if it matches the search filter - - Args: - room_id (str): The ID of the room. - num_joined_users (int): The number of joined users in the room. - chunk (list) - limit (int|None): Maximum amount of rooms to display. Function will - return if length of chunk is greater than limit + 1. - search_filter (dict|None) - from_federation (bool): Whether this request originated from a - federating server or a client. Used for room filtering. - """ - if limit and len(chunk) > limit + 1: - # We've already got enough, so lets just drop it. - return + if batch_token: + response["next_batch"] = RoomListNextBatch( + last_room_id=final_room_id, direction_is_forward=True + ).to_token() + + if more_to_come: + response["prev_batch"] = RoomListNextBatch( + last_room_id=initial_room_id, direction_is_forward=False + ).to_token() + + for room in results: + # populate search result entries with additional fields, namely + # 'aliases' + room_id = room["room_id"] + + aliases = yield self.store.get_aliases_for_room(room_id) + if aliases: + room["aliases"] = aliases - result = yield self.generate_room_entry(room_id, num_joined_users) - if not result: - return + response["chunk"] = results - if from_federation and not result.get("m.federate", True): - # This is a room that other servers cannot join. Do not show them - # this room. - return + response["total_room_count_estimate"] = yield self.store.count_public_rooms( + network_tuple, ignore_non_federatable=from_federation + ) - if _matches_room_entry(result, search_filter): - chunk.append(result) + return response @cachedInlineCallbacks(num_args=1, cache_context=True) def generate_room_entry( @@ -580,32 +449,18 @@ class RoomListNextBatch( namedtuple( "RoomListNextBatch", ( - "stream_ordering", # stream_ordering of the first public room list - "public_room_stream_id", # public room stream id for first public room list - "current_limit", # The number of previous rooms returned + "last_room_id", # The room_id to get rooms after/before "direction_is_forward", # Bool if this is a next_batch, false if prev_batch ), ) ): - - KEY_DICT = { - "stream_ordering": "s", - "public_room_stream_id": "p", - "current_limit": "n", - "direction_is_forward": "d", - } + KEY_DICT = {"last_room_id": "r", "direction_is_forward": "d"} REVERSE_KEY_DICT = {v: k for k, v in KEY_DICT.items()} @classmethod def from_token(cls, token): - if PY3: - # The argument raw=False is only available on new versions of - # msgpack, and only really needed on Python 3. Gate it behind - # a PY3 check to avoid causing issues on Debian-packaged versions. - decoded = msgpack.loads(decode_base64(token), raw=False) - else: - decoded = msgpack.loads(decode_base64(token)) + decoded = msgpack.loads(decode_base64(token), raw=False) return RoomListNextBatch( **{cls.REVERSE_KEY_DICT[key]: val for key, val in decoded.items()} ) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 6bf924dedc..9c1d41421c 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -361,6 +361,10 @@ class PublicRoomListRestServlet(TransactionRestServlet): limit = parse_integer(request, "limit", 0) since_token = parse_string(request, "since", None) + if limit == 0: + # zero is a special value which corresponds to no limit. + limit = None + handler = self.hs.get_room_list_handler() if server: data = yield handler.get_remote_public_room_list( @@ -398,6 +402,10 @@ class PublicRoomListRestServlet(TransactionRestServlet): else: network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id) + if limit == 0: + # zero is a special value which corresponds to no limit. + limit = None + handler = self.hs.get_room_list_handler() if server: data = yield handler.get_remote_public_room_list( diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 08e13f3a3b..c02787a73d 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -63,103 +64,176 @@ class RoomWorkerStore(SQLBaseStore): desc="get_public_room_ids", ) - @cached(num_args=2, max_entries=100) - def get_public_room_ids_at_stream_id(self, stream_id, network_tuple): - """Get pulbic rooms for a particular list, or across all lists. + def count_public_rooms(self, network_tuple, ignore_non_federatable): + """Counts the number of public rooms as tracked in the room_stats_current + and room_stats_state table. Args: - stream_id (int) - network_tuple (ThirdPartyInstanceID): The list to use (None, None) - means the main list, None means all lsits. + network_tuple (ThirdPartyInstanceID|None) + ignore_non_federatable (bool): If true filters out non-federatable rooms """ - return self.runInteraction( - "get_public_room_ids_at_stream_id", - self.get_public_room_ids_at_stream_id_txn, - stream_id, - network_tuple=network_tuple, - ) - - def get_public_room_ids_at_stream_id_txn(self, txn, stream_id, network_tuple): - return { - rm - for rm, vis in self.get_published_at_stream_id_txn( - txn, stream_id, network_tuple=network_tuple - ).items() - if vis - } - def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple): - if network_tuple: - # We want to get from a particular list. No aggregation required. + def _count_public_rooms_txn(txn): + query_args = [] + + if network_tuple: + if network_tuple.appservice_id: + published_sql = """ + SELECT room_id from appservice_room_list + WHERE appservice_id = ? AND network_id = ? + """ + query_args.append(network_tuple.appservice_id) + query_args.append(network_tuple.network_id) + else: + published_sql = """ + SELECT room_id FROM rooms WHERE is_public + """ + else: + published_sql = """ + SELECT room_id FROM rooms WHERE is_public + UNION SELECT room_id from appservice_room_list + """ sql = """ - SELECT room_id, visibility FROM public_room_list_stream - INNER JOIN ( - SELECT room_id, max(stream_id) AS stream_id - FROM public_room_list_stream - WHERE stream_id <= ? %s - GROUP BY room_id - ) grouped USING (room_id, stream_id) - """ + SELECT + COALESCE(COUNT(*), 0) + FROM ( + %(published_sql)s + ) published + INNER JOIN room_stats_state USING (room_id) + INNER JOIN room_stats_current USING (room_id) + WHERE + ( + join_rules = 'public' OR history_visibility = 'world_readable' + ) + AND joined_members > 0 + """ % { + "published_sql": published_sql + } - if network_tuple.appservice_id is not None: - txn.execute( - sql % ("AND appservice_id = ? AND network_id = ?",), - (stream_id, network_tuple.appservice_id, network_tuple.network_id), - ) - else: - txn.execute(sql % ("AND appservice_id IS NULL",), (stream_id,)) - return dict(txn) - else: - # We want to get from all lists, so we need to aggregate the results + txn.execute(sql, query_args) + return txn.fetchone()[0] - logger.info("Executing full list") + return self.runInteraction("count_public_rooms", _count_public_rooms_txn) - sql = """ - SELECT room_id, visibility - FROM public_room_list_stream - INNER JOIN ( - SELECT - room_id, max(stream_id) AS stream_id, appservice_id, - network_id - FROM public_room_list_stream - WHERE stream_id <= ? - GROUP BY room_id, appservice_id, network_id - ) grouped USING (room_id, stream_id) - """ + @defer.inlineCallbacks + def get_largest_public_rooms( + self, + network_tuple, + search_filter, + limit, + last_room_id, + forwards, + ignore_non_federatable=False, + ): + """Gets the largest public rooms (where largest is in terms of joined + members, as tracked in the statistics table). - txn.execute(sql, (stream_id,)) + Args: + network_tuple (ThirdPartyInstanceID|None): + search_filter (dict|None): + limit (int|None): Maxmimum number of rows to return, unlimited otherwise. + last_room_id (str|None): if present, a room ID which bounds the + result set, and is always *excluded* from the result set. + forwards (bool): true iff going forwards, going backwards otherwise + ignore_non_federatable (bool): If true filters out non-federatable rooms. - results = {} - # A room is visible if its visible on any list. - for room_id, visibility in txn: - results[room_id] = bool(visibility) or results.get(room_id, False) + Returns: + Rooms in order: biggest number of joined users first. + We then arbitrarily use the room_id as a tie breaker. - return results + """ - def get_public_room_changes(self, prev_stream_id, new_stream_id, network_tuple): - def get_public_room_changes_txn(txn): - then_rooms = self.get_public_room_ids_at_stream_id_txn( - txn, prev_stream_id, network_tuple - ) + where_clauses = [] + query_args = [] - now_rooms_dict = self.get_published_at_stream_id_txn( - txn, new_stream_id, network_tuple - ) + if last_room_id: + if forwards: + where_clauses.append("room_id < ?") + else: + where_clauses.append("? < room_id") - now_rooms_visible = set(rm for rm, vis in now_rooms_dict.items() if vis) - now_rooms_not_visible = set( - rm for rm, vis in now_rooms_dict.items() if not vis + query_args += [last_room_id] + + if search_filter and search_filter.get("generic_search_term", None): + search_term = "%" + search_filter["generic_search_term"] + "%" + + where_clauses.append( + """ + ( + name LIKE ? + OR topic LIKE ? + OR canonical_alias LIKE ? + ) + """ ) + query_args += [search_term, search_term, search_term] + + if network_tuple: + if network_tuple.appservice_id: + published_sql = """ + SELECT room_id from appservice_room_list + WHERE appservice_id = ? AND network_id = ? + """ + query_args.append(network_tuple.appservice_id) + query_args.append(network_tuple.network_id) + else: + published_sql = """ + SELECT room_id FROM rooms WHERE is_public + """ + else: + published_sql = """ + SELECT room_id FROM rooms WHERE is_public + UNION SELECT room_id from appservice_room_list + """ - newly_visible = now_rooms_visible - then_rooms - newly_unpublished = now_rooms_not_visible & then_rooms + where_clause = "" + if where_clauses: + where_clause = " AND " + " AND ".join(where_clauses) + + sql = """ + SELECT + room_id, name, topic, canonical_alias, joined_members, + avatar, history_visibility, joined_members, guest_access + FROM ( + %(published_sql)s + ) published + INNER JOIN room_stats_state USING (room_id) + INNER JOIN room_stats_current USING (room_id) + WHERE + ( + join_rules = 'public' OR history_visibility = 'world_readable' + ) + AND joined_members > 0 + %(where_clause)s + ORDER BY joined_members %(dir)s, room_id %(dir)s + """ % { + "published_sql": published_sql, + "where_clause": where_clause, + "dir": "DESC" if forwards else "ASC", + } - return newly_visible, newly_unpublished + if limit is not None: + query_args.append(limit) - return self.runInteraction( - "get_public_room_changes", get_public_room_changes_txn + sql += """ + LIMIT ? + """ + + def _get_largest_public_rooms_txn(txn): + txn.execute(sql, query_args) + + results = self.cursor_to_dict(txn) + + if not forwards: + results.reverse() + + return results + + ret_val = yield self.runInteraction( + "get_largest_public_rooms", _get_largest_public_rooms_txn ) + defer.returnValue(ret_val) @cached(max_entries=10000) def is_room_blocked(self, room_id): diff --git a/synapse/storage/schema/delta/56/public_room_list_idx.sql b/synapse/storage/schema/delta/56/public_room_list_idx.sql new file mode 100644 index 0000000000..7be31ffebb --- /dev/null +++ b/synapse/storage/schema/delta/56/public_room_list_idx.sql @@ -0,0 +1,16 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE INDEX public_room_list_stream_network ON public_room_list_stream (appservice_id, network_id, room_id); diff --git a/tests/handlers/test_roomlist.py b/tests/handlers/test_roomlist.py deleted file mode 100644 index 61eebb6985..0000000000 --- a/tests/handlers/test_roomlist.py +++ /dev/null @@ -1,39 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.handlers.room_list import RoomListNextBatch - -import tests.unittest -import tests.utils - - -class RoomListTestCase(tests.unittest.TestCase): - """ Tests RoomList's RoomListNextBatch. """ - - def setUp(self): - pass - - def test_check_read_batch_tokens(self): - batch_token = RoomListNextBatch( - stream_ordering="abcdef", - public_room_stream_id="123", - current_limit=20, - direction_is_forward=True, - ).to_token() - next_batch = RoomListNextBatch.from_token(batch_token) - self.assertEquals(next_batch.stream_ordering, "abcdef") - self.assertEquals(next_batch.public_room_stream_id, "123") - self.assertEquals(next_batch.current_limit, 20) - self.assertEquals(next_batch.direction_is_forward, True) -- cgit 1.5.1 From 6527fa18c1e6f9bcb22318916b5e9534c91c84c1 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 2 Oct 2019 14:44:58 +0100 Subject: Add test case --- synapse/handlers/federation.py | 2 +- tests/handlers/test_federation.py | 83 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+), 1 deletion(-) create mode 100644 tests/handlers/test_federation.py (limited to 'tests') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 75d79bb8e4..91f3a69298 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -2570,7 +2570,7 @@ class FederationHandler(BaseHandler): ) try: - self.auth.check_from_context(room_version, event, context) + yield self.auth.check_from_context(room_version, event, context) except AuthError as e: logger.warn("Denying third party invite %r because %s", event, e) raise e diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py new file mode 100644 index 0000000000..20416a0142 --- /dev/null +++ b/tests/handlers/test_federation.py @@ -0,0 +1,83 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.api.constants import EventTypes +from synapse.api.errors import AuthError, Codes +from synapse.rest import admin +from synapse.rest.client.v1 import login, room + +from tests import unittest + + +class FederationTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver(http_client=None) + self.handler = hs.get_handlers().federation_handler + self.store = hs.get_datastore() + return hs + + def test_exchange_revoked_invite(self): + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + + room_id = self.helper.create_room_as( + room_creator=user_id, tok=tok + ) + + # Send a 3PID invite event with an empty body so it's considered as a revoked one. + invite_token = "sometoken" + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.ThirdPartyInvite, + state_key=invite_token, + body={}, + tok=tok, + ) + + d = self.handler.on_exchange_third_party_invite_request( + room_id=room_id, + event_dict={ + "type": EventTypes.Member, + "room_id": room_id, + "sender": user_id, + "state_key": "@someone:example.org", + "content": { + "membership": "invite", + "third_party_invite": { + "display_name": "alice", + "signed": { + "mxid": "@alice:localhost", + "token": invite_token, + "signatures": { + "magic.forest": { + "ed25519:3": "fQpGIW1Snz+pwLZu6sTy2aHy/DYWWTspTJRPyNp0PKkymfIsNffysMl6ObMMFdIJhk6g6pwlIqZ54rxo8SLmAg" + } + } + } + } + } + } + ) + + failure = self.get_failure(d, AuthError).value + + self.assertEqual(failure.code, 403, failure) + self.assertEqual(failure.errcode, Codes.FORBIDDEN, failure) + self.assertEqual(failure.msg, "You are not invited to this room.") -- cgit 1.5.1 From ebcb6a30d7b1bdb859a1fd22d567b163a1488763 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 3 Oct 2019 11:29:07 +0100 Subject: Lint --- tests/handlers/test_federation.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'tests') diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 20416a0142..a18dfc0e96 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -69,11 +69,11 @@ class FederationTestCase(unittest.HomeserverTestCase): "magic.forest": { "ed25519:3": "fQpGIW1Snz+pwLZu6sTy2aHy/DYWWTspTJRPyNp0PKkymfIsNffysMl6ObMMFdIJhk6g6pwlIqZ54rxo8SLmAg" } - } - } - } - } - } + }, + }, + }, + }, + }, ) failure = self.get_failure(d, AuthError).value -- cgit 1.5.1 From 8a5e8e829b98687ea274fae47db3aa801b6f97d3 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 3 Oct 2019 11:30:43 +0100 Subject: Lint (again) --- tests/handlers/test_federation.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index a18dfc0e96..d56220f403 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -37,9 +37,7 @@ class FederationTestCase(unittest.HomeserverTestCase): user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test") - room_id = self.helper.create_room_as( - room_creator=user_id, tok=tok - ) + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) # Send a 3PID invite event with an empty body so it's considered as a revoked one. invite_token = "sometoken" -- cgit 1.5.1 From 4535a07f4af0854eed0cfd171e4032bdd3f39cbb Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Wed, 9 Oct 2019 17:54:03 -0400 Subject: make version optional in body of e2e backup version update to agree with latest version of the MSC --- synapse/handlers/e2e_room_keys.py | 4 +-- synapse/rest/client/v2_alpha/room_keys.py | 2 +- tests/handlers/test_e2e_room_keys.py | 47 ++++++++++++++++++++----------- 3 files changed, 34 insertions(+), 19 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index a9d80f708c..0cea445f0d 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -352,8 +352,8 @@ class E2eRoomKeysHandler(object): A deferred of an empty dict. """ if "version" not in version_info: - raise SynapseError(400, "Missing version in body", Codes.MISSING_PARAM) - if version_info["version"] != version: + version_info["version"] = version + elif version_info["version"] != version: raise SynapseError( 400, "Version in body does not match", Codes.INVALID_PARAM ) diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py index df4f44cd36..d596786430 100644 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -375,7 +375,7 @@ class RoomKeysVersionServlet(RestServlet): "ed25519:something": "hijklmnop" } }, - "version": "42" + "version": "12345" } HTTP/1.1 200 OK diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index c4503c1611..c700a2fad1 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -187,9 +187,8 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): self.assertEqual(res, 404) @defer.inlineCallbacks - def test_update_bad_version(self): - """Check that we get a 400 if the version in the body is missing or - doesn't match + def test_update_missing_version(self): + """Check that the update succeeds if the version is missing from the body """ version = yield self.handler.create_version( self.local_user, @@ -197,19 +196,35 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) self.assertEqual(version, "1") - res = None - try: - yield self.handler.update_version( - self.local_user, - version, - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - }, - ) - except errors.SynapseError as e: - res = e.code - self.assertEqual(res, 400) + yield self.handler.update_version( + self.local_user, + version, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + }, + ) + + # check we can retrieve it as the current version + res = yield self.handler.get_version_info(self.local_user) + self.assertDictEqual( + res, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": version, + }, + ) + + @defer.inlineCallbacks + def test_update_bad_version(self): + """Check that we get a 400 if the version in the body doesn't match + """ + version = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) + self.assertEqual(version, "1") res = None try: -- cgit 1.5.1 From f743108a94658eb1dbaf168d39874272f756a386 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Thu, 10 Oct 2019 09:39:35 +0100 Subject: Refactor HomeserverConfig so it can be typechecked (#6137) --- changelog.d/6137.misc | 1 + mypy.ini | 16 ++- synapse/config/_base.py | 191 +++++++++++++++++++++++------- synapse/config/_base.pyi | 135 +++++++++++++++++++++ synapse/config/api.py | 2 + synapse/config/appservice.py | 2 + synapse/config/captcha.py | 2 + synapse/config/cas.py | 2 + synapse/config/consent_config.py | 3 + synapse/config/database.py | 2 + synapse/config/emailconfig.py | 2 + synapse/config/groups.py | 2 + synapse/config/homeserver.py | 68 +++++------ synapse/config/jwt_config.py | 2 + synapse/config/key.py | 2 + synapse/config/logger.py | 2 + synapse/config/metrics.py | 2 + synapse/config/password.py | 2 + synapse/config/password_auth_providers.py | 2 + synapse/config/push.py | 2 + synapse/config/ratelimiting.py | 2 + synapse/config/registration.py | 4 + synapse/config/repository.py | 2 + synapse/config/room_directory.py | 2 + synapse/config/saml2_config.py | 2 + synapse/config/server.py | 2 + synapse/config/server_notices_config.py | 2 + synapse/config/spam_checker.py | 2 + synapse/config/stats.py | 2 + synapse/config/third_party_event_rules.py | 2 + synapse/config/tls.py | 9 +- synapse/config/tracer.py | 2 + synapse/config/user_directory.py | 2 + synapse/config/voip.py | 2 + synapse/config/workers.py | 2 + tests/config/test_tls.py | 25 ++-- tox.ini | 3 +- 37 files changed, 415 insertions(+), 94 deletions(-) create mode 100644 changelog.d/6137.misc create mode 100644 synapse/config/_base.pyi (limited to 'tests') diff --git a/changelog.d/6137.misc b/changelog.d/6137.misc new file mode 100644 index 0000000000..92a02e71c3 --- /dev/null +++ b/changelog.d/6137.misc @@ -0,0 +1 @@ +Refactor configuration loading to allow better typechecking. diff --git a/mypy.ini b/mypy.ini index 8788574ee3..ffadaddc0b 100644 --- a/mypy.ini +++ b/mypy.ini @@ -4,10 +4,6 @@ plugins=mypy_zope:plugin follow_imports=skip mypy_path=stubs -[mypy-synapse.config.homeserver] -# this is a mess because of the metaclass shenanigans -ignore_errors = True - [mypy-zope] ignore_missing_imports = True @@ -52,3 +48,15 @@ ignore_missing_imports = True [mypy-signedjson.*] ignore_missing_imports = True + +[mypy-prometheus_client.*] +ignore_missing_imports = True + +[mypy-service_identity.*] +ignore_missing_imports = True + +[mypy-daemonize] +ignore_missing_imports = True + +[mypy-sentry_sdk] +ignore_missing_imports = True diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 31f6530978..08619404bb 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -18,7 +18,9 @@ import argparse import errno import os +from collections import OrderedDict from textwrap import dedent +from typing import Any, MutableMapping, Optional from six import integer_types @@ -51,7 +53,56 @@ Missing mandatory `server_name` config option. """ +def path_exists(file_path): + """Check if a file exists + + Unlike os.path.exists, this throws an exception if there is an error + checking if the file exists (for example, if there is a perms error on + the parent dir). + + Returns: + bool: True if the file exists; False if not. + """ + try: + os.stat(file_path) + return True + except OSError as e: + if e.errno != errno.ENOENT: + raise e + return False + + class Config(object): + """ + A configuration section, containing configuration keys and values. + + Attributes: + section (str): The section title of this config object, such as + "tls" or "logger". This is used to refer to it on the root + logger (for example, `config.tls.some_option`). Must be + defined in subclasses. + """ + + section = None + + def __init__(self, root_config=None): + self.root = root_config + + def __getattr__(self, item: str) -> Any: + """ + Try and fetch a configuration option that does not exist on this class. + + This is so that existing configs that rely on `self.value`, where value + is actually from a different config section, continue to work. + """ + if item in ["generate_config_section", "read_config"]: + raise AttributeError(item) + + if self.root is None: + raise AttributeError(item) + else: + return self.root._get_unclassed_config(self.section, item) + @staticmethod def parse_size(value): if isinstance(value, integer_types): @@ -88,22 +139,7 @@ class Config(object): @classmethod def path_exists(cls, file_path): - """Check if a file exists - - Unlike os.path.exists, this throws an exception if there is an error - checking if the file exists (for example, if there is a perms error on - the parent dir). - - Returns: - bool: True if the file exists; False if not. - """ - try: - os.stat(file_path) - return True - except OSError as e: - if e.errno != errno.ENOENT: - raise e - return False + return path_exists(file_path) @classmethod def check_file(cls, file_path, config_name): @@ -136,42 +172,106 @@ class Config(object): with open(file_path) as file_stream: return file_stream.read() - def invoke_all(self, name, *args, **kargs): - """Invoke all instance methods with the given name and arguments in the - class's MRO. + +class RootConfig(object): + """ + Holder of an application's configuration. + + What configuration this object holds is defined by `config_classes`, a list + of Config classes that will be instantiated and given the contents of a + configuration file to read. They can then be accessed on this class by their + section name, defined in the Config or dynamically set to be the name of the + class, lower-cased and with "Config" removed. + """ + + config_classes = [] + + def __init__(self): + self._configs = OrderedDict() + + for config_class in self.config_classes: + if config_class.section is None: + raise ValueError("%r requires a section name" % (config_class,)) + + try: + conf = config_class(self) + except Exception as e: + raise Exception("Failed making %s: %r" % (config_class.section, e)) + self._configs[config_class.section] = conf + + def __getattr__(self, item: str) -> Any: + """ + Redirect lookups on this object either to config objects, or values on + config objects, so that `config.tls.blah` works, as well as legacy uses + of things like `config.server_name`. It will first look up the config + section name, and then values on those config classes. + """ + if item in self._configs.keys(): + return self._configs[item] + + return self._get_unclassed_config(None, item) + + def _get_unclassed_config(self, asking_section: Optional[str], item: str): + """ + Fetch a config value from one of the instantiated config classes that + has not been fetched directly. + + Args: + asking_section: If this check is coming from a Config child, which + one? This section will not be asked if it has the value. + item: The configuration value key. + + Raises: + AttributeError if no config classes have the config key. The body + will contain what sections were checked. + """ + for key, val in self._configs.items(): + if key == asking_section: + continue + + if item in dir(val): + return getattr(val, item) + + raise AttributeError(item, "not found in %s" % (list(self._configs.keys()),)) + + def invoke_all(self, func_name: str, *args, **kwargs) -> MutableMapping[str, Any]: + """ + Invoke a function on all instantiated config objects this RootConfig is + configured to use. Args: - name (str): Name of function to invoke + func_name: Name of function to invoke *args **kwargs - Returns: - list: The list of the return values from each method called + ordered dictionary of config section name and the result of the + function from it. """ - results = [] - for cls in type(self).mro(): - if name in cls.__dict__: - results.append(getattr(cls, name)(self, *args, **kargs)) - return results + res = OrderedDict() + + for name, config in self._configs.items(): + if hasattr(config, func_name): + res[name] = getattr(config, func_name)(*args, **kwargs) + + return res @classmethod - def invoke_all_static(cls, name, *args, **kargs): - """Invoke all static methods with the given name and arguments in the - class's MRO. + def invoke_all_static(cls, func_name: str, *args, **kwargs): + """ + Invoke a static function on config objects this RootConfig is + configured to use. Args: - name (str): Name of function to invoke + func_name: Name of function to invoke *args **kwargs - Returns: - list: The list of the return values from each method called + ordered dictionary of config section name and the result of the + function from it. """ - results = [] - for c in cls.mro(): - if name in c.__dict__: - results.append(getattr(c, name)(*args, **kargs)) - return results + for config in cls.config_classes: + if hasattr(config, func_name): + getattr(config, func_name)(*args, **kwargs) def generate_config( self, @@ -187,7 +287,8 @@ class Config(object): tls_private_key_path=None, acme_domain=None, ): - """Build a default configuration file + """ + Build a default configuration file This is used when the user explicitly asks us to generate a config file (eg with --generate_config). @@ -242,6 +343,7 @@ class Config(object): Returns: str: the yaml config file """ + return "\n\n".join( dedent(conf) for conf in self.invoke_all( @@ -257,7 +359,7 @@ class Config(object): tls_certificate_path=tls_certificate_path, tls_private_key_path=tls_private_key_path, acme_domain=acme_domain, - ) + ).values() ) @classmethod @@ -444,7 +546,7 @@ class Config(object): ) (config_path,) = config_files - if not cls.path_exists(config_path): + if not path_exists(config_path): print("Generating config file %s" % (config_path,)) if config_args.data_directory: @@ -469,7 +571,7 @@ class Config(object): open_private_ports=config_args.open_private_ports, ) - if not cls.path_exists(config_dir_path): + if not path_exists(config_dir_path): os.makedirs(config_dir_path) with open(config_path, "w") as config_file: config_file.write("# vim:ft=yaml\n\n") @@ -518,7 +620,7 @@ class Config(object): return obj - def parse_config_dict(self, config_dict, config_dir_path, data_dir_path): + def parse_config_dict(self, config_dict, config_dir_path=None, data_dir_path=None): """Read the information from the config dict into this Config object. Args: @@ -607,3 +709,6 @@ def find_config_files(search_paths): else: config_files.append(config_path) return config_files + + +__all__ = ["Config", "RootConfig"] diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi new file mode 100644 index 0000000000..86bc965ee4 --- /dev/null +++ b/synapse/config/_base.pyi @@ -0,0 +1,135 @@ +from typing import Any, List, Optional + +from synapse.config import ( + api, + appservice, + captcha, + cas, + consent_config, + database, + emailconfig, + groups, + jwt_config, + key, + logger, + metrics, + password, + password_auth_providers, + push, + ratelimiting, + registration, + repository, + room_directory, + saml2_config, + server, + server_notices_config, + spam_checker, + stats, + third_party_event_rules, + tls, + tracer, + user_directory, + voip, + workers, +) + +class ConfigError(Exception): ... + +MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS: str +MISSING_REPORT_STATS_SPIEL: str +MISSING_SERVER_NAME: str + +def path_exists(file_path: str): ... + +class RootConfig: + server: server.ServerConfig + tls: tls.TlsConfig + database: database.DatabaseConfig + logging: logger.LoggingConfig + ratelimit: ratelimiting.RatelimitConfig + media: repository.ContentRepositoryConfig + captcha: captcha.CaptchaConfig + voip: voip.VoipConfig + registration: registration.RegistrationConfig + metrics: metrics.MetricsConfig + api: api.ApiConfig + appservice: appservice.AppServiceConfig + key: key.KeyConfig + saml2: saml2_config.SAML2Config + cas: cas.CasConfig + jwt: jwt_config.JWTConfig + password: password.PasswordConfig + email: emailconfig.EmailConfig + worker: workers.WorkerConfig + authproviders: password_auth_providers.PasswordAuthProviderConfig + push: push.PushConfig + spamchecker: spam_checker.SpamCheckerConfig + groups: groups.GroupsConfig + userdirectory: user_directory.UserDirectoryConfig + consent: consent_config.ConsentConfig + stats: stats.StatsConfig + servernotices: server_notices_config.ServerNoticesConfig + roomdirectory: room_directory.RoomDirectoryConfig + thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig + tracer: tracer.TracerConfig + + config_classes: List = ... + def __init__(self) -> None: ... + def invoke_all(self, func_name: str, *args: Any, **kwargs: Any): ... + @classmethod + def invoke_all_static(cls, func_name: str, *args: Any, **kwargs: Any) -> None: ... + def __getattr__(self, item: str): ... + def parse_config_dict( + self, + config_dict: Any, + config_dir_path: Optional[Any] = ..., + data_dir_path: Optional[Any] = ..., + ) -> None: ... + read_config: Any = ... + def generate_config( + self, + config_dir_path: str, + data_dir_path: str, + server_name: str, + generate_secrets: bool = ..., + report_stats: Optional[str] = ..., + open_private_ports: bool = ..., + listeners: Optional[Any] = ..., + database_conf: Optional[Any] = ..., + tls_certificate_path: Optional[str] = ..., + tls_private_key_path: Optional[str] = ..., + acme_domain: Optional[str] = ..., + ): ... + @classmethod + def load_or_generate_config(cls, description: Any, argv: Any): ... + @classmethod + def load_config(cls, description: Any, argv: Any): ... + @classmethod + def add_arguments_to_parser(cls, config_parser: Any) -> None: ... + @classmethod + def load_config_with_parser(cls, parser: Any, argv: Any): ... + def generate_missing_files( + self, config_dict: dict, config_dir_path: str + ) -> None: ... + +class Config: + root: RootConfig + def __init__(self, root_config: Optional[RootConfig] = ...) -> None: ... + def __getattr__(self, item: str, from_root: bool = ...): ... + @staticmethod + def parse_size(value: Any): ... + @staticmethod + def parse_duration(value: Any): ... + @staticmethod + def abspath(file_path: Optional[str]): ... + @classmethod + def path_exists(cls, file_path: str): ... + @classmethod + def check_file(cls, file_path: str, config_name: str): ... + @classmethod + def ensure_directory(cls, dir_path: str): ... + @classmethod + def read_file(cls, file_path: str, config_name: str): ... + +def read_config_files(config_files: List[str]): ... +def find_config_files(search_paths: List[str]): ... diff --git a/synapse/config/api.py b/synapse/config/api.py index dddea79a8a..74cd53a8ed 100644 --- a/synapse/config/api.py +++ b/synapse/config/api.py @@ -18,6 +18,8 @@ from ._base import Config class ApiConfig(Config): + section = "api" + def read_config(self, config, **kwargs): self.room_invite_state_types = config.get( "room_invite_state_types", diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 28d36b1bc3..9b4682222d 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -30,6 +30,8 @@ logger = logging.getLogger(__name__) class AppServiceConfig(Config): + section = "appservice" + def read_config(self, config, **kwargs): self.app_service_config_files = config.get("app_service_config_files", []) self.notify_appservices = config.get("notify_appservices", True) diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py index 8dac8152cf..44bd5c6799 100644 --- a/synapse/config/captcha.py +++ b/synapse/config/captcha.py @@ -16,6 +16,8 @@ from ._base import Config class CaptchaConfig(Config): + section = "captcha" + def read_config(self, config, **kwargs): self.recaptcha_private_key = config.get("recaptcha_private_key") self.recaptcha_public_key = config.get("recaptcha_public_key") diff --git a/synapse/config/cas.py b/synapse/config/cas.py index ebe34d933b..b916c3aa66 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -22,6 +22,8 @@ class CasConfig(Config): cas_server_url: URL of CAS server """ + section = "cas" + def read_config(self, config, **kwargs): cas_config = config.get("cas_config", None) if cas_config: diff --git a/synapse/config/consent_config.py b/synapse/config/consent_config.py index 48976e17b1..62c4c44d60 100644 --- a/synapse/config/consent_config.py +++ b/synapse/config/consent_config.py @@ -73,6 +73,9 @@ DEFAULT_CONFIG = """\ class ConsentConfig(Config): + + section = "consent" + def __init__(self, *args): super(ConsentConfig, self).__init__(*args) diff --git a/synapse/config/database.py b/synapse/config/database.py index 118aafbd4a..0e2509f0b1 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -21,6 +21,8 @@ from ._base import Config class DatabaseConfig(Config): + section = "database" + def read_config(self, config, **kwargs): self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K")) diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index d9b43de660..658897a77e 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -28,6 +28,8 @@ from ._base import Config, ConfigError class EmailConfig(Config): + section = "email" + def read_config(self, config, **kwargs): # TODO: We should separate better the email configuration from the notification # and account validity config. diff --git a/synapse/config/groups.py b/synapse/config/groups.py index 2a522b5f44..d6862d9a64 100644 --- a/synapse/config/groups.py +++ b/synapse/config/groups.py @@ -17,6 +17,8 @@ from ._base import Config class GroupsConfig(Config): + section = "groups" + def read_config(self, config, **kwargs): self.enable_group_creation = config.get("enable_group_creation", False) self.group_creation_prefix = config.get("group_creation_prefix", "") diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 72acad4f18..6e348671c7 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from ._base import RootConfig from .api import ApiConfig from .appservice import AppServiceConfig from .captcha import CaptchaConfig @@ -46,36 +47,37 @@ from .voip import VoipConfig from .workers import WorkerConfig -class HomeServerConfig( - ServerConfig, - TlsConfig, - DatabaseConfig, - LoggingConfig, - RatelimitConfig, - ContentRepositoryConfig, - CaptchaConfig, - VoipConfig, - RegistrationConfig, - MetricsConfig, - ApiConfig, - AppServiceConfig, - KeyConfig, - SAML2Config, - CasConfig, - JWTConfig, - PasswordConfig, - EmailConfig, - WorkerConfig, - PasswordAuthProviderConfig, - PushConfig, - SpamCheckerConfig, - GroupsConfig, - UserDirectoryConfig, - ConsentConfig, - StatsConfig, - ServerNoticesConfig, - RoomDirectoryConfig, - ThirdPartyRulesConfig, - TracerConfig, -): - pass +class HomeServerConfig(RootConfig): + + config_classes = [ + ServerConfig, + TlsConfig, + DatabaseConfig, + LoggingConfig, + RatelimitConfig, + ContentRepositoryConfig, + CaptchaConfig, + VoipConfig, + RegistrationConfig, + MetricsConfig, + ApiConfig, + AppServiceConfig, + KeyConfig, + SAML2Config, + CasConfig, + JWTConfig, + PasswordConfig, + EmailConfig, + WorkerConfig, + PasswordAuthProviderConfig, + PushConfig, + SpamCheckerConfig, + GroupsConfig, + UserDirectoryConfig, + ConsentConfig, + StatsConfig, + ServerNoticesConfig, + RoomDirectoryConfig, + ThirdPartyRulesConfig, + TracerConfig, + ] diff --git a/synapse/config/jwt_config.py b/synapse/config/jwt_config.py index 36d87cef03..a568726985 100644 --- a/synapse/config/jwt_config.py +++ b/synapse/config/jwt_config.py @@ -23,6 +23,8 @@ MISSING_JWT = """Missing jwt library. This is required for jwt login. class JWTConfig(Config): + section = "jwt" + def read_config(self, config, **kwargs): jwt_config = config.get("jwt_config", None) if jwt_config: diff --git a/synapse/config/key.py b/synapse/config/key.py index f039f96e9c..ec5d430afb 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -92,6 +92,8 @@ class TrustedKeyServer(object): class KeyConfig(Config): + section = "key" + def read_config(self, config, config_dir_path, **kwargs): # the signing key can be specified inline or in a separate file if "signing_key" in config: diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 767ecfdf09..d609ec111b 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -84,6 +84,8 @@ root: class LoggingConfig(Config): + section = "logging" + def read_config(self, config, **kwargs): self.log_config = self.abspath(config.get("log_config")) self.no_redirect_stdio = config.get("no_redirect_stdio", False) diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index ec35a6b868..282a43bddb 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -34,6 +34,8 @@ class MetricsFlags(object): class MetricsConfig(Config): + section = "metrics" + def read_config(self, config, **kwargs): self.enable_metrics = config.get("enable_metrics", False) self.report_stats = config.get("report_stats", None) diff --git a/synapse/config/password.py b/synapse/config/password.py index d5b5953f2f..2a634ac751 100644 --- a/synapse/config/password.py +++ b/synapse/config/password.py @@ -20,6 +20,8 @@ class PasswordConfig(Config): """Password login configuration """ + section = "password" + def read_config(self, config, **kwargs): password_config = config.get("password_config", {}) if password_config is None: diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py index c50e244394..9746bbc681 100644 --- a/synapse/config/password_auth_providers.py +++ b/synapse/config/password_auth_providers.py @@ -23,6 +23,8 @@ LDAP_PROVIDER = "ldap_auth_provider.LdapAuthProvider" class PasswordAuthProviderConfig(Config): + section = "authproviders" + def read_config(self, config, **kwargs): self.password_providers = [] # type: List[Any] providers = [] diff --git a/synapse/config/push.py b/synapse/config/push.py index 1b932722a5..0910958649 100644 --- a/synapse/config/push.py +++ b/synapse/config/push.py @@ -18,6 +18,8 @@ from ._base import Config class PushConfig(Config): + section = "push" + def read_config(self, config, **kwargs): push_config = config.get("push", {}) self.push_include_content = push_config.get("include_content", True) diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 587e2862b7..947f653e03 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -36,6 +36,8 @@ class FederationRateLimitConfig(object): class RatelimitConfig(Config): + section = "ratelimiting" + def read_config(self, config, **kwargs): # Load the new-style messages config if it exists. Otherwise fall back diff --git a/synapse/config/registration.py b/synapse/config/registration.py index bef89e2bf4..b3e3e6dda2 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -24,6 +24,8 @@ from synapse.util.stringutils import random_string_with_symbols class AccountValidityConfig(Config): + section = "accountvalidity" + def __init__(self, config, synapse_config): self.enabled = config.get("enabled", False) self.renew_by_email_enabled = "renew_at" in config @@ -77,6 +79,8 @@ class AccountValidityConfig(Config): class RegistrationConfig(Config): + section = "registration" + def read_config(self, config, **kwargs): self.enable_registration = bool( strtobool(str(config.get("enable_registration", False))) diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 14740891f3..d0205e14b9 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -78,6 +78,8 @@ def parse_thumbnail_requirements(thumbnail_sizes): class ContentRepositoryConfig(Config): + section = "media" + def read_config(self, config, **kwargs): # Only enable the media repo if either the media repo is enabled or the diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py index a92693017b..7c9f05bde4 100644 --- a/synapse/config/room_directory.py +++ b/synapse/config/room_directory.py @@ -19,6 +19,8 @@ from ._base import Config, ConfigError class RoomDirectoryConfig(Config): + section = "roomdirectory" + def read_config(self, config, **kwargs): self.enable_room_list_search = config.get("enable_room_list_search", True) diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index ab34b41ca8..c407e13680 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -55,6 +55,8 @@ def _dict_merge(merge_dict, into_dict): class SAML2Config(Config): + section = "saml2" + def read_config(self, config, **kwargs): self.saml2_enabled = False diff --git a/synapse/config/server.py b/synapse/config/server.py index 709bd387e5..afc4d6a4ab 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -58,6 +58,8 @@ on how to configure the new listener. class ServerConfig(Config): + section = "server" + def read_config(self, config, **kwargs): self.server_name = config["server_name"] self.server_context = config.get("server_context", None) diff --git a/synapse/config/server_notices_config.py b/synapse/config/server_notices_config.py index 6d4285ef93..6ea2ea8869 100644 --- a/synapse/config/server_notices_config.py +++ b/synapse/config/server_notices_config.py @@ -59,6 +59,8 @@ class ServerNoticesConfig(Config): None if server notices are not enabled. """ + section = "servernotices" + def __init__(self, *args): super(ServerNoticesConfig, self).__init__(*args) self.server_notices_mxid = None diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py index e40797ab50..36e0ddab5c 100644 --- a/synapse/config/spam_checker.py +++ b/synapse/config/spam_checker.py @@ -19,6 +19,8 @@ from ._base import Config class SpamCheckerConfig(Config): + section = "spamchecker" + def read_config(self, config, **kwargs): self.spam_checker = None diff --git a/synapse/config/stats.py b/synapse/config/stats.py index b18ddbd1fa..62485189ea 100644 --- a/synapse/config/stats.py +++ b/synapse/config/stats.py @@ -25,6 +25,8 @@ class StatsConfig(Config): Configuration for the behaviour of synapse's stats engine """ + section = "stats" + def read_config(self, config, **kwargs): self.stats_enabled = True self.stats_bucket_size = 86400 * 1000 diff --git a/synapse/config/third_party_event_rules.py b/synapse/config/third_party_event_rules.py index b3431441b9..10a99c792e 100644 --- a/synapse/config/third_party_event_rules.py +++ b/synapse/config/third_party_event_rules.py @@ -19,6 +19,8 @@ from ._base import Config class ThirdPartyRulesConfig(Config): + section = "thirdpartyrules" + def read_config(self, config, **kwargs): self.third_party_event_rules = None diff --git a/synapse/config/tls.py b/synapse/config/tls.py index fc47ba3e9a..f06341eb67 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -18,6 +18,7 @@ import os import warnings from datetime import datetime from hashlib import sha256 +from typing import List import six @@ -33,7 +34,9 @@ logger = logging.getLogger(__name__) class TlsConfig(Config): - def read_config(self, config, config_dir_path, **kwargs): + section = "tls" + + def read_config(self, config: dict, config_dir_path: str, **kwargs): acme_config = config.get("acme", None) if acme_config is None: @@ -57,7 +60,7 @@ class TlsConfig(Config): self.tls_certificate_file = self.abspath(config.get("tls_certificate_path")) self.tls_private_key_file = self.abspath(config.get("tls_private_key_path")) - if self.has_tls_listener(): + if self.root.server.has_tls_listener(): if not self.tls_certificate_file: raise ConfigError( "tls_certificate_path must be specified if TLS-enabled listeners are " @@ -108,7 +111,7 @@ class TlsConfig(Config): ) # Support globs (*) in whitelist values - self.federation_certificate_verification_whitelist = [] + self.federation_certificate_verification_whitelist = [] # type: List[str] for entry in fed_whitelist_entries: try: entry_regex = glob_to_regex(entry.encode("ascii").decode("ascii")) diff --git a/synapse/config/tracer.py b/synapse/config/tracer.py index 85d99a3166..8be1346113 100644 --- a/synapse/config/tracer.py +++ b/synapse/config/tracer.py @@ -19,6 +19,8 @@ from ._base import Config, ConfigError class TracerConfig(Config): + section = "tracing" + def read_config(self, config, **kwargs): opentracing_config = config.get("opentracing") if opentracing_config is None: diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py index f6313e17d4..c8d19c5d6b 100644 --- a/synapse/config/user_directory.py +++ b/synapse/config/user_directory.py @@ -21,6 +21,8 @@ class UserDirectoryConfig(Config): Configuration for the behaviour of the /user_directory API """ + section = "userdirectory" + def read_config(self, config, **kwargs): self.user_directory_search_enabled = True self.user_directory_search_all_users = False diff --git a/synapse/config/voip.py b/synapse/config/voip.py index 2ca0e1cf70..a68a3068aa 100644 --- a/synapse/config/voip.py +++ b/synapse/config/voip.py @@ -16,6 +16,8 @@ from ._base import Config class VoipConfig(Config): + section = "voip" + def read_config(self, config, **kwargs): self.turn_uris = config.get("turn_uris", []) self.turn_shared_secret = config.get("turn_shared_secret") diff --git a/synapse/config/workers.py b/synapse/config/workers.py index 1ec4998625..fef72ed974 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -21,6 +21,8 @@ class WorkerConfig(Config): They have their own pid_file and listener configuration. They use the replication_url to talk to the main synapse process.""" + section = "worker" + def read_config(self, config, **kwargs): self.worker_app = config.get("worker_app") diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py index b02780772a..1be6ff563b 100644 --- a/tests/config/test_tls.py +++ b/tests/config/test_tls.py @@ -21,17 +21,24 @@ import yaml from OpenSSL import SSL +from synapse.config._base import Config, RootConfig from synapse.config.tls import ConfigError, TlsConfig from synapse.crypto.context_factory import ClientTLSOptionsFactory from tests.unittest import TestCase -class TestConfig(TlsConfig): +class FakeServer(Config): + section = "server" + def has_tls_listener(self): return False +class TestConfig(RootConfig): + config_classes = [FakeServer, TlsConfig] + + class TLSConfigTests(TestCase): def test_warn_self_signed(self): """ @@ -202,13 +209,13 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg= conf = TestConfig() conf.read_config( yaml.safe_load( - TestConfig().generate_config_section( + TestConfig().generate_config( "/config_dir_path", "my_super_secure_server", "/data_dir_path", - "/tls_cert_path", - "tls_private_key", - None, # This is the acme_domain + tls_certificate_path="/tls_cert_path", + tls_private_key_path="tls_private_key", + acme_domain=None, # This is the acme_domain ) ), "/config_dir_path", @@ -223,13 +230,13 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg= conf = TestConfig() conf.read_config( yaml.safe_load( - TestConfig().generate_config_section( + TestConfig().generate_config( "/config_dir_path", "my_super_secure_server", "/data_dir_path", - "/tls_cert_path", - "tls_private_key", - "my_supe_secure_server", # This is the acme_domain + tls_certificate_path="/tls_cert_path", + tls_private_key_path="tls_private_key", + acme_domain="my_supe_secure_server", # This is the acme_domain ) ), "/config_dir_path", diff --git a/tox.ini b/tox.ini index 1bce10a4ce..367cc2ccf2 100644 --- a/tox.ini +++ b/tox.ini @@ -163,10 +163,9 @@ deps = {[base]deps} mypy mypy-zope - typeshed env = MYPYPATH = stubs/ extras = all -commands = mypy --show-traceback \ +commands = mypy --show-traceback --check-untyped-defs --show-error-codes --follow-imports=normal \ synapse/logging/ \ synapse/config/ -- cgit 1.5.1 From 1d6dd1c2944c22147258dda8ccf2777c68b38fba Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 10 Oct 2019 10:53:06 +0100 Subject: Move patch_inline_callbacks into synapse/ --- synapse/__init__.py | 2 +- synapse/util/patch_inline_callbacks.py | 179 +++++++++++++++++++++++++++++++++ tests/__init__.py | 4 +- tests/patch_inline_callbacks.py | 179 --------------------------------- 4 files changed, 182 insertions(+), 182 deletions(-) create mode 100644 synapse/util/patch_inline_callbacks.py delete mode 100644 tests/patch_inline_callbacks.py (limited to 'tests') diff --git a/synapse/__init__.py b/synapse/__init__.py index 1055f54e00..bf102244a9 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -20,7 +20,7 @@ import os import sys -from tests.patch_inline_callbacks import do_patch +from synapse.util.patch_inline_callbacks import do_patch # Check that we're not running on an unsupported Python version. if sys.version_info < (3, 5): diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py new file mode 100644 index 0000000000..4fb49b0b2b --- /dev/null +++ b/synapse/util/patch_inline_callbacks.py @@ -0,0 +1,179 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import functools +import sys + +from twisted.internet import defer +from twisted.internet.defer import Deferred +from twisted.python.failure import Failure + + +def do_patch(): + """ + Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit + """ + + from synapse.logging.context import LoggingContext + + orig_inline_callbacks = defer.inlineCallbacks + if hasattr(orig_inline_callbacks, "patched_by_synapse"): + return + + def new_inline_callbacks(f): + @functools.wraps(f) + def wrapped(*args, **kwargs): + start_context = LoggingContext.current_context() + changes = [] + orig = orig_inline_callbacks(_check_yield_points(f, changes, start_context)) + + try: + res = orig(*args, **kwargs) + except Exception: + if LoggingContext.current_context() != start_context: + for err in changes: + print(err, file=sys.stderr) + + err = "%s changed context from %s to %s on exception" % ( + f, + start_context, + LoggingContext.current_context(), + ) + print(err, file=sys.stderr) + raise Exception(err) + raise + + if not isinstance(res, Deferred) or res.called: + if LoggingContext.current_context() != start_context: + for err in changes: + print(err, file=sys.stderr) + + err = "Completed %s changed context from %s to %s" % ( + f, + start_context, + LoggingContext.current_context(), + ) + # print the error to stderr because otherwise all we + # see in travis-ci is the 500 error + print(err, file=sys.stderr) + raise Exception(err) + return res + + if LoggingContext.current_context() != LoggingContext.sentinel: + err = ( + "%s returned incomplete deferred in non-sentinel context " + "%s (start was %s)" + ) % (f, LoggingContext.current_context(), start_context) + print(err, file=sys.stderr) + raise Exception(err) + + def check_ctx(r): + if LoggingContext.current_context() != start_context: + for err in changes: + print(err, file=sys.stderr) + err = "%s completion of %s changed context from %s to %s" % ( + "Failure" if isinstance(r, Failure) else "Success", + f, + start_context, + LoggingContext.current_context(), + ) + print(err, file=sys.stderr) + raise Exception(err) + return r + + res.addBoth(check_ctx) + return res + + return wrapped + + defer.inlineCallbacks = new_inline_callbacks + new_inline_callbacks.patched_by_synapse = True + + +def _check_yield_points(f, changes, start_context): + """Wraps a generator that is about to be passed to defer.inlineCallbacks + checking that after every yield the log contexts are correct. + """ + + from synapse.logging.context import LoggingContext + + @functools.wraps(f) + def check_yield_points_inner(*args, **kwargs): + expected_context = start_context + + gen = f(*args, **kwargs) + + last_yield_line_no = 1 + result = None + while True: + try: + isFailure = isinstance(result, Failure) + if isFailure: + d = result.throwExceptionIntoGenerator(gen) + else: + d = gen.send(result) + except (StopIteration, defer._DefGen_Return) as e: + if LoggingContext.current_context() != expected_context: + # This happens when the context is lost sometime *after* the + # final yield and returning. E.g. we forgot to yield on a + # function that returns a deferred. + err = ( + "Function %r returned and changed context from %s to %s," + " in %s between %d and end of func" + % ( + f.__qualname__, + start_context, + LoggingContext.current_context(), + f.__code__.co_filename, + last_yield_line_no, + ) + ) + changes.append(err) + # raise Exception(err) + return getattr(e, "value", None) + + try: + result = yield d + except Exception as e: + result = Failure(e) + + frame = gen.gi_frame + + if LoggingContext.current_context() != expected_context: + # This happens because the context is lost sometime *after* the + # previous yield and *after* the current yield. E.g. the + # deferred we waited on didn't follow the rules, or we forgot to + # yield on a function between the two yield points. + err = ( + "%s changed context from %s to %s, happened between lines %d and %d in %s" + % ( + frame.f_code.co_name, + start_context, + LoggingContext.current_context(), + last_yield_line_no, + frame.f_lineno, + frame.f_code.co_filename, + ) + ) + changes.append(err) + # raise Exception(err) + + expected_context = LoggingContext.current_context() + + last_yield_line_no = frame.f_lineno + + return check_yield_points_inner diff --git a/tests/__init__.py b/tests/__init__.py index f7fc502f01..ed805db1c2 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -16,9 +16,9 @@ from twisted.trial import util -import tests.patch_inline_callbacks +from synapse.util.patch_inline_callbacks import do_patch # attempt to do the patch before we load any synapse code -tests.patch_inline_callbacks.do_patch() +do_patch() util.DEFAULT_TIMEOUT_DURATION = 20 diff --git a/tests/patch_inline_callbacks.py b/tests/patch_inline_callbacks.py deleted file mode 100644 index a35a1d3305..0000000000 --- a/tests/patch_inline_callbacks.py +++ /dev/null @@ -1,179 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import print_function - -import functools -import sys - -from twisted.internet import defer -from twisted.internet.defer import Deferred -from twisted.python.failure import Failure - - -def do_patch(): - """ - Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit - """ - - from synapse.logging.context import LoggingContext - - orig_inline_callbacks = defer.inlineCallbacks - if hasattr(orig_inline_callbacks, "patched_by_synapse"): - return - - def new_inline_callbacks(f): - @functools.wraps(f) - def wrapped(*args, **kwargs): - start_context = LoggingContext.current_context() - changes = [] - orig = orig_inline_callbacks(_check_yield_points(f, changes, start_context)) - - try: - res = orig(*args, **kwargs) - except Exception: - if LoggingContext.current_context() != start_context: - for err in changes: - print(err, file=sys.stderr) - - err = "%s changed context from %s to %s on exception" % ( - f, - start_context, - LoggingContext.current_context(), - ) - print(err, file=sys.stderr) - raise Exception(err) - raise - - if not isinstance(res, Deferred) or res.called: - if LoggingContext.current_context() != start_context: - for err in changes: - print(err, file=sys.stderr) - - err = "Completed %s changed context from %s to %s" % ( - f, - start_context, - LoggingContext.current_context(), - ) - # print the error to stderr because otherwise all we - # see in travis-ci is the 500 error - print(err, file=sys.stderr) - raise Exception(err) - return res - - if LoggingContext.current_context() != LoggingContext.sentinel: - err = ( - "%s returned incomplete deferred in non-sentinel context " - "%s (start was %s)" - ) % (f, LoggingContext.current_context(), start_context) - print(err, file=sys.stderr) - raise Exception(err) - - def check_ctx(r): - if LoggingContext.current_context() != start_context: - for err in changes: - print(err, file=sys.stderr) - err = "%s completion of %s changed context from %s to %s" % ( - "Failure" if isinstance(r, Failure) else "Success", - f, - start_context, - LoggingContext.current_context(), - ) - print(err, file=sys.stderr) - raise Exception(err) - return r - - res.addBoth(check_ctx) - return res - - return wrapped - - defer.inlineCallbacks = new_inline_callbacks - new_inline_callbacks.patched_by_synapse = True - - -def _check_yield_points(f, changes, start_context): - """Wraps a generator that is about to passed to defer.inlineCallbacks - checking that after every yield the log contexts are correct. - """ - - from synapse.logging.context import LoggingContext - - @functools.wraps(f) - def check_yield_points_inner(*args, **kwargs): - expected_context = start_context - - gen = f(*args, **kwargs) - - last_yield_line_no = 1 - result = None - while True: - try: - isFailure = isinstance(result, Failure) - if isFailure: - d = result.throwExceptionIntoGenerator(gen) - else: - d = gen.send(result) - except (StopIteration, defer._DefGen_Return) as e: - if LoggingContext.current_context() != expected_context: - # This happens when the context is lost sometime *after* the - # final yield and returning. E.g. we forgot to yield on a - # function that returns a deferred. - err = ( - "Function %r returned and changed context from %s to %s," - " in %s between %d and end of func" - % ( - f.__qualname__, - start_context, - LoggingContext.current_context(), - f.__code__.co_filename, - last_yield_line_no, - ) - ) - changes.append(err) - # raise Exception(err) - return getattr(e, "value", None) - - try: - result = yield d - except Exception as e: - result = Failure(e) - - frame = gen.gi_frame - - if LoggingContext.current_context() != expected_context: - # This happens because the context is lost sometime *after* the - # previous yield and *after* the current yield. E.g. the - # deferred we waited on didn't follow the rules, or we forgot to - # yield on a function between the two yield points. - err = ( - "%s changed context from %s to %s, happened between lines %d and %d in %s" - % ( - frame.f_code.co_name, - start_context, - LoggingContext.current_context(), - last_yield_line_no, - frame.f_lineno, - frame.f_code.co_filename, - ) - ) - changes.append(err) - # raise Exception(err) - - expected_context = LoggingContext.current_context() - - last_yield_line_no = frame.f_lineno - - return check_yield_points_inner -- cgit 1.5.1 From a139420a3cfda6a4a4ee4750611b31dd71fc33f3 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 10 Oct 2019 11:29:01 +0100 Subject: Fix races in room stats (and other) updates. (#6187) Hopefully this will fix the occasional failures we were seeing in the room directory. The problem was that events are not necessarily persisted (and `current_state_delta_stream` updated) in the same order as their stream_id. So for instance current_state_delta 9 might be persisted *before* current_state_delta 8. Then, when the room stats saw stream_id 9, it assumed it had done everything up to 9, and never came back to do stream_id 8. We can solve this easily by only processing up to the stream_id where we know all events have been persisted. --- changelog.d/6187.bugfix | 1 + synapse/handlers/presence.py | 16 ++++++++++++---- synapse/handlers/stats.py | 12 +++++++----- synapse/handlers/user_directory.py | 17 ++++++++++++----- synapse/storage/state_deltas.py | 38 +++++++++++++++++++++++++++++--------- tests/handlers/test_typing.py | 2 +- tests/rest/admin/test_admin.py | 2 +- 7 files changed, 63 insertions(+), 25 deletions(-) create mode 100644 changelog.d/6187.bugfix (limited to 'tests') diff --git a/changelog.d/6187.bugfix b/changelog.d/6187.bugfix new file mode 100644 index 0000000000..6142c5b98d --- /dev/null +++ b/changelog.d/6187.bugfix @@ -0,0 +1 @@ +Fix occasional missed updates in the room and user directories. \ No newline at end of file diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 053cf66b28..2a5f1a007d 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -803,17 +803,25 @@ class PresenceHandler(object): # Loop round handling deltas until we're up to date while True: with Measure(self.clock, "presence_delta"): - deltas = yield self.store.get_current_state_deltas(self._event_pos) - if not deltas: + room_max_stream_ordering = self.store.get_room_max_stream_ordering() + if self._event_pos == room_max_stream_ordering: return + logger.debug( + "Processing presence stats %s->%s", + self._event_pos, + room_max_stream_ordering, + ) + max_pos, deltas = yield self.store.get_current_state_deltas( + self._event_pos, room_max_stream_ordering + ) yield self._handle_state_delta(deltas) - self._event_pos = deltas[-1]["stream_id"] + self._event_pos = max_pos # Expose current event processing position to prometheus synapse.metrics.event_processing_positions.labels("presence").set( - self._event_pos + max_pos ) @defer.inlineCallbacks diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index c62b113115..466daf9202 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -87,21 +87,23 @@ class StatsHandler(StateDeltasHandler): # Be sure to read the max stream_ordering *before* checking if there are any outstanding # deltas, since there is otherwise a chance that we could miss updates which arrive # after we check the deltas. - room_max_stream_ordering = yield self.store.get_room_max_stream_ordering() + room_max_stream_ordering = self.store.get_room_max_stream_ordering() if self.pos == room_max_stream_ordering: break - deltas = yield self.store.get_current_state_deltas(self.pos) + logger.debug( + "Processing room stats %s->%s", self.pos, room_max_stream_ordering + ) + max_pos, deltas = yield self.store.get_current_state_deltas( + self.pos, room_max_stream_ordering + ) if deltas: logger.debug("Handling %d state deltas", len(deltas)) room_deltas, user_deltas = yield self._handle_deltas(deltas) - - max_pos = deltas[-1]["stream_id"] else: room_deltas = {} user_deltas = {} - max_pos = room_max_stream_ordering # Then count deltas for total_events and total_event_bytes. room_count, user_count = yield self.store.get_changes_room_total_events_and_bytes( diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index e53669e40d..624f05ab5b 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -138,21 +138,28 @@ class UserDirectoryHandler(StateDeltasHandler): # Loop round handling deltas until we're up to date while True: with Measure(self.clock, "user_dir_delta"): - deltas = yield self.store.get_current_state_deltas(self.pos) - if not deltas: + room_max_stream_ordering = self.store.get_room_max_stream_ordering() + if self.pos == room_max_stream_ordering: return + logger.debug( + "Processing user stats %s->%s", self.pos, room_max_stream_ordering + ) + max_pos, deltas = yield self.store.get_current_state_deltas( + self.pos, room_max_stream_ordering + ) + logger.info("Handling %d state deltas", len(deltas)) yield self._handle_deltas(deltas) - self.pos = deltas[-1]["stream_id"] + self.pos = max_pos # Expose current event processing position to prometheus synapse.metrics.event_processing_positions.labels("user_dir").set( - self.pos + max_pos ) - yield self.store.update_user_directory_stream_pos(self.pos) + yield self.store.update_user_directory_stream_pos(max_pos) @defer.inlineCallbacks def _handle_deltas(self, deltas): diff --git a/synapse/storage/state_deltas.py b/synapse/storage/state_deltas.py index 5fdb442104..28f33ec18f 100644 --- a/synapse/storage/state_deltas.py +++ b/synapse/storage/state_deltas.py @@ -21,7 +21,7 @@ logger = logging.getLogger(__name__) class StateDeltasStore(SQLBaseStore): - def get_current_state_deltas(self, prev_stream_id): + def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int): """Fetch a list of room state changes since the given stream id Each entry in the result contains the following fields: @@ -36,15 +36,27 @@ class StateDeltasStore(SQLBaseStore): Args: prev_stream_id (int): point to get changes since (exclusive) + max_stream_id (int): the point that we know has been correctly persisted + - ie, an upper limit to return changes from. Returns: - Deferred[list[dict]]: results + Deferred[tuple[int, list[dict]]: A tuple consisting of: + - the stream id which these results go up to + - list of current_state_delta_stream rows. If it is empty, we are + up to date. """ prev_stream_id = int(prev_stream_id) + + # check we're not going backwards + assert prev_stream_id <= max_stream_id + if not self._curr_state_delta_stream_cache.has_any_entity_changed( prev_stream_id ): - return [] + # if the CSDs haven't changed between prev_stream_id and now, we + # know for certain that they haven't changed between prev_stream_id and + # max_stream_id. + return max_stream_id, [] def get_current_state_deltas_txn(txn): # First we calculate the max stream id that will give us less than @@ -54,21 +66,29 @@ class StateDeltasStore(SQLBaseStore): sql = """ SELECT stream_id, count(*) FROM current_state_delta_stream - WHERE stream_id > ? + WHERE stream_id > ? AND stream_id <= ? GROUP BY stream_id ORDER BY stream_id ASC LIMIT 100 """ - txn.execute(sql, (prev_stream_id,)) + txn.execute(sql, (prev_stream_id, max_stream_id)) total = 0 - max_stream_id = prev_stream_id - for max_stream_id, count in txn: + + for stream_id, count in txn: total += count if total > 100: # We arbitarily limit to 100 entries to ensure we don't # select toooo many. + logger.debug( + "Clipping current_state_delta_stream rows to stream_id %i", + stream_id, + ) + clipped_stream_id = stream_id break + else: + # if there's no problem, we may as well go right up to the max_stream_id + clipped_stream_id = max_stream_id # Now actually get the deltas sql = """ @@ -77,8 +97,8 @@ class StateDeltasStore(SQLBaseStore): WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC """ - txn.execute(sql, (prev_stream_id, max_stream_id)) - return self.cursor_to_dict(txn) + txn.execute(sql, (prev_stream_id, clipped_stream_id)) + return clipped_stream_id, self.cursor_to_dict(txn) return self.runInteraction( "get_current_state_deltas", get_current_state_deltas_txn diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 1f2ef5d01f..67f1013051 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -139,7 +139,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): defer.succeed(1) ) - self.datastore.get_current_state_deltas.return_value = None + self.datastore.get_current_state_deltas.return_value = (0, None) self.datastore.get_to_device_stream_token = lambda: 0 self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: ([], 0) diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 5877bb2133..d3a4f717f7 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -62,7 +62,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): self.device_handler.check_device_registered = Mock(return_value="FAKE") self.datastore = Mock(return_value=Mock()) - self.datastore.get_current_state_deltas = Mock(return_value=[]) + self.datastore.get_current_state_deltas = Mock(return_value=(0, [])) self.secrets = Mock() -- cgit 1.5.1 From 2efd050c9db2e96fd96535dc9b1c6f54acbd163d Mon Sep 17 00:00:00 2001 From: krombel Date: Thu, 10 Oct 2019 13:59:55 +0200 Subject: send 404 as http-status when filter-id is unknown to the server (#2380) This fixed the weirdness of 400 vs 404 as http status code in the case the filter id is not known by the server. As e.g. matrix-js-sdk expects 404 to catch this situation this leads to unwanted behaviour. --- changelog.d/2380.bugfix | 1 + synapse/rest/client/v2_alpha/filter.py | 12 +++++---- synapse/rest/client/v2_alpha/sync.py | 41 ++++++++++++++++++------------- tests/rest/client/v2_alpha/test_filter.py | 2 +- 4 files changed, 33 insertions(+), 23 deletions(-) create mode 100644 changelog.d/2380.bugfix (limited to 'tests') diff --git a/changelog.d/2380.bugfix b/changelog.d/2380.bugfix new file mode 100644 index 0000000000..eae3206031 --- /dev/null +++ b/changelog.d/2380.bugfix @@ -0,0 +1 @@ +Return an HTTP 404 instead of 400 when requesting a filter by ID that is unknown to the server. Thanks to @krombel for contributing this! diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index c6ddf24c8d..17a8bc7366 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -17,7 +17,7 @@ import logging from twisted.internet import defer -from synapse.api.errors import AuthError, Codes, StoreError, SynapseError +from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.types import UserID @@ -52,13 +52,15 @@ class GetFilterRestServlet(RestServlet): raise SynapseError(400, "Invalid filter_id") try: - filter = yield self.filtering.get_user_filter( + filter_collection = yield self.filtering.get_user_filter( user_localpart=target_user.localpart, filter_id=filter_id ) + except StoreError as e: + if e.code != 404: + raise + raise NotFoundError("No such filter") - return 200, filter.get_filter_json() - except (KeyError, StoreError): - raise SynapseError(400, "No such filter", errcode=Codes.NOT_FOUND) + return 200, filter_collection.get_filter_json() class CreateFilterRestServlet(RestServlet): diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index c98c5a3802..a883c8adda 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -21,7 +21,7 @@ from canonicaljson import json from twisted.internet import defer from synapse.api.constants import PresenceState -from synapse.api.errors import SynapseError +from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection from synapse.events.utils import ( format_event_for_client_v2_without_room_id, @@ -119,25 +119,32 @@ class SyncRestServlet(RestServlet): request_key = (user, timeout, since, filter_id, full_state, device_id) - if filter_id: - if filter_id.startswith("{"): - try: - filter_object = json.loads(filter_id) - set_timeline_upper_limit( - filter_object, self.hs.config.filter_timeline_limit - ) - except Exception: - raise SynapseError(400, "Invalid filter JSON") - self.filtering.check_valid_filter(filter_object) - filter = FilterCollection(filter_object) - else: - filter = yield self.filtering.get_user_filter(user.localpart, filter_id) + if filter_id is None: + filter_collection = DEFAULT_FILTER_COLLECTION + elif filter_id.startswith("{"): + try: + filter_object = json.loads(filter_id) + set_timeline_upper_limit( + filter_object, self.hs.config.filter_timeline_limit + ) + except Exception: + raise SynapseError(400, "Invalid filter JSON") + self.filtering.check_valid_filter(filter_object) + filter_collection = FilterCollection(filter_object) else: - filter = DEFAULT_FILTER_COLLECTION + try: + filter_collection = yield self.filtering.get_user_filter( + user.localpart, filter_id + ) + except StoreError as err: + if err.code != 404: + raise + # fix up the description and errcode to be more useful + raise SynapseError(400, "No such filter", errcode=Codes.INVALID_PARAM) sync_config = SyncConfig( user=user, - filter_collection=filter, + filter_collection=filter_collection, is_guest=requester.is_guest, request_key=request_key, device_id=device_id, @@ -171,7 +178,7 @@ class SyncRestServlet(RestServlet): time_now = self.clock.time_msec() response_content = yield self.encode_response( - time_now, sync_result, requester.access_token_id, filter + time_now, sync_result, requester.access_token_id, filter_collection ) return 200, response_content diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py index f42a8efbf4..e0e9e94fbf 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/v2_alpha/test_filter.py @@ -92,7 +92,7 @@ class FilterTestCase(unittest.HomeserverTestCase): ) self.render(request) - self.assertEqual(channel.result["code"], b"400") + self.assertEqual(channel.result["code"], b"404") self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND) # Currently invalid params do not have an appropriate errcode -- cgit 1.5.1 From b5b03b7079a9baa34a25915d6a569e383e8307c3 Mon Sep 17 00:00:00 2001 From: werner291 Date: Thu, 10 Oct 2019 14:05:48 +0200 Subject: Add domain validation when creating room with list of invitees (#6121) --- changelog.d/4088.bugfix | 1 + synapse/handlers/room.py | 4 +++- tests/rest/client/v1/test_rooms.py | 9 +++++++++ 3 files changed, 13 insertions(+), 1 deletion(-) create mode 100644 changelog.d/4088.bugfix (limited to 'tests') diff --git a/changelog.d/4088.bugfix b/changelog.d/4088.bugfix new file mode 100644 index 0000000000..61722b6224 --- /dev/null +++ b/changelog.d/4088.bugfix @@ -0,0 +1 @@ +Added domain validation when including a list of invitees upon room creation. \ No newline at end of file diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 970be3c846..2816bd8f87 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -28,6 +28,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.http.endpoint import parse_and_validate_server_name from synapse.storage.state import StateFilter from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID from synapse.util import stringutils @@ -554,7 +555,8 @@ class RoomCreationHandler(BaseHandler): invite_list = config.get("invite", []) for i in invite_list: try: - UserID.from_string(i) + uid = UserID.from_string(i) + parse_and_validate_server_name(uid.domain) except Exception: raise SynapseError(400, "Invalid user_id: %s" % (i,)) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index fe741637f5..2f2ca74611 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -484,6 +484,15 @@ class RoomsCreateTestCase(RoomBase): self.render(request) self.assertEquals(400, channel.code) + def test_post_room_invitees_invalid_mxid(self): + # POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088 + # Note the trailing space in the MXID here! + request, channel = self.make_request( + "POST", "/createRoom", b'{"invite":["@alice:example.com "]}' + ) + self.render(request) + self.assertEquals(400, channel.code) + class RoomTopicTestCase(RoomBase): """ Tests /rooms/$room_id/topic REST events. """ -- cgit 1.5.1 From 5373de6cced56c983098c82872cf17c311abdb96 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Thu, 10 Oct 2019 08:54:07 -0400 Subject: change test name to be unique --- tests/handlers/test_e2e_room_keys.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index c700a2fad1..0bb96674a2 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -187,7 +187,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): self.assertEqual(res, 404) @defer.inlineCallbacks - def test_update_missing_version(self): + def test_update_omitted_version(self): """Check that the update succeeds if the version is missing from the body """ version = yield self.handler.create_version( -- cgit 1.5.1 From bc244627ac759bbd4691a2a20ac16383b4c2348c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 10 Oct 2019 15:37:53 +0100 Subject: Fix postgres unit tests --- tests/storage/test_event_federation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index b58386994e..2fe50377f8 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -57,7 +57,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): "(event_id, algorithm, hash) " "VALUES (?, 'sha256', ?)" ), - (event_id, b"ffff"), + (event_id, bytearray(b"ffff")), ) for i in range(0, 11): -- cgit 1.5.1 From 4908fb3b30ac007fda5993521448804067751a6d Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Thu, 10 Oct 2019 15:56:00 -0400 Subject: make storage layer in charge of interpreting the device key data --- synapse/handlers/e2e_keys.py | 11 ----------- synapse/storage/end_to_end_keys.py | 11 +++++++++-- tests/storage/test_end_to_end_keys.py | 12 ++++++------ 3 files changed, 15 insertions(+), 19 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 056fb97acb..6708d983ac 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -248,17 +248,6 @@ class E2eKeysHandler(object): results = yield self.store.get_e2e_device_keys(local_query) - # Build the result structure, un-jsonify the results, and add the - # "unsigned" section - for user_id, device_keys in results.items(): - for device_id, device_info in device_keys.items(): - r = dict(device_info["keys"]) - r["unsigned"] = {} - display_name = device_info["device_display_name"] - if display_name is not None: - r["unsigned"]["device_display_name"] = display_name - result_dict[user_id][device_id] = r - log_kv(results) return result_dict diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 33e3a84933..d802d7a485 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -40,7 +40,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): This option only takes effect if include_all_devices is true. Returns: Dict mapping from user-id to dict mapping from device_id to - dict containing "key_json", "device_display_name". + key data. """ set_tag("query_list", query_list) if not query_list: @@ -54,9 +54,16 @@ class EndToEndKeyWorkerStore(SQLBaseStore): include_deleted_devices, ) + # Build the result structure, un-jsonify the results, and add the + # "unsigned" section for user_id, device_keys in iteritems(results): for device_id, device_info in iteritems(device_keys): - device_info["keys"] = db_to_json(device_info.pop("key_json")) + r = db_to_json(device_info.pop("key_json")) + r["unsigned"] = {} + display_name = device_info["device_display_name"] + if display_name is not None: + r["unsigned"]["device_display_name"] = display_name + results[user_id][device_id] = r return results diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index c8ece15284..398d546280 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -38,7 +38,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): self.assertIn("user", res) self.assertIn("device", res["user"]) dev = res["user"]["device"] - self.assertDictContainsSubset({"keys": json, "device_display_name": None}, dev) + self.assertDictContainsSubset(json, dev) @defer.inlineCallbacks def test_reupload_key(self): @@ -68,7 +68,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): self.assertIn("device", res["user"]) dev = res["user"]["device"] self.assertDictContainsSubset( - {"keys": json, "device_display_name": "display_name"}, dev + {"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev ) @defer.inlineCallbacks @@ -80,10 +80,10 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): yield self.store.store_device("user2", "device1", None) yield self.store.store_device("user2", "device2", None) - yield self.store.set_e2e_device_keys("user1", "device1", now, "json11") - yield self.store.set_e2e_device_keys("user1", "device2", now, "json12") - yield self.store.set_e2e_device_keys("user2", "device1", now, "json21") - yield self.store.set_e2e_device_keys("user2", "device2", now, "json22") + yield self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"}) + yield self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"}) + yield self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"}) + yield self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"}) res = yield self.store.get_e2e_device_keys( (("user1", "device1"), ("user2", "device2")) -- cgit 1.5.1 From a0d0ba7862e38588aa0d0ac29a720fdf06f1ab8d Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Fri, 11 Oct 2019 09:38:26 +0100 Subject: Fix MAU reaping where reserved users are specified. (#6168) --- changelog.d/6168.bugfix | 1 + synapse/app/homeserver.py | 6 +- synapse/storage/monthly_active_users.py | 101 ++++++++++++++++++----------- tests/storage/test_monthly_active_users.py | 58 ++++++++++++++--- 4 files changed, 115 insertions(+), 51 deletions(-) create mode 100644 changelog.d/6168.bugfix (limited to 'tests') diff --git a/changelog.d/6168.bugfix b/changelog.d/6168.bugfix new file mode 100644 index 0000000000..39e8e9d019 --- /dev/null +++ b/changelog.d/6168.bugfix @@ -0,0 +1 @@ +Fix monthly active user reaping where reserved users are specified. diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 774326dff9..eb54f56853 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -605,13 +605,13 @@ def run(hs): @defer.inlineCallbacks def generate_monthly_active_users(): current_mau_count = 0 - reserved_count = 0 + reserved_users = () store = hs.get_datastore() if hs.config.limit_usage_by_mau or hs.config.mau_stats_only: current_mau_count = yield store.get_monthly_active_count() - reserved_count = yield store.get_registered_reserved_users_count() + reserved_users = yield store.get_registered_reserved_users() current_mau_gauge.set(float(current_mau_count)) - registered_reserved_users_mau_gauge.set(float(reserved_count)) + registered_reserved_users_mau_gauge.set(float(len(reserved_users))) max_mau_gauge.set(float(hs.config.max_mau_value)) def start_generate_monthly_active_users(): diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 752e9788a2..3803604be7 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -32,7 +32,6 @@ class MonthlyActiveUsersStore(SQLBaseStore): super(MonthlyActiveUsersStore, self).__init__(None, hs) self._clock = hs.get_clock() self.hs = hs - self.reserved_users = () # Do not add more reserved users than the total allowable number self._new_transaction( dbconn, @@ -51,7 +50,6 @@ class MonthlyActiveUsersStore(SQLBaseStore): txn (cursor): threepids (list[dict]): List of threepid dicts to reserve """ - reserved_user_list = [] for tp in threepids: user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"]) @@ -60,10 +58,8 @@ class MonthlyActiveUsersStore(SQLBaseStore): is_support = self.is_support_user_txn(txn, user_id) if not is_support: self.upsert_monthly_active_user_txn(txn, user_id) - reserved_user_list.append(user_id) else: logger.warning("mau limit reserved threepid %s not found in db" % tp) - self.reserved_users = tuple(reserved_user_list) @defer.inlineCallbacks def reap_monthly_active_users(self): @@ -74,8 +70,11 @@ class MonthlyActiveUsersStore(SQLBaseStore): Deferred[] """ - def _reap_users(txn): - # Purge stale users + def _reap_users(txn, reserved_users): + """ + Args: + reserved_users (tuple): reserved users to preserve + """ thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) query_args = [thirty_days_ago] @@ -83,20 +82,19 @@ class MonthlyActiveUsersStore(SQLBaseStore): # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres # when len(reserved_users) == 0. Works fine on sqlite. - if len(self.reserved_users) > 0: + if len(reserved_users) > 0: # questionmarks is a hack to overcome sqlite not supporting # tuples in 'WHERE IN %s' - questionmarks = "?" * len(self.reserved_users) + question_marks = ",".join("?" * len(reserved_users)) - query_args.extend(self.reserved_users) - sql = base_sql + """ AND user_id NOT IN ({})""".format( - ",".join(questionmarks) - ) + query_args.extend(reserved_users) + sql = base_sql + " AND user_id NOT IN ({})".format(question_marks) else: sql = base_sql txn.execute(sql, query_args) + max_mau_value = self.hs.config.max_mau_value if self.hs.config.limit_usage_by_mau: # If MAU user count still exceeds the MAU threshold, then delete on # a least recently active basis. @@ -106,31 +104,52 @@ class MonthlyActiveUsersStore(SQLBaseStore): # While Postgres does not require 'LIMIT', but also does not support # negative LIMIT values. So there is no way to write it that both can # support - safe_guard = self.hs.config.max_mau_value - len(self.reserved_users) - # Must be greater than zero for postgres - safe_guard = safe_guard if safe_guard > 0 else 0 - query_args = [safe_guard] - - base_sql = """ - DELETE FROM monthly_active_users - WHERE user_id NOT IN ( - SELECT user_id FROM monthly_active_users - ORDER BY timestamp DESC - LIMIT ? + if len(reserved_users) == 0: + sql = """ + DELETE FROM monthly_active_users + WHERE user_id NOT IN ( + SELECT user_id FROM monthly_active_users + ORDER BY timestamp DESC + LIMIT ? ) - """ + """ + txn.execute(sql, (max_mau_value,)) # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres # when len(reserved_users) == 0. Works fine on sqlite. - if len(self.reserved_users) > 0: - query_args.extend(self.reserved_users) - sql = base_sql + """ AND user_id NOT IN ({})""".format( - ",".join(questionmarks) - ) else: - sql = base_sql - txn.execute(sql, query_args) + # Must be >= 0 for postgres + num_of_non_reserved_users_to_remove = max( + max_mau_value - len(reserved_users), 0 + ) + + # It is important to filter reserved users twice to guard + # against the case where the reserved user is present in the + # SELECT, meaning that a legitmate mau is deleted. + sql = """ + DELETE FROM monthly_active_users + WHERE user_id NOT IN ( + SELECT user_id FROM monthly_active_users + WHERE user_id NOT IN ({}) + ORDER BY timestamp DESC + LIMIT ? + ) + AND user_id NOT IN ({}) + """.format( + question_marks, question_marks + ) + + query_args = [ + *reserved_users, + num_of_non_reserved_users_to_remove, + *reserved_users, + ] - yield self.runInteraction("reap_monthly_active_users", _reap_users) + txn.execute(sql, query_args) + + reserved_users = yield self.get_registered_reserved_users() + yield self.runInteraction( + "reap_monthly_active_users", _reap_users, reserved_users + ) # It seems poor to invalidate the whole cache, Postgres supports # 'Returning' which would allow me to invalidate only the # specific users, but sqlite has no way to do this and instead @@ -159,21 +178,25 @@ class MonthlyActiveUsersStore(SQLBaseStore): return self.runInteraction("count_users", _count_users) @defer.inlineCallbacks - def get_registered_reserved_users_count(self): - """Of the reserved threepids defined in config, how many are associated + def get_registered_reserved_users(self): + """Of the reserved threepids defined in config, which are associated with registered users? Returns: - Defered[int]: Number of real reserved users + Defered[list]: Real reserved users """ - count = 0 - for tp in self.hs.config.mau_limits_reserved_threepids: + users = [] + + for tp in self.hs.config.mau_limits_reserved_threepids[ + : self.hs.config.max_mau_value + ]: user_id = yield self.hs.get_datastore().get_user_id_by_threepid( tp["medium"], tp["address"] ) if user_id: - count = count + 1 - return count + users.append(user_id) + + return users @defer.inlineCallbacks def upsert_monthly_active_user(self, user_id): diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 1494650d10..90a63dc477 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -50,6 +50,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): {"medium": "email", "address": user2_email}, {"medium": "email", "address": user3_email}, ] + self.hs.config.mau_limits_reserved_threepids = threepids # -1 because user3 is a support user and does not count user_num = len(threepids) - 1 @@ -84,6 +85,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.hs.config.max_mau_value = 0 self.reactor.advance(FORTY_DAYS) + self.hs.config.max_mau_value = 5 self.store.reap_monthly_active_users() self.pump() @@ -147,9 +149,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.reap_monthly_active_users() self.pump() count = self.store.get_monthly_active_count() - self.assertEquals( - self.get_success(count), initial_users - self.hs.config.max_mau_value - ) + self.assertEquals(self.get_success(count), self.hs.config.max_mau_value) self.reactor.advance(FORTY_DAYS) self.store.reap_monthly_active_users() @@ -158,6 +158,44 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): count = self.store.get_monthly_active_count() self.assertEquals(self.get_success(count), 0) + def test_reap_monthly_active_users_reserved_users(self): + """ Tests that reaping correctly handles reaping where reserved users are + present""" + + self.hs.config.max_mau_value = 5 + initial_users = 5 + reserved_user_number = initial_users - 1 + threepids = [] + for i in range(initial_users): + user = "@user%d:server" % i + email = "user%d@example.com" % i + self.get_success(self.store.upsert_monthly_active_user(user)) + threepids.append({"medium": "email", "address": email}) + # Need to ensure that the most recent entries in the + # monthly_active_users table are reserved + now = int(self.hs.get_clock().time_msec()) + if i != 0: + self.get_success( + self.store.register_user(user_id=user, password_hash=None) + ) + self.get_success( + self.store.user_add_threepid(user, "email", email, now, now) + ) + + self.hs.config.mau_limits_reserved_threepids = threepids + self.store.runInteraction( + "initialise", self.store._initialise_reserved_users, threepids + ) + count = self.store.get_monthly_active_count() + self.assertTrue(self.get_success(count), initial_users) + + users = self.store.get_registered_reserved_users() + self.assertEquals(len(self.get_success(users)), reserved_user_number) + + self.get_success(self.store.reap_monthly_active_users()) + count = self.store.get_monthly_active_count() + self.assertEquals(self.get_success(count), self.hs.config.max_mau_value) + def test_populate_monthly_users_is_guest(self): # Test that guest users are not added to mau list user_id = "@user_id:host" @@ -192,12 +230,13 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): def test_get_reserved_real_user_account(self): # Test no reserved users, or reserved threepids - count = self.store.get_registered_reserved_users_count() - self.assertEquals(self.get_success(count), 0) + users = self.get_success(self.store.get_registered_reserved_users()) + self.assertEquals(len(users), 0) # Test reserved users but no registered users user1 = "@user1:example.com" user2 = "@user2:example.com" + user1_email = "user1@example.com" user2_email = "user2@example.com" threepids = [ @@ -210,8 +249,8 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): ) self.pump() - count = self.store.get_registered_reserved_users_count() - self.assertEquals(self.get_success(count), 0) + users = self.get_success(self.store.get_registered_reserved_users()) + self.assertEquals(len(users), 0) # Test reserved registed users self.store.register_user(user_id=user1, password_hash=None) @@ -221,8 +260,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): now = int(self.hs.get_clock().time_msec()) self.store.user_add_threepid(user1, "email", user1_email, now, now) self.store.user_add_threepid(user2, "email", user2_email, now, now) - count = self.store.get_registered_reserved_users_count() - self.assertEquals(self.get_success(count), len(threepids)) + + users = self.get_success(self.store.get_registered_reserved_users()) + self.assertEquals(len(users), len(threepids)) def test_support_user_not_add_to_mau_limits(self): support_user_id = "@support:test" -- cgit 1.5.1 From 5859a5c569c03f3b7c578fe4dbf2274e37af03bb Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 18 Oct 2019 07:42:26 +0200 Subject: Fix presence timeouts when synchrotron restarts. (#6212) * Fix presence timeouts when synchrotron restarts. Handling timeouts would fail if there was an external process that had timed out, e.g. a synchrotron restarting. This was due to a couple of variable name typoes. Fixes #3715. --- changelog.d/6212.bugfix | 1 + synapse/handlers/presence.py | 13 +++++++++---- tests/handlers/test_presence.py | 39 +++++++++++++++++++++++++++++++++++++++ tox.ini | 2 +- 4 files changed, 50 insertions(+), 5 deletions(-) create mode 100644 changelog.d/6212.bugfix (limited to 'tests') diff --git a/changelog.d/6212.bugfix b/changelog.d/6212.bugfix new file mode 100644 index 0000000000..918755fee0 --- /dev/null +++ b/changelog.d/6212.bugfix @@ -0,0 +1 @@ +Fix bug where presence would not get timed out correctly if a synchrotron worker is used and restarted. diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 2a5f1a007d..eda15bc623 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -24,6 +24,7 @@ The methods that define policy are: import logging from contextlib import contextmanager +from typing import Dict, Set from six import iteritems, itervalues @@ -179,8 +180,9 @@ class PresenceHandler(object): # we assume that all the sync requests on that process have stopped. # Stored as a dict from process_id to set of user_id, and a dict of # process_id to millisecond timestamp last updated. - self.external_process_to_current_syncs = {} - self.external_process_last_updated_ms = {} + self.external_process_to_current_syncs = {} # type: Dict[int, Set[str]] + self.external_process_last_updated_ms = {} # type: Dict[int, int] + self.external_sync_linearizer = Linearizer(name="external_sync_linearizer") # Start a LoopingCall in 30s that fires every 5s. @@ -349,10 +351,13 @@ class PresenceHandler(object): if now - last_update > EXTERNAL_PROCESS_EXPIRY ] for process_id in expired_process_ids: + # For each expired process drop tracking info and check the users + # that were syncing on that process to see if they need to be timed + # out. users_to_check.update( - self.external_process_last_updated_ms.pop(process_id, ()) + self.external_process_to_current_syncs.pop(process_id, ()) ) - self.external_process_last_update.pop(process_id) + self.external_process_last_updated_ms.pop(process_id) states = [ self.user_to_current_state.get(user_id, UserPresenceState.default(user_id)) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index f70c6e7d65..d4293b4312 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -22,6 +22,7 @@ from synapse.api.constants import EventTypes, Membership, PresenceState from synapse.events import room_version_to_event_format from synapse.events.builder import EventBuilder from synapse.handlers.presence import ( + EXTERNAL_PROCESS_EXPIRY, FEDERATION_PING_INTERVAL, FEDERATION_TIMEOUT, IDLE_TIMER, @@ -413,6 +414,44 @@ class PresenceTimeoutTestCase(unittest.TestCase): self.assertEquals(state, new_state) +class PresenceHandlerTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, hs): + self.presence_handler = hs.get_presence_handler() + self.clock = hs.get_clock() + + def test_external_process_timeout(self): + """Test that if an external process doesn't update the records for a while + we time out their syncing users presence. + """ + process_id = 1 + user_id = "@test:server" + + # Notify handler that a user is now syncing. + self.get_success( + self.presence_handler.update_external_syncs_row( + process_id, user_id, True, self.clock.time_msec() + ) + ) + + # Check that if we wait a while without telling the handler the user has + # stopped syncing that their presence state doesn't get timed out. + self.reactor.advance(EXTERNAL_PROCESS_EXPIRY / 2) + + state = self.get_success( + self.presence_handler.get_state(UserID.from_string(user_id)) + ) + self.assertEqual(state.state, PresenceState.ONLINE) + + # Check that if the external process timeout fires, then the syncing + # user gets timed out + self.reactor.advance(EXTERNAL_PROCESS_EXPIRY) + + state = self.get_success( + self.presence_handler.get_state(UserID.from_string(user_id)) + ) + self.assertEqual(state.state, PresenceState.OFFLINE) + + class PresenceJoinTestCase(unittest.HomeserverTestCase): """Tests remote servers get told about presence of users in the room when they join and when new local users join. diff --git a/tox.ini b/tox.ini index 367cc2ccf2..7ba6f6339f 100644 --- a/tox.ini +++ b/tox.ini @@ -161,7 +161,7 @@ basepython = python3.7 skip_install = True deps = {[base]deps} - mypy + mypy==0.730 mypy-zope env = MYPYPATH = stubs/ -- cgit 1.5.1 From c66a06ac6b69b0a03f5c6284ded980399e9df94e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 21 Oct 2019 12:56:42 +0100 Subject: Move storage classes into a main "data store". This is in preparation for having multiple data stores that offer different functionality, e.g. splitting out state or event storage. --- synapse/app/event_creator.py | 2 +- synapse/app/media_repository.py | 2 +- synapse/app/synchrotron.py | 2 +- synapse/app/user_dir.py | 2 +- synapse/federation/sender/per_destination_queue.py | 2 +- synapse/replication/slave/storage/account_data.py | 4 +- synapse/replication/slave/storage/appservice.py | 2 +- synapse/replication/slave/storage/client_ips.py | 2 +- synapse/replication/slave/storage/deviceinbox.py | 2 +- synapse/replication/slave/storage/devices.py | 4 +- synapse/replication/slave/storage/directory.py | 2 +- synapse/replication/slave/storage/events.py | 20 +- synapse/replication/slave/storage/filtering.py | 2 +- synapse/replication/slave/storage/keys.py | 2 +- synapse/replication/slave/storage/presence.py | 2 +- synapse/replication/slave/storage/profile.py | 2 +- synapse/replication/slave/storage/push_rule.py | 2 +- synapse/replication/slave/storage/pushers.py | 2 +- synapse/replication/slave/storage/receipts.py | 2 +- synapse/replication/slave/storage/registration.py | 2 +- synapse/replication/slave/storage/room.py | 2 +- synapse/replication/slave/storage/transactions.py | 2 +- synapse/storage/__init__.py | 504 +--- synapse/storage/account_data.py | 391 --- synapse/storage/appservice.py | 368 --- synapse/storage/client_ips.py | 581 ----- synapse/storage/data_stores/__init__.py | 14 + synapse/storage/data_stores/main/__init__.py | 524 +++++ synapse/storage/data_stores/main/account_data.py | 391 +++ synapse/storage/data_stores/main/appservice.py | 367 +++ synapse/storage/data_stores/main/client_ips.py | 580 +++++ synapse/storage/data_stores/main/deviceinbox.py | 457 ++++ synapse/storage/data_stores/main/devices.py | 937 ++++++++ synapse/storage/data_stores/main/directory.py | 173 ++ synapse/storage/data_stores/main/e2e_room_keys.py | 336 +++ .../storage/data_stores/main/end_to_end_keys.py | 322 +++ .../storage/data_stores/main/event_federation.py | 672 ++++++ .../storage/data_stores/main/event_push_actions.py | 960 ++++++++ synapse/storage/data_stores/main/events.py | 2489 ++++++++++++++++++++ .../storage/data_stores/main/events_bg_updates.py | 505 ++++ synapse/storage/data_stores/main/events_worker.py | 882 +++++++ synapse/storage/data_stores/main/filtering.py | 74 + synapse/storage/data_stores/main/group_server.py | 1180 ++++++++++ synapse/storage/data_stores/main/keys.py | 214 ++ .../storage/data_stores/main/media_repository.py | 378 +++ .../data_stores/main/monthly_active_users.py | 328 +++ synapse/storage/data_stores/main/openid.py | 31 + synapse/storage/data_stores/main/presence.py | 150 ++ synapse/storage/data_stores/main/profile.py | 178 ++ synapse/storage/data_stores/main/push_rule.py | 713 ++++++ synapse/storage/data_stores/main/pusher.py | 371 +++ synapse/storage/data_stores/main/receipts.py | 536 +++++ synapse/storage/data_stores/main/registration.py | 1499 ++++++++++++ synapse/storage/data_stores/main/rejections.py | 42 + synapse/storage/data_stores/main/relations.py | 385 +++ synapse/storage/data_stores/main/room.py | 681 ++++++ synapse/storage/data_stores/main/roommember.py | 1145 +++++++++ .../data_stores/main/schema/delta/12/v12.sql | 63 + .../data_stores/main/schema/delta/13/v13.sql | 19 + .../data_stores/main/schema/delta/14/v14.sql | 23 + .../main/schema/delta/15/appservice_txns.sql | 31 + .../main/schema/delta/15/presence_indices.sql | 2 + .../data_stores/main/schema/delta/15/v15.sql | 24 + .../main/schema/delta/16/events_order_index.sql | 4 + .../schema/delta/16/remote_media_cache_index.sql | 2 + .../main/schema/delta/16/remove_duplicates.sql | 9 + .../main/schema/delta/16/room_alias_index.sql | 3 + .../main/schema/delta/16/unique_constraints.sql | 72 + .../data_stores/main/schema/delta/16/users.sql | 56 + .../main/schema/delta/17/drop_indexes.sql | 18 + .../main/schema/delta/17/server_keys.sql | 24 + .../main/schema/delta/17/user_threepids.sql | 9 + .../schema/delta/18/server_keys_bigger_ints.sql | 32 + .../main/schema/delta/19/event_index.sql | 19 + .../data_stores/main/schema/delta/20/dummy.sql | 1 + .../data_stores/main/schema/delta/20/pushers.py | 88 + .../main/schema/delta/21/end_to_end_keys.sql | 34 + .../data_stores/main/schema/delta/21/receipts.sql | 38 + .../main/schema/delta/22/receipts_index.sql | 22 + .../main/schema/delta/22/user_threepids_unique.sql | 19 + .../main/schema/delta/23/drop_state_index.sql | 16 + .../main/schema/delta/24/stats_reporting.sql | 18 + .../main/schema/delta/25/00background_updates.sql | 21 + .../data_stores/main/schema/delta/25/fts.py | 82 + .../main/schema/delta/25/guest_access.sql | 25 + .../main/schema/delta/25/history_visibility.sql | 25 + .../data_stores/main/schema/delta/25/tags.sql | 38 + .../main/schema/delta/26/account_data.sql | 17 + .../main/schema/delta/27/account_data.sql | 36 + .../main/schema/delta/27/forgotten_memberships.sql | 26 + .../storage/data_stores/main/schema/delta/27/ts.py | 61 + .../main/schema/delta/28/event_push_actions.sql | 27 + .../main/schema/delta/28/events_room_stream.sql | 20 + .../main/schema/delta/28/public_roms_index.sql | 20 + .../schema/delta/28/receipts_user_id_index.sql | 22 + .../main/schema/delta/28/upgrade_times.sql | 21 + .../main/schema/delta/28/users_is_guest.sql | 22 + .../main/schema/delta/29/push_actions.sql | 35 + .../main/schema/delta/30/alias_creator.sql | 16 + .../data_stores/main/schema/delta/30/as_users.py | 69 + .../main/schema/delta/30/deleted_pushers.sql | 25 + .../main/schema/delta/30/presence_stream.sql | 30 + .../main/schema/delta/30/public_rooms.sql | 23 + .../main/schema/delta/30/push_rule_stream.sql | 38 + .../main/schema/delta/30/state_stream.sql | 33 + .../delta/30/threepid_guest_access_tokens.sql | 24 + .../data_stores/main/schema/delta/31/invites.sql | 42 + .../delta/31/local_media_repository_url_cache.sql | 27 + .../data_stores/main/schema/delta/31/pushers.py | 87 + .../main/schema/delta/31/pushers_index.sql | 22 + .../main/schema/delta/31/search_update.py | 66 + .../data_stores/main/schema/delta/32/events.sql | 16 + .../data_stores/main/schema/delta/32/openid.sql | 9 + .../main/schema/delta/32/pusher_throttle.sql | 23 + .../main/schema/delta/32/remove_indices.sql | 34 + .../data_stores/main/schema/delta/32/reports.sql | 25 + .../schema/delta/33/access_tokens_device_index.sql | 17 + .../data_stores/main/schema/delta/33/devices.sql | 21 + .../main/schema/delta/33/devices_for_e2e_keys.sql | 19 + .../devices_for_e2e_keys_clear_unknown_device.sql | 20 + .../main/schema/delta/33/event_fields.py | 61 + .../main/schema/delta/33/remote_media_ts.py | 30 + .../main/schema/delta/33/user_ips_index.sql | 17 + .../main/schema/delta/34/appservice_stream.sql | 23 + .../main/schema/delta/34/cache_stream.py | 46 + .../main/schema/delta/34/device_inbox.sql | 24 + .../schema/delta/34/push_display_name_rename.sql | 20 + .../main/schema/delta/34/received_txn_purge.py | 32 + .../main/schema/delta/35/add_state_index.sql | 17 + .../main/schema/delta/35/contains_url.sql | 17 + .../main/schema/delta/35/device_outbox.sql | 39 + .../main/schema/delta/35/device_stream_id.sql | 21 + .../schema/delta/35/event_push_actions_index.sql | 17 + .../delta/35/public_room_list_change_stream.sql | 33 + .../data_stores/main/schema/delta/35/state.sql | 22 + .../main/schema/delta/35/state_dedupe.sql | 17 + .../schema/delta/35/stream_order_to_extrem.sql | 37 + .../main/schema/delta/36/readd_public_rooms.sql | 26 + .../main/schema/delta/37/remove_auth_idx.py | 85 + .../main/schema/delta/37/user_threepids.sql | 52 + .../main/schema/delta/38/postgres_fts_gist.sql | 19 + .../main/schema/delta/39/appservice_room_list.sql | 29 + .../delta/39/device_federation_stream_idx.sql | 16 + .../main/schema/delta/39/event_push_index.sql | 17 + .../schema/delta/39/federation_out_position.sql | 22 + .../main/schema/delta/39/membership_profile.sql | 20 + .../main/schema/delta/40/current_state_idx.sql | 17 + .../main/schema/delta/40/device_inbox.sql | 21 + .../main/schema/delta/40/device_list_streams.sql | 60 + .../main/schema/delta/40/event_push_summary.sql | 37 + .../data_stores/main/schema/delta/40/pushers.sql | 39 + .../schema/delta/41/device_list_stream_idx.sql | 17 + .../main/schema/delta/41/device_outbound_index.sql | 16 + .../schema/delta/41/event_search_event_id_idx.sql | 17 + .../data_stores/main/schema/delta/41/ratelimit.sql | 22 + .../main/schema/delta/42/current_state_delta.sql | 26 + .../main/schema/delta/42/device_list_last_id.sql | 33 + .../main/schema/delta/42/event_auth_state_only.sql | 17 + .../data_stores/main/schema/delta/42/user_dir.py | 84 + .../main/schema/delta/43/blocked_rooms.sql | 21 + .../main/schema/delta/43/quarantine_media.sql | 17 + .../data_stores/main/schema/delta/43/url_cache.sql | 16 + .../main/schema/delta/43/user_share.sql | 33 + .../main/schema/delta/44/expire_url_cache.sql | 41 + .../main/schema/delta/45/group_server.sql | 167 ++ .../main/schema/delta/45/profile_cache.sql | 28 + .../main/schema/delta/46/drop_refresh_tokens.sql | 17 + .../delta/46/drop_unique_deleted_pushers.sql | 35 + .../main/schema/delta/46/group_server.sql | 32 + .../delta/46/local_media_repository_url_idx.sql | 24 + .../schema/delta/46/user_dir_null_room_ids.sql | 35 + .../main/schema/delta/46/user_dir_typos.sql | 24 + .../main/schema/delta/47/last_access_media.sql | 16 + .../main/schema/delta/47/postgres_fts_gin.sql | 17 + .../main/schema/delta/47/push_actions_staging.sql | 28 + .../main/schema/delta/47/state_group_seq.py | 34 + .../main/schema/delta/48/add_user_consent.sql | 18 + .../delta/48/add_user_ips_last_seen_index.sql | 17 + .../main/schema/delta/48/deactivated_users.sql | 25 + .../main/schema/delta/48/group_unique_indexes.py | 63 + .../main/schema/delta/48/groups_joinable.sql | 22 + .../49/add_user_consent_server_notice_sent.sql | 20 + .../main/schema/delta/49/add_user_daily_visits.sql | 21 + .../delta/49/add_user_ips_last_seen_only_index.sql | 17 + .../delta/50/add_creation_ts_users_index.sql | 19 + .../main/schema/delta/50/erasure_store.sql | 21 + .../schema/delta/50/make_event_content_nullable.py | 96 + .../main/schema/delta/51/e2e_room_keys.sql | 39 + .../main/schema/delta/51/monthly_active_users.sql | 27 + .../delta/52/add_event_to_state_group_index.sql | 19 + .../delta/52/device_list_streams_unique_idx.sql | 36 + .../main/schema/delta/52/e2e_room_keys.sql | 53 + .../schema/delta/53/add_user_type_to_users.sql | 19 + .../schema/delta/53/drop_sent_transactions.sql | 16 + .../main/schema/delta/53/event_format_version.sql | 16 + .../main/schema/delta/53/user_dir_populate.sql | 30 + .../main/schema/delta/53/user_ips_index.sql | 30 + .../main/schema/delta/53/user_share.sql | 44 + .../main/schema/delta/53/user_threepid_id.sql | 29 + .../main/schema/delta/53/users_in_public_rooms.sql | 28 + .../delta/54/account_validity_with_renewal.sql | 30 + .../delta/54/add_validity_to_server_keys.sql | 23 + .../schema/delta/54/delete_forward_extremities.sql | 23 + .../main/schema/delta/54/drop_legacy_tables.sql | 30 + .../main/schema/delta/54/drop_presence_list.sql | 16 + .../data_stores/main/schema/delta/54/relations.sql | 27 + .../data_stores/main/schema/delta/54/stats.sql | 80 + .../data_stores/main/schema/delta/54/stats2.sql | 28 + .../main/schema/delta/55/access_token_expiry.sql | 18 + .../schema/delta/55/track_threepid_validations.sql | 31 + .../schema/delta/55/users_alter_deactivated.sql | 19 + .../schema/delta/56/add_spans_to_device_lists.sql | 20 + .../delta/56/current_state_events_membership.sql | 22 + .../56/current_state_events_membership_mk2.sql | 24 + .../schema/delta/56/destinations_failure_ts.sql | 25 + .../destinations_retry_interval_type.sql.postgres | 18 + .../main/schema/delta/56/devices_last_seen.sql | 24 + .../schema/delta/56/drop_unused_event_tables.sql | 20 + .../main/schema/delta/56/fix_room_keys_index.sql | 18 + .../main/schema/delta/56/public_room_list_idx.sql | 16 + .../main/schema/delta/56/redaction_censor.sql | 17 + .../main/schema/delta/56/redaction_censor2.sql | 20 + .../56/redaction_censor3_fix_update.sql.postgres | 25 + .../main/schema/delta/56/room_membership_idx.sql | 18 + .../main/schema/delta/56/stats_separated.sql | 152 ++ .../schema/delta/56/unique_user_filter_index.py | 52 + .../main/schema/delta/56/user_external_ids.sql | 24 + .../schema/delta/56/users_in_public_rooms_idx.sql | 17 + .../full_schemas/16/application_services.sql | 37 + .../main/schema/full_schemas/16/event_edges.sql | 70 + .../schema/full_schemas/16/event_signatures.sql | 38 + .../data_stores/main/schema/full_schemas/16/im.sql | 120 + .../main/schema/full_schemas/16/keys.sql | 26 + .../schema/full_schemas/16/media_repository.sql | 68 + .../main/schema/full_schemas/16/presence.sql | 32 + .../main/schema/full_schemas/16/profiles.sql | 20 + .../main/schema/full_schemas/16/push.sql | 74 + .../main/schema/full_schemas/16/redactions.sql | 22 + .../main/schema/full_schemas/16/room_aliases.sql | 29 + .../main/schema/full_schemas/16/state.sql | 40 + .../main/schema/full_schemas/16/transactions.sql | 44 + .../main/schema/full_schemas/16/users.sql | 42 + .../main/schema/full_schemas/54/full.sql.postgres | 2035 ++++++++++++++++ .../main/schema/full_schemas/54/full.sql.sqlite | 259 ++ .../schema/full_schemas/54/stream_positions.sql | 7 + .../main/schema/full_schemas/README.txt | 19 + synapse/storage/data_stores/main/search.py | 712 ++++++ synapse/storage/data_stores/main/signatures.py | 101 + synapse/storage/data_stores/main/state.py | 1244 ++++++++++ synapse/storage/data_stores/main/state_deltas.py | 119 + synapse/storage/data_stores/main/stats.py | 881 +++++++ synapse/storage/data_stores/main/stream.py | 948 ++++++++ synapse/storage/data_stores/main/tags.py | 265 +++ synapse/storage/data_stores/main/transactions.py | 273 +++ synapse/storage/data_stores/main/user_directory.py | 827 +++++++ .../storage/data_stores/main/user_erasure_store.py | 91 + synapse/storage/deviceinbox.py | 457 ---- synapse/storage/devices.py | 937 -------- synapse/storage/directory.py | 174 -- synapse/storage/e2e_room_keys.py | 337 --- synapse/storage/end_to_end_keys.py | 323 --- synapse/storage/event_federation.py | 672 ------ synapse/storage/event_push_actions.py | 960 -------- synapse/storage/events.py | 2489 -------------------- synapse/storage/events_bg_updates.py | 505 ---- synapse/storage/events_worker.py | 882 ------- synapse/storage/filtering.py | 75 - synapse/storage/group_server.py | 1181 ---------- synapse/storage/keys.py | 194 -- synapse/storage/media_repository.py | 378 --- synapse/storage/monthly_active_users.py | 329 --- synapse/storage/openid.py | 31 - synapse/storage/presence.py | 134 -- synapse/storage/profile.py | 179 -- synapse/storage/push_rule.py | 698 ------ synapse/storage/pusher.py | 372 --- synapse/storage/receipts.py | 536 ----- synapse/storage/registration.py | 1499 ------------ synapse/storage/rejections.py | 42 - synapse/storage/relations.py | 359 --- synapse/storage/room.py | 681 ------ synapse/storage/roommember.py | 1119 --------- synapse/storage/schema/delta/12/v12.sql | 63 - synapse/storage/schema/delta/13/v13.sql | 19 - synapse/storage/schema/delta/14/v14.sql | 23 - .../storage/schema/delta/15/appservice_txns.sql | 31 - .../storage/schema/delta/15/presence_indices.sql | 2 - synapse/storage/schema/delta/15/v15.sql | 24 - .../storage/schema/delta/16/events_order_index.sql | 4 - .../schema/delta/16/remote_media_cache_index.sql | 2 - .../storage/schema/delta/16/remove_duplicates.sql | 9 - .../storage/schema/delta/16/room_alias_index.sql | 3 - .../storage/schema/delta/16/unique_constraints.sql | 72 - synapse/storage/schema/delta/16/users.sql | 56 - synapse/storage/schema/delta/17/drop_indexes.sql | 18 - synapse/storage/schema/delta/17/server_keys.sql | 24 - synapse/storage/schema/delta/17/user_threepids.sql | 9 - .../schema/delta/18/server_keys_bigger_ints.sql | 32 - synapse/storage/schema/delta/19/event_index.sql | 19 - synapse/storage/schema/delta/20/dummy.sql | 1 - synapse/storage/schema/delta/20/pushers.py | 88 - .../storage/schema/delta/21/end_to_end_keys.sql | 34 - synapse/storage/schema/delta/21/receipts.sql | 38 - synapse/storage/schema/delta/22/receipts_index.sql | 22 - .../schema/delta/22/user_threepids_unique.sql | 19 - .../storage/schema/delta/23/drop_state_index.sql | 16 - .../storage/schema/delta/24/stats_reporting.sql | 18 - synapse/storage/schema/delta/25/fts.py | 82 - synapse/storage/schema/delta/25/guest_access.sql | 25 - .../storage/schema/delta/25/history_visibility.sql | 25 - synapse/storage/schema/delta/25/tags.sql | 38 - synapse/storage/schema/delta/26/account_data.sql | 17 - synapse/storage/schema/delta/27/account_data.sql | 36 - .../schema/delta/27/forgotten_memberships.sql | 26 - synapse/storage/schema/delta/27/ts.py | 61 - .../storage/schema/delta/28/event_push_actions.sql | 27 - .../storage/schema/delta/28/events_room_stream.sql | 20 - .../storage/schema/delta/28/public_roms_index.sql | 20 - .../schema/delta/28/receipts_user_id_index.sql | 22 - synapse/storage/schema/delta/28/upgrade_times.sql | 21 - synapse/storage/schema/delta/28/users_is_guest.sql | 22 - synapse/storage/schema/delta/29/push_actions.sql | 35 - synapse/storage/schema/delta/30/alias_creator.sql | 16 - synapse/storage/schema/delta/30/as_users.py | 69 - .../storage/schema/delta/30/deleted_pushers.sql | 25 - .../storage/schema/delta/30/presence_stream.sql | 30 - synapse/storage/schema/delta/30/public_rooms.sql | 23 - .../storage/schema/delta/30/push_rule_stream.sql | 38 - synapse/storage/schema/delta/30/state_stream.sql | 33 - .../delta/30/threepid_guest_access_tokens.sql | 24 - synapse/storage/schema/delta/31/invites.sql | 42 - .../delta/31/local_media_repository_url_cache.sql | 27 - synapse/storage/schema/delta/31/pushers.py | 87 - synapse/storage/schema/delta/31/pushers_index.sql | 22 - synapse/storage/schema/delta/31/search_update.py | 66 - synapse/storage/schema/delta/32/events.sql | 16 - synapse/storage/schema/delta/32/openid.sql | 9 - .../storage/schema/delta/32/pusher_throttle.sql | 23 - synapse/storage/schema/delta/32/remove_indices.sql | 34 - synapse/storage/schema/delta/32/reports.sql | 25 - .../schema/delta/33/access_tokens_device_index.sql | 17 - synapse/storage/schema/delta/33/devices.sql | 21 - .../schema/delta/33/devices_for_e2e_keys.sql | 19 - .../devices_for_e2e_keys_clear_unknown_device.sql | 20 - synapse/storage/schema/delta/33/event_fields.py | 61 - synapse/storage/schema/delta/33/remote_media_ts.py | 30 - synapse/storage/schema/delta/33/user_ips_index.sql | 17 - .../storage/schema/delta/34/appservice_stream.sql | 23 - synapse/storage/schema/delta/34/cache_stream.py | 46 - synapse/storage/schema/delta/34/device_inbox.sql | 24 - .../schema/delta/34/push_display_name_rename.sql | 20 - .../storage/schema/delta/34/received_txn_purge.py | 32 - .../delta/35/00background_updates_add_col.sql | 17 + .../storage/schema/delta/35/add_state_index.sql | 20 - synapse/storage/schema/delta/35/contains_url.sql | 17 - synapse/storage/schema/delta/35/device_outbox.sql | 39 - .../storage/schema/delta/35/device_stream_id.sql | 21 - .../schema/delta/35/event_push_actions_index.sql | 17 - .../delta/35/public_room_list_change_stream.sql | 33 - synapse/storage/schema/delta/35/state.sql | 22 - synapse/storage/schema/delta/35/state_dedupe.sql | 17 - .../schema/delta/35/stream_order_to_extrem.sql | 37 - .../storage/schema/delta/36/readd_public_rooms.sql | 26 - synapse/storage/schema/delta/37/remove_auth_idx.py | 85 - synapse/storage/schema/delta/37/user_threepids.sql | 52 - .../storage/schema/delta/38/postgres_fts_gist.sql | 19 - .../schema/delta/39/appservice_room_list.sql | 29 - .../delta/39/device_federation_stream_idx.sql | 16 - .../storage/schema/delta/39/event_push_index.sql | 17 - .../schema/delta/39/federation_out_position.sql | 22 - .../storage/schema/delta/39/membership_profile.sql | 20 - .../storage/schema/delta/40/current_state_idx.sql | 17 - synapse/storage/schema/delta/40/device_inbox.sql | 21 - .../schema/delta/40/device_list_streams.sql | 60 - .../storage/schema/delta/40/event_push_summary.sql | 37 - synapse/storage/schema/delta/40/pushers.sql | 39 - .../schema/delta/41/device_list_stream_idx.sql | 17 - .../schema/delta/41/device_outbound_index.sql | 16 - .../schema/delta/41/event_search_event_id_idx.sql | 17 - synapse/storage/schema/delta/41/ratelimit.sql | 22 - .../schema/delta/42/current_state_delta.sql | 26 - .../schema/delta/42/device_list_last_id.sql | 33 - .../schema/delta/42/event_auth_state_only.sql | 17 - synapse/storage/schema/delta/42/user_dir.py | 84 - synapse/storage/schema/delta/43/blocked_rooms.sql | 21 - .../storage/schema/delta/43/quarantine_media.sql | 17 - synapse/storage/schema/delta/43/url_cache.sql | 16 - synapse/storage/schema/delta/43/user_share.sql | 33 - .../storage/schema/delta/44/expire_url_cache.sql | 41 - synapse/storage/schema/delta/45/group_server.sql | 167 -- synapse/storage/schema/delta/45/profile_cache.sql | 28 - .../schema/delta/46/drop_refresh_tokens.sql | 17 - .../delta/46/drop_unique_deleted_pushers.sql | 35 - synapse/storage/schema/delta/46/group_server.sql | 32 - .../delta/46/local_media_repository_url_idx.sql | 24 - .../schema/delta/46/user_dir_null_room_ids.sql | 35 - synapse/storage/schema/delta/46/user_dir_typos.sql | 24 - .../storage/schema/delta/47/last_access_media.sql | 16 - .../storage/schema/delta/47/postgres_fts_gin.sql | 17 - .../schema/delta/47/push_actions_staging.sql | 28 - synapse/storage/schema/delta/47/state_group_seq.py | 34 - .../storage/schema/delta/48/add_user_consent.sql | 18 - .../delta/48/add_user_ips_last_seen_index.sql | 17 - .../storage/schema/delta/48/deactivated_users.sql | 25 - .../schema/delta/48/group_unique_indexes.py | 63 - .../storage/schema/delta/48/groups_joinable.sql | 22 - .../49/add_user_consent_server_notice_sent.sql | 20 - .../schema/delta/49/add_user_daily_visits.sql | 21 - .../delta/49/add_user_ips_last_seen_only_index.sql | 17 - .../delta/50/add_creation_ts_users_index.sql | 19 - synapse/storage/schema/delta/50/erasure_store.sql | 21 - .../schema/delta/50/make_event_content_nullable.py | 96 - synapse/storage/schema/delta/51/e2e_room_keys.sql | 39 - .../schema/delta/51/monthly_active_users.sql | 27 - .../delta/52/add_event_to_state_group_index.sql | 19 - .../delta/52/device_list_streams_unique_idx.sql | 36 - synapse/storage/schema/delta/52/e2e_room_keys.sql | 53 - .../schema/delta/53/add_user_type_to_users.sql | 19 - .../schema/delta/53/drop_sent_transactions.sql | 16 - .../schema/delta/53/event_format_version.sql | 16 - .../storage/schema/delta/53/user_dir_populate.sql | 30 - synapse/storage/schema/delta/53/user_ips_index.sql | 30 - synapse/storage/schema/delta/53/user_share.sql | 44 - .../storage/schema/delta/53/user_threepid_id.sql | 29 - .../schema/delta/53/users_in_public_rooms.sql | 28 - .../delta/54/account_validity_with_renewal.sql | 30 - .../delta/54/add_validity_to_server_keys.sql | 23 - .../schema/delta/54/delete_forward_extremities.sql | 23 - .../storage/schema/delta/54/drop_legacy_tables.sql | 30 - .../storage/schema/delta/54/drop_presence_list.sql | 16 - synapse/storage/schema/delta/54/relations.sql | 27 - synapse/storage/schema/delta/54/stats.sql | 80 - synapse/storage/schema/delta/54/stats2.sql | 28 - .../schema/delta/55/access_token_expiry.sql | 18 - .../schema/delta/55/track_threepid_validations.sql | 31 - .../schema/delta/55/users_alter_deactivated.sql | 19 - .../schema/delta/56/add_spans_to_device_lists.sql | 20 - .../delta/56/current_state_events_membership.sql | 22 - .../56/current_state_events_membership_mk2.sql | 24 - .../schema/delta/56/destinations_failure_ts.sql | 25 - .../destinations_retry_interval_type.sql.postgres | 18 - .../storage/schema/delta/56/devices_last_seen.sql | 24 - .../schema/delta/56/drop_unused_event_tables.sql | 20 - .../schema/delta/56/fix_room_keys_index.sql | 18 - .../schema/delta/56/public_room_list_idx.sql | 16 - .../storage/schema/delta/56/redaction_censor.sql | 17 - .../storage/schema/delta/56/redaction_censor2.sql | 20 - .../56/redaction_censor3_fix_update.sql.postgres | 25 - .../schema/delta/56/room_membership_idx.sql | 18 - .../storage/schema/delta/56/stats_separated.sql | 152 -- .../schema/delta/56/unique_user_filter_index.py | 52 - .../storage/schema/delta/56/user_external_ids.sql | 24 - .../schema/delta/56/users_in_public_rooms_idx.sql | 17 - .../full_schemas/16/application_services.sql | 37 - .../storage/schema/full_schemas/16/event_edges.sql | 70 - .../schema/full_schemas/16/event_signatures.sql | 38 - synapse/storage/schema/full_schemas/16/im.sql | 120 - synapse/storage/schema/full_schemas/16/keys.sql | 26 - .../schema/full_schemas/16/media_repository.sql | 68 - .../storage/schema/full_schemas/16/presence.sql | 32 - .../storage/schema/full_schemas/16/profiles.sql | 20 - synapse/storage/schema/full_schemas/16/push.sql | 74 - .../storage/schema/full_schemas/16/redactions.sql | 22 - .../schema/full_schemas/16/room_aliases.sql | 29 - synapse/storage/schema/full_schemas/16/state.sql | 40 - .../schema/full_schemas/16/transactions.sql | 44 - synapse/storage/schema/full_schemas/16/users.sql | 42 - synapse/storage/schema/full_schemas/54/full.sql | 8 + .../schema/full_schemas/54/full.sql.postgres | 2052 ---------------- .../storage/schema/full_schemas/54/full.sql.sqlite | 260 -- .../schema/full_schemas/54/stream_positions.sql | 7 - synapse/storage/schema/full_schemas/README.txt | 19 - synapse/storage/search.py | 713 ------ synapse/storage/signatures.py | 102 - synapse/storage/state.py | 1221 ---------- synapse/storage/state_deltas.py | 119 - synapse/storage/stats.py | 881 ------- synapse/storage/stream.py | 948 -------- synapse/storage/tags.py | 265 --- synapse/storage/transactions.py | 274 --- synapse/storage/user_directory.py | 827 ------- synapse/storage/user_erasure_store.py | 91 - tests/handlers/test_stats.py | 8 +- tests/storage/test_appservice.py | 2 +- tests/storage/test_cleanup_extrems.py | 2 + tests/storage/test_profile.py | 2 +- tests/storage/test_user_directory.py | 2 +- 487 files changed, 31168 insertions(+), 30990 deletions(-) delete mode 100644 synapse/storage/account_data.py delete mode 100644 synapse/storage/appservice.py delete mode 100644 synapse/storage/client_ips.py create mode 100644 synapse/storage/data_stores/__init__.py create mode 100644 synapse/storage/data_stores/main/__init__.py create mode 100644 synapse/storage/data_stores/main/account_data.py create mode 100644 synapse/storage/data_stores/main/appservice.py create mode 100644 synapse/storage/data_stores/main/client_ips.py create mode 100644 synapse/storage/data_stores/main/deviceinbox.py create mode 100644 synapse/storage/data_stores/main/devices.py create mode 100644 synapse/storage/data_stores/main/directory.py create mode 100644 synapse/storage/data_stores/main/e2e_room_keys.py create mode 100644 synapse/storage/data_stores/main/end_to_end_keys.py create mode 100644 synapse/storage/data_stores/main/event_federation.py create mode 100644 synapse/storage/data_stores/main/event_push_actions.py create mode 100644 synapse/storage/data_stores/main/events.py create mode 100644 synapse/storage/data_stores/main/events_bg_updates.py create mode 100644 synapse/storage/data_stores/main/events_worker.py create mode 100644 synapse/storage/data_stores/main/filtering.py create mode 100644 synapse/storage/data_stores/main/group_server.py create mode 100644 synapse/storage/data_stores/main/keys.py create mode 100644 synapse/storage/data_stores/main/media_repository.py create mode 100644 synapse/storage/data_stores/main/monthly_active_users.py create mode 100644 synapse/storage/data_stores/main/openid.py create mode 100644 synapse/storage/data_stores/main/presence.py create mode 100644 synapse/storage/data_stores/main/profile.py create mode 100644 synapse/storage/data_stores/main/push_rule.py create mode 100644 synapse/storage/data_stores/main/pusher.py create mode 100644 synapse/storage/data_stores/main/receipts.py create mode 100644 synapse/storage/data_stores/main/registration.py create mode 100644 synapse/storage/data_stores/main/rejections.py create mode 100644 synapse/storage/data_stores/main/relations.py create mode 100644 synapse/storage/data_stores/main/room.py create mode 100644 synapse/storage/data_stores/main/roommember.py create mode 100644 synapse/storage/data_stores/main/schema/delta/12/v12.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/13/v13.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/14/v14.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/15/appservice_txns.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/15/presence_indices.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/15/v15.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/16/events_order_index.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/16/remote_media_cache_index.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/16/remove_duplicates.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/16/room_alias_index.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/16/unique_constraints.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/16/users.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/17/drop_indexes.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/17/server_keys.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/17/user_threepids.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/18/server_keys_bigger_ints.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/19/event_index.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/20/dummy.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/20/pushers.py create mode 100644 synapse/storage/data_stores/main/schema/delta/21/end_to_end_keys.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/21/receipts.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/22/receipts_index.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/22/user_threepids_unique.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/23/drop_state_index.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/24/stats_reporting.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/25/00background_updates.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/25/fts.py create mode 100644 synapse/storage/data_stores/main/schema/delta/25/guest_access.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/25/history_visibility.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/25/tags.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/26/account_data.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/27/account_data.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/27/forgotten_memberships.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/27/ts.py create mode 100644 synapse/storage/data_stores/main/schema/delta/28/event_push_actions.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/28/events_room_stream.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/28/public_roms_index.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/28/receipts_user_id_index.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/28/upgrade_times.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/28/users_is_guest.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/29/push_actions.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/30/alias_creator.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/30/as_users.py create mode 100644 synapse/storage/data_stores/main/schema/delta/30/deleted_pushers.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/30/presence_stream.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/30/public_rooms.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/30/push_rule_stream.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/30/state_stream.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/30/threepid_guest_access_tokens.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/31/invites.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/31/local_media_repository_url_cache.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/31/pushers.py create mode 100644 synapse/storage/data_stores/main/schema/delta/31/pushers_index.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/31/search_update.py create mode 100644 synapse/storage/data_stores/main/schema/delta/32/events.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/32/openid.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/32/pusher_throttle.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/32/reports.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/33/access_tokens_device_index.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/33/devices.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/33/event_fields.py create mode 100644 synapse/storage/data_stores/main/schema/delta/33/remote_media_ts.py create mode 100644 synapse/storage/data_stores/main/schema/delta/33/user_ips_index.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/34/appservice_stream.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/34/cache_stream.py create mode 100644 synapse/storage/data_stores/main/schema/delta/34/device_inbox.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/34/push_display_name_rename.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/34/received_txn_purge.py create mode 100644 synapse/storage/data_stores/main/schema/delta/35/add_state_index.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/35/contains_url.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/35/device_outbox.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/35/device_stream_id.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/35/event_push_actions_index.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/35/public_room_list_change_stream.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/35/state.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/35/state_dedupe.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/35/stream_order_to_extrem.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/36/readd_public_rooms.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/37/remove_auth_idx.py create mode 100644 synapse/storage/data_stores/main/schema/delta/37/user_threepids.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/38/postgres_fts_gist.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/39/appservice_room_list.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/39/device_federation_stream_idx.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/39/event_push_index.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/39/federation_out_position.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/39/membership_profile.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/40/current_state_idx.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/40/device_inbox.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/40/device_list_streams.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/40/event_push_summary.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/40/pushers.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/41/device_list_stream_idx.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/41/device_outbound_index.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/41/event_search_event_id_idx.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/41/ratelimit.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/42/current_state_delta.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/42/device_list_last_id.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/42/event_auth_state_only.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/42/user_dir.py create mode 100644 synapse/storage/data_stores/main/schema/delta/43/blocked_rooms.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/43/quarantine_media.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/43/url_cache.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/43/user_share.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/44/expire_url_cache.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/45/group_server.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/45/profile_cache.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/46/drop_refresh_tokens.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/46/drop_unique_deleted_pushers.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/46/group_server.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/46/local_media_repository_url_idx.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/46/user_dir_null_room_ids.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/46/user_dir_typos.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/47/last_access_media.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/47/postgres_fts_gin.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/47/push_actions_staging.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/47/state_group_seq.py create mode 100644 synapse/storage/data_stores/main/schema/delta/48/add_user_consent.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/48/add_user_ips_last_seen_index.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/48/deactivated_users.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/48/group_unique_indexes.py create mode 100644 synapse/storage/data_stores/main/schema/delta/48/groups_joinable.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/49/add_user_consent_server_notice_sent.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/49/add_user_daily_visits.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/49/add_user_ips_last_seen_only_index.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/50/add_creation_ts_users_index.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/50/erasure_store.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/50/make_event_content_nullable.py create mode 100644 synapse/storage/data_stores/main/schema/delta/51/e2e_room_keys.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/51/monthly_active_users.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/52/add_event_to_state_group_index.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/52/device_list_streams_unique_idx.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/52/e2e_room_keys.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/53/add_user_type_to_users.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/53/drop_sent_transactions.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/53/event_format_version.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/53/user_ips_index.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/53/user_share.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/53/user_threepid_id.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/53/users_in_public_rooms.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/54/account_validity_with_renewal.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/54/add_validity_to_server_keys.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/54/delete_forward_extremities.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/54/drop_legacy_tables.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/54/drop_presence_list.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/54/relations.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/54/stats.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/54/stats2.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/55/access_token_expiry.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/55/track_threepid_validations.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/55/users_alter_deactivated.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/56/add_spans_to_device_lists.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership_mk2.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/56/destinations_failure_ts.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/56/destinations_retry_interval_type.sql.postgres create mode 100644 synapse/storage/data_stores/main/schema/delta/56/devices_last_seen.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/56/drop_unused_event_tables.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/56/fix_room_keys_index.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/56/public_room_list_idx.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres create mode 100644 synapse/storage/data_stores/main/schema/delta/56/room_membership_idx.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/56/unique_user_filter_index.py create mode 100644 synapse/storage/data_stores/main/schema/delta/56/user_external_ids.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/56/users_in_public_rooms_idx.sql create mode 100644 synapse/storage/data_stores/main/schema/full_schemas/16/application_services.sql create mode 100644 synapse/storage/data_stores/main/schema/full_schemas/16/event_edges.sql create mode 100644 synapse/storage/data_stores/main/schema/full_schemas/16/event_signatures.sql create mode 100644 synapse/storage/data_stores/main/schema/full_schemas/16/im.sql create mode 100644 synapse/storage/data_stores/main/schema/full_schemas/16/keys.sql create mode 100644 synapse/storage/data_stores/main/schema/full_schemas/16/media_repository.sql create mode 100644 synapse/storage/data_stores/main/schema/full_schemas/16/presence.sql create mode 100644 synapse/storage/data_stores/main/schema/full_schemas/16/profiles.sql create mode 100644 synapse/storage/data_stores/main/schema/full_schemas/16/push.sql create mode 100644 synapse/storage/data_stores/main/schema/full_schemas/16/redactions.sql create mode 100644 synapse/storage/data_stores/main/schema/full_schemas/16/room_aliases.sql create mode 100644 synapse/storage/data_stores/main/schema/full_schemas/16/state.sql create mode 100644 synapse/storage/data_stores/main/schema/full_schemas/16/transactions.sql create mode 100644 synapse/storage/data_stores/main/schema/full_schemas/16/users.sql create mode 100644 synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres create mode 100644 synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite create mode 100644 synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql create mode 100644 synapse/storage/data_stores/main/schema/full_schemas/README.txt create mode 100644 synapse/storage/data_stores/main/search.py create mode 100644 synapse/storage/data_stores/main/signatures.py create mode 100644 synapse/storage/data_stores/main/state.py create mode 100644 synapse/storage/data_stores/main/state_deltas.py create mode 100644 synapse/storage/data_stores/main/stats.py create mode 100644 synapse/storage/data_stores/main/stream.py create mode 100644 synapse/storage/data_stores/main/tags.py create mode 100644 synapse/storage/data_stores/main/transactions.py create mode 100644 synapse/storage/data_stores/main/user_directory.py create mode 100644 synapse/storage/data_stores/main/user_erasure_store.py delete mode 100644 synapse/storage/deviceinbox.py delete mode 100644 synapse/storage/devices.py delete mode 100644 synapse/storage/directory.py delete mode 100644 synapse/storage/e2e_room_keys.py delete mode 100644 synapse/storage/end_to_end_keys.py delete mode 100644 synapse/storage/event_federation.py delete mode 100644 synapse/storage/event_push_actions.py delete mode 100644 synapse/storage/events.py delete mode 100644 synapse/storage/events_bg_updates.py delete mode 100644 synapse/storage/events_worker.py delete mode 100644 synapse/storage/filtering.py delete mode 100644 synapse/storage/group_server.py delete mode 100644 synapse/storage/media_repository.py delete mode 100644 synapse/storage/monthly_active_users.py delete mode 100644 synapse/storage/openid.py delete mode 100644 synapse/storage/profile.py delete mode 100644 synapse/storage/pusher.py delete mode 100644 synapse/storage/receipts.py delete mode 100644 synapse/storage/registration.py delete mode 100644 synapse/storage/rejections.py delete mode 100644 synapse/storage/room.py delete mode 100644 synapse/storage/schema/delta/12/v12.sql delete mode 100644 synapse/storage/schema/delta/13/v13.sql delete mode 100644 synapse/storage/schema/delta/14/v14.sql delete mode 100644 synapse/storage/schema/delta/15/appservice_txns.sql delete mode 100644 synapse/storage/schema/delta/15/presence_indices.sql delete mode 100644 synapse/storage/schema/delta/15/v15.sql delete mode 100644 synapse/storage/schema/delta/16/events_order_index.sql delete mode 100644 synapse/storage/schema/delta/16/remote_media_cache_index.sql delete mode 100644 synapse/storage/schema/delta/16/remove_duplicates.sql delete mode 100644 synapse/storage/schema/delta/16/room_alias_index.sql delete mode 100644 synapse/storage/schema/delta/16/unique_constraints.sql delete mode 100644 synapse/storage/schema/delta/16/users.sql delete mode 100644 synapse/storage/schema/delta/17/drop_indexes.sql delete mode 100644 synapse/storage/schema/delta/17/server_keys.sql delete mode 100644 synapse/storage/schema/delta/17/user_threepids.sql delete mode 100644 synapse/storage/schema/delta/18/server_keys_bigger_ints.sql delete mode 100644 synapse/storage/schema/delta/19/event_index.sql delete mode 100644 synapse/storage/schema/delta/20/dummy.sql delete mode 100644 synapse/storage/schema/delta/20/pushers.py delete mode 100644 synapse/storage/schema/delta/21/end_to_end_keys.sql delete mode 100644 synapse/storage/schema/delta/21/receipts.sql delete mode 100644 synapse/storage/schema/delta/22/receipts_index.sql delete mode 100644 synapse/storage/schema/delta/22/user_threepids_unique.sql delete mode 100644 synapse/storage/schema/delta/23/drop_state_index.sql delete mode 100644 synapse/storage/schema/delta/24/stats_reporting.sql delete mode 100644 synapse/storage/schema/delta/25/fts.py delete mode 100644 synapse/storage/schema/delta/25/guest_access.sql delete mode 100644 synapse/storage/schema/delta/25/history_visibility.sql delete mode 100644 synapse/storage/schema/delta/25/tags.sql delete mode 100644 synapse/storage/schema/delta/26/account_data.sql delete mode 100644 synapse/storage/schema/delta/27/account_data.sql delete mode 100644 synapse/storage/schema/delta/27/forgotten_memberships.sql delete mode 100644 synapse/storage/schema/delta/27/ts.py delete mode 100644 synapse/storage/schema/delta/28/event_push_actions.sql delete mode 100644 synapse/storage/schema/delta/28/events_room_stream.sql delete mode 100644 synapse/storage/schema/delta/28/public_roms_index.sql delete mode 100644 synapse/storage/schema/delta/28/receipts_user_id_index.sql delete mode 100644 synapse/storage/schema/delta/28/upgrade_times.sql delete mode 100644 synapse/storage/schema/delta/28/users_is_guest.sql delete mode 100644 synapse/storage/schema/delta/29/push_actions.sql delete mode 100644 synapse/storage/schema/delta/30/alias_creator.sql delete mode 100644 synapse/storage/schema/delta/30/as_users.py delete mode 100644 synapse/storage/schema/delta/30/deleted_pushers.sql delete mode 100644 synapse/storage/schema/delta/30/presence_stream.sql delete mode 100644 synapse/storage/schema/delta/30/public_rooms.sql delete mode 100644 synapse/storage/schema/delta/30/push_rule_stream.sql delete mode 100644 synapse/storage/schema/delta/30/state_stream.sql delete mode 100644 synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql delete mode 100644 synapse/storage/schema/delta/31/invites.sql delete mode 100644 synapse/storage/schema/delta/31/local_media_repository_url_cache.sql delete mode 100644 synapse/storage/schema/delta/31/pushers.py delete mode 100644 synapse/storage/schema/delta/31/pushers_index.sql delete mode 100644 synapse/storage/schema/delta/31/search_update.py delete mode 100644 synapse/storage/schema/delta/32/events.sql delete mode 100644 synapse/storage/schema/delta/32/openid.sql delete mode 100644 synapse/storage/schema/delta/32/pusher_throttle.sql delete mode 100644 synapse/storage/schema/delta/32/remove_indices.sql delete mode 100644 synapse/storage/schema/delta/32/reports.sql delete mode 100644 synapse/storage/schema/delta/33/access_tokens_device_index.sql delete mode 100644 synapse/storage/schema/delta/33/devices.sql delete mode 100644 synapse/storage/schema/delta/33/devices_for_e2e_keys.sql delete mode 100644 synapse/storage/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql delete mode 100644 synapse/storage/schema/delta/33/event_fields.py delete mode 100644 synapse/storage/schema/delta/33/remote_media_ts.py delete mode 100644 synapse/storage/schema/delta/33/user_ips_index.sql delete mode 100644 synapse/storage/schema/delta/34/appservice_stream.sql delete mode 100644 synapse/storage/schema/delta/34/cache_stream.py delete mode 100644 synapse/storage/schema/delta/34/device_inbox.sql delete mode 100644 synapse/storage/schema/delta/34/push_display_name_rename.sql delete mode 100644 synapse/storage/schema/delta/34/received_txn_purge.py create mode 100644 synapse/storage/schema/delta/35/00background_updates_add_col.sql delete mode 100644 synapse/storage/schema/delta/35/add_state_index.sql delete mode 100644 synapse/storage/schema/delta/35/contains_url.sql delete mode 100644 synapse/storage/schema/delta/35/device_outbox.sql delete mode 100644 synapse/storage/schema/delta/35/device_stream_id.sql delete mode 100644 synapse/storage/schema/delta/35/event_push_actions_index.sql delete mode 100644 synapse/storage/schema/delta/35/public_room_list_change_stream.sql delete mode 100644 synapse/storage/schema/delta/35/state.sql delete mode 100644 synapse/storage/schema/delta/35/state_dedupe.sql delete mode 100644 synapse/storage/schema/delta/35/stream_order_to_extrem.sql delete mode 100644 synapse/storage/schema/delta/36/readd_public_rooms.sql delete mode 100644 synapse/storage/schema/delta/37/remove_auth_idx.py delete mode 100644 synapse/storage/schema/delta/37/user_threepids.sql delete mode 100644 synapse/storage/schema/delta/38/postgres_fts_gist.sql delete mode 100644 synapse/storage/schema/delta/39/appservice_room_list.sql delete mode 100644 synapse/storage/schema/delta/39/device_federation_stream_idx.sql delete mode 100644 synapse/storage/schema/delta/39/event_push_index.sql delete mode 100644 synapse/storage/schema/delta/39/federation_out_position.sql delete mode 100644 synapse/storage/schema/delta/39/membership_profile.sql delete mode 100644 synapse/storage/schema/delta/40/current_state_idx.sql delete mode 100644 synapse/storage/schema/delta/40/device_inbox.sql delete mode 100644 synapse/storage/schema/delta/40/device_list_streams.sql delete mode 100644 synapse/storage/schema/delta/40/event_push_summary.sql delete mode 100644 synapse/storage/schema/delta/40/pushers.sql delete mode 100644 synapse/storage/schema/delta/41/device_list_stream_idx.sql delete mode 100644 synapse/storage/schema/delta/41/device_outbound_index.sql delete mode 100644 synapse/storage/schema/delta/41/event_search_event_id_idx.sql delete mode 100644 synapse/storage/schema/delta/41/ratelimit.sql delete mode 100644 synapse/storage/schema/delta/42/current_state_delta.sql delete mode 100644 synapse/storage/schema/delta/42/device_list_last_id.sql delete mode 100644 synapse/storage/schema/delta/42/event_auth_state_only.sql delete mode 100644 synapse/storage/schema/delta/42/user_dir.py delete mode 100644 synapse/storage/schema/delta/43/blocked_rooms.sql delete mode 100644 synapse/storage/schema/delta/43/quarantine_media.sql delete mode 100644 synapse/storage/schema/delta/43/url_cache.sql delete mode 100644 synapse/storage/schema/delta/43/user_share.sql delete mode 100644 synapse/storage/schema/delta/44/expire_url_cache.sql delete mode 100644 synapse/storage/schema/delta/45/group_server.sql delete mode 100644 synapse/storage/schema/delta/45/profile_cache.sql delete mode 100644 synapse/storage/schema/delta/46/drop_refresh_tokens.sql delete mode 100644 synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql delete mode 100644 synapse/storage/schema/delta/46/group_server.sql delete mode 100644 synapse/storage/schema/delta/46/local_media_repository_url_idx.sql delete mode 100644 synapse/storage/schema/delta/46/user_dir_null_room_ids.sql delete mode 100644 synapse/storage/schema/delta/46/user_dir_typos.sql delete mode 100644 synapse/storage/schema/delta/47/last_access_media.sql delete mode 100644 synapse/storage/schema/delta/47/postgres_fts_gin.sql delete mode 100644 synapse/storage/schema/delta/47/push_actions_staging.sql delete mode 100644 synapse/storage/schema/delta/47/state_group_seq.py delete mode 100644 synapse/storage/schema/delta/48/add_user_consent.sql delete mode 100644 synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql delete mode 100644 synapse/storage/schema/delta/48/deactivated_users.sql delete mode 100644 synapse/storage/schema/delta/48/group_unique_indexes.py delete mode 100644 synapse/storage/schema/delta/48/groups_joinable.sql delete mode 100644 synapse/storage/schema/delta/49/add_user_consent_server_notice_sent.sql delete mode 100644 synapse/storage/schema/delta/49/add_user_daily_visits.sql delete mode 100644 synapse/storage/schema/delta/49/add_user_ips_last_seen_only_index.sql delete mode 100644 synapse/storage/schema/delta/50/add_creation_ts_users_index.sql delete mode 100644 synapse/storage/schema/delta/50/erasure_store.sql delete mode 100644 synapse/storage/schema/delta/50/make_event_content_nullable.py delete mode 100644 synapse/storage/schema/delta/51/e2e_room_keys.sql delete mode 100644 synapse/storage/schema/delta/51/monthly_active_users.sql delete mode 100644 synapse/storage/schema/delta/52/add_event_to_state_group_index.sql delete mode 100644 synapse/storage/schema/delta/52/device_list_streams_unique_idx.sql delete mode 100644 synapse/storage/schema/delta/52/e2e_room_keys.sql delete mode 100644 synapse/storage/schema/delta/53/add_user_type_to_users.sql delete mode 100644 synapse/storage/schema/delta/53/drop_sent_transactions.sql delete mode 100644 synapse/storage/schema/delta/53/event_format_version.sql delete mode 100644 synapse/storage/schema/delta/53/user_dir_populate.sql delete mode 100644 synapse/storage/schema/delta/53/user_ips_index.sql delete mode 100644 synapse/storage/schema/delta/53/user_share.sql delete mode 100644 synapse/storage/schema/delta/53/user_threepid_id.sql delete mode 100644 synapse/storage/schema/delta/53/users_in_public_rooms.sql delete mode 100644 synapse/storage/schema/delta/54/account_validity_with_renewal.sql delete mode 100644 synapse/storage/schema/delta/54/add_validity_to_server_keys.sql delete mode 100644 synapse/storage/schema/delta/54/delete_forward_extremities.sql delete mode 100644 synapse/storage/schema/delta/54/drop_legacy_tables.sql delete mode 100644 synapse/storage/schema/delta/54/drop_presence_list.sql delete mode 100644 synapse/storage/schema/delta/54/relations.sql delete mode 100644 synapse/storage/schema/delta/54/stats.sql delete mode 100644 synapse/storage/schema/delta/54/stats2.sql delete mode 100644 synapse/storage/schema/delta/55/access_token_expiry.sql delete mode 100644 synapse/storage/schema/delta/55/track_threepid_validations.sql delete mode 100644 synapse/storage/schema/delta/55/users_alter_deactivated.sql delete mode 100644 synapse/storage/schema/delta/56/add_spans_to_device_lists.sql delete mode 100644 synapse/storage/schema/delta/56/current_state_events_membership.sql delete mode 100644 synapse/storage/schema/delta/56/current_state_events_membership_mk2.sql delete mode 100644 synapse/storage/schema/delta/56/destinations_failure_ts.sql delete mode 100644 synapse/storage/schema/delta/56/destinations_retry_interval_type.sql.postgres delete mode 100644 synapse/storage/schema/delta/56/devices_last_seen.sql delete mode 100644 synapse/storage/schema/delta/56/drop_unused_event_tables.sql delete mode 100644 synapse/storage/schema/delta/56/fix_room_keys_index.sql delete mode 100644 synapse/storage/schema/delta/56/public_room_list_idx.sql delete mode 100644 synapse/storage/schema/delta/56/redaction_censor.sql delete mode 100644 synapse/storage/schema/delta/56/redaction_censor2.sql delete mode 100644 synapse/storage/schema/delta/56/redaction_censor3_fix_update.sql.postgres delete mode 100644 synapse/storage/schema/delta/56/room_membership_idx.sql delete mode 100644 synapse/storage/schema/delta/56/stats_separated.sql delete mode 100644 synapse/storage/schema/delta/56/unique_user_filter_index.py delete mode 100644 synapse/storage/schema/delta/56/user_external_ids.sql delete mode 100644 synapse/storage/schema/delta/56/users_in_public_rooms_idx.sql delete mode 100644 synapse/storage/schema/full_schemas/16/application_services.sql delete mode 100644 synapse/storage/schema/full_schemas/16/event_edges.sql delete mode 100644 synapse/storage/schema/full_schemas/16/event_signatures.sql delete mode 100644 synapse/storage/schema/full_schemas/16/im.sql delete mode 100644 synapse/storage/schema/full_schemas/16/keys.sql delete mode 100644 synapse/storage/schema/full_schemas/16/media_repository.sql delete mode 100644 synapse/storage/schema/full_schemas/16/presence.sql delete mode 100644 synapse/storage/schema/full_schemas/16/profiles.sql delete mode 100644 synapse/storage/schema/full_schemas/16/push.sql delete mode 100644 synapse/storage/schema/full_schemas/16/redactions.sql delete mode 100644 synapse/storage/schema/full_schemas/16/room_aliases.sql delete mode 100644 synapse/storage/schema/full_schemas/16/state.sql delete mode 100644 synapse/storage/schema/full_schemas/16/transactions.sql delete mode 100644 synapse/storage/schema/full_schemas/16/users.sql create mode 100644 synapse/storage/schema/full_schemas/54/full.sql delete mode 100644 synapse/storage/schema/full_schemas/54/full.sql.postgres delete mode 100644 synapse/storage/schema/full_schemas/54/full.sql.sqlite delete mode 100644 synapse/storage/schema/full_schemas/54/stream_positions.sql delete mode 100644 synapse/storage/schema/full_schemas/README.txt delete mode 100644 synapse/storage/search.py delete mode 100644 synapse/storage/signatures.py delete mode 100644 synapse/storage/state_deltas.py delete mode 100644 synapse/storage/stats.py delete mode 100644 synapse/storage/stream.py delete mode 100644 synapse/storage/tags.py delete mode 100644 synapse/storage/transactions.py delete mode 100644 synapse/storage/user_directory.py delete mode 100644 synapse/storage/user_erasure_store.py (limited to 'tests') diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py index c67fe69a50..f20d810ece 100644 --- a/synapse/app/event_creator.py +++ b/synapse/app/event_creator.py @@ -56,8 +56,8 @@ from synapse.rest.client.v1.room import ( RoomStateEventRestServlet, ) from synapse.server import HomeServer +from synapse.storage.data_stores.main.user_directory import UserDirectoryStore from synapse.storage.engines import create_engine -from synapse.storage.user_directory import UserDirectoryStore from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index 2ac783ffa3..6bc7202f33 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -39,8 +39,8 @@ from synapse.replication.tcp.client import ReplicationClientHandler from synapse.rest.admin import register_servlets_for_media_repo from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.server import HomeServer +from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore from synapse.storage.engines import create_engine -from synapse.storage.media_repository import MediaRepositoryStore from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 473026fce5..6a7e2fa707 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -54,8 +54,8 @@ from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet from synapse.rest.client.v1.room import RoomInitialSyncRestServlet from synapse.rest.client.v2_alpha import sync from synapse.server import HomeServer +from synapse.storage.data_stores.main.presence import UserPresenceState from synapse.storage.engines import create_engine -from synapse.storage.presence import UserPresenceState from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.stringutils import random_string diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index e01afb39f2..a5d6dc7915 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -42,8 +42,8 @@ from synapse.replication.tcp.streams.events import ( ) from synapse.rest.client.v2_alpha import user_directory from synapse.server import HomeServer +from synapse.storage.data_stores.main.user_directory import UserDirectoryStore from synapse.storage.engines import create_engine -from synapse.storage.user_directory import UserDirectoryStore from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index fad980b893..cc75c39476 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -30,7 +30,7 @@ from synapse.federation.units import Edu from synapse.handlers.presence import format_user_presence_state from synapse.metrics import sent_transactions_counter from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage import UserPresenceState +from synapse.storage.presence import UserPresenceState from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter # This is defined in the Matrix spec and enforced by the receiver. diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index 3c44d1d48d..bc2f6a12ae 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -16,8 +16,8 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.storage.account_data import AccountDataWorkerStore -from synapse.storage.tags import TagsWorkerStore +from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore +from synapse.storage.data_stores.main.tags import TagsWorkerStore class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py index cda12ea70d..a67fbeffb7 100644 --- a/synapse/replication/slave/storage/appservice.py +++ b/synapse/replication/slave/storage/appservice.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.appservice import ( +from synapse.storage.data_stores.main.appservice import ( ApplicationServiceTransactionWorkerStore, ApplicationServiceWorkerStore, ) diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 14ced32333..b4f58cea19 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.client_ips import LAST_SEEN_GRANULARITY +from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.descriptors import Cache diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index 284fd30d89..9fb6c5c6ff 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -15,7 +15,7 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.storage.deviceinbox import DeviceInboxWorkerStore +from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.stream_change_cache import StreamChangeCache diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index d9300fce33..f856c72d84 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -15,8 +15,8 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.storage.devices import DeviceWorkerStore -from synapse.storage.end_to_end_keys import EndToEndKeyWorkerStore +from synapse.storage.data_stores.main.devices import DeviceWorkerStore +from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore from synapse.util.caches.stream_change_cache import StreamChangeCache diff --git a/synapse/replication/slave/storage/directory.py b/synapse/replication/slave/storage/directory.py index 1d1d48709a..8b9717c46f 100644 --- a/synapse/replication/slave/storage/directory.py +++ b/synapse/replication/slave/storage/directory.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.directory import DirectoryWorkerStore +from synapse.storage.data_stores.main.directory import DirectoryWorkerStore from ._base import BaseSlavedStore diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index ab5937e638..d0a0eaf75b 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -20,15 +20,17 @@ from synapse.replication.tcp.streams.events import ( EventsStreamCurrentStateRow, EventsStreamEventRow, ) -from synapse.storage.event_federation import EventFederationWorkerStore -from synapse.storage.event_push_actions import EventPushActionsWorkerStore -from synapse.storage.events_worker import EventsWorkerStore -from synapse.storage.relations import RelationsWorkerStore -from synapse.storage.roommember import RoomMemberWorkerStore -from synapse.storage.signatures import SignatureWorkerStore -from synapse.storage.state import StateGroupWorkerStore -from synapse.storage.stream import StreamWorkerStore -from synapse.storage.user_erasure_store import UserErasureWorkerStore +from synapse.storage.data_stores.main.event_federation import EventFederationWorkerStore +from synapse.storage.data_stores.main.event_push_actions import ( + EventPushActionsWorkerStore, +) +from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.data_stores.main.relations import RelationsWorkerStore +from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore +from synapse.storage.data_stores.main.signatures import SignatureWorkerStore +from synapse.storage.data_stores.main.state import StateGroupWorkerStore +from synapse.storage.data_stores.main.stream import StreamWorkerStore +from synapse.storage.data_stores.main.user_erasure_store import UserErasureWorkerStore from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py index 456a14cd5c..5c84ebd125 100644 --- a/synapse/replication/slave/storage/filtering.py +++ b/synapse/replication/slave/storage/filtering.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.filtering import FilteringStore +from synapse.storage.data_stores.main.filtering import FilteringStore from ._base import BaseSlavedStore diff --git a/synapse/replication/slave/storage/keys.py b/synapse/replication/slave/storage/keys.py index cc6f7f009f..3def367ae9 100644 --- a/synapse/replication/slave/storage/keys.py +++ b/synapse/replication/slave/storage/keys.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage import KeyStore +from synapse.storage.data_stores.main.keys import KeyStore # KeyStore isn't really safe to use from a worker, but for now we do so and hope that # the races it creates aren't too bad. diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index 82d808af4c..747ced0c84 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -14,7 +14,7 @@ # limitations under the License. from synapse.storage import DataStore -from synapse.storage.presence import PresenceStore +from synapse.storage.data_stores.main.presence import PresenceStore from synapse.util.caches.stream_change_cache import StreamChangeCache from ._base import BaseSlavedStore, __func__ diff --git a/synapse/replication/slave/storage/profile.py b/synapse/replication/slave/storage/profile.py index 46c28d4171..28c508aad3 100644 --- a/synapse/replication/slave/storage/profile.py +++ b/synapse/replication/slave/storage/profile.py @@ -14,7 +14,7 @@ # limitations under the License. from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.storage.profile import ProfileWorkerStore +from synapse.storage.data_stores.main.profile import ProfileWorkerStore class SlavedProfileStore(ProfileWorkerStore, BaseSlavedStore): diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index af7012702e..3655f05e54 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.push_rule import PushRulesWorkerStore +from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore from ._slaved_id_tracker import SlavedIdTracker from .events import SlavedEventStore diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index 8eeb267d61..b4331d0799 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.pusher import PusherWorkerStore +from synapse.storage.data_stores.main.pusher import PusherWorkerStore from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index 91afa5a72b..43d823c601 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.receipts import ReceiptsWorkerStore +from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker diff --git a/synapse/replication/slave/storage/registration.py b/synapse/replication/slave/storage/registration.py index 408d91df1c..4b8553e250 100644 --- a/synapse/replication/slave/storage/registration.py +++ b/synapse/replication/slave/storage/registration.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.registration import RegistrationWorkerStore +from synapse.storage.data_stores.main.registration import RegistrationWorkerStore from ._base import BaseSlavedStore diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index f68b3378e3..d9ad386b28 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.room import RoomWorkerStore +from synapse.storage.data_stores.main.room import RoomWorkerStore from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py index 3527beb3c9..ac88e6b8c3 100644 --- a/synapse/replication/slave/storage/transactions.py +++ b/synapse/replication/slave/storage/transactions.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.transactions import TransactionStore +from synapse.storage.data_stores.main.transactions import TransactionStore from ._base import BaseSlavedStore diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index e7f6ea7286..e42fba45a1 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -14,509 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import calendar -import logging -import time - -from twisted.internet import defer - -from synapse.api.constants import PresenceState -from synapse.storage.devices import DeviceStore -from synapse.storage.user_erasure_store import UserErasureStore -from synapse.util.caches.stream_change_cache import StreamChangeCache - -from .account_data import AccountDataStore -from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore -from .client_ips import ClientIpStore -from .deviceinbox import DeviceInboxStore -from .directory import DirectoryStore -from .e2e_room_keys import EndToEndRoomKeyStore -from .end_to_end_keys import EndToEndKeyStore -from .engines import PostgresEngine -from .event_federation import EventFederationStore -from .event_push_actions import EventPushActionsStore -from .events import EventsStore -from .events_bg_updates import EventsBackgroundUpdatesStore -from .filtering import FilteringStore -from .group_server import GroupServerStore -from .keys import KeyStore -from .media_repository import MediaRepositoryStore -from .monthly_active_users import MonthlyActiveUsersStore -from .openid import OpenIdStore -from .presence import PresenceStore, UserPresenceState -from .profile import ProfileStore -from .push_rule import PushRuleStore -from .pusher import PusherStore -from .receipts import ReceiptsStore -from .registration import RegistrationStore -from .rejections import RejectionsStore -from .relations import RelationsStore -from .room import RoomStore -from .roommember import RoomMemberStore -from .search import SearchStore -from .signatures import SignatureStore -from .state import StateStore -from .stats import StatsStore -from .stream import StreamStore -from .tags import TagsStore -from .transactions import TransactionStore -from .user_directory import UserDirectoryStore -from .util.id_generators import ChainedIdGenerator, IdGenerator, StreamIdGenerator - -logger = logging.getLogger(__name__) - - -class DataStore( - EventsBackgroundUpdatesStore, - RoomMemberStore, - RoomStore, - RegistrationStore, - StreamStore, - ProfileStore, - PresenceStore, - TransactionStore, - DirectoryStore, - KeyStore, - StateStore, - SignatureStore, - ApplicationServiceStore, - EventsStore, - EventFederationStore, - MediaRepositoryStore, - RejectionsStore, - FilteringStore, - PusherStore, - PushRuleStore, - ApplicationServiceTransactionStore, - ReceiptsStore, - EndToEndKeyStore, - EndToEndRoomKeyStore, - SearchStore, - TagsStore, - AccountDataStore, - EventPushActionsStore, - OpenIdStore, - ClientIpStore, - DeviceStore, - DeviceInboxStore, - UserDirectoryStore, - GroupServerStore, - UserErasureStore, - MonthlyActiveUsersStore, - StatsStore, - RelationsStore, -): - def __init__(self, db_conn, hs): - self.hs = hs - self._clock = hs.get_clock() - self.database_engine = hs.database_engine - - self._stream_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - extra_tables=[("local_invites", "stream_id")], - ) - self._backfill_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - step=-1, - extra_tables=[("ex_outlier_stream", "event_stream_ordering")], - ) - self._presence_id_gen = StreamIdGenerator( - db_conn, "presence_stream", "stream_id" - ) - self._device_inbox_id_gen = StreamIdGenerator( - db_conn, "device_max_stream_id", "stream_id" - ) - self._public_room_id_gen = StreamIdGenerator( - db_conn, "public_room_list_stream", "stream_id" - ) - self._device_list_id_gen = StreamIdGenerator( - db_conn, "device_lists_stream", "stream_id" - ) - - self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") - self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") - self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") - self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") - self._push_rules_stream_id_gen = ChainedIdGenerator( - self._stream_id_gen, db_conn, "push_rules_stream", "stream_id" - ) - self._pushers_id_gen = StreamIdGenerator( - db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] - ) - self._group_updates_id_gen = StreamIdGenerator( - db_conn, "local_group_updates", "stream_id" - ) - - if isinstance(self.database_engine, PostgresEngine): - self._cache_id_gen = StreamIdGenerator( - db_conn, "cache_invalidation_stream", "stream_id" - ) - else: - self._cache_id_gen = None - - self._presence_on_startup = self._get_active_presence(db_conn) - - presence_cache_prefill, min_presence_val = self._get_cache_dict( - db_conn, - "presence_stream", - entity_column="user_id", - stream_column="stream_id", - max_value=self._presence_id_gen.get_current_token(), - ) - self.presence_stream_cache = StreamChangeCache( - "PresenceStreamChangeCache", - min_presence_val, - prefilled_cache=presence_cache_prefill, - ) - - max_device_inbox_id = self._device_inbox_id_gen.get_current_token() - device_inbox_prefill, min_device_inbox_id = self._get_cache_dict( - db_conn, - "device_inbox", - entity_column="user_id", - stream_column="stream_id", - max_value=max_device_inbox_id, - limit=1000, - ) - self._device_inbox_stream_cache = StreamChangeCache( - "DeviceInboxStreamChangeCache", - min_device_inbox_id, - prefilled_cache=device_inbox_prefill, - ) - # The federation outbox and the local device inbox uses the same - # stream_id generator. - device_outbox_prefill, min_device_outbox_id = self._get_cache_dict( - db_conn, - "device_federation_outbox", - entity_column="destination", - stream_column="stream_id", - max_value=max_device_inbox_id, - limit=1000, - ) - self._device_federation_outbox_stream_cache = StreamChangeCache( - "DeviceFederationOutboxStreamChangeCache", - min_device_outbox_id, - prefilled_cache=device_outbox_prefill, - ) - - device_list_max = self._device_list_id_gen.get_current_token() - self._device_list_stream_cache = StreamChangeCache( - "DeviceListStreamChangeCache", device_list_max - ) - self._device_list_federation_stream_cache = StreamChangeCache( - "DeviceListFederationStreamChangeCache", device_list_max - ) - - events_max = self._stream_id_gen.get_current_token() - curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict( - db_conn, - "current_state_delta_stream", - entity_column="room_id", - stream_column="stream_id", - max_value=events_max, # As we share the stream id with events token - limit=1000, - ) - self._curr_state_delta_stream_cache = StreamChangeCache( - "_curr_state_delta_stream_cache", - min_curr_state_delta_id, - prefilled_cache=curr_state_delta_prefill, - ) - - _group_updates_prefill, min_group_updates_id = self._get_cache_dict( - db_conn, - "local_group_updates", - entity_column="user_id", - stream_column="stream_id", - max_value=self._group_updates_id_gen.get_current_token(), - limit=1000, - ) - self._group_updates_stream_cache = StreamChangeCache( - "_group_updates_stream_cache", - min_group_updates_id, - prefilled_cache=_group_updates_prefill, - ) - - self._stream_order_on_start = self.get_room_max_stream_ordering() - self._min_stream_order_on_start = self.get_room_min_stream_ordering() - - # Used in _generate_user_daily_visits to keep track of progress - self._last_user_visit_update = self._get_start_of_day() - - super(DataStore, self).__init__(db_conn, hs) - - def take_presence_startup_info(self): - active_on_startup = self._presence_on_startup - self._presence_on_startup = None - return active_on_startup - - def _get_active_presence(self, db_conn): - """Fetch non-offline presence from the database so that we can register - the appropriate time outs. - """ - - sql = ( - "SELECT user_id, state, last_active_ts, last_federation_update_ts," - " last_user_sync_ts, status_msg, currently_active FROM presence_stream" - " WHERE state != ?" - ) - sql = self.database_engine.convert_param_style(sql) - - txn = db_conn.cursor() - txn.execute(sql, (PresenceState.OFFLINE,)) - rows = self.cursor_to_dict(txn) - txn.close() - - for row in rows: - row["currently_active"] = bool(row["currently_active"]) - - return [UserPresenceState(**row) for row in rows] - - def count_daily_users(self): - """ - Counts the number of users who used this homeserver in the last 24 hours. - """ - yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) - return self.runInteraction("count_daily_users", self._count_users, yesterday) - - def count_monthly_users(self): - """ - Counts the number of users who used this homeserver in the last 30 days. - Note this method is intended for phonehome metrics only and is different - from the mau figure in synapse.storage.monthly_active_users which, - amongst other things, includes a 3 day grace period before a user counts. - """ - thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) - return self.runInteraction( - "count_monthly_users", self._count_users, thirty_days_ago - ) - - def _count_users(self, txn, time_from): - """ - Returns number of users seen in the past time_from period - """ - sql = """ - SELECT COALESCE(count(*), 0) FROM ( - SELECT user_id FROM user_ips - WHERE last_seen > ? - GROUP BY user_id - ) u - """ - txn.execute(sql, (time_from,)) - count, = txn.fetchone() - return count - - def count_r30_users(self): - """ - Counts the number of 30 day retained users, defined as:- - * Users who have created their accounts more than 30 days ago - * Where last seen at most 30 days ago - * Where account creation and last_seen are > 30 days apart - - Returns counts globaly for a given user as well as breaking - by platform - """ - - def _count_r30_users(txn): - thirty_days_in_secs = 86400 * 30 - now = int(self._clock.time()) - thirty_days_ago_in_secs = now - thirty_days_in_secs - - sql = """ - SELECT platform, COALESCE(count(*), 0) FROM ( - SELECT - users.name, platform, users.creation_ts * 1000, - MAX(uip.last_seen) - FROM users - INNER JOIN ( - SELECT - user_id, - last_seen, - CASE - WHEN user_agent LIKE '%%Android%%' THEN 'android' - WHEN user_agent LIKE '%%iOS%%' THEN 'ios' - WHEN user_agent LIKE '%%Electron%%' THEN 'electron' - WHEN user_agent LIKE '%%Mozilla%%' THEN 'web' - WHEN user_agent LIKE '%%Gecko%%' THEN 'web' - ELSE 'unknown' - END - AS platform - FROM user_ips - ) uip - ON users.name = uip.user_id - AND users.appservice_id is NULL - AND users.creation_ts < ? - AND uip.last_seen/1000 > ? - AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 - GROUP BY users.name, platform, users.creation_ts - ) u GROUP BY platform - """ - - results = {} - txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) - - for row in txn: - if row[0] == "unknown": - pass - results[row[0]] = row[1] - - sql = """ - SELECT COALESCE(count(*), 0) FROM ( - SELECT users.name, users.creation_ts * 1000, - MAX(uip.last_seen) - FROM users - INNER JOIN ( - SELECT - user_id, - last_seen - FROM user_ips - ) uip - ON users.name = uip.user_id - AND appservice_id is NULL - AND users.creation_ts < ? - AND uip.last_seen/1000 > ? - AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 - GROUP BY users.name, users.creation_ts - ) u - """ - - txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) - - count, = txn.fetchone() - results["all"] = count - - return results - - return self.runInteraction("count_r30_users", _count_r30_users) - - def _get_start_of_day(self): - """ - Returns millisecond unixtime for start of UTC day. - """ - now = time.gmtime() - today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0)) - return today_start * 1000 - - def generate_user_daily_visits(self): - """ - Generates daily visit data for use in cohort/ retention analysis - """ - - def _generate_user_daily_visits(txn): - logger.info("Calling _generate_user_daily_visits") - today_start = self._get_start_of_day() - a_day_in_milliseconds = 24 * 60 * 60 * 1000 - now = self.clock.time_msec() - - sql = """ - INSERT INTO user_daily_visits (user_id, device_id, timestamp) - SELECT u.user_id, u.device_id, ? - FROM user_ips AS u - LEFT JOIN ( - SELECT user_id, device_id, timestamp FROM user_daily_visits - WHERE timestamp = ? - ) udv - ON u.user_id = udv.user_id AND u.device_id=udv.device_id - INNER JOIN users ON users.name=u.user_id - WHERE last_seen > ? AND last_seen <= ? - AND udv.timestamp IS NULL AND users.is_guest=0 - AND users.appservice_id IS NULL - GROUP BY u.user_id, u.device_id - """ - - # This means that the day has rolled over but there could still - # be entries from the previous day. There is an edge case - # where if the user logs in at 23:59 and overwrites their - # last_seen at 00:01 then they will not be counted in the - # previous day's stats - it is important that the query is run - # often to minimise this case. - if today_start > self._last_user_visit_update: - yesterday_start = today_start - a_day_in_milliseconds - txn.execute( - sql, - ( - yesterday_start, - yesterday_start, - self._last_user_visit_update, - today_start, - ), - ) - self._last_user_visit_update = today_start - - txn.execute( - sql, (today_start, today_start, self._last_user_visit_update, now) - ) - # Update _last_user_visit_update to now. The reason to do this - # rather just clamping to the beginning of the day is to limit - # the size of the join - meaning that the query can be run more - # frequently - self._last_user_visit_update = now - - return self.runInteraction( - "generate_user_daily_visits", _generate_user_daily_visits - ) - - def get_users(self): - """Function to reterive a list of users in users table. - - Args: - Returns: - defer.Deferred: resolves to list[dict[str, Any]] - """ - return self._simple_select_list( - table="users", - keyvalues={}, - retcols=["name", "password_hash", "is_guest", "admin", "user_type"], - desc="get_users", - ) - - @defer.inlineCallbacks - def get_users_paginate(self, order, start, limit): - """Function to reterive a paginated list of users from - users list. This will return a json object, which contains - list of users and the total number of users in users table. - - Args: - order (str): column name to order the select by this column - start (int): start number to begin the query from - limit (int): number of rows to reterive - Returns: - defer.Deferred: resolves to json object {list[dict[str, Any]], count} - """ - users = yield self.runInteraction( - "get_users_paginate", - self._simple_select_list_paginate_txn, - table="users", - keyvalues={"is_guest": False}, - orderby=order, - start=start, - limit=limit, - retcols=["name", "password_hash", "is_guest", "admin", "user_type"], - ) - count = yield self.runInteraction("get_users_paginate", self.get_user_count_txn) - retval = {"users": users, "total": count} - return retval - - def search_users(self, term): - """Function to search users list for one or more users with - the matched term. - - Args: - term (str): search term - col (str): column to query term should be matched to - Returns: - defer.Deferred: resolves to list[dict[str, Any]] - """ - return self._simple_search_list( - table="users", - term=term, - col="name", - retcols=["name", "password_hash", "is_guest", "admin", "user_type"], - desc="search_users", - ) +from synapse.storage.data_stores.main import DataStore # noqa: F401 def are_all_users_on_domain(txn, database_engine, domain): diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py deleted file mode 100644 index 6afbfc0d74..0000000000 --- a/synapse/storage/account_data.py +++ /dev/null @@ -1,391 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import abc -import logging - -from canonicaljson import json - -from twisted.internet import defer - -from synapse.storage._base import SQLBaseStore -from synapse.storage.util.id_generators import StreamIdGenerator -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks -from synapse.util.caches.stream_change_cache import StreamChangeCache - -logger = logging.getLogger(__name__) - - -class AccountDataWorkerStore(SQLBaseStore): - """This is an abstract base class where subclasses must implement - `get_max_account_data_stream_id` which can be called in the initializer. - """ - - # This ABCMeta metaclass ensures that we cannot be instantiated without - # the abstract methods being implemented. - __metaclass__ = abc.ABCMeta - - def __init__(self, db_conn, hs): - account_max = self.get_max_account_data_stream_id() - self._account_data_stream_cache = StreamChangeCache( - "AccountDataAndTagsChangeCache", account_max - ) - - super(AccountDataWorkerStore, self).__init__(db_conn, hs) - - @abc.abstractmethod - def get_max_account_data_stream_id(self): - """Get the current max stream ID for account data stream - - Returns: - int - """ - raise NotImplementedError() - - @cached() - def get_account_data_for_user(self, user_id): - """Get all the client account_data for a user. - - Args: - user_id(str): The user to get the account_data for. - Returns: - A deferred pair of a dict of global account_data and a dict - mapping from room_id string to per room account_data dicts. - """ - - def get_account_data_for_user_txn(txn): - rows = self._simple_select_list_txn( - txn, - "account_data", - {"user_id": user_id}, - ["account_data_type", "content"], - ) - - global_account_data = { - row["account_data_type"]: json.loads(row["content"]) for row in rows - } - - rows = self._simple_select_list_txn( - txn, - "room_account_data", - {"user_id": user_id}, - ["room_id", "account_data_type", "content"], - ) - - by_room = {} - for row in rows: - room_data = by_room.setdefault(row["room_id"], {}) - room_data[row["account_data_type"]] = json.loads(row["content"]) - - return global_account_data, by_room - - return self.runInteraction( - "get_account_data_for_user", get_account_data_for_user_txn - ) - - @cachedInlineCallbacks(num_args=2, max_entries=5000) - def get_global_account_data_by_type_for_user(self, data_type, user_id): - """ - Returns: - Deferred: A dict - """ - result = yield self._simple_select_one_onecol( - table="account_data", - keyvalues={"user_id": user_id, "account_data_type": data_type}, - retcol="content", - desc="get_global_account_data_by_type_for_user", - allow_none=True, - ) - - if result: - return json.loads(result) - else: - return None - - @cached(num_args=2) - def get_account_data_for_room(self, user_id, room_id): - """Get all the client account_data for a user for a room. - - Args: - user_id(str): The user to get the account_data for. - room_id(str): The room to get the account_data for. - Returns: - A deferred dict of the room account_data - """ - - def get_account_data_for_room_txn(txn): - rows = self._simple_select_list_txn( - txn, - "room_account_data", - {"user_id": user_id, "room_id": room_id}, - ["account_data_type", "content"], - ) - - return { - row["account_data_type"]: json.loads(row["content"]) for row in rows - } - - return self.runInteraction( - "get_account_data_for_room", get_account_data_for_room_txn - ) - - @cached(num_args=3, max_entries=5000) - def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type): - """Get the client account_data of given type for a user for a room. - - Args: - user_id(str): The user to get the account_data for. - room_id(str): The room to get the account_data for. - account_data_type (str): The account data type to get. - Returns: - A deferred of the room account_data for that type, or None if - there isn't any set. - """ - - def get_account_data_for_room_and_type_txn(txn): - content_json = self._simple_select_one_onecol_txn( - txn, - table="room_account_data", - keyvalues={ - "user_id": user_id, - "room_id": room_id, - "account_data_type": account_data_type, - }, - retcol="content", - allow_none=True, - ) - - return json.loads(content_json) if content_json else None - - return self.runInteraction( - "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn - ) - - def get_all_updated_account_data( - self, last_global_id, last_room_id, current_id, limit - ): - """Get all the client account_data that has changed on the server - Args: - last_global_id(int): The position to fetch from for top level data - last_room_id(int): The position to fetch from for per room data - current_id(int): The position to fetch up to. - Returns: - A deferred pair of lists of tuples of stream_id int, user_id string, - room_id string, type string, and content string. - """ - if last_room_id == current_id and last_global_id == current_id: - return defer.succeed(([], [])) - - def get_updated_account_data_txn(txn): - sql = ( - "SELECT stream_id, user_id, account_data_type, content" - " FROM account_data WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC LIMIT ?" - ) - txn.execute(sql, (last_global_id, current_id, limit)) - global_results = txn.fetchall() - - sql = ( - "SELECT stream_id, user_id, room_id, account_data_type, content" - " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC LIMIT ?" - ) - txn.execute(sql, (last_room_id, current_id, limit)) - room_results = txn.fetchall() - return global_results, room_results - - return self.runInteraction( - "get_all_updated_account_data_txn", get_updated_account_data_txn - ) - - def get_updated_account_data_for_user(self, user_id, stream_id): - """Get all the client account_data for a that's changed for a user - - Args: - user_id(str): The user to get the account_data for. - stream_id(int): The point in the stream since which to get updates - Returns: - A deferred pair of a dict of global account_data and a dict - mapping from room_id string to per room account_data dicts. - """ - - def get_updated_account_data_for_user_txn(txn): - sql = ( - "SELECT account_data_type, content FROM account_data" - " WHERE user_id = ? AND stream_id > ?" - ) - - txn.execute(sql, (user_id, stream_id)) - - global_account_data = {row[0]: json.loads(row[1]) for row in txn} - - sql = ( - "SELECT room_id, account_data_type, content FROM room_account_data" - " WHERE user_id = ? AND stream_id > ?" - ) - - txn.execute(sql, (user_id, stream_id)) - - account_data_by_room = {} - for row in txn: - room_account_data = account_data_by_room.setdefault(row[0], {}) - room_account_data[row[1]] = json.loads(row[2]) - - return global_account_data, account_data_by_room - - changed = self._account_data_stream_cache.has_entity_changed( - user_id, int(stream_id) - ) - if not changed: - return {}, {} - - return self.runInteraction( - "get_updated_account_data_for_user", get_updated_account_data_for_user_txn - ) - - @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000) - def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context): - ignored_account_data = yield self.get_global_account_data_by_type_for_user( - "m.ignored_user_list", - ignorer_user_id, - on_invalidate=cache_context.invalidate, - ) - if not ignored_account_data: - return False - - return ignored_user_id in ignored_account_data.get("ignored_users", {}) - - -class AccountDataStore(AccountDataWorkerStore): - def __init__(self, db_conn, hs): - self._account_data_id_gen = StreamIdGenerator( - db_conn, "account_data_max_stream_id", "stream_id" - ) - - super(AccountDataStore, self).__init__(db_conn, hs) - - def get_max_account_data_stream_id(self): - """Get the current max stream id for the private user data stream - - Returns: - A deferred int. - """ - return self._account_data_id_gen.get_current_token() - - @defer.inlineCallbacks - def add_account_data_to_room(self, user_id, room_id, account_data_type, content): - """Add some account_data to a room for a user. - Args: - user_id(str): The user to add a tag for. - room_id(str): The room to add a tag for. - account_data_type(str): The type of account_data to add. - content(dict): A json object to associate with the tag. - Returns: - A deferred that completes once the account_data has been added. - """ - content_json = json.dumps(content) - - with self._account_data_id_gen.get_next() as next_id: - # no need to lock here as room_account_data has a unique constraint - # on (user_id, room_id, account_data_type) so _simple_upsert will - # retry if there is a conflict. - yield self._simple_upsert( - desc="add_room_account_data", - table="room_account_data", - keyvalues={ - "user_id": user_id, - "room_id": room_id, - "account_data_type": account_data_type, - }, - values={"stream_id": next_id, "content": content_json}, - lock=False, - ) - - # it's theoretically possible for the above to succeed and the - # below to fail - in which case we might reuse a stream id on - # restart, and the above update might not get propagated. That - # doesn't sound any worse than the whole update getting lost, - # which is what would happen if we combined the two into one - # transaction. - yield self._update_max_stream_id(next_id) - - self._account_data_stream_cache.entity_has_changed(user_id, next_id) - self.get_account_data_for_user.invalidate((user_id,)) - self.get_account_data_for_room.invalidate((user_id, room_id)) - self.get_account_data_for_room_and_type.prefill( - (user_id, room_id, account_data_type), content - ) - - result = self._account_data_id_gen.get_current_token() - return result - - @defer.inlineCallbacks - def add_account_data_for_user(self, user_id, account_data_type, content): - """Add some account_data to a room for a user. - Args: - user_id(str): The user to add a tag for. - account_data_type(str): The type of account_data to add. - content(dict): A json object to associate with the tag. - Returns: - A deferred that completes once the account_data has been added. - """ - content_json = json.dumps(content) - - with self._account_data_id_gen.get_next() as next_id: - # no need to lock here as account_data has a unique constraint on - # (user_id, account_data_type) so _simple_upsert will retry if - # there is a conflict. - yield self._simple_upsert( - desc="add_user_account_data", - table="account_data", - keyvalues={"user_id": user_id, "account_data_type": account_data_type}, - values={"stream_id": next_id, "content": content_json}, - lock=False, - ) - - # it's theoretically possible for the above to succeed and the - # below to fail - in which case we might reuse a stream id on - # restart, and the above update might not get propagated. That - # doesn't sound any worse than the whole update getting lost, - # which is what would happen if we combined the two into one - # transaction. - yield self._update_max_stream_id(next_id) - - self._account_data_stream_cache.entity_has_changed(user_id, next_id) - self.get_account_data_for_user.invalidate((user_id,)) - self.get_global_account_data_by_type_for_user.invalidate( - (account_data_type, user_id) - ) - - result = self._account_data_id_gen.get_current_token() - return result - - def _update_max_stream_id(self, next_id): - """Update the max stream_id - - Args: - next_id(int): The the revision to advance to. - """ - - def _update(txn): - update_max_id_sql = ( - "UPDATE account_data_max_stream_id" - " SET stream_id = ?" - " WHERE stream_id < ?" - ) - txn.execute(update_max_id_sql, (next_id, next_id)) - - return self.runInteraction("update_account_data_max_stream_id", _update) diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py deleted file mode 100644 index 435b2acd4d..0000000000 --- a/synapse/storage/appservice.py +++ /dev/null @@ -1,368 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -import re - -from canonicaljson import json - -from twisted.internet import defer - -from synapse.appservice import AppServiceTransaction -from synapse.config.appservice import load_appservices -from synapse.storage.events_worker import EventsWorkerStore - -from ._base import SQLBaseStore - -logger = logging.getLogger(__name__) - - -def _make_exclusive_regex(services_cache): - # We precompie a regex constructed from all the regexes that the AS's - # have registered for exclusive users. - exclusive_user_regexes = [ - regex.pattern - for service in services_cache - for regex in service.get_exlusive_user_regexes() - ] - if exclusive_user_regexes: - exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes) - exclusive_user_regex = re.compile(exclusive_user_regex) - else: - # We handle this case specially otherwise the constructed regex - # will always match - exclusive_user_regex = None - - return exclusive_user_regex - - -class ApplicationServiceWorkerStore(SQLBaseStore): - def __init__(self, db_conn, hs): - self.services_cache = load_appservices( - hs.hostname, hs.config.app_service_config_files - ) - self.exclusive_user_regex = _make_exclusive_regex(self.services_cache) - - super(ApplicationServiceWorkerStore, self).__init__(db_conn, hs) - - def get_app_services(self): - return self.services_cache - - def get_if_app_services_interested_in_user(self, user_id): - """Check if the user is one associated with an app service (exclusively) - """ - if self.exclusive_user_regex: - return bool(self.exclusive_user_regex.match(user_id)) - else: - return False - - def get_app_service_by_user_id(self, user_id): - """Retrieve an application service from their user ID. - - All application services have associated with them a particular user ID. - There is no distinguishing feature on the user ID which indicates it - represents an application service. This function allows you to map from - a user ID to an application service. - - Args: - user_id(str): The user ID to see if it is an application service. - Returns: - synapse.appservice.ApplicationService or None. - """ - for service in self.services_cache: - if service.sender == user_id: - return service - return None - - def get_app_service_by_token(self, token): - """Get the application service with the given appservice token. - - Args: - token (str): The application service token. - Returns: - synapse.appservice.ApplicationService or None. - """ - for service in self.services_cache: - if service.token == token: - return service - return None - - def get_app_service_by_id(self, as_id): - """Get the application service with the given appservice ID. - - Args: - as_id (str): The application service ID. - Returns: - synapse.appservice.ApplicationService or None. - """ - for service in self.services_cache: - if service.id == as_id: - return service - return None - - -class ApplicationServiceStore(ApplicationServiceWorkerStore): - # This is currently empty due to there not being any AS storage functions - # that can't be run on the workers. Since this may change in future, and - # to keep consistency with the other stores, we keep this empty class for - # now. - pass - - -class ApplicationServiceTransactionWorkerStore( - ApplicationServiceWorkerStore, EventsWorkerStore -): - @defer.inlineCallbacks - def get_appservices_by_state(self, state): - """Get a list of application services based on their state. - - Args: - state(ApplicationServiceState): The state to filter on. - Returns: - A Deferred which resolves to a list of ApplicationServices, which - may be empty. - """ - results = yield self._simple_select_list( - "application_services_state", dict(state=state), ["as_id"] - ) - # NB: This assumes this class is linked with ApplicationServiceStore - as_list = self.get_app_services() - services = [] - - for res in results: - for service in as_list: - if service.id == res["as_id"]: - services.append(service) - return services - - @defer.inlineCallbacks - def get_appservice_state(self, service): - """Get the application service state. - - Args: - service(ApplicationService): The service whose state to set. - Returns: - A Deferred which resolves to ApplicationServiceState. - """ - result = yield self._simple_select_one( - "application_services_state", - dict(as_id=service.id), - ["state"], - allow_none=True, - desc="get_appservice_state", - ) - if result: - return result.get("state") - return None - - def set_appservice_state(self, service, state): - """Set the application service state. - - Args: - service(ApplicationService): The service whose state to set. - state(ApplicationServiceState): The connectivity state to apply. - Returns: - A Deferred which resolves when the state was set successfully. - """ - return self._simple_upsert( - "application_services_state", dict(as_id=service.id), dict(state=state) - ) - - def create_appservice_txn(self, service, events): - """Atomically creates a new transaction for this application service - with the given list of events. - - Args: - service(ApplicationService): The service who the transaction is for. - events(list): A list of events to put in the transaction. - Returns: - AppServiceTransaction: A new transaction. - """ - - def _create_appservice_txn(txn): - # work out new txn id (highest txn id for this service += 1) - # The highest id may be the last one sent (in which case it is last_txn) - # or it may be the highest in the txns list (which are waiting to be/are - # being sent) - last_txn_id = self._get_last_txn(txn, service.id) - - txn.execute( - "SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?", - (service.id,), - ) - highest_txn_id = txn.fetchone()[0] - if highest_txn_id is None: - highest_txn_id = 0 - - new_txn_id = max(highest_txn_id, last_txn_id) + 1 - - # Insert new txn into txn table - event_ids = json.dumps([e.event_id for e in events]) - txn.execute( - "INSERT INTO application_services_txns(as_id, txn_id, event_ids) " - "VALUES(?,?,?)", - (service.id, new_txn_id, event_ids), - ) - return AppServiceTransaction(service=service, id=new_txn_id, events=events) - - return self.runInteraction("create_appservice_txn", _create_appservice_txn) - - def complete_appservice_txn(self, txn_id, service): - """Completes an application service transaction. - - Args: - txn_id(str): The transaction ID being completed. - service(ApplicationService): The application service which was sent - this transaction. - Returns: - A Deferred which resolves if this transaction was stored - successfully. - """ - txn_id = int(txn_id) - - def _complete_appservice_txn(txn): - # Debugging query: Make sure the txn being completed is EXACTLY +1 from - # what was there before. If it isn't, we've got problems (e.g. the AS - # has probably missed some events), so whine loudly but still continue, - # since it shouldn't fail completion of the transaction. - last_txn_id = self._get_last_txn(txn, service.id) - if (last_txn_id + 1) != txn_id: - logger.error( - "appservice: Completing a transaction which has an ID > 1 from " - "the last ID sent to this AS. We've either dropped events or " - "sent it to the AS out of order. FIX ME. last_txn=%s " - "completing_txn=%s service_id=%s", - last_txn_id, - txn_id, - service.id, - ) - - # Set current txn_id for AS to 'txn_id' - self._simple_upsert_txn( - txn, - "application_services_state", - dict(as_id=service.id), - dict(last_txn=txn_id), - ) - - # Delete txn - self._simple_delete_txn( - txn, "application_services_txns", dict(txn_id=txn_id, as_id=service.id) - ) - - return self.runInteraction("complete_appservice_txn", _complete_appservice_txn) - - @defer.inlineCallbacks - def get_oldest_unsent_txn(self, service): - """Get the oldest transaction which has not been sent for this - service. - - Args: - service(ApplicationService): The app service to get the oldest txn. - Returns: - A Deferred which resolves to an AppServiceTransaction or - None. - """ - - def _get_oldest_unsent_txn(txn): - # Monotonically increasing txn ids, so just select the smallest - # one in the txns table (we delete them when they are sent) - txn.execute( - "SELECT * FROM application_services_txns WHERE as_id=?" - " ORDER BY txn_id ASC LIMIT 1", - (service.id,), - ) - rows = self.cursor_to_dict(txn) - if not rows: - return None - - entry = rows[0] - - return entry - - entry = yield self.runInteraction( - "get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn - ) - - if not entry: - return None - - event_ids = json.loads(entry["event_ids"]) - - events = yield self.get_events_as_list(event_ids) - - return AppServiceTransaction(service=service, id=entry["txn_id"], events=events) - - def _get_last_txn(self, txn, service_id): - txn.execute( - "SELECT last_txn FROM application_services_state WHERE as_id=?", - (service_id,), - ) - last_txn_id = txn.fetchone() - if last_txn_id is None or last_txn_id[0] is None: # no row exists - return 0 - else: - return int(last_txn_id[0]) # select 'last_txn' col - - def set_appservice_last_pos(self, pos): - def set_appservice_last_pos_txn(txn): - txn.execute( - "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) - ) - - return self.runInteraction( - "set_appservice_last_pos", set_appservice_last_pos_txn - ) - - @defer.inlineCallbacks - def get_new_events_for_appservice(self, current_id, limit): - """Get all new evnets""" - - def get_new_events_for_appservice_txn(txn): - sql = ( - "SELECT e.stream_ordering, e.event_id" - " FROM events AS e" - " WHERE" - " (SELECT stream_ordering FROM appservice_stream_position)" - " < e.stream_ordering" - " AND e.stream_ordering <= ?" - " ORDER BY e.stream_ordering ASC" - " LIMIT ?" - ) - - txn.execute(sql, (current_id, limit)) - rows = txn.fetchall() - - upper_bound = current_id - if len(rows) == limit: - upper_bound = rows[-1][0] - - return upper_bound, [row[1] for row in rows] - - upper_bound, event_ids = yield self.runInteraction( - "get_new_events_for_appservice", get_new_events_for_appservice_txn - ) - - events = yield self.get_events_as_list(event_ids) - - return upper_bound, events - - -class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore): - # This is currently empty due to there not being any AS storage functions - # that can't be run on the workers. Since this may change in future, and - # to keep consistency with the other stores, we keep this empty class for - # now. - pass diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py deleted file mode 100644 index 067820a5da..0000000000 --- a/synapse/storage/client_ips.py +++ /dev/null @@ -1,581 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from six import iteritems - -from twisted.internet import defer - -from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.util.caches import CACHE_SIZE_FACTOR - -from . import background_updates -from ._base import Cache - -logger = logging.getLogger(__name__) - -# Number of msec of granularity to store the user IP 'last seen' time. Smaller -# times give more inserts into the database even for readonly API hits -# 120 seconds == 2 minutes -LAST_SEEN_GRANULARITY = 120 * 1000 - - -class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(ClientIpBackgroundUpdateStore, self).__init__(db_conn, hs) - - self.register_background_index_update( - "user_ips_device_index", - index_name="user_ips_device_id", - table="user_ips", - columns=["user_id", "device_id", "last_seen"], - ) - - self.register_background_index_update( - "user_ips_last_seen_index", - index_name="user_ips_last_seen", - table="user_ips", - columns=["user_id", "last_seen"], - ) - - self.register_background_index_update( - "user_ips_last_seen_only_index", - index_name="user_ips_last_seen_only", - table="user_ips", - columns=["last_seen"], - ) - - self.register_background_update_handler( - "user_ips_analyze", self._analyze_user_ip - ) - - self.register_background_update_handler( - "user_ips_remove_dupes", self._remove_user_ip_dupes - ) - - # Register a unique index - self.register_background_index_update( - "user_ips_device_unique_index", - index_name="user_ips_user_token_ip_unique_index", - table="user_ips", - columns=["user_id", "access_token", "ip"], - unique=True, - ) - - # Drop the old non-unique index - self.register_background_update_handler( - "user_ips_drop_nonunique_index", self._remove_user_ip_nonunique - ) - - # Update the last seen info in devices. - self.register_background_update_handler( - "devices_last_seen", self._devices_last_seen_update - ) - - @defer.inlineCallbacks - def _remove_user_ip_nonunique(self, progress, batch_size): - def f(conn): - txn = conn.cursor() - txn.execute("DROP INDEX IF EXISTS user_ips_user_ip") - txn.close() - - yield self.runWithConnection(f) - yield self._end_background_update("user_ips_drop_nonunique_index") - return 1 - - @defer.inlineCallbacks - def _analyze_user_ip(self, progress, batch_size): - # Background update to analyze user_ips table before we run the - # deduplication background update. The table may not have been analyzed - # for ages due to the table locks. - # - # This will lock out the naive upserts to user_ips while it happens, but - # the analyze should be quick (28GB table takes ~10s) - def user_ips_analyze(txn): - txn.execute("ANALYZE user_ips") - - yield self.runInteraction("user_ips_analyze", user_ips_analyze) - - yield self._end_background_update("user_ips_analyze") - - return 1 - - @defer.inlineCallbacks - def _remove_user_ip_dupes(self, progress, batch_size): - # This works function works by scanning the user_ips table in batches - # based on `last_seen`. For each row in a batch it searches the rest of - # the table to see if there are any duplicates, if there are then they - # are removed and replaced with a suitable row. - - # Fetch the start of the batch - begin_last_seen = progress.get("last_seen", 0) - - def get_last_seen(txn): - txn.execute( - """ - SELECT last_seen FROM user_ips - WHERE last_seen > ? - ORDER BY last_seen - LIMIT 1 - OFFSET ? - """, - (begin_last_seen, batch_size), - ) - row = txn.fetchone() - if row: - return row[0] - else: - return None - - # Get a last seen that has roughly `batch_size` since `begin_last_seen` - end_last_seen = yield self.runInteraction( - "user_ips_dups_get_last_seen", get_last_seen - ) - - # If it returns None, then we're processing the last batch - last = end_last_seen is None - - logger.info( - "Scanning for duplicate 'user_ips' rows in range: %s <= last_seen < %s", - begin_last_seen, - end_last_seen, - ) - - def remove(txn): - # This works by looking at all entries in the given time span, and - # then for each (user_id, access_token, ip) tuple in that range - # checking for any duplicates in the rest of the table (via a join). - # It then only returns entries which have duplicates, and the max - # last_seen across all duplicates, which can the be used to delete - # all other duplicates. - # It is efficient due to the existence of (user_id, access_token, - # ip) and (last_seen) indices. - - # Define the search space, which requires handling the last batch in - # a different way - if last: - clause = "? <= last_seen" - args = (begin_last_seen,) - else: - clause = "? <= last_seen AND last_seen < ?" - args = (begin_last_seen, end_last_seen) - - # (Note: The DISTINCT in the inner query is important to ensure that - # the COUNT(*) is accurate, otherwise double counting may happen due - # to the join effectively being a cross product) - txn.execute( - """ - SELECT user_id, access_token, ip, - MAX(device_id), MAX(user_agent), MAX(last_seen), - COUNT(*) - FROM ( - SELECT DISTINCT user_id, access_token, ip - FROM user_ips - WHERE {} - ) c - INNER JOIN user_ips USING (user_id, access_token, ip) - GROUP BY user_id, access_token, ip - HAVING count(*) > 1 - """.format( - clause - ), - args, - ) - res = txn.fetchall() - - # We've got some duplicates - for i in res: - user_id, access_token, ip, device_id, user_agent, last_seen, count = i - - # We want to delete the duplicates so we end up with only a - # single row. - # - # The naive way of doing this would be just to delete all rows - # and reinsert a constructed row. However, if there are a lot of - # duplicate rows this can cause the table to grow a lot, which - # can be problematic in two ways: - # 1. If user_ips is already large then this can cause the - # table to rapidly grow, potentially filling the disk. - # 2. Reinserting a lot of rows can confuse the table - # statistics for postgres, causing it to not use the - # correct indices for the query above, resulting in a full - # table scan. This is incredibly slow for large tables and - # can kill database performance. (This seems to mainly - # happen for the last query where the clause is simply `? < - # last_seen`) - # - # So instead we want to delete all but *one* of the duplicate - # rows. That is hard to do reliably, so we cheat and do a two - # step process: - # 1. Delete all rows with a last_seen strictly less than the - # max last_seen. This hopefully results in deleting all but - # one row the majority of the time, but there may be - # duplicate last_seen - # 2. If multiple rows remain, we fall back to the naive method - # and simply delete all rows and reinsert. - # - # Note that this relies on no new duplicate rows being inserted, - # but if that is happening then this entire process is futile - # anyway. - - # Do step 1: - - txn.execute( - """ - DELETE FROM user_ips - WHERE user_id = ? AND access_token = ? AND ip = ? AND last_seen < ? - """, - (user_id, access_token, ip, last_seen), - ) - if txn.rowcount == count - 1: - # We deleted all but one of the duplicate rows, i.e. there - # is exactly one remaining and so there is nothing left to - # do. - continue - elif txn.rowcount >= count: - raise Exception( - "We deleted more duplicate rows from 'user_ips' than expected" - ) - - # The previous step didn't delete enough rows, so we fallback to - # step 2: - - # Drop all the duplicates - txn.execute( - """ - DELETE FROM user_ips - WHERE user_id = ? AND access_token = ? AND ip = ? - """, - (user_id, access_token, ip), - ) - - # Add in one to be the last_seen - txn.execute( - """ - INSERT INTO user_ips - (user_id, access_token, ip, device_id, user_agent, last_seen) - VALUES (?, ?, ?, ?, ?, ?) - """, - (user_id, access_token, ip, device_id, user_agent, last_seen), - ) - - self._background_update_progress_txn( - txn, "user_ips_remove_dupes", {"last_seen": end_last_seen} - ) - - yield self.runInteraction("user_ips_dups_remove", remove) - - if last: - yield self._end_background_update("user_ips_remove_dupes") - - return batch_size - - @defer.inlineCallbacks - def _devices_last_seen_update(self, progress, batch_size): - """Background update to insert last seen info into devices table - """ - - last_user_id = progress.get("last_user_id", "") - last_device_id = progress.get("last_device_id", "") - - def _devices_last_seen_update_txn(txn): - # This consists of two queries: - # - # 1. The sub-query searches for the next N devices and joins - # against user_ips to find the max last_seen associated with - # that device. - # 2. The outer query then joins again against user_ips on - # user/device/last_seen. This *should* hopefully only - # return one row, but if it does return more than one then - # we'll just end up updating the same device row multiple - # times, which is fine. - - if self.database_engine.supports_tuple_comparison: - where_clause = "(user_id, device_id) > (?, ?)" - where_args = [last_user_id, last_device_id] - else: - # We explicitly do a `user_id >= ? AND (...)` here to ensure - # that an index is used, as doing `user_id > ? OR (user_id = ? AND ...)` - # makes it hard for query optimiser to tell that it can use the - # index on user_id - where_clause = "user_id >= ? AND (user_id > ? OR device_id > ?)" - where_args = [last_user_id, last_user_id, last_device_id] - - sql = """ - SELECT - last_seen, ip, user_agent, user_id, device_id - FROM ( - SELECT - user_id, device_id, MAX(u.last_seen) AS last_seen - FROM devices - INNER JOIN user_ips AS u USING (user_id, device_id) - WHERE %(where_clause)s - GROUP BY user_id, device_id - ORDER BY user_id ASC, device_id ASC - LIMIT ? - ) c - INNER JOIN user_ips AS u USING (user_id, device_id, last_seen) - """ % { - "where_clause": where_clause - } - txn.execute(sql, where_args + [batch_size]) - - rows = txn.fetchall() - if not rows: - return 0 - - sql = """ - UPDATE devices - SET last_seen = ?, ip = ?, user_agent = ? - WHERE user_id = ? AND device_id = ? - """ - txn.execute_batch(sql, rows) - - _, _, _, user_id, device_id = rows[-1] - self._background_update_progress_txn( - txn, - "devices_last_seen", - {"last_user_id": user_id, "last_device_id": device_id}, - ) - - return len(rows) - - updated = yield self.runInteraction( - "_devices_last_seen_update", _devices_last_seen_update_txn - ) - - if not updated: - yield self._end_background_update("devices_last_seen") - - return updated - - -class ClientIpStore(ClientIpBackgroundUpdateStore): - def __init__(self, db_conn, hs): - - self.client_ip_last_seen = Cache( - name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR - ) - - super(ClientIpStore, self).__init__(db_conn, hs) - - self.user_ips_max_age = hs.config.user_ips_max_age - - # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen) - self._batch_row_update = {} - - self._client_ip_looper = self._clock.looping_call( - self._update_client_ips_batch, 5 * 1000 - ) - self.hs.get_reactor().addSystemEventTrigger( - "before", "shutdown", self._update_client_ips_batch - ) - - if self.user_ips_max_age: - self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) - - @defer.inlineCallbacks - def insert_client_ip( - self, user_id, access_token, ip, user_agent, device_id, now=None - ): - if not now: - now = int(self._clock.time_msec()) - key = (user_id, access_token, ip) - - try: - last_seen = self.client_ip_last_seen.get(key) - except KeyError: - last_seen = None - yield self.populate_monthly_active_users(user_id) - # Rate-limited inserts - if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: - return - - self.client_ip_last_seen.prefill(key, now) - - self._batch_row_update[key] = (user_agent, device_id, now) - - @wrap_as_background_process("update_client_ips") - def _update_client_ips_batch(self): - - # If the DB pool has already terminated, don't try updating - if not self.hs.get_db_pool().running: - return - - to_update = self._batch_row_update - self._batch_row_update = {} - - return self.runInteraction( - "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update - ) - - def _update_client_ips_batch_txn(self, txn, to_update): - if "user_ips" in self._unsafe_to_upsert_tables or ( - not self.database_engine.can_native_upsert - ): - self.database_engine.lock_table(txn, "user_ips") - - for entry in iteritems(to_update): - (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry - - try: - self._simple_upsert_txn( - txn, - table="user_ips", - keyvalues={ - "user_id": user_id, - "access_token": access_token, - "ip": ip, - }, - values={ - "user_agent": user_agent, - "device_id": device_id, - "last_seen": last_seen, - }, - lock=False, - ) - - # Technically an access token might not be associated with - # a device so we need to check. - if device_id: - self._simple_upsert_txn( - txn, - table="devices", - keyvalues={"user_id": user_id, "device_id": device_id}, - values={ - "user_agent": user_agent, - "last_seen": last_seen, - "ip": ip, - }, - lock=False, - ) - except Exception as e: - # Failed to upsert, log and continue - logger.error("Failed to insert client IP %r: %r", entry, e) - - @defer.inlineCallbacks - def get_last_client_ip_by_device(self, user_id, device_id): - """For each device_id listed, give the user_ip it was last seen on - - Args: - user_id (str) - device_id (str): If None fetches all devices for the user - - Returns: - defer.Deferred: resolves to a dict, where the keys - are (user_id, device_id) tuples. The values are also dicts, with - keys giving the column names - """ - - keyvalues = {"user_id": user_id} - if device_id is not None: - keyvalues["device_id"] = device_id - - res = yield self._simple_select_list( - table="devices", - keyvalues=keyvalues, - retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), - ) - - ret = {(d["user_id"], d["device_id"]): d for d in res} - for key in self._batch_row_update: - uid, access_token, ip = key - if uid == user_id: - user_agent, did, last_seen = self._batch_row_update[key] - if not device_id or did == device_id: - ret[(user_id, device_id)] = { - "user_id": user_id, - "access_token": access_token, - "ip": ip, - "user_agent": user_agent, - "device_id": did, - "last_seen": last_seen, - } - return ret - - @defer.inlineCallbacks - def get_user_ip_and_agents(self, user): - user_id = user.to_string() - results = {} - - for key in self._batch_row_update: - uid, access_token, ip, = key - if uid == user_id: - user_agent, _, last_seen = self._batch_row_update[key] - results[(access_token, ip)] = (user_agent, last_seen) - - rows = yield self._simple_select_list( - table="user_ips", - keyvalues={"user_id": user_id}, - retcols=["access_token", "ip", "user_agent", "last_seen"], - desc="get_user_ip_and_agents", - ) - - results.update( - ((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"])) - for row in rows - ) - return list( - { - "access_token": access_token, - "ip": ip, - "user_agent": user_agent, - "last_seen": last_seen, - } - for (access_token, ip), (user_agent, last_seen) in iteritems(results) - ) - - @wrap_as_background_process("prune_old_user_ips") - async def _prune_old_user_ips(self): - """Removes entries in user IPs older than the configured period. - """ - - if self.user_ips_max_age is None: - # Nothing to do - return - - if not await self.has_completed_background_update("devices_last_seen"): - # Only start pruning if we have finished populating the devices - # last seen info. - return - - # We do a slightly funky SQL delete to ensure we don't try and delete - # too much at once (as the table may be very large from before we - # started pruning). - # - # This works by finding the max last_seen that is less than the given - # time, but has no more than N rows before it, deleting all rows with - # a lesser last_seen time. (We COALESCE so that the sub-SELECT always - # returns exactly one row). - sql = """ - DELETE FROM user_ips - WHERE last_seen <= ( - SELECT COALESCE(MAX(last_seen), -1) - FROM ( - SELECT last_seen FROM user_ips - WHERE last_seen <= ? - ORDER BY last_seen ASC - LIMIT 5000 - ) AS u - ) - """ - - timestamp = self.clock.time_msec() - self.user_ips_max_age - - def _prune_old_user_ips_txn(txn): - txn.execute(sql, (timestamp,)) - - await self.runInteraction("_prune_old_user_ips", _prune_old_user_ips_txn) diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py new file mode 100644 index 0000000000..56094078ed --- /dev/null +++ b/synapse/storage/data_stores/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py new file mode 100644 index 0000000000..d29135588f --- /dev/null +++ b/synapse/storage/data_stores/main/__init__.py @@ -0,0 +1,524 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import calendar +import logging +import time + +from twisted.internet import defer + +from synapse.api.constants import PresenceState +from synapse.storage.engines import PostgresEngine +from synapse.storage.util.id_generators import ( + ChainedIdGenerator, + IdGenerator, + StreamIdGenerator, +) +from synapse.util.caches.stream_change_cache import StreamChangeCache + +from .account_data import AccountDataStore +from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore +from .client_ips import ClientIpStore +from .deviceinbox import DeviceInboxStore +from .devices import DeviceStore +from .directory import DirectoryStore +from .e2e_room_keys import EndToEndRoomKeyStore +from .end_to_end_keys import EndToEndKeyStore +from .event_federation import EventFederationStore +from .event_push_actions import EventPushActionsStore +from .events import EventsStore +from .events_bg_updates import EventsBackgroundUpdatesStore +from .filtering import FilteringStore +from .group_server import GroupServerStore +from .keys import KeyStore +from .media_repository import MediaRepositoryStore +from .monthly_active_users import MonthlyActiveUsersStore +from .openid import OpenIdStore +from .presence import PresenceStore, UserPresenceState +from .profile import ProfileStore +from .push_rule import PushRuleStore +from .pusher import PusherStore +from .receipts import ReceiptsStore +from .registration import RegistrationStore +from .rejections import RejectionsStore +from .relations import RelationsStore +from .room import RoomStore +from .roommember import RoomMemberStore +from .search import SearchStore +from .signatures import SignatureStore +from .state import StateStore +from .stats import StatsStore +from .stream import StreamStore +from .tags import TagsStore +from .transactions import TransactionStore +from .user_directory import UserDirectoryStore +from .user_erasure_store import UserErasureStore + +logger = logging.getLogger(__name__) + + +class DataStore( + EventsBackgroundUpdatesStore, + RoomMemberStore, + RoomStore, + RegistrationStore, + StreamStore, + ProfileStore, + PresenceStore, + TransactionStore, + DirectoryStore, + KeyStore, + StateStore, + SignatureStore, + ApplicationServiceStore, + EventsStore, + EventFederationStore, + MediaRepositoryStore, + RejectionsStore, + FilteringStore, + PusherStore, + PushRuleStore, + ApplicationServiceTransactionStore, + ReceiptsStore, + EndToEndKeyStore, + EndToEndRoomKeyStore, + SearchStore, + TagsStore, + AccountDataStore, + EventPushActionsStore, + OpenIdStore, + ClientIpStore, + DeviceStore, + DeviceInboxStore, + UserDirectoryStore, + GroupServerStore, + UserErasureStore, + MonthlyActiveUsersStore, + StatsStore, + RelationsStore, +): + def __init__(self, db_conn, hs): + self.hs = hs + self._clock = hs.get_clock() + self.database_engine = hs.database_engine + + self._stream_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + extra_tables=[("local_invites", "stream_id")], + ) + self._backfill_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + step=-1, + extra_tables=[("ex_outlier_stream", "event_stream_ordering")], + ) + self._presence_id_gen = StreamIdGenerator( + db_conn, "presence_stream", "stream_id" + ) + self._device_inbox_id_gen = StreamIdGenerator( + db_conn, "device_max_stream_id", "stream_id" + ) + self._public_room_id_gen = StreamIdGenerator( + db_conn, "public_room_list_stream", "stream_id" + ) + self._device_list_id_gen = StreamIdGenerator( + db_conn, "device_lists_stream", "stream_id" + ) + + self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") + self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") + self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") + self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") + self._push_rules_stream_id_gen = ChainedIdGenerator( + self._stream_id_gen, db_conn, "push_rules_stream", "stream_id" + ) + self._pushers_id_gen = StreamIdGenerator( + db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] + ) + self._group_updates_id_gen = StreamIdGenerator( + db_conn, "local_group_updates", "stream_id" + ) + + if isinstance(self.database_engine, PostgresEngine): + self._cache_id_gen = StreamIdGenerator( + db_conn, "cache_invalidation_stream", "stream_id" + ) + else: + self._cache_id_gen = None + + self._presence_on_startup = self._get_active_presence(db_conn) + + presence_cache_prefill, min_presence_val = self._get_cache_dict( + db_conn, + "presence_stream", + entity_column="user_id", + stream_column="stream_id", + max_value=self._presence_id_gen.get_current_token(), + ) + self.presence_stream_cache = StreamChangeCache( + "PresenceStreamChangeCache", + min_presence_val, + prefilled_cache=presence_cache_prefill, + ) + + max_device_inbox_id = self._device_inbox_id_gen.get_current_token() + device_inbox_prefill, min_device_inbox_id = self._get_cache_dict( + db_conn, + "device_inbox", + entity_column="user_id", + stream_column="stream_id", + max_value=max_device_inbox_id, + limit=1000, + ) + self._device_inbox_stream_cache = StreamChangeCache( + "DeviceInboxStreamChangeCache", + min_device_inbox_id, + prefilled_cache=device_inbox_prefill, + ) + # The federation outbox and the local device inbox uses the same + # stream_id generator. + device_outbox_prefill, min_device_outbox_id = self._get_cache_dict( + db_conn, + "device_federation_outbox", + entity_column="destination", + stream_column="stream_id", + max_value=max_device_inbox_id, + limit=1000, + ) + self._device_federation_outbox_stream_cache = StreamChangeCache( + "DeviceFederationOutboxStreamChangeCache", + min_device_outbox_id, + prefilled_cache=device_outbox_prefill, + ) + + device_list_max = self._device_list_id_gen.get_current_token() + self._device_list_stream_cache = StreamChangeCache( + "DeviceListStreamChangeCache", device_list_max + ) + self._device_list_federation_stream_cache = StreamChangeCache( + "DeviceListFederationStreamChangeCache", device_list_max + ) + + events_max = self._stream_id_gen.get_current_token() + curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict( + db_conn, + "current_state_delta_stream", + entity_column="room_id", + stream_column="stream_id", + max_value=events_max, # As we share the stream id with events token + limit=1000, + ) + self._curr_state_delta_stream_cache = StreamChangeCache( + "_curr_state_delta_stream_cache", + min_curr_state_delta_id, + prefilled_cache=curr_state_delta_prefill, + ) + + _group_updates_prefill, min_group_updates_id = self._get_cache_dict( + db_conn, + "local_group_updates", + entity_column="user_id", + stream_column="stream_id", + max_value=self._group_updates_id_gen.get_current_token(), + limit=1000, + ) + self._group_updates_stream_cache = StreamChangeCache( + "_group_updates_stream_cache", + min_group_updates_id, + prefilled_cache=_group_updates_prefill, + ) + + self._stream_order_on_start = self.get_room_max_stream_ordering() + self._min_stream_order_on_start = self.get_room_min_stream_ordering() + + # Used in _generate_user_daily_visits to keep track of progress + self._last_user_visit_update = self._get_start_of_day() + + super(DataStore, self).__init__(db_conn, hs) + + def take_presence_startup_info(self): + active_on_startup = self._presence_on_startup + self._presence_on_startup = None + return active_on_startup + + def _get_active_presence(self, db_conn): + """Fetch non-offline presence from the database so that we can register + the appropriate time outs. + """ + + sql = ( + "SELECT user_id, state, last_active_ts, last_federation_update_ts," + " last_user_sync_ts, status_msg, currently_active FROM presence_stream" + " WHERE state != ?" + ) + sql = self.database_engine.convert_param_style(sql) + + txn = db_conn.cursor() + txn.execute(sql, (PresenceState.OFFLINE,)) + rows = self.cursor_to_dict(txn) + txn.close() + + for row in rows: + row["currently_active"] = bool(row["currently_active"]) + + return [UserPresenceState(**row) for row in rows] + + def count_daily_users(self): + """ + Counts the number of users who used this homeserver in the last 24 hours. + """ + yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) + return self.runInteraction("count_daily_users", self._count_users, yesterday) + + def count_monthly_users(self): + """ + Counts the number of users who used this homeserver in the last 30 days. + Note this method is intended for phonehome metrics only and is different + from the mau figure in synapse.storage.monthly_active_users which, + amongst other things, includes a 3 day grace period before a user counts. + """ + thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) + return self.runInteraction( + "count_monthly_users", self._count_users, thirty_days_ago + ) + + def _count_users(self, txn, time_from): + """ + Returns number of users seen in the past time_from period + """ + sql = """ + SELECT COALESCE(count(*), 0) FROM ( + SELECT user_id FROM user_ips + WHERE last_seen > ? + GROUP BY user_id + ) u + """ + txn.execute(sql, (time_from,)) + count, = txn.fetchone() + return count + + def count_r30_users(self): + """ + Counts the number of 30 day retained users, defined as:- + * Users who have created their accounts more than 30 days ago + * Where last seen at most 30 days ago + * Where account creation and last_seen are > 30 days apart + + Returns counts globaly for a given user as well as breaking + by platform + """ + + def _count_r30_users(txn): + thirty_days_in_secs = 86400 * 30 + now = int(self._clock.time()) + thirty_days_ago_in_secs = now - thirty_days_in_secs + + sql = """ + SELECT platform, COALESCE(count(*), 0) FROM ( + SELECT + users.name, platform, users.creation_ts * 1000, + MAX(uip.last_seen) + FROM users + INNER JOIN ( + SELECT + user_id, + last_seen, + CASE + WHEN user_agent LIKE '%%Android%%' THEN 'android' + WHEN user_agent LIKE '%%iOS%%' THEN 'ios' + WHEN user_agent LIKE '%%Electron%%' THEN 'electron' + WHEN user_agent LIKE '%%Mozilla%%' THEN 'web' + WHEN user_agent LIKE '%%Gecko%%' THEN 'web' + ELSE 'unknown' + END + AS platform + FROM user_ips + ) uip + ON users.name = uip.user_id + AND users.appservice_id is NULL + AND users.creation_ts < ? + AND uip.last_seen/1000 > ? + AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 + GROUP BY users.name, platform, users.creation_ts + ) u GROUP BY platform + """ + + results = {} + txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) + + for row in txn: + if row[0] == "unknown": + pass + results[row[0]] = row[1] + + sql = """ + SELECT COALESCE(count(*), 0) FROM ( + SELECT users.name, users.creation_ts * 1000, + MAX(uip.last_seen) + FROM users + INNER JOIN ( + SELECT + user_id, + last_seen + FROM user_ips + ) uip + ON users.name = uip.user_id + AND appservice_id is NULL + AND users.creation_ts < ? + AND uip.last_seen/1000 > ? + AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 + GROUP BY users.name, users.creation_ts + ) u + """ + + txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) + + count, = txn.fetchone() + results["all"] = count + + return results + + return self.runInteraction("count_r30_users", _count_r30_users) + + def _get_start_of_day(self): + """ + Returns millisecond unixtime for start of UTC day. + """ + now = time.gmtime() + today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0)) + return today_start * 1000 + + def generate_user_daily_visits(self): + """ + Generates daily visit data for use in cohort/ retention analysis + """ + + def _generate_user_daily_visits(txn): + logger.info("Calling _generate_user_daily_visits") + today_start = self._get_start_of_day() + a_day_in_milliseconds = 24 * 60 * 60 * 1000 + now = self.clock.time_msec() + + sql = """ + INSERT INTO user_daily_visits (user_id, device_id, timestamp) + SELECT u.user_id, u.device_id, ? + FROM user_ips AS u + LEFT JOIN ( + SELECT user_id, device_id, timestamp FROM user_daily_visits + WHERE timestamp = ? + ) udv + ON u.user_id = udv.user_id AND u.device_id=udv.device_id + INNER JOIN users ON users.name=u.user_id + WHERE last_seen > ? AND last_seen <= ? + AND udv.timestamp IS NULL AND users.is_guest=0 + AND users.appservice_id IS NULL + GROUP BY u.user_id, u.device_id + """ + + # This means that the day has rolled over but there could still + # be entries from the previous day. There is an edge case + # where if the user logs in at 23:59 and overwrites their + # last_seen at 00:01 then they will not be counted in the + # previous day's stats - it is important that the query is run + # often to minimise this case. + if today_start > self._last_user_visit_update: + yesterday_start = today_start - a_day_in_milliseconds + txn.execute( + sql, + ( + yesterday_start, + yesterday_start, + self._last_user_visit_update, + today_start, + ), + ) + self._last_user_visit_update = today_start + + txn.execute( + sql, (today_start, today_start, self._last_user_visit_update, now) + ) + # Update _last_user_visit_update to now. The reason to do this + # rather just clamping to the beginning of the day is to limit + # the size of the join - meaning that the query can be run more + # frequently + self._last_user_visit_update = now + + return self.runInteraction( + "generate_user_daily_visits", _generate_user_daily_visits + ) + + def get_users(self): + """Function to reterive a list of users in users table. + + Args: + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + return self._simple_select_list( + table="users", + keyvalues={}, + retcols=["name", "password_hash", "is_guest", "admin", "user_type"], + desc="get_users", + ) + + @defer.inlineCallbacks + def get_users_paginate(self, order, start, limit): + """Function to reterive a paginated list of users from + users list. This will return a json object, which contains + list of users and the total number of users in users table. + + Args: + order (str): column name to order the select by this column + start (int): start number to begin the query from + limit (int): number of rows to reterive + Returns: + defer.Deferred: resolves to json object {list[dict[str, Any]], count} + """ + users = yield self.runInteraction( + "get_users_paginate", + self._simple_select_list_paginate_txn, + table="users", + keyvalues={"is_guest": False}, + orderby=order, + start=start, + limit=limit, + retcols=["name", "password_hash", "is_guest", "admin", "user_type"], + ) + count = yield self.runInteraction("get_users_paginate", self.get_user_count_txn) + retval = {"users": users, "total": count} + return retval + + def search_users(self, term): + """Function to search users list for one or more users with + the matched term. + + Args: + term (str): search term + col (str): column to query term should be matched to + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + return self._simple_search_list( + table="users", + term=term, + col="name", + retcols=["name", "password_hash", "is_guest", "admin", "user_type"], + desc="search_users", + ) diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py new file mode 100644 index 0000000000..6afbfc0d74 --- /dev/null +++ b/synapse/storage/data_stores/main/account_data.py @@ -0,0 +1,391 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import logging + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.storage._base import SQLBaseStore +from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.util.caches.stream_change_cache import StreamChangeCache + +logger = logging.getLogger(__name__) + + +class AccountDataWorkerStore(SQLBaseStore): + """This is an abstract base class where subclasses must implement + `get_max_account_data_stream_id` which can be called in the initializer. + """ + + # This ABCMeta metaclass ensures that we cannot be instantiated without + # the abstract methods being implemented. + __metaclass__ = abc.ABCMeta + + def __init__(self, db_conn, hs): + account_max = self.get_max_account_data_stream_id() + self._account_data_stream_cache = StreamChangeCache( + "AccountDataAndTagsChangeCache", account_max + ) + + super(AccountDataWorkerStore, self).__init__(db_conn, hs) + + @abc.abstractmethod + def get_max_account_data_stream_id(self): + """Get the current max stream ID for account data stream + + Returns: + int + """ + raise NotImplementedError() + + @cached() + def get_account_data_for_user(self, user_id): + """Get all the client account_data for a user. + + Args: + user_id(str): The user to get the account_data for. + Returns: + A deferred pair of a dict of global account_data and a dict + mapping from room_id string to per room account_data dicts. + """ + + def get_account_data_for_user_txn(txn): + rows = self._simple_select_list_txn( + txn, + "account_data", + {"user_id": user_id}, + ["account_data_type", "content"], + ) + + global_account_data = { + row["account_data_type"]: json.loads(row["content"]) for row in rows + } + + rows = self._simple_select_list_txn( + txn, + "room_account_data", + {"user_id": user_id}, + ["room_id", "account_data_type", "content"], + ) + + by_room = {} + for row in rows: + room_data = by_room.setdefault(row["room_id"], {}) + room_data[row["account_data_type"]] = json.loads(row["content"]) + + return global_account_data, by_room + + return self.runInteraction( + "get_account_data_for_user", get_account_data_for_user_txn + ) + + @cachedInlineCallbacks(num_args=2, max_entries=5000) + def get_global_account_data_by_type_for_user(self, data_type, user_id): + """ + Returns: + Deferred: A dict + """ + result = yield self._simple_select_one_onecol( + table="account_data", + keyvalues={"user_id": user_id, "account_data_type": data_type}, + retcol="content", + desc="get_global_account_data_by_type_for_user", + allow_none=True, + ) + + if result: + return json.loads(result) + else: + return None + + @cached(num_args=2) + def get_account_data_for_room(self, user_id, room_id): + """Get all the client account_data for a user for a room. + + Args: + user_id(str): The user to get the account_data for. + room_id(str): The room to get the account_data for. + Returns: + A deferred dict of the room account_data + """ + + def get_account_data_for_room_txn(txn): + rows = self._simple_select_list_txn( + txn, + "room_account_data", + {"user_id": user_id, "room_id": room_id}, + ["account_data_type", "content"], + ) + + return { + row["account_data_type"]: json.loads(row["content"]) for row in rows + } + + return self.runInteraction( + "get_account_data_for_room", get_account_data_for_room_txn + ) + + @cached(num_args=3, max_entries=5000) + def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type): + """Get the client account_data of given type for a user for a room. + + Args: + user_id(str): The user to get the account_data for. + room_id(str): The room to get the account_data for. + account_data_type (str): The account data type to get. + Returns: + A deferred of the room account_data for that type, or None if + there isn't any set. + """ + + def get_account_data_for_room_and_type_txn(txn): + content_json = self._simple_select_one_onecol_txn( + txn, + table="room_account_data", + keyvalues={ + "user_id": user_id, + "room_id": room_id, + "account_data_type": account_data_type, + }, + retcol="content", + allow_none=True, + ) + + return json.loads(content_json) if content_json else None + + return self.runInteraction( + "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn + ) + + def get_all_updated_account_data( + self, last_global_id, last_room_id, current_id, limit + ): + """Get all the client account_data that has changed on the server + Args: + last_global_id(int): The position to fetch from for top level data + last_room_id(int): The position to fetch from for per room data + current_id(int): The position to fetch up to. + Returns: + A deferred pair of lists of tuples of stream_id int, user_id string, + room_id string, type string, and content string. + """ + if last_room_id == current_id and last_global_id == current_id: + return defer.succeed(([], [])) + + def get_updated_account_data_txn(txn): + sql = ( + "SELECT stream_id, user_id, account_data_type, content" + " FROM account_data WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_global_id, current_id, limit)) + global_results = txn.fetchall() + + sql = ( + "SELECT stream_id, user_id, room_id, account_data_type, content" + " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_room_id, current_id, limit)) + room_results = txn.fetchall() + return global_results, room_results + + return self.runInteraction( + "get_all_updated_account_data_txn", get_updated_account_data_txn + ) + + def get_updated_account_data_for_user(self, user_id, stream_id): + """Get all the client account_data for a that's changed for a user + + Args: + user_id(str): The user to get the account_data for. + stream_id(int): The point in the stream since which to get updates + Returns: + A deferred pair of a dict of global account_data and a dict + mapping from room_id string to per room account_data dicts. + """ + + def get_updated_account_data_for_user_txn(txn): + sql = ( + "SELECT account_data_type, content FROM account_data" + " WHERE user_id = ? AND stream_id > ?" + ) + + txn.execute(sql, (user_id, stream_id)) + + global_account_data = {row[0]: json.loads(row[1]) for row in txn} + + sql = ( + "SELECT room_id, account_data_type, content FROM room_account_data" + " WHERE user_id = ? AND stream_id > ?" + ) + + txn.execute(sql, (user_id, stream_id)) + + account_data_by_room = {} + for row in txn: + room_account_data = account_data_by_room.setdefault(row[0], {}) + room_account_data[row[1]] = json.loads(row[2]) + + return global_account_data, account_data_by_room + + changed = self._account_data_stream_cache.has_entity_changed( + user_id, int(stream_id) + ) + if not changed: + return {}, {} + + return self.runInteraction( + "get_updated_account_data_for_user", get_updated_account_data_for_user_txn + ) + + @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000) + def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context): + ignored_account_data = yield self.get_global_account_data_by_type_for_user( + "m.ignored_user_list", + ignorer_user_id, + on_invalidate=cache_context.invalidate, + ) + if not ignored_account_data: + return False + + return ignored_user_id in ignored_account_data.get("ignored_users", {}) + + +class AccountDataStore(AccountDataWorkerStore): + def __init__(self, db_conn, hs): + self._account_data_id_gen = StreamIdGenerator( + db_conn, "account_data_max_stream_id", "stream_id" + ) + + super(AccountDataStore, self).__init__(db_conn, hs) + + def get_max_account_data_stream_id(self): + """Get the current max stream id for the private user data stream + + Returns: + A deferred int. + """ + return self._account_data_id_gen.get_current_token() + + @defer.inlineCallbacks + def add_account_data_to_room(self, user_id, room_id, account_data_type, content): + """Add some account_data to a room for a user. + Args: + user_id(str): The user to add a tag for. + room_id(str): The room to add a tag for. + account_data_type(str): The type of account_data to add. + content(dict): A json object to associate with the tag. + Returns: + A deferred that completes once the account_data has been added. + """ + content_json = json.dumps(content) + + with self._account_data_id_gen.get_next() as next_id: + # no need to lock here as room_account_data has a unique constraint + # on (user_id, room_id, account_data_type) so _simple_upsert will + # retry if there is a conflict. + yield self._simple_upsert( + desc="add_room_account_data", + table="room_account_data", + keyvalues={ + "user_id": user_id, + "room_id": room_id, + "account_data_type": account_data_type, + }, + values={"stream_id": next_id, "content": content_json}, + lock=False, + ) + + # it's theoretically possible for the above to succeed and the + # below to fail - in which case we might reuse a stream id on + # restart, and the above update might not get propagated. That + # doesn't sound any worse than the whole update getting lost, + # which is what would happen if we combined the two into one + # transaction. + yield self._update_max_stream_id(next_id) + + self._account_data_stream_cache.entity_has_changed(user_id, next_id) + self.get_account_data_for_user.invalidate((user_id,)) + self.get_account_data_for_room.invalidate((user_id, room_id)) + self.get_account_data_for_room_and_type.prefill( + (user_id, room_id, account_data_type), content + ) + + result = self._account_data_id_gen.get_current_token() + return result + + @defer.inlineCallbacks + def add_account_data_for_user(self, user_id, account_data_type, content): + """Add some account_data to a room for a user. + Args: + user_id(str): The user to add a tag for. + account_data_type(str): The type of account_data to add. + content(dict): A json object to associate with the tag. + Returns: + A deferred that completes once the account_data has been added. + """ + content_json = json.dumps(content) + + with self._account_data_id_gen.get_next() as next_id: + # no need to lock here as account_data has a unique constraint on + # (user_id, account_data_type) so _simple_upsert will retry if + # there is a conflict. + yield self._simple_upsert( + desc="add_user_account_data", + table="account_data", + keyvalues={"user_id": user_id, "account_data_type": account_data_type}, + values={"stream_id": next_id, "content": content_json}, + lock=False, + ) + + # it's theoretically possible for the above to succeed and the + # below to fail - in which case we might reuse a stream id on + # restart, and the above update might not get propagated. That + # doesn't sound any worse than the whole update getting lost, + # which is what would happen if we combined the two into one + # transaction. + yield self._update_max_stream_id(next_id) + + self._account_data_stream_cache.entity_has_changed(user_id, next_id) + self.get_account_data_for_user.invalidate((user_id,)) + self.get_global_account_data_by_type_for_user.invalidate( + (account_data_type, user_id) + ) + + result = self._account_data_id_gen.get_current_token() + return result + + def _update_max_stream_id(self, next_id): + """Update the max stream_id + + Args: + next_id(int): The the revision to advance to. + """ + + def _update(txn): + update_max_id_sql = ( + "UPDATE account_data_max_stream_id" + " SET stream_id = ?" + " WHERE stream_id < ?" + ) + txn.execute(update_max_id_sql, (next_id, next_id)) + + return self.runInteraction("update_account_data_max_stream_id", _update) diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py new file mode 100644 index 0000000000..81babf2029 --- /dev/null +++ b/synapse/storage/data_stores/main/appservice.py @@ -0,0 +1,367 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import re + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.appservice import AppServiceTransaction +from synapse.config.appservice import load_appservices +from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.main.events_worker import EventsWorkerStore + +logger = logging.getLogger(__name__) + + +def _make_exclusive_regex(services_cache): + # We precompie a regex constructed from all the regexes that the AS's + # have registered for exclusive users. + exclusive_user_regexes = [ + regex.pattern + for service in services_cache + for regex in service.get_exlusive_user_regexes() + ] + if exclusive_user_regexes: + exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes) + exclusive_user_regex = re.compile(exclusive_user_regex) + else: + # We handle this case specially otherwise the constructed regex + # will always match + exclusive_user_regex = None + + return exclusive_user_regex + + +class ApplicationServiceWorkerStore(SQLBaseStore): + def __init__(self, db_conn, hs): + self.services_cache = load_appservices( + hs.hostname, hs.config.app_service_config_files + ) + self.exclusive_user_regex = _make_exclusive_regex(self.services_cache) + + super(ApplicationServiceWorkerStore, self).__init__(db_conn, hs) + + def get_app_services(self): + return self.services_cache + + def get_if_app_services_interested_in_user(self, user_id): + """Check if the user is one associated with an app service (exclusively) + """ + if self.exclusive_user_regex: + return bool(self.exclusive_user_regex.match(user_id)) + else: + return False + + def get_app_service_by_user_id(self, user_id): + """Retrieve an application service from their user ID. + + All application services have associated with them a particular user ID. + There is no distinguishing feature on the user ID which indicates it + represents an application service. This function allows you to map from + a user ID to an application service. + + Args: + user_id(str): The user ID to see if it is an application service. + Returns: + synapse.appservice.ApplicationService or None. + """ + for service in self.services_cache: + if service.sender == user_id: + return service + return None + + def get_app_service_by_token(self, token): + """Get the application service with the given appservice token. + + Args: + token (str): The application service token. + Returns: + synapse.appservice.ApplicationService or None. + """ + for service in self.services_cache: + if service.token == token: + return service + return None + + def get_app_service_by_id(self, as_id): + """Get the application service with the given appservice ID. + + Args: + as_id (str): The application service ID. + Returns: + synapse.appservice.ApplicationService or None. + """ + for service in self.services_cache: + if service.id == as_id: + return service + return None + + +class ApplicationServiceStore(ApplicationServiceWorkerStore): + # This is currently empty due to there not being any AS storage functions + # that can't be run on the workers. Since this may change in future, and + # to keep consistency with the other stores, we keep this empty class for + # now. + pass + + +class ApplicationServiceTransactionWorkerStore( + ApplicationServiceWorkerStore, EventsWorkerStore +): + @defer.inlineCallbacks + def get_appservices_by_state(self, state): + """Get a list of application services based on their state. + + Args: + state(ApplicationServiceState): The state to filter on. + Returns: + A Deferred which resolves to a list of ApplicationServices, which + may be empty. + """ + results = yield self._simple_select_list( + "application_services_state", dict(state=state), ["as_id"] + ) + # NB: This assumes this class is linked with ApplicationServiceStore + as_list = self.get_app_services() + services = [] + + for res in results: + for service in as_list: + if service.id == res["as_id"]: + services.append(service) + return services + + @defer.inlineCallbacks + def get_appservice_state(self, service): + """Get the application service state. + + Args: + service(ApplicationService): The service whose state to set. + Returns: + A Deferred which resolves to ApplicationServiceState. + """ + result = yield self._simple_select_one( + "application_services_state", + dict(as_id=service.id), + ["state"], + allow_none=True, + desc="get_appservice_state", + ) + if result: + return result.get("state") + return None + + def set_appservice_state(self, service, state): + """Set the application service state. + + Args: + service(ApplicationService): The service whose state to set. + state(ApplicationServiceState): The connectivity state to apply. + Returns: + A Deferred which resolves when the state was set successfully. + """ + return self._simple_upsert( + "application_services_state", dict(as_id=service.id), dict(state=state) + ) + + def create_appservice_txn(self, service, events): + """Atomically creates a new transaction for this application service + with the given list of events. + + Args: + service(ApplicationService): The service who the transaction is for. + events(list): A list of events to put in the transaction. + Returns: + AppServiceTransaction: A new transaction. + """ + + def _create_appservice_txn(txn): + # work out new txn id (highest txn id for this service += 1) + # The highest id may be the last one sent (in which case it is last_txn) + # or it may be the highest in the txns list (which are waiting to be/are + # being sent) + last_txn_id = self._get_last_txn(txn, service.id) + + txn.execute( + "SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?", + (service.id,), + ) + highest_txn_id = txn.fetchone()[0] + if highest_txn_id is None: + highest_txn_id = 0 + + new_txn_id = max(highest_txn_id, last_txn_id) + 1 + + # Insert new txn into txn table + event_ids = json.dumps([e.event_id for e in events]) + txn.execute( + "INSERT INTO application_services_txns(as_id, txn_id, event_ids) " + "VALUES(?,?,?)", + (service.id, new_txn_id, event_ids), + ) + return AppServiceTransaction(service=service, id=new_txn_id, events=events) + + return self.runInteraction("create_appservice_txn", _create_appservice_txn) + + def complete_appservice_txn(self, txn_id, service): + """Completes an application service transaction. + + Args: + txn_id(str): The transaction ID being completed. + service(ApplicationService): The application service which was sent + this transaction. + Returns: + A Deferred which resolves if this transaction was stored + successfully. + """ + txn_id = int(txn_id) + + def _complete_appservice_txn(txn): + # Debugging query: Make sure the txn being completed is EXACTLY +1 from + # what was there before. If it isn't, we've got problems (e.g. the AS + # has probably missed some events), so whine loudly but still continue, + # since it shouldn't fail completion of the transaction. + last_txn_id = self._get_last_txn(txn, service.id) + if (last_txn_id + 1) != txn_id: + logger.error( + "appservice: Completing a transaction which has an ID > 1 from " + "the last ID sent to this AS. We've either dropped events or " + "sent it to the AS out of order. FIX ME. last_txn=%s " + "completing_txn=%s service_id=%s", + last_txn_id, + txn_id, + service.id, + ) + + # Set current txn_id for AS to 'txn_id' + self._simple_upsert_txn( + txn, + "application_services_state", + dict(as_id=service.id), + dict(last_txn=txn_id), + ) + + # Delete txn + self._simple_delete_txn( + txn, "application_services_txns", dict(txn_id=txn_id, as_id=service.id) + ) + + return self.runInteraction("complete_appservice_txn", _complete_appservice_txn) + + @defer.inlineCallbacks + def get_oldest_unsent_txn(self, service): + """Get the oldest transaction which has not been sent for this + service. + + Args: + service(ApplicationService): The app service to get the oldest txn. + Returns: + A Deferred which resolves to an AppServiceTransaction or + None. + """ + + def _get_oldest_unsent_txn(txn): + # Monotonically increasing txn ids, so just select the smallest + # one in the txns table (we delete them when they are sent) + txn.execute( + "SELECT * FROM application_services_txns WHERE as_id=?" + " ORDER BY txn_id ASC LIMIT 1", + (service.id,), + ) + rows = self.cursor_to_dict(txn) + if not rows: + return None + + entry = rows[0] + + return entry + + entry = yield self.runInteraction( + "get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn + ) + + if not entry: + return None + + event_ids = json.loads(entry["event_ids"]) + + events = yield self.get_events_as_list(event_ids) + + return AppServiceTransaction(service=service, id=entry["txn_id"], events=events) + + def _get_last_txn(self, txn, service_id): + txn.execute( + "SELECT last_txn FROM application_services_state WHERE as_id=?", + (service_id,), + ) + last_txn_id = txn.fetchone() + if last_txn_id is None or last_txn_id[0] is None: # no row exists + return 0 + else: + return int(last_txn_id[0]) # select 'last_txn' col + + def set_appservice_last_pos(self, pos): + def set_appservice_last_pos_txn(txn): + txn.execute( + "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) + ) + + return self.runInteraction( + "set_appservice_last_pos", set_appservice_last_pos_txn + ) + + @defer.inlineCallbacks + def get_new_events_for_appservice(self, current_id, limit): + """Get all new evnets""" + + def get_new_events_for_appservice_txn(txn): + sql = ( + "SELECT e.stream_ordering, e.event_id" + " FROM events AS e" + " WHERE" + " (SELECT stream_ordering FROM appservice_stream_position)" + " < e.stream_ordering" + " AND e.stream_ordering <= ?" + " ORDER BY e.stream_ordering ASC" + " LIMIT ?" + ) + + txn.execute(sql, (current_id, limit)) + rows = txn.fetchall() + + upper_bound = current_id + if len(rows) == limit: + upper_bound = rows[-1][0] + + return upper_bound, [row[1] for row in rows] + + upper_bound, event_ids = yield self.runInteraction( + "get_new_events_for_appservice", get_new_events_for_appservice_txn + ) + + events = yield self.get_events_as_list(event_ids) + + return upper_bound, events + + +class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore): + # This is currently empty due to there not being any AS storage functions + # that can't be run on the workers. Since this may change in future, and + # to keep consistency with the other stores, we keep this empty class for + # now. + pass diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py new file mode 100644 index 0000000000..706c6a1f3f --- /dev/null +++ b/synapse/storage/data_stores/main/client_ips.py @@ -0,0 +1,580 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from six import iteritems + +from twisted.internet import defer + +from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.storage import background_updates +from synapse.storage._base import Cache +from synapse.util.caches import CACHE_SIZE_FACTOR + +logger = logging.getLogger(__name__) + +# Number of msec of granularity to store the user IP 'last seen' time. Smaller +# times give more inserts into the database even for readonly API hits +# 120 seconds == 2 minutes +LAST_SEEN_GRANULARITY = 120 * 1000 + + +class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): + def __init__(self, db_conn, hs): + super(ClientIpBackgroundUpdateStore, self).__init__(db_conn, hs) + + self.register_background_index_update( + "user_ips_device_index", + index_name="user_ips_device_id", + table="user_ips", + columns=["user_id", "device_id", "last_seen"], + ) + + self.register_background_index_update( + "user_ips_last_seen_index", + index_name="user_ips_last_seen", + table="user_ips", + columns=["user_id", "last_seen"], + ) + + self.register_background_index_update( + "user_ips_last_seen_only_index", + index_name="user_ips_last_seen_only", + table="user_ips", + columns=["last_seen"], + ) + + self.register_background_update_handler( + "user_ips_analyze", self._analyze_user_ip + ) + + self.register_background_update_handler( + "user_ips_remove_dupes", self._remove_user_ip_dupes + ) + + # Register a unique index + self.register_background_index_update( + "user_ips_device_unique_index", + index_name="user_ips_user_token_ip_unique_index", + table="user_ips", + columns=["user_id", "access_token", "ip"], + unique=True, + ) + + # Drop the old non-unique index + self.register_background_update_handler( + "user_ips_drop_nonunique_index", self._remove_user_ip_nonunique + ) + + # Update the last seen info in devices. + self.register_background_update_handler( + "devices_last_seen", self._devices_last_seen_update + ) + + @defer.inlineCallbacks + def _remove_user_ip_nonunique(self, progress, batch_size): + def f(conn): + txn = conn.cursor() + txn.execute("DROP INDEX IF EXISTS user_ips_user_ip") + txn.close() + + yield self.runWithConnection(f) + yield self._end_background_update("user_ips_drop_nonunique_index") + return 1 + + @defer.inlineCallbacks + def _analyze_user_ip(self, progress, batch_size): + # Background update to analyze user_ips table before we run the + # deduplication background update. The table may not have been analyzed + # for ages due to the table locks. + # + # This will lock out the naive upserts to user_ips while it happens, but + # the analyze should be quick (28GB table takes ~10s) + def user_ips_analyze(txn): + txn.execute("ANALYZE user_ips") + + yield self.runInteraction("user_ips_analyze", user_ips_analyze) + + yield self._end_background_update("user_ips_analyze") + + return 1 + + @defer.inlineCallbacks + def _remove_user_ip_dupes(self, progress, batch_size): + # This works function works by scanning the user_ips table in batches + # based on `last_seen`. For each row in a batch it searches the rest of + # the table to see if there are any duplicates, if there are then they + # are removed and replaced with a suitable row. + + # Fetch the start of the batch + begin_last_seen = progress.get("last_seen", 0) + + def get_last_seen(txn): + txn.execute( + """ + SELECT last_seen FROM user_ips + WHERE last_seen > ? + ORDER BY last_seen + LIMIT 1 + OFFSET ? + """, + (begin_last_seen, batch_size), + ) + row = txn.fetchone() + if row: + return row[0] + else: + return None + + # Get a last seen that has roughly `batch_size` since `begin_last_seen` + end_last_seen = yield self.runInteraction( + "user_ips_dups_get_last_seen", get_last_seen + ) + + # If it returns None, then we're processing the last batch + last = end_last_seen is None + + logger.info( + "Scanning for duplicate 'user_ips' rows in range: %s <= last_seen < %s", + begin_last_seen, + end_last_seen, + ) + + def remove(txn): + # This works by looking at all entries in the given time span, and + # then for each (user_id, access_token, ip) tuple in that range + # checking for any duplicates in the rest of the table (via a join). + # It then only returns entries which have duplicates, and the max + # last_seen across all duplicates, which can the be used to delete + # all other duplicates. + # It is efficient due to the existence of (user_id, access_token, + # ip) and (last_seen) indices. + + # Define the search space, which requires handling the last batch in + # a different way + if last: + clause = "? <= last_seen" + args = (begin_last_seen,) + else: + clause = "? <= last_seen AND last_seen < ?" + args = (begin_last_seen, end_last_seen) + + # (Note: The DISTINCT in the inner query is important to ensure that + # the COUNT(*) is accurate, otherwise double counting may happen due + # to the join effectively being a cross product) + txn.execute( + """ + SELECT user_id, access_token, ip, + MAX(device_id), MAX(user_agent), MAX(last_seen), + COUNT(*) + FROM ( + SELECT DISTINCT user_id, access_token, ip + FROM user_ips + WHERE {} + ) c + INNER JOIN user_ips USING (user_id, access_token, ip) + GROUP BY user_id, access_token, ip + HAVING count(*) > 1 + """.format( + clause + ), + args, + ) + res = txn.fetchall() + + # We've got some duplicates + for i in res: + user_id, access_token, ip, device_id, user_agent, last_seen, count = i + + # We want to delete the duplicates so we end up with only a + # single row. + # + # The naive way of doing this would be just to delete all rows + # and reinsert a constructed row. However, if there are a lot of + # duplicate rows this can cause the table to grow a lot, which + # can be problematic in two ways: + # 1. If user_ips is already large then this can cause the + # table to rapidly grow, potentially filling the disk. + # 2. Reinserting a lot of rows can confuse the table + # statistics for postgres, causing it to not use the + # correct indices for the query above, resulting in a full + # table scan. This is incredibly slow for large tables and + # can kill database performance. (This seems to mainly + # happen for the last query where the clause is simply `? < + # last_seen`) + # + # So instead we want to delete all but *one* of the duplicate + # rows. That is hard to do reliably, so we cheat and do a two + # step process: + # 1. Delete all rows with a last_seen strictly less than the + # max last_seen. This hopefully results in deleting all but + # one row the majority of the time, but there may be + # duplicate last_seen + # 2. If multiple rows remain, we fall back to the naive method + # and simply delete all rows and reinsert. + # + # Note that this relies on no new duplicate rows being inserted, + # but if that is happening then this entire process is futile + # anyway. + + # Do step 1: + + txn.execute( + """ + DELETE FROM user_ips + WHERE user_id = ? AND access_token = ? AND ip = ? AND last_seen < ? + """, + (user_id, access_token, ip, last_seen), + ) + if txn.rowcount == count - 1: + # We deleted all but one of the duplicate rows, i.e. there + # is exactly one remaining and so there is nothing left to + # do. + continue + elif txn.rowcount >= count: + raise Exception( + "We deleted more duplicate rows from 'user_ips' than expected" + ) + + # The previous step didn't delete enough rows, so we fallback to + # step 2: + + # Drop all the duplicates + txn.execute( + """ + DELETE FROM user_ips + WHERE user_id = ? AND access_token = ? AND ip = ? + """, + (user_id, access_token, ip), + ) + + # Add in one to be the last_seen + txn.execute( + """ + INSERT INTO user_ips + (user_id, access_token, ip, device_id, user_agent, last_seen) + VALUES (?, ?, ?, ?, ?, ?) + """, + (user_id, access_token, ip, device_id, user_agent, last_seen), + ) + + self._background_update_progress_txn( + txn, "user_ips_remove_dupes", {"last_seen": end_last_seen} + ) + + yield self.runInteraction("user_ips_dups_remove", remove) + + if last: + yield self._end_background_update("user_ips_remove_dupes") + + return batch_size + + @defer.inlineCallbacks + def _devices_last_seen_update(self, progress, batch_size): + """Background update to insert last seen info into devices table + """ + + last_user_id = progress.get("last_user_id", "") + last_device_id = progress.get("last_device_id", "") + + def _devices_last_seen_update_txn(txn): + # This consists of two queries: + # + # 1. The sub-query searches for the next N devices and joins + # against user_ips to find the max last_seen associated with + # that device. + # 2. The outer query then joins again against user_ips on + # user/device/last_seen. This *should* hopefully only + # return one row, but if it does return more than one then + # we'll just end up updating the same device row multiple + # times, which is fine. + + if self.database_engine.supports_tuple_comparison: + where_clause = "(user_id, device_id) > (?, ?)" + where_args = [last_user_id, last_device_id] + else: + # We explicitly do a `user_id >= ? AND (...)` here to ensure + # that an index is used, as doing `user_id > ? OR (user_id = ? AND ...)` + # makes it hard for query optimiser to tell that it can use the + # index on user_id + where_clause = "user_id >= ? AND (user_id > ? OR device_id > ?)" + where_args = [last_user_id, last_user_id, last_device_id] + + sql = """ + SELECT + last_seen, ip, user_agent, user_id, device_id + FROM ( + SELECT + user_id, device_id, MAX(u.last_seen) AS last_seen + FROM devices + INNER JOIN user_ips AS u USING (user_id, device_id) + WHERE %(where_clause)s + GROUP BY user_id, device_id + ORDER BY user_id ASC, device_id ASC + LIMIT ? + ) c + INNER JOIN user_ips AS u USING (user_id, device_id, last_seen) + """ % { + "where_clause": where_clause + } + txn.execute(sql, where_args + [batch_size]) + + rows = txn.fetchall() + if not rows: + return 0 + + sql = """ + UPDATE devices + SET last_seen = ?, ip = ?, user_agent = ? + WHERE user_id = ? AND device_id = ? + """ + txn.execute_batch(sql, rows) + + _, _, _, user_id, device_id = rows[-1] + self._background_update_progress_txn( + txn, + "devices_last_seen", + {"last_user_id": user_id, "last_device_id": device_id}, + ) + + return len(rows) + + updated = yield self.runInteraction( + "_devices_last_seen_update", _devices_last_seen_update_txn + ) + + if not updated: + yield self._end_background_update("devices_last_seen") + + return updated + + +class ClientIpStore(ClientIpBackgroundUpdateStore): + def __init__(self, db_conn, hs): + + self.client_ip_last_seen = Cache( + name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR + ) + + super(ClientIpStore, self).__init__(db_conn, hs) + + self.user_ips_max_age = hs.config.user_ips_max_age + + # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen) + self._batch_row_update = {} + + self._client_ip_looper = self._clock.looping_call( + self._update_client_ips_batch, 5 * 1000 + ) + self.hs.get_reactor().addSystemEventTrigger( + "before", "shutdown", self._update_client_ips_batch + ) + + if self.user_ips_max_age: + self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) + + @defer.inlineCallbacks + def insert_client_ip( + self, user_id, access_token, ip, user_agent, device_id, now=None + ): + if not now: + now = int(self._clock.time_msec()) + key = (user_id, access_token, ip) + + try: + last_seen = self.client_ip_last_seen.get(key) + except KeyError: + last_seen = None + yield self.populate_monthly_active_users(user_id) + # Rate-limited inserts + if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: + return + + self.client_ip_last_seen.prefill(key, now) + + self._batch_row_update[key] = (user_agent, device_id, now) + + @wrap_as_background_process("update_client_ips") + def _update_client_ips_batch(self): + + # If the DB pool has already terminated, don't try updating + if not self.hs.get_db_pool().running: + return + + to_update = self._batch_row_update + self._batch_row_update = {} + + return self.runInteraction( + "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update + ) + + def _update_client_ips_batch_txn(self, txn, to_update): + if "user_ips" in self._unsafe_to_upsert_tables or ( + not self.database_engine.can_native_upsert + ): + self.database_engine.lock_table(txn, "user_ips") + + for entry in iteritems(to_update): + (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry + + try: + self._simple_upsert_txn( + txn, + table="user_ips", + keyvalues={ + "user_id": user_id, + "access_token": access_token, + "ip": ip, + }, + values={ + "user_agent": user_agent, + "device_id": device_id, + "last_seen": last_seen, + }, + lock=False, + ) + + # Technically an access token might not be associated with + # a device so we need to check. + if device_id: + self._simple_upsert_txn( + txn, + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + values={ + "user_agent": user_agent, + "last_seen": last_seen, + "ip": ip, + }, + lock=False, + ) + except Exception as e: + # Failed to upsert, log and continue + logger.error("Failed to insert client IP %r: %r", entry, e) + + @defer.inlineCallbacks + def get_last_client_ip_by_device(self, user_id, device_id): + """For each device_id listed, give the user_ip it was last seen on + + Args: + user_id (str) + device_id (str): If None fetches all devices for the user + + Returns: + defer.Deferred: resolves to a dict, where the keys + are (user_id, device_id) tuples. The values are also dicts, with + keys giving the column names + """ + + keyvalues = {"user_id": user_id} + if device_id is not None: + keyvalues["device_id"] = device_id + + res = yield self._simple_select_list( + table="devices", + keyvalues=keyvalues, + retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), + ) + + ret = {(d["user_id"], d["device_id"]): d for d in res} + for key in self._batch_row_update: + uid, access_token, ip = key + if uid == user_id: + user_agent, did, last_seen = self._batch_row_update[key] + if not device_id or did == device_id: + ret[(user_id, device_id)] = { + "user_id": user_id, + "access_token": access_token, + "ip": ip, + "user_agent": user_agent, + "device_id": did, + "last_seen": last_seen, + } + return ret + + @defer.inlineCallbacks + def get_user_ip_and_agents(self, user): + user_id = user.to_string() + results = {} + + for key in self._batch_row_update: + uid, access_token, ip, = key + if uid == user_id: + user_agent, _, last_seen = self._batch_row_update[key] + results[(access_token, ip)] = (user_agent, last_seen) + + rows = yield self._simple_select_list( + table="user_ips", + keyvalues={"user_id": user_id}, + retcols=["access_token", "ip", "user_agent", "last_seen"], + desc="get_user_ip_and_agents", + ) + + results.update( + ((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"])) + for row in rows + ) + return list( + { + "access_token": access_token, + "ip": ip, + "user_agent": user_agent, + "last_seen": last_seen, + } + for (access_token, ip), (user_agent, last_seen) in iteritems(results) + ) + + @wrap_as_background_process("prune_old_user_ips") + async def _prune_old_user_ips(self): + """Removes entries in user IPs older than the configured period. + """ + + if self.user_ips_max_age is None: + # Nothing to do + return + + if not await self.has_completed_background_update("devices_last_seen"): + # Only start pruning if we have finished populating the devices + # last seen info. + return + + # We do a slightly funky SQL delete to ensure we don't try and delete + # too much at once (as the table may be very large from before we + # started pruning). + # + # This works by finding the max last_seen that is less than the given + # time, but has no more than N rows before it, deleting all rows with + # a lesser last_seen time. (We COALESCE so that the sub-SELECT always + # returns exactly one row). + sql = """ + DELETE FROM user_ips + WHERE last_seen <= ( + SELECT COALESCE(MAX(last_seen), -1) + FROM ( + SELECT last_seen FROM user_ips + WHERE last_seen <= ? + ORDER BY last_seen ASC + LIMIT 5000 + ) AS u + ) + """ + + timestamp = self.clock.time_msec() - self.user_ips_max_age + + def _prune_old_user_ips_txn(txn): + txn.execute(sql, (timestamp,)) + + await self.runInteraction("_prune_old_user_ips", _prune_old_user_ips_txn) diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py new file mode 100644 index 0000000000..f04aad0743 --- /dev/null +++ b/synapse/storage/data_stores/main/deviceinbox.py @@ -0,0 +1,457 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.logging.opentracing import log_kv, set_tag, trace +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.util.caches.expiringcache import ExpiringCache + +logger = logging.getLogger(__name__) + + +class DeviceInboxWorkerStore(SQLBaseStore): + def get_to_device_stream_token(self): + return self._device_inbox_id_gen.get_current_token() + + def get_new_messages_for_device( + self, user_id, device_id, last_stream_id, current_stream_id, limit=100 + ): + """ + Args: + user_id(str): The recipient user_id. + device_id(str): The recipient device_id. + current_stream_id(int): The current position of the to device + message stream. + Returns: + Deferred ([dict], int): List of messages for the device and where + in the stream the messages got to. + """ + has_changed = self._device_inbox_stream_cache.has_entity_changed( + user_id, last_stream_id + ) + if not has_changed: + return defer.succeed(([], current_stream_id)) + + def get_new_messages_for_device_txn(txn): + sql = ( + "SELECT stream_id, message_json FROM device_inbox" + " WHERE user_id = ? AND device_id = ?" + " AND ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC" + " LIMIT ?" + ) + txn.execute( + sql, (user_id, device_id, last_stream_id, current_stream_id, limit) + ) + messages = [] + for row in txn: + stream_pos = row[0] + messages.append(json.loads(row[1])) + if len(messages) < limit: + stream_pos = current_stream_id + return messages, stream_pos + + return self.runInteraction( + "get_new_messages_for_device", get_new_messages_for_device_txn + ) + + @trace + @defer.inlineCallbacks + def delete_messages_for_device(self, user_id, device_id, up_to_stream_id): + """ + Args: + user_id(str): The recipient user_id. + device_id(str): The recipient device_id. + up_to_stream_id(int): Where to delete messages up to. + Returns: + A deferred that resolves to the number of messages deleted. + """ + # If we have cached the last stream id we've deleted up to, we can + # check if there is likely to be anything that needs deleting + last_deleted_stream_id = self._last_device_delete_cache.get( + (user_id, device_id), None + ) + + set_tag("last_deleted_stream_id", last_deleted_stream_id) + + if last_deleted_stream_id: + has_changed = self._device_inbox_stream_cache.has_entity_changed( + user_id, last_deleted_stream_id + ) + if not has_changed: + log_kv({"message": "No changes in cache since last check"}) + return 0 + + def delete_messages_for_device_txn(txn): + sql = ( + "DELETE FROM device_inbox" + " WHERE user_id = ? AND device_id = ?" + " AND stream_id <= ?" + ) + txn.execute(sql, (user_id, device_id, up_to_stream_id)) + return txn.rowcount + + count = yield self.runInteraction( + "delete_messages_for_device", delete_messages_for_device_txn + ) + + log_kv( + {"message": "deleted {} messages for device".format(count), "count": count} + ) + + # Update the cache, ensuring that we only ever increase the value + last_deleted_stream_id = self._last_device_delete_cache.get( + (user_id, device_id), 0 + ) + self._last_device_delete_cache[(user_id, device_id)] = max( + last_deleted_stream_id, up_to_stream_id + ) + + return count + + @trace + def get_new_device_msgs_for_remote( + self, destination, last_stream_id, current_stream_id, limit + ): + """ + Args: + destination(str): The name of the remote server. + last_stream_id(int|long): The last position of the device message stream + that the server sent up to. + current_stream_id(int|long): The current position of the device + message stream. + Returns: + Deferred ([dict], int|long): List of messages for the device and where + in the stream the messages got to. + """ + + set_tag("destination", destination) + set_tag("last_stream_id", last_stream_id) + set_tag("current_stream_id", current_stream_id) + set_tag("limit", limit) + + has_changed = self._device_federation_outbox_stream_cache.has_entity_changed( + destination, last_stream_id + ) + if not has_changed or last_stream_id == current_stream_id: + log_kv({"message": "No new messages in stream"}) + return defer.succeed(([], current_stream_id)) + + if limit <= 0: + # This can happen if we run out of room for EDUs in the transaction. + return defer.succeed(([], last_stream_id)) + + @trace + def get_new_messages_for_remote_destination_txn(txn): + sql = ( + "SELECT stream_id, messages_json FROM device_federation_outbox" + " WHERE destination = ?" + " AND ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC" + " LIMIT ?" + ) + txn.execute(sql, (destination, last_stream_id, current_stream_id, limit)) + messages = [] + for row in txn: + stream_pos = row[0] + messages.append(json.loads(row[1])) + if len(messages) < limit: + log_kv({"message": "Set stream position to current position"}) + stream_pos = current_stream_id + return messages, stream_pos + + return self.runInteraction( + "get_new_device_msgs_for_remote", + get_new_messages_for_remote_destination_txn, + ) + + @trace + def delete_device_msgs_for_remote(self, destination, up_to_stream_id): + """Used to delete messages when the remote destination acknowledges + their receipt. + + Args: + destination(str): The destination server_name + up_to_stream_id(int): Where to delete messages up to. + Returns: + A deferred that resolves when the messages have been deleted. + """ + + def delete_messages_for_remote_destination_txn(txn): + sql = ( + "DELETE FROM device_federation_outbox" + " WHERE destination = ?" + " AND stream_id <= ?" + ) + txn.execute(sql, (destination, up_to_stream_id)) + + return self.runInteraction( + "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn + ) + + +class DeviceInboxBackgroundUpdateStore(BackgroundUpdateStore): + DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" + + def __init__(self, db_conn, hs): + super(DeviceInboxBackgroundUpdateStore, self).__init__(db_conn, hs) + + self.register_background_index_update( + "device_inbox_stream_index", + index_name="device_inbox_stream_id_user_id", + table="device_inbox", + columns=["stream_id", "user_id"], + ) + + self.register_background_update_handler( + self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox + ) + + @defer.inlineCallbacks + def _background_drop_index_device_inbox(self, progress, batch_size): + def reindex_txn(conn): + txn = conn.cursor() + txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id") + txn.close() + + yield self.runWithConnection(reindex_txn) + + yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID) + + return 1 + + +class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore): + DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" + + def __init__(self, db_conn, hs): + super(DeviceInboxStore, self).__init__(db_conn, hs) + + # Map of (user_id, device_id) to the last stream_id that has been + # deleted up to. This is so that we can no op deletions. + self._last_device_delete_cache = ExpiringCache( + cache_name="last_device_delete_cache", + clock=self._clock, + max_len=10000, + expiry_ms=30 * 60 * 1000, + ) + + @trace + @defer.inlineCallbacks + def add_messages_to_device_inbox( + self, local_messages_by_user_then_device, remote_messages_by_destination + ): + """Used to send messages from this server. + + Args: + sender_user_id(str): The ID of the user sending these messages. + local_messages_by_user_and_device(dict): + Dictionary of user_id to device_id to message. + remote_messages_by_destination(dict): + Dictionary of destination server_name to the EDU JSON to send. + Returns: + A deferred stream_id that resolves when the messages have been + inserted. + """ + + def add_messages_txn(txn, now_ms, stream_id): + # Add the local messages directly to the local inbox. + self._add_messages_to_local_device_inbox_txn( + txn, stream_id, local_messages_by_user_then_device + ) + + # Add the remote messages to the federation outbox. + # We'll send them to a remote server when we next send a + # federation transaction to that destination. + sql = ( + "INSERT INTO device_federation_outbox" + " (destination, stream_id, queued_ts, messages_json)" + " VALUES (?,?,?,?)" + ) + rows = [] + for destination, edu in remote_messages_by_destination.items(): + edu_json = json.dumps(edu) + rows.append((destination, stream_id, now_ms, edu_json)) + txn.executemany(sql, rows) + + with self._device_inbox_id_gen.get_next() as stream_id: + now_ms = self.clock.time_msec() + yield self.runInteraction( + "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id + ) + for user_id in local_messages_by_user_then_device.keys(): + self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id) + for destination in remote_messages_by_destination.keys(): + self._device_federation_outbox_stream_cache.entity_has_changed( + destination, stream_id + ) + + return self._device_inbox_id_gen.get_current_token() + + @defer.inlineCallbacks + def add_messages_from_remote_to_device_inbox( + self, origin, message_id, local_messages_by_user_then_device + ): + def add_messages_txn(txn, now_ms, stream_id): + # Check if we've already inserted a matching message_id for that + # origin. This can happen if the origin doesn't receive our + # acknowledgement from the first time we received the message. + already_inserted = self._simple_select_one_txn( + txn, + table="device_federation_inbox", + keyvalues={"origin": origin, "message_id": message_id}, + retcols=("message_id",), + allow_none=True, + ) + if already_inserted is not None: + return + + # Add an entry for this message_id so that we know we've processed + # it. + self._simple_insert_txn( + txn, + table="device_federation_inbox", + values={ + "origin": origin, + "message_id": message_id, + "received_ts": now_ms, + }, + ) + + # Add the messages to the approriate local device inboxes so that + # they'll be sent to the devices when they next sync. + self._add_messages_to_local_device_inbox_txn( + txn, stream_id, local_messages_by_user_then_device + ) + + with self._device_inbox_id_gen.get_next() as stream_id: + now_ms = self.clock.time_msec() + yield self.runInteraction( + "add_messages_from_remote_to_device_inbox", + add_messages_txn, + now_ms, + stream_id, + ) + for user_id in local_messages_by_user_then_device.keys(): + self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id) + + return stream_id + + def _add_messages_to_local_device_inbox_txn( + self, txn, stream_id, messages_by_user_then_device + ): + sql = "UPDATE device_max_stream_id" " SET stream_id = ?" " WHERE stream_id < ?" + txn.execute(sql, (stream_id, stream_id)) + + local_by_user_then_device = {} + for user_id, messages_by_device in messages_by_user_then_device.items(): + messages_json_for_user = {} + devices = list(messages_by_device.keys()) + if len(devices) == 1 and devices[0] == "*": + # Handle wildcard device_ids. + sql = "SELECT device_id FROM devices" " WHERE user_id = ?" + txn.execute(sql, (user_id,)) + message_json = json.dumps(messages_by_device["*"]) + for row in txn: + # Add the message for all devices for this user on this + # server. + device = row[0] + messages_json_for_user[device] = message_json + else: + if not devices: + continue + + clause, args = make_in_list_sql_clause( + txn.database_engine, "device_id", devices + ) + sql = "SELECT device_id FROM devices WHERE user_id = ? AND " + clause + + # TODO: Maybe this needs to be done in batches if there are + # too many local devices for a given user. + txn.execute(sql, [user_id] + list(args)) + for row in txn: + # Only insert into the local inbox if the device exists on + # this server + device = row[0] + message_json = json.dumps(messages_by_device[device]) + messages_json_for_user[device] = message_json + + if messages_json_for_user: + local_by_user_then_device[user_id] = messages_json_for_user + + if not local_by_user_then_device: + return + + sql = ( + "INSERT INTO device_inbox" + " (user_id, device_id, stream_id, message_json)" + " VALUES (?,?,?,?)" + ) + rows = [] + for user_id, messages_by_device in local_by_user_then_device.items(): + for device_id, message_json in messages_by_device.items(): + rows.append((user_id, device_id, stream_id, message_json)) + + txn.executemany(sql, rows) + + def get_all_new_device_messages(self, last_pos, current_pos, limit): + """ + Args: + last_pos(int): + current_pos(int): + limit(int): + Returns: + A deferred list of rows from the device inbox + """ + if last_pos == current_pos: + return defer.succeed([]) + + def get_all_new_device_messages_txn(txn): + # We limit like this as we might have multiple rows per stream_id, and + # we want to make sure we always get all entries for any stream_id + # we return. + upper_pos = min(current_pos, last_pos + limit) + sql = ( + "SELECT max(stream_id), user_id" + " FROM device_inbox" + " WHERE ? < stream_id AND stream_id <= ?" + " GROUP BY user_id" + ) + txn.execute(sql, (last_pos, upper_pos)) + rows = txn.fetchall() + + sql = ( + "SELECT max(stream_id), destination" + " FROM device_federation_outbox" + " WHERE ? < stream_id AND stream_id <= ?" + " GROUP BY destination" + ) + txn.execute(sql, (last_pos, upper_pos)) + rows.extend(txn) + + # Order by ascending stream ordering + rows.sort() + + return rows + + return self.runInteraction( + "get_all_new_device_messages", get_all_new_device_messages_txn + ) diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py new file mode 100644 index 0000000000..ac5239e50a --- /dev/null +++ b/synapse/storage/data_stores/main/devices.py @@ -0,0 +1,937 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from six import iteritems + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.api.errors import StoreError +from synapse.logging.opentracing import ( + get_active_span_text_map, + set_tag, + trace, + whitelisted_homeserver, +) +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage._base import ( + Cache, + SQLBaseStore, + db_to_json, + make_in_list_sql_clause, +) +from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.util import batch_iter +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList + +logger = logging.getLogger(__name__) + +DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = ( + "drop_device_list_streams_non_unique_indexes" +) + + +class DeviceWorkerStore(SQLBaseStore): + def get_device(self, user_id, device_id): + """Retrieve a device. + + Args: + user_id (str): The ID of the user which owns the device + device_id (str): The ID of the device to retrieve + Returns: + defer.Deferred for a dict containing the device information + Raises: + StoreError: if the device is not found + """ + return self._simple_select_one( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + retcols=("user_id", "device_id", "display_name"), + desc="get_device", + ) + + @defer.inlineCallbacks + def get_devices_by_user(self, user_id): + """Retrieve all of a user's registered devices. + + Args: + user_id (str): + Returns: + defer.Deferred: resolves to a dict from device_id to a dict + containing "device_id", "user_id" and "display_name" for each + device. + """ + devices = yield self._simple_select_list( + table="devices", + keyvalues={"user_id": user_id}, + retcols=("user_id", "device_id", "display_name"), + desc="get_devices_by_user", + ) + + return {d["device_id"]: d for d in devices} + + @trace + @defer.inlineCallbacks + def get_devices_by_remote(self, destination, from_stream_id, limit): + """Get stream of updates to send to remote servers + + Returns: + Deferred[tuple[int, list[dict]]]: + current stream id (ie, the stream id of the last update included in the + response), and the list of updates + """ + now_stream_id = self._device_list_id_gen.get_current_token() + + has_changed = self._device_list_federation_stream_cache.has_entity_changed( + destination, int(from_stream_id) + ) + if not has_changed: + return now_stream_id, [] + + # We retrieve n+1 devices from the list of outbound pokes where n is + # our outbound device update limit. We then check if the very last + # device has the same stream_id as the second-to-last device. If so, + # then we ignore all devices with that stream_id and only send the + # devices with a lower stream_id. + # + # If when culling the list we end up with no devices afterwards, we + # consider the device update to be too large, and simply skip the + # stream_id; the rationale being that such a large device list update + # is likely an error. + updates = yield self.runInteraction( + "get_devices_by_remote", + self._get_devices_by_remote_txn, + destination, + from_stream_id, + now_stream_id, + limit + 1, + ) + + # Return an empty list if there are no updates + if not updates: + return now_stream_id, [] + + # if we have exceeded the limit, we need to exclude any results with the + # same stream_id as the last row. + if len(updates) > limit: + stream_id_cutoff = updates[-1][2] + now_stream_id = stream_id_cutoff - 1 + else: + stream_id_cutoff = None + + # Perform the equivalent of a GROUP BY + # + # Iterate through the updates list and copy non-duplicate + # (user_id, device_id) entries into a map, with the value being + # the max stream_id across each set of duplicate entries + # + # maps (user_id, device_id) -> (stream_id, opentracing_context) + # as long as their stream_id does not match that of the last row + # + # opentracing_context contains the opentracing metadata for the request + # that created the poke + # + # The most recent request's opentracing_context is used as the + # context which created the Edu. + + query_map = {} + for update in updates: + if stream_id_cutoff is not None and update[2] >= stream_id_cutoff: + # Stop processing updates + break + + key = (update[0], update[1]) + + update_context = update[3] + update_stream_id = update[2] + + previous_update_stream_id, _ = query_map.get(key, (0, None)) + + if update_stream_id > previous_update_stream_id: + query_map[key] = (update_stream_id, update_context) + + # If we didn't find any updates with a stream_id lower than the cutoff, it + # means that there are more than limit updates all of which have the same + # steam_id. + + # That should only happen if a client is spamming the server with new + # devices, in which case E2E isn't going to work well anyway. We'll just + # skip that stream_id and return an empty list, and continue with the next + # stream_id next time. + if not query_map: + return stream_id_cutoff, [] + + results = yield self._get_device_update_edus_by_remote( + destination, from_stream_id, query_map + ) + + return now_stream_id, results + + def _get_devices_by_remote_txn( + self, txn, destination, from_stream_id, now_stream_id, limit + ): + """Return device update information for a given remote destination + + Args: + txn (LoggingTransaction): The transaction to execute + destination (str): The host the device updates are intended for + from_stream_id (int): The minimum stream_id to filter updates by, exclusive + now_stream_id (int): The maximum stream_id to filter updates by, inclusive + limit (int): Maximum number of device updates to return + + Returns: + List: List of device updates + """ + sql = """ + SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes + WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ? + ORDER BY stream_id + LIMIT ? + """ + txn.execute(sql, (destination, from_stream_id, now_stream_id, False, limit)) + + return list(txn) + + @defer.inlineCallbacks + def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_map): + """Returns a list of device update EDUs as well as E2EE keys + + Args: + destination (str): The host the device updates are intended for + from_stream_id (int): The minimum stream_id to filter updates by, exclusive + query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping + user_id/device_id to update stream_id and the relevent json-encoded + opentracing context + + Returns: + List[Dict]: List of objects representing an device update EDU + + """ + devices = yield self.runInteraction( + "_get_e2e_device_keys_txn", + self._get_e2e_device_keys_txn, + query_map.keys(), + include_all_devices=True, + include_deleted_devices=True, + ) + + results = [] + for user_id, user_devices in iteritems(devices): + # The prev_id for the first row is always the last row before + # `from_stream_id` + prev_id = yield self._get_last_device_update_for_remote_user( + destination, user_id, from_stream_id + ) + for device_id, device in iteritems(user_devices): + stream_id, opentracing_context = query_map[(user_id, device_id)] + result = { + "user_id": user_id, + "device_id": device_id, + "prev_id": [prev_id] if prev_id else [], + "stream_id": stream_id, + "org.matrix.opentracing_context": opentracing_context, + } + + prev_id = stream_id + + if device is not None: + key_json = device.get("key_json", None) + if key_json: + result["keys"] = db_to_json(key_json) + device_display_name = device.get("device_display_name", None) + if device_display_name: + result["device_display_name"] = device_display_name + else: + result["deleted"] = True + + results.append(result) + + return results + + def _get_last_device_update_for_remote_user( + self, destination, user_id, from_stream_id + ): + def f(txn): + prev_sent_id_sql = """ + SELECT coalesce(max(stream_id), 0) as stream_id + FROM device_lists_outbound_last_success + WHERE destination = ? AND user_id = ? AND stream_id <= ? + """ + txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id)) + rows = txn.fetchall() + return rows[0][0] + + return self.runInteraction("get_last_device_update_for_remote_user", f) + + def mark_as_sent_devices_by_remote(self, destination, stream_id): + """Mark that updates have successfully been sent to the destination. + """ + return self.runInteraction( + "mark_as_sent_devices_by_remote", + self._mark_as_sent_devices_by_remote_txn, + destination, + stream_id, + ) + + def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): + # We update the device_lists_outbound_last_success with the successfully + # poked users. We do the join to see which users need to be inserted and + # which updated. + sql = """ + SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL) + FROM device_lists_outbound_pokes as o + LEFT JOIN device_lists_outbound_last_success as s + USING (destination, user_id) + WHERE destination = ? AND o.stream_id <= ? + GROUP BY user_id + """ + txn.execute(sql, (destination, stream_id)) + rows = txn.fetchall() + + sql = """ + UPDATE device_lists_outbound_last_success + SET stream_id = ? + WHERE destination = ? AND user_id = ? + """ + txn.executemany(sql, ((row[1], destination, row[0]) for row in rows if row[2])) + + sql = """ + INSERT INTO device_lists_outbound_last_success + (destination, user_id, stream_id) VALUES (?, ?, ?) + """ + txn.executemany( + sql, ((destination, row[0], row[1]) for row in rows if not row[2]) + ) + + # Delete all sent outbound pokes + sql = """ + DELETE FROM device_lists_outbound_pokes + WHERE destination = ? AND stream_id <= ? + """ + txn.execute(sql, (destination, stream_id)) + + def get_device_stream_token(self): + return self._device_list_id_gen.get_current_token() + + @trace + @defer.inlineCallbacks + def get_user_devices_from_cache(self, query_list): + """Get the devices (and keys if any) for remote users from the cache. + + Args: + query_list(list): List of (user_id, device_ids), if device_ids is + falsey then return all device ids for that user. + + Returns: + (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is + a set of user_ids and results_map is a mapping of + user_id -> device_id -> device_info + """ + user_ids = set(user_id for user_id, _ in query_list) + user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids)) + user_ids_in_cache = set( + user_id for user_id, stream_id in user_map.items() if stream_id + ) + user_ids_not_in_cache = user_ids - user_ids_in_cache + + results = {} + for user_id, device_id in query_list: + if user_id not in user_ids_in_cache: + continue + + if device_id: + device = yield self._get_cached_user_device(user_id, device_id) + results.setdefault(user_id, {})[device_id] = device + else: + results[user_id] = yield self._get_cached_devices_for_user(user_id) + + set_tag("in_cache", results) + set_tag("not_in_cache", user_ids_not_in_cache) + + return user_ids_not_in_cache, results + + @cachedInlineCallbacks(num_args=2, tree=True) + def _get_cached_user_device(self, user_id, device_id): + content = yield self._simple_select_one_onecol( + table="device_lists_remote_cache", + keyvalues={"user_id": user_id, "device_id": device_id}, + retcol="content", + desc="_get_cached_user_device", + ) + return db_to_json(content) + + @cachedInlineCallbacks() + def _get_cached_devices_for_user(self, user_id): + devices = yield self._simple_select_list( + table="device_lists_remote_cache", + keyvalues={"user_id": user_id}, + retcols=("device_id", "content"), + desc="_get_cached_devices_for_user", + ) + return { + device["device_id"]: db_to_json(device["content"]) for device in devices + } + + def get_devices_with_keys_by_user(self, user_id): + """Get all devices (with any device keys) for a user + + Returns: + (stream_id, devices) + """ + return self.runInteraction( + "get_devices_with_keys_by_user", + self._get_devices_with_keys_by_user_txn, + user_id, + ) + + def _get_devices_with_keys_by_user_txn(self, txn, user_id): + now_stream_id = self._device_list_id_gen.get_current_token() + + devices = self._get_e2e_device_keys_txn( + txn, [(user_id, None)], include_all_devices=True + ) + + if devices: + user_devices = devices[user_id] + results = [] + for device_id, device in iteritems(user_devices): + result = {"device_id": device_id} + + key_json = device.get("key_json", None) + if key_json: + result["keys"] = db_to_json(key_json) + device_display_name = device.get("device_display_name", None) + if device_display_name: + result["device_display_name"] = device_display_name + + results.append(result) + + return now_stream_id, results + + return now_stream_id, [] + + def get_users_whose_devices_changed(self, from_key, user_ids): + """Get set of users whose devices have changed since `from_key` that + are in the given list of user_ids. + + Args: + from_key (str): The device lists stream token + user_ids (Iterable[str]) + + Returns: + Deferred[set[str]]: The set of user_ids whose devices have changed + since `from_key` + """ + from_key = int(from_key) + + # Get set of users who *may* have changed. Users not in the returned + # list have definitely not changed. + to_check = list( + self._device_list_stream_cache.get_entities_changed(user_ids, from_key) + ) + + if not to_check: + return defer.succeed(set()) + + def _get_users_whose_devices_changed_txn(txn): + changes = set() + + sql = """ + SELECT DISTINCT user_id FROM device_lists_stream + WHERE stream_id > ? + AND + """ + + for chunk in batch_iter(to_check, 100): + clause, args = make_in_list_sql_clause( + txn.database_engine, "user_id", chunk + ) + txn.execute(sql + clause, (from_key,) + tuple(args)) + changes.update(user_id for user_id, in txn) + + return changes + + return self.runInteraction( + "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn + ) + + def get_all_device_list_changes_for_remotes(self, from_key, to_key): + """Return a list of `(stream_id, user_id, destination)` which is the + combined list of changes to devices, and which destinations need to be + poked. `destination` may be None if no destinations need to be poked. + """ + # We do a group by here as there can be a large number of duplicate + # entries, since we throw away device IDs. + sql = """ + SELECT MAX(stream_id) AS stream_id, user_id, destination + FROM device_lists_stream + LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id) + WHERE ? < stream_id AND stream_id <= ? + GROUP BY user_id, destination + """ + return self._execute( + "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key + ) + + @cached(max_entries=10000) + def get_device_list_last_stream_id_for_remote(self, user_id): + """Get the last stream_id we got for a user. May be None if we haven't + got any information for them. + """ + return self._simple_select_one_onecol( + table="device_lists_remote_extremeties", + keyvalues={"user_id": user_id}, + retcol="stream_id", + desc="get_device_list_last_stream_id_for_remote", + allow_none=True, + ) + + @cachedList( + cached_method_name="get_device_list_last_stream_id_for_remote", + list_name="user_ids", + inlineCallbacks=True, + ) + def get_device_list_last_stream_id_for_remotes(self, user_ids): + rows = yield self._simple_select_many_batch( + table="device_lists_remote_extremeties", + column="user_id", + iterable=user_ids, + retcols=("user_id", "stream_id"), + desc="get_device_list_last_stream_id_for_remotes", + ) + + results = {user_id: None for user_id in user_ids} + results.update({row["user_id"]: row["stream_id"] for row in rows}) + + return results + + +class DeviceBackgroundUpdateStore(BackgroundUpdateStore): + def __init__(self, db_conn, hs): + super(DeviceBackgroundUpdateStore, self).__init__(db_conn, hs) + + self.register_background_index_update( + "device_lists_stream_idx", + index_name="device_lists_stream_user_id", + table="device_lists_stream", + columns=["user_id", "device_id"], + ) + + # create a unique index on device_lists_remote_cache + self.register_background_index_update( + "device_lists_remote_cache_unique_idx", + index_name="device_lists_remote_cache_unique_id", + table="device_lists_remote_cache", + columns=["user_id", "device_id"], + unique=True, + ) + + # And one on device_lists_remote_extremeties + self.register_background_index_update( + "device_lists_remote_extremeties_unique_idx", + index_name="device_lists_remote_extremeties_unique_idx", + table="device_lists_remote_extremeties", + columns=["user_id"], + unique=True, + ) + + # once they complete, we can remove the old non-unique indexes. + self.register_background_update_handler( + DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES, + self._drop_device_list_streams_non_unique_indexes, + ) + + @defer.inlineCallbacks + def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size): + def f(conn): + txn = conn.cursor() + txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id") + txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id") + txn.close() + + yield self.runWithConnection(f) + yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES) + return 1 + + +class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): + def __init__(self, db_conn, hs): + super(DeviceStore, self).__init__(db_conn, hs) + + # Map of (user_id, device_id) -> bool. If there is an entry that implies + # the device exists. + self.device_id_exists_cache = Cache( + name="device_id_exists", keylen=2, max_entries=10000 + ) + + self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000) + + @defer.inlineCallbacks + def store_device(self, user_id, device_id, initial_device_display_name): + """Ensure the given device is known; add it to the store if not + + Args: + user_id (str): id of user associated with the device + device_id (str): id of device + initial_device_display_name (str): initial displayname of the + device. Ignored if device exists. + Returns: + defer.Deferred: boolean whether the device was inserted or an + existing device existed with that ID. + """ + key = (user_id, device_id) + if self.device_id_exists_cache.get(key, None): + return False + + try: + inserted = yield self._simple_insert( + "devices", + values={ + "user_id": user_id, + "device_id": device_id, + "display_name": initial_device_display_name, + }, + desc="store_device", + or_ignore=True, + ) + self.device_id_exists_cache.prefill(key, True) + return inserted + except Exception as e: + logger.error( + "store_device with device_id=%s(%r) user_id=%s(%r)" + " display_name=%s(%r) failed: %s", + type(device_id).__name__, + device_id, + type(user_id).__name__, + user_id, + type(initial_device_display_name).__name__, + initial_device_display_name, + e, + ) + raise StoreError(500, "Problem storing device.") + + @defer.inlineCallbacks + def delete_device(self, user_id, device_id): + """Delete a device. + + Args: + user_id (str): The ID of the user which owns the device + device_id (str): The ID of the device to delete + Returns: + defer.Deferred + """ + yield self._simple_delete_one( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + desc="delete_device", + ) + + self.device_id_exists_cache.invalidate((user_id, device_id)) + + @defer.inlineCallbacks + def delete_devices(self, user_id, device_ids): + """Deletes several devices. + + Args: + user_id (str): The ID of the user which owns the devices + device_ids (list): The IDs of the devices to delete + Returns: + defer.Deferred + """ + yield self._simple_delete_many( + table="devices", + column="device_id", + iterable=device_ids, + keyvalues={"user_id": user_id}, + desc="delete_devices", + ) + for device_id in device_ids: + self.device_id_exists_cache.invalidate((user_id, device_id)) + + def update_device(self, user_id, device_id, new_display_name=None): + """Update a device. + + Args: + user_id (str): The ID of the user which owns the device + device_id (str): The ID of the device to update + new_display_name (str|None): new displayname for device; None + to leave unchanged + Raises: + StoreError: if the device is not found + Returns: + defer.Deferred + """ + updates = {} + if new_display_name is not None: + updates["display_name"] = new_display_name + if not updates: + return defer.succeed(None) + return self._simple_update_one( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + updatevalues=updates, + desc="update_device", + ) + + @defer.inlineCallbacks + def mark_remote_user_device_list_as_unsubscribed(self, user_id): + """Mark that we no longer track device lists for remote user. + """ + yield self._simple_delete( + table="device_lists_remote_extremeties", + keyvalues={"user_id": user_id}, + desc="mark_remote_user_device_list_as_unsubscribed", + ) + self.get_device_list_last_stream_id_for_remote.invalidate((user_id,)) + + def update_remote_device_list_cache_entry( + self, user_id, device_id, content, stream_id + ): + """Updates a single device in the cache of a remote user's devicelist. + + Note: assumes that we are the only thread that can be updating this user's + device list. + + Args: + user_id (str): User to update device list for + device_id (str): ID of decivice being updated + content (dict): new data on this device + stream_id (int): the version of the device list + + Returns: + Deferred[None] + """ + return self.runInteraction( + "update_remote_device_list_cache_entry", + self._update_remote_device_list_cache_entry_txn, + user_id, + device_id, + content, + stream_id, + ) + + def _update_remote_device_list_cache_entry_txn( + self, txn, user_id, device_id, content, stream_id + ): + if content.get("deleted"): + self._simple_delete_txn( + txn, + table="device_lists_remote_cache", + keyvalues={"user_id": user_id, "device_id": device_id}, + ) + + txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id)) + else: + self._simple_upsert_txn( + txn, + table="device_lists_remote_cache", + keyvalues={"user_id": user_id, "device_id": device_id}, + values={"content": json.dumps(content)}, + # we don't need to lock, because we assume we are the only thread + # updating this user's devices. + lock=False, + ) + + txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id)) + txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,)) + txn.call_after( + self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) + ) + + self._simple_upsert_txn( + txn, + table="device_lists_remote_extremeties", + keyvalues={"user_id": user_id}, + values={"stream_id": stream_id}, + # again, we can assume we are the only thread updating this user's + # extremity. + lock=False, + ) + + def update_remote_device_list_cache(self, user_id, devices, stream_id): + """Replace the entire cache of the remote user's devices. + + Note: assumes that we are the only thread that can be updating this user's + device list. + + Args: + user_id (str): User to update device list for + devices (list[dict]): list of device objects supplied over federation + stream_id (int): the version of the device list + + Returns: + Deferred[None] + """ + return self.runInteraction( + "update_remote_device_list_cache", + self._update_remote_device_list_cache_txn, + user_id, + devices, + stream_id, + ) + + def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id): + self._simple_delete_txn( + txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id} + ) + + self._simple_insert_many_txn( + txn, + table="device_lists_remote_cache", + values=[ + { + "user_id": user_id, + "device_id": content["device_id"], + "content": json.dumps(content), + } + for content in devices + ], + ) + + txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,)) + txn.call_after(self._get_cached_user_device.invalidate_many, (user_id,)) + txn.call_after( + self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) + ) + + self._simple_upsert_txn( + txn, + table="device_lists_remote_extremeties", + keyvalues={"user_id": user_id}, + values={"stream_id": stream_id}, + # we don't need to lock, because we can assume we are the only thread + # updating this user's extremity. + lock=False, + ) + + @defer.inlineCallbacks + def add_device_change_to_streams(self, user_id, device_ids, hosts): + """Persist that a user's devices have been updated, and which hosts + (if any) should be poked. + """ + with self._device_list_id_gen.get_next() as stream_id: + yield self.runInteraction( + "add_device_change_to_streams", + self._add_device_change_txn, + user_id, + device_ids, + hosts, + stream_id, + ) + return stream_id + + def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id): + now = self._clock.time_msec() + + txn.call_after( + self._device_list_stream_cache.entity_has_changed, user_id, stream_id + ) + for host in hosts: + txn.call_after( + self._device_list_federation_stream_cache.entity_has_changed, + host, + stream_id, + ) + + # Delete older entries in the table, as we really only care about + # when the latest change happened. + txn.executemany( + """ + DELETE FROM device_lists_stream + WHERE user_id = ? AND device_id = ? AND stream_id < ? + """, + [(user_id, device_id, stream_id) for device_id in device_ids], + ) + + self._simple_insert_many_txn( + txn, + table="device_lists_stream", + values=[ + {"stream_id": stream_id, "user_id": user_id, "device_id": device_id} + for device_id in device_ids + ], + ) + + context = get_active_span_text_map() + + self._simple_insert_many_txn( + txn, + table="device_lists_outbound_pokes", + values=[ + { + "destination": destination, + "stream_id": stream_id, + "user_id": user_id, + "device_id": device_id, + "sent": False, + "ts": now, + "opentracing_context": json.dumps(context) + if whitelisted_homeserver(destination) + else "{}", + } + for destination in hosts + for device_id in device_ids + ], + ) + + def _prune_old_outbound_device_pokes(self): + """Delete old entries out of the device_lists_outbound_pokes to ensure + that we don't fill up due to dead servers. We keep one entry per + (destination, user_id) tuple to ensure that the prev_ids remain correct + if the server does come back. + """ + yesterday = self._clock.time_msec() - 24 * 60 * 60 * 1000 + + def _prune_txn(txn): + select_sql = """ + SELECT destination, user_id, max(stream_id) as stream_id + FROM device_lists_outbound_pokes + GROUP BY destination, user_id + HAVING min(ts) < ? AND count(*) > 1 + """ + + txn.execute(select_sql, (yesterday,)) + rows = txn.fetchall() + + if not rows: + return + + delete_sql = """ + DELETE FROM device_lists_outbound_pokes + WHERE ts < ? AND destination = ? AND user_id = ? AND stream_id < ? + """ + + txn.executemany( + delete_sql, ((yesterday, row[0], row[1], row[2]) for row in rows) + ) + + # Since we've deleted unsent deltas, we need to remove the entry + # of last successful sent so that the prev_ids are correctly set. + sql = """ + DELETE FROM device_lists_outbound_last_success + WHERE destination = ? AND user_id = ? + """ + txn.executemany(sql, ((row[0], row[1]) for row in rows)) + + logger.info("Pruned %d device list outbound pokes", txn.rowcount) + + return run_as_background_process( + "prune_old_outbound_device_pokes", + self.runInteraction, + "_prune_old_outbound_device_pokes", + _prune_txn, + ) diff --git a/synapse/storage/data_stores/main/directory.py b/synapse/storage/data_stores/main/directory.py new file mode 100644 index 0000000000..297966d9f4 --- /dev/null +++ b/synapse/storage/data_stores/main/directory.py @@ -0,0 +1,173 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple + +from twisted.internet import defer + +from synapse.api.errors import SynapseError +from synapse.storage._base import SQLBaseStore +from synapse.util.caches.descriptors import cached + +RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers")) + + +class DirectoryWorkerStore(SQLBaseStore): + @defer.inlineCallbacks + def get_association_from_room_alias(self, room_alias): + """ Get's the room_id and server list for a given room_alias + + Args: + room_alias (RoomAlias) + + Returns: + Deferred: results in namedtuple with keys "room_id" and + "servers" or None if no association can be found + """ + room_id = yield self._simple_select_one_onecol( + "room_aliases", + {"room_alias": room_alias.to_string()}, + "room_id", + allow_none=True, + desc="get_association_from_room_alias", + ) + + if not room_id: + return None + + servers = yield self._simple_select_onecol( + "room_alias_servers", + {"room_alias": room_alias.to_string()}, + "server", + desc="get_association_from_room_alias", + ) + + if not servers: + return None + + return RoomAliasMapping(room_id, room_alias.to_string(), servers) + + def get_room_alias_creator(self, room_alias): + return self._simple_select_one_onecol( + table="room_aliases", + keyvalues={"room_alias": room_alias}, + retcol="creator", + desc="get_room_alias_creator", + ) + + @cached(max_entries=5000) + def get_aliases_for_room(self, room_id): + return self._simple_select_onecol( + "room_aliases", + {"room_id": room_id}, + "room_alias", + desc="get_aliases_for_room", + ) + + +class DirectoryStore(DirectoryWorkerStore): + @defer.inlineCallbacks + def create_room_alias_association(self, room_alias, room_id, servers, creator=None): + """ Creates an association between a room alias and room_id/servers + + Args: + room_alias (RoomAlias) + room_id (str) + servers (list) + creator (str): Optional user_id of creator. + + Returns: + Deferred + """ + + def alias_txn(txn): + self._simple_insert_txn( + txn, + "room_aliases", + { + "room_alias": room_alias.to_string(), + "room_id": room_id, + "creator": creator, + }, + ) + + self._simple_insert_many_txn( + txn, + table="room_alias_servers", + values=[ + {"room_alias": room_alias.to_string(), "server": server} + for server in servers + ], + ) + + self._invalidate_cache_and_stream( + txn, self.get_aliases_for_room, (room_id,) + ) + + try: + ret = yield self.runInteraction("create_room_alias_association", alias_txn) + except self.database_engine.module.IntegrityError: + raise SynapseError( + 409, "Room alias %s already exists" % room_alias.to_string() + ) + return ret + + @defer.inlineCallbacks + def delete_room_alias(self, room_alias): + room_id = yield self.runInteraction( + "delete_room_alias", self._delete_room_alias_txn, room_alias + ) + + return room_id + + def _delete_room_alias_txn(self, txn, room_alias): + txn.execute( + "SELECT room_id FROM room_aliases WHERE room_alias = ?", + (room_alias.to_string(),), + ) + + res = txn.fetchone() + if res: + room_id = res[0] + else: + return None + + txn.execute( + "DELETE FROM room_aliases WHERE room_alias = ?", (room_alias.to_string(),) + ) + + txn.execute( + "DELETE FROM room_alias_servers WHERE room_alias = ?", + (room_alias.to_string(),), + ) + + self._invalidate_cache_and_stream(txn, self.get_aliases_for_room, (room_id,)) + + return room_id + + def update_aliases_for_room(self, old_room_id, new_room_id, creator): + def _update_aliases_for_room_txn(txn): + sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?" + txn.execute(sql, (new_room_id, creator, old_room_id)) + self._invalidate_cache_and_stream( + txn, self.get_aliases_for_room, (old_room_id,) + ) + self._invalidate_cache_and_stream( + txn, self.get_aliases_for_room, (new_room_id,) + ) + + return self.runInteraction( + "_update_aliases_for_room_txn", _update_aliases_for_room_txn + ) diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py new file mode 100644 index 0000000000..ef88e79293 --- /dev/null +++ b/synapse/storage/data_stores/main/e2e_room_keys.py @@ -0,0 +1,336 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from twisted.internet import defer + +from synapse.api.errors import StoreError +from synapse.logging.opentracing import log_kv, trace +from synapse.storage._base import SQLBaseStore + + +class EndToEndRoomKeyStore(SQLBaseStore): + @defer.inlineCallbacks + def get_e2e_room_key(self, user_id, version, room_id, session_id): + """Get the encrypted E2E room key for a given session from a given + backup version of room_keys. We only store the 'best' room key for a given + session at a given time, as determined by the handler. + + Args: + user_id(str): the user whose backup we're querying + version(str): the version ID of the backup for the set of keys we're querying + room_id(str): the ID of the room whose keys we're querying. + This is a bit redundant as it's implied by the session_id, but + we include for consistency with the rest of the API. + session_id(str): the session whose room_key we're querying. + + Returns: + A deferred dict giving the session_data and message metadata for + this room key. + """ + + row = yield self._simple_select_one( + table="e2e_room_keys", + keyvalues={ + "user_id": user_id, + "version": version, + "room_id": room_id, + "session_id": session_id, + }, + retcols=( + "first_message_index", + "forwarded_count", + "is_verified", + "session_data", + ), + desc="get_e2e_room_key", + ) + + row["session_data"] = json.loads(row["session_data"]) + + return row + + @defer.inlineCallbacks + def set_e2e_room_key(self, user_id, version, room_id, session_id, room_key): + """Replaces or inserts the encrypted E2E room key for a given session in + a given backup + + Args: + user_id(str): the user whose backup we're setting + version(str): the version ID of the backup we're updating + room_id(str): the ID of the room whose keys we're setting + session_id(str): the session whose room_key we're setting + room_key(dict): the room_key being set + Raises: + StoreError + """ + + yield self._simple_upsert( + table="e2e_room_keys", + keyvalues={ + "user_id": user_id, + "version": version, + "room_id": room_id, + "session_id": session_id, + }, + values={ + "first_message_index": room_key["first_message_index"], + "forwarded_count": room_key["forwarded_count"], + "is_verified": room_key["is_verified"], + "session_data": json.dumps(room_key["session_data"]), + }, + lock=False, + ) + log_kv( + { + "message": "Set room key", + "room_id": room_id, + "session_id": session_id, + "room_key": room_key, + } + ) + + @trace + @defer.inlineCallbacks + def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): + """Bulk get the E2E room keys for a given backup, optionally filtered to a given + room, or a given session. + + Args: + user_id(str): the user whose backup we're querying + version(str): the version ID of the backup for the set of keys we're querying + room_id(str): Optional. the ID of the room whose keys we're querying, if any. + If not specified, we return the keys for all the rooms in the backup. + session_id(str): Optional. the session whose room_key we're querying, if any. + If specified, we also require the room_id to be specified. + If not specified, we return all the keys in this version of + the backup (or for the specified room) + + Returns: + A deferred list of dicts giving the session_data and message metadata for + these room keys. + """ + + try: + version = int(version) + except ValueError: + return {"rooms": {}} + + keyvalues = {"user_id": user_id, "version": version} + if room_id: + keyvalues["room_id"] = room_id + if session_id: + keyvalues["session_id"] = session_id + + rows = yield self._simple_select_list( + table="e2e_room_keys", + keyvalues=keyvalues, + retcols=( + "user_id", + "room_id", + "session_id", + "first_message_index", + "forwarded_count", + "is_verified", + "session_data", + ), + desc="get_e2e_room_keys", + ) + + sessions = {"rooms": {}} + for row in rows: + room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}}) + room_entry["sessions"][row["session_id"]] = { + "first_message_index": row["first_message_index"], + "forwarded_count": row["forwarded_count"], + "is_verified": row["is_verified"], + "session_data": json.loads(row["session_data"]), + } + + return sessions + + @trace + @defer.inlineCallbacks + def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): + """Bulk delete the E2E room keys for a given backup, optionally filtered to a given + room or a given session. + + Args: + user_id(str): the user whose backup we're deleting from + version(str): the version ID of the backup for the set of keys we're deleting + room_id(str): Optional. the ID of the room whose keys we're deleting, if any. + If not specified, we delete the keys for all the rooms in the backup. + session_id(str): Optional. the session whose room_key we're querying, if any. + If specified, we also require the room_id to be specified. + If not specified, we delete all the keys in this version of + the backup (or for the specified room) + + Returns: + A deferred of the deletion transaction + """ + + keyvalues = {"user_id": user_id, "version": int(version)} + if room_id: + keyvalues["room_id"] = room_id + if session_id: + keyvalues["session_id"] = session_id + + yield self._simple_delete( + table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys" + ) + + @staticmethod + def _get_current_version(txn, user_id): + txn.execute( + "SELECT MAX(version) FROM e2e_room_keys_versions " + "WHERE user_id=? AND deleted=0", + (user_id,), + ) + row = txn.fetchone() + if not row: + raise StoreError(404, "No current backup version") + return row[0] + + def get_e2e_room_keys_version_info(self, user_id, version=None): + """Get info metadata about a version of our room_keys backup. + + Args: + user_id(str): the user whose backup we're querying + version(str): Optional. the version ID of the backup we're querying about + If missing, we return the information about the current version. + Raises: + StoreError: with code 404 if there are no e2e_room_keys_versions present + Returns: + A deferred dict giving the info metadata for this backup version, with + fields including: + version(str) + algorithm(str) + auth_data(object): opaque dict supplied by the client + """ + + def _get_e2e_room_keys_version_info_txn(txn): + if version is None: + this_version = self._get_current_version(txn, user_id) + else: + try: + this_version = int(version) + except ValueError: + # Our versions are all ints so if we can't convert it to an integer, + # it isn't there. + raise StoreError(404, "No row found") + + result = self._simple_select_one_txn( + txn, + table="e2e_room_keys_versions", + keyvalues={"user_id": user_id, "version": this_version, "deleted": 0}, + retcols=("version", "algorithm", "auth_data"), + ) + result["auth_data"] = json.loads(result["auth_data"]) + result["version"] = str(result["version"]) + return result + + return self.runInteraction( + "get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn + ) + + @trace + def create_e2e_room_keys_version(self, user_id, info): + """Atomically creates a new version of this user's e2e_room_keys store + with the given version info. + + Args: + user_id(str): the user whose backup we're creating a version + info(dict): the info about the backup version to be created + + Returns: + A deferred string for the newly created version ID + """ + + def _create_e2e_room_keys_version_txn(txn): + txn.execute( + "SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?", + (user_id,), + ) + current_version = txn.fetchone()[0] + if current_version is None: + current_version = "0" + + new_version = str(int(current_version) + 1) + + self._simple_insert_txn( + txn, + table="e2e_room_keys_versions", + values={ + "user_id": user_id, + "version": new_version, + "algorithm": info["algorithm"], + "auth_data": json.dumps(info["auth_data"]), + }, + ) + + return new_version + + return self.runInteraction( + "create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn + ) + + @trace + def update_e2e_room_keys_version(self, user_id, version, info): + """Update a given backup version + + Args: + user_id(str): the user whose backup version we're updating + version(str): the version ID of the backup version we're updating + info(dict): the new backup version info to store + """ + + return self._simple_update( + table="e2e_room_keys_versions", + keyvalues={"user_id": user_id, "version": version}, + updatevalues={"auth_data": json.dumps(info["auth_data"])}, + desc="update_e2e_room_keys_version", + ) + + @trace + def delete_e2e_room_keys_version(self, user_id, version=None): + """Delete a given backup version of the user's room keys. + Doesn't delete their actual key data. + + Args: + user_id(str): the user whose backup version we're deleting + version(str): Optional. the version ID of the backup version we're deleting + If missing, we delete the current backup version info. + Raises: + StoreError: with code 404 if there are no e2e_room_keys_versions present, + or if the version requested doesn't exist. + """ + + def _delete_e2e_room_keys_version_txn(txn): + if version is None: + this_version = self._get_current_version(txn, user_id) + else: + this_version = version + + return self._simple_update_one_txn( + txn, + table="e2e_room_keys_versions", + keyvalues={"user_id": user_id, "version": this_version}, + updatevalues={"deleted": 1}, + ) + + return self.runInteraction( + "delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn + ) diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py new file mode 100644 index 0000000000..7e44f41046 --- /dev/null +++ b/synapse/storage/data_stores/main/end_to_end_keys.py @@ -0,0 +1,322 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from six import iteritems + +from canonicaljson import encode_canonical_json + +from twisted.internet import defer + +from synapse.logging.opentracing import log_kv, set_tag, trace +from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.util.caches.descriptors import cached + + +class EndToEndKeyWorkerStore(SQLBaseStore): + @trace + @defer.inlineCallbacks + def get_e2e_device_keys( + self, query_list, include_all_devices=False, include_deleted_devices=False + ): + """Fetch a list of device keys. + Args: + query_list(list): List of pairs of user_ids and device_ids. + include_all_devices (bool): whether to include entries for devices + that don't have device keys + include_deleted_devices (bool): whether to include null entries for + devices which no longer exist (but were in the query_list). + This option only takes effect if include_all_devices is true. + Returns: + Dict mapping from user-id to dict mapping from device_id to + key data. The key data will be a dict in the same format as the + DeviceKeys type returned by POST /_matrix/client/r0/keys/query. + """ + set_tag("query_list", query_list) + if not query_list: + return {} + + results = yield self.runInteraction( + "get_e2e_device_keys", + self._get_e2e_device_keys_txn, + query_list, + include_all_devices, + include_deleted_devices, + ) + + # Build the result structure, un-jsonify the results, and add the + # "unsigned" section + rv = {} + for user_id, device_keys in iteritems(results): + rv[user_id] = {} + for device_id, device_info in iteritems(device_keys): + r = db_to_json(device_info.pop("key_json")) + r["unsigned"] = {} + display_name = device_info["device_display_name"] + if display_name is not None: + r["unsigned"]["device_display_name"] = display_name + rv[user_id][device_id] = r + + return rv + + @trace + def _get_e2e_device_keys_txn( + self, txn, query_list, include_all_devices=False, include_deleted_devices=False + ): + set_tag("include_all_devices", include_all_devices) + set_tag("include_deleted_devices", include_deleted_devices) + + query_clauses = [] + query_params = [] + + if include_all_devices is False: + include_deleted_devices = False + + if include_deleted_devices: + deleted_devices = set(query_list) + + for (user_id, device_id) in query_list: + query_clause = "user_id = ?" + query_params.append(user_id) + + if device_id is not None: + query_clause += " AND device_id = ?" + query_params.append(device_id) + + query_clauses.append(query_clause) + + sql = ( + "SELECT user_id, device_id, " + " d.display_name AS device_display_name, " + " k.key_json" + " FROM devices d" + " %s JOIN e2e_device_keys_json k USING (user_id, device_id)" + " WHERE %s" + ) % ( + "LEFT" if include_all_devices else "INNER", + " OR ".join("(" + q + ")" for q in query_clauses), + ) + + txn.execute(sql, query_params) + rows = self.cursor_to_dict(txn) + + result = {} + for row in rows: + if include_deleted_devices: + deleted_devices.remove((row["user_id"], row["device_id"])) + result.setdefault(row["user_id"], {})[row["device_id"]] = row + + if include_deleted_devices: + for user_id, device_id in deleted_devices: + result.setdefault(user_id, {})[device_id] = None + + log_kv(result) + return result + + @defer.inlineCallbacks + def get_e2e_one_time_keys(self, user_id, device_id, key_ids): + """Retrieve a number of one-time keys for a user + + Args: + user_id(str): id of user to get keys for + device_id(str): id of device to get keys for + key_ids(list[str]): list of key ids (excluding algorithm) to + retrieve + + Returns: + deferred resolving to Dict[(str, str), str]: map from (algorithm, + key_id) to json string for key + """ + + rows = yield self._simple_select_many_batch( + table="e2e_one_time_keys_json", + column="key_id", + iterable=key_ids, + retcols=("algorithm", "key_id", "key_json"), + keyvalues={"user_id": user_id, "device_id": device_id}, + desc="add_e2e_one_time_keys_check", + ) + result = {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows} + log_kv({"message": "Fetched one time keys for user", "one_time_keys": result}) + return result + + @defer.inlineCallbacks + def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys): + """Insert some new one time keys for a device. Errors if any of the + keys already exist. + + Args: + user_id(str): id of user to get keys for + device_id(str): id of device to get keys for + time_now(long): insertion time to record (ms since epoch) + new_keys(iterable[(str, str, str)]: keys to add - each a tuple of + (algorithm, key_id, key json) + """ + + def _add_e2e_one_time_keys(txn): + set_tag("user_id", user_id) + set_tag("device_id", device_id) + set_tag("new_keys", new_keys) + # We are protected from race between lookup and insertion due to + # a unique constraint. If there is a race of two calls to + # `add_e2e_one_time_keys` then they'll conflict and we will only + # insert one set. + self._simple_insert_many_txn( + txn, + table="e2e_one_time_keys_json", + values=[ + { + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + "key_id": key_id, + "ts_added_ms": time_now, + "key_json": json_bytes, + } + for algorithm, key_id, json_bytes in new_keys + ], + ) + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id) + ) + + yield self.runInteraction( + "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys + ) + + @cached(max_entries=10000) + def count_e2e_one_time_keys(self, user_id, device_id): + """ Count the number of one time keys the server has for a device + Returns: + Dict mapping from algorithm to number of keys for that algorithm. + """ + + def _count_e2e_one_time_keys(txn): + sql = ( + "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json" + " WHERE user_id = ? AND device_id = ?" + " GROUP BY algorithm" + ) + txn.execute(sql, (user_id, device_id)) + result = {} + for algorithm, key_count in txn: + result[algorithm] = key_count + return result + + return self.runInteraction("count_e2e_one_time_keys", _count_e2e_one_time_keys) + + +class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): + def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys): + """Stores device keys for a device. Returns whether there was a change + or the keys were already in the database. + """ + + def _set_e2e_device_keys_txn(txn): + set_tag("user_id", user_id) + set_tag("device_id", device_id) + set_tag("time_now", time_now) + set_tag("device_keys", device_keys) + + old_key_json = self._simple_select_one_onecol_txn( + txn, + table="e2e_device_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + retcol="key_json", + allow_none=True, + ) + + # In py3 we need old_key_json to match new_key_json type. The DB + # returns unicode while encode_canonical_json returns bytes. + new_key_json = encode_canonical_json(device_keys).decode("utf-8") + + if old_key_json == new_key_json: + log_kv({"Message": "Device key already stored."}) + return False + + self._simple_upsert_txn( + txn, + table="e2e_device_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + values={"ts_added_ms": time_now, "key_json": new_key_json}, + ) + log_kv({"message": "Device keys stored."}) + return True + + return self.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn) + + def claim_e2e_one_time_keys(self, query_list): + """Take a list of one time keys out of the database""" + + @trace + def _claim_e2e_one_time_keys(txn): + sql = ( + "SELECT key_id, key_json FROM e2e_one_time_keys_json" + " WHERE user_id = ? AND device_id = ? AND algorithm = ?" + " LIMIT 1" + ) + result = {} + delete = [] + for user_id, device_id, algorithm in query_list: + user_result = result.setdefault(user_id, {}) + device_result = user_result.setdefault(device_id, {}) + txn.execute(sql, (user_id, device_id, algorithm)) + for key_id, key_json in txn: + device_result[algorithm + ":" + key_id] = key_json + delete.append((user_id, device_id, algorithm, key_id)) + sql = ( + "DELETE FROM e2e_one_time_keys_json" + " WHERE user_id = ? AND device_id = ? AND algorithm = ?" + " AND key_id = ?" + ) + for user_id, device_id, algorithm, key_id in delete: + log_kv( + { + "message": "Executing claim e2e_one_time_keys transaction on database." + } + ) + txn.execute(sql, (user_id, device_id, algorithm, key_id)) + log_kv({"message": "finished executing and invalidating cache"}) + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id) + ) + return result + + return self.runInteraction("claim_e2e_one_time_keys", _claim_e2e_one_time_keys) + + def delete_e2e_keys_by_device(self, user_id, device_id): + def delete_e2e_keys_by_device_txn(txn): + log_kv( + { + "message": "Deleting keys for device", + "device_id": device_id, + "user_id": user_id, + } + ) + self._simple_delete_txn( + txn, + table="e2e_device_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + ) + self._simple_delete_txn( + txn, + table="e2e_one_time_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + ) + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id) + ) + + return self.runInteraction( + "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn + ) diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py new file mode 100644 index 0000000000..a470a48e0f --- /dev/null +++ b/synapse/storage/data_stores/main/event_federation.py @@ -0,0 +1,672 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools +import logging +import random + +from six.moves import range +from six.moves.queue import Empty, PriorityQueue + +from unpaddedbase64 import encode_base64 + +from twisted.internet import defer + +from synapse.api.errors import StoreError +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.data_stores.main.signatures import SignatureWorkerStore +from synapse.util.caches.descriptors import cached + +logger = logging.getLogger(__name__) + + +class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): + def get_auth_chain(self, event_ids, include_given=False): + """Get auth events for given event_ids. The events *must* be state events. + + Args: + event_ids (list): state events + include_given (bool): include the given events in result + + Returns: + list of events + """ + return self.get_auth_chain_ids( + event_ids, include_given=include_given + ).addCallback(self.get_events_as_list) + + def get_auth_chain_ids(self, event_ids, include_given=False): + """Get auth events for given event_ids. The events *must* be state events. + + Args: + event_ids (list): state events + include_given (bool): include the given events in result + + Returns: + list of event_ids + """ + return self.runInteraction( + "get_auth_chain_ids", self._get_auth_chain_ids_txn, event_ids, include_given + ) + + def _get_auth_chain_ids_txn(self, txn, event_ids, include_given): + if include_given: + results = set(event_ids) + else: + results = set() + + base_sql = "SELECT auth_id FROM event_auth WHERE " + + front = set(event_ids) + while front: + new_front = set() + front_list = list(front) + chunks = [front_list[x : x + 100] for x in range(0, len(front), 100)] + for chunk in chunks: + clause, args = make_in_list_sql_clause( + txn.database_engine, "event_id", chunk + ) + txn.execute(base_sql + clause, list(args)) + new_front.update([r[0] for r in txn]) + + new_front -= results + + front = new_front + results.update(front) + + return list(results) + + def get_oldest_events_in_room(self, room_id): + return self.runInteraction( + "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id + ) + + def get_oldest_events_with_depth_in_room(self, room_id): + return self.runInteraction( + "get_oldest_events_with_depth_in_room", + self.get_oldest_events_with_depth_in_room_txn, + room_id, + ) + + def get_oldest_events_with_depth_in_room_txn(self, txn, room_id): + sql = ( + "SELECT b.event_id, MAX(e.depth) FROM events as e" + " INNER JOIN event_edges as g" + " ON g.event_id = e.event_id" + " INNER JOIN event_backward_extremities as b" + " ON g.prev_event_id = b.event_id" + " WHERE b.room_id = ? AND g.is_state is ?" + " GROUP BY b.event_id" + ) + + txn.execute(sql, (room_id, False)) + + return dict(txn) + + @defer.inlineCallbacks + def get_max_depth_of(self, event_ids): + """Returns the max depth of a set of event IDs + + Args: + event_ids (list[str]) + + Returns + Deferred[int] + """ + rows = yield self._simple_select_many_batch( + table="events", + column="event_id", + iterable=event_ids, + retcols=("depth",), + desc="get_max_depth_of", + ) + + if not rows: + return 0 + else: + return max(row["depth"] for row in rows) + + def _get_oldest_events_in_room_txn(self, txn, room_id): + return self._simple_select_onecol_txn( + txn, + table="event_backward_extremities", + keyvalues={"room_id": room_id}, + retcol="event_id", + ) + + @defer.inlineCallbacks + def get_prev_events_for_room(self, room_id): + """ + Gets a subset of the current forward extremities in the given room. + + Limits the result to 10 extremities, so that we can avoid creating + events which refer to hundreds of prev_events. + + Args: + room_id (str): room_id + + Returns: + Deferred[list[(str, dict[str, str], int)]] + for each event, a tuple of (event_id, hashes, depth) + where *hashes* is a map from algorithm to hash. + """ + res = yield self.get_latest_event_ids_and_hashes_in_room(room_id) + if len(res) > 10: + # Sort by reverse depth, so we point to the most recent. + res.sort(key=lambda a: -a[2]) + + # we use half of the limit for the actual most recent events, and + # the other half to randomly point to some of the older events, to + # make sure that we don't completely ignore the older events. + res = res[0:5] + random.sample(res[5:], 5) + + return res + + def get_latest_event_ids_and_hashes_in_room(self, room_id): + """ + Gets the current forward extremities in the given room + + Args: + room_id (str): room_id + + Returns: + Deferred[list[(str, dict[str, str], int)]] + for each event, a tuple of (event_id, hashes, depth) + where *hashes* is a map from algorithm to hash. + """ + + return self.runInteraction( + "get_latest_event_ids_and_hashes_in_room", + self._get_latest_event_ids_and_hashes_in_room, + room_id, + ) + + def get_rooms_with_many_extremities(self, min_count, limit, room_id_filter): + """Get the top rooms with at least N extremities. + + Args: + min_count (int): The minimum number of extremities + limit (int): The maximum number of rooms to return. + room_id_filter (iterable[str]): room_ids to exclude from the results + + Returns: + Deferred[list]: At most `limit` room IDs that have at least + `min_count` extremities, sorted by extremity count. + """ + + def _get_rooms_with_many_extremities_txn(txn): + where_clause = "1=1" + if room_id_filter: + where_clause = "room_id NOT IN (%s)" % ( + ",".join("?" for _ in room_id_filter), + ) + + sql = """ + SELECT room_id FROM event_forward_extremities + WHERE %s + GROUP BY room_id + HAVING count(*) > ? + ORDER BY count(*) DESC + LIMIT ? + """ % ( + where_clause, + ) + + query_args = list(itertools.chain(room_id_filter, [min_count, limit])) + txn.execute(sql, query_args) + return [room_id for room_id, in txn] + + return self.runInteraction( + "get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn + ) + + @cached(max_entries=5000, iterable=True) + def get_latest_event_ids_in_room(self, room_id): + return self._simple_select_onecol( + table="event_forward_extremities", + keyvalues={"room_id": room_id}, + retcol="event_id", + desc="get_latest_event_ids_in_room", + ) + + def _get_latest_event_ids_and_hashes_in_room(self, txn, room_id): + sql = ( + "SELECT e.event_id, e.depth FROM events as e " + "INNER JOIN event_forward_extremities as f " + "ON e.event_id = f.event_id " + "AND e.room_id = f.room_id " + "WHERE f.room_id = ?" + ) + + txn.execute(sql, (room_id,)) + + results = [] + for event_id, depth in txn.fetchall(): + hashes = self._get_event_reference_hashes_txn(txn, event_id) + prev_hashes = { + k: encode_base64(v) for k, v in hashes.items() if k == "sha256" + } + results.append((event_id, prev_hashes, depth)) + + return results + + def get_min_depth(self, room_id): + """ For hte given room, get the minimum depth we have seen for it. + """ + return self.runInteraction( + "get_min_depth", self._get_min_depth_interaction, room_id + ) + + def _get_min_depth_interaction(self, txn, room_id): + min_depth = self._simple_select_one_onecol_txn( + txn, + table="room_depth", + keyvalues={"room_id": room_id}, + retcol="min_depth", + allow_none=True, + ) + + return int(min_depth) if min_depth is not None else None + + def get_forward_extremeties_for_room(self, room_id, stream_ordering): + """For a given room_id and stream_ordering, return the forward + extremeties of the room at that point in "time". + + Throws a StoreError if we have since purged the index for + stream_orderings from that point. + + Args: + room_id (str): + stream_ordering (int): + + Returns: + deferred, which resolves to a list of event_ids + """ + # We want to make the cache more effective, so we clamp to the last + # change before the given ordering. + last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id) + + # We don't always have a full stream_to_exterm_id table, e.g. after + # the upgrade that introduced it, so we make sure we never ask for a + # stream_ordering from before a restart + last_change = max(self._stream_order_on_start, last_change) + + # provided the last_change is recent enough, we now clamp the requested + # stream_ordering to it. + if last_change > self.stream_ordering_month_ago: + stream_ordering = min(last_change, stream_ordering) + + return self._get_forward_extremeties_for_room(room_id, stream_ordering) + + @cached(max_entries=5000, num_args=2) + def _get_forward_extremeties_for_room(self, room_id, stream_ordering): + """For a given room_id and stream_ordering, return the forward + extremeties of the room at that point in "time". + + Throws a StoreError if we have since purged the index for + stream_orderings from that point. + """ + + if stream_ordering <= self.stream_ordering_month_ago: + raise StoreError(400, "stream_ordering too old") + + sql = """ + SELECT event_id FROM stream_ordering_to_exterm + INNER JOIN ( + SELECT room_id, MAX(stream_ordering) AS stream_ordering + FROM stream_ordering_to_exterm + WHERE stream_ordering <= ? GROUP BY room_id + ) AS rms USING (room_id, stream_ordering) + WHERE room_id = ? + """ + + def get_forward_extremeties_for_room_txn(txn): + txn.execute(sql, (stream_ordering, room_id)) + return [event_id for event_id, in txn] + + return self.runInteraction( + "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn + ) + + def get_backfill_events(self, room_id, event_list, limit): + """Get a list of Events for a given topic that occurred before (and + including) the events in event_list. Return a list of max size `limit` + + Args: + txn + room_id (str) + event_list (list) + limit (int) + """ + return ( + self.runInteraction( + "get_backfill_events", + self._get_backfill_events, + room_id, + event_list, + limit, + ) + .addCallback(self.get_events_as_list) + .addCallback(lambda l: sorted(l, key=lambda e: -e.depth)) + ) + + def _get_backfill_events(self, txn, room_id, event_list, limit): + logger.debug( + "_get_backfill_events: %s, %s, %s", room_id, repr(event_list), limit + ) + + event_results = set() + + # We want to make sure that we do a breadth-first, "depth" ordered + # search. + + query = ( + "SELECT depth, prev_event_id FROM event_edges" + " INNER JOIN events" + " ON prev_event_id = events.event_id" + " WHERE event_edges.event_id = ?" + " AND event_edges.is_state = ?" + " LIMIT ?" + ) + + queue = PriorityQueue() + + for event_id in event_list: + depth = self._simple_select_one_onecol_txn( + txn, + table="events", + keyvalues={"event_id": event_id, "room_id": room_id}, + retcol="depth", + allow_none=True, + ) + + if depth: + queue.put((-depth, event_id)) + + while not queue.empty() and len(event_results) < limit: + try: + _, event_id = queue.get_nowait() + except Empty: + break + + if event_id in event_results: + continue + + event_results.add(event_id) + + txn.execute(query, (event_id, False, limit - len(event_results))) + + for row in txn: + if row[1] not in event_results: + queue.put((-row[0], row[1])) + + return event_results + + @defer.inlineCallbacks + def get_missing_events(self, room_id, earliest_events, latest_events, limit): + ids = yield self.runInteraction( + "get_missing_events", + self._get_missing_events, + room_id, + earliest_events, + latest_events, + limit, + ) + events = yield self.get_events_as_list(ids) + return events + + def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit): + + seen_events = set(earliest_events) + front = set(latest_events) - seen_events + event_results = [] + + query = ( + "SELECT prev_event_id FROM event_edges " + "WHERE room_id = ? AND event_id = ? AND is_state = ? " + "LIMIT ?" + ) + + while front and len(event_results) < limit: + new_front = set() + for event_id in front: + txn.execute( + query, (room_id, event_id, False, limit - len(event_results)) + ) + + new_results = set(t[0] for t in txn) - seen_events + + new_front |= new_results + seen_events |= new_results + event_results.extend(new_results) + + front = new_front + + # we built the list working backwards from latest_events; we now need to + # reverse it so that the events are approximately chronological. + event_results.reverse() + return event_results + + @defer.inlineCallbacks + def get_successor_events(self, event_ids): + """Fetch all events that have the given events as a prev event + + Args: + event_ids (iterable[str]) + + Returns: + Deferred[list[str]] + """ + rows = yield self._simple_select_many_batch( + table="event_edges", + column="prev_event_id", + iterable=event_ids, + retcols=("event_id",), + desc="get_successor_events", + ) + + return [row["event_id"] for row in rows] + + +class EventFederationStore(EventFederationWorkerStore): + """ Responsible for storing and serving up the various graphs associated + with an event. Including the main event graph and the auth chains for an + event. + + Also has methods for getting the front (latest) and back (oldest) edges + of the event graphs. These are used to generate the parents for new events + and backfilling from another server respectively. + """ + + EVENT_AUTH_STATE_ONLY = "event_auth_state_only" + + def __init__(self, db_conn, hs): + super(EventFederationStore, self).__init__(db_conn, hs) + + self.register_background_update_handler( + self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth + ) + + hs.get_clock().looping_call( + self._delete_old_forward_extrem_cache, 60 * 60 * 1000 + ) + + def _update_min_depth_for_room_txn(self, txn, room_id, depth): + min_depth = self._get_min_depth_interaction(txn, room_id) + + if min_depth and depth >= min_depth: + return + + self._simple_upsert_txn( + txn, + table="room_depth", + keyvalues={"room_id": room_id}, + values={"min_depth": depth}, + ) + + def _handle_mult_prev_events(self, txn, events): + """ + For the given event, update the event edges table and forward and + backward extremities tables. + """ + self._simple_insert_many_txn( + txn, + table="event_edges", + values=[ + { + "event_id": ev.event_id, + "prev_event_id": e_id, + "room_id": ev.room_id, + "is_state": False, + } + for ev in events + for e_id in ev.prev_event_ids() + ], + ) + + self._update_backward_extremeties(txn, events) + + def _update_backward_extremeties(self, txn, events): + """Updates the event_backward_extremities tables based on the new/updated + events being persisted. + + This is called for new events *and* for events that were outliers, but + are now being persisted as non-outliers. + + Forward extremities are handled when we first start persisting the events. + """ + events_by_room = {} + for ev in events: + events_by_room.setdefault(ev.room_id, []).append(ev) + + query = ( + "INSERT INTO event_backward_extremities (event_id, room_id)" + " SELECT ?, ? WHERE NOT EXISTS (" + " SELECT 1 FROM event_backward_extremities" + " WHERE event_id = ? AND room_id = ?" + " )" + " AND NOT EXISTS (" + " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " + " AND outlier = ?" + " )" + ) + + txn.executemany( + query, + [ + (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False) + for ev in events + for e_id in ev.prev_event_ids() + if not ev.internal_metadata.is_outlier() + ], + ) + + query = ( + "DELETE FROM event_backward_extremities" + " WHERE event_id = ? AND room_id = ?" + ) + txn.executemany( + query, + [ + (ev.event_id, ev.room_id) + for ev in events + if not ev.internal_metadata.is_outlier() + ], + ) + + def _delete_old_forward_extrem_cache(self): + def _delete_old_forward_extrem_cache_txn(txn): + # Delete entries older than a month, while making sure we don't delete + # the only entries for a room. + sql = """ + DELETE FROM stream_ordering_to_exterm + WHERE + room_id IN ( + SELECT room_id + FROM stream_ordering_to_exterm + WHERE stream_ordering > ? + ) AND stream_ordering < ? + """ + txn.execute( + sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago) + ) + + return run_as_background_process( + "delete_old_forward_extrem_cache", + self.runInteraction, + "_delete_old_forward_extrem_cache", + _delete_old_forward_extrem_cache_txn, + ) + + def clean_room_for_join(self, room_id): + return self.runInteraction( + "clean_room_for_join", self._clean_room_for_join_txn, room_id + ) + + def _clean_room_for_join_txn(self, txn, room_id): + query = "DELETE FROM event_forward_extremities WHERE room_id = ?" + + txn.execute(query, (room_id,)) + txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) + + @defer.inlineCallbacks + def _background_delete_non_state_event_auth(self, progress, batch_size): + def delete_event_auth(txn): + target_min_stream_id = progress.get("target_min_stream_id_inclusive") + max_stream_id = progress.get("max_stream_id_exclusive") + + if not target_min_stream_id or not max_stream_id: + txn.execute("SELECT COALESCE(MIN(stream_ordering), 0) FROM events") + rows = txn.fetchall() + target_min_stream_id = rows[0][0] + + txn.execute("SELECT COALESCE(MAX(stream_ordering), 0) FROM events") + rows = txn.fetchall() + max_stream_id = rows[0][0] + + min_stream_id = max_stream_id - batch_size + + sql = """ + DELETE FROM event_auth + WHERE event_id IN ( + SELECT event_id FROM events + LEFT JOIN state_events USING (room_id, event_id) + WHERE ? <= stream_ordering AND stream_ordering < ? + AND state_key IS null + ) + """ + + txn.execute(sql, (min_stream_id, max_stream_id)) + + new_progress = { + "target_min_stream_id_inclusive": target_min_stream_id, + "max_stream_id_exclusive": min_stream_id, + } + + self._background_update_progress_txn( + txn, self.EVENT_AUTH_STATE_ONLY, new_progress + ) + + return min_stream_id >= target_min_stream_id + + result = yield self.runInteraction( + self.EVENT_AUTH_STATE_ONLY, delete_event_auth + ) + + if not result: + yield self._end_background_update(self.EVENT_AUTH_STATE_ONLY) + + return batch_size diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py new file mode 100644 index 0000000000..22025effbc --- /dev/null +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -0,0 +1,960 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from six import iteritems + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage._base import LoggingTransaction, SQLBaseStore +from synapse.util.caches.descriptors import cachedInlineCallbacks + +logger = logging.getLogger(__name__) + + +DEFAULT_NOTIF_ACTION = ["notify", {"set_tweak": "highlight", "value": False}] +DEFAULT_HIGHLIGHT_ACTION = [ + "notify", + {"set_tweak": "sound", "value": "default"}, + {"set_tweak": "highlight"}, +] + + +def _serialize_action(actions, is_highlight): + """Custom serializer for actions. This allows us to "compress" common actions. + + We use the fact that most users have the same actions for notifs (and for + highlights). + We store these default actions as the empty string rather than the full JSON. + Since the empty string isn't valid JSON there is no risk of this clashing with + any real JSON actions + """ + if is_highlight: + if actions == DEFAULT_HIGHLIGHT_ACTION: + return "" # We use empty string as the column is non-NULL + else: + if actions == DEFAULT_NOTIF_ACTION: + return "" + return json.dumps(actions) + + +def _deserialize_action(actions, is_highlight): + """Custom deserializer for actions. This allows us to "compress" common actions + """ + if actions: + return json.loads(actions) + + if is_highlight: + return DEFAULT_HIGHLIGHT_ACTION + else: + return DEFAULT_NOTIF_ACTION + + +class EventPushActionsWorkerStore(SQLBaseStore): + def __init__(self, db_conn, hs): + super(EventPushActionsWorkerStore, self).__init__(db_conn, hs) + + # These get correctly set by _find_stream_orderings_for_times_txn + self.stream_ordering_month_ago = None + self.stream_ordering_day_ago = None + + cur = LoggingTransaction( + db_conn.cursor(), + name="_find_stream_orderings_for_times_txn", + database_engine=self.database_engine, + ) + self._find_stream_orderings_for_times_txn(cur) + cur.close() + + self.find_stream_orderings_looping_call = self._clock.looping_call( + self._find_stream_orderings_for_times, 10 * 60 * 1000 + ) + self._rotate_delay = 3 + self._rotate_count = 10000 + + @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000) + def get_unread_event_push_actions_by_room_for_user( + self, room_id, user_id, last_read_event_id + ): + ret = yield self.runInteraction( + "get_unread_event_push_actions_by_room", + self._get_unread_counts_by_receipt_txn, + room_id, + user_id, + last_read_event_id, + ) + return ret + + def _get_unread_counts_by_receipt_txn( + self, txn, room_id, user_id, last_read_event_id + ): + sql = ( + "SELECT stream_ordering" + " FROM events" + " WHERE room_id = ? AND event_id = ?" + ) + txn.execute(sql, (room_id, last_read_event_id)) + results = txn.fetchall() + if len(results) == 0: + return {"notify_count": 0, "highlight_count": 0} + + stream_ordering = results[0][0] + + return self._get_unread_counts_by_pos_txn( + txn, room_id, user_id, stream_ordering + ) + + def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering): + + # First get number of notifications. + # We don't need to put a notif=1 clause as all rows always have + # notif=1 + sql = ( + "SELECT count(*)" + " FROM event_push_actions ea" + " WHERE" + " user_id = ?" + " AND room_id = ?" + " AND stream_ordering > ?" + ) + + txn.execute(sql, (user_id, room_id, stream_ordering)) + row = txn.fetchone() + notify_count = row[0] if row else 0 + + txn.execute( + """ + SELECT notif_count FROM event_push_summary + WHERE room_id = ? AND user_id = ? AND stream_ordering > ? + """, + (room_id, user_id, stream_ordering), + ) + rows = txn.fetchall() + if rows: + notify_count += rows[0][0] + + # Now get the number of highlights + sql = ( + "SELECT count(*)" + " FROM event_push_actions ea" + " WHERE" + " highlight = 1" + " AND user_id = ?" + " AND room_id = ?" + " AND stream_ordering > ?" + ) + + txn.execute(sql, (user_id, room_id, stream_ordering)) + row = txn.fetchone() + highlight_count = row[0] if row else 0 + + return {"notify_count": notify_count, "highlight_count": highlight_count} + + @defer.inlineCallbacks + def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering): + def f(txn): + sql = ( + "SELECT DISTINCT(user_id) FROM event_push_actions WHERE" + " stream_ordering >= ? AND stream_ordering <= ?" + ) + txn.execute(sql, (min_stream_ordering, max_stream_ordering)) + return [r[0] for r in txn] + + ret = yield self.runInteraction("get_push_action_users_in_range", f) + return ret + + @defer.inlineCallbacks + def get_unread_push_actions_for_user_in_range_for_http( + self, user_id, min_stream_ordering, max_stream_ordering, limit=20 + ): + """Get a list of the most recent unread push actions for a given user, + within the given stream ordering range. Called by the httppusher. + + Args: + user_id (str): The user to fetch push actions for. + min_stream_ordering(int): The exclusive lower bound on the + stream ordering of event push actions to fetch. + max_stream_ordering(int): The inclusive upper bound on the + stream ordering of event push actions to fetch. + limit (int): The maximum number of rows to return. + Returns: + A promise which resolves to a list of dicts with the keys "event_id", + "room_id", "stream_ordering", "actions". + The list will be ordered by ascending stream_ordering. + The list will have between 0~limit entries. + """ + # find rooms that have a read receipt in them and return the next + # push actions + def get_after_receipt(txn): + # find rooms that have a read receipt in them and return the next + # push actions + sql = ( + "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," + " ep.highlight " + " FROM (" + " SELECT room_id," + " MAX(stream_ordering) as stream_ordering" + " FROM events" + " INNER JOIN receipts_linearized USING (room_id, event_id)" + " WHERE receipt_type = 'm.read' AND user_id = ?" + " GROUP BY room_id" + ") AS rl," + " event_push_actions AS ep" + " WHERE" + " ep.room_id = rl.room_id" + " AND ep.stream_ordering > rl.stream_ordering" + " AND ep.user_id = ?" + " AND ep.stream_ordering > ?" + " AND ep.stream_ordering <= ?" + " ORDER BY ep.stream_ordering ASC LIMIT ?" + ) + args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] + txn.execute(sql, args) + return txn.fetchall() + + after_read_receipt = yield self.runInteraction( + "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt + ) + + # There are rooms with push actions in them but you don't have a read receipt in + # them e.g. rooms you've been invited to, so get push actions for rooms which do + # not have read receipts in them too. + def get_no_receipt(txn): + sql = ( + "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," + " ep.highlight " + " FROM event_push_actions AS ep" + " INNER JOIN events AS e USING (room_id, event_id)" + " WHERE" + " ep.room_id NOT IN (" + " SELECT room_id FROM receipts_linearized" + " WHERE receipt_type = 'm.read' AND user_id = ?" + " GROUP BY room_id" + " )" + " AND ep.user_id = ?" + " AND ep.stream_ordering > ?" + " AND ep.stream_ordering <= ?" + " ORDER BY ep.stream_ordering ASC LIMIT ?" + ) + args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] + txn.execute(sql, args) + return txn.fetchall() + + no_read_receipt = yield self.runInteraction( + "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt + ) + + notifs = [ + { + "event_id": row[0], + "room_id": row[1], + "stream_ordering": row[2], + "actions": _deserialize_action(row[3], row[4]), + } + for row in after_read_receipt + no_read_receipt + ] + + # Now sort it so it's ordered correctly, since currently it will + # contain results from the first query, correctly ordered, followed + # by results from the second query, but we want them all ordered + # by stream_ordering, oldest first. + notifs.sort(key=lambda r: r["stream_ordering"]) + + # Take only up to the limit. We have to stop at the limit because + # one of the subqueries may have hit the limit. + return notifs[:limit] + + @defer.inlineCallbacks + def get_unread_push_actions_for_user_in_range_for_email( + self, user_id, min_stream_ordering, max_stream_ordering, limit=20 + ): + """Get a list of the most recent unread push actions for a given user, + within the given stream ordering range. Called by the emailpusher + + Args: + user_id (str): The user to fetch push actions for. + min_stream_ordering(int): The exclusive lower bound on the + stream ordering of event push actions to fetch. + max_stream_ordering(int): The inclusive upper bound on the + stream ordering of event push actions to fetch. + limit (int): The maximum number of rows to return. + Returns: + A promise which resolves to a list of dicts with the keys "event_id", + "room_id", "stream_ordering", "actions", "received_ts". + The list will be ordered by descending received_ts. + The list will have between 0~limit entries. + """ + # find rooms that have a read receipt in them and return the most recent + # push actions + def get_after_receipt(txn): + sql = ( + "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," + " ep.highlight, e.received_ts" + " FROM (" + " SELECT room_id," + " MAX(stream_ordering) as stream_ordering" + " FROM events" + " INNER JOIN receipts_linearized USING (room_id, event_id)" + " WHERE receipt_type = 'm.read' AND user_id = ?" + " GROUP BY room_id" + ") AS rl," + " event_push_actions AS ep" + " INNER JOIN events AS e USING (room_id, event_id)" + " WHERE" + " ep.room_id = rl.room_id" + " AND ep.stream_ordering > rl.stream_ordering" + " AND ep.user_id = ?" + " AND ep.stream_ordering > ?" + " AND ep.stream_ordering <= ?" + " ORDER BY ep.stream_ordering DESC LIMIT ?" + ) + args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] + txn.execute(sql, args) + return txn.fetchall() + + after_read_receipt = yield self.runInteraction( + "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt + ) + + # There are rooms with push actions in them but you don't have a read receipt in + # them e.g. rooms you've been invited to, so get push actions for rooms which do + # not have read receipts in them too. + def get_no_receipt(txn): + sql = ( + "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," + " ep.highlight, e.received_ts" + " FROM event_push_actions AS ep" + " INNER JOIN events AS e USING (room_id, event_id)" + " WHERE" + " ep.room_id NOT IN (" + " SELECT room_id FROM receipts_linearized" + " WHERE receipt_type = 'm.read' AND user_id = ?" + " GROUP BY room_id" + " )" + " AND ep.user_id = ?" + " AND ep.stream_ordering > ?" + " AND ep.stream_ordering <= ?" + " ORDER BY ep.stream_ordering DESC LIMIT ?" + ) + args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] + txn.execute(sql, args) + return txn.fetchall() + + no_read_receipt = yield self.runInteraction( + "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt + ) + + # Make a list of dicts from the two sets of results. + notifs = [ + { + "event_id": row[0], + "room_id": row[1], + "stream_ordering": row[2], + "actions": _deserialize_action(row[3], row[4]), + "received_ts": row[5], + } + for row in after_read_receipt + no_read_receipt + ] + + # Now sort it so it's ordered correctly, since currently it will + # contain results from the first query, correctly ordered, followed + # by results from the second query, but we want them all ordered + # by received_ts (most recent first) + notifs.sort(key=lambda r: -(r["received_ts"] or 0)) + + # Now return the first `limit` + return notifs[:limit] + + def get_if_maybe_push_in_range_for_user(self, user_id, min_stream_ordering): + """A fast check to see if there might be something to push for the + user since the given stream ordering. May return false positives. + + Useful to know whether to bother starting a pusher on start up or not. + + Args: + user_id (str) + min_stream_ordering (int) + + Returns: + Deferred[bool]: True if there may be push to process, False if + there definitely isn't. + """ + + def _get_if_maybe_push_in_range_for_user_txn(txn): + sql = """ + SELECT 1 FROM event_push_actions + WHERE user_id = ? AND stream_ordering > ? + LIMIT 1 + """ + + txn.execute(sql, (user_id, min_stream_ordering)) + return bool(txn.fetchone()) + + return self.runInteraction( + "get_if_maybe_push_in_range_for_user", + _get_if_maybe_push_in_range_for_user_txn, + ) + + def add_push_actions_to_staging(self, event_id, user_id_actions): + """Add the push actions for the event to the push action staging area. + + Args: + event_id (str) + user_id_actions (dict[str, list[dict|str])]): A dictionary mapping + user_id to list of push actions, where an action can either be + a string or dict. + + Returns: + Deferred + """ + + if not user_id_actions: + return + + # This is a helper function for generating the necessary tuple that + # can be used to inert into the `event_push_actions_staging` table. + def _gen_entry(user_id, actions): + is_highlight = 1 if _action_has_highlight(actions) else 0 + return ( + event_id, # event_id column + user_id, # user_id column + _serialize_action(actions, is_highlight), # actions column + 1, # notif column + is_highlight, # highlight column + ) + + def _add_push_actions_to_staging_txn(txn): + # We don't use _simple_insert_many here to avoid the overhead + # of generating lists of dicts. + + sql = """ + INSERT INTO event_push_actions_staging + (event_id, user_id, actions, notif, highlight) + VALUES (?, ?, ?, ?, ?) + """ + + txn.executemany( + sql, + ( + _gen_entry(user_id, actions) + for user_id, actions in iteritems(user_id_actions) + ), + ) + + return self.runInteraction( + "add_push_actions_to_staging", _add_push_actions_to_staging_txn + ) + + @defer.inlineCallbacks + def remove_push_actions_from_staging(self, event_id): + """Called if we failed to persist the event to ensure that stale push + actions don't build up in the DB + + Args: + event_id (str) + """ + + try: + res = yield self._simple_delete( + table="event_push_actions_staging", + keyvalues={"event_id": event_id}, + desc="remove_push_actions_from_staging", + ) + return res + except Exception: + # this method is called from an exception handler, so propagating + # another exception here really isn't helpful - there's nothing + # the caller can do about it. Just log the exception and move on. + logger.exception( + "Error removing push actions after event persistence failure" + ) + + def _find_stream_orderings_for_times(self): + return run_as_background_process( + "event_push_action_stream_orderings", + self.runInteraction, + "_find_stream_orderings_for_times", + self._find_stream_orderings_for_times_txn, + ) + + def _find_stream_orderings_for_times_txn(self, txn): + logger.info("Searching for stream ordering 1 month ago") + self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn( + txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000 + ) + logger.info( + "Found stream ordering 1 month ago: it's %d", self.stream_ordering_month_ago + ) + logger.info("Searching for stream ordering 1 day ago") + self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn( + txn, self._clock.time_msec() - 24 * 60 * 60 * 1000 + ) + logger.info( + "Found stream ordering 1 day ago: it's %d", self.stream_ordering_day_ago + ) + + def find_first_stream_ordering_after_ts(self, ts): + """Gets the stream ordering corresponding to a given timestamp. + + Specifically, finds the stream_ordering of the first event that was + received on or after the timestamp. This is done by a binary search on + the events table, since there is no index on received_ts, so is + relatively slow. + + Args: + ts (int): timestamp in millis + + Returns: + Deferred[int]: stream ordering of the first event received on/after + the timestamp + """ + return self.runInteraction( + "_find_first_stream_ordering_after_ts_txn", + self._find_first_stream_ordering_after_ts_txn, + ts, + ) + + @staticmethod + def _find_first_stream_ordering_after_ts_txn(txn, ts): + """ + Find the stream_ordering of the first event that was received on or + after a given timestamp. This is relatively slow as there is no index + on received_ts but we can then use this to delete push actions before + this. + + received_ts must necessarily be in the same order as stream_ordering + and stream_ordering is indexed, so we manually binary search using + stream_ordering + + Args: + txn (twisted.enterprise.adbapi.Transaction): + ts (int): timestamp to search for + + Returns: + int: stream ordering + """ + txn.execute("SELECT MAX(stream_ordering) FROM events") + max_stream_ordering = txn.fetchone()[0] + + if max_stream_ordering is None: + return 0 + + # We want the first stream_ordering in which received_ts is greater + # than or equal to ts. Call this point X. + # + # We maintain the invariants: + # + # range_start <= X <= range_end + # + range_start = 0 + range_end = max_stream_ordering + 1 + + # Given a stream_ordering, look up the timestamp at that + # stream_ordering. + # + # The array may be sparse (we may be missing some stream_orderings). + # We treat the gaps as the same as having the same value as the + # preceding entry, because we will pick the lowest stream_ordering + # which satisfies our requirement of received_ts >= ts. + # + # For example, if our array of events indexed by stream_ordering is + # [10, , 20], we should treat this as being equivalent to + # [10, 10, 20]. + # + sql = ( + "SELECT received_ts FROM events" + " WHERE stream_ordering <= ?" + " ORDER BY stream_ordering DESC" + " LIMIT 1" + ) + + while range_end - range_start > 0: + middle = (range_end + range_start) // 2 + txn.execute(sql, (middle,)) + row = txn.fetchone() + if row is None: + # no rows with stream_ordering<=middle + range_start = middle + 1 + continue + + middle_ts = row[0] + if ts > middle_ts: + # we got a timestamp lower than the one we were looking for. + # definitely need to look higher: X > middle. + range_start = middle + 1 + else: + # we got a timestamp higher than (or the same as) the one we + # were looking for. We aren't yet sure about the point we + # looked up, but we can be sure that X <= middle. + range_end = middle + + return range_end + + +class EventPushActionsStore(EventPushActionsWorkerStore): + EPA_HIGHLIGHT_INDEX = "epa_highlight_index" + + def __init__(self, db_conn, hs): + super(EventPushActionsStore, self).__init__(db_conn, hs) + + self.register_background_index_update( + self.EPA_HIGHLIGHT_INDEX, + index_name="event_push_actions_u_highlight", + table="event_push_actions", + columns=["user_id", "stream_ordering"], + ) + + self.register_background_index_update( + "event_push_actions_highlights_index", + index_name="event_push_actions_highlights_index", + table="event_push_actions", + columns=["user_id", "room_id", "topological_ordering", "stream_ordering"], + where_clause="highlight=1", + ) + + self._doing_notif_rotation = False + self._rotate_notif_loop = self._clock.looping_call( + self._start_rotate_notifs, 30 * 60 * 1000 + ) + + def _set_push_actions_for_event_and_users_txn( + self, txn, events_and_contexts, all_events_and_contexts + ): + """Handles moving push actions from staging table to main + event_push_actions table for all events in `events_and_contexts`. + + Also ensures that all events in `all_events_and_contexts` are removed + from the push action staging area. + + Args: + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + all_events_and_contexts (list[(EventBase, EventContext)]): all + events that we were going to persist. This includes events + we've already persisted, etc, that wouldn't appear in + events_and_context. + """ + + sql = """ + INSERT INTO event_push_actions ( + room_id, event_id, user_id, actions, stream_ordering, + topological_ordering, notif, highlight + ) + SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight + FROM event_push_actions_staging + WHERE event_id = ? + """ + + if events_and_contexts: + txn.executemany( + sql, + ( + ( + event.room_id, + event.internal_metadata.stream_ordering, + event.depth, + event.event_id, + ) + for event, _ in events_and_contexts + ), + ) + + for event, _ in events_and_contexts: + user_ids = self._simple_select_onecol_txn( + txn, + table="event_push_actions_staging", + keyvalues={"event_id": event.event_id}, + retcol="user_id", + ) + + for uid in user_ids: + txn.call_after( + self.get_unread_event_push_actions_by_room_for_user.invalidate_many, + (event.room_id, uid), + ) + + # Now we delete the staging area for *all* events that were being + # persisted. + txn.executemany( + "DELETE FROM event_push_actions_staging WHERE event_id = ?", + ((event.event_id,) for event, _ in all_events_and_contexts), + ) + + @defer.inlineCallbacks + def get_push_actions_for_user( + self, user_id, before=None, limit=50, only_highlight=False + ): + def f(txn): + before_clause = "" + if before: + before_clause = "AND epa.stream_ordering < ?" + args = [user_id, before, limit] + else: + args = [user_id, limit] + + if only_highlight: + if len(before_clause) > 0: + before_clause += " " + before_clause += "AND epa.highlight = 1" + + # NB. This assumes event_ids are globally unique since + # it makes the query easier to index + sql = ( + "SELECT epa.event_id, epa.room_id," + " epa.stream_ordering, epa.topological_ordering," + " epa.actions, epa.highlight, epa.profile_tag, e.received_ts" + " FROM event_push_actions epa, events e" + " WHERE epa.event_id = e.event_id" + " AND epa.user_id = ? %s" + " ORDER BY epa.stream_ordering DESC" + " LIMIT ?" % (before_clause,) + ) + txn.execute(sql, args) + return self.cursor_to_dict(txn) + + push_actions = yield self.runInteraction("get_push_actions_for_user", f) + for pa in push_actions: + pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) + return push_actions + + @defer.inlineCallbacks + def get_time_of_last_push_action_before(self, stream_ordering): + def f(txn): + sql = ( + "SELECT e.received_ts" + " FROM event_push_actions AS ep" + " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id" + " WHERE ep.stream_ordering > ?" + " ORDER BY ep.stream_ordering ASC" + " LIMIT 1" + ) + txn.execute(sql, (stream_ordering,)) + return txn.fetchone() + + result = yield self.runInteraction("get_time_of_last_push_action_before", f) + return result[0] if result else None + + @defer.inlineCallbacks + def get_latest_push_action_stream_ordering(self): + def f(txn): + txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions") + return txn.fetchone() + + result = yield self.runInteraction("get_latest_push_action_stream_ordering", f) + return result[0] or 0 + + def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id): + # Sad that we have to blow away the cache for the whole room here + txn.call_after( + self.get_unread_event_push_actions_by_room_for_user.invalidate_many, + (room_id,), + ) + txn.execute( + "DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?", + (room_id, event_id), + ) + + def _remove_old_push_actions_before_txn( + self, txn, room_id, user_id, stream_ordering + ): + """ + Purges old push actions for a user and room before a given + stream_ordering. + + We however keep a months worth of highlighted notifications, so that + users can still get a list of recent highlights. + + Args: + txn: The transcation + room_id: Room ID to delete from + user_id: user ID to delete for + stream_ordering: The lowest stream ordering which will + not be deleted. + """ + txn.call_after( + self.get_unread_event_push_actions_by_room_for_user.invalidate_many, + (room_id, user_id), + ) + + # We need to join on the events table to get the received_ts for + # event_push_actions and sqlite won't let us use a join in a delete so + # we can't just delete where received_ts < x. Furthermore we can + # only identify event_push_actions by a tuple of room_id, event_id + # we we can't use a subquery. + # Instead, we look up the stream ordering for the last event in that + # room received before the threshold time and delete event_push_actions + # in the room with a stream_odering before that. + txn.execute( + "DELETE FROM event_push_actions " + " WHERE user_id = ? AND room_id = ? AND " + " stream_ordering <= ?" + " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)", + (user_id, room_id, stream_ordering, self.stream_ordering_month_ago), + ) + + txn.execute( + """ + DELETE FROM event_push_summary + WHERE room_id = ? AND user_id = ? AND stream_ordering <= ? + """, + (room_id, user_id, stream_ordering), + ) + + def _start_rotate_notifs(self): + return run_as_background_process("rotate_notifs", self._rotate_notifs) + + @defer.inlineCallbacks + def _rotate_notifs(self): + if self._doing_notif_rotation or self.stream_ordering_day_ago is None: + return + self._doing_notif_rotation = True + + try: + while True: + logger.info("Rotating notifications") + + caught_up = yield self.runInteraction( + "_rotate_notifs", self._rotate_notifs_txn + ) + if caught_up: + break + yield self.hs.get_clock().sleep(self._rotate_delay) + finally: + self._doing_notif_rotation = False + + def _rotate_notifs_txn(self, txn): + """Archives older notifications into event_push_summary. Returns whether + the archiving process has caught up or not. + """ + + old_rotate_stream_ordering = self._simple_select_one_onecol_txn( + txn, + table="event_push_summary_stream_ordering", + keyvalues={}, + retcol="stream_ordering", + ) + + # We don't to try and rotate millions of rows at once, so we cap the + # maximum stream ordering we'll rotate before. + txn.execute( + """ + SELECT stream_ordering FROM event_push_actions + WHERE stream_ordering > ? + ORDER BY stream_ordering ASC LIMIT 1 OFFSET ? + """, + (old_rotate_stream_ordering, self._rotate_count), + ) + stream_row = txn.fetchone() + if stream_row: + offset_stream_ordering, = stream_row + rotate_to_stream_ordering = min( + self.stream_ordering_day_ago, offset_stream_ordering + ) + caught_up = offset_stream_ordering >= self.stream_ordering_day_ago + else: + rotate_to_stream_ordering = self.stream_ordering_day_ago + caught_up = True + + logger.info("Rotating notifications up to: %s", rotate_to_stream_ordering) + + self._rotate_notifs_before_txn(txn, rotate_to_stream_ordering) + + # We have caught up iff we were limited by `stream_ordering_day_ago` + return caught_up + + def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering): + old_rotate_stream_ordering = self._simple_select_one_onecol_txn( + txn, + table="event_push_summary_stream_ordering", + keyvalues={}, + retcol="stream_ordering", + ) + + # Calculate the new counts that should be upserted into event_push_summary + sql = """ + SELECT user_id, room_id, + coalesce(old.notif_count, 0) + upd.notif_count, + upd.stream_ordering, + old.user_id + FROM ( + SELECT user_id, room_id, count(*) as notif_count, + max(stream_ordering) as stream_ordering + FROM event_push_actions + WHERE ? <= stream_ordering AND stream_ordering < ? + AND highlight = 0 + GROUP BY user_id, room_id + ) AS upd + LEFT JOIN event_push_summary AS old USING (user_id, room_id) + """ + + txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering)) + rows = txn.fetchall() + + logger.info("Rotating notifications, handling %d rows", len(rows)) + + # If the `old.user_id` above is NULL then we know there isn't already an + # entry in the table, so we simply insert it. Otherwise we update the + # existing table. + self._simple_insert_many_txn( + txn, + table="event_push_summary", + values=[ + { + "user_id": row[0], + "room_id": row[1], + "notif_count": row[2], + "stream_ordering": row[3], + } + for row in rows + if row[4] is None + ], + ) + + txn.executemany( + """ + UPDATE event_push_summary SET notif_count = ?, stream_ordering = ? + WHERE user_id = ? AND room_id = ? + """, + ((row[2], row[3], row[0], row[1]) for row in rows if row[4] is not None), + ) + + txn.execute( + "DELETE FROM event_push_actions" + " WHERE ? <= stream_ordering AND stream_ordering < ? AND highlight = 0", + (old_rotate_stream_ordering, rotate_to_stream_ordering), + ) + + logger.info("Rotating notifications, deleted %s push actions", txn.rowcount) + + txn.execute( + "UPDATE event_push_summary_stream_ordering SET stream_ordering = ?", + (rotate_to_stream_ordering,), + ) + + +def _action_has_highlight(actions): + for action in actions: + try: + if action.get("set_tweak", None) == "highlight": + return action.get("value", True) + except AttributeError: + pass + + return False diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py new file mode 100644 index 0000000000..03b5111c5d --- /dev/null +++ b/synapse/storage/data_stores/main/events.py @@ -0,0 +1,2489 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import logging +from collections import Counter as c_counter, OrderedDict, deque, namedtuple +from functools import wraps + +from six import iteritems, text_type +from six.moves import range + +from canonicaljson import json +from prometheus_client import Counter, Histogram + +from twisted.internet import defer + +import synapse.metrics +from synapse.api.constants import EventTypes +from synapse.api.errors import SynapseError +from synapse.events import EventBase # noqa: F401 +from synapse.events.snapshot import EventContext # noqa: F401 +from synapse.events.utils import prune_event_dict +from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable +from synapse.logging.utils import log_function +from synapse.metrics import BucketCollector +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.state import StateResolutionStore +from synapse.storage._base import make_in_list_sql_clause +from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage.data_stores.main.event_federation import EventFederationStore +from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.data_stores.main.state import StateGroupWorkerStore +from synapse.types import RoomStreamToken, get_domain_from_id +from synapse.util import batch_iter +from synapse.util.async_helpers import ObservableDeferred +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.util.frozenutils import frozendict_json_encoder +from synapse.util.metrics import Measure + +logger = logging.getLogger(__name__) + +persist_event_counter = Counter("synapse_storage_events_persisted_events", "") +event_counter = Counter( + "synapse_storage_events_persisted_events_sep", + "", + ["type", "origin_type", "origin_entity"], +) + +# The number of times we are recalculating the current state +state_delta_counter = Counter("synapse_storage_events_state_delta", "") + +# The number of times we are recalculating state when there is only a +# single forward extremity +state_delta_single_event_counter = Counter( + "synapse_storage_events_state_delta_single_event", "" +) + +# The number of times we are reculating state when we could have resonably +# calculated the delta when we calculated the state for an event we were +# persisting. +state_delta_reuse_delta_counter = Counter( + "synapse_storage_events_state_delta_reuse_delta", "" +) + +# The number of forward extremities for each new event. +forward_extremities_counter = Histogram( + "synapse_storage_events_forward_extremities_persisted", + "Number of forward extremities for each new event", + buckets=(1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), +) + +# The number of stale forward extremities for each new event. Stale extremities +# are those that were in the previous set of extremities as well as the new. +stale_forward_extremities_counter = Histogram( + "synapse_storage_events_stale_forward_extremities_persisted", + "Number of unchanged forward extremities for each new event", + buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), +) + + +def encode_json(json_object): + """ + Encode a Python object as JSON and return it in a Unicode string. + """ + out = frozendict_json_encoder.encode(json_object) + if isinstance(out, bytes): + out = out.decode("utf8") + return out + + +class _EventPeristenceQueue(object): + """Queues up events so that they can be persisted in bulk with only one + concurrent transaction per room. + """ + + _EventPersistQueueItem = namedtuple( + "_EventPersistQueueItem", ("events_and_contexts", "backfilled", "deferred") + ) + + def __init__(self): + self._event_persist_queues = {} + self._currently_persisting_rooms = set() + + def add_to_queue(self, room_id, events_and_contexts, backfilled): + """Add events to the queue, with the given persist_event options. + + NB: due to the normal usage pattern of this method, it does *not* + follow the synapse logcontext rules, and leaves the logcontext in + place whether or not the returned deferred is ready. + + Args: + room_id (str): + events_and_contexts (list[(EventBase, EventContext)]): + backfilled (bool): + + Returns: + defer.Deferred: a deferred which will resolve once the events are + persisted. Runs its callbacks *without* a logcontext. + """ + queue = self._event_persist_queues.setdefault(room_id, deque()) + if queue: + # if the last item in the queue has the same `backfilled` setting, + # we can just add these new events to that item. + end_item = queue[-1] + if end_item.backfilled == backfilled: + end_item.events_and_contexts.extend(events_and_contexts) + return end_item.deferred.observe() + + deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True) + + queue.append( + self._EventPersistQueueItem( + events_and_contexts=events_and_contexts, + backfilled=backfilled, + deferred=deferred, + ) + ) + + return deferred.observe() + + def handle_queue(self, room_id, per_item_callback): + """Attempts to handle the queue for a room if not already being handled. + + The given callback will be invoked with for each item in the queue, + of type _EventPersistQueueItem. The per_item_callback will continuously + be called with new items, unless the queue becomnes empty. The return + value of the function will be given to the deferreds waiting on the item, + exceptions will be passed to the deferreds as well. + + This function should therefore be called whenever anything is added + to the queue. + + If another callback is currently handling the queue then it will not be + invoked. + """ + + if room_id in self._currently_persisting_rooms: + return + + self._currently_persisting_rooms.add(room_id) + + @defer.inlineCallbacks + def handle_queue_loop(): + try: + queue = self._get_drainining_queue(room_id) + for item in queue: + try: + ret = yield per_item_callback(item) + except Exception: + with PreserveLoggingContext(): + item.deferred.errback() + else: + with PreserveLoggingContext(): + item.deferred.callback(ret) + finally: + queue = self._event_persist_queues.pop(room_id, None) + if queue: + self._event_persist_queues[room_id] = queue + self._currently_persisting_rooms.discard(room_id) + + # set handle_queue_loop off in the background + run_as_background_process("persist_events", handle_queue_loop) + + def _get_drainining_queue(self, room_id): + queue = self._event_persist_queues.setdefault(room_id, deque()) + + try: + while True: + yield queue.popleft() + except IndexError: + # Queue has been drained. + pass + + +_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) + + +def _retry_on_integrity_error(func): + """Wraps a database function so that it gets retried on IntegrityError, + with `delete_existing=True` passed in. + + Args: + func: function that returns a Deferred and accepts a `delete_existing` arg + """ + + @wraps(func) + @defer.inlineCallbacks + def f(self, *args, **kwargs): + try: + res = yield func(self, *args, **kwargs) + except self.database_engine.module.IntegrityError: + logger.exception("IntegrityError, retrying.") + res = yield func(self, *args, delete_existing=True, **kwargs) + return res + + return f + + +# inherits from EventFederationStore so that we can call _update_backward_extremities +# and _handle_mult_prev_events (though arguably those could both be moved in here) +class EventsStore( + StateGroupWorkerStore, + EventFederationStore, + EventsWorkerStore, + BackgroundUpdateStore, +): + def __init__(self, db_conn, hs): + super(EventsStore, self).__init__(db_conn, hs) + + self._event_persist_queue = _EventPeristenceQueue() + self._state_resolution_handler = hs.get_state_resolution_handler() + + # Collect metrics on the number of forward extremities that exist. + # Counter of number of extremities to count + self._current_forward_extremities_amount = c_counter() + + BucketCollector( + "synapse_forward_extremities", + lambda: self._current_forward_extremities_amount, + buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"], + ) + + # Read the extrems every 60 minutes + def read_forward_extremities(): + # run as a background process to make sure that the database transactions + # have a logcontext to report to + return run_as_background_process( + "read_forward_extremities", self._read_forward_extremities + ) + + hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000) + + def _censor_redactions(): + return run_as_background_process( + "_censor_redactions", self._censor_redactions + ) + + if self.hs.config.redaction_retention_period is not None: + hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000) + + @defer.inlineCallbacks + def _read_forward_extremities(self): + def fetch(txn): + txn.execute( + """ + select count(*) c from event_forward_extremities + group by room_id + """ + ) + return txn.fetchall() + + res = yield self.runInteraction("read_forward_extremities", fetch) + self._current_forward_extremities_amount = c_counter(list(x[0] for x in res)) + + @defer.inlineCallbacks + def persist_events(self, events_and_contexts, backfilled=False): + """ + Write events to the database + Args: + events_and_contexts: list of tuples of (event, context) + backfilled (bool): Whether the results are retrieved from federation + via backfill or not. Used to determine if they're "new" events + which might update the current state etc. + + Returns: + Deferred[int]: the stream ordering of the latest persisted event + """ + partitioned = {} + for event, ctx in events_and_contexts: + partitioned.setdefault(event.room_id, []).append((event, ctx)) + + deferreds = [] + for room_id, evs_ctxs in iteritems(partitioned): + d = self._event_persist_queue.add_to_queue( + room_id, evs_ctxs, backfilled=backfilled + ) + deferreds.append(d) + + for room_id in partitioned: + self._maybe_start_persisting(room_id) + + yield make_deferred_yieldable( + defer.gatherResults(deferreds, consumeErrors=True) + ) + + max_persisted_id = yield self._stream_id_gen.get_current_token() + + return max_persisted_id + + @defer.inlineCallbacks + @log_function + def persist_event(self, event, context, backfilled=False): + """ + + Args: + event (EventBase): + context (EventContext): + backfilled (bool): + + Returns: + Deferred: resolves to (int, int): the stream ordering of ``event``, + and the stream ordering of the latest persisted event + """ + deferred = self._event_persist_queue.add_to_queue( + event.room_id, [(event, context)], backfilled=backfilled + ) + + self._maybe_start_persisting(event.room_id) + + yield make_deferred_yieldable(deferred) + + max_persisted_id = yield self._stream_id_gen.get_current_token() + return (event.internal_metadata.stream_ordering, max_persisted_id) + + def _maybe_start_persisting(self, room_id): + @defer.inlineCallbacks + def persisting_queue(item): + with Measure(self._clock, "persist_events"): + yield self._persist_events( + item.events_and_contexts, backfilled=item.backfilled + ) + + self._event_persist_queue.handle_queue(room_id, persisting_queue) + + @_retry_on_integrity_error + @defer.inlineCallbacks + def _persist_events( + self, events_and_contexts, backfilled=False, delete_existing=False + ): + """Persist events to db + + Args: + events_and_contexts (list[(EventBase, EventContext)]): + backfilled (bool): + delete_existing (bool): + + Returns: + Deferred: resolves when the events have been persisted + """ + if not events_and_contexts: + return + + chunks = [ + events_and_contexts[x : x + 100] + for x in range(0, len(events_and_contexts), 100) + ] + + for chunk in chunks: + # We can't easily parallelize these since different chunks + # might contain the same event. :( + + # NB: Assumes that we are only persisting events for one room + # at a time. + + # map room_id->list[event_ids] giving the new forward + # extremities in each room + new_forward_extremeties = {} + + # map room_id->(type,state_key)->event_id tracking the full + # state in each room after adding these events. + # This is simply used to prefill the get_current_state_ids + # cache + current_state_for_room = {} + + # map room_id->(to_delete, to_insert) where to_delete is a list + # of type/state keys to remove from current state, and to_insert + # is a map (type,key)->event_id giving the state delta in each + # room + state_delta_for_room = {} + + if not backfilled: + with Measure(self._clock, "_calculate_state_and_extrem"): + # Work out the new "current state" for each room. + # We do this by working out what the new extremities are and then + # calculating the state from that. + events_by_room = {} + for event, context in chunk: + events_by_room.setdefault(event.room_id, []).append( + (event, context) + ) + + for room_id, ev_ctx_rm in iteritems(events_by_room): + latest_event_ids = yield self.get_latest_event_ids_in_room( + room_id + ) + new_latest_event_ids = yield self._calculate_new_extremities( + room_id, ev_ctx_rm, latest_event_ids + ) + + latest_event_ids = set(latest_event_ids) + if new_latest_event_ids == latest_event_ids: + # No change in extremities, so no change in state + continue + + # there should always be at least one forward extremity. + # (except during the initial persistence of the send_join + # results, in which case there will be no existing + # extremities, so we'll `continue` above and skip this bit.) + assert new_latest_event_ids, "No forward extremities left!" + + new_forward_extremeties[room_id] = new_latest_event_ids + + len_1 = ( + len(latest_event_ids) == 1 + and len(new_latest_event_ids) == 1 + ) + if len_1: + all_single_prev_not_state = all( + len(event.prev_event_ids()) == 1 + and not event.is_state() + for event, ctx in ev_ctx_rm + ) + # Don't bother calculating state if they're just + # a long chain of single ancestor non-state events. + if all_single_prev_not_state: + continue + + state_delta_counter.inc() + if len(new_latest_event_ids) == 1: + state_delta_single_event_counter.inc() + + # This is a fairly handwavey check to see if we could + # have guessed what the delta would have been when + # processing one of these events. + # What we're interested in is if the latest extremities + # were the same when we created the event as they are + # now. When this server creates a new event (as opposed + # to receiving it over federation) it will use the + # forward extremities as the prev_events, so we can + # guess this by looking at the prev_events and checking + # if they match the current forward extremities. + for ev, _ in ev_ctx_rm: + prev_event_ids = set(ev.prev_event_ids()) + if latest_event_ids == prev_event_ids: + state_delta_reuse_delta_counter.inc() + break + + logger.info("Calculating state delta for room %s", room_id) + with Measure( + self._clock, "persist_events.get_new_state_after_events" + ): + res = yield self._get_new_state_after_events( + room_id, + ev_ctx_rm, + latest_event_ids, + new_latest_event_ids, + ) + current_state, delta_ids = res + + # If either are not None then there has been a change, + # and we need to work out the delta (or use that + # given) + if delta_ids is not None: + # If there is a delta we know that we've + # only added or replaced state, never + # removed keys entirely. + state_delta_for_room[room_id] = ([], delta_ids) + elif current_state is not None: + with Measure( + self._clock, "persist_events.calculate_state_delta" + ): + delta = yield self._calculate_state_delta( + room_id, current_state + ) + state_delta_for_room[room_id] = delta + + # If we have the current_state then lets prefill + # the cache with it. + if current_state is not None: + current_state_for_room[room_id] = current_state + + # We want to calculate the stream orderings as late as possible, as + # we only notify after all events with a lesser stream ordering have + # been persisted. I.e. if we spend 10s inside the with block then + # that will delay all subsequent events from being notified about. + # Hence why we do it down here rather than wrapping the entire + # function. + # + # Its safe to do this after calculating the state deltas etc as we + # only need to protect the *persistence* of the events. This is to + # ensure that queries of the form "fetch events since X" don't + # return events and stream positions after events that are still in + # flight, as otherwise subsequent requests "fetch event since Y" + # will not return those events. + # + # Note: Multiple instances of this function cannot be in flight at + # the same time for the same room. + if backfilled: + stream_ordering_manager = self._backfill_id_gen.get_next_mult( + len(chunk) + ) + else: + stream_ordering_manager = self._stream_id_gen.get_next_mult(len(chunk)) + + with stream_ordering_manager as stream_orderings: + for (event, context), stream in zip(chunk, stream_orderings): + event.internal_metadata.stream_ordering = stream + + yield self.runInteraction( + "persist_events", + self._persist_events_txn, + events_and_contexts=chunk, + backfilled=backfilled, + delete_existing=delete_existing, + state_delta_for_room=state_delta_for_room, + new_forward_extremeties=new_forward_extremeties, + ) + persist_event_counter.inc(len(chunk)) + + if not backfilled: + # backfilled events have negative stream orderings, so we don't + # want to set the event_persisted_position to that. + synapse.metrics.event_persisted_position.set( + chunk[-1][0].internal_metadata.stream_ordering + ) + + for event, context in chunk: + if context.app_service: + origin_type = "local" + origin_entity = context.app_service.id + elif self.hs.is_mine_id(event.sender): + origin_type = "local" + origin_entity = "*client*" + else: + origin_type = "remote" + origin_entity = get_domain_from_id(event.sender) + + event_counter.labels(event.type, origin_type, origin_entity).inc() + + for room_id, new_state in iteritems(current_state_for_room): + self.get_current_state_ids.prefill((room_id,), new_state) + + for room_id, latest_event_ids in iteritems(new_forward_extremeties): + self.get_latest_event_ids_in_room.prefill( + (room_id,), list(latest_event_ids) + ) + + @defer.inlineCallbacks + def _calculate_new_extremities(self, room_id, event_contexts, latest_event_ids): + """Calculates the new forward extremities for a room given events to + persist. + + Assumes that we are only persisting events for one room at a time. + """ + + # we're only interested in new events which aren't outliers and which aren't + # being rejected. + new_events = [ + event + for event, ctx in event_contexts + if not event.internal_metadata.is_outlier() + and not ctx.rejected + and not event.internal_metadata.is_soft_failed() + ] + + latest_event_ids = set(latest_event_ids) + + # start with the existing forward extremities + result = set(latest_event_ids) + + # add all the new events to the list + result.update(event.event_id for event in new_events) + + # Now remove all events which are prev_events of any of the new events + result.difference_update( + e_id for event in new_events for e_id in event.prev_event_ids() + ) + + # Remove any events which are prev_events of any existing events. + existing_prevs = yield self._get_events_which_are_prevs(result) + result.difference_update(existing_prevs) + + # Finally handle the case where the new events have soft-failed prev + # events. If they do we need to remove them and their prev events, + # otherwise we end up with dangling extremities. + existing_prevs = yield self._get_prevs_before_rejected( + e_id for event in new_events for e_id in event.prev_event_ids() + ) + result.difference_update(existing_prevs) + + # We only update metrics for events that change forward extremities + # (e.g. we ignore backfill/outliers/etc) + if result != latest_event_ids: + forward_extremities_counter.observe(len(result)) + stale = latest_event_ids & result + stale_forward_extremities_counter.observe(len(stale)) + + return result + + @defer.inlineCallbacks + def _get_events_which_are_prevs(self, event_ids): + """Filter the supplied list of event_ids to get those which are prev_events of + existing (non-outlier/rejected) events. + + Args: + event_ids (Iterable[str]): event ids to filter + + Returns: + Deferred[List[str]]: filtered event ids + """ + results = [] + + def _get_events_which_are_prevs_txn(txn, batch): + sql = """ + SELECT prev_event_id, internal_metadata + FROM event_edges + INNER JOIN events USING (event_id) + LEFT JOIN rejections USING (event_id) + LEFT JOIN event_json USING (event_id) + WHERE + NOT events.outlier + AND rejections.event_id IS NULL + AND + """ + + clause, args = make_in_list_sql_clause( + self.database_engine, "prev_event_id", batch + ) + + txn.execute(sql + clause, args) + results.extend(r[0] for r in txn if not json.loads(r[1]).get("soft_failed")) + + for chunk in batch_iter(event_ids, 100): + yield self.runInteraction( + "_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk + ) + + return results + + @defer.inlineCallbacks + def _get_prevs_before_rejected(self, event_ids): + """Get soft-failed ancestors to remove from the extremities. + + Given a set of events, find all those that have been soft-failed or + rejected. Returns those soft failed/rejected events and their prev + events (whether soft-failed/rejected or not), and recurses up the + prev-event graph until it finds no more soft-failed/rejected events. + + This is used to find extremities that are ancestors of new events, but + are separated by soft failed events. + + Args: + event_ids (Iterable[str]): Events to find prev events for. Note + that these must have already been persisted. + + Returns: + Deferred[set[str]] + """ + + # The set of event_ids to return. This includes all soft-failed events + # and their prev events. + existing_prevs = set() + + def _get_prevs_before_rejected_txn(txn, batch): + to_recursively_check = batch + + while to_recursively_check: + sql = """ + SELECT + event_id, prev_event_id, internal_metadata, + rejections.event_id IS NOT NULL + FROM event_edges + INNER JOIN events USING (event_id) + LEFT JOIN rejections USING (event_id) + LEFT JOIN event_json USING (event_id) + WHERE + NOT events.outlier + AND + """ + + clause, args = make_in_list_sql_clause( + self.database_engine, "event_id", to_recursively_check + ) + + txn.execute(sql + clause, args) + to_recursively_check = [] + + for event_id, prev_event_id, metadata, rejected in txn: + if prev_event_id in existing_prevs: + continue + + soft_failed = json.loads(metadata).get("soft_failed") + if soft_failed or rejected: + to_recursively_check.append(prev_event_id) + existing_prevs.add(prev_event_id) + + for chunk in batch_iter(event_ids, 100): + yield self.runInteraction( + "_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk + ) + + return existing_prevs + + @defer.inlineCallbacks + def _get_new_state_after_events( + self, room_id, events_context, old_latest_event_ids, new_latest_event_ids + ): + """Calculate the current state dict after adding some new events to + a room + + Args: + room_id (str): + room to which the events are being added. Used for logging etc + + events_context (list[(EventBase, EventContext)]): + events and contexts which are being added to the room + + old_latest_event_ids (iterable[str]): + the old forward extremities for the room. + + new_latest_event_ids (iterable[str]): + the new forward extremities for the room. + + Returns: + Deferred[tuple[dict[(str,str), str]|None, dict[(str,str), str]|None]]: + Returns a tuple of two state maps, the first being the full new current + state and the second being the delta to the existing current state. + If both are None then there has been no change. + + If there has been a change then we only return the delta if its + already been calculated. Conversely if we do know the delta then + the new current state is only returned if we've already calculated + it. + """ + # map from state_group to ((type, key) -> event_id) state map + state_groups_map = {} + + # Map from (prev state group, new state group) -> delta state dict + state_group_deltas = {} + + for ev, ctx in events_context: + if ctx.state_group is None: + # This should only happen for outlier events. + if not ev.internal_metadata.is_outlier(): + raise Exception( + "Context for new event %s has no state " + "group" % (ev.event_id,) + ) + continue + + if ctx.state_group in state_groups_map: + continue + + # We're only interested in pulling out state that has already + # been cached in the context. We'll pull stuff out of the DB later + # if necessary. + current_state_ids = ctx.get_cached_current_state_ids() + if current_state_ids is not None: + state_groups_map[ctx.state_group] = current_state_ids + + if ctx.prev_group: + state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids + + # We need to map the event_ids to their state groups. First, let's + # check if the event is one we're persisting, in which case we can + # pull the state group from its context. + # Otherwise we need to pull the state group from the database. + + # Set of events we need to fetch groups for. (We know none of the old + # extremities are going to be in events_context). + missing_event_ids = set(old_latest_event_ids) + + event_id_to_state_group = {} + for event_id in new_latest_event_ids: + # First search in the list of new events we're adding. + for ev, ctx in events_context: + if event_id == ev.event_id and ctx.state_group is not None: + event_id_to_state_group[event_id] = ctx.state_group + break + else: + # If we couldn't find it, then we'll need to pull + # the state from the database + missing_event_ids.add(event_id) + + if missing_event_ids: + # Now pull out the state groups for any missing events from DB + event_to_groups = yield self._get_state_group_for_events(missing_event_ids) + event_id_to_state_group.update(event_to_groups) + + # State groups of old_latest_event_ids + old_state_groups = set( + event_id_to_state_group[evid] for evid in old_latest_event_ids + ) + + # State groups of new_latest_event_ids + new_state_groups = set( + event_id_to_state_group[evid] for evid in new_latest_event_ids + ) + + # If they old and new groups are the same then we don't need to do + # anything. + if old_state_groups == new_state_groups: + return None, None + + if len(new_state_groups) == 1 and len(old_state_groups) == 1: + # If we're going from one state group to another, lets check if + # we have a delta for that transition. If we do then we can just + # return that. + + new_state_group = next(iter(new_state_groups)) + old_state_group = next(iter(old_state_groups)) + + delta_ids = state_group_deltas.get((old_state_group, new_state_group), None) + if delta_ids is not None: + # We have a delta from the existing to new current state, + # so lets just return that. If we happen to already have + # the current state in memory then lets also return that, + # but it doesn't matter if we don't. + new_state = state_groups_map.get(new_state_group) + return new_state, delta_ids + + # Now that we have calculated new_state_groups we need to get + # their state IDs so we can resolve to a single state set. + missing_state = new_state_groups - set(state_groups_map) + if missing_state: + group_to_state = yield self._get_state_for_groups(missing_state) + state_groups_map.update(group_to_state) + + if len(new_state_groups) == 1: + # If there is only one state group, then we know what the current + # state is. + return state_groups_map[new_state_groups.pop()], None + + # Ok, we need to defer to the state handler to resolve our state sets. + + state_groups = {sg: state_groups_map[sg] for sg in new_state_groups} + + events_map = {ev.event_id: ev for ev, _ in events_context} + + # We need to get the room version, which is in the create event. + # Normally that'd be in the database, but its also possible that we're + # currently trying to persist it. + room_version = None + for ev, _ in events_context: + if ev.type == EventTypes.Create and ev.state_key == "": + room_version = ev.content.get("room_version", "1") + break + + if not room_version: + room_version = yield self.get_room_version(room_id) + + logger.debug("calling resolve_state_groups from preserve_events") + res = yield self._state_resolution_handler.resolve_state_groups( + room_id, + room_version, + state_groups, + events_map, + state_res_store=StateResolutionStore(self), + ) + + return res.state, None + + @defer.inlineCallbacks + def _calculate_state_delta(self, room_id, current_state): + """Calculate the new state deltas for a room. + + Assumes that we are only persisting events for one room at a time. + + Returns: + tuple[list, dict] (to_delete, to_insert): where to_delete are the + type/state_keys to remove from current_state_events and `to_insert` + are the updates to current_state_events. + """ + existing_state = yield self.get_current_state_ids(room_id) + + to_delete = [key for key in existing_state if key not in current_state] + + to_insert = { + key: ev_id + for key, ev_id in iteritems(current_state) + if ev_id != existing_state.get(key) + } + + return to_delete, to_insert + + @log_function + def _persist_events_txn( + self, + txn, + events_and_contexts, + backfilled, + delete_existing=False, + state_delta_for_room={}, + new_forward_extremeties={}, + ): + """Insert some number of room events into the necessary database tables. + + Rejected events are only inserted into the events table, the events_json table, + and the rejections table. Things reading from those table will need to check + whether the event was rejected. + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + events_and_contexts (list[(EventBase, EventContext)]): + events to persist + backfilled (bool): True if the events were backfilled + delete_existing (bool): True to purge existing table rows for the + events from the database. This is useful when retrying due to + IntegrityError. + state_delta_for_room (dict[str, (list, dict)]): + The current-state delta for each room. For each room, a tuple + (to_delete, to_insert), being a list of type/state keys to be + removed from the current state, and a state set to be added to + the current state. + new_forward_extremeties (dict[str, list[str]]): + The new forward extremities for each room. For each room, a + list of the event ids which are the forward extremities. + + """ + all_events_and_contexts = events_and_contexts + + min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering + max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering + + self._update_forward_extremities_txn( + txn, + new_forward_extremities=new_forward_extremeties, + max_stream_order=max_stream_order, + ) + + # Ensure that we don't have the same event twice. + events_and_contexts = self._filter_events_and_contexts_for_duplicates( + events_and_contexts + ) + + self._update_room_depths_txn( + txn, events_and_contexts=events_and_contexts, backfilled=backfilled + ) + + # _update_outliers_txn filters out any events which have already been + # persisted, and returns the filtered list. + events_and_contexts = self._update_outliers_txn( + txn, events_and_contexts=events_and_contexts + ) + + # From this point onwards the events are only events that we haven't + # seen before. + + if delete_existing: + # For paranoia reasons, we go and delete all the existing entries + # for these events so we can reinsert them. + # This gets around any problems with some tables already having + # entries. + self._delete_existing_rows_txn(txn, events_and_contexts=events_and_contexts) + + self._store_event_txn(txn, events_and_contexts=events_and_contexts) + + # Insert into event_to_state_groups. + self._store_event_state_mappings_txn(txn, events_and_contexts) + + # We want to store event_auth mappings for rejected events, as they're + # used in state res v2. + # This is only necessary if the rejected event appears in an accepted + # event's auth chain, but its easier for now just to store them (and + # it doesn't take much storage compared to storing the entire event + # anyway). + self._simple_insert_many_txn( + txn, + table="event_auth", + values=[ + { + "event_id": event.event_id, + "room_id": event.room_id, + "auth_id": auth_id, + } + for event, _ in events_and_contexts + for auth_id in event.auth_event_ids() + if event.is_state() + ], + ) + + # _store_rejected_events_txn filters out any events which were + # rejected, and returns the filtered list. + events_and_contexts = self._store_rejected_events_txn( + txn, events_and_contexts=events_and_contexts + ) + + # From this point onwards the events are only ones that weren't + # rejected. + + self._update_metadata_tables_txn( + txn, + events_and_contexts=events_and_contexts, + all_events_and_contexts=all_events_and_contexts, + backfilled=backfilled, + ) + + # We call this last as it assumes we've inserted the events into + # room_memberships, where applicable. + self._update_current_state_txn(txn, state_delta_for_room, min_stream_order) + + def _update_current_state_txn(self, txn, state_delta_by_room, stream_id): + for room_id, current_state_tuple in iteritems(state_delta_by_room): + to_delete, to_insert = current_state_tuple + + # First we add entries to the current_state_delta_stream. We + # do this before updating the current_state_events table so + # that we can use it to calculate the `prev_event_id`. (This + # allows us to not have to pull out the existing state + # unnecessarily). + # + # The stream_id for the update is chosen to be the minimum of the stream_ids + # for the batch of the events that we are persisting; that means we do not + # end up in a situation where workers see events before the + # current_state_delta updates. + # + sql = """ + INSERT INTO current_state_delta_stream + (stream_id, room_id, type, state_key, event_id, prev_event_id) + SELECT ?, ?, ?, ?, ?, ( + SELECT event_id FROM current_state_events + WHERE room_id = ? AND type = ? AND state_key = ? + ) + """ + txn.executemany( + sql, + ( + ( + stream_id, + room_id, + etype, + state_key, + None, + room_id, + etype, + state_key, + ) + for etype, state_key in to_delete + # We sanity check that we're deleting rather than updating + if (etype, state_key) not in to_insert + ), + ) + txn.executemany( + sql, + ( + ( + stream_id, + room_id, + etype, + state_key, + ev_id, + room_id, + etype, + state_key, + ) + for (etype, state_key), ev_id in iteritems(to_insert) + ), + ) + + # Now we actually update the current_state_events table + + txn.executemany( + "DELETE FROM current_state_events" + " WHERE room_id = ? AND type = ? AND state_key = ?", + ( + (room_id, etype, state_key) + for etype, state_key in itertools.chain(to_delete, to_insert) + ), + ) + + # We include the membership in the current state table, hence we do + # a lookup when we insert. This assumes that all events have already + # been inserted into room_memberships. + txn.executemany( + """INSERT INTO current_state_events + (room_id, type, state_key, event_id, membership) + VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?)) + """, + [ + (room_id, key[0], key[1], ev_id, ev_id) + for key, ev_id in iteritems(to_insert) + ], + ) + + txn.call_after( + self._curr_state_delta_stream_cache.entity_has_changed, + room_id, + stream_id, + ) + + # Invalidate the various caches + + # Figure out the changes of membership to invalidate the + # `get_rooms_for_user` cache. + # We find out which membership events we may have deleted + # and which we have added, then we invlidate the caches for all + # those users. + members_changed = set( + state_key + for ev_type, state_key in itertools.chain(to_delete, to_insert) + if ev_type == EventTypes.Member + ) + + for member in members_changed: + txn.call_after( + self.get_rooms_for_user_with_stream_ordering.invalidate, (member,) + ) + + self._invalidate_state_caches_and_stream(txn, room_id, members_changed) + + def _update_forward_extremities_txn( + self, txn, new_forward_extremities, max_stream_order + ): + for room_id, new_extrem in iteritems(new_forward_extremities): + self._simple_delete_txn( + txn, table="event_forward_extremities", keyvalues={"room_id": room_id} + ) + txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) + + self._simple_insert_many_txn( + txn, + table="event_forward_extremities", + values=[ + {"event_id": ev_id, "room_id": room_id} + for room_id, new_extrem in iteritems(new_forward_extremities) + for ev_id in new_extrem + ], + ) + # We now insert into stream_ordering_to_exterm a mapping from room_id, + # new stream_ordering to new forward extremeties in the room. + # This allows us to later efficiently look up the forward extremeties + # for a room before a given stream_ordering + self._simple_insert_many_txn( + txn, + table="stream_ordering_to_exterm", + values=[ + { + "room_id": room_id, + "event_id": event_id, + "stream_ordering": max_stream_order, + } + for room_id, new_extrem in iteritems(new_forward_extremities) + for event_id in new_extrem + ], + ) + + @classmethod + def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts): + """Ensure that we don't have the same event twice. + + Pick the earliest non-outlier if there is one, else the earliest one. + + Args: + events_and_contexts (list[(EventBase, EventContext)]): + Returns: + list[(EventBase, EventContext)]: filtered list + """ + new_events_and_contexts = OrderedDict() + for event, context in events_and_contexts: + prev_event_context = new_events_and_contexts.get(event.event_id) + if prev_event_context: + if not event.internal_metadata.is_outlier(): + if prev_event_context[0].internal_metadata.is_outlier(): + # To ensure correct ordering we pop, as OrderedDict is + # ordered by first insertion. + new_events_and_contexts.pop(event.event_id, None) + new_events_and_contexts[event.event_id] = (event, context) + else: + new_events_and_contexts[event.event_id] = (event, context) + return list(new_events_and_contexts.values()) + + def _update_room_depths_txn(self, txn, events_and_contexts, backfilled): + """Update min_depth for each room + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + backfilled (bool): True if the events were backfilled + """ + depth_updates = {} + for event, context in events_and_contexts: + # Remove the any existing cache entries for the event_ids + txn.call_after(self._invalidate_get_event_cache, event.event_id) + if not backfilled: + txn.call_after( + self._events_stream_cache.entity_has_changed, + event.room_id, + event.internal_metadata.stream_ordering, + ) + + if not event.internal_metadata.is_outlier() and not context.rejected: + depth_updates[event.room_id] = max( + event.depth, depth_updates.get(event.room_id, event.depth) + ) + + for room_id, depth in iteritems(depth_updates): + self._update_min_depth_for_room_txn(txn, room_id, depth) + + def _update_outliers_txn(self, txn, events_and_contexts): + """Update any outliers with new event info. + + This turns outliers into ex-outliers (unless the new event was + rejected). + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + + Returns: + list[(EventBase, EventContext)] new list, without events which + are already in the events table. + """ + txn.execute( + "SELECT event_id, outlier FROM events WHERE event_id in (%s)" + % (",".join(["?"] * len(events_and_contexts)),), + [event.event_id for event, _ in events_and_contexts], + ) + + have_persisted = {event_id: outlier for event_id, outlier in txn} + + to_remove = set() + for event, context in events_and_contexts: + if event.event_id not in have_persisted: + continue + + to_remove.add(event) + + if context.rejected: + # If the event is rejected then we don't care if the event + # was an outlier or not. + continue + + outlier_persisted = have_persisted[event.event_id] + if not event.internal_metadata.is_outlier() and outlier_persisted: + # We received a copy of an event that we had already stored as + # an outlier in the database. We now have some state at that + # so we need to update the state_groups table with that state. + + # insert into event_to_state_groups. + try: + self._store_event_state_mappings_txn(txn, ((event, context),)) + except Exception: + logger.exception("") + raise + + metadata_json = encode_json(event.internal_metadata.get_dict()) + + sql = ( + "UPDATE event_json SET internal_metadata = ?" " WHERE event_id = ?" + ) + txn.execute(sql, (metadata_json, event.event_id)) + + # Add an entry to the ex_outlier_stream table to replicate the + # change in outlier status to our workers. + stream_order = event.internal_metadata.stream_ordering + state_group_id = context.state_group + self._simple_insert_txn( + txn, + table="ex_outlier_stream", + values={ + "event_stream_ordering": stream_order, + "event_id": event.event_id, + "state_group": state_group_id, + }, + ) + + sql = "UPDATE events SET outlier = ?" " WHERE event_id = ?" + txn.execute(sql, (False, event.event_id)) + + # Update the event_backward_extremities table now that this + # event isn't an outlier any more. + self._update_backward_extremeties(txn, [event]) + + return [ec for ec in events_and_contexts if ec[0] not in to_remove] + + @classmethod + def _delete_existing_rows_txn(cls, txn, events_and_contexts): + if not events_and_contexts: + # nothing to do here + return + + logger.info("Deleting existing") + + for table in ( + "events", + "event_auth", + "event_json", + "event_edges", + "event_forward_extremities", + "event_reference_hashes", + "event_search", + "event_to_state_groups", + "local_invites", + "state_events", + "rejections", + "redactions", + "room_memberships", + ): + txn.executemany( + "DELETE FROM %s WHERE event_id = ?" % (table,), + [(ev.event_id,) for ev, _ in events_and_contexts], + ) + + for table in ("event_push_actions",): + txn.executemany( + "DELETE FROM %s WHERE room_id = ? AND event_id = ?" % (table,), + [(ev.room_id, ev.event_id) for ev, _ in events_and_contexts], + ) + + def _store_event_txn(self, txn, events_and_contexts): + """Insert new events into the event and event_json tables + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + """ + + if not events_and_contexts: + # nothing to do here + return + + def event_dict(event): + d = event.get_dict() + d.pop("redacted", None) + d.pop("redacted_because", None) + return d + + self._simple_insert_many_txn( + txn, + table="event_json", + values=[ + { + "event_id": event.event_id, + "room_id": event.room_id, + "internal_metadata": encode_json( + event.internal_metadata.get_dict() + ), + "json": encode_json(event_dict(event)), + "format_version": event.format_version, + } + for event, _ in events_and_contexts + ], + ) + + self._simple_insert_many_txn( + txn, + table="events", + values=[ + { + "stream_ordering": event.internal_metadata.stream_ordering, + "topological_ordering": event.depth, + "depth": event.depth, + "event_id": event.event_id, + "room_id": event.room_id, + "type": event.type, + "processed": True, + "outlier": event.internal_metadata.is_outlier(), + "origin_server_ts": int(event.origin_server_ts), + "received_ts": self._clock.time_msec(), + "sender": event.sender, + "contains_url": ( + "url" in event.content + and isinstance(event.content["url"], text_type) + ), + } + for event, _ in events_and_contexts + ], + ) + + for event, _ in events_and_contexts: + if not event.internal_metadata.is_redacted(): + # If we're persisting an unredacted event we go and ensure + # that we mark any redactions that reference this event as + # requiring censoring. + self._simple_update_txn( + txn, + table="redactions", + keyvalues={"redacts": event.event_id}, + updatevalues={"have_censored": False}, + ) + + def _store_rejected_events_txn(self, txn, events_and_contexts): + """Add rows to the 'rejections' table for received events which were + rejected + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + + Returns: + list[(EventBase, EventContext)] new list, without the rejected + events. + """ + # Remove the rejected events from the list now that we've added them + # to the events table and the events_json table. + to_remove = set() + for event, context in events_and_contexts: + if context.rejected: + # Insert the event_id into the rejections table + self._store_rejections_txn(txn, event.event_id, context.rejected) + to_remove.add(event) + + return [ec for ec in events_and_contexts if ec[0] not in to_remove] + + def _update_metadata_tables_txn( + self, txn, events_and_contexts, all_events_and_contexts, backfilled + ): + """Update all the miscellaneous tables for new events + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + all_events_and_contexts (list[(EventBase, EventContext)]): all + events that we were going to persist. This includes events + we've already persisted, etc, that wouldn't appear in + events_and_context. + backfilled (bool): True if the events were backfilled + """ + + # Insert all the push actions into the event_push_actions table. + self._set_push_actions_for_event_and_users_txn( + txn, + events_and_contexts=events_and_contexts, + all_events_and_contexts=all_events_and_contexts, + ) + + if not events_and_contexts: + # nothing to do here + return + + for event, context in events_and_contexts: + if event.type == EventTypes.Redaction and event.redacts is not None: + # Remove the entries in the event_push_actions table for the + # redacted event. + self._remove_push_actions_for_event_id_txn( + txn, event.room_id, event.redacts + ) + + # Remove from relations table. + self._handle_redaction(txn, event.redacts) + + # Update the event_forward_extremities, event_backward_extremities and + # event_edges tables. + self._handle_mult_prev_events( + txn, events=[event for event, _ in events_and_contexts] + ) + + for event, _ in events_and_contexts: + if event.type == EventTypes.Name: + # Insert into the event_search table. + self._store_room_name_txn(txn, event) + elif event.type == EventTypes.Topic: + # Insert into the event_search table. + self._store_room_topic_txn(txn, event) + elif event.type == EventTypes.Message: + # Insert into the event_search table. + self._store_room_message_txn(txn, event) + elif event.type == EventTypes.Redaction: + # Insert into the redactions table. + self._store_redaction(txn, event) + + self._handle_event_relations(txn, event) + + # Insert into the room_memberships table. + self._store_room_members_txn( + txn, + [ + event + for event, _ in events_and_contexts + if event.type == EventTypes.Member + ], + backfilled=backfilled, + ) + + # Insert event_reference_hashes table. + self._store_event_reference_hashes_txn( + txn, [event for event, _ in events_and_contexts] + ) + + state_events_and_contexts = [ + ec for ec in events_and_contexts if ec[0].is_state() + ] + + state_values = [] + for event, context in state_events_and_contexts: + vals = { + "event_id": event.event_id, + "room_id": event.room_id, + "type": event.type, + "state_key": event.state_key, + } + + # TODO: How does this work with backfilling? + if hasattr(event, "replaces_state"): + vals["prev_state"] = event.replaces_state + + state_values.append(vals) + + self._simple_insert_many_txn(txn, table="state_events", values=state_values) + + # Prefill the event cache + self._add_to_cache(txn, events_and_contexts) + + def _add_to_cache(self, txn, events_and_contexts): + to_prefill = [] + + rows = [] + N = 200 + for i in range(0, len(events_and_contexts), N): + ev_map = {e[0].event_id: e[0] for e in events_and_contexts[i : i + N]} + if not ev_map: + break + + sql = ( + "SELECT " + " e.event_id as event_id, " + " r.redacts as redacts," + " rej.event_id as rejects " + " FROM events as e" + " LEFT JOIN rejections as rej USING (event_id)" + " LEFT JOIN redactions as r ON e.event_id = r.redacts" + " WHERE " + ) + + clause, args = make_in_list_sql_clause( + self.database_engine, "e.event_id", list(ev_map) + ) + + txn.execute(sql + clause, args) + rows = self.cursor_to_dict(txn) + for row in rows: + event = ev_map[row["event_id"]] + if not row["rejects"] and not row["redacts"]: + to_prefill.append( + _EventCacheEntry(event=event, redacted_event=None) + ) + + def prefill(): + for cache_entry in to_prefill: + self._get_event_cache.prefill((cache_entry[0].event_id,), cache_entry) + + txn.call_after(prefill) + + def _store_redaction(self, txn, event): + # invalidate the cache for the redacted event + txn.call_after(self._invalidate_get_event_cache, event.redacts) + + self._simple_insert_txn( + txn, + table="redactions", + values={ + "event_id": event.event_id, + "redacts": event.redacts, + "received_ts": self._clock.time_msec(), + }, + ) + + @defer.inlineCallbacks + def _censor_redactions(self): + """Censors all redactions older than the configured period that haven't + been censored yet. + + By censor we mean update the event_json table with the redacted event. + + Returns: + Deferred + """ + + if self.hs.config.redaction_retention_period is None: + return + + before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period + + # We fetch all redactions that: + # 1. point to an event we have, + # 2. has a received_ts from before the cut off, and + # 3. we haven't yet censored. + # + # This is limited to 100 events to ensure that we don't try and do too + # much at once. We'll get called again so this should eventually catch + # up. + sql = """ + SELECT redactions.event_id, redacts FROM redactions + LEFT JOIN events AS original_event ON ( + redacts = original_event.event_id + ) + WHERE NOT have_censored + AND redactions.received_ts <= ? + ORDER BY redactions.received_ts ASC + LIMIT ? + """ + + rows = yield self._execute( + "_censor_redactions_fetch", None, sql, before_ts, 100 + ) + + updates = [] + + for redaction_id, event_id in rows: + redaction_event = yield self.get_event(redaction_id, allow_none=True) + original_event = yield self.get_event( + event_id, allow_rejected=True, allow_none=True + ) + + # The SQL above ensures that we have both the redaction and + # original event, so if the `get_event` calls return None it + # means that the redaction wasn't allowed. Either way we know that + # the result won't change so we mark the fact that we've checked. + if ( + redaction_event + and original_event + and original_event.internal_metadata.is_redacted() + ): + # Redaction was allowed + pruned_json = encode_json(prune_event_dict(original_event.get_dict())) + else: + # Redaction wasn't allowed + pruned_json = None + + updates.append((redaction_id, event_id, pruned_json)) + + def _update_censor_txn(txn): + for redaction_id, event_id, pruned_json in updates: + if pruned_json: + self._simple_update_one_txn( + txn, + table="event_json", + keyvalues={"event_id": event_id}, + updatevalues={"json": pruned_json}, + ) + + self._simple_update_one_txn( + txn, + table="redactions", + keyvalues={"event_id": redaction_id}, + updatevalues={"have_censored": True}, + ) + + yield self.runInteraction("_update_censor_txn", _update_censor_txn) + + @defer.inlineCallbacks + def count_daily_messages(self): + """ + Returns an estimate of the number of messages sent in the last day. + + If it has been significantly less or more than one day since the last + call to this function, it will return None. + """ + + def _count_messages(txn): + sql = """ + SELECT COALESCE(COUNT(*), 0) FROM events + WHERE type = 'm.room.message' + AND stream_ordering > ? + """ + txn.execute(sql, (self.stream_ordering_day_ago,)) + count, = txn.fetchone() + return count + + ret = yield self.runInteraction("count_messages", _count_messages) + return ret + + @defer.inlineCallbacks + def count_daily_sent_messages(self): + def _count_messages(txn): + # This is good enough as if you have silly characters in your own + # hostname then thats your own fault. + like_clause = "%:" + self.hs.hostname + + sql = """ + SELECT COALESCE(COUNT(*), 0) FROM events + WHERE type = 'm.room.message' + AND sender LIKE ? + AND stream_ordering > ? + """ + + txn.execute(sql, (like_clause, self.stream_ordering_day_ago)) + count, = txn.fetchone() + return count + + ret = yield self.runInteraction("count_daily_sent_messages", _count_messages) + return ret + + @defer.inlineCallbacks + def count_daily_active_rooms(self): + def _count(txn): + sql = """ + SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events + WHERE type = 'm.room.message' + AND stream_ordering > ? + """ + txn.execute(sql, (self.stream_ordering_day_ago,)) + count, = txn.fetchone() + return count + + ret = yield self.runInteraction("count_daily_active_rooms", _count) + return ret + + def get_current_backfill_token(self): + """The current minimum token that backfilled events have reached""" + return -self._backfill_id_gen.get_current_token() + + def get_current_events_token(self): + """The current maximum token that events have reached""" + return self._stream_id_gen.get_current_token() + + def get_all_new_forward_event_rows(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed([]) + + def get_all_new_forward_event_rows(txn): + sql = ( + "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? < stream_ordering AND stream_ordering <= ?" + " ORDER BY stream_ordering ASC" + " LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + new_event_updates = txn.fetchall() + + if len(new_event_updates) == limit: + upper_bound = new_event_updates[-1][0] + else: + upper_bound = current_id + + sql = ( + "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " INNER JOIN ex_outlier_stream USING (event_id)" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? < event_stream_ordering" + " AND event_stream_ordering <= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (last_id, upper_bound)) + new_event_updates.extend(txn) + + return new_event_updates + + return self.runInteraction( + "get_all_new_forward_event_rows", get_all_new_forward_event_rows + ) + + def get_all_new_backfill_event_rows(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed([]) + + def get_all_new_backfill_event_rows(txn): + sql = ( + "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? > stream_ordering AND stream_ordering >= ?" + " ORDER BY stream_ordering ASC" + " LIMIT ?" + ) + txn.execute(sql, (-last_id, -current_id, limit)) + new_event_updates = txn.fetchall() + + if len(new_event_updates) == limit: + upper_bound = new_event_updates[-1][0] + else: + upper_bound = current_id + + sql = ( + "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " INNER JOIN ex_outlier_stream USING (event_id)" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? > event_stream_ordering" + " AND event_stream_ordering >= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (-last_id, -upper_bound)) + new_event_updates.extend(txn.fetchall()) + + return new_event_updates + + return self.runInteraction( + "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows + ) + + @cached(num_args=5, max_entries=10) + def get_all_new_events( + self, + last_backfill_id, + last_forward_id, + current_backfill_id, + current_forward_id, + limit, + ): + """Get all the new events that have arrived at the server either as + new events or as backfilled events""" + have_backfill_events = last_backfill_id != current_backfill_id + have_forward_events = last_forward_id != current_forward_id + + if not have_backfill_events and not have_forward_events: + return defer.succeed(AllNewEventsResult([], [], [], [], [])) + + def get_all_new_events_txn(txn): + sql = ( + "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " WHERE ? < stream_ordering AND stream_ordering <= ?" + " ORDER BY stream_ordering ASC" + " LIMIT ?" + ) + if have_forward_events: + txn.execute(sql, (last_forward_id, current_forward_id, limit)) + new_forward_events = txn.fetchall() + + if len(new_forward_events) == limit: + upper_bound = new_forward_events[-1][0] + else: + upper_bound = current_forward_id + + sql = ( + "SELECT event_stream_ordering, event_id, state_group" + " FROM ex_outlier_stream" + " WHERE ? > event_stream_ordering" + " AND event_stream_ordering >= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (last_forward_id, upper_bound)) + forward_ex_outliers = txn.fetchall() + else: + new_forward_events = [] + forward_ex_outliers = [] + + sql = ( + "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " WHERE ? > stream_ordering AND stream_ordering >= ?" + " ORDER BY stream_ordering DESC" + " LIMIT ?" + ) + if have_backfill_events: + txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit)) + new_backfill_events = txn.fetchall() + + if len(new_backfill_events) == limit: + upper_bound = new_backfill_events[-1][0] + else: + upper_bound = current_backfill_id + + sql = ( + "SELECT -event_stream_ordering, event_id, state_group" + " FROM ex_outlier_stream" + " WHERE ? > event_stream_ordering" + " AND event_stream_ordering >= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (-last_backfill_id, -upper_bound)) + backward_ex_outliers = txn.fetchall() + else: + new_backfill_events = [] + backward_ex_outliers = [] + + return AllNewEventsResult( + new_forward_events, + new_backfill_events, + forward_ex_outliers, + backward_ex_outliers, + ) + + return self.runInteraction("get_all_new_events", get_all_new_events_txn) + + def purge_history(self, room_id, token, delete_local_events): + """Deletes room history before a certain point + + Args: + room_id (str): + + token (str): A topological token to delete events before + + delete_local_events (bool): + if True, we will delete local events as well as remote ones + (instead of just marking them as outliers and deleting their + state groups). + """ + + return self.runInteraction( + "purge_history", + self._purge_history_txn, + room_id, + token, + delete_local_events, + ) + + def _purge_history_txn(self, txn, room_id, token_str, delete_local_events): + token = RoomStreamToken.parse(token_str) + + # Tables that should be pruned: + # event_auth + # event_backward_extremities + # event_edges + # event_forward_extremities + # event_json + # event_push_actions + # event_reference_hashes + # event_search + # event_to_state_groups + # events + # rejections + # room_depth + # state_groups + # state_groups_state + + # we will build a temporary table listing the events so that we don't + # have to keep shovelling the list back and forth across the + # connection. Annoyingly the python sqlite driver commits the + # transaction on CREATE, so let's do this first. + # + # furthermore, we might already have the table from a previous (failed) + # purge attempt, so let's drop the table first. + + txn.execute("DROP TABLE IF EXISTS events_to_purge") + + txn.execute( + "CREATE TEMPORARY TABLE events_to_purge (" + " event_id TEXT NOT NULL," + " should_delete BOOLEAN NOT NULL" + ")" + ) + + # First ensure that we're not about to delete all the forward extremeties + txn.execute( + "SELECT e.event_id, e.depth FROM events as e " + "INNER JOIN event_forward_extremities as f " + "ON e.event_id = f.event_id " + "AND e.room_id = f.room_id " + "WHERE f.room_id = ?", + (room_id,), + ) + rows = txn.fetchall() + max_depth = max(row[1] for row in rows) + + if max_depth < token.topological: + # We need to ensure we don't delete all the events from the database + # otherwise we wouldn't be able to send any events (due to not + # having any backwards extremeties) + raise SynapseError( + 400, "topological_ordering is greater than forward extremeties" + ) + + logger.info("[purge] looking for events to delete") + + should_delete_expr = "state_key IS NULL" + should_delete_params = () + if not delete_local_events: + should_delete_expr += " AND event_id NOT LIKE ?" + + # We include the parameter twice since we use the expression twice + should_delete_params += ("%:" + self.hs.hostname, "%:" + self.hs.hostname) + + should_delete_params += (room_id, token.topological) + + # Note that we insert events that are outliers and aren't going to be + # deleted, as nothing will happen to them. + txn.execute( + "INSERT INTO events_to_purge" + " SELECT event_id, %s" + " FROM events AS e LEFT JOIN state_events USING (event_id)" + " WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?" + % (should_delete_expr, should_delete_expr), + should_delete_params, + ) + + # We create the indices *after* insertion as that's a lot faster. + + # create an index on should_delete because later we'll be looking for + # the should_delete / shouldn't_delete subsets + txn.execute( + "CREATE INDEX events_to_purge_should_delete" + " ON events_to_purge(should_delete)" + ) + + # We do joins against events_to_purge for e.g. calculating state + # groups to purge, etc., so lets make an index. + txn.execute("CREATE INDEX events_to_purge_id" " ON events_to_purge(event_id)") + + txn.execute("SELECT event_id, should_delete FROM events_to_purge") + event_rows = txn.fetchall() + logger.info( + "[purge] found %i events before cutoff, of which %i can be deleted", + len(event_rows), + sum(1 for e in event_rows if e[1]), + ) + + logger.info("[purge] Finding new backward extremities") + + # We calculate the new entries for the backward extremeties by finding + # events to be purged that are pointed to by events we're not going to + # purge. + txn.execute( + "SELECT DISTINCT e.event_id FROM events_to_purge AS e" + " INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id" + " LEFT JOIN events_to_purge AS ep2 ON ed.event_id = ep2.event_id" + " WHERE ep2.event_id IS NULL" + ) + new_backwards_extrems = txn.fetchall() + + logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems) + + txn.execute( + "DELETE FROM event_backward_extremities WHERE room_id = ?", (room_id,) + ) + + # Update backward extremeties + txn.executemany( + "INSERT INTO event_backward_extremities (room_id, event_id)" + " VALUES (?, ?)", + [(room_id, event_id) for event_id, in new_backwards_extrems], + ) + + logger.info("[purge] finding redundant state groups") + + # Get all state groups that are referenced by events that are to be + # deleted. We then go and check if they are referenced by other events + # or state groups, and if not we delete them. + txn.execute( + """ + SELECT DISTINCT state_group FROM events_to_purge + INNER JOIN event_to_state_groups USING (event_id) + """ + ) + + referenced_state_groups = set(sg for sg, in txn) + logger.info( + "[purge] found %i referenced state groups", len(referenced_state_groups) + ) + + logger.info("[purge] finding state groups that can be deleted") + + _ = self._find_unreferenced_groups_during_purge(txn, referenced_state_groups) + state_groups_to_delete, remaining_state_groups = _ + + logger.info( + "[purge] found %i state groups to delete", len(state_groups_to_delete) + ) + + logger.info( + "[purge] de-delta-ing %i remaining state groups", + len(remaining_state_groups), + ) + + # Now we turn the state groups that reference to-be-deleted state + # groups to non delta versions. + for sg in remaining_state_groups: + logger.info("[purge] de-delta-ing remaining state group %s", sg) + curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) + curr_state = curr_state[sg] + + self._simple_delete_txn( + txn, table="state_groups_state", keyvalues={"state_group": sg} + ) + + self._simple_delete_txn( + txn, table="state_group_edges", keyvalues={"state_group": sg} + ) + + self._simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": sg, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(curr_state) + ], + ) + + logger.info("[purge] removing redundant state groups") + txn.executemany( + "DELETE FROM state_groups_state WHERE state_group = ?", + ((sg,) for sg in state_groups_to_delete), + ) + txn.executemany( + "DELETE FROM state_groups WHERE id = ?", + ((sg,) for sg in state_groups_to_delete), + ) + + logger.info("[purge] removing events from event_to_state_groups") + txn.execute( + "DELETE FROM event_to_state_groups " + "WHERE event_id IN (SELECT event_id from events_to_purge)" + ) + for event_id, _ in event_rows: + txn.call_after(self._get_state_group_for_event.invalidate, (event_id,)) + + # Delete all remote non-state events + for table in ( + "events", + "event_json", + "event_auth", + "event_edges", + "event_forward_extremities", + "event_reference_hashes", + "event_search", + "rejections", + ): + logger.info("[purge] removing events from %s", table) + + txn.execute( + "DELETE FROM %s WHERE event_id IN (" + " SELECT event_id FROM events_to_purge WHERE should_delete" + ")" % (table,) + ) + + # event_push_actions lacks an index on event_id, and has one on + # (room_id, event_id) instead. + for table in ("event_push_actions",): + logger.info("[purge] removing events from %s", table) + + txn.execute( + "DELETE FROM %s WHERE room_id = ? AND event_id IN (" + " SELECT event_id FROM events_to_purge WHERE should_delete" + ")" % (table,), + (room_id,), + ) + + # Mark all state and own events as outliers + logger.info("[purge] marking remaining events as outliers") + txn.execute( + "UPDATE events SET outlier = ?" + " WHERE event_id IN (" + " SELECT event_id FROM events_to_purge " + " WHERE NOT should_delete" + ")", + (True,), + ) + + # synapse tries to take out an exclusive lock on room_depth whenever it + # persists events (because upsert), and once we run this update, we + # will block that for the rest of our transaction. + # + # So, let's stick it at the end so that we don't block event + # persistence. + # + # We do this by calculating the minimum depth of the backwards + # extremities. However, the events in event_backward_extremities + # are ones we don't have yet so we need to look at the events that + # point to it via event_edges table. + txn.execute( + """ + SELECT COALESCE(MIN(depth), 0) + FROM event_backward_extremities AS eb + INNER JOIN event_edges AS eg ON eg.prev_event_id = eb.event_id + INNER JOIN events AS e ON e.event_id = eg.event_id + WHERE eb.room_id = ? + """, + (room_id,), + ) + min_depth, = txn.fetchone() + + logger.info("[purge] updating room_depth to %d", min_depth) + + txn.execute( + "UPDATE room_depth SET min_depth = ? WHERE room_id = ?", + (min_depth, room_id), + ) + + # finally, drop the temp table. this will commit the txn in sqlite, + # so make sure to keep this actually last. + txn.execute("DROP TABLE events_to_purge") + + logger.info("[purge] done") + + def _find_unreferenced_groups_during_purge(self, txn, state_groups): + """Used when purging history to figure out which state groups can be + deleted and which need to be de-delta'ed (due to one of its prev groups + being scheduled for deletion). + + Args: + txn + state_groups (set[int]): Set of state groups referenced by events + that are going to be deleted. + + Returns: + tuple[set[int], set[int]]: The set of state groups that can be + deleted and the set of state groups that need to be de-delta'ed + """ + # Graph of state group -> previous group + graph = {} + + # Set of events that we have found to be referenced by events + referenced_groups = set() + + # Set of state groups we've already seen + state_groups_seen = set(state_groups) + + # Set of state groups to handle next. + next_to_search = set(state_groups) + while next_to_search: + # We bound size of groups we're looking up at once, to stop the + # SQL query getting too big + if len(next_to_search) < 100: + current_search = next_to_search + next_to_search = set() + else: + current_search = set(itertools.islice(next_to_search, 100)) + next_to_search -= current_search + + # Check if state groups are referenced + sql = """ + SELECT DISTINCT state_group FROM event_to_state_groups + LEFT JOIN events_to_purge AS ep USING (event_id) + WHERE ep.event_id IS NULL AND + """ + clause, args = make_in_list_sql_clause( + txn.database_engine, "state_group", current_search + ) + txn.execute(sql + clause, list(args)) + + referenced = set(sg for sg, in txn) + referenced_groups |= referenced + + # We don't continue iterating up the state group graphs for state + # groups that are referenced. + current_search -= referenced + + rows = self._simple_select_many_txn( + txn, + table="state_group_edges", + column="prev_state_group", + iterable=current_search, + keyvalues={}, + retcols=("prev_state_group", "state_group"), + ) + + prevs = set(row["state_group"] for row in rows) + # We don't bother re-handling groups we've already seen + prevs -= state_groups_seen + next_to_search |= prevs + state_groups_seen |= prevs + + for row in rows: + # Note: Each state group can have at most one prev group + graph[row["state_group"]] = row["prev_state_group"] + + to_delete = state_groups_seen - referenced_groups + + to_dedelta = set() + for sg in referenced_groups: + prev_sg = graph.get(sg) + if prev_sg and prev_sg in to_delete: + to_dedelta.add(sg) + + return to_delete, to_dedelta + + def purge_room(self, room_id): + """Deletes all record of a room + + Args: + room_id (str): + """ + + return self.runInteraction("purge_room", self._purge_room_txn, room_id) + + def _purge_room_txn(self, txn, room_id): + # first we have to delete the state groups states + logger.info("[purge] removing %s from state_groups_state", room_id) + + txn.execute( + """ + DELETE FROM state_groups_state WHERE state_group IN ( + SELECT state_group FROM events JOIN event_to_state_groups USING(event_id) + WHERE events.room_id=? + ) + """, + (room_id,), + ) + + # ... and the state group edges + logger.info("[purge] removing %s from state_group_edges", room_id) + + txn.execute( + """ + DELETE FROM state_group_edges WHERE state_group IN ( + SELECT state_group FROM events JOIN event_to_state_groups USING(event_id) + WHERE events.room_id=? + ) + """, + (room_id,), + ) + + # ... and the state groups + logger.info("[purge] removing %s from state_groups", room_id) + + txn.execute( + """ + DELETE FROM state_groups WHERE id IN ( + SELECT state_group FROM events JOIN event_to_state_groups USING(event_id) + WHERE events.room_id=? + ) + """, + (room_id,), + ) + + # and then tables which lack an index on room_id but have one on event_id + for table in ( + "event_auth", + "event_edges", + "event_push_actions_staging", + "event_reference_hashes", + "event_relations", + "event_to_state_groups", + "redactions", + "rejections", + "state_events", + ): + logger.info("[purge] removing %s from %s", room_id, table) + + txn.execute( + """ + DELETE FROM %s WHERE event_id IN ( + SELECT event_id FROM events WHERE room_id=? + ) + """ + % (table,), + (room_id,), + ) + + # and finally, the tables with an index on room_id (or no useful index) + for table in ( + "current_state_events", + "event_backward_extremities", + "event_forward_extremities", + "event_json", + "event_push_actions", + "event_search", + "events", + "group_rooms", + "public_room_list_stream", + "receipts_graph", + "receipts_linearized", + "room_aliases", + "room_depth", + "room_memberships", + "room_stats_state", + "room_stats_current", + "room_stats_historical", + "room_stats_earliest_token", + "rooms", + "stream_ordering_to_exterm", + "topics", + "users_in_public_rooms", + "users_who_share_private_rooms", + # no useful index, but let's clear them anyway + "appservice_room_list", + "e2e_room_keys", + "event_push_summary", + "pusher_throttle", + "group_summary_rooms", + "local_invites", + "room_account_data", + "room_tags", + ): + logger.info("[purge] removing %s from %s", room_id, table) + txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,)) + + # Other tables we do NOT need to clear out: + # + # - blocked_rooms + # This is important, to make sure that we don't accidentally rejoin a blocked + # room after it was purged + # + # - user_directory + # This has a room_id column, but it is unused + # + + # Other tables that we might want to consider clearing out include: + # + # - event_reports + # Given that these are intended for abuse management my initial + # inclination is to leave them in place. + # + # - current_state_delta_stream + # - ex_outlier_stream + # - room_tags_revisions + # The problem with these is that they are largeish and there is no room_id + # index on them. In any case we should be clearing out 'stream' tables + # periodically anyway (#5888) + + # TODO: we could probably usefully do a bunch of cache invalidation here + + logger.info("[purge] done") + + @defer.inlineCallbacks + def is_event_after(self, event_id1, event_id2): + """Returns True if event_id1 is after event_id2 in the stream + """ + to_1, so_1 = yield self._get_event_ordering(event_id1) + to_2, so_2 = yield self._get_event_ordering(event_id2) + return (to_1, so_1) > (to_2, so_2) + + @cachedInlineCallbacks(max_entries=5000) + def _get_event_ordering(self, event_id): + res = yield self._simple_select_one( + table="events", + retcols=["topological_ordering", "stream_ordering"], + keyvalues={"event_id": event_id}, + allow_none=True, + ) + + if not res: + raise SynapseError(404, "Could not find event %s" % (event_id,)) + + return (int(res["topological_ordering"]), int(res["stream_ordering"])) + + def get_all_updated_current_state_deltas(self, from_token, to_token, limit): + def get_all_updated_current_state_deltas_txn(txn): + sql = """ + SELECT stream_id, room_id, type, state_key, event_id + FROM current_state_delta_stream + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC LIMIT ? + """ + txn.execute(sql, (from_token, to_token, limit)) + return txn.fetchall() + + return self.runInteraction( + "get_all_updated_current_state_deltas", + get_all_updated_current_state_deltas_txn, + ) + + +AllNewEventsResult = namedtuple( + "AllNewEventsResult", + [ + "new_forward_events", + "new_backfill_events", + "forward_ex_outliers", + "backward_ex_outliers", + ], +) diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py new file mode 100644 index 0000000000..31ea6f917f --- /dev/null +++ b/synapse/storage/data_stores/main/events_bg_updates.py @@ -0,0 +1,505 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from six import text_type + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.storage._base import make_in_list_sql_clause +from synapse.storage.background_updates import BackgroundUpdateStore + +logger = logging.getLogger(__name__) + + +class EventsBackgroundUpdatesStore(BackgroundUpdateStore): + + EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" + EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" + DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" + + def __init__(self, db_conn, hs): + super(EventsBackgroundUpdatesStore, self).__init__(db_conn, hs) + + self.register_background_update_handler( + self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts + ) + self.register_background_update_handler( + self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, + self._background_reindex_fields_sender, + ) + + self.register_background_index_update( + "event_contains_url_index", + index_name="event_contains_url_index", + table="events", + columns=["room_id", "topological_ordering", "stream_ordering"], + where_clause="contains_url = true AND outlier = false", + ) + + # an event_id index on event_search is useful for the purge_history + # api. Plus it means we get to enforce some integrity with a UNIQUE + # clause + self.register_background_index_update( + "event_search_event_id_idx", + index_name="event_search_event_id_idx", + table="event_search", + columns=["event_id"], + unique=True, + psql_only=True, + ) + + self.register_background_update_handler( + self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update + ) + + self.register_background_update_handler( + "redactions_received_ts", self._redactions_received_ts + ) + + # This index gets deleted in `event_fix_redactions_bytes` update + self.register_background_index_update( + "event_fix_redactions_bytes_create_index", + index_name="redactions_censored_redacts", + table="redactions", + columns=["redacts"], + where_clause="have_censored", + ) + + self.register_background_update_handler( + "event_fix_redactions_bytes", self._event_fix_redactions_bytes + ) + + @defer.inlineCallbacks + def _background_reindex_fields_sender(self, progress, batch_size): + target_min_stream_id = progress["target_min_stream_id_inclusive"] + max_stream_id = progress["max_stream_id_exclusive"] + rows_inserted = progress.get("rows_inserted", 0) + + INSERT_CLUMP_SIZE = 1000 + + def reindex_txn(txn): + sql = ( + "SELECT stream_ordering, event_id, json FROM events" + " INNER JOIN event_json USING (event_id)" + " WHERE ? <= stream_ordering AND stream_ordering < ?" + " ORDER BY stream_ordering DESC" + " LIMIT ?" + ) + + txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) + + rows = txn.fetchall() + if not rows: + return 0 + + min_stream_id = rows[-1][0] + + update_rows = [] + for row in rows: + try: + event_id = row[1] + event_json = json.loads(row[2]) + sender = event_json["sender"] + content = event_json["content"] + + contains_url = "url" in content + if contains_url: + contains_url &= isinstance(content["url"], text_type) + except (KeyError, AttributeError): + # If the event is missing a necessary field then + # skip over it. + continue + + update_rows.append((sender, contains_url, event_id)) + + sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?" + + for index in range(0, len(update_rows), INSERT_CLUMP_SIZE): + clump = update_rows[index : index + INSERT_CLUMP_SIZE] + txn.executemany(sql, clump) + + progress = { + "target_min_stream_id_inclusive": target_min_stream_id, + "max_stream_id_exclusive": min_stream_id, + "rows_inserted": rows_inserted + len(rows), + } + + self._background_update_progress_txn( + txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress + ) + + return len(rows) + + result = yield self.runInteraction( + self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn + ) + + if not result: + yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME) + + return result + + @defer.inlineCallbacks + def _background_reindex_origin_server_ts(self, progress, batch_size): + target_min_stream_id = progress["target_min_stream_id_inclusive"] + max_stream_id = progress["max_stream_id_exclusive"] + rows_inserted = progress.get("rows_inserted", 0) + + INSERT_CLUMP_SIZE = 1000 + + def reindex_search_txn(txn): + sql = ( + "SELECT stream_ordering, event_id FROM events" + " WHERE ? <= stream_ordering AND stream_ordering < ?" + " ORDER BY stream_ordering DESC" + " LIMIT ?" + ) + + txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) + + rows = txn.fetchall() + if not rows: + return 0 + + min_stream_id = rows[-1][0] + event_ids = [row[1] for row in rows] + + rows_to_update = [] + + chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)] + for chunk in chunks: + ev_rows = self._simple_select_many_txn( + txn, + table="event_json", + column="event_id", + iterable=chunk, + retcols=["event_id", "json"], + keyvalues={}, + ) + + for row in ev_rows: + event_id = row["event_id"] + event_json = json.loads(row["json"]) + try: + origin_server_ts = event_json["origin_server_ts"] + except (KeyError, AttributeError): + # If the event is missing a necessary field then + # skip over it. + continue + + rows_to_update.append((origin_server_ts, event_id)) + + sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?" + + for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE): + clump = rows_to_update[index : index + INSERT_CLUMP_SIZE] + txn.executemany(sql, clump) + + progress = { + "target_min_stream_id_inclusive": target_min_stream_id, + "max_stream_id_exclusive": min_stream_id, + "rows_inserted": rows_inserted + len(rows_to_update), + } + + self._background_update_progress_txn( + txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress + ) + + return len(rows_to_update) + + result = yield self.runInteraction( + self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn + ) + + if not result: + yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME) + + return result + + @defer.inlineCallbacks + def _cleanup_extremities_bg_update(self, progress, batch_size): + """Background update to clean out extremities that should have been + deleted previously. + + Mainly used to deal with the aftermath of #5269. + """ + + # This works by first copying all existing forward extremities into the + # `_extremities_to_check` table at start up, and then checking each + # event in that table whether we have any descendants that are not + # soft-failed/rejected. If that is the case then we delete that event + # from the forward extremities table. + # + # For efficiency, we do this in batches by recursively pulling out all + # descendants of a batch until we find the non soft-failed/rejected + # events, i.e. the set of descendants whose chain of prev events back + # to the batch of extremities are all soft-failed or rejected. + # Typically, we won't find any such events as extremities will rarely + # have any descendants, but if they do then we should delete those + # extremities. + + def _cleanup_extremities_bg_update_txn(txn): + # The set of extremity event IDs that we're checking this round + original_set = set() + + # A dict[str, set[str]] of event ID to their prev events. + graph = {} + + # The set of descendants of the original set that are not rejected + # nor soft-failed. Ancestors of these events should be removed + # from the forward extremities table. + non_rejected_leaves = set() + + # Set of event IDs that have been soft failed, and for which we + # should check if they have descendants which haven't been soft + # failed. + soft_failed_events_to_lookup = set() + + # First, we get `batch_size` events from the table, pulling out + # their successor events, if any, and the successor events' + # rejection status. + txn.execute( + """SELECT prev_event_id, event_id, internal_metadata, + rejections.event_id IS NOT NULL, events.outlier + FROM ( + SELECT event_id AS prev_event_id + FROM _extremities_to_check + LIMIT ? + ) AS f + LEFT JOIN event_edges USING (prev_event_id) + LEFT JOIN events USING (event_id) + LEFT JOIN event_json USING (event_id) + LEFT JOIN rejections USING (event_id) + """, + (batch_size,), + ) + + for prev_event_id, event_id, metadata, rejected, outlier in txn: + original_set.add(prev_event_id) + + if not event_id or outlier: + # Common case where the forward extremity doesn't have any + # descendants. + continue + + graph.setdefault(event_id, set()).add(prev_event_id) + + soft_failed = False + if metadata: + soft_failed = json.loads(metadata).get("soft_failed") + + if soft_failed or rejected: + soft_failed_events_to_lookup.add(event_id) + else: + non_rejected_leaves.add(event_id) + + # Now we recursively check all the soft-failed descendants we + # found above in the same way, until we have nothing left to + # check. + while soft_failed_events_to_lookup: + # We only want to do 100 at a time, so we split given list + # into two. + batch = list(soft_failed_events_to_lookup) + to_check, to_defer = batch[:100], batch[100:] + soft_failed_events_to_lookup = set(to_defer) + + sql = """SELECT prev_event_id, event_id, internal_metadata, + rejections.event_id IS NOT NULL + FROM event_edges + INNER JOIN events USING (event_id) + INNER JOIN event_json USING (event_id) + LEFT JOIN rejections USING (event_id) + WHERE + NOT events.outlier + AND + """ + clause, args = make_in_list_sql_clause( + self.database_engine, "prev_event_id", to_check + ) + txn.execute(sql + clause, list(args)) + + for prev_event_id, event_id, metadata, rejected in txn: + if event_id in graph: + # Already handled this event previously, but we still + # want to record the edge. + graph[event_id].add(prev_event_id) + continue + + graph[event_id] = {prev_event_id} + + soft_failed = json.loads(metadata).get("soft_failed") + if soft_failed or rejected: + soft_failed_events_to_lookup.add(event_id) + else: + non_rejected_leaves.add(event_id) + + # We have a set of non-soft-failed descendants, so we recurse up + # the graph to find all ancestors and add them to the set of event + # IDs that we can delete from forward extremities table. + to_delete = set() + while non_rejected_leaves: + event_id = non_rejected_leaves.pop() + prev_event_ids = graph.get(event_id, set()) + non_rejected_leaves.update(prev_event_ids) + to_delete.update(prev_event_ids) + + to_delete.intersection_update(original_set) + + deleted = self._simple_delete_many_txn( + txn=txn, + table="event_forward_extremities", + column="event_id", + iterable=to_delete, + keyvalues={}, + ) + + logger.info( + "Deleted %d forward extremities of %d checked, to clean up #5269", + deleted, + len(original_set), + ) + + if deleted: + # We now need to invalidate the caches of these rooms + rows = self._simple_select_many_txn( + txn, + table="events", + column="event_id", + iterable=to_delete, + keyvalues={}, + retcols=("room_id",), + ) + room_ids = set(row["room_id"] for row in rows) + for room_id in room_ids: + txn.call_after( + self.get_latest_event_ids_in_room.invalidate, (room_id,) + ) + + self._simple_delete_many_txn( + txn=txn, + table="_extremities_to_check", + column="event_id", + iterable=original_set, + keyvalues={}, + ) + + return len(original_set) + + num_handled = yield self.runInteraction( + "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn + ) + + if not num_handled: + yield self._end_background_update(self.DELETE_SOFT_FAILED_EXTREMITIES) + + def _drop_table_txn(txn): + txn.execute("DROP TABLE _extremities_to_check") + + yield self.runInteraction( + "_cleanup_extremities_bg_update_drop_table", _drop_table_txn + ) + + return num_handled + + @defer.inlineCallbacks + def _redactions_received_ts(self, progress, batch_size): + """Handles filling out the `received_ts` column in redactions. + """ + last_event_id = progress.get("last_event_id", "") + + def _redactions_received_ts_txn(txn): + # Fetch the set of event IDs that we want to update + sql = """ + SELECT event_id FROM redactions + WHERE event_id > ? + ORDER BY event_id ASC + LIMIT ? + """ + + txn.execute(sql, (last_event_id, batch_size)) + + rows = txn.fetchall() + if not rows: + return 0 + + upper_event_id, = rows[-1] + + # Update the redactions with the received_ts. + # + # Note: Not all events have an associated received_ts, so we + # fallback to using origin_server_ts. If we for some reason don't + # have an origin_server_ts, lets just use the current timestamp. + # + # We don't want to leave it null, as then we'll never try and + # censor those redactions. + sql = """ + UPDATE redactions + SET received_ts = ( + SELECT COALESCE(received_ts, origin_server_ts, ?) FROM events + WHERE events.event_id = redactions.event_id + ) + WHERE ? <= event_id AND event_id <= ? + """ + + txn.execute(sql, (self._clock.time_msec(), last_event_id, upper_event_id)) + + self._background_update_progress_txn( + txn, "redactions_received_ts", {"last_event_id": upper_event_id} + ) + + return len(rows) + + count = yield self.runInteraction( + "_redactions_received_ts", _redactions_received_ts_txn + ) + + if not count: + yield self._end_background_update("redactions_received_ts") + + return count + + @defer.inlineCallbacks + def _event_fix_redactions_bytes(self, progress, batch_size): + """Undoes hex encoded censored redacted event JSON. + """ + + def _event_fix_redactions_bytes_txn(txn): + # This update is quite fast due to new index. + txn.execute( + """ + UPDATE event_json + SET + json = convert_from(json::bytea, 'utf8') + FROM redactions + WHERE + redactions.have_censored + AND event_json.event_id = redactions.redacts + AND json NOT LIKE '{%'; + """ + ) + + txn.execute("DROP INDEX redactions_censored_redacts") + + yield self.runInteraction( + "_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn + ) + + yield self._end_background_update("event_fix_redactions_bytes") + + return 1 diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py new file mode 100644 index 0000000000..4c4b76bd93 --- /dev/null +++ b/synapse/storage/data_stores/main/events_worker.py @@ -0,0 +1,882 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division + +import itertools +import logging +from collections import namedtuple + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.api.constants import EventTypes +from synapse.api.errors import NotFoundError +from synapse.api.room_versions import EventFormatVersions +from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401 +from synapse.events.snapshot import EventContext # noqa: F401 +from synapse.events.utils import prune_event +from synapse.logging.context import LoggingContext, PreserveLoggingContext +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.types import get_domain_from_id +from synapse.util import batch_iter +from synapse.util.metrics import Measure + +logger = logging.getLogger(__name__) + + +# These values are used in the `enqueus_event` and `_do_fetch` methods to +# control how we batch/bulk fetch events from the database. +# The values are plucked out of thing air to make initial sync run faster +# on jki.re +# TODO: Make these configurable. +EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events +EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events +EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events + + +_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) + + +class EventsWorkerStore(SQLBaseStore): + def get_received_ts(self, event_id): + """Get received_ts (when it was persisted) for the event. + + Raises an exception for unknown events. + + Args: + event_id (str) + + Returns: + Deferred[int|None]: Timestamp in milliseconds, or None for events + that were persisted before received_ts was implemented. + """ + return self._simple_select_one_onecol( + table="events", + keyvalues={"event_id": event_id}, + retcol="received_ts", + desc="get_received_ts", + ) + + def get_received_ts_by_stream_pos(self, stream_ordering): + """Given a stream ordering get an approximate timestamp of when it + happened. + + This is done by simply taking the received ts of the first event that + has a stream ordering greater than or equal to the given stream pos. + If none exists returns the current time, on the assumption that it must + have happened recently. + + Args: + stream_ordering (int) + + Returns: + Deferred[int] + """ + + def _get_approximate_received_ts_txn(txn): + sql = """ + SELECT received_ts FROM events + WHERE stream_ordering >= ? + LIMIT 1 + """ + + txn.execute(sql, (stream_ordering,)) + row = txn.fetchone() + if row and row[0]: + ts = row[0] + else: + ts = self.clock.time_msec() + + return ts + + return self.runInteraction( + "get_approximate_received_ts", _get_approximate_received_ts_txn + ) + + @defer.inlineCallbacks + def get_event( + self, + event_id, + check_redacted=True, + get_prev_content=False, + allow_rejected=False, + allow_none=False, + check_room_id=None, + ): + """Get an event from the database by event_id. + + Args: + event_id (str): The event_id of the event to fetch + check_redacted (bool): If True, check if event has been redacted + and redact it. + get_prev_content (bool): If True and event is a state event, + include the previous states content in the unsigned field. + allow_rejected (bool): If True return rejected events. + allow_none (bool): If True, return None if no event found, if + False throw a NotFoundError + check_room_id (str|None): if not None, check the room of the found event. + If there is a mismatch, behave as per allow_none. + + Returns: + Deferred[EventBase|None] + """ + if not isinstance(event_id, str): + raise TypeError("Invalid event event_id %r" % (event_id,)) + + events = yield self.get_events_as_list( + [event_id], + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + event = events[0] if events else None + + if event is not None and check_room_id is not None: + if event.room_id != check_room_id: + event = None + + if event is None and not allow_none: + raise NotFoundError("Could not find event %s" % (event_id,)) + + return event + + @defer.inlineCallbacks + def get_events( + self, + event_ids, + check_redacted=True, + get_prev_content=False, + allow_rejected=False, + ): + """Get events from the database + + Args: + event_ids (list): The event_ids of the events to fetch + check_redacted (bool): If True, check if event has been redacted + and redact it. + get_prev_content (bool): If True and event is a state event, + include the previous states content in the unsigned field. + allow_rejected (bool): If True return rejected events. + + Returns: + Deferred : Dict from event_id to event. + """ + events = yield self.get_events_as_list( + event_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + return {e.event_id: e for e in events} + + @defer.inlineCallbacks + def get_events_as_list( + self, + event_ids, + check_redacted=True, + get_prev_content=False, + allow_rejected=False, + ): + """Get events from the database and return in a list in the same order + as given by `event_ids` arg. + + Args: + event_ids (list): The event_ids of the events to fetch + check_redacted (bool): If True, check if event has been redacted + and redact it. + get_prev_content (bool): If True and event is a state event, + include the previous states content in the unsigned field. + allow_rejected (bool): If True return rejected events. + + Returns: + Deferred[list[EventBase]]: List of events fetched from the database. The + events are in the same order as `event_ids` arg. + + Note that the returned list may be smaller than the list of event + IDs if not all events could be fetched. + """ + + if not event_ids: + return [] + + # there may be duplicates so we cast the list to a set + event_entry_map = yield self._get_events_from_cache_or_db( + set(event_ids), allow_rejected=allow_rejected + ) + + events = [] + for event_id in event_ids: + entry = event_entry_map.get(event_id, None) + if not entry: + continue + + if not allow_rejected: + assert not entry.event.rejected_reason, ( + "rejected event returned from _get_events_from_cache_or_db despite " + "allow_rejected=False" + ) + + # We may not have had the original event when we received a redaction, so + # we have to recheck auth now. + + if not allow_rejected and entry.event.type == EventTypes.Redaction: + if not hasattr(entry.event, "redacts"): + # A redacted redaction doesn't have a `redacts` key, in + # which case lets just withhold the event. + # + # Note: Most of the time if the redactions has been + # redacted we still have the un-redacted event in the DB + # and so we'll still see the `redacts` key. However, this + # isn't always true e.g. if we have censored the event. + logger.debug( + "Withholding redaction event %s as we don't have redacts key", + event_id, + ) + continue + + redacted_event_id = entry.event.redacts + event_map = yield self._get_events_from_cache_or_db([redacted_event_id]) + original_event_entry = event_map.get(redacted_event_id) + if not original_event_entry: + # we don't have the redacted event (or it was rejected). + # + # We assume that the redaction isn't authorized for now; if the + # redacted event later turns up, the redaction will be re-checked, + # and if it is found valid, the original will get redacted before it + # is served to the client. + logger.debug( + "Withholding redaction event %s since we don't (yet) have the " + "original %s", + event_id, + redacted_event_id, + ) + continue + + original_event = original_event_entry.event + if original_event.type == EventTypes.Create: + # we never serve redactions of Creates to clients. + logger.info( + "Withholding redaction %s of create event %s", + event_id, + redacted_event_id, + ) + continue + + if original_event.room_id != entry.event.room_id: + logger.info( + "Withholding redaction %s of event %s from a different room", + event_id, + redacted_event_id, + ) + continue + + if entry.event.internal_metadata.need_to_check_redaction(): + original_domain = get_domain_from_id(original_event.sender) + redaction_domain = get_domain_from_id(entry.event.sender) + if original_domain != redaction_domain: + # the senders don't match, so this is forbidden + logger.info( + "Withholding redaction %s whose sender domain %s doesn't " + "match that of redacted event %s %s", + event_id, + redaction_domain, + redacted_event_id, + original_domain, + ) + continue + + # Update the cache to save doing the checks again. + entry.event.internal_metadata.recheck_redaction = False + + if check_redacted and entry.redacted_event: + event = entry.redacted_event + else: + event = entry.event + + events.append(event) + + if get_prev_content: + if "replaces_state" in event.unsigned: + prev = yield self.get_event( + event.unsigned["replaces_state"], + get_prev_content=False, + allow_none=True, + ) + if prev: + event.unsigned = dict(event.unsigned) + event.unsigned["prev_content"] = prev.content + event.unsigned["prev_sender"] = prev.sender + + return events + + @defer.inlineCallbacks + def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): + """Fetch a bunch of events from the cache or the database. + + If events are pulled from the database, they will be cached for future lookups. + + Args: + event_ids (Iterable[str]): The event_ids of the events to fetch + allow_rejected (bool): Whether to include rejected events + + Returns: + Deferred[Dict[str, _EventCacheEntry]]: + map from event id to result + """ + event_entry_map = self._get_events_from_cache( + event_ids, allow_rejected=allow_rejected + ) + + missing_events_ids = [e for e in event_ids if e not in event_entry_map] + + if missing_events_ids: + log_ctx = LoggingContext.current_context() + log_ctx.record_event_fetch(len(missing_events_ids)) + + # Note that _get_events_from_db is also responsible for turning db rows + # into FrozenEvents (via _get_event_from_row), which involves seeing if + # the events have been redacted, and if so pulling the redaction event out + # of the database to check it. + # + missing_events = yield self._get_events_from_db( + missing_events_ids, allow_rejected=allow_rejected + ) + + event_entry_map.update(missing_events) + + return event_entry_map + + def _invalidate_get_event_cache(self, event_id): + self._get_event_cache.invalidate((event_id,)) + + def _get_events_from_cache(self, events, allow_rejected, update_metrics=True): + """Fetch events from the caches + + Args: + events (Iterable[str]): list of event_ids to fetch + allow_rejected (bool): Whether to return events that were rejected + update_metrics (bool): Whether to update the cache hit ratio metrics + + Returns: + dict of event_id -> _EventCacheEntry for each event_id in cache. If + allow_rejected is `False` then there will still be an entry but it + will be `None` + """ + event_map = {} + + for event_id in events: + ret = self._get_event_cache.get( + (event_id,), None, update_metrics=update_metrics + ) + if not ret: + continue + + if allow_rejected or not ret.event.rejected_reason: + event_map[event_id] = ret + else: + event_map[event_id] = None + + return event_map + + def _do_fetch(self, conn): + """Takes a database connection and waits for requests for events from + the _event_fetch_list queue. + """ + i = 0 + while True: + with self._event_fetch_lock: + event_list = self._event_fetch_list + self._event_fetch_list = [] + + if not event_list: + single_threaded = self.database_engine.single_threaded + if single_threaded or i > EVENT_QUEUE_ITERATIONS: + self._event_fetch_ongoing -= 1 + return + else: + self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S) + i += 1 + continue + i = 0 + + self._fetch_event_list(conn, event_list) + + def _fetch_event_list(self, conn, event_list): + """Handle a load of requests from the _event_fetch_list queue + + Args: + conn (twisted.enterprise.adbapi.Connection): database connection + + event_list (list[Tuple[list[str], Deferred]]): + The fetch requests. Each entry consists of a list of event + ids to be fetched, and a deferred to be completed once the + events have been fetched. + + The deferreds are callbacked with a dictionary mapping from event id + to event row. Note that it may well contain additional events that + were not part of this request. + """ + with Measure(self._clock, "_fetch_event_list"): + try: + events_to_fetch = set( + event_id for events, _ in event_list for event_id in events + ) + + row_dict = self._new_transaction( + conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch + ) + + # We only want to resolve deferreds from the main thread + def fire(): + for _, d in event_list: + d.callback(row_dict) + + with PreserveLoggingContext(): + self.hs.get_reactor().callFromThread(fire) + except Exception as e: + logger.exception("do_fetch") + + # We only want to resolve deferreds from the main thread + def fire(evs, exc): + for _, d in evs: + if not d.called: + with PreserveLoggingContext(): + d.errback(exc) + + with PreserveLoggingContext(): + self.hs.get_reactor().callFromThread(fire, event_list, e) + + @defer.inlineCallbacks + def _get_events_from_db(self, event_ids, allow_rejected=False): + """Fetch a bunch of events from the database. + + Returned events will be added to the cache for future lookups. + + Args: + event_ids (Iterable[str]): The event_ids of the events to fetch + allow_rejected (bool): Whether to include rejected events + + Returns: + Deferred[Dict[str, _EventCacheEntry]]: + map from event id to result. May return extra events which + weren't asked for. + """ + fetched_events = {} + events_to_fetch = event_ids + + while events_to_fetch: + row_map = yield self._enqueue_events(events_to_fetch) + + # we need to recursively fetch any redactions of those events + redaction_ids = set() + for event_id in events_to_fetch: + row = row_map.get(event_id) + fetched_events[event_id] = row + if row: + redaction_ids.update(row["redactions"]) + + events_to_fetch = redaction_ids.difference(fetched_events.keys()) + if events_to_fetch: + logger.debug("Also fetching redaction events %s", events_to_fetch) + + # build a map from event_id to EventBase + event_map = {} + for event_id, row in fetched_events.items(): + if not row: + continue + assert row["event_id"] == event_id + + rejected_reason = row["rejected_reason"] + + if not allow_rejected and rejected_reason: + continue + + d = json.loads(row["json"]) + internal_metadata = json.loads(row["internal_metadata"]) + + format_version = row["format_version"] + if format_version is None: + # This means that we stored the event before we had the concept + # of a event format version, so it must be a V1 event. + format_version = EventFormatVersions.V1 + + original_ev = event_type_from_format_version(format_version)( + event_dict=d, + internal_metadata_dict=internal_metadata, + rejected_reason=rejected_reason, + ) + + event_map[event_id] = original_ev + + # finally, we can decide whether each one nededs redacting, and build + # the cache entries. + result_map = {} + for event_id, original_ev in event_map.items(): + redactions = fetched_events[event_id]["redactions"] + redacted_event = self._maybe_redact_event_row( + original_ev, redactions, event_map + ) + + cache_entry = _EventCacheEntry( + event=original_ev, redacted_event=redacted_event + ) + + self._get_event_cache.prefill((event_id,), cache_entry) + result_map[event_id] = cache_entry + + return result_map + + @defer.inlineCallbacks + def _enqueue_events(self, events): + """Fetches events from the database using the _event_fetch_list. This + allows batch and bulk fetching of events - it allows us to fetch events + without having to create a new transaction for each request for events. + + Args: + events (Iterable[str]): events to be fetched. + + Returns: + Deferred[Dict[str, Dict]]: map from event id to row data from the database. + May contain events that weren't requested. + """ + + events_d = defer.Deferred() + with self._event_fetch_lock: + self._event_fetch_list.append((events, events_d)) + + self._event_fetch_lock.notify() + + if self._event_fetch_ongoing < EVENT_QUEUE_THREADS: + self._event_fetch_ongoing += 1 + should_start = True + else: + should_start = False + + if should_start: + run_as_background_process( + "fetch_events", self.runWithConnection, self._do_fetch + ) + + logger.debug("Loading %d events: %s", len(events), events) + with PreserveLoggingContext(): + row_map = yield events_d + logger.debug("Loaded %d events (%d rows)", len(events), len(row_map)) + + return row_map + + def _fetch_event_rows(self, txn, event_ids): + """Fetch event rows from the database + + Events which are not found are omitted from the result. + + The returned per-event dicts contain the following keys: + + * event_id (str) + + * json (str): json-encoded event structure + + * internal_metadata (str): json-encoded internal metadata dict + + * format_version (int|None): The format of the event. Hopefully one + of EventFormatVersions. 'None' means the event predates + EventFormatVersions (so the event is format V1). + + * rejected_reason (str|None): if the event was rejected, the reason + why. + + * redactions (List[str]): a list of event-ids which (claim to) redact + this event. + + Args: + txn (twisted.enterprise.adbapi.Connection): + event_ids (Iterable[str]): event IDs to fetch + + Returns: + Dict[str, Dict]: a map from event id to event info. + """ + event_dict = {} + for evs in batch_iter(event_ids, 200): + sql = ( + "SELECT " + " e.event_id, " + " e.internal_metadata," + " e.json," + " e.format_version, " + " rej.reason " + " FROM event_json as e" + " LEFT JOIN rejections as rej USING (event_id)" + " WHERE " + ) + + clause, args = make_in_list_sql_clause( + txn.database_engine, "e.event_id", evs + ) + + txn.execute(sql + clause, args) + + for row in txn: + event_id = row[0] + event_dict[event_id] = { + "event_id": event_id, + "internal_metadata": row[1], + "json": row[2], + "format_version": row[3], + "rejected_reason": row[4], + "redactions": [], + } + + # check for redactions + redactions_sql = "SELECT event_id, redacts FROM redactions WHERE " + + clause, args = make_in_list_sql_clause(txn.database_engine, "redacts", evs) + + txn.execute(redactions_sql + clause, args) + + for (redacter, redacted) in txn: + d = event_dict.get(redacted) + if d: + d["redactions"].append(redacter) + + return event_dict + + def _maybe_redact_event_row(self, original_ev, redactions, event_map): + """Given an event object and a list of possible redacting event ids, + determine whether to honour any of those redactions and if so return a redacted + event. + + Args: + original_ev (EventBase): + redactions (iterable[str]): list of event ids of potential redaction events + event_map (dict[str, EventBase]): other events which have been fetched, in + which we can look up the redaaction events. Map from event id to event. + + Returns: + Deferred[EventBase|None]: if the event should be redacted, a pruned + event object. Otherwise, None. + """ + if original_ev.type == "m.room.create": + # we choose to ignore redactions of m.room.create events. + return None + + for redaction_id in redactions: + redaction_event = event_map.get(redaction_id) + if not redaction_event or redaction_event.rejected_reason: + # we don't have the redaction event, or the redaction event was not + # authorized. + logger.debug( + "%s was redacted by %s but redaction not found/authed", + original_ev.event_id, + redaction_id, + ) + continue + + if redaction_event.room_id != original_ev.room_id: + logger.debug( + "%s was redacted by %s but redaction was in a different room!", + original_ev.event_id, + redaction_id, + ) + continue + + # Starting in room version v3, some redactions need to be + # rechecked if we didn't have the redacted event at the + # time, so we recheck on read instead. + if redaction_event.internal_metadata.need_to_check_redaction(): + expected_domain = get_domain_from_id(original_ev.sender) + if get_domain_from_id(redaction_event.sender) == expected_domain: + # This redaction event is allowed. Mark as not needing a recheck. + redaction_event.internal_metadata.recheck_redaction = False + else: + # Senders don't match, so the event isn't actually redacted + logger.debug( + "%s was redacted by %s but the senders don't match", + original_ev.event_id, + redaction_id, + ) + continue + + logger.debug("Redacting %s due to %s", original_ev.event_id, redaction_id) + + # we found a good redaction event. Redact! + redacted_event = prune_event(original_ev) + redacted_event.unsigned["redacted_by"] = redaction_id + + # It's fine to add the event directly, since get_pdu_json + # will serialise this field correctly + redacted_event.unsigned["redacted_because"] = redaction_event + + return redacted_event + + # no valid redaction found for this event + return None + + @defer.inlineCallbacks + def have_events_in_timeline(self, event_ids): + """Given a list of event ids, check if we have already processed and + stored them as non outliers. + """ + rows = yield self._simple_select_many_batch( + table="events", + retcols=("event_id",), + column="event_id", + iterable=list(event_ids), + keyvalues={"outlier": False}, + desc="have_events_in_timeline", + ) + + return set(r["event_id"] for r in rows) + + @defer.inlineCallbacks + def have_seen_events(self, event_ids): + """Given a list of event ids, check if we have already processed them. + + Args: + event_ids (iterable[str]): + + Returns: + Deferred[set[str]]: The events we have already seen. + """ + results = set() + + def have_seen_events_txn(txn, chunk): + sql = "SELECT event_id FROM events as e WHERE " + clause, args = make_in_list_sql_clause( + txn.database_engine, "e.event_id", chunk + ) + txn.execute(sql + clause, args) + for (event_id,) in txn: + results.add(event_id) + + # break the input up into chunks of 100 + input_iterator = iter(event_ids) + for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []): + yield self.runInteraction("have_seen_events", have_seen_events_txn, chunk) + return results + + def get_seen_events_with_rejections(self, event_ids): + """Given a list of event ids, check if we rejected them. + + Args: + event_ids (list[str]) + + Returns: + Deferred[dict[str, str|None): + Has an entry for each event id we already have seen. Maps to + the rejected reason string if we rejected the event, else maps + to None. + """ + if not event_ids: + return defer.succeed({}) + + def f(txn): + sql = ( + "SELECT e.event_id, reason FROM events as e " + "LEFT JOIN rejections as r ON e.event_id = r.event_id " + "WHERE e.event_id = ?" + ) + + res = {} + for event_id in event_ids: + txn.execute(sql, (event_id,)) + row = txn.fetchone() + if row: + _, rejected = row + res[event_id] = rejected + + return res + + return self.runInteraction("get_seen_events_with_rejections", f) + + def _get_total_state_event_counts_txn(self, txn, room_id): + """ + See get_total_state_event_counts. + """ + # We join against the events table as that has an index on room_id + sql = """ + SELECT COUNT(*) FROM state_events + INNER JOIN events USING (room_id, event_id) + WHERE room_id=? + """ + txn.execute(sql, (room_id,)) + row = txn.fetchone() + return row[0] if row else 0 + + def get_total_state_event_counts(self, room_id): + """ + Gets the total number of state events in a room. + + Args: + room_id (str) + + Returns: + Deferred[int] + """ + return self.runInteraction( + "get_total_state_event_counts", + self._get_total_state_event_counts_txn, + room_id, + ) + + def _get_current_state_event_counts_txn(self, txn, room_id): + """ + See get_current_state_event_counts. + """ + sql = "SELECT COUNT(*) FROM current_state_events WHERE room_id=?" + txn.execute(sql, (room_id,)) + row = txn.fetchone() + return row[0] if row else 0 + + def get_current_state_event_counts(self, room_id): + """ + Gets the current number of state events in a room. + + Args: + room_id (str) + + Returns: + Deferred[int] + """ + return self.runInteraction( + "get_current_state_event_counts", + self._get_current_state_event_counts_txn, + room_id, + ) + + @defer.inlineCallbacks + def get_room_complexity(self, room_id): + """ + Get a rough approximation of the complexity of the room. This is used by + remote servers to decide whether they wish to join the room or not. + Higher complexity value indicates that being in the room will consume + more resources. + + Args: + room_id (str) + + Returns: + Deferred[dict[str:int]] of complexity version to complexity. + """ + state_events = yield self.get_current_state_event_counts(room_id) + + # Call this one "v1", so we can introduce new ones as we want to develop + # it. + complexity_v1 = round(state_events / 500, 2) + + return {"v1": complexity_v1} diff --git a/synapse/storage/data_stores/main/filtering.py b/synapse/storage/data_stores/main/filtering.py new file mode 100644 index 0000000000..a2a2a67927 --- /dev/null +++ b/synapse/storage/data_stores/main/filtering.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from canonicaljson import encode_canonical_json + +from synapse.api.errors import Codes, SynapseError +from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.util.caches.descriptors import cachedInlineCallbacks + + +class FilteringStore(SQLBaseStore): + @cachedInlineCallbacks(num_args=2) + def get_user_filter(self, user_localpart, filter_id): + # filter_id is BIGINT UNSIGNED, so if it isn't a number, fail + # with a coherent error message rather than 500 M_UNKNOWN. + try: + int(filter_id) + except ValueError: + raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM) + + def_json = yield self._simple_select_one_onecol( + table="user_filters", + keyvalues={"user_id": user_localpart, "filter_id": filter_id}, + retcol="filter_json", + allow_none=False, + desc="get_user_filter", + ) + + return db_to_json(def_json) + + def add_user_filter(self, user_localpart, user_filter): + def_json = encode_canonical_json(user_filter) + + # Need an atomic transaction to SELECT the maximal ID so far then + # INSERT a new one + def _do_txn(txn): + sql = ( + "SELECT filter_id FROM user_filters " + "WHERE user_id = ? AND filter_json = ?" + ) + txn.execute(sql, (user_localpart, bytearray(def_json))) + filter_id_response = txn.fetchone() + if filter_id_response is not None: + return filter_id_response[0] + + sql = "SELECT MAX(filter_id) FROM user_filters " "WHERE user_id = ?" + txn.execute(sql, (user_localpart,)) + max_id = txn.fetchone()[0] + if max_id is None: + filter_id = 0 + else: + filter_id = max_id + 1 + + sql = ( + "INSERT INTO user_filters (user_id, filter_id, filter_json)" + "VALUES(?, ?, ?)" + ) + txn.execute(sql, (user_localpart, filter_id, bytearray(def_json))) + + return filter_id + + return self.runInteraction("add_user_filter", _do_txn) diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py new file mode 100644 index 0000000000..aeae5a2b28 --- /dev/null +++ b/synapse/storage/data_stores/main/group_server.py @@ -0,0 +1,1180 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.api.errors import SynapseError +from synapse.storage._base import SQLBaseStore + +# The category ID for the "default" category. We don't store as null in the +# database to avoid the fun of null != null +_DEFAULT_CATEGORY_ID = "" +_DEFAULT_ROLE_ID = "" + + +class GroupServerStore(SQLBaseStore): + def set_group_join_policy(self, group_id, join_policy): + """Set the join policy of a group. + + join_policy can be one of: + * "invite" + * "open" + """ + return self._simple_update_one( + table="groups", + keyvalues={"group_id": group_id}, + updatevalues={"join_policy": join_policy}, + desc="set_group_join_policy", + ) + + def get_group(self, group_id): + return self._simple_select_one( + table="groups", + keyvalues={"group_id": group_id}, + retcols=( + "name", + "short_description", + "long_description", + "avatar_url", + "is_public", + "join_policy", + ), + allow_none=True, + desc="get_group", + ) + + def get_users_in_group(self, group_id, include_private=False): + # TODO: Pagination + + keyvalues = {"group_id": group_id} + if not include_private: + keyvalues["is_public"] = True + + return self._simple_select_list( + table="group_users", + keyvalues=keyvalues, + retcols=("user_id", "is_public", "is_admin"), + desc="get_users_in_group", + ) + + def get_invited_users_in_group(self, group_id): + # TODO: Pagination + + return self._simple_select_onecol( + table="group_invites", + keyvalues={"group_id": group_id}, + retcol="user_id", + desc="get_invited_users_in_group", + ) + + def get_rooms_in_group(self, group_id, include_private=False): + # TODO: Pagination + + keyvalues = {"group_id": group_id} + if not include_private: + keyvalues["is_public"] = True + + return self._simple_select_list( + table="group_rooms", + keyvalues=keyvalues, + retcols=("room_id", "is_public"), + desc="get_rooms_in_group", + ) + + def get_rooms_for_summary_by_category(self, group_id, include_private=False): + """Get the rooms and categories that should be included in a summary request + + Returns ([rooms], [categories]) + """ + + def _get_rooms_for_summary_txn(txn): + keyvalues = {"group_id": group_id} + if not include_private: + keyvalues["is_public"] = True + + sql = """ + SELECT room_id, is_public, category_id, room_order + FROM group_summary_rooms + WHERE group_id = ? + """ + + if not include_private: + sql += " AND is_public = ?" + txn.execute(sql, (group_id, True)) + else: + txn.execute(sql, (group_id,)) + + rooms = [ + { + "room_id": row[0], + "is_public": row[1], + "category_id": row[2] if row[2] != _DEFAULT_CATEGORY_ID else None, + "order": row[3], + } + for row in txn + ] + + sql = """ + SELECT category_id, is_public, profile, cat_order + FROM group_summary_room_categories + INNER JOIN group_room_categories USING (group_id, category_id) + WHERE group_id = ? + """ + + if not include_private: + sql += " AND is_public = ?" + txn.execute(sql, (group_id, True)) + else: + txn.execute(sql, (group_id,)) + + categories = { + row[0]: { + "is_public": row[1], + "profile": json.loads(row[2]), + "order": row[3], + } + for row in txn + } + + return rooms, categories + + return self.runInteraction("get_rooms_for_summary", _get_rooms_for_summary_txn) + + def add_room_to_summary(self, group_id, room_id, category_id, order, is_public): + return self.runInteraction( + "add_room_to_summary", + self._add_room_to_summary_txn, + group_id, + room_id, + category_id, + order, + is_public, + ) + + def _add_room_to_summary_txn( + self, txn, group_id, room_id, category_id, order, is_public + ): + """Add (or update) room's entry in summary. + + Args: + group_id (str) + room_id (str) + category_id (str): If not None then adds the category to the end of + the summary if its not already there. [Optional] + order (int): If not None inserts the room at that position, e.g. + an order of 1 will put the room first. Otherwise, the room gets + added to the end. + """ + room_in_group = self._simple_select_one_onecol_txn( + txn, + table="group_rooms", + keyvalues={"group_id": group_id, "room_id": room_id}, + retcol="room_id", + allow_none=True, + ) + if not room_in_group: + raise SynapseError(400, "room not in group") + + if category_id is None: + category_id = _DEFAULT_CATEGORY_ID + else: + cat_exists = self._simple_select_one_onecol_txn( + txn, + table="group_room_categories", + keyvalues={"group_id": group_id, "category_id": category_id}, + retcol="group_id", + allow_none=True, + ) + if not cat_exists: + raise SynapseError(400, "Category doesn't exist") + + # TODO: Check category is part of summary already + cat_exists = self._simple_select_one_onecol_txn( + txn, + table="group_summary_room_categories", + keyvalues={"group_id": group_id, "category_id": category_id}, + retcol="group_id", + allow_none=True, + ) + if not cat_exists: + # If not, add it with an order larger than all others + txn.execute( + """ + INSERT INTO group_summary_room_categories + (group_id, category_id, cat_order) + SELECT ?, ?, COALESCE(MAX(cat_order), 0) + 1 + FROM group_summary_room_categories + WHERE group_id = ? AND category_id = ? + """, + (group_id, category_id, group_id, category_id), + ) + + existing = self._simple_select_one_txn( + txn, + table="group_summary_rooms", + keyvalues={ + "group_id": group_id, + "room_id": room_id, + "category_id": category_id, + }, + retcols=("room_order", "is_public"), + allow_none=True, + ) + + if order is not None: + # Shuffle other room orders that come after the given order + sql = """ + UPDATE group_summary_rooms SET room_order = room_order + 1 + WHERE group_id = ? AND category_id = ? AND room_order >= ? + """ + txn.execute(sql, (group_id, category_id, order)) + elif not existing: + sql = """ + SELECT COALESCE(MAX(room_order), 0) + 1 FROM group_summary_rooms + WHERE group_id = ? AND category_id = ? + """ + txn.execute(sql, (group_id, category_id)) + order, = txn.fetchone() + + if existing: + to_update = {} + if order is not None: + to_update["room_order"] = order + if is_public is not None: + to_update["is_public"] = is_public + self._simple_update_txn( + txn, + table="group_summary_rooms", + keyvalues={ + "group_id": group_id, + "category_id": category_id, + "room_id": room_id, + }, + values=to_update, + ) + else: + if is_public is None: + is_public = True + + self._simple_insert_txn( + txn, + table="group_summary_rooms", + values={ + "group_id": group_id, + "category_id": category_id, + "room_id": room_id, + "room_order": order, + "is_public": is_public, + }, + ) + + def remove_room_from_summary(self, group_id, room_id, category_id): + if category_id is None: + category_id = _DEFAULT_CATEGORY_ID + + return self._simple_delete( + table="group_summary_rooms", + keyvalues={ + "group_id": group_id, + "category_id": category_id, + "room_id": room_id, + }, + desc="remove_room_from_summary", + ) + + @defer.inlineCallbacks + def get_group_categories(self, group_id): + rows = yield self._simple_select_list( + table="group_room_categories", + keyvalues={"group_id": group_id}, + retcols=("category_id", "is_public", "profile"), + desc="get_group_categories", + ) + + return { + row["category_id"]: { + "is_public": row["is_public"], + "profile": json.loads(row["profile"]), + } + for row in rows + } + + @defer.inlineCallbacks + def get_group_category(self, group_id, category_id): + category = yield self._simple_select_one( + table="group_room_categories", + keyvalues={"group_id": group_id, "category_id": category_id}, + retcols=("is_public", "profile"), + desc="get_group_category", + ) + + category["profile"] = json.loads(category["profile"]) + + return category + + def upsert_group_category(self, group_id, category_id, profile, is_public): + """Add/update room category for group + """ + insertion_values = {} + update_values = {"category_id": category_id} # This cannot be empty + + if profile is None: + insertion_values["profile"] = "{}" + else: + update_values["profile"] = json.dumps(profile) + + if is_public is None: + insertion_values["is_public"] = True + else: + update_values["is_public"] = is_public + + return self._simple_upsert( + table="group_room_categories", + keyvalues={"group_id": group_id, "category_id": category_id}, + values=update_values, + insertion_values=insertion_values, + desc="upsert_group_category", + ) + + def remove_group_category(self, group_id, category_id): + return self._simple_delete( + table="group_room_categories", + keyvalues={"group_id": group_id, "category_id": category_id}, + desc="remove_group_category", + ) + + @defer.inlineCallbacks + def get_group_roles(self, group_id): + rows = yield self._simple_select_list( + table="group_roles", + keyvalues={"group_id": group_id}, + retcols=("role_id", "is_public", "profile"), + desc="get_group_roles", + ) + + return { + row["role_id"]: { + "is_public": row["is_public"], + "profile": json.loads(row["profile"]), + } + for row in rows + } + + @defer.inlineCallbacks + def get_group_role(self, group_id, role_id): + role = yield self._simple_select_one( + table="group_roles", + keyvalues={"group_id": group_id, "role_id": role_id}, + retcols=("is_public", "profile"), + desc="get_group_role", + ) + + role["profile"] = json.loads(role["profile"]) + + return role + + def upsert_group_role(self, group_id, role_id, profile, is_public): + """Add/remove user role + """ + insertion_values = {} + update_values = {"role_id": role_id} # This cannot be empty + + if profile is None: + insertion_values["profile"] = "{}" + else: + update_values["profile"] = json.dumps(profile) + + if is_public is None: + insertion_values["is_public"] = True + else: + update_values["is_public"] = is_public + + return self._simple_upsert( + table="group_roles", + keyvalues={"group_id": group_id, "role_id": role_id}, + values=update_values, + insertion_values=insertion_values, + desc="upsert_group_role", + ) + + def remove_group_role(self, group_id, role_id): + return self._simple_delete( + table="group_roles", + keyvalues={"group_id": group_id, "role_id": role_id}, + desc="remove_group_role", + ) + + def add_user_to_summary(self, group_id, user_id, role_id, order, is_public): + return self.runInteraction( + "add_user_to_summary", + self._add_user_to_summary_txn, + group_id, + user_id, + role_id, + order, + is_public, + ) + + def _add_user_to_summary_txn( + self, txn, group_id, user_id, role_id, order, is_public + ): + """Add (or update) user's entry in summary. + + Args: + group_id (str) + user_id (str) + role_id (str): If not None then adds the role to the end of + the summary if its not already there. [Optional] + order (int): If not None inserts the user at that position, e.g. + an order of 1 will put the user first. Otherwise, the user gets + added to the end. + """ + user_in_group = self._simple_select_one_onecol_txn( + txn, + table="group_users", + keyvalues={"group_id": group_id, "user_id": user_id}, + retcol="user_id", + allow_none=True, + ) + if not user_in_group: + raise SynapseError(400, "user not in group") + + if role_id is None: + role_id = _DEFAULT_ROLE_ID + else: + role_exists = self._simple_select_one_onecol_txn( + txn, + table="group_roles", + keyvalues={"group_id": group_id, "role_id": role_id}, + retcol="group_id", + allow_none=True, + ) + if not role_exists: + raise SynapseError(400, "Role doesn't exist") + + # TODO: Check role is part of the summary already + role_exists = self._simple_select_one_onecol_txn( + txn, + table="group_summary_roles", + keyvalues={"group_id": group_id, "role_id": role_id}, + retcol="group_id", + allow_none=True, + ) + if not role_exists: + # If not, add it with an order larger than all others + txn.execute( + """ + INSERT INTO group_summary_roles + (group_id, role_id, role_order) + SELECT ?, ?, COALESCE(MAX(role_order), 0) + 1 + FROM group_summary_roles + WHERE group_id = ? AND role_id = ? + """, + (group_id, role_id, group_id, role_id), + ) + + existing = self._simple_select_one_txn( + txn, + table="group_summary_users", + keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id}, + retcols=("user_order", "is_public"), + allow_none=True, + ) + + if order is not None: + # Shuffle other users orders that come after the given order + sql = """ + UPDATE group_summary_users SET user_order = user_order + 1 + WHERE group_id = ? AND role_id = ? AND user_order >= ? + """ + txn.execute(sql, (group_id, role_id, order)) + elif not existing: + sql = """ + SELECT COALESCE(MAX(user_order), 0) + 1 FROM group_summary_users + WHERE group_id = ? AND role_id = ? + """ + txn.execute(sql, (group_id, role_id)) + order, = txn.fetchone() + + if existing: + to_update = {} + if order is not None: + to_update["user_order"] = order + if is_public is not None: + to_update["is_public"] = is_public + self._simple_update_txn( + txn, + table="group_summary_users", + keyvalues={ + "group_id": group_id, + "role_id": role_id, + "user_id": user_id, + }, + values=to_update, + ) + else: + if is_public is None: + is_public = True + + self._simple_insert_txn( + txn, + table="group_summary_users", + values={ + "group_id": group_id, + "role_id": role_id, + "user_id": user_id, + "user_order": order, + "is_public": is_public, + }, + ) + + def remove_user_from_summary(self, group_id, user_id, role_id): + if role_id is None: + role_id = _DEFAULT_ROLE_ID + + return self._simple_delete( + table="group_summary_users", + keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id}, + desc="remove_user_from_summary", + ) + + def get_users_for_summary_by_role(self, group_id, include_private=False): + """Get the users and roles that should be included in a summary request + + Returns ([users], [roles]) + """ + + def _get_users_for_summary_txn(txn): + keyvalues = {"group_id": group_id} + if not include_private: + keyvalues["is_public"] = True + + sql = """ + SELECT user_id, is_public, role_id, user_order + FROM group_summary_users + WHERE group_id = ? + """ + + if not include_private: + sql += " AND is_public = ?" + txn.execute(sql, (group_id, True)) + else: + txn.execute(sql, (group_id,)) + + users = [ + { + "user_id": row[0], + "is_public": row[1], + "role_id": row[2] if row[2] != _DEFAULT_ROLE_ID else None, + "order": row[3], + } + for row in txn + ] + + sql = """ + SELECT role_id, is_public, profile, role_order + FROM group_summary_roles + INNER JOIN group_roles USING (group_id, role_id) + WHERE group_id = ? + """ + + if not include_private: + sql += " AND is_public = ?" + txn.execute(sql, (group_id, True)) + else: + txn.execute(sql, (group_id,)) + + roles = { + row[0]: { + "is_public": row[1], + "profile": json.loads(row[2]), + "order": row[3], + } + for row in txn + } + + return users, roles + + return self.runInteraction( + "get_users_for_summary_by_role", _get_users_for_summary_txn + ) + + def is_user_in_group(self, user_id, group_id): + return self._simple_select_one_onecol( + table="group_users", + keyvalues={"group_id": group_id, "user_id": user_id}, + retcol="user_id", + allow_none=True, + desc="is_user_in_group", + ).addCallback(lambda r: bool(r)) + + def is_user_admin_in_group(self, group_id, user_id): + return self._simple_select_one_onecol( + table="group_users", + keyvalues={"group_id": group_id, "user_id": user_id}, + retcol="is_admin", + allow_none=True, + desc="is_user_admin_in_group", + ) + + def add_group_invite(self, group_id, user_id): + """Record that the group server has invited a user + """ + return self._simple_insert( + table="group_invites", + values={"group_id": group_id, "user_id": user_id}, + desc="add_group_invite", + ) + + def is_user_invited_to_local_group(self, group_id, user_id): + """Has the group server invited a user? + """ + return self._simple_select_one_onecol( + table="group_invites", + keyvalues={"group_id": group_id, "user_id": user_id}, + retcol="user_id", + desc="is_user_invited_to_local_group", + allow_none=True, + ) + + def get_users_membership_info_in_group(self, group_id, user_id): + """Get a dict describing the membership of a user in a group. + + Example if joined: + + { + "membership": "join", + "is_public": True, + "is_privileged": False, + } + + Returns an empty dict if the user is not join/invite/etc + """ + + def _get_users_membership_in_group_txn(txn): + row = self._simple_select_one_txn( + txn, + table="group_users", + keyvalues={"group_id": group_id, "user_id": user_id}, + retcols=("is_admin", "is_public"), + allow_none=True, + ) + + if row: + return { + "membership": "join", + "is_public": row["is_public"], + "is_privileged": row["is_admin"], + } + + row = self._simple_select_one_onecol_txn( + txn, + table="group_invites", + keyvalues={"group_id": group_id, "user_id": user_id}, + retcol="user_id", + allow_none=True, + ) + + if row: + return {"membership": "invite"} + + return {} + + return self.runInteraction( + "get_users_membership_info_in_group", _get_users_membership_in_group_txn + ) + + def add_user_to_group( + self, + group_id, + user_id, + is_admin=False, + is_public=True, + local_attestation=None, + remote_attestation=None, + ): + """Add a user to the group server. + + Args: + group_id (str) + user_id (str) + is_admin (bool) + is_public (bool) + local_attestation (dict): The attestation the GS created to give + to the remote server. Optional if the user and group are on the + same server + remote_attestation (dict): The attestation given to GS by remote + server. Optional if the user and group are on the same server + """ + + def _add_user_to_group_txn(txn): + self._simple_insert_txn( + txn, + table="group_users", + values={ + "group_id": group_id, + "user_id": user_id, + "is_admin": is_admin, + "is_public": is_public, + }, + ) + + self._simple_delete_txn( + txn, + table="group_invites", + keyvalues={"group_id": group_id, "user_id": user_id}, + ) + + if local_attestation: + self._simple_insert_txn( + txn, + table="group_attestations_renewals", + values={ + "group_id": group_id, + "user_id": user_id, + "valid_until_ms": local_attestation["valid_until_ms"], + }, + ) + if remote_attestation: + self._simple_insert_txn( + txn, + table="group_attestations_remote", + values={ + "group_id": group_id, + "user_id": user_id, + "valid_until_ms": remote_attestation["valid_until_ms"], + "attestation_json": json.dumps(remote_attestation), + }, + ) + + return self.runInteraction("add_user_to_group", _add_user_to_group_txn) + + def remove_user_from_group(self, group_id, user_id): + def _remove_user_from_group_txn(txn): + self._simple_delete_txn( + txn, + table="group_users", + keyvalues={"group_id": group_id, "user_id": user_id}, + ) + self._simple_delete_txn( + txn, + table="group_invites", + keyvalues={"group_id": group_id, "user_id": user_id}, + ) + self._simple_delete_txn( + txn, + table="group_attestations_renewals", + keyvalues={"group_id": group_id, "user_id": user_id}, + ) + self._simple_delete_txn( + txn, + table="group_attestations_remote", + keyvalues={"group_id": group_id, "user_id": user_id}, + ) + self._simple_delete_txn( + txn, + table="group_summary_users", + keyvalues={"group_id": group_id, "user_id": user_id}, + ) + + return self.runInteraction( + "remove_user_from_group", _remove_user_from_group_txn + ) + + def add_room_to_group(self, group_id, room_id, is_public): + return self._simple_insert( + table="group_rooms", + values={"group_id": group_id, "room_id": room_id, "is_public": is_public}, + desc="add_room_to_group", + ) + + def update_room_in_group_visibility(self, group_id, room_id, is_public): + return self._simple_update( + table="group_rooms", + keyvalues={"group_id": group_id, "room_id": room_id}, + updatevalues={"is_public": is_public}, + desc="update_room_in_group_visibility", + ) + + def remove_room_from_group(self, group_id, room_id): + def _remove_room_from_group_txn(txn): + self._simple_delete_txn( + txn, + table="group_rooms", + keyvalues={"group_id": group_id, "room_id": room_id}, + ) + + self._simple_delete_txn( + txn, + table="group_summary_rooms", + keyvalues={"group_id": group_id, "room_id": room_id}, + ) + + return self.runInteraction( + "remove_room_from_group", _remove_room_from_group_txn + ) + + def get_publicised_groups_for_user(self, user_id): + """Get all groups a user is publicising + """ + return self._simple_select_onecol( + table="local_group_membership", + keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True}, + retcol="group_id", + desc="get_publicised_groups_for_user", + ) + + def update_group_publicity(self, group_id, user_id, publicise): + """Update whether the user is publicising their membership of the group + """ + return self._simple_update_one( + table="local_group_membership", + keyvalues={"group_id": group_id, "user_id": user_id}, + updatevalues={"is_publicised": publicise}, + desc="update_group_publicity", + ) + + @defer.inlineCallbacks + def register_user_group_membership( + self, + group_id, + user_id, + membership, + is_admin=False, + content={}, + local_attestation=None, + remote_attestation=None, + is_publicised=False, + ): + """Registers that a local user is a member of a (local or remote) group. + + Args: + group_id (str) + user_id (str) + membership (str) + is_admin (bool) + content (dict): Content of the membership, e.g. includes the inviter + if the user has been invited. + local_attestation (dict): If remote group then store the fact that we + have given out an attestation, else None. + remote_attestation (dict): If remote group then store the remote + attestation from the group, else None. + """ + + def _register_user_group_membership_txn(txn, next_id): + # TODO: Upsert? + self._simple_delete_txn( + txn, + table="local_group_membership", + keyvalues={"group_id": group_id, "user_id": user_id}, + ) + self._simple_insert_txn( + txn, + table="local_group_membership", + values={ + "group_id": group_id, + "user_id": user_id, + "is_admin": is_admin, + "membership": membership, + "is_publicised": is_publicised, + "content": json.dumps(content), + }, + ) + + self._simple_insert_txn( + txn, + table="local_group_updates", + values={ + "stream_id": next_id, + "group_id": group_id, + "user_id": user_id, + "type": "membership", + "content": json.dumps( + {"membership": membership, "content": content} + ), + }, + ) + self._group_updates_stream_cache.entity_has_changed(user_id, next_id) + + # TODO: Insert profile to ensure it comes down stream if its a join. + + if membership == "join": + if local_attestation: + self._simple_insert_txn( + txn, + table="group_attestations_renewals", + values={ + "group_id": group_id, + "user_id": user_id, + "valid_until_ms": local_attestation["valid_until_ms"], + }, + ) + if remote_attestation: + self._simple_insert_txn( + txn, + table="group_attestations_remote", + values={ + "group_id": group_id, + "user_id": user_id, + "valid_until_ms": remote_attestation["valid_until_ms"], + "attestation_json": json.dumps(remote_attestation), + }, + ) + else: + self._simple_delete_txn( + txn, + table="group_attestations_renewals", + keyvalues={"group_id": group_id, "user_id": user_id}, + ) + self._simple_delete_txn( + txn, + table="group_attestations_remote", + keyvalues={"group_id": group_id, "user_id": user_id}, + ) + + return next_id + + with self._group_updates_id_gen.get_next() as next_id: + res = yield self.runInteraction( + "register_user_group_membership", + _register_user_group_membership_txn, + next_id, + ) + return res + + @defer.inlineCallbacks + def create_group( + self, group_id, user_id, name, avatar_url, short_description, long_description + ): + yield self._simple_insert( + table="groups", + values={ + "group_id": group_id, + "name": name, + "avatar_url": avatar_url, + "short_description": short_description, + "long_description": long_description, + "is_public": True, + }, + desc="create_group", + ) + + @defer.inlineCallbacks + def update_group_profile(self, group_id, profile): + yield self._simple_update_one( + table="groups", + keyvalues={"group_id": group_id}, + updatevalues=profile, + desc="update_group_profile", + ) + + def get_attestations_need_renewals(self, valid_until_ms): + """Get all attestations that need to be renewed until givent time + """ + + def _get_attestations_need_renewals_txn(txn): + sql = """ + SELECT group_id, user_id FROM group_attestations_renewals + WHERE valid_until_ms <= ? + """ + txn.execute(sql, (valid_until_ms,)) + return self.cursor_to_dict(txn) + + return self.runInteraction( + "get_attestations_need_renewals", _get_attestations_need_renewals_txn + ) + + def update_attestation_renewal(self, group_id, user_id, attestation): + """Update an attestation that we have renewed + """ + return self._simple_update_one( + table="group_attestations_renewals", + keyvalues={"group_id": group_id, "user_id": user_id}, + updatevalues={"valid_until_ms": attestation["valid_until_ms"]}, + desc="update_attestation_renewal", + ) + + def update_remote_attestion(self, group_id, user_id, attestation): + """Update an attestation that a remote has renewed + """ + return self._simple_update_one( + table="group_attestations_remote", + keyvalues={"group_id": group_id, "user_id": user_id}, + updatevalues={ + "valid_until_ms": attestation["valid_until_ms"], + "attestation_json": json.dumps(attestation), + }, + desc="update_remote_attestion", + ) + + def remove_attestation_renewal(self, group_id, user_id): + """Remove an attestation that we thought we should renew, but actually + shouldn't. Ideally this would never get called as we would never + incorrectly try and do attestations for local users on local groups. + + Args: + group_id (str) + user_id (str) + """ + return self._simple_delete( + table="group_attestations_renewals", + keyvalues={"group_id": group_id, "user_id": user_id}, + desc="remove_attestation_renewal", + ) + + @defer.inlineCallbacks + def get_remote_attestation(self, group_id, user_id): + """Get the attestation that proves the remote agrees that the user is + in the group. + """ + row = yield self._simple_select_one( + table="group_attestations_remote", + keyvalues={"group_id": group_id, "user_id": user_id}, + retcols=("valid_until_ms", "attestation_json"), + desc="get_remote_attestation", + allow_none=True, + ) + + now = int(self._clock.time_msec()) + if row and now < row["valid_until_ms"]: + return json.loads(row["attestation_json"]) + + return None + + def get_joined_groups(self, user_id): + return self._simple_select_onecol( + table="local_group_membership", + keyvalues={"user_id": user_id, "membership": "join"}, + retcol="group_id", + desc="get_joined_groups", + ) + + def get_all_groups_for_user(self, user_id, now_token): + def _get_all_groups_for_user_txn(txn): + sql = """ + SELECT group_id, type, membership, u.content + FROM local_group_updates AS u + INNER JOIN local_group_membership USING (group_id, user_id) + WHERE user_id = ? AND membership != 'leave' + AND stream_id <= ? + """ + txn.execute(sql, (user_id, now_token)) + return [ + { + "group_id": row[0], + "type": row[1], + "membership": row[2], + "content": json.loads(row[3]), + } + for row in txn + ] + + return self.runInteraction( + "get_all_groups_for_user", _get_all_groups_for_user_txn + ) + + def get_groups_changes_for_user(self, user_id, from_token, to_token): + from_token = int(from_token) + has_changed = self._group_updates_stream_cache.has_entity_changed( + user_id, from_token + ) + if not has_changed: + return [] + + def _get_groups_changes_for_user_txn(txn): + sql = """ + SELECT group_id, membership, type, u.content + FROM local_group_updates AS u + INNER JOIN local_group_membership USING (group_id, user_id) + WHERE user_id = ? AND ? < stream_id AND stream_id <= ? + """ + txn.execute(sql, (user_id, from_token, to_token)) + return [ + { + "group_id": group_id, + "membership": membership, + "type": gtype, + "content": json.loads(content_json), + } + for group_id, membership, gtype, content_json in txn + ] + + return self.runInteraction( + "get_groups_changes_for_user", _get_groups_changes_for_user_txn + ) + + def get_all_groups_changes(self, from_token, to_token, limit): + from_token = int(from_token) + has_changed = self._group_updates_stream_cache.has_any_entity_changed( + from_token + ) + if not has_changed: + return [] + + def _get_all_groups_changes_txn(txn): + sql = """ + SELECT stream_id, group_id, user_id, type, content + FROM local_group_updates + WHERE ? < stream_id AND stream_id <= ? + LIMIT ? + """ + txn.execute(sql, (from_token, to_token, limit)) + return [ + (stream_id, group_id, user_id, gtype, json.loads(content_json)) + for stream_id, group_id, user_id, gtype, content_json in txn + ] + + return self.runInteraction( + "get_all_groups_changes", _get_all_groups_changes_txn + ) + + def get_group_stream_token(self): + return self._group_updates_id_gen.get_current_token() + + def delete_group(self, group_id): + """Deletes a group fully from the database. + + Args: + group_id (str) + + Returns: + Deferred + """ + + def _delete_group_txn(txn): + tables = [ + "groups", + "group_users", + "group_invites", + "group_rooms", + "group_summary_rooms", + "group_summary_room_categories", + "group_room_categories", + "group_summary_users", + "group_summary_roles", + "group_roles", + "group_attestations_renewals", + "group_attestations_remote", + ] + + for table in tables: + self._simple_delete_txn( + txn, table=table, keyvalues={"group_id": group_id} + ) + + return self.runInteraction("delete_group", _delete_group_txn) diff --git a/synapse/storage/data_stores/main/keys.py b/synapse/storage/data_stores/main/keys.py new file mode 100644 index 0000000000..ebc7db3ed6 --- /dev/null +++ b/synapse/storage/data_stores/main/keys.py @@ -0,0 +1,214 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import logging + +import six + +from signedjson.key import decode_verify_key_bytes + +from synapse.storage._base import SQLBaseStore +from synapse.storage.keys import FetchKeyResult +from synapse.util import batch_iter +from synapse.util.caches.descriptors import cached, cachedList + +logger = logging.getLogger(__name__) + +# py2 sqlite has buffer hardcoded as only binary type, so we must use it, +# despite being deprecated and removed in favor of memoryview +if six.PY2: + db_binary_type = six.moves.builtins.buffer +else: + db_binary_type = memoryview + + +class KeyStore(SQLBaseStore): + """Persistence for signature verification keys + """ + + @cached() + def _get_server_verify_key(self, server_name_and_key_id): + raise NotImplementedError() + + @cachedList( + cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids" + ) + def get_server_verify_keys(self, server_name_and_key_ids): + """ + Args: + server_name_and_key_ids (iterable[Tuple[str, str]]): + iterable of (server_name, key-id) tuples to fetch keys for + + Returns: + Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]: + map from (server_name, key_id) -> FetchKeyResult, or None if the key is + unknown + """ + keys = {} + + def _get_keys(txn, batch): + """Processes a batch of keys to fetch, and adds the result to `keys`.""" + + # batch_iter always returns tuples so it's safe to do len(batch) + sql = ( + "SELECT server_name, key_id, verify_key, ts_valid_until_ms " + "FROM server_signature_keys WHERE 1=0" + ) + " OR (server_name=? AND key_id=?)" * len(batch) + + txn.execute(sql, tuple(itertools.chain.from_iterable(batch))) + + for row in txn: + server_name, key_id, key_bytes, ts_valid_until_ms = row + + if ts_valid_until_ms is None: + # Old keys may be stored with a ts_valid_until_ms of null, + # in which case we treat this as if it was set to `0`, i.e. + # it won't match key requests that define a minimum + # `ts_valid_until_ms`. + ts_valid_until_ms = 0 + + res = FetchKeyResult( + verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)), + valid_until_ts=ts_valid_until_ms, + ) + keys[(server_name, key_id)] = res + + def _txn(txn): + for batch in batch_iter(server_name_and_key_ids, 50): + _get_keys(txn, batch) + return keys + + return self.runInteraction("get_server_verify_keys", _txn) + + def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys): + """Stores NACL verification keys for remote servers. + Args: + from_server (str): Where the verification keys were looked up + ts_added_ms (int): The time to record that the key was added + verify_keys (iterable[tuple[str, str, FetchKeyResult]]): + keys to be stored. Each entry is a triplet of + (server_name, key_id, key). + """ + key_values = [] + value_values = [] + invalidations = [] + for server_name, key_id, fetch_result in verify_keys: + key_values.append((server_name, key_id)) + value_values.append( + ( + from_server, + ts_added_ms, + fetch_result.valid_until_ts, + db_binary_type(fetch_result.verify_key.encode()), + ) + ) + # invalidate takes a tuple corresponding to the params of + # _get_server_verify_key. _get_server_verify_key only takes one + # param, which is itself the 2-tuple (server_name, key_id). + invalidations.append((server_name, key_id)) + + def _invalidate(res): + f = self._get_server_verify_key.invalidate + for i in invalidations: + f((i,)) + return res + + return self.runInteraction( + "store_server_verify_keys", + self._simple_upsert_many_txn, + table="server_signature_keys", + key_names=("server_name", "key_id"), + key_values=key_values, + value_names=( + "from_server", + "ts_added_ms", + "ts_valid_until_ms", + "verify_key", + ), + value_values=value_values, + ).addCallback(_invalidate) + + def store_server_keys_json( + self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes + ): + """Stores the JSON bytes for a set of keys from a server + The JSON should be signed by the originating server, the intermediate + server, and by this server. Updates the value for the + (server_name, key_id, from_server) triplet if one already existed. + Args: + server_name (str): The name of the server. + key_id (str): The identifer of the key this JSON is for. + from_server (str): The server this JSON was fetched from. + ts_now_ms (int): The time now in milliseconds. + ts_valid_until_ms (int): The time when this json stops being valid. + key_json (bytes): The encoded JSON. + """ + return self._simple_upsert( + table="server_keys_json", + keyvalues={ + "server_name": server_name, + "key_id": key_id, + "from_server": from_server, + }, + values={ + "server_name": server_name, + "key_id": key_id, + "from_server": from_server, + "ts_added_ms": ts_now_ms, + "ts_valid_until_ms": ts_expires_ms, + "key_json": db_binary_type(key_json_bytes), + }, + desc="store_server_keys_json", + ) + + def get_server_keys_json(self, server_keys): + """Retrive the key json for a list of server_keys and key ids. + If no keys are found for a given server, key_id and source then + that server, key_id, and source triplet entry will be an empty list. + The JSON is returned as a byte array so that it can be efficiently + used in an HTTP response. + Args: + server_keys (list): List of (server_name, key_id, source) triplets. + Returns: + Deferred[dict[Tuple[str, str, str|None], list[dict]]]: + Dict mapping (server_name, key_id, source) triplets to lists of dicts + """ + + def _get_server_keys_json_txn(txn): + results = {} + for server_name, key_id, from_server in server_keys: + keyvalues = {"server_name": server_name} + if key_id is not None: + keyvalues["key_id"] = key_id + if from_server is not None: + keyvalues["from_server"] = from_server + rows = self._simple_select_list_txn( + txn, + "server_keys_json", + keyvalues=keyvalues, + retcols=( + "key_id", + "from_server", + "ts_added_ms", + "ts_valid_until_ms", + "key_json", + ), + ) + results[(server_name, key_id, from_server)] = rows + return results + + return self.runInteraction("get_server_keys_json", _get_server_keys_json_txn) diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py new file mode 100644 index 0000000000..84b5f3ad5e --- /dev/null +++ b/synapse/storage/data_stores/main/media_repository.py @@ -0,0 +1,378 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.storage.background_updates import BackgroundUpdateStore + + +class MediaRepositoryBackgroundUpdateStore(BackgroundUpdateStore): + def __init__(self, db_conn, hs): + super(MediaRepositoryBackgroundUpdateStore, self).__init__(db_conn, hs) + + self.register_background_index_update( + update_name="local_media_repository_url_idx", + index_name="local_media_repository_url_idx", + table="local_media_repository", + columns=["created_ts"], + where_clause="url_cache IS NOT NULL", + ) + + +class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): + """Persistence for attachments and avatars""" + + def __init__(self, db_conn, hs): + super(MediaRepositoryStore, self).__init__(db_conn, hs) + + def get_local_media(self, media_id): + """Get the metadata for a local piece of media + Returns: + None if the media_id doesn't exist. + """ + return self._simple_select_one( + "local_media_repository", + {"media_id": media_id}, + ( + "media_type", + "media_length", + "upload_name", + "created_ts", + "quarantined_by", + "url_cache", + ), + allow_none=True, + desc="get_local_media", + ) + + def store_local_media( + self, + media_id, + media_type, + time_now_ms, + upload_name, + media_length, + user_id, + url_cache=None, + ): + return self._simple_insert( + "local_media_repository", + { + "media_id": media_id, + "media_type": media_type, + "created_ts": time_now_ms, + "upload_name": upload_name, + "media_length": media_length, + "user_id": user_id.to_string(), + "url_cache": url_cache, + }, + desc="store_local_media", + ) + + def get_url_cache(self, url, ts): + """Get the media_id and ts for a cached URL as of the given timestamp + Returns: + None if the URL isn't cached. + """ + + def get_url_cache_txn(txn): + # get the most recently cached result (relative to the given ts) + sql = ( + "SELECT response_code, etag, expires_ts, og, media_id, download_ts" + " FROM local_media_repository_url_cache" + " WHERE url = ? AND download_ts <= ?" + " ORDER BY download_ts DESC LIMIT 1" + ) + txn.execute(sql, (url, ts)) + row = txn.fetchone() + + if not row: + # ...or if we've requested a timestamp older than the oldest + # copy in the cache, return the oldest copy (if any) + sql = ( + "SELECT response_code, etag, expires_ts, og, media_id, download_ts" + " FROM local_media_repository_url_cache" + " WHERE url = ? AND download_ts > ?" + " ORDER BY download_ts ASC LIMIT 1" + ) + txn.execute(sql, (url, ts)) + row = txn.fetchone() + + if not row: + return None + + return dict( + zip( + ( + "response_code", + "etag", + "expires_ts", + "og", + "media_id", + "download_ts", + ), + row, + ) + ) + + return self.runInteraction("get_url_cache", get_url_cache_txn) + + def store_url_cache( + self, url, response_code, etag, expires_ts, og, media_id, download_ts + ): + return self._simple_insert( + "local_media_repository_url_cache", + { + "url": url, + "response_code": response_code, + "etag": etag, + "expires_ts": expires_ts, + "og": og, + "media_id": media_id, + "download_ts": download_ts, + }, + desc="store_url_cache", + ) + + def get_local_media_thumbnails(self, media_id): + return self._simple_select_list( + "local_media_repository_thumbnails", + {"media_id": media_id}, + ( + "thumbnail_width", + "thumbnail_height", + "thumbnail_method", + "thumbnail_type", + "thumbnail_length", + ), + desc="get_local_media_thumbnails", + ) + + def store_local_thumbnail( + self, + media_id, + thumbnail_width, + thumbnail_height, + thumbnail_type, + thumbnail_method, + thumbnail_length, + ): + return self._simple_insert( + "local_media_repository_thumbnails", + { + "media_id": media_id, + "thumbnail_width": thumbnail_width, + "thumbnail_height": thumbnail_height, + "thumbnail_method": thumbnail_method, + "thumbnail_type": thumbnail_type, + "thumbnail_length": thumbnail_length, + }, + desc="store_local_thumbnail", + ) + + def get_cached_remote_media(self, origin, media_id): + return self._simple_select_one( + "remote_media_cache", + {"media_origin": origin, "media_id": media_id}, + ( + "media_type", + "media_length", + "upload_name", + "created_ts", + "filesystem_id", + "quarantined_by", + ), + allow_none=True, + desc="get_cached_remote_media", + ) + + def store_cached_remote_media( + self, + origin, + media_id, + media_type, + media_length, + time_now_ms, + upload_name, + filesystem_id, + ): + return self._simple_insert( + "remote_media_cache", + { + "media_origin": origin, + "media_id": media_id, + "media_type": media_type, + "media_length": media_length, + "created_ts": time_now_ms, + "upload_name": upload_name, + "filesystem_id": filesystem_id, + "last_access_ts": time_now_ms, + }, + desc="store_cached_remote_media", + ) + + def update_cached_last_access_time(self, local_media, remote_media, time_ms): + """Updates the last access time of the given media + + Args: + local_media (iterable[str]): Set of media_ids + remote_media (iterable[(str, str)]): Set of (server_name, media_id) + time_ms: Current time in milliseconds + """ + + def update_cache_txn(txn): + sql = ( + "UPDATE remote_media_cache SET last_access_ts = ?" + " WHERE media_origin = ? AND media_id = ?" + ) + + txn.executemany( + sql, + ( + (time_ms, media_origin, media_id) + for media_origin, media_id in remote_media + ), + ) + + sql = ( + "UPDATE local_media_repository SET last_access_ts = ?" + " WHERE media_id = ?" + ) + + txn.executemany(sql, ((time_ms, media_id) for media_id in local_media)) + + return self.runInteraction("update_cached_last_access_time", update_cache_txn) + + def get_remote_media_thumbnails(self, origin, media_id): + return self._simple_select_list( + "remote_media_cache_thumbnails", + {"media_origin": origin, "media_id": media_id}, + ( + "thumbnail_width", + "thumbnail_height", + "thumbnail_method", + "thumbnail_type", + "thumbnail_length", + "filesystem_id", + ), + desc="get_remote_media_thumbnails", + ) + + def store_remote_media_thumbnail( + self, + origin, + media_id, + filesystem_id, + thumbnail_width, + thumbnail_height, + thumbnail_type, + thumbnail_method, + thumbnail_length, + ): + return self._simple_insert( + "remote_media_cache_thumbnails", + { + "media_origin": origin, + "media_id": media_id, + "thumbnail_width": thumbnail_width, + "thumbnail_height": thumbnail_height, + "thumbnail_method": thumbnail_method, + "thumbnail_type": thumbnail_type, + "thumbnail_length": thumbnail_length, + "filesystem_id": filesystem_id, + }, + desc="store_remote_media_thumbnail", + ) + + def get_remote_media_before(self, before_ts): + sql = ( + "SELECT media_origin, media_id, filesystem_id" + " FROM remote_media_cache" + " WHERE last_access_ts < ?" + ) + + return self._execute( + "get_remote_media_before", self.cursor_to_dict, sql, before_ts + ) + + def delete_remote_media(self, media_origin, media_id): + def delete_remote_media_txn(txn): + self._simple_delete_txn( + txn, + "remote_media_cache", + keyvalues={"media_origin": media_origin, "media_id": media_id}, + ) + self._simple_delete_txn( + txn, + "remote_media_cache_thumbnails", + keyvalues={"media_origin": media_origin, "media_id": media_id}, + ) + + return self.runInteraction("delete_remote_media", delete_remote_media_txn) + + def get_expired_url_cache(self, now_ts): + sql = ( + "SELECT media_id FROM local_media_repository_url_cache" + " WHERE expires_ts < ?" + " ORDER BY expires_ts ASC" + " LIMIT 500" + ) + + def _get_expired_url_cache_txn(txn): + txn.execute(sql, (now_ts,)) + return [row[0] for row in txn] + + return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn) + + def delete_url_cache(self, media_ids): + if len(media_ids) == 0: + return + + sql = "DELETE FROM local_media_repository_url_cache" " WHERE media_id = ?" + + def _delete_url_cache_txn(txn): + txn.executemany(sql, [(media_id,) for media_id in media_ids]) + + return self.runInteraction("delete_url_cache", _delete_url_cache_txn) + + def get_url_cache_media_before(self, before_ts): + sql = ( + "SELECT media_id FROM local_media_repository" + " WHERE created_ts < ? AND url_cache IS NOT NULL" + " ORDER BY created_ts ASC" + " LIMIT 500" + ) + + def _get_url_cache_media_before_txn(txn): + txn.execute(sql, (before_ts,)) + return [row[0] for row in txn] + + return self.runInteraction( + "get_url_cache_media_before", _get_url_cache_media_before_txn + ) + + def delete_url_cache_media(self, media_ids): + if len(media_ids) == 0: + return + + def _delete_url_cache_media_txn(txn): + sql = "DELETE FROM local_media_repository" " WHERE media_id = ?" + + txn.executemany(sql, [(media_id,) for media_id in media_ids]) + + sql = "DELETE FROM local_media_repository_thumbnails" " WHERE media_id = ?" + + txn.executemany(sql, [(media_id,) for media_id in media_ids]) + + return self.runInteraction( + "delete_url_cache_media", _delete_url_cache_media_txn + ) diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py new file mode 100644 index 0000000000..e6ee1e4aaa --- /dev/null +++ b/synapse/storage/data_stores/main/monthly_active_users.py @@ -0,0 +1,328 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from twisted.internet import defer + +from synapse.storage._base import SQLBaseStore +from synapse.util.caches.descriptors import cached + +logger = logging.getLogger(__name__) + +# Number of msec of granularity to store the monthly_active_user timestamp +# This means it is not necessary to update the table on every request +LAST_SEEN_GRANULARITY = 60 * 60 * 1000 + + +class MonthlyActiveUsersStore(SQLBaseStore): + def __init__(self, dbconn, hs): + super(MonthlyActiveUsersStore, self).__init__(None, hs) + self._clock = hs.get_clock() + self.hs = hs + # Do not add more reserved users than the total allowable number + self._new_transaction( + dbconn, + "initialise_mau_threepids", + [], + [], + self._initialise_reserved_users, + hs.config.mau_limits_reserved_threepids[: self.hs.config.max_mau_value], + ) + + def _initialise_reserved_users(self, txn, threepids): + """Ensures that reserved threepids are accounted for in the MAU table, should + be called on start up. + + Args: + txn (cursor): + threepids (list[dict]): List of threepid dicts to reserve + """ + + for tp in threepids: + user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"]) + + if user_id: + is_support = self.is_support_user_txn(txn, user_id) + if not is_support: + self.upsert_monthly_active_user_txn(txn, user_id) + else: + logger.warning("mau limit reserved threepid %s not found in db" % tp) + + @defer.inlineCallbacks + def reap_monthly_active_users(self): + """Cleans out monthly active user table to ensure that no stale + entries exist. + + Returns: + Deferred[] + """ + + def _reap_users(txn, reserved_users): + """ + Args: + reserved_users (tuple): reserved users to preserve + """ + + thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) + query_args = [thirty_days_ago] + base_sql = "DELETE FROM monthly_active_users WHERE timestamp < ?" + + # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres + # when len(reserved_users) == 0. Works fine on sqlite. + if len(reserved_users) > 0: + # questionmarks is a hack to overcome sqlite not supporting + # tuples in 'WHERE IN %s' + question_marks = ",".join("?" * len(reserved_users)) + + query_args.extend(reserved_users) + sql = base_sql + " AND user_id NOT IN ({})".format(question_marks) + else: + sql = base_sql + + txn.execute(sql, query_args) + + max_mau_value = self.hs.config.max_mau_value + if self.hs.config.limit_usage_by_mau: + # If MAU user count still exceeds the MAU threshold, then delete on + # a least recently active basis. + # Note it is not possible to write this query using OFFSET due to + # incompatibilities in how sqlite and postgres support the feature. + # sqlite requires 'LIMIT -1 OFFSET ?', the LIMIT must be present + # While Postgres does not require 'LIMIT', but also does not support + # negative LIMIT values. So there is no way to write it that both can + # support + if len(reserved_users) == 0: + sql = """ + DELETE FROM monthly_active_users + WHERE user_id NOT IN ( + SELECT user_id FROM monthly_active_users + ORDER BY timestamp DESC + LIMIT ? + ) + """ + txn.execute(sql, (max_mau_value,)) + # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres + # when len(reserved_users) == 0. Works fine on sqlite. + else: + # Must be >= 0 for postgres + num_of_non_reserved_users_to_remove = max( + max_mau_value - len(reserved_users), 0 + ) + + # It is important to filter reserved users twice to guard + # against the case where the reserved user is present in the + # SELECT, meaning that a legitmate mau is deleted. + sql = """ + DELETE FROM monthly_active_users + WHERE user_id NOT IN ( + SELECT user_id FROM monthly_active_users + WHERE user_id NOT IN ({}) + ORDER BY timestamp DESC + LIMIT ? + ) + AND user_id NOT IN ({}) + """.format( + question_marks, question_marks + ) + + query_args = [ + *reserved_users, + num_of_non_reserved_users_to_remove, + *reserved_users, + ] + + txn.execute(sql, query_args) + + reserved_users = yield self.get_registered_reserved_users() + yield self.runInteraction( + "reap_monthly_active_users", _reap_users, reserved_users + ) + # It seems poor to invalidate the whole cache, Postgres supports + # 'Returning' which would allow me to invalidate only the + # specific users, but sqlite has no way to do this and instead + # I would need to SELECT and the DELETE which without locking + # is racy. + # Have resolved to invalidate the whole cache for now and do + # something about it if and when the perf becomes significant + self.user_last_seen_monthly_active.invalidate_all() + self.get_monthly_active_count.invalidate_all() + + @cached(num_args=0) + def get_monthly_active_count(self): + """Generates current count of monthly active users + + Returns: + Defered[int]: Number of current monthly active users + """ + + def _count_users(txn): + sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users" + + txn.execute(sql) + count, = txn.fetchone() + return count + + return self.runInteraction("count_users", _count_users) + + @defer.inlineCallbacks + def get_registered_reserved_users(self): + """Of the reserved threepids defined in config, which are associated + with registered users? + + Returns: + Defered[list]: Real reserved users + """ + users = [] + + for tp in self.hs.config.mau_limits_reserved_threepids[ + : self.hs.config.max_mau_value + ]: + user_id = yield self.hs.get_datastore().get_user_id_by_threepid( + tp["medium"], tp["address"] + ) + if user_id: + users.append(user_id) + + return users + + @defer.inlineCallbacks + def upsert_monthly_active_user(self, user_id): + """Updates or inserts the user into the monthly active user table, which + is used to track the current MAU usage of the server + + Args: + user_id (str): user to add/update + """ + # Support user never to be included in MAU stats. Note I can't easily call this + # from upsert_monthly_active_user_txn because then I need a _txn form of + # is_support_user which is complicated because I want to cache the result. + # Therefore I call it here and ignore the case where + # upsert_monthly_active_user_txn is called directly from + # _initialise_reserved_users reasoning that it would be very strange to + # include a support user in this context. + + is_support = yield self.is_support_user(user_id) + if is_support: + return + + yield self.runInteraction( + "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id + ) + + user_in_mau = self.user_last_seen_monthly_active.cache.get( + (user_id,), None, update_metrics=False + ) + if user_in_mau is None: + self.get_monthly_active_count.invalidate(()) + + self.user_last_seen_monthly_active.invalidate((user_id,)) + + def upsert_monthly_active_user_txn(self, txn, user_id): + """Updates or inserts monthly active user member + + Note that, after calling this method, it will generally be necessary + to invalidate the caches on user_last_seen_monthly_active and + get_monthly_active_count. We can't do that here, because we are running + in a database thread rather than the main thread, and we can't call + txn.call_after because txn may not be a LoggingTransaction. + + We consciously do not call is_support_txn from this method because it + is not possible to cache the response. is_support_txn will be false in + almost all cases, so it seems reasonable to call it only for + upsert_monthly_active_user and to call is_support_txn manually + for cases where upsert_monthly_active_user_txn is called directly, + like _initialise_reserved_users + + In short, don't call this method with support users. (Support users + should not appear in the MAU stats). + + Args: + txn (cursor): + user_id (str): user to add/update + + Returns: + bool: True if a new entry was created, False if an + existing one was updated. + """ + + # Am consciously deciding to lock the table on the basis that is ought + # never be a big table and alternative approaches (batching multiple + # upserts into a single txn) introduced a lot of extra complexity. + # See https://github.com/matrix-org/synapse/issues/3854 for more + is_insert = self._simple_upsert_txn( + txn, + table="monthly_active_users", + keyvalues={"user_id": user_id}, + values={"timestamp": int(self._clock.time_msec())}, + ) + + return is_insert + + @cached(num_args=1) + def user_last_seen_monthly_active(self, user_id): + """ + Checks if a given user is part of the monthly active user group + Arguments: + user_id (str): user to add/update + Return: + Deferred[int] : timestamp since last seen, None if never seen + + """ + + return self._simple_select_one_onecol( + table="monthly_active_users", + keyvalues={"user_id": user_id}, + retcol="timestamp", + allow_none=True, + desc="user_last_seen_monthly_active", + ) + + @defer.inlineCallbacks + def populate_monthly_active_users(self, user_id): + """Checks on the state of monthly active user limits and optionally + add the user to the monthly active tables + + Args: + user_id(str): the user_id to query + """ + if self.hs.config.limit_usage_by_mau or self.hs.config.mau_stats_only: + # Trial users and guests should not be included as part of MAU group + is_guest = yield self.is_guest(user_id) + if is_guest: + return + is_trial = yield self.is_trial_user(user_id) + if is_trial: + return + + last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id) + now = self.hs.get_clock().time_msec() + + # We want to reduce to the total number of db writes, and are happy + # to trade accuracy of timestamp in order to lighten load. This means + # We always insert new users (where MAU threshold has not been reached), + # but only update if we have not previously seen the user for + # LAST_SEEN_GRANULARITY ms + if last_seen_timestamp is None: + # In the case where mau_stats_only is True and limit_usage_by_mau is + # False, there is no point in checking get_monthly_active_count - it + # adds no value and will break the logic if max_mau_value is exceeded. + if not self.hs.config.limit_usage_by_mau: + yield self.upsert_monthly_active_user(user_id) + else: + count = yield self.get_monthly_active_count() + if count < self.hs.config.max_mau_value: + yield self.upsert_monthly_active_user(user_id) + elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY: + yield self.upsert_monthly_active_user(user_id) diff --git a/synapse/storage/data_stores/main/openid.py b/synapse/storage/data_stores/main/openid.py new file mode 100644 index 0000000000..79b40044d9 --- /dev/null +++ b/synapse/storage/data_stores/main/openid.py @@ -0,0 +1,31 @@ +from synapse.storage._base import SQLBaseStore + + +class OpenIdStore(SQLBaseStore): + def insert_open_id_token(self, token, ts_valid_until_ms, user_id): + return self._simple_insert( + table="open_id_tokens", + values={ + "token": token, + "ts_valid_until_ms": ts_valid_until_ms, + "user_id": user_id, + }, + desc="insert_open_id_token", + ) + + def get_user_id_for_open_id_token(self, token, ts_now_ms): + def get_user_id_for_token_txn(txn): + sql = ( + "SELECT user_id FROM open_id_tokens" + " WHERE token = ? AND ? <= ts_valid_until_ms" + ) + + txn.execute(sql, (token, ts_now_ms)) + + rows = txn.fetchall() + if not rows: + return None + else: + return rows[0][0] + + return self.runInteraction("get_user_id_for_token", get_user_id_for_token_txn) diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py new file mode 100644 index 0000000000..523ed6575e --- /dev/null +++ b/synapse/storage/data_stores/main/presence.py @@ -0,0 +1,150 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.presence import UserPresenceState +from synapse.util import batch_iter +from synapse.util.caches.descriptors import cached, cachedList + + +class PresenceStore(SQLBaseStore): + @defer.inlineCallbacks + def update_presence(self, presence_states): + stream_ordering_manager = self._presence_id_gen.get_next_mult( + len(presence_states) + ) + + with stream_ordering_manager as stream_orderings: + yield self.runInteraction( + "update_presence", + self._update_presence_txn, + stream_orderings, + presence_states, + ) + + return stream_orderings[-1], self._presence_id_gen.get_current_token() + + def _update_presence_txn(self, txn, stream_orderings, presence_states): + for stream_id, state in zip(stream_orderings, presence_states): + txn.call_after( + self.presence_stream_cache.entity_has_changed, state.user_id, stream_id + ) + txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,)) + + # Actually insert new rows + self._simple_insert_many_txn( + txn, + table="presence_stream", + values=[ + { + "stream_id": stream_id, + "user_id": state.user_id, + "state": state.state, + "last_active_ts": state.last_active_ts, + "last_federation_update_ts": state.last_federation_update_ts, + "last_user_sync_ts": state.last_user_sync_ts, + "status_msg": state.status_msg, + "currently_active": state.currently_active, + } + for state in presence_states + ], + ) + + # Delete old rows to stop database from getting really big + sql = "DELETE FROM presence_stream WHERE stream_id < ? AND " + + for states in batch_iter(presence_states, 50): + clause, args = make_in_list_sql_clause( + self.database_engine, "user_id", [s.user_id for s in states] + ) + txn.execute(sql + clause, [stream_id] + list(args)) + + def get_all_presence_updates(self, last_id, current_id): + if last_id == current_id: + return defer.succeed([]) + + def get_all_presence_updates_txn(txn): + sql = ( + "SELECT stream_id, user_id, state, last_active_ts," + " last_federation_update_ts, last_user_sync_ts, status_msg," + " currently_active" + " FROM presence_stream" + " WHERE ? < stream_id AND stream_id <= ?" + ) + txn.execute(sql, (last_id, current_id)) + return txn.fetchall() + + return self.runInteraction( + "get_all_presence_updates", get_all_presence_updates_txn + ) + + @cached() + def _get_presence_for_user(self, user_id): + raise NotImplementedError() + + @cachedList( + cached_method_name="_get_presence_for_user", + list_name="user_ids", + num_args=1, + inlineCallbacks=True, + ) + def get_presence_for_users(self, user_ids): + rows = yield self._simple_select_many_batch( + table="presence_stream", + column="user_id", + iterable=user_ids, + keyvalues={}, + retcols=( + "user_id", + "state", + "last_active_ts", + "last_federation_update_ts", + "last_user_sync_ts", + "status_msg", + "currently_active", + ), + desc="get_presence_for_users", + ) + + for row in rows: + row["currently_active"] = bool(row["currently_active"]) + + return {row["user_id"]: UserPresenceState(**row) for row in rows} + + def get_current_presence_token(self): + return self._presence_id_gen.get_current_token() + + def allow_presence_visible(self, observed_localpart, observer_userid): + return self._simple_insert( + table="presence_allow_inbound", + values={ + "observed_user_id": observed_localpart, + "observer_user_id": observer_userid, + }, + desc="allow_presence_visible", + or_ignore=True, + ) + + def disallow_presence_visible(self, observed_localpart, observer_userid): + return self._simple_delete_one( + table="presence_allow_inbound", + keyvalues={ + "observed_user_id": observed_localpart, + "observer_user_id": observer_userid, + }, + desc="disallow_presence_visible", + ) diff --git a/synapse/storage/data_stores/main/profile.py b/synapse/storage/data_stores/main/profile.py new file mode 100644 index 0000000000..e4e8a1c1d6 --- /dev/null +++ b/synapse/storage/data_stores/main/profile.py @@ -0,0 +1,178 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.errors import StoreError +from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.main.roommember import ProfileInfo + + +class ProfileWorkerStore(SQLBaseStore): + @defer.inlineCallbacks + def get_profileinfo(self, user_localpart): + try: + profile = yield self._simple_select_one( + table="profiles", + keyvalues={"user_id": user_localpart}, + retcols=("displayname", "avatar_url"), + desc="get_profileinfo", + ) + except StoreError as e: + if e.code == 404: + # no match + return ProfileInfo(None, None) + else: + raise + + return ProfileInfo( + avatar_url=profile["avatar_url"], display_name=profile["displayname"] + ) + + def get_profile_displayname(self, user_localpart): + return self._simple_select_one_onecol( + table="profiles", + keyvalues={"user_id": user_localpart}, + retcol="displayname", + desc="get_profile_displayname", + ) + + def get_profile_avatar_url(self, user_localpart): + return self._simple_select_one_onecol( + table="profiles", + keyvalues={"user_id": user_localpart}, + retcol="avatar_url", + desc="get_profile_avatar_url", + ) + + def get_from_remote_profile_cache(self, user_id): + return self._simple_select_one( + table="remote_profile_cache", + keyvalues={"user_id": user_id}, + retcols=("displayname", "avatar_url"), + allow_none=True, + desc="get_from_remote_profile_cache", + ) + + def create_profile(self, user_localpart): + return self._simple_insert( + table="profiles", values={"user_id": user_localpart}, desc="create_profile" + ) + + def set_profile_displayname(self, user_localpart, new_displayname): + return self._simple_update_one( + table="profiles", + keyvalues={"user_id": user_localpart}, + updatevalues={"displayname": new_displayname}, + desc="set_profile_displayname", + ) + + def set_profile_avatar_url(self, user_localpart, new_avatar_url): + return self._simple_update_one( + table="profiles", + keyvalues={"user_id": user_localpart}, + updatevalues={"avatar_url": new_avatar_url}, + desc="set_profile_avatar_url", + ) + + +class ProfileStore(ProfileWorkerStore): + def add_remote_profile_cache(self, user_id, displayname, avatar_url): + """Ensure we are caching the remote user's profiles. + + This should only be called when `is_subscribed_remote_profile_for_user` + would return true for the user. + """ + return self._simple_upsert( + table="remote_profile_cache", + keyvalues={"user_id": user_id}, + values={ + "displayname": displayname, + "avatar_url": avatar_url, + "last_check": self._clock.time_msec(), + }, + desc="add_remote_profile_cache", + ) + + def update_remote_profile_cache(self, user_id, displayname, avatar_url): + return self._simple_update( + table="remote_profile_cache", + keyvalues={"user_id": user_id}, + values={ + "displayname": displayname, + "avatar_url": avatar_url, + "last_check": self._clock.time_msec(), + }, + desc="update_remote_profile_cache", + ) + + @defer.inlineCallbacks + def maybe_delete_remote_profile_cache(self, user_id): + """Check if we still care about the remote user's profile, and if we + don't then remove their profile from the cache + """ + subscribed = yield self.is_subscribed_remote_profile_for_user(user_id) + if not subscribed: + yield self._simple_delete( + table="remote_profile_cache", + keyvalues={"user_id": user_id}, + desc="delete_remote_profile_cache", + ) + + def get_remote_profile_cache_entries_that_expire(self, last_checked): + """Get all users who haven't been checked since `last_checked` + """ + + def _get_remote_profile_cache_entries_that_expire_txn(txn): + sql = """ + SELECT user_id, displayname, avatar_url + FROM remote_profile_cache + WHERE last_check < ? + """ + + txn.execute(sql, (last_checked,)) + + return self.cursor_to_dict(txn) + + return self.runInteraction( + "get_remote_profile_cache_entries_that_expire", + _get_remote_profile_cache_entries_that_expire_txn, + ) + + @defer.inlineCallbacks + def is_subscribed_remote_profile_for_user(self, user_id): + """Check whether we are interested in a remote user's profile. + """ + res = yield self._simple_select_one_onecol( + table="group_users", + keyvalues={"user_id": user_id}, + retcol="user_id", + allow_none=True, + desc="should_update_remote_profile_cache_for_user", + ) + + if res: + return True + + res = yield self._simple_select_one_onecol( + table="group_invites", + keyvalues={"user_id": user_id}, + retcol="user_id", + allow_none=True, + desc="should_update_remote_profile_cache_for_user", + ) + + if res: + return True diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py new file mode 100644 index 0000000000..cd95f1ce60 --- /dev/null +++ b/synapse/storage/data_stores/main/push_rule.py @@ -0,0 +1,713 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import logging + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.push.baserules import list_with_base_rules +from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.main.appservice import ApplicationServiceWorkerStore +from synapse.storage.data_stores.main.pusher import PusherWorkerStore +from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore +from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore +from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException +from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList +from synapse.util.caches.stream_change_cache import StreamChangeCache + +logger = logging.getLogger(__name__) + + +def _load_rules(rawrules, enabled_map): + ruleslist = [] + for rawrule in rawrules: + rule = dict(rawrule) + rule["conditions"] = json.loads(rawrule["conditions"]) + rule["actions"] = json.loads(rawrule["actions"]) + ruleslist.append(rule) + + # We're going to be mutating this a lot, so do a deep copy + rules = list(list_with_base_rules(ruleslist)) + + for i, rule in enumerate(rules): + rule_id = rule["rule_id"] + if rule_id in enabled_map: + if rule.get("enabled", True) != bool(enabled_map[rule_id]): + # Rules are cached across users. + rule = dict(rule) + rule["enabled"] = bool(enabled_map[rule_id]) + rules[i] = rule + + return rules + + +class PushRulesWorkerStore( + ApplicationServiceWorkerStore, + ReceiptsWorkerStore, + PusherWorkerStore, + RoomMemberWorkerStore, + SQLBaseStore, +): + """This is an abstract base class where subclasses must implement + `get_max_push_rules_stream_id` which can be called in the initializer. + """ + + # This ABCMeta metaclass ensures that we cannot be instantiated without + # the abstract methods being implemented. + __metaclass__ = abc.ABCMeta + + def __init__(self, db_conn, hs): + super(PushRulesWorkerStore, self).__init__(db_conn, hs) + + push_rules_prefill, push_rules_id = self._get_cache_dict( + db_conn, + "push_rules_stream", + entity_column="user_id", + stream_column="stream_id", + max_value=self.get_max_push_rules_stream_id(), + ) + + self.push_rules_stream_cache = StreamChangeCache( + "PushRulesStreamChangeCache", + push_rules_id, + prefilled_cache=push_rules_prefill, + ) + + @abc.abstractmethod + def get_max_push_rules_stream_id(self): + """Get the position of the push rules stream. + + Returns: + int + """ + raise NotImplementedError() + + @cachedInlineCallbacks(max_entries=5000) + def get_push_rules_for_user(self, user_id): + rows = yield self._simple_select_list( + table="push_rules", + keyvalues={"user_name": user_id}, + retcols=( + "user_name", + "rule_id", + "priority_class", + "priority", + "conditions", + "actions", + ), + desc="get_push_rules_enabled_for_user", + ) + + rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) + + enabled_map = yield self.get_push_rules_enabled_for_user(user_id) + + rules = _load_rules(rows, enabled_map) + + return rules + + @cachedInlineCallbacks(max_entries=5000) + def get_push_rules_enabled_for_user(self, user_id): + results = yield self._simple_select_list( + table="push_rules_enable", + keyvalues={"user_name": user_id}, + retcols=("user_name", "rule_id", "enabled"), + desc="get_push_rules_enabled_for_user", + ) + return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results} + + def have_push_rules_changed_for_user(self, user_id, last_id): + if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): + return defer.succeed(False) + else: + + def have_push_rules_changed_txn(txn): + sql = ( + "SELECT COUNT(stream_id) FROM push_rules_stream" + " WHERE user_id = ? AND ? < stream_id" + ) + txn.execute(sql, (user_id, last_id)) + count, = txn.fetchone() + return bool(count) + + return self.runInteraction( + "have_push_rules_changed", have_push_rules_changed_txn + ) + + @cachedList( + cached_method_name="get_push_rules_for_user", + list_name="user_ids", + num_args=1, + inlineCallbacks=True, + ) + def bulk_get_push_rules(self, user_ids): + if not user_ids: + return {} + + results = {user_id: [] for user_id in user_ids} + + rows = yield self._simple_select_many_batch( + table="push_rules", + column="user_name", + iterable=user_ids, + retcols=("*",), + desc="bulk_get_push_rules", + ) + + rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) + + for row in rows: + results.setdefault(row["user_name"], []).append(row) + + enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids) + + for user_id, rules in results.items(): + results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {})) + + return results + + @defer.inlineCallbacks + def copy_push_rule_from_room_to_room(self, new_room_id, user_id, rule): + """Copy a single push rule from one room to another for a specific user. + + Args: + new_room_id (str): ID of the new room. + user_id (str): ID of user the push rule belongs to. + rule (Dict): A push rule. + """ + # Create new rule id + rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1]) + new_rule_id = rule_id_scope + "/" + new_room_id + + # Change room id in each condition + for condition in rule.get("conditions", []): + if condition.get("key") == "room_id": + condition["pattern"] = new_room_id + + # Add the rule for the new room + yield self.add_push_rule( + user_id=user_id, + rule_id=new_rule_id, + priority_class=rule["priority_class"], + conditions=rule["conditions"], + actions=rule["actions"], + ) + + @defer.inlineCallbacks + def copy_push_rules_from_room_to_room_for_user( + self, old_room_id, new_room_id, user_id + ): + """Copy all of the push rules from one room to another for a specific + user. + + Args: + old_room_id (str): ID of the old room. + new_room_id (str): ID of the new room. + user_id (str): ID of user to copy push rules for. + """ + # Retrieve push rules for this user + user_push_rules = yield self.get_push_rules_for_user(user_id) + + # Get rules relating to the old room and copy them to the new room + for rule in user_push_rules: + conditions = rule.get("conditions", []) + if any( + (c.get("key") == "room_id" and c.get("pattern") == old_room_id) + for c in conditions + ): + yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule) + + @defer.inlineCallbacks + def bulk_get_push_rules_for_room(self, event, context): + state_group = context.state_group + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + current_state_ids = yield context.get_current_state_ids(self) + result = yield self._bulk_get_push_rules_for_room( + event.room_id, state_group, current_state_ids, event=event + ) + return result + + @cachedInlineCallbacks(num_args=2, cache_context=True) + def _bulk_get_push_rules_for_room( + self, room_id, state_group, current_state_ids, cache_context, event=None + ): + # We don't use `state_group`, its there so that we can cache based + # on it. However, its important that its never None, since two current_state's + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + # We also will want to generate notifs for other people in the room so + # their unread countss are correct in the event stream, but to avoid + # generating them for bot / AS users etc, we only do so for people who've + # sent a read receipt into the room. + + users_in_room = yield self._get_joined_users_from_context( + room_id, + state_group, + current_state_ids, + on_invalidate=cache_context.invalidate, + event=event, + ) + + # We ignore app service users for now. This is so that we don't fill + # up the `get_if_users_have_pushers` cache with AS entries that we + # know don't have pushers, nor even read receipts. + local_users_in_room = set( + u + for u in users_in_room + if self.hs.is_mine_id(u) + and not self.get_if_app_services_interested_in_user(u) + ) + + # users in the room who have pushers need to get push rules run because + # that's how their pushers work + if_users_with_pushers = yield self.get_if_users_have_pushers( + local_users_in_room, on_invalidate=cache_context.invalidate + ) + user_ids = set( + uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher + ) + + users_with_receipts = yield self.get_users_with_read_receipts_in_room( + room_id, on_invalidate=cache_context.invalidate + ) + + # any users with pushers must be ours: they have pushers + for uid in users_with_receipts: + if uid in local_users_in_room: + user_ids.add(uid) + + rules_by_user = yield self.bulk_get_push_rules( + user_ids, on_invalidate=cache_context.invalidate + ) + + rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None} + + return rules_by_user + + @cachedList( + cached_method_name="get_push_rules_enabled_for_user", + list_name="user_ids", + num_args=1, + inlineCallbacks=True, + ) + def bulk_get_push_rules_enabled(self, user_ids): + if not user_ids: + return {} + + results = {user_id: {} for user_id in user_ids} + + rows = yield self._simple_select_many_batch( + table="push_rules_enable", + column="user_name", + iterable=user_ids, + retcols=("user_name", "rule_id", "enabled"), + desc="bulk_get_push_rules_enabled", + ) + for row in rows: + enabled = bool(row["enabled"]) + results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled + return results + + +class PushRuleStore(PushRulesWorkerStore): + @defer.inlineCallbacks + def add_push_rule( + self, + user_id, + rule_id, + priority_class, + conditions, + actions, + before=None, + after=None, + ): + conditions_json = json.dumps(conditions) + actions_json = json.dumps(actions) + with self._push_rules_stream_id_gen.get_next() as ids: + stream_id, event_stream_ordering = ids + if before or after: + yield self.runInteraction( + "_add_push_rule_relative_txn", + self._add_push_rule_relative_txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + conditions_json, + actions_json, + before, + after, + ) + else: + yield self.runInteraction( + "_add_push_rule_highest_priority_txn", + self._add_push_rule_highest_priority_txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + conditions_json, + actions_json, + ) + + def _add_push_rule_relative_txn( + self, + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + conditions_json, + actions_json, + before, + after, + ): + # Lock the table since otherwise we'll have annoying races between the + # SELECT here and the UPSERT below. + self.database_engine.lock_table(txn, "push_rules") + + relative_to_rule = before or after + + res = self._simple_select_one_txn( + txn, + table="push_rules", + keyvalues={"user_name": user_id, "rule_id": relative_to_rule}, + retcols=["priority_class", "priority"], + allow_none=True, + ) + + if not res: + raise RuleNotFoundException( + "before/after rule not found: %s" % (relative_to_rule,) + ) + + base_priority_class = res["priority_class"] + base_rule_priority = res["priority"] + + if base_priority_class != priority_class: + raise InconsistentRuleException( + "Given priority class does not match class of relative rule" + ) + + if before: + # Higher priority rules are executed first, So adding a rule before + # a rule means giving it a higher priority than that rule. + new_rule_priority = base_rule_priority + 1 + else: + # We increment the priority of the existing rules to make space for + # the new rule. Therefore if we want this rule to appear after + # an existing rule we give it the priority of the existing rule, + # and then increment the priority of the existing rule. + new_rule_priority = base_rule_priority + + sql = ( + "UPDATE push_rules SET priority = priority + 1" + " WHERE user_name = ? AND priority_class = ? AND priority >= ?" + ) + + txn.execute(sql, (user_id, priority_class, new_rule_priority)) + + self._upsert_push_rule_txn( + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + new_rule_priority, + conditions_json, + actions_json, + ) + + def _add_push_rule_highest_priority_txn( + self, + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + conditions_json, + actions_json, + ): + # Lock the table since otherwise we'll have annoying races between the + # SELECT here and the UPSERT below. + self.database_engine.lock_table(txn, "push_rules") + + # find the highest priority rule in that class + sql = ( + "SELECT COUNT(*), MAX(priority) FROM push_rules" + " WHERE user_name = ? and priority_class = ?" + ) + txn.execute(sql, (user_id, priority_class)) + res = txn.fetchall() + (how_many, highest_prio) = res[0] + + new_prio = 0 + if how_many > 0: + new_prio = highest_prio + 1 + + self._upsert_push_rule_txn( + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + new_prio, + conditions_json, + actions_json, + ) + + def _upsert_push_rule_txn( + self, + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + priority, + conditions_json, + actions_json, + update_stream=True, + ): + """Specialised version of _simple_upsert_txn that picks a push_rule_id + using the _push_rule_id_gen if it needs to insert the rule. It assumes + that the "push_rules" table is locked""" + + sql = ( + "UPDATE push_rules" + " SET priority_class = ?, priority = ?, conditions = ?, actions = ?" + " WHERE user_name = ? AND rule_id = ?" + ) + + txn.execute( + sql, + (priority_class, priority, conditions_json, actions_json, user_id, rule_id), + ) + + if txn.rowcount == 0: + # We didn't update a row with the given rule_id so insert one + push_rule_id = self._push_rule_id_gen.get_next() + + self._simple_insert_txn( + txn, + table="push_rules", + values={ + "id": push_rule_id, + "user_name": user_id, + "rule_id": rule_id, + "priority_class": priority_class, + "priority": priority, + "conditions": conditions_json, + "actions": actions_json, + }, + ) + + if update_stream: + self._insert_push_rules_update_txn( + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + op="ADD", + data={ + "priority_class": priority_class, + "priority": priority, + "conditions": conditions_json, + "actions": actions_json, + }, + ) + + @defer.inlineCallbacks + def delete_push_rule(self, user_id, rule_id): + """ + Delete a push rule. Args specify the row to be deleted and can be + any of the columns in the push_rule table, but below are the + standard ones + + Args: + user_id (str): The matrix ID of the push rule owner + rule_id (str): The rule_id of the rule to be deleted + """ + + def delete_push_rule_txn(txn, stream_id, event_stream_ordering): + self._simple_delete_one_txn( + txn, "push_rules", {"user_name": user_id, "rule_id": rule_id} + ) + + self._insert_push_rules_update_txn( + txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE" + ) + + with self._push_rules_stream_id_gen.get_next() as ids: + stream_id, event_stream_ordering = ids + yield self.runInteraction( + "delete_push_rule", + delete_push_rule_txn, + stream_id, + event_stream_ordering, + ) + + @defer.inlineCallbacks + def set_push_rule_enabled(self, user_id, rule_id, enabled): + with self._push_rules_stream_id_gen.get_next() as ids: + stream_id, event_stream_ordering = ids + yield self.runInteraction( + "_set_push_rule_enabled_txn", + self._set_push_rule_enabled_txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + enabled, + ) + + def _set_push_rule_enabled_txn( + self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled + ): + new_id = self._push_rules_enable_id_gen.get_next() + self._simple_upsert_txn( + txn, + "push_rules_enable", + {"user_name": user_id, "rule_id": rule_id}, + {"enabled": 1 if enabled else 0}, + {"id": new_id}, + ) + + self._insert_push_rules_update_txn( + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + op="ENABLE" if enabled else "DISABLE", + ) + + @defer.inlineCallbacks + def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule): + actions_json = json.dumps(actions) + + def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering): + if is_default_rule: + # Add a dummy rule to the rules table with the user specified + # actions. + priority_class = -1 + priority = 1 + self._upsert_push_rule_txn( + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + priority, + "[]", + actions_json, + update_stream=False, + ) + else: + self._simple_update_one_txn( + txn, + "push_rules", + {"user_name": user_id, "rule_id": rule_id}, + {"actions": actions_json}, + ) + + self._insert_push_rules_update_txn( + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + op="ACTIONS", + data={"actions": actions_json}, + ) + + with self._push_rules_stream_id_gen.get_next() as ids: + stream_id, event_stream_ordering = ids + yield self.runInteraction( + "set_push_rule_actions", + set_push_rule_actions_txn, + stream_id, + event_stream_ordering, + ) + + def _insert_push_rules_update_txn( + self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None + ): + values = { + "stream_id": stream_id, + "event_stream_ordering": event_stream_ordering, + "user_id": user_id, + "rule_id": rule_id, + "op": op, + } + if data is not None: + values.update(data) + + self._simple_insert_txn(txn, "push_rules_stream", values=values) + + txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,)) + txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,)) + txn.call_after( + self.push_rules_stream_cache.entity_has_changed, user_id, stream_id + ) + + def get_all_push_rule_updates(self, last_id, current_id, limit): + """Get all the push rules changes that have happend on the server""" + if last_id == current_id: + return defer.succeed([]) + + def get_all_push_rule_updates_txn(txn): + sql = ( + "SELECT stream_id, event_stream_ordering, user_id, rule_id," + " op, priority_class, priority, conditions, actions" + " FROM push_rules_stream" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + return txn.fetchall() + + return self.runInteraction( + "get_all_push_rule_updates", get_all_push_rule_updates_txn + ) + + def get_push_rules_stream_token(self): + """Get the position of the push rules stream. + Returns a pair of a stream id for the push_rules stream and the + room stream ordering it corresponds to.""" + return self._push_rules_stream_id_gen.get_current_token() + + def get_max_push_rules_stream_id(self): + return self.get_push_rules_stream_token()[0] diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py new file mode 100644 index 0000000000..f005c1ae0a --- /dev/null +++ b/synapse/storage/data_stores/main/pusher.py @@ -0,0 +1,371 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import six + +from canonicaljson import encode_canonical_json, json + +from twisted.internet import defer + +from synapse.storage._base import SQLBaseStore +from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList + +logger = logging.getLogger(__name__) + +if six.PY2: + db_binary_type = six.moves.builtins.buffer +else: + db_binary_type = memoryview + + +class PusherWorkerStore(SQLBaseStore): + def _decode_pushers_rows(self, rows): + for r in rows: + dataJson = r["data"] + r["data"] = None + try: + if isinstance(dataJson, db_binary_type): + dataJson = str(dataJson).decode("UTF8") + + r["data"] = json.loads(dataJson) + except Exception as e: + logger.warn( + "Invalid JSON in data for pusher %d: %s, %s", + r["id"], + dataJson, + e.args[0], + ) + pass + + if isinstance(r["pushkey"], db_binary_type): + r["pushkey"] = str(r["pushkey"]).decode("UTF8") + + return rows + + @defer.inlineCallbacks + def user_has_pusher(self, user_id): + ret = yield self._simple_select_one_onecol( + "pushers", {"user_name": user_id}, "id", allow_none=True + ) + return ret is not None + + def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey): + return self.get_pushers_by({"app_id": app_id, "pushkey": pushkey}) + + def get_pushers_by_user_id(self, user_id): + return self.get_pushers_by({"user_name": user_id}) + + @defer.inlineCallbacks + def get_pushers_by(self, keyvalues): + ret = yield self._simple_select_list( + "pushers", + keyvalues, + [ + "id", + "user_name", + "access_token", + "profile_tag", + "kind", + "app_id", + "app_display_name", + "device_display_name", + "pushkey", + "ts", + "lang", + "data", + "last_stream_ordering", + "last_success", + "failing_since", + ], + desc="get_pushers_by", + ) + return self._decode_pushers_rows(ret) + + @defer.inlineCallbacks + def get_all_pushers(self): + def get_pushers(txn): + txn.execute("SELECT * FROM pushers") + rows = self.cursor_to_dict(txn) + + return self._decode_pushers_rows(rows) + + rows = yield self.runInteraction("get_all_pushers", get_pushers) + return rows + + def get_all_updated_pushers(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed(([], [])) + + def get_all_updated_pushers_txn(txn): + sql = ( + "SELECT id, user_name, access_token, profile_tag, kind," + " app_id, app_display_name, device_display_name, pushkey, ts," + " lang, data" + " FROM pushers" + " WHERE ? < id AND id <= ?" + " ORDER BY id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + updated = txn.fetchall() + + sql = ( + "SELECT stream_id, user_id, app_id, pushkey" + " FROM deleted_pushers" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + deleted = txn.fetchall() + + return updated, deleted + + return self.runInteraction( + "get_all_updated_pushers", get_all_updated_pushers_txn + ) + + def get_all_updated_pushers_rows(self, last_id, current_id, limit): + """Get all the pushers that have changed between the given tokens. + + Returns: + Deferred(list(tuple)): each tuple consists of: + stream_id (str) + user_id (str) + app_id (str) + pushkey (str) + was_deleted (bool): whether the pusher was added/updated (False) + or deleted (True) + """ + + if last_id == current_id: + return defer.succeed([]) + + def get_all_updated_pushers_rows_txn(txn): + sql = ( + "SELECT id, user_name, app_id, pushkey" + " FROM pushers" + " WHERE ? < id AND id <= ?" + " ORDER BY id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + results = [list(row) + [False] for row in txn] + + sql = ( + "SELECT stream_id, user_id, app_id, pushkey" + " FROM deleted_pushers" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + + results.extend(list(row) + [True] for row in txn) + results.sort() # Sort so that they're ordered by stream id + + return results + + return self.runInteraction( + "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn + ) + + @cachedInlineCallbacks(num_args=1, max_entries=15000) + def get_if_user_has_pusher(self, user_id): + # This only exists for the cachedList decorator + raise NotImplementedError() + + @cachedList( + cached_method_name="get_if_user_has_pusher", + list_name="user_ids", + num_args=1, + inlineCallbacks=True, + ) + def get_if_users_have_pushers(self, user_ids): + rows = yield self._simple_select_many_batch( + table="pushers", + column="user_name", + iterable=user_ids, + retcols=["user_name"], + desc="get_if_users_have_pushers", + ) + + result = {user_id: False for user_id in user_ids} + result.update({r["user_name"]: True for r in rows}) + + return result + + +class PusherStore(PusherWorkerStore): + def get_pushers_stream_token(self): + return self._pushers_id_gen.get_current_token() + + @defer.inlineCallbacks + def add_pusher( + self, + user_id, + access_token, + kind, + app_id, + app_display_name, + device_display_name, + pushkey, + pushkey_ts, + lang, + data, + last_stream_ordering, + profile_tag="", + ): + with self._pushers_id_gen.get_next() as stream_id: + # no need to lock because `pushers` has a unique key on + # (app_id, pushkey, user_name) so _simple_upsert will retry + yield self._simple_upsert( + table="pushers", + keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, + values={ + "access_token": access_token, + "kind": kind, + "app_display_name": app_display_name, + "device_display_name": device_display_name, + "ts": pushkey_ts, + "lang": lang, + "data": bytearray(encode_canonical_json(data)), + "last_stream_ordering": last_stream_ordering, + "profile_tag": profile_tag, + "id": stream_id, + }, + desc="add_pusher", + lock=False, + ) + + user_has_pusher = self.get_if_user_has_pusher.cache.get( + (user_id,), None, update_metrics=False + ) + + if user_has_pusher is not True: + # invalidate, since we the user might not have had a pusher before + yield self.runInteraction( + "add_pusher", + self._invalidate_cache_and_stream, + self.get_if_user_has_pusher, + (user_id,), + ) + + @defer.inlineCallbacks + def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id): + def delete_pusher_txn(txn, stream_id): + self._invalidate_cache_and_stream( + txn, self.get_if_user_has_pusher, (user_id,) + ) + + self._simple_delete_one_txn( + txn, + "pushers", + {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, + ) + + # it's possible for us to end up with duplicate rows for + # (app_id, pushkey, user_id) at different stream_ids, but that + # doesn't really matter. + self._simple_insert_txn( + txn, + table="deleted_pushers", + values={ + "stream_id": stream_id, + "app_id": app_id, + "pushkey": pushkey, + "user_id": user_id, + }, + ) + + with self._pushers_id_gen.get_next() as stream_id: + yield self.runInteraction("delete_pusher", delete_pusher_txn, stream_id) + + @defer.inlineCallbacks + def update_pusher_last_stream_ordering( + self, app_id, pushkey, user_id, last_stream_ordering + ): + yield self._simple_update_one( + "pushers", + {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, + {"last_stream_ordering": last_stream_ordering}, + desc="update_pusher_last_stream_ordering", + ) + + @defer.inlineCallbacks + def update_pusher_last_stream_ordering_and_success( + self, app_id, pushkey, user_id, last_stream_ordering, last_success + ): + """Update the last stream ordering position we've processed up to for + the given pusher. + + Args: + app_id (str) + pushkey (str) + last_stream_ordering (int) + last_success (int) + + Returns: + Deferred[bool]: True if the pusher still exists; False if it has been deleted. + """ + updated = yield self._simple_update( + table="pushers", + keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, + updatevalues={ + "last_stream_ordering": last_stream_ordering, + "last_success": last_success, + }, + desc="update_pusher_last_stream_ordering_and_success", + ) + + return bool(updated) + + @defer.inlineCallbacks + def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since): + yield self._simple_update( + table="pushers", + keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, + updatevalues={"failing_since": failing_since}, + desc="update_pusher_failing_since", + ) + + @defer.inlineCallbacks + def get_throttle_params_by_room(self, pusher_id): + res = yield self._simple_select_list( + "pusher_throttle", + {"pusher": pusher_id}, + ["room_id", "last_sent_ts", "throttle_ms"], + desc="get_throttle_params_by_room", + ) + + params_by_room = {} + for row in res: + params_by_room[row["room_id"]] = { + "last_sent_ts": row["last_sent_ts"], + "throttle_ms": row["throttle_ms"], + } + + return params_by_room + + @defer.inlineCallbacks + def set_throttle_params(self, pusher_id, room_id, params): + # no need to lock because `pusher_throttle` has a primary key on + # (pusher, room_id) so _simple_upsert will retry + yield self._simple_upsert( + "pusher_throttle", + {"pusher": pusher_id, "room_id": room_id}, + params, + desc="set_throttle_params", + lock=False, + ) diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py new file mode 100644 index 0000000000..0c24430f28 --- /dev/null +++ b/synapse/storage/data_stores/main/receipts.py @@ -0,0 +1,536 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import logging + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList +from synapse.util.caches.stream_change_cache import StreamChangeCache + +logger = logging.getLogger(__name__) + + +class ReceiptsWorkerStore(SQLBaseStore): + """This is an abstract base class where subclasses must implement + `get_max_receipt_stream_id` which can be called in the initializer. + """ + + # This ABCMeta metaclass ensures that we cannot be instantiated without + # the abstract methods being implemented. + __metaclass__ = abc.ABCMeta + + def __init__(self, db_conn, hs): + super(ReceiptsWorkerStore, self).__init__(db_conn, hs) + + self._receipts_stream_cache = StreamChangeCache( + "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() + ) + + @abc.abstractmethod + def get_max_receipt_stream_id(self): + """Get the current max stream ID for receipts stream + + Returns: + int + """ + raise NotImplementedError() + + @cachedInlineCallbacks() + def get_users_with_read_receipts_in_room(self, room_id): + receipts = yield self.get_receipts_for_room(room_id, "m.read") + return set(r["user_id"] for r in receipts) + + @cached(num_args=2) + def get_receipts_for_room(self, room_id, receipt_type): + return self._simple_select_list( + table="receipts_linearized", + keyvalues={"room_id": room_id, "receipt_type": receipt_type}, + retcols=("user_id", "event_id"), + desc="get_receipts_for_room", + ) + + @cached(num_args=3) + def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type): + return self._simple_select_one_onecol( + table="receipts_linearized", + keyvalues={ + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + }, + retcol="event_id", + desc="get_own_receipt_for_user", + allow_none=True, + ) + + @cachedInlineCallbacks(num_args=2) + def get_receipts_for_user(self, user_id, receipt_type): + rows = yield self._simple_select_list( + table="receipts_linearized", + keyvalues={"user_id": user_id, "receipt_type": receipt_type}, + retcols=("room_id", "event_id"), + desc="get_receipts_for_user", + ) + + return {row["room_id"]: row["event_id"] for row in rows} + + @defer.inlineCallbacks + def get_receipts_for_user_with_orderings(self, user_id, receipt_type): + def f(txn): + sql = ( + "SELECT rl.room_id, rl.event_id," + " e.topological_ordering, e.stream_ordering" + " FROM receipts_linearized AS rl" + " INNER JOIN events AS e USING (room_id, event_id)" + " WHERE rl.room_id = e.room_id" + " AND rl.event_id = e.event_id" + " AND user_id = ?" + ) + txn.execute(sql, (user_id,)) + return txn.fetchall() + + rows = yield self.runInteraction("get_receipts_for_user_with_orderings", f) + return { + row[0]: { + "event_id": row[1], + "topological_ordering": row[2], + "stream_ordering": row[3], + } + for row in rows + } + + @defer.inlineCallbacks + def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): + """Get receipts for multiple rooms for sending to clients. + + Args: + room_ids (list): List of room_ids. + to_key (int): Max stream id to fetch receipts upto. + from_key (int): Min stream id to fetch receipts from. None fetches + from the start. + + Returns: + list: A list of receipts. + """ + room_ids = set(room_ids) + + if from_key is not None: + # Only ask the database about rooms where there have been new + # receipts added since `from_key` + room_ids = yield self._receipts_stream_cache.get_entities_changed( + room_ids, from_key + ) + + results = yield self._get_linearized_receipts_for_rooms( + room_ids, to_key, from_key=from_key + ) + + return [ev for res in results.values() for ev in res] + + def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): + """Get receipts for a single room for sending to clients. + + Args: + room_ids (str): The room id. + to_key (int): Max stream id to fetch receipts upto. + from_key (int): Min stream id to fetch receipts from. None fetches + from the start. + + Returns: + Deferred[list]: A list of receipts. + """ + if from_key is not None: + # Check the cache first to see if any new receipts have been added + # since`from_key`. If not we can no-op. + if not self._receipts_stream_cache.has_entity_changed(room_id, from_key): + defer.succeed([]) + + return self._get_linearized_receipts_for_room(room_id, to_key, from_key) + + @cachedInlineCallbacks(num_args=3, tree=True) + def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): + """See get_linearized_receipts_for_room + """ + + def f(txn): + if from_key: + sql = ( + "SELECT * FROM receipts_linearized WHERE" + " room_id = ? AND stream_id > ? AND stream_id <= ?" + ) + + txn.execute(sql, (room_id, from_key, to_key)) + else: + sql = ( + "SELECT * FROM receipts_linearized WHERE" + " room_id = ? AND stream_id <= ?" + ) + + txn.execute(sql, (room_id, to_key)) + + rows = self.cursor_to_dict(txn) + + return rows + + rows = yield self.runInteraction("get_linearized_receipts_for_room", f) + + if not rows: + return [] + + content = {} + for row in rows: + content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[ + row["user_id"] + ] = json.loads(row["data"]) + + return [{"type": "m.receipt", "room_id": room_id, "content": content}] + + @cachedList( + cached_method_name="_get_linearized_receipts_for_room", + list_name="room_ids", + num_args=3, + inlineCallbacks=True, + ) + def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): + if not room_ids: + return {} + + def f(txn): + if from_key: + sql = """ + SELECT * FROM receipts_linearized WHERE + stream_id > ? AND stream_id <= ? AND + """ + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", room_ids + ) + + txn.execute(sql + clause, [from_key, to_key] + list(args)) + else: + sql = """ + SELECT * FROM receipts_linearized WHERE + stream_id <= ? AND + """ + + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", room_ids + ) + + txn.execute(sql + clause, [to_key] + list(args)) + + return self.cursor_to_dict(txn) + + txn_results = yield self.runInteraction("_get_linearized_receipts_for_rooms", f) + + results = {} + for row in txn_results: + # We want a single event per room, since we want to batch the + # receipts by room, event and type. + room_event = results.setdefault( + row["room_id"], + {"type": "m.receipt", "room_id": row["room_id"], "content": {}}, + ) + + # The content is of the form: + # {"$foo:bar": { "read": { "@user:host": }, .. }, .. } + event_entry = room_event["content"].setdefault(row["event_id"], {}) + receipt_type = event_entry.setdefault(row["receipt_type"], {}) + + receipt_type[row["user_id"]] = json.loads(row["data"]) + + results = { + room_id: [results[room_id]] if room_id in results else [] + for room_id in room_ids + } + return results + + def get_all_updated_receipts(self, last_id, current_id, limit=None): + if last_id == current_id: + return defer.succeed([]) + + def get_all_updated_receipts_txn(txn): + sql = ( + "SELECT stream_id, room_id, receipt_type, user_id, event_id, data" + " FROM receipts_linearized" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC" + ) + args = [last_id, current_id] + if limit is not None: + sql += " LIMIT ?" + args.append(limit) + txn.execute(sql, args) + + return (r[0:5] + (json.loads(r[5]),) for r in txn) + + return self.runInteraction( + "get_all_updated_receipts", get_all_updated_receipts_txn + ) + + def _invalidate_get_users_with_receipts_in_room( + self, room_id, receipt_type, user_id + ): + if receipt_type != "m.read": + return + + # Returns either an ObservableDeferred or the raw result + res = self.get_users_with_read_receipts_in_room.cache.get( + room_id, None, update_metrics=False + ) + + # first handle the Deferred case + if isinstance(res, defer.Deferred): + if res.called: + res = res.result + else: + res = None + + if res and user_id in res: + # We'd only be adding to the set, so no point invalidating if the + # user is already there + return + + self.get_users_with_read_receipts_in_room.invalidate((room_id,)) + + +class ReceiptsStore(ReceiptsWorkerStore): + def __init__(self, db_conn, hs): + # We instantiate this first as the ReceiptsWorkerStore constructor + # needs to be able to call get_max_receipt_stream_id + self._receipts_id_gen = StreamIdGenerator( + db_conn, "receipts_linearized", "stream_id" + ) + + super(ReceiptsStore, self).__init__(db_conn, hs) + + def get_max_receipt_stream_id(self): + return self._receipts_id_gen.get_current_token() + + def insert_linearized_receipt_txn( + self, txn, room_id, receipt_type, user_id, event_id, data, stream_id + ): + """Inserts a read-receipt into the database if it's newer than the current RR + + Returns: int|None + None if the RR is older than the current RR + otherwise, the rx timestamp of the event that the RR corresponds to + (or 0 if the event is unknown) + """ + res = self._simple_select_one_txn( + txn, + table="events", + retcols=["stream_ordering", "received_ts"], + keyvalues={"event_id": event_id}, + allow_none=True, + ) + + stream_ordering = int(res["stream_ordering"]) if res else None + rx_ts = res["received_ts"] if res else 0 + + # We don't want to clobber receipts for more recent events, so we + # have to compare orderings of existing receipts + if stream_ordering is not None: + sql = ( + "SELECT stream_ordering, event_id FROM events" + " INNER JOIN receipts_linearized as r USING (event_id, room_id)" + " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?" + ) + txn.execute(sql, (room_id, receipt_type, user_id)) + + for so, eid in txn: + if int(so) >= stream_ordering: + logger.debug( + "Ignoring new receipt for %s in favour of existing " + "one for later event %s", + event_id, + eid, + ) + return None + + txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type)) + txn.call_after( + self._invalidate_get_users_with_receipts_in_room, + room_id, + receipt_type, + user_id, + ) + txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type)) + # FIXME: This shouldn't invalidate the whole cache + txn.call_after( + self._get_linearized_receipts_for_room.invalidate_many, (room_id,) + ) + + txn.call_after( + self._receipts_stream_cache.entity_has_changed, room_id, stream_id + ) + + txn.call_after( + self.get_last_receipt_event_id_for_user.invalidate, + (user_id, room_id, receipt_type), + ) + + self._simple_delete_txn( + txn, + table="receipts_linearized", + keyvalues={ + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + }, + ) + + self._simple_insert_txn( + txn, + table="receipts_linearized", + values={ + "stream_id": stream_id, + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + "event_id": event_id, + "data": json.dumps(data), + }, + ) + + if receipt_type == "m.read" and stream_ordering is not None: + self._remove_old_push_actions_before_txn( + txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering + ) + + return rx_ts + + @defer.inlineCallbacks + def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data): + """Insert a receipt, either from local client or remote server. + + Automatically does conversion between linearized and graph + representations. + """ + if not event_ids: + return + + if len(event_ids) == 1: + linearized_event_id = event_ids[0] + else: + # we need to points in graph -> linearized form. + # TODO: Make this better. + def graph_to_linear(txn): + clause, args = make_in_list_sql_clause( + self.database_engine, "event_id", event_ids + ) + + sql = """ + SELECT event_id WHERE room_id = ? AND stream_ordering IN ( + SELECT max(stream_ordering) WHERE %s + ) + """ % ( + clause, + ) + + txn.execute(sql, [room_id] + list(args)) + rows = txn.fetchall() + if rows: + return rows[0][0] + else: + raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,)) + + linearized_event_id = yield self.runInteraction( + "insert_receipt_conv", graph_to_linear + ) + + stream_id_manager = self._receipts_id_gen.get_next() + with stream_id_manager as stream_id: + event_ts = yield self.runInteraction( + "insert_linearized_receipt", + self.insert_linearized_receipt_txn, + room_id, + receipt_type, + user_id, + linearized_event_id, + data, + stream_id=stream_id, + ) + + if event_ts is None: + return None + + now = self._clock.time_msec() + logger.debug( + "RR for event %s in %s (%i ms old)", + linearized_event_id, + room_id, + now - event_ts, + ) + + yield self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data) + + max_persisted_id = self._receipts_id_gen.get_current_token() + + return stream_id, max_persisted_id + + def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data): + return self.runInteraction( + "insert_graph_receipt", + self.insert_graph_receipt_txn, + room_id, + receipt_type, + user_id, + event_ids, + data, + ) + + def insert_graph_receipt_txn( + self, txn, room_id, receipt_type, user_id, event_ids, data + ): + txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type)) + txn.call_after( + self._invalidate_get_users_with_receipts_in_room, + room_id, + receipt_type, + user_id, + ) + txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type)) + # FIXME: This shouldn't invalidate the whole cache + txn.call_after( + self._get_linearized_receipts_for_room.invalidate_many, (room_id,) + ) + + self._simple_delete_txn( + txn, + table="receipts_graph", + keyvalues={ + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + }, + ) + self._simple_insert_txn( + txn, + table="receipts_graph", + values={ + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + "event_ids": json.dumps(event_ids), + "data": json.dumps(data), + }, + ) diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py new file mode 100644 index 0000000000..6c5b29288a --- /dev/null +++ b/synapse/storage/data_stores/main/registration.py @@ -0,0 +1,1499 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re + +from six import iterkeys +from six.moves import range + +from twisted.internet import defer +from twisted.internet.defer import Deferred + +from synapse.api.constants import UserTypes +from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage import background_updates +from synapse.storage._base import SQLBaseStore +from synapse.types import UserID +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks + +THIRTY_MINUTES_IN_MS = 30 * 60 * 1000 + +logger = logging.getLogger(__name__) + + +class RegistrationWorkerStore(SQLBaseStore): + def __init__(self, db_conn, hs): + super(RegistrationWorkerStore, self).__init__(db_conn, hs) + + self.config = hs.config + self.clock = hs.get_clock() + + @cached() + def get_user_by_id(self, user_id): + return self._simple_select_one( + table="users", + keyvalues={"name": user_id}, + retcols=[ + "name", + "password_hash", + "is_guest", + "consent_version", + "consent_server_notice_sent", + "appservice_id", + "creation_ts", + "user_type", + ], + allow_none=True, + desc="get_user_by_id", + ) + + @defer.inlineCallbacks + def is_trial_user(self, user_id): + """Checks if user is in the "trial" period, i.e. within the first + N days of registration defined by `mau_trial_days` config + + Args: + user_id (str) + + Returns: + Deferred[bool] + """ + + info = yield self.get_user_by_id(user_id) + if not info: + return False + + now = self.clock.time_msec() + trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000 + is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms + return is_trial + + @cached() + def get_user_by_access_token(self, token): + """Get a user from the given access token. + + Args: + token (str): The access token of a user. + Returns: + defer.Deferred: None, if the token did not match, otherwise dict + including the keys `name`, `is_guest`, `device_id`, `token_id`, + `valid_until_ms`. + """ + return self.runInteraction( + "get_user_by_access_token", self._query_for_auth, token + ) + + @cachedInlineCallbacks() + def get_expiration_ts_for_user(self, user_id): + """Get the expiration timestamp for the account bearing a given user ID. + + Args: + user_id (str): The ID of the user. + Returns: + defer.Deferred: None, if the account has no expiration timestamp, + otherwise int representation of the timestamp (as a number of + milliseconds since epoch). + """ + res = yield self._simple_select_one_onecol( + table="account_validity", + keyvalues={"user_id": user_id}, + retcol="expiration_ts_ms", + allow_none=True, + desc="get_expiration_ts_for_user", + ) + return res + + @defer.inlineCallbacks + def set_account_validity_for_user( + self, user_id, expiration_ts, email_sent, renewal_token=None + ): + """Updates the account validity properties of the given account, with the + given values. + + Args: + user_id (str): ID of the account to update properties for. + expiration_ts (int): New expiration date, as a timestamp in milliseconds + since epoch. + email_sent (bool): True means a renewal email has been sent for this + account and there's no need to send another one for the current validity + period. + renewal_token (str): Renewal token the user can use to extend the validity + of their account. Defaults to no token. + """ + + def set_account_validity_for_user_txn(txn): + self._simple_update_txn( + txn=txn, + table="account_validity", + keyvalues={"user_id": user_id}, + updatevalues={ + "expiration_ts_ms": expiration_ts, + "email_sent": email_sent, + "renewal_token": renewal_token, + }, + ) + self._invalidate_cache_and_stream( + txn, self.get_expiration_ts_for_user, (user_id,) + ) + + yield self.runInteraction( + "set_account_validity_for_user", set_account_validity_for_user_txn + ) + + @defer.inlineCallbacks + def set_renewal_token_for_user(self, user_id, renewal_token): + """Defines a renewal token for a given user. + + Args: + user_id (str): ID of the user to set the renewal token for. + renewal_token (str): Random unique string that will be used to renew the + user's account. + + Raises: + StoreError: The provided token is already set for another user. + """ + yield self._simple_update_one( + table="account_validity", + keyvalues={"user_id": user_id}, + updatevalues={"renewal_token": renewal_token}, + desc="set_renewal_token_for_user", + ) + + @defer.inlineCallbacks + def get_user_from_renewal_token(self, renewal_token): + """Get a user ID from a renewal token. + + Args: + renewal_token (str): The renewal token to perform the lookup with. + + Returns: + defer.Deferred[str]: The ID of the user to which the token belongs. + """ + res = yield self._simple_select_one_onecol( + table="account_validity", + keyvalues={"renewal_token": renewal_token}, + retcol="user_id", + desc="get_user_from_renewal_token", + ) + + return res + + @defer.inlineCallbacks + def get_renewal_token_for_user(self, user_id): + """Get the renewal token associated with a given user ID. + + Args: + user_id (str): The user ID to lookup a token for. + + Returns: + defer.Deferred[str]: The renewal token associated with this user ID. + """ + res = yield self._simple_select_one_onecol( + table="account_validity", + keyvalues={"user_id": user_id}, + retcol="renewal_token", + desc="get_renewal_token_for_user", + ) + + return res + + @defer.inlineCallbacks + def get_users_expiring_soon(self): + """Selects users whose account will expire in the [now, now + renew_at] time + window (see configuration for account_validity for information on what renew_at + refers to). + + Returns: + Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]] + """ + + def select_users_txn(txn, now_ms, renew_at): + sql = ( + "SELECT user_id, expiration_ts_ms FROM account_validity" + " WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?" + ) + values = [False, now_ms, renew_at] + txn.execute(sql, values) + return self.cursor_to_dict(txn) + + res = yield self.runInteraction( + "get_users_expiring_soon", + select_users_txn, + self.clock.time_msec(), + self.config.account_validity.renew_at, + ) + + return res + + @defer.inlineCallbacks + def set_renewal_mail_status(self, user_id, email_sent): + """Sets or unsets the flag that indicates whether a renewal email has been sent + to the user (and the user hasn't renewed their account yet). + + Args: + user_id (str): ID of the user to set/unset the flag for. + email_sent (bool): Flag which indicates whether a renewal email has been sent + to this user. + """ + yield self._simple_update_one( + table="account_validity", + keyvalues={"user_id": user_id}, + updatevalues={"email_sent": email_sent}, + desc="set_renewal_mail_status", + ) + + @defer.inlineCallbacks + def delete_account_validity_for_user(self, user_id): + """Deletes the entry for the given user in the account validity table, removing + their expiration date and renewal token. + + Args: + user_id (str): ID of the user to remove from the account validity table. + """ + yield self._simple_delete_one( + table="account_validity", + keyvalues={"user_id": user_id}, + desc="delete_account_validity_for_user", + ) + + @defer.inlineCallbacks + def is_server_admin(self, user): + """Determines if a user is an admin of this homeserver. + + Args: + user (UserID): user ID of the user to test + + Returns (bool): + true iff the user is a server admin, false otherwise. + """ + res = yield self._simple_select_one_onecol( + table="users", + keyvalues={"name": user.to_string()}, + retcol="admin", + allow_none=True, + desc="is_server_admin", + ) + + return res if res else False + + def set_server_admin(self, user, admin): + """Sets whether a user is an admin of this homeserver. + + Args: + user (UserID): user ID of the user to test + admin (bool): true iff the user is to be a server admin, + false otherwise. + """ + return self._simple_update_one( + table="users", + keyvalues={"name": user.to_string()}, + updatevalues={"admin": 1 if admin else 0}, + desc="set_server_admin", + ) + + def _query_for_auth(self, txn, token): + sql = ( + "SELECT users.name, users.is_guest, access_tokens.id as token_id," + " access_tokens.device_id, access_tokens.valid_until_ms" + " FROM users" + " INNER JOIN access_tokens on users.name = access_tokens.user_id" + " WHERE token = ?" + ) + + txn.execute(sql, (token,)) + rows = self.cursor_to_dict(txn) + if rows: + return rows[0] + + return None + + @cachedInlineCallbacks() + def is_real_user(self, user_id): + """Determines if the user is a real user, ie does not have a 'user_type'. + + Args: + user_id (str): user id to test + + Returns: + Deferred[bool]: True if user 'user_type' is null or empty string + """ + res = yield self.runInteraction("is_real_user", self.is_real_user_txn, user_id) + return res + + @cachedInlineCallbacks() + def is_support_user(self, user_id): + """Determines if the user is of type UserTypes.SUPPORT + + Args: + user_id (str): user id to test + + Returns: + Deferred[bool]: True if user is of type UserTypes.SUPPORT + """ + res = yield self.runInteraction( + "is_support_user", self.is_support_user_txn, user_id + ) + return res + + def is_real_user_txn(self, txn, user_id): + res = self._simple_select_one_onecol_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + retcol="user_type", + allow_none=True, + ) + return res is None + + def is_support_user_txn(self, txn, user_id): + res = self._simple_select_one_onecol_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + retcol="user_type", + allow_none=True, + ) + return True if res == UserTypes.SUPPORT else False + + def get_users_by_id_case_insensitive(self, user_id): + """Gets users that match user_id case insensitively. + Returns a mapping of user_id -> password_hash. + """ + + def f(txn): + sql = ( + "SELECT name, password_hash FROM users" " WHERE lower(name) = lower(?)" + ) + txn.execute(sql, (user_id,)) + return dict(txn) + + return self.runInteraction("get_users_by_id_case_insensitive", f) + + async def get_user_by_external_id( + self, auth_provider: str, external_id: str + ) -> str: + """Look up a user by their external auth id + + Args: + auth_provider: identifier for the remote auth provider + external_id: id on that system + + Returns: + str|None: the mxid of the user, or None if they are not known + """ + return await self._simple_select_one_onecol( + table="user_external_ids", + keyvalues={"auth_provider": auth_provider, "external_id": external_id}, + retcol="user_id", + allow_none=True, + desc="get_user_by_external_id", + ) + + @defer.inlineCallbacks + def count_all_users(self): + """Counts all users registered on the homeserver.""" + + def _count_users(txn): + txn.execute("SELECT COUNT(*) AS users FROM users") + rows = self.cursor_to_dict(txn) + if rows: + return rows[0]["users"] + return 0 + + ret = yield self.runInteraction("count_users", _count_users) + return ret + + def count_daily_user_type(self): + """ + Counts 1) native non guest users + 2) native guests users + 3) bridged users + who registered on the homeserver in the past 24 hours + """ + + def _count_daily_user_type(txn): + yesterday = int(self._clock.time()) - (60 * 60 * 24) + + sql = """ + SELECT user_type, COALESCE(count(*), 0) AS count FROM ( + SELECT + CASE + WHEN is_guest=0 AND appservice_id IS NULL THEN 'native' + WHEN is_guest=1 AND appservice_id IS NULL THEN 'guest' + WHEN is_guest=0 AND appservice_id IS NOT NULL THEN 'bridged' + END AS user_type + FROM users + WHERE creation_ts > ? + ) AS t GROUP BY user_type + """ + results = {"native": 0, "guest": 0, "bridged": 0} + txn.execute(sql, (yesterday,)) + for row in txn: + results[row[0]] = row[1] + return results + + return self.runInteraction("count_daily_user_type", _count_daily_user_type) + + @defer.inlineCallbacks + def count_nonbridged_users(self): + def _count_users(txn): + txn.execute( + """ + SELECT COALESCE(COUNT(*), 0) FROM users + WHERE appservice_id IS NULL + """ + ) + count, = txn.fetchone() + return count + + ret = yield self.runInteraction("count_users", _count_users) + return ret + + @defer.inlineCallbacks + def count_real_users(self): + """Counts all users without a special user_type registered on the homeserver.""" + + def _count_users(txn): + txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null") + rows = self.cursor_to_dict(txn) + if rows: + return rows[0]["users"] + return 0 + + ret = yield self.runInteraction("count_real_users", _count_users) + return ret + + @defer.inlineCallbacks + def find_next_generated_user_id_localpart(self): + """ + Gets the localpart of the next generated user ID. + + Generated user IDs are integers, and we aim for them to be as small as + we can. Unfortunately, it's possible some of them are already taken by + existing users, and there may be gaps in the already taken range. This + function returns the start of the first allocatable gap. This is to + avoid the case of ID 10000000 being pre-allocated, so us wasting the + first (and shortest) many generated user IDs. + """ + + def _find_next_generated_user_id(txn): + # We bound between '@1' and '@a' to avoid pulling the entire table + # out. + txn.execute("SELECT name FROM users WHERE '@1' <= name AND name < '@a'") + + regex = re.compile(r"^@(\d+):") + + found = set() + + for (user_id,) in txn: + match = regex.search(user_id) + if match: + found.add(int(match.group(1))) + for i in range(len(found) + 1): + if i not in found: + return i + + return ( + ( + yield self.runInteraction( + "find_next_generated_user_id", _find_next_generated_user_id + ) + ) + ) + + @defer.inlineCallbacks + def get_user_id_by_threepid(self, medium, address): + """Returns user id from threepid + + Args: + medium (str): threepid medium e.g. email + address (str): threepid address e.g. me@example.com + + Returns: + Deferred[str|None]: user id or None if no user id/threepid mapping exists + """ + user_id = yield self.runInteraction( + "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address + ) + return user_id + + def get_user_id_by_threepid_txn(self, txn, medium, address): + """Returns user id from threepid + + Args: + txn (cursor): + medium (str): threepid medium e.g. email + address (str): threepid address e.g. me@example.com + + Returns: + str|None: user id or None if no user id/threepid mapping exists + """ + ret = self._simple_select_one_txn( + txn, + "user_threepids", + {"medium": medium, "address": address}, + ["user_id"], + True, + ) + if ret: + return ret["user_id"] + return None + + @defer.inlineCallbacks + def user_add_threepid(self, user_id, medium, address, validated_at, added_at): + yield self._simple_upsert( + "user_threepids", + {"medium": medium, "address": address}, + {"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, + ) + + @defer.inlineCallbacks + def user_get_threepids(self, user_id): + ret = yield self._simple_select_list( + "user_threepids", + {"user_id": user_id}, + ["medium", "address", "validated_at", "added_at"], + "user_get_threepids", + ) + return ret + + def user_delete_threepid(self, user_id, medium, address): + return self._simple_delete( + "user_threepids", + keyvalues={"user_id": user_id, "medium": medium, "address": address}, + desc="user_delete_threepids", + ) + + def add_user_bound_threepid(self, user_id, medium, address, id_server): + """The server proxied a bind request to the given identity server on + behalf of the given user. We need to remember this in case the user + asks us to unbind the threepid. + + Args: + user_id (str) + medium (str) + address (str) + id_server (str) + + Returns: + Deferred + """ + # We need to use an upsert, in case they user had already bound the + # threepid + return self._simple_upsert( + table="user_threepid_id_server", + keyvalues={ + "user_id": user_id, + "medium": medium, + "address": address, + "id_server": id_server, + }, + values={}, + insertion_values={}, + desc="add_user_bound_threepid", + ) + + def user_get_bound_threepids(self, user_id): + """Get the threepids that a user has bound to an identity server through the homeserver + The homeserver remembers where binds to an identity server occurred. Using this + method can retrieve those threepids. + + Args: + user_id (str): The ID of the user to retrieve threepids for + + Returns: + Deferred[list[dict]]: List of dictionaries containing the following: + medium (str): The medium of the threepid (e.g "email") + address (str): The address of the threepid (e.g "bob@example.com") + """ + return self._simple_select_list( + table="user_threepid_id_server", + keyvalues={"user_id": user_id}, + retcols=["medium", "address"], + desc="user_get_bound_threepids", + ) + + def remove_user_bound_threepid(self, user_id, medium, address, id_server): + """The server proxied an unbind request to the given identity server on + behalf of the given user, so we remove the mapping of threepid to + identity server. + + Args: + user_id (str) + medium (str) + address (str) + id_server (str) + + Returns: + Deferred + """ + return self._simple_delete( + table="user_threepid_id_server", + keyvalues={ + "user_id": user_id, + "medium": medium, + "address": address, + "id_server": id_server, + }, + desc="remove_user_bound_threepid", + ) + + def get_id_servers_user_bound(self, user_id, medium, address): + """Get the list of identity servers that the server proxied bind + requests to for given user and threepid + + Args: + user_id (str) + medium (str) + address (str) + + Returns: + Deferred[list[str]]: Resolves to a list of identity servers + """ + return self._simple_select_onecol( + table="user_threepid_id_server", + keyvalues={"user_id": user_id, "medium": medium, "address": address}, + retcol="id_server", + desc="get_id_servers_user_bound", + ) + + @cachedInlineCallbacks() + def get_user_deactivated_status(self, user_id): + """Retrieve the value for the `deactivated` property for the provided user. + + Args: + user_id (str): The ID of the user to retrieve the status for. + + Returns: + defer.Deferred(bool): The requested value. + """ + + res = yield self._simple_select_one_onecol( + table="users", + keyvalues={"name": user_id}, + retcol="deactivated", + desc="get_user_deactivated_status", + ) + + # Convert the integer into a boolean. + return res == 1 + + def get_threepid_validation_session( + self, medium, client_secret, address=None, sid=None, validated=True + ): + """Gets a session_id and last_send_attempt (if available) for a + combination of validation metadata + + Args: + medium (str|None): The medium of the 3PID + address (str|None): The address of the 3PID + sid (str|None): The ID of the validation session + client_secret (str): A unique string provided by the client to help identify this + validation attempt + validated (bool|None): Whether sessions should be filtered by + whether they have been validated already or not. None to + perform no filtering + + Returns: + Deferred[dict|None]: A dict containing the following: + * address - address of the 3pid + * medium - medium of the 3pid + * client_secret - a secret provided by the client for this validation session + * session_id - ID of the validation session + * send_attempt - a number serving to dedupe send attempts for this session + * validated_at - timestamp of when this session was validated if so + + Otherwise None if a validation session is not found + """ + if not client_secret: + raise SynapseError( + 400, "Missing parameter: client_secret", errcode=Codes.MISSING_PARAM + ) + + keyvalues = {"client_secret": client_secret} + if medium: + keyvalues["medium"] = medium + if address: + keyvalues["address"] = address + if sid: + keyvalues["session_id"] = sid + + assert address or sid + + def get_threepid_validation_session_txn(txn): + sql = """ + SELECT address, session_id, medium, client_secret, + last_send_attempt, validated_at + FROM threepid_validation_session WHERE %s + """ % ( + " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)), + ) + + if validated is not None: + sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL") + + sql += " LIMIT 1" + + txn.execute(sql, list(keyvalues.values())) + rows = self.cursor_to_dict(txn) + if not rows: + return None + + return rows[0] + + return self.runInteraction( + "get_threepid_validation_session", get_threepid_validation_session_txn + ) + + def delete_threepid_session(self, session_id): + """Removes a threepid validation session from the database. This can + be done after validation has been performed and whatever action was + waiting on it has been carried out + + Args: + session_id (str): The ID of the session to delete + """ + + def delete_threepid_session_txn(txn): + self._simple_delete_txn( + txn, + table="threepid_validation_token", + keyvalues={"session_id": session_id}, + ) + self._simple_delete_txn( + txn, + table="threepid_validation_session", + keyvalues={"session_id": session_id}, + ) + + return self.runInteraction( + "delete_threepid_session", delete_threepid_session_txn + ) + + +class RegistrationBackgroundUpdateStore( + RegistrationWorkerStore, background_updates.BackgroundUpdateStore +): + def __init__(self, db_conn, hs): + super(RegistrationBackgroundUpdateStore, self).__init__(db_conn, hs) + + self.clock = hs.get_clock() + self.config = hs.config + + self.register_background_index_update( + "access_tokens_device_index", + index_name="access_tokens_device_id", + table="access_tokens", + columns=["user_id", "device_id"], + ) + + self.register_background_index_update( + "users_creation_ts", + index_name="users_creation_ts", + table="users", + columns=["creation_ts"], + ) + + # we no longer use refresh tokens, but it's possible that some people + # might have a background update queued to build this index. Just + # clear the background update. + self.register_noop_background_update("refresh_tokens_device_index") + + self.register_background_update_handler( + "user_threepids_grandfather", self._bg_user_threepids_grandfather + ) + + self.register_background_update_handler( + "users_set_deactivated_flag", self._background_update_set_deactivated_flag + ) + + @defer.inlineCallbacks + def _background_update_set_deactivated_flag(self, progress, batch_size): + """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1 + for each of them. + """ + + last_user = progress.get("user_id", "") + + def _background_update_set_deactivated_flag_txn(txn): + txn.execute( + """ + SELECT + users.name, + COUNT(access_tokens.token) AS count_tokens, + COUNT(user_threepids.address) AS count_threepids + FROM users + LEFT JOIN access_tokens ON (access_tokens.user_id = users.name) + LEFT JOIN user_threepids ON (user_threepids.user_id = users.name) + WHERE (users.password_hash IS NULL OR users.password_hash = '') + AND (users.appservice_id IS NULL OR users.appservice_id = '') + AND users.is_guest = 0 + AND users.name > ? + GROUP BY users.name + ORDER BY users.name ASC + LIMIT ?; + """, + (last_user, batch_size), + ) + + rows = self.cursor_to_dict(txn) + + if not rows: + return True, 0 + + rows_processed_nb = 0 + + for user in rows: + if not user["count_tokens"] and not user["count_threepids"]: + self.set_user_deactivated_status_txn(txn, user["name"], True) + rows_processed_nb += 1 + + logger.info("Marked %d rows as deactivated", rows_processed_nb) + + self._background_update_progress_txn( + txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]} + ) + + if batch_size > len(rows): + return True, len(rows) + else: + return False, len(rows) + + end, nb_processed = yield self.runInteraction( + "users_set_deactivated_flag", _background_update_set_deactivated_flag_txn + ) + + if end: + yield self._end_background_update("users_set_deactivated_flag") + + return nb_processed + + @defer.inlineCallbacks + def _bg_user_threepids_grandfather(self, progress, batch_size): + """We now track which identity servers a user binds their 3PID to, so + we need to handle the case of existing bindings where we didn't track + this. + + We do this by grandfathering in existing user threepids assuming that + they used one of the server configured trusted identity servers. + """ + id_servers = set(self.config.trusted_third_party_id_servers) + + def _bg_user_threepids_grandfather_txn(txn): + sql = """ + INSERT INTO user_threepid_id_server + (user_id, medium, address, id_server) + SELECT user_id, medium, address, ? + FROM user_threepids + """ + + txn.executemany(sql, [(id_server,) for id_server in id_servers]) + + if id_servers: + yield self.runInteraction( + "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn + ) + + yield self._end_background_update("user_threepids_grandfather") + + return 1 + + +class RegistrationStore(RegistrationBackgroundUpdateStore): + def __init__(self, db_conn, hs): + super(RegistrationStore, self).__init__(db_conn, hs) + + self._account_validity = hs.config.account_validity + + # Create a background job for culling expired 3PID validity tokens + def start_cull(): + # run as a background process to make sure that the database transactions + # have a logcontext to report to + return run_as_background_process( + "cull_expired_threepid_validation_tokens", + self.cull_expired_threepid_validation_tokens, + ) + + hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS) + + @defer.inlineCallbacks + def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms): + """Adds an access token for the given user. + + Args: + user_id (str): The user ID. + token (str): The new access token to add. + device_id (str): ID of the device to associate with the access + token + valid_until_ms (int|None): when the token is valid until. None for + no expiry. + Raises: + StoreError if there was a problem adding this. + """ + next_id = self._access_tokens_id_gen.get_next() + + yield self._simple_insert( + "access_tokens", + { + "id": next_id, + "user_id": user_id, + "token": token, + "device_id": device_id, + "valid_until_ms": valid_until_ms, + }, + desc="add_access_token_to_user", + ) + + def register_user( + self, + user_id, + password_hash=None, + was_guest=False, + make_guest=False, + appservice_id=None, + create_profile_with_displayname=None, + admin=False, + user_type=None, + ): + """Attempts to register an account. + + Args: + user_id (str): The desired user ID to register. + password_hash (str): Optional. The password hash for this user. + was_guest (bool): Optional. Whether this is a guest account being + upgraded to a non-guest account. + make_guest (boolean): True if the the new user should be guest, + false to add a regular user account. + appservice_id (str): The ID of the appservice registering the user. + create_profile_with_displayname (unicode): Optionally create a profile for + the user, setting their displayname to the given value + admin (boolean): is an admin user? + user_type (str|None): type of user. One of the values from + api.constants.UserTypes, or None for a normal user. + + Raises: + StoreError if the user_id could not be registered. + """ + return self.runInteraction( + "register_user", + self._register_user, + user_id, + password_hash, + was_guest, + make_guest, + appservice_id, + create_profile_with_displayname, + admin, + user_type, + ) + + def _register_user( + self, + txn, + user_id, + password_hash, + was_guest, + make_guest, + appservice_id, + create_profile_with_displayname, + admin, + user_type, + ): + user_id_obj = UserID.from_string(user_id) + + now = int(self.clock.time()) + + try: + if was_guest: + # Ensure that the guest user actually exists + # ``allow_none=False`` makes this raise an exception + # if the row isn't in the database. + self._simple_select_one_txn( + txn, + "users", + keyvalues={"name": user_id, "is_guest": 1}, + retcols=("name",), + allow_none=False, + ) + + self._simple_update_one_txn( + txn, + "users", + keyvalues={"name": user_id, "is_guest": 1}, + updatevalues={ + "password_hash": password_hash, + "upgrade_ts": now, + "is_guest": 1 if make_guest else 0, + "appservice_id": appservice_id, + "admin": 1 if admin else 0, + "user_type": user_type, + }, + ) + else: + self._simple_insert_txn( + txn, + "users", + values={ + "name": user_id, + "password_hash": password_hash, + "creation_ts": now, + "is_guest": 1 if make_guest else 0, + "appservice_id": appservice_id, + "admin": 1 if admin else 0, + "user_type": user_type, + }, + ) + + except self.database_engine.module.IntegrityError: + raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE) + + if self._account_validity.enabled: + self.set_expiration_date_for_user_txn(txn, user_id) + + if create_profile_with_displayname: + # set a default displayname serverside to avoid ugly race + # between auto-joins and clients trying to set displaynames + # + # *obviously* the 'profiles' table uses localpart for user_id + # while everything else uses the full mxid. + txn.execute( + "INSERT INTO profiles(user_id, displayname) VALUES (?,?)", + (user_id_obj.localpart, create_profile_with_displayname), + ) + + if self.hs.config.stats_enabled: + # we create a new completed user statistics row + + # we don't strictly need current_token since this user really can't + # have any state deltas before now (as it is a new user), but still, + # we include it for completeness. + current_token = self._get_max_stream_id_in_current_state_deltas_txn(txn) + self._update_stats_delta_txn( + txn, now, "user", user_id, {}, complete_with_stream_id=current_token + ) + + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) + txn.call_after(self.is_guest.invalidate, (user_id,)) + + def record_user_external_id( + self, auth_provider: str, external_id: str, user_id: str + ) -> Deferred: + """Record a mapping from an external user id to a mxid + + Args: + auth_provider: identifier for the remote auth provider + external_id: id on that system + user_id: complete mxid that it is mapped to + """ + return self._simple_insert( + table="user_external_ids", + values={ + "auth_provider": auth_provider, + "external_id": external_id, + "user_id": user_id, + }, + desc="record_user_external_id", + ) + + def user_set_password_hash(self, user_id, password_hash): + """ + NB. This does *not* evict any cache because the one use for this + removes most of the entries subsequently anyway so it would be + pointless. Use flush_user separately. + """ + + def user_set_password_hash_txn(txn): + self._simple_update_one_txn( + txn, "users", {"name": user_id}, {"password_hash": password_hash} + ) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) + + return self.runInteraction("user_set_password_hash", user_set_password_hash_txn) + + def user_set_consent_version(self, user_id, consent_version): + """Updates the user table to record privacy policy consent + + Args: + user_id (str): full mxid of the user to update + consent_version (str): version of the policy the user has consented + to + + Raises: + StoreError(404) if user not found + """ + + def f(txn): + self._simple_update_one_txn( + txn, + table="users", + keyvalues={"name": user_id}, + updatevalues={"consent_version": consent_version}, + ) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) + + return self.runInteraction("user_set_consent_version", f) + + def user_set_consent_server_notice_sent(self, user_id, consent_version): + """Updates the user table to record that we have sent the user a server + notice about privacy policy consent + + Args: + user_id (str): full mxid of the user to update + consent_version (str): version of the policy we have notified the + user about + + Raises: + StoreError(404) if user not found + """ + + def f(txn): + self._simple_update_one_txn( + txn, + table="users", + keyvalues={"name": user_id}, + updatevalues={"consent_server_notice_sent": consent_version}, + ) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) + + return self.runInteraction("user_set_consent_server_notice_sent", f) + + def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None): + """ + Invalidate access tokens belonging to a user + + Args: + user_id (str): ID of user the tokens belong to + except_token_id (str): list of access_tokens IDs which should + *not* be deleted + device_id (str|None): ID of device the tokens are associated with. + If None, tokens associated with any device (or no device) will + be deleted + Returns: + defer.Deferred[list[str, int, str|None, int]]: a list of + (token, token id, device id) for each of the deleted tokens + """ + + def f(txn): + keyvalues = {"user_id": user_id} + if device_id is not None: + keyvalues["device_id"] = device_id + + items = keyvalues.items() + where_clause = " AND ".join(k + " = ?" for k, _ in items) + values = [v for _, v in items] + if except_token_id: + where_clause += " AND id != ?" + values.append(except_token_id) + + txn.execute( + "SELECT token, id, device_id FROM access_tokens WHERE %s" + % where_clause, + values, + ) + tokens_and_devices = [(r[0], r[1], r[2]) for r in txn] + + for token, _, _ in tokens_and_devices: + self._invalidate_cache_and_stream( + txn, self.get_user_by_access_token, (token,) + ) + + txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values) + + return tokens_and_devices + + return self.runInteraction("user_delete_access_tokens", f) + + def delete_access_token(self, access_token): + def f(txn): + self._simple_delete_one_txn( + txn, table="access_tokens", keyvalues={"token": access_token} + ) + + self._invalidate_cache_and_stream( + txn, self.get_user_by_access_token, (access_token,) + ) + + return self.runInteraction("delete_access_token", f) + + @cachedInlineCallbacks() + def is_guest(self, user_id): + res = yield self._simple_select_one_onecol( + table="users", + keyvalues={"name": user_id}, + retcol="is_guest", + allow_none=True, + desc="is_guest", + ) + + return res if res else False + + def add_user_pending_deactivation(self, user_id): + """ + Adds a user to the table of users who need to be parted from all the rooms they're + in + """ + return self._simple_insert( + "users_pending_deactivation", + values={"user_id": user_id}, + desc="add_user_pending_deactivation", + ) + + def del_user_pending_deactivation(self, user_id): + """ + Removes the given user to the table of users who need to be parted from all the + rooms they're in, effectively marking that user as fully deactivated. + """ + # XXX: This should be simple_delete_one but we failed to put a unique index on + # the table, so somehow duplicate entries have ended up in it. + return self._simple_delete( + "users_pending_deactivation", + keyvalues={"user_id": user_id}, + desc="del_user_pending_deactivation", + ) + + def get_user_pending_deactivation(self): + """ + Gets one user from the table of users waiting to be parted from all the rooms + they're in. + """ + return self._simple_select_one_onecol( + "users_pending_deactivation", + keyvalues={}, + retcol="user_id", + allow_none=True, + desc="get_users_pending_deactivation", + ) + + def validate_threepid_session(self, session_id, client_secret, token, current_ts): + """Attempt to validate a threepid session using a token + + Args: + session_id (str): The id of a validation session + client_secret (str): A unique string provided by the client to + help identify this validation attempt + token (str): A validation token + current_ts (int): The current unix time in milliseconds. Used for + checking token expiry status + + Raises: + ThreepidValidationError: if a matching validation token was not found or has + expired + + Returns: + deferred str|None: A str representing a link to redirect the user + to if there is one. + """ + + # Insert everything into a transaction in order to run atomically + def validate_threepid_session_txn(txn): + row = self._simple_select_one_txn( + txn, + table="threepid_validation_session", + keyvalues={"session_id": session_id}, + retcols=["client_secret", "validated_at"], + allow_none=True, + ) + + if not row: + raise ThreepidValidationError(400, "Unknown session_id") + retrieved_client_secret = row["client_secret"] + validated_at = row["validated_at"] + + if retrieved_client_secret != client_secret: + raise ThreepidValidationError( + 400, "This client_secret does not match the provided session_id" + ) + + row = self._simple_select_one_txn( + txn, + table="threepid_validation_token", + keyvalues={"session_id": session_id, "token": token}, + retcols=["expires", "next_link"], + allow_none=True, + ) + + if not row: + raise ThreepidValidationError( + 400, "Validation token not found or has expired" + ) + expires = row["expires"] + next_link = row["next_link"] + + # If the session is already validated, no need to revalidate + if validated_at: + return next_link + + if expires <= current_ts: + raise ThreepidValidationError( + 400, "This token has expired. Please request a new one" + ) + + # Looks good. Validate the session + self._simple_update_txn( + txn, + table="threepid_validation_session", + keyvalues={"session_id": session_id}, + updatevalues={"validated_at": self.clock.time_msec()}, + ) + + return next_link + + # Return next_link if it exists + return self.runInteraction( + "validate_threepid_session_txn", validate_threepid_session_txn + ) + + def upsert_threepid_validation_session( + self, + medium, + address, + client_secret, + send_attempt, + session_id, + validated_at=None, + ): + """Upsert a threepid validation session + Args: + medium (str): The medium of the 3PID + address (str): The address of the 3PID + client_secret (str): A unique string provided by the client to + help identify this validation attempt + send_attempt (int): The latest send_attempt on this session + session_id (str): The id of this validation session + validated_at (int|None): The unix timestamp in milliseconds of + when the session was marked as valid + """ + insertion_values = { + "medium": medium, + "address": address, + "client_secret": client_secret, + } + + if validated_at: + insertion_values["validated_at"] = validated_at + + return self._simple_upsert( + table="threepid_validation_session", + keyvalues={"session_id": session_id}, + values={"last_send_attempt": send_attempt}, + insertion_values=insertion_values, + desc="upsert_threepid_validation_session", + ) + + def start_or_continue_validation_session( + self, + medium, + address, + session_id, + client_secret, + send_attempt, + next_link, + token, + token_expires, + ): + """Creates a new threepid validation session if it does not already + exist and associates a new validation token with it + + Args: + medium (str): The medium of the 3PID + address (str): The address of the 3PID + session_id (str): The id of this validation session + client_secret (str): A unique string provided by the client to + help identify this validation attempt + send_attempt (int): The latest send_attempt on this session + next_link (str|None): The link to redirect the user to upon + successful validation + token (str): The validation token + token_expires (int): The timestamp for which after the token + will no longer be valid + """ + + def start_or_continue_validation_session_txn(txn): + # Create or update a validation session + self._simple_upsert_txn( + txn, + table="threepid_validation_session", + keyvalues={"session_id": session_id}, + values={"last_send_attempt": send_attempt}, + insertion_values={ + "medium": medium, + "address": address, + "client_secret": client_secret, + }, + ) + + # Create a new validation token with this session ID + self._simple_insert_txn( + txn, + table="threepid_validation_token", + values={ + "session_id": session_id, + "token": token, + "next_link": next_link, + "expires": token_expires, + }, + ) + + return self.runInteraction( + "start_or_continue_validation_session", + start_or_continue_validation_session_txn, + ) + + def cull_expired_threepid_validation_tokens(self): + """Remove threepid validation tokens with expiry dates that have passed""" + + def cull_expired_threepid_validation_tokens_txn(txn, ts): + sql = """ + DELETE FROM threepid_validation_token WHERE + expires < ? + """ + return txn.execute(sql, (ts,)) + + return self.runInteraction( + "cull_expired_threepid_validation_tokens", + cull_expired_threepid_validation_tokens_txn, + self.clock.time_msec(), + ) + + @defer.inlineCallbacks + def set_user_deactivated_status(self, user_id, deactivated): + """Set the `deactivated` property for the provided user to the provided value. + + Args: + user_id (str): The ID of the user to set the status for. + deactivated (bool): The value to set for `deactivated`. + """ + + yield self.runInteraction( + "set_user_deactivated_status", + self.set_user_deactivated_status_txn, + user_id, + deactivated, + ) + + def set_user_deactivated_status_txn(self, txn, user_id, deactivated): + self._simple_update_one_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + updatevalues={"deactivated": 1 if deactivated else 0}, + ) + self._invalidate_cache_and_stream( + txn, self.get_user_deactivated_status, (user_id,) + ) diff --git a/synapse/storage/data_stores/main/rejections.py b/synapse/storage/data_stores/main/rejections.py new file mode 100644 index 0000000000..7d5de0ea2e --- /dev/null +++ b/synapse/storage/data_stores/main/rejections.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.storage._base import SQLBaseStore + +logger = logging.getLogger(__name__) + + +class RejectionsStore(SQLBaseStore): + def _store_rejections_txn(self, txn, event_id, reason): + self._simple_insert_txn( + txn, + table="rejections", + values={ + "event_id": event_id, + "reason": reason, + "last_check": self._clock.time_msec(), + }, + ) + + def get_rejection_reason(self, event_id): + return self._simple_select_one_onecol( + table="rejections", + retcol="reason", + keyvalues={"event_id": event_id}, + allow_none=True, + desc="get_rejection_reason", + ) diff --git a/synapse/storage/data_stores/main/relations.py b/synapse/storage/data_stores/main/relations.py new file mode 100644 index 0000000000..858f65582b --- /dev/null +++ b/synapse/storage/data_stores/main/relations.py @@ -0,0 +1,385 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import attr + +from synapse.api.constants import RelationTypes +from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.main.stream import generate_pagination_where_clause +from synapse.storage.relations import ( + AggregationPaginationToken, + PaginationChunk, + RelationPaginationToken, +) +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks + +logger = logging.getLogger(__name__) + + +class RelationsWorkerStore(SQLBaseStore): + @cached(tree=True) + def get_relations_for_event( + self, + event_id, + relation_type=None, + event_type=None, + aggregation_key=None, + limit=5, + direction="b", + from_token=None, + to_token=None, + ): + """Get a list of relations for an event, ordered by topological ordering. + + Args: + event_id (str): Fetch events that relate to this event ID. + relation_type (str|None): Only fetch events with this relation + type, if given. + event_type (str|None): Only fetch events with this event type, if + given. + aggregation_key (str|None): Only fetch events with this aggregation + key, if given. + limit (int): Only fetch the most recent `limit` events. + direction (str): Whether to fetch the most recent first (`"b"`) or + the oldest first (`"f"`). + from_token (RelationPaginationToken|None): Fetch rows from the given + token, or from the start if None. + to_token (RelationPaginationToken|None): Fetch rows up to the given + token, or up to the end if None. + + Returns: + Deferred[PaginationChunk]: List of event IDs that match relations + requested. The rows are of the form `{"event_id": "..."}`. + """ + + where_clause = ["relates_to_id = ?"] + where_args = [event_id] + + if relation_type is not None: + where_clause.append("relation_type = ?") + where_args.append(relation_type) + + if event_type is not None: + where_clause.append("type = ?") + where_args.append(event_type) + + if aggregation_key: + where_clause.append("aggregation_key = ?") + where_args.append(aggregation_key) + + pagination_clause = generate_pagination_where_clause( + direction=direction, + column_names=("topological_ordering", "stream_ordering"), + from_token=attr.astuple(from_token) if from_token else None, + to_token=attr.astuple(to_token) if to_token else None, + engine=self.database_engine, + ) + + if pagination_clause: + where_clause.append(pagination_clause) + + if direction == "b": + order = "DESC" + else: + order = "ASC" + + sql = """ + SELECT event_id, topological_ordering, stream_ordering + FROM event_relations + INNER JOIN events USING (event_id) + WHERE %s + ORDER BY topological_ordering %s, stream_ordering %s + LIMIT ? + """ % ( + " AND ".join(where_clause), + order, + order, + ) + + def _get_recent_references_for_event_txn(txn): + txn.execute(sql, where_args + [limit + 1]) + + last_topo_id = None + last_stream_id = None + events = [] + for row in txn: + events.append({"event_id": row[0]}) + last_topo_id = row[1] + last_stream_id = row[2] + + next_batch = None + if len(events) > limit and last_topo_id and last_stream_id: + next_batch = RelationPaginationToken(last_topo_id, last_stream_id) + + return PaginationChunk( + chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token + ) + + return self.runInteraction( + "get_recent_references_for_event", _get_recent_references_for_event_txn + ) + + @cached(tree=True) + def get_aggregation_groups_for_event( + self, + event_id, + event_type=None, + limit=5, + direction="b", + from_token=None, + to_token=None, + ): + """Get a list of annotations on the event, grouped by event type and + aggregation key, sorted by count. + + This is used e.g. to get the what and how many reactions have happend + on an event. + + Args: + event_id (str): Fetch events that relate to this event ID. + event_type (str|None): Only fetch events with this event type, if + given. + limit (int): Only fetch the `limit` groups. + direction (str): Whether to fetch the highest count first (`"b"`) or + the lowest count first (`"f"`). + from_token (AggregationPaginationToken|None): Fetch rows from the + given token, or from the start if None. + to_token (AggregationPaginationToken|None): Fetch rows up to the + given token, or up to the end if None. + + + Returns: + Deferred[PaginationChunk]: List of groups of annotations that + match. Each row is a dict with `type`, `key` and `count` fields. + """ + + where_clause = ["relates_to_id = ?", "relation_type = ?"] + where_args = [event_id, RelationTypes.ANNOTATION] + + if event_type: + where_clause.append("type = ?") + where_args.append(event_type) + + having_clause = generate_pagination_where_clause( + direction=direction, + column_names=("COUNT(*)", "MAX(stream_ordering)"), + from_token=attr.astuple(from_token) if from_token else None, + to_token=attr.astuple(to_token) if to_token else None, + engine=self.database_engine, + ) + + if direction == "b": + order = "DESC" + else: + order = "ASC" + + if having_clause: + having_clause = "HAVING " + having_clause + else: + having_clause = "" + + sql = """ + SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering) + FROM event_relations + INNER JOIN events USING (event_id) + WHERE {where_clause} + GROUP BY relation_type, type, aggregation_key + {having_clause} + ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order} + LIMIT ? + """.format( + where_clause=" AND ".join(where_clause), + order=order, + having_clause=having_clause, + ) + + def _get_aggregation_groups_for_event_txn(txn): + txn.execute(sql, where_args + [limit + 1]) + + next_batch = None + events = [] + for row in txn: + events.append({"type": row[0], "key": row[1], "count": row[2]}) + next_batch = AggregationPaginationToken(row[2], row[3]) + + if len(events) <= limit: + next_batch = None + + return PaginationChunk( + chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token + ) + + return self.runInteraction( + "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn + ) + + @cachedInlineCallbacks() + def get_applicable_edit(self, event_id): + """Get the most recent edit (if any) that has happened for the given + event. + + Correctly handles checking whether edits were allowed to happen. + + Args: + event_id (str): The original event ID + + Returns: + Deferred[EventBase|None]: Returns the most recent edit, if any. + """ + + # We only allow edits for `m.room.message` events that have the same sender + # and event type. We can't assert these things during regular event auth so + # we have to do the checks post hoc. + + # Fetches latest edit that has the same type and sender as the + # original, and is an `m.room.message`. + sql = """ + SELECT edit.event_id FROM events AS edit + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS original ON + original.event_id = relates_to_id + AND edit.type = original.type + AND edit.sender = original.sender + WHERE + relates_to_id = ? + AND relation_type = ? + AND edit.type = 'm.room.message' + ORDER by edit.origin_server_ts DESC, edit.event_id DESC + LIMIT 1 + """ + + def _get_applicable_edit_txn(txn): + txn.execute(sql, (event_id, RelationTypes.REPLACE)) + row = txn.fetchone() + if row: + return row[0] + + edit_id = yield self.runInteraction( + "get_applicable_edit", _get_applicable_edit_txn + ) + + if not edit_id: + return + + edit_event = yield self.get_event(edit_id, allow_none=True) + return edit_event + + def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender): + """Check if a user has already annotated an event with the same key + (e.g. already liked an event). + + Args: + parent_id (str): The event being annotated + event_type (str): The event type of the annotation + aggregation_key (str): The aggregation key of the annotation + sender (str): The sender of the annotation + + Returns: + Deferred[bool] + """ + + sql = """ + SELECT 1 FROM event_relations + INNER JOIN events USING (event_id) + WHERE + relates_to_id = ? + AND relation_type = ? + AND type = ? + AND sender = ? + AND aggregation_key = ? + LIMIT 1; + """ + + def _get_if_user_has_annotated_event(txn): + txn.execute( + sql, + ( + parent_id, + RelationTypes.ANNOTATION, + event_type, + sender, + aggregation_key, + ), + ) + + return bool(txn.fetchone()) + + return self.runInteraction( + "get_if_user_has_annotated_event", _get_if_user_has_annotated_event + ) + + +class RelationsStore(RelationsWorkerStore): + def _handle_event_relations(self, txn, event): + """Handles inserting relation data during peristence of events + + Args: + txn + event (EventBase) + """ + relation = event.content.get("m.relates_to") + if not relation: + # No relations + return + + rel_type = relation.get("rel_type") + if rel_type not in ( + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + RelationTypes.REPLACE, + ): + # Unknown relation type + return + + parent_id = relation.get("event_id") + if not parent_id: + # Invalid relation + return + + aggregation_key = relation.get("key") + + self._simple_insert_txn( + txn, + table="event_relations", + values={ + "event_id": event.event_id, + "relates_to_id": parent_id, + "relation_type": rel_type, + "aggregation_key": aggregation_key, + }, + ) + + txn.call_after(self.get_relations_for_event.invalidate_many, (parent_id,)) + txn.call_after( + self.get_aggregation_groups_for_event.invalidate_many, (parent_id,) + ) + + if rel_type == RelationTypes.REPLACE: + txn.call_after(self.get_applicable_edit.invalidate, (parent_id,)) + + def _handle_redaction(self, txn, redacted_event_id): + """Handles receiving a redaction and checking whether we need to remove + any redacted relations from the database. + + Args: + txn + redacted_event_id (str): The event that was redacted. + """ + + self._simple_delete_txn( + txn, table="event_relations", keyvalues={"event_id": redacted_event_id} + ) diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py new file mode 100644 index 0000000000..4428e5c55d --- /dev/null +++ b/synapse/storage/data_stores/main/room.py @@ -0,0 +1,681 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import logging +import re +from typing import Optional, Tuple + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.api.errors import StoreError +from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.main.search import SearchStore +from synapse.types import ThirdPartyInstanceID +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks + +logger = logging.getLogger(__name__) + + +OpsLevel = collections.namedtuple( + "OpsLevel", ("ban_level", "kick_level", "redact_level") +) + +RatelimitOverride = collections.namedtuple( + "RatelimitOverride", ("messages_per_second", "burst_count") +) + + +class RoomWorkerStore(SQLBaseStore): + def get_room(self, room_id): + """Retrieve a room. + + Args: + room_id (str): The ID of the room to retrieve. + Returns: + A dict containing the room information, or None if the room is unknown. + """ + return self._simple_select_one( + table="rooms", + keyvalues={"room_id": room_id}, + retcols=("room_id", "is_public", "creator"), + desc="get_room", + allow_none=True, + ) + + def get_public_room_ids(self): + return self._simple_select_onecol( + table="rooms", + keyvalues={"is_public": True}, + retcol="room_id", + desc="get_public_room_ids", + ) + + def count_public_rooms(self, network_tuple, ignore_non_federatable): + """Counts the number of public rooms as tracked in the room_stats_current + and room_stats_state table. + + Args: + network_tuple (ThirdPartyInstanceID|None) + ignore_non_federatable (bool): If true filters out non-federatable rooms + """ + + def _count_public_rooms_txn(txn): + query_args = [] + + if network_tuple: + if network_tuple.appservice_id: + published_sql = """ + SELECT room_id from appservice_room_list + WHERE appservice_id = ? AND network_id = ? + """ + query_args.append(network_tuple.appservice_id) + query_args.append(network_tuple.network_id) + else: + published_sql = """ + SELECT room_id FROM rooms WHERE is_public + """ + else: + published_sql = """ + SELECT room_id FROM rooms WHERE is_public + UNION SELECT room_id from appservice_room_list + """ + + sql = """ + SELECT + COALESCE(COUNT(*), 0) + FROM ( + %(published_sql)s + ) published + INNER JOIN room_stats_state USING (room_id) + INNER JOIN room_stats_current USING (room_id) + WHERE + ( + join_rules = 'public' OR history_visibility = 'world_readable' + ) + AND joined_members > 0 + """ % { + "published_sql": published_sql + } + + txn.execute(sql, query_args) + return txn.fetchone()[0] + + return self.runInteraction("count_public_rooms", _count_public_rooms_txn) + + @defer.inlineCallbacks + def get_largest_public_rooms( + self, + network_tuple: Optional[ThirdPartyInstanceID], + search_filter: Optional[dict], + limit: Optional[int], + bounds: Optional[Tuple[int, str]], + forwards: bool, + ignore_non_federatable: bool = False, + ): + """Gets the largest public rooms (where largest is in terms of joined + members, as tracked in the statistics table). + + Args: + network_tuple + search_filter + limit: Maxmimum number of rows to return, unlimited otherwise. + bounds: An uppoer or lower bound to apply to result set if given, + consists of a joined member count and room_id (these are + excluded from result set). + forwards: true iff going forwards, going backwards otherwise + ignore_non_federatable: If true filters out non-federatable rooms. + + Returns: + Rooms in order: biggest number of joined users first. + We then arbitrarily use the room_id as a tie breaker. + + """ + + where_clauses = [] + query_args = [] + + if network_tuple: + if network_tuple.appservice_id: + published_sql = """ + SELECT room_id from appservice_room_list + WHERE appservice_id = ? AND network_id = ? + """ + query_args.append(network_tuple.appservice_id) + query_args.append(network_tuple.network_id) + else: + published_sql = """ + SELECT room_id FROM rooms WHERE is_public + """ + else: + published_sql = """ + SELECT room_id FROM rooms WHERE is_public + UNION SELECT room_id from appservice_room_list + """ + + # Work out the bounds if we're given them, these bounds look slightly + # odd, but are designed to help query planner use indices by pulling + # out a common bound. + if bounds: + last_joined_members, last_room_id = bounds + if forwards: + where_clauses.append( + """ + joined_members <= ? AND ( + joined_members < ? OR room_id < ? + ) + """ + ) + else: + where_clauses.append( + """ + joined_members >= ? AND ( + joined_members > ? OR room_id > ? + ) + """ + ) + + query_args += [last_joined_members, last_joined_members, last_room_id] + + if ignore_non_federatable: + where_clauses.append("is_federatable") + + if search_filter and search_filter.get("generic_search_term", None): + search_term = "%" + search_filter["generic_search_term"] + "%" + + where_clauses.append( + """ + ( + name LIKE ? + OR topic LIKE ? + OR canonical_alias LIKE ? + ) + """ + ) + query_args += [search_term, search_term, search_term] + + where_clause = "" + if where_clauses: + where_clause = " AND " + " AND ".join(where_clauses) + + sql = """ + SELECT + room_id, name, topic, canonical_alias, joined_members, + avatar, history_visibility, joined_members, guest_access + FROM ( + %(published_sql)s + ) published + INNER JOIN room_stats_state USING (room_id) + INNER JOIN room_stats_current USING (room_id) + WHERE + ( + join_rules = 'public' OR history_visibility = 'world_readable' + ) + AND joined_members > 0 + %(where_clause)s + ORDER BY joined_members %(dir)s, room_id %(dir)s + """ % { + "published_sql": published_sql, + "where_clause": where_clause, + "dir": "DESC" if forwards else "ASC", + } + + if limit is not None: + query_args.append(limit) + + sql += """ + LIMIT ? + """ + + def _get_largest_public_rooms_txn(txn): + txn.execute(sql, query_args) + + results = self.cursor_to_dict(txn) + + if not forwards: + results.reverse() + + return results + + ret_val = yield self.runInteraction( + "get_largest_public_rooms", _get_largest_public_rooms_txn + ) + defer.returnValue(ret_val) + + @cached(max_entries=10000) + def is_room_blocked(self, room_id): + return self._simple_select_one_onecol( + table="blocked_rooms", + keyvalues={"room_id": room_id}, + retcol="1", + allow_none=True, + desc="is_room_blocked", + ) + + @cachedInlineCallbacks(max_entries=10000) + def get_ratelimit_for_user(self, user_id): + """Check if there are any overrides for ratelimiting for the given + user + + Args: + user_id (str) + + Returns: + RatelimitOverride if there is an override, else None. If the contents + of RatelimitOverride are None or 0 then ratelimitng has been + disabled for that user entirely. + """ + row = yield self._simple_select_one( + table="ratelimit_override", + keyvalues={"user_id": user_id}, + retcols=("messages_per_second", "burst_count"), + allow_none=True, + desc="get_ratelimit_for_user", + ) + + if row: + return RatelimitOverride( + messages_per_second=row["messages_per_second"], + burst_count=row["burst_count"], + ) + else: + return None + + +class RoomStore(RoomWorkerStore, SearchStore): + @defer.inlineCallbacks + def store_room(self, room_id, room_creator_user_id, is_public): + """Stores a room. + + Args: + room_id (str): The desired room ID, can be None. + room_creator_user_id (str): The user ID of the room creator. + is_public (bool): True to indicate that this room should appear in + public room lists. + Raises: + StoreError if the room could not be stored. + """ + try: + + def store_room_txn(txn, next_id): + self._simple_insert_txn( + txn, + "rooms", + { + "room_id": room_id, + "creator": room_creator_user_id, + "is_public": is_public, + }, + ) + if is_public: + self._simple_insert_txn( + txn, + table="public_room_list_stream", + values={ + "stream_id": next_id, + "room_id": room_id, + "visibility": is_public, + }, + ) + + with self._public_room_id_gen.get_next() as next_id: + yield self.runInteraction("store_room_txn", store_room_txn, next_id) + except Exception as e: + logger.error("store_room with room_id=%s failed: %s", room_id, e) + raise StoreError(500, "Problem creating room.") + + @defer.inlineCallbacks + def set_room_is_public(self, room_id, is_public): + def set_room_is_public_txn(txn, next_id): + self._simple_update_one_txn( + txn, + table="rooms", + keyvalues={"room_id": room_id}, + updatevalues={"is_public": is_public}, + ) + + entries = self._simple_select_list_txn( + txn, + table="public_room_list_stream", + keyvalues={ + "room_id": room_id, + "appservice_id": None, + "network_id": None, + }, + retcols=("stream_id", "visibility"), + ) + + entries.sort(key=lambda r: r["stream_id"]) + + add_to_stream = True + if entries: + add_to_stream = bool(entries[-1]["visibility"]) != is_public + + if add_to_stream: + self._simple_insert_txn( + txn, + table="public_room_list_stream", + values={ + "stream_id": next_id, + "room_id": room_id, + "visibility": is_public, + "appservice_id": None, + "network_id": None, + }, + ) + + with self._public_room_id_gen.get_next() as next_id: + yield self.runInteraction( + "set_room_is_public", set_room_is_public_txn, next_id + ) + self.hs.get_notifier().on_new_replication_data() + + @defer.inlineCallbacks + def set_room_is_public_appservice( + self, room_id, appservice_id, network_id, is_public + ): + """Edit the appservice/network specific public room list. + + Each appservice can have a number of published room lists associated + with them, keyed off of an appservice defined `network_id`, which + basically represents a single instance of a bridge to a third party + network. + + Args: + room_id (str) + appservice_id (str) + network_id (str) + is_public (bool): Whether to publish or unpublish the room from the + list. + """ + + def set_room_is_public_appservice_txn(txn, next_id): + if is_public: + try: + self._simple_insert_txn( + txn, + table="appservice_room_list", + values={ + "appservice_id": appservice_id, + "network_id": network_id, + "room_id": room_id, + }, + ) + except self.database_engine.module.IntegrityError: + # We've already inserted, nothing to do. + return + else: + self._simple_delete_txn( + txn, + table="appservice_room_list", + keyvalues={ + "appservice_id": appservice_id, + "network_id": network_id, + "room_id": room_id, + }, + ) + + entries = self._simple_select_list_txn( + txn, + table="public_room_list_stream", + keyvalues={ + "room_id": room_id, + "appservice_id": appservice_id, + "network_id": network_id, + }, + retcols=("stream_id", "visibility"), + ) + + entries.sort(key=lambda r: r["stream_id"]) + + add_to_stream = True + if entries: + add_to_stream = bool(entries[-1]["visibility"]) != is_public + + if add_to_stream: + self._simple_insert_txn( + txn, + table="public_room_list_stream", + values={ + "stream_id": next_id, + "room_id": room_id, + "visibility": is_public, + "appservice_id": appservice_id, + "network_id": network_id, + }, + ) + + with self._public_room_id_gen.get_next() as next_id: + yield self.runInteraction( + "set_room_is_public_appservice", + set_room_is_public_appservice_txn, + next_id, + ) + self.hs.get_notifier().on_new_replication_data() + + def get_room_count(self): + """Retrieve a list of all rooms + """ + + def f(txn): + sql = "SELECT count(*) FROM rooms" + txn.execute(sql) + row = txn.fetchone() + return row[0] or 0 + + return self.runInteraction("get_rooms", f) + + def _store_room_topic_txn(self, txn, event): + if hasattr(event, "content") and "topic" in event.content: + self.store_event_search_txn( + txn, event, "content.topic", event.content["topic"] + ) + + def _store_room_name_txn(self, txn, event): + if hasattr(event, "content") and "name" in event.content: + self.store_event_search_txn( + txn, event, "content.name", event.content["name"] + ) + + def _store_room_message_txn(self, txn, event): + if hasattr(event, "content") and "body" in event.content: + self.store_event_search_txn( + txn, event, "content.body", event.content["body"] + ) + + def add_event_report( + self, room_id, event_id, user_id, reason, content, received_ts + ): + next_id = self._event_reports_id_gen.get_next() + return self._simple_insert( + table="event_reports", + values={ + "id": next_id, + "received_ts": received_ts, + "room_id": room_id, + "event_id": event_id, + "user_id": user_id, + "reason": reason, + "content": json.dumps(content), + }, + desc="add_event_report", + ) + + def get_current_public_room_stream_id(self): + return self._public_room_id_gen.get_current_token() + + def get_all_new_public_rooms(self, prev_id, current_id, limit): + def get_all_new_public_rooms(txn): + sql = """ + SELECT stream_id, room_id, visibility, appservice_id, network_id + FROM public_room_list_stream + WHERE stream_id > ? AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + """ + + txn.execute(sql, (prev_id, current_id, limit)) + return txn.fetchall() + + if prev_id == current_id: + return defer.succeed([]) + + return self.runInteraction("get_all_new_public_rooms", get_all_new_public_rooms) + + @defer.inlineCallbacks + def block_room(self, room_id, user_id): + """Marks the room as blocked. Can be called multiple times. + + Args: + room_id (str): Room to block + user_id (str): Who blocked it + + Returns: + Deferred + """ + yield self._simple_upsert( + table="blocked_rooms", + keyvalues={"room_id": room_id}, + values={}, + insertion_values={"user_id": user_id}, + desc="block_room", + ) + yield self.runInteraction( + "block_room_invalidation", + self._invalidate_cache_and_stream, + self.is_room_blocked, + (room_id,), + ) + + def get_media_mxcs_in_room(self, room_id): + """Retrieves all the local and remote media MXC URIs in a given room + + Args: + room_id (str) + + Returns: + The local and remote media as a lists of tuples where the key is + the hostname and the value is the media ID. + """ + + def _get_media_mxcs_in_room_txn(txn): + local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) + local_media_mxcs = [] + remote_media_mxcs = [] + + # Convert the IDs to MXC URIs + for media_id in local_mxcs: + local_media_mxcs.append("mxc://%s/%s" % (self.hs.hostname, media_id)) + for hostname, media_id in remote_mxcs: + remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id)) + + return local_media_mxcs, remote_media_mxcs + + return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn) + + def quarantine_media_ids_in_room(self, room_id, quarantined_by): + """For a room loops through all events with media and quarantines + the associated media + """ + + def _quarantine_media_in_room_txn(txn): + local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) + total_media_quarantined = 0 + + # Now update all the tables to set the quarantined_by flag + + txn.executemany( + """ + UPDATE local_media_repository + SET quarantined_by = ? + WHERE media_id = ? + """, + ((quarantined_by, media_id) for media_id in local_mxcs), + ) + + txn.executemany( + """ + UPDATE remote_media_cache + SET quarantined_by = ? + WHERE media_origin = ? AND media_id = ? + """, + ( + (quarantined_by, origin, media_id) + for origin, media_id in remote_mxcs + ), + ) + + total_media_quarantined += len(local_mxcs) + total_media_quarantined += len(remote_mxcs) + + return total_media_quarantined + + return self.runInteraction( + "quarantine_media_in_room", _quarantine_media_in_room_txn + ) + + def _get_media_mxcs_in_room_txn(self, txn, room_id): + """Retrieves all the local and remote media MXC URIs in a given room + + Args: + txn (cursor) + room_id (str) + + Returns: + The local and remote media as a lists of tuples where the key is + the hostname and the value is the media ID. + """ + mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)") + + next_token = self.get_current_events_token() + 1 + local_media_mxcs = [] + remote_media_mxcs = [] + + while next_token: + sql = """ + SELECT stream_ordering, json FROM events + JOIN event_json USING (room_id, event_id) + WHERE room_id = ? + AND stream_ordering < ? + AND contains_url = ? AND outlier = ? + ORDER BY stream_ordering DESC + LIMIT ? + """ + txn.execute(sql, (room_id, next_token, True, False, 100)) + + next_token = None + for stream_ordering, content_json in txn: + next_token = stream_ordering + event_json = json.loads(content_json) + content = event_json["content"] + content_url = content.get("url") + thumbnail_url = content.get("info", {}).get("thumbnail_url") + + for url in (content_url, thumbnail_url): + if not url: + continue + matches = mxc_re.match(url) + if matches: + hostname = matches.group(1) + media_id = matches.group(2) + if hostname == self.hs.hostname: + local_media_mxcs.append(media_id) + else: + remote_media_mxcs.append((hostname, media_id)) + + return local_media_mxcs, remote_media_mxcs diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py new file mode 100644 index 0000000000..e47ab604dd --- /dev/null +++ b/synapse/storage/data_stores/main/roommember.py @@ -0,0 +1,1145 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from six import iteritems, itervalues + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, Membership +from synapse.metrics import LaterGauge +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage._base import LoggingTransaction, make_in_list_sql_clause +from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.engines import Sqlite3Engine +from synapse.storage.roommember import ( + GetRoomsForUserWithStreamOrdering, + MemberSummary, + ProfileInfo, + RoomsForUser, +) +from synapse.types import get_domain_from_id +from synapse.util.async_helpers import Linearizer +from synapse.util.caches import intern_string +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList +from synapse.util.metrics import Measure +from synapse.util.stringutils import to_ascii + +logger = logging.getLogger(__name__) + + +_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" +_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" + + +class RoomMemberWorkerStore(EventsWorkerStore): + def __init__(self, db_conn, hs): + super(RoomMemberWorkerStore, self).__init__(db_conn, hs) + + # Is the current_state_events.membership up to date? Or is the + # background update still running? + self._current_state_events_membership_up_to_date = False + + txn = LoggingTransaction( + db_conn.cursor(), + name="_check_safe_current_state_events_membership_updated", + database_engine=self.database_engine, + ) + self._check_safe_current_state_events_membership_updated_txn(txn) + txn.close() + + if self.hs.config.metrics_flags.known_servers: + self._known_servers_count = 1 + self.hs.get_clock().looping_call( + run_as_background_process, + 60 * 1000, + "_count_known_servers", + self._count_known_servers, + ) + self.hs.get_clock().call_later( + 1000, + run_as_background_process, + "_count_known_servers", + self._count_known_servers, + ) + LaterGauge( + "synapse_federation_known_servers", + "", + [], + lambda: self._known_servers_count, + ) + + @defer.inlineCallbacks + def _count_known_servers(self): + """ + Count the servers that this server knows about. + + The statistic is stored on the class for the + `synapse_federation_known_servers` LaterGauge to collect. + """ + + def _transact(txn): + if isinstance(self.database_engine, Sqlite3Engine): + query = """ + SELECT COUNT(DISTINCT substr(out.user_id, pos+1)) + FROM ( + SELECT rm.user_id as user_id, instr(rm.user_id, ':') + AS pos FROM room_memberships as rm + INNER JOIN current_state_events as c ON rm.event_id = c.event_id + WHERE c.type = 'm.room.member' + ) as out + """ + else: + query = """ + SELECT COUNT(DISTINCT split_part(state_key, ':', 2)) + FROM current_state_events + WHERE type = 'm.room.member' AND membership = 'join'; + """ + txn.execute(query) + return list(txn)[0][0] + + count = yield self.runInteraction("get_known_servers", _transact) + + # We always know about ourselves, even if we have nothing in + # room_memberships (for example, the server is new). + self._known_servers_count = max([count, 1]) + return self._known_servers_count + + def _check_safe_current_state_events_membership_updated_txn(self, txn): + """Checks if it is safe to assume the new current_state_events + membership column is up to date + """ + + pending_update = self._simple_select_one_txn( + txn, + table="background_updates", + keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME}, + retcols=["update_name"], + allow_none=True, + ) + + self._current_state_events_membership_up_to_date = not pending_update + + # If the update is still running, reschedule to run. + if pending_update: + self._clock.call_later( + 15.0, + run_as_background_process, + "_check_safe_current_state_events_membership_updated", + self.runInteraction, + "_check_safe_current_state_events_membership_updated", + self._check_safe_current_state_events_membership_updated_txn, + ) + + @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True) + def get_hosts_in_room(self, room_id, cache_context): + """Returns the set of all hosts currently in the room + """ + user_ids = yield self.get_users_in_room( + room_id, on_invalidate=cache_context.invalidate + ) + hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids) + return hosts + + @cached(max_entries=100000, iterable=True) + def get_users_in_room(self, room_id): + return self.runInteraction( + "get_users_in_room", self.get_users_in_room_txn, room_id + ) + + def get_users_in_room_txn(self, txn, room_id): + # If we can assume current_state_events.membership is up to date + # then we can avoid a join, which is a Very Good Thing given how + # frequently this function gets called. + if self._current_state_events_membership_up_to_date: + sql = """ + SELECT state_key FROM current_state_events + WHERE type = 'm.room.member' AND room_id = ? AND membership = ? + """ + else: + sql = """ + SELECT state_key FROM room_memberships as m + INNER JOIN current_state_events as c + ON m.event_id = c.event_id + AND m.room_id = c.room_id + AND m.user_id = c.state_key + WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ? + """ + + txn.execute(sql, (room_id, Membership.JOIN)) + return [to_ascii(r[0]) for r in txn] + + @cached(max_entries=100000) + def get_room_summary(self, room_id): + """ Get the details of a room roughly suitable for use by the room + summary extension to /sync. Useful when lazy loading room members. + Args: + room_id (str): The room ID to query + Returns: + Deferred[dict[str, MemberSummary]: + dict of membership states, pointing to a MemberSummary named tuple. + """ + + def _get_room_summary_txn(txn): + # first get counts. + # We do this all in one transaction to keep the cache small. + # FIXME: get rid of this when we have room_stats + + # If we can assume current_state_events.membership is up to date + # then we can avoid a join, which is a Very Good Thing given how + # frequently this function gets called. + if self._current_state_events_membership_up_to_date: + # Note, rejected events will have a null membership field, so + # we we manually filter them out. + sql = """ + SELECT count(*), membership FROM current_state_events + WHERE type = 'm.room.member' AND room_id = ? + AND membership IS NOT NULL + GROUP BY membership + """ + else: + sql = """ + SELECT count(*), m.membership FROM room_memberships as m + INNER JOIN current_state_events as c + ON m.event_id = c.event_id + AND m.room_id = c.room_id + AND m.user_id = c.state_key + WHERE c.type = 'm.room.member' AND c.room_id = ? + GROUP BY m.membership + """ + + txn.execute(sql, (room_id,)) + res = {} + for count, membership in txn: + summary = res.setdefault(to_ascii(membership), MemberSummary([], count)) + + # we order by membership and then fairly arbitrarily by event_id so + # heroes are consistent + if self._current_state_events_membership_up_to_date: + # Note, rejected events will have a null membership field, so + # we we manually filter them out. + sql = """ + SELECT state_key, membership, event_id + FROM current_state_events + WHERE type = 'm.room.member' AND room_id = ? + AND membership IS NOT NULL + ORDER BY + CASE membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC, + event_id ASC + LIMIT ? + """ + else: + sql = """ + SELECT c.state_key, m.membership, c.event_id + FROM room_memberships as m + INNER JOIN current_state_events as c USING (room_id, event_id) + WHERE c.type = 'm.room.member' AND c.room_id = ? + ORDER BY + CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC, + c.event_id ASC + LIMIT ? + """ + + # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user. + txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6)) + for user_id, membership, event_id in txn: + summary = res[to_ascii(membership)] + # we will always have a summary for this membership type at this + # point given the summary currently contains the counts. + members = summary.members + members.append((to_ascii(user_id), to_ascii(event_id))) + + return res + + return self.runInteraction("get_room_summary", _get_room_summary_txn) + + def _get_user_counts_in_room_txn(self, txn, room_id): + """ + Get the user count in a room by membership. + + Args: + room_id (str) + membership (Membership) + + Returns: + Deferred[int] + """ + sql = """ + SELECT m.membership, count(*) FROM room_memberships as m + INNER JOIN current_state_events as c USING(event_id) + WHERE c.type = 'm.room.member' AND c.room_id = ? + GROUP BY m.membership + """ + + txn.execute(sql, (room_id,)) + return {row[0]: row[1] for row in txn} + + @cached() + def get_invited_rooms_for_user(self, user_id): + """ Get all the rooms the user is invited to + Args: + user_id (str): The user ID. + Returns: + A deferred list of RoomsForUser. + """ + + return self.get_rooms_for_user_where_membership_is(user_id, [Membership.INVITE]) + + @defer.inlineCallbacks + def get_invite_for_user_in_room(self, user_id, room_id): + """Gets the invite for the given user and room + + Args: + user_id (str) + room_id (str) + + Returns: + Deferred: Resolves to either a RoomsForUser or None if no invite was + found. + """ + invites = yield self.get_invited_rooms_for_user(user_id) + for invite in invites: + if invite.room_id == room_id: + return invite + return None + + @defer.inlineCallbacks + def get_rooms_for_user_where_membership_is(self, user_id, membership_list): + """ Get all the rooms for this user where the membership for this user + matches one in the membership list. + + Filters out forgotten rooms. + + Args: + user_id (str): The user ID. + membership_list (list): A list of synapse.api.constants.Membership + values which the user must be in. + + Returns: + Deferred[list[RoomsForUser]] + """ + if not membership_list: + return defer.succeed(None) + + rooms = yield self.runInteraction( + "get_rooms_for_user_where_membership_is", + self._get_rooms_for_user_where_membership_is_txn, + user_id, + membership_list, + ) + + # Now we filter out forgotten rooms + forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id) + return [room for room in rooms if room.room_id not in forgotten_rooms] + + def _get_rooms_for_user_where_membership_is_txn( + self, txn, user_id, membership_list + ): + + do_invite = Membership.INVITE in membership_list + membership_list = [m for m in membership_list if m != Membership.INVITE] + + results = [] + if membership_list: + if self._current_state_events_membership_up_to_date: + clause, args = make_in_list_sql_clause( + self.database_engine, "c.membership", membership_list + ) + sql = """ + SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering + FROM current_state_events AS c + INNER JOIN events AS e USING (room_id, event_id) + WHERE + c.type = 'm.room.member' + AND state_key = ? + AND %s + """ % ( + clause, + ) + else: + clause, args = make_in_list_sql_clause( + self.database_engine, "m.membership", membership_list + ) + sql = """ + SELECT room_id, e.sender, m.membership, event_id, e.stream_ordering + FROM current_state_events AS c + INNER JOIN room_memberships AS m USING (room_id, event_id) + INNER JOIN events AS e USING (room_id, event_id) + WHERE + c.type = 'm.room.member' + AND state_key = ? + AND %s + """ % ( + clause, + ) + + txn.execute(sql, (user_id, *args)) + results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)] + + if do_invite: + sql = ( + "SELECT i.room_id, inviter, i.event_id, e.stream_ordering" + " FROM local_invites as i" + " INNER JOIN events as e USING (event_id)" + " WHERE invitee = ? AND locally_rejected is NULL" + " AND replaced_by is NULL" + ) + + txn.execute(sql, (user_id,)) + results.extend( + RoomsForUser( + room_id=r["room_id"], + sender=r["inviter"], + event_id=r["event_id"], + stream_ordering=r["stream_ordering"], + membership=Membership.INVITE, + ) + for r in self.cursor_to_dict(txn) + ) + + return results + + @cachedInlineCallbacks(max_entries=500000, iterable=True) + def get_rooms_for_user_with_stream_ordering(self, user_id): + """Returns a set of room_ids the user is currently joined to + + Args: + user_id (str) + + Returns: + Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns + the rooms the user is in currently, along with the stream ordering + of the most recent join for that user and room. + """ + rooms = yield self.get_rooms_for_user_where_membership_is( + user_id, membership_list=[Membership.JOIN] + ) + return frozenset( + GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering) + for r in rooms + ) + + @defer.inlineCallbacks + def get_rooms_for_user(self, user_id, on_invalidate=None): + """Returns a set of room_ids the user is currently joined to + """ + rooms = yield self.get_rooms_for_user_with_stream_ordering( + user_id, on_invalidate=on_invalidate + ) + return frozenset(r.room_id for r in rooms) + + @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True) + def get_users_who_share_room_with_user(self, user_id, cache_context): + """Returns the set of users who share a room with `user_id` + """ + room_ids = yield self.get_rooms_for_user( + user_id, on_invalidate=cache_context.invalidate + ) + + user_who_share_room = set() + for room_id in room_ids: + user_ids = yield self.get_users_in_room( + room_id, on_invalidate=cache_context.invalidate + ) + user_who_share_room.update(user_ids) + + return user_who_share_room + + @defer.inlineCallbacks + def get_joined_users_from_context(self, event, context): + state_group = context.state_group + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + current_state_ids = yield context.get_current_state_ids(self) + result = yield self._get_joined_users_from_context( + event.room_id, state_group, current_state_ids, event=event, context=context + ) + return result + + @defer.inlineCallbacks + def get_joined_users_from_state(self, room_id, state_entry): + state_group = state_entry.state_group + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + with Measure(self._clock, "get_joined_users_from_state"): + return ( + yield self._get_joined_users_from_context( + room_id, state_group, state_entry.state, context=state_entry + ) + ) + + @cachedInlineCallbacks( + num_args=2, cache_context=True, iterable=True, max_entries=100000 + ) + def _get_joined_users_from_context( + self, + room_id, + state_group, + current_state_ids, + cache_context, + event=None, + context=None, + ): + # We don't use `state_group`, it's there so that we can cache based + # on it. However, it's important that it's never None, since two current_states + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + users_in_room = {} + member_event_ids = [ + e_id + for key, e_id in iteritems(current_state_ids) + if key[0] == EventTypes.Member + ] + + if context is not None: + # If we have a context with a delta from a previous state group, + # check if we also have the result from the previous group in cache. + # If we do then we can reuse that result and simply update it with + # any membership changes in `delta_ids` + if context.prev_group and context.delta_ids: + prev_res = self._get_joined_users_from_context.cache.get( + (room_id, context.prev_group), None + ) + if prev_res and isinstance(prev_res, dict): + users_in_room = dict(prev_res) + member_event_ids = [ + e_id + for key, e_id in iteritems(context.delta_ids) + if key[0] == EventTypes.Member + ] + for etype, state_key in context.delta_ids: + users_in_room.pop(state_key, None) + + # We check if we have any of the member event ids in the event cache + # before we ask the DB + + # We don't update the event cache hit ratio as it completely throws off + # the hit ratio counts. After all, we don't populate the cache if we + # miss it here + event_map = self._get_events_from_cache( + member_event_ids, allow_rejected=False, update_metrics=False + ) + + missing_member_event_ids = [] + for event_id in member_event_ids: + ev_entry = event_map.get(event_id) + if ev_entry: + if ev_entry.event.membership == Membership.JOIN: + users_in_room[to_ascii(ev_entry.event.state_key)] = ProfileInfo( + display_name=to_ascii( + ev_entry.event.content.get("displayname", None) + ), + avatar_url=to_ascii( + ev_entry.event.content.get("avatar_url", None) + ), + ) + else: + missing_member_event_ids.append(event_id) + + if missing_member_event_ids: + event_to_memberships = yield self._get_joined_profiles_from_event_ids( + missing_member_event_ids + ) + users_in_room.update((row for row in event_to_memberships.values() if row)) + + if event is not None and event.type == EventTypes.Member: + if event.membership == Membership.JOIN: + if event.event_id in member_event_ids: + users_in_room[to_ascii(event.state_key)] = ProfileInfo( + display_name=to_ascii(event.content.get("displayname", None)), + avatar_url=to_ascii(event.content.get("avatar_url", None)), + ) + + return users_in_room + + @cached(max_entries=10000) + def _get_joined_profile_from_event_id(self, event_id): + raise NotImplementedError() + + @cachedList( + cached_method_name="_get_joined_profile_from_event_id", + list_name="event_ids", + inlineCallbacks=True, + ) + def _get_joined_profiles_from_event_ids(self, event_ids): + """For given set of member event_ids check if they point to a join + event and if so return the associated user and profile info. + + Args: + event_ids (Iterable[str]): The member event IDs to lookup + + Returns: + Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID + to `user_id` and ProfileInfo (or None if not join event). + """ + + rows = yield self._simple_select_many_batch( + table="room_memberships", + column="event_id", + iterable=event_ids, + retcols=("user_id", "display_name", "avatar_url", "event_id"), + keyvalues={"membership": Membership.JOIN}, + batch_size=500, + desc="_get_membership_from_event_ids", + ) + + return { + row["event_id"]: ( + row["user_id"], + ProfileInfo( + avatar_url=row["avatar_url"], display_name=row["display_name"] + ), + ) + for row in rows + } + + @cachedInlineCallbacks(max_entries=10000) + def is_host_joined(self, room_id, host): + if "%" in host or "_" in host: + raise Exception("Invalid host name") + + sql = """ + SELECT state_key FROM current_state_events AS c + INNER JOIN room_memberships AS m USING (event_id) + WHERE m.membership = 'join' + AND type = 'm.room.member' + AND c.room_id = ? + AND state_key LIKE ? + LIMIT 1 + """ + + # We do need to be careful to ensure that host doesn't have any wild cards + # in it, but we checked above for known ones and we'll check below that + # the returned user actually has the correct domain. + like_clause = "%:" + host + + rows = yield self._execute("is_host_joined", None, sql, room_id, like_clause) + + if not rows: + return False + + user_id = rows[0][0] + if get_domain_from_id(user_id) != host: + # This can only happen if the host name has something funky in it + raise Exception("Invalid host name") + + return True + + @cachedInlineCallbacks() + def was_host_joined(self, room_id, host): + """Check whether the server is or ever was in the room. + + Args: + room_id (str) + host (str) + + Returns: + Deferred: Resolves to True if the host is/was in the room, otherwise + False. + """ + if "%" in host or "_" in host: + raise Exception("Invalid host name") + + sql = """ + SELECT user_id FROM room_memberships + WHERE room_id = ? + AND user_id LIKE ? + AND membership = 'join' + LIMIT 1 + """ + + # We do need to be careful to ensure that host doesn't have any wild cards + # in it, but we checked above for known ones and we'll check below that + # the returned user actually has the correct domain. + like_clause = "%:" + host + + rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause) + + if not rows: + return False + + user_id = rows[0][0] + if get_domain_from_id(user_id) != host: + # This can only happen if the host name has something funky in it + raise Exception("Invalid host name") + + return True + + @defer.inlineCallbacks + def get_joined_hosts(self, room_id, state_entry): + state_group = state_entry.state_group + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + with Measure(self._clock, "get_joined_hosts"): + return ( + yield self._get_joined_hosts( + room_id, state_group, state_entry.state, state_entry=state_entry + ) + ) + + @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True) + # @defer.inlineCallbacks + def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry): + # We don't use `state_group`, its there so that we can cache based + # on it. However, its important that its never None, since two current_state's + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + cache = self._get_joined_hosts_cache(room_id) + joined_hosts = yield cache.get_destinations(state_entry) + + return joined_hosts + + @cached(max_entries=10000) + def _get_joined_hosts_cache(self, room_id): + return _JoinedHostsCache(self, room_id) + + @cachedInlineCallbacks(num_args=2) + def did_forget(self, user_id, room_id): + """Returns whether user_id has elected to discard history for room_id. + + Returns False if they have since re-joined.""" + + def f(txn): + sql = ( + "SELECT" + " COUNT(*)" + " FROM" + " room_memberships" + " WHERE" + " user_id = ?" + " AND" + " room_id = ?" + " AND" + " forgotten = 0" + ) + txn.execute(sql, (user_id, room_id)) + rows = txn.fetchall() + return rows[0][0] + + count = yield self.runInteraction("did_forget_membership", f) + return count == 0 + + @cached() + def get_forgotten_rooms_for_user(self, user_id): + """Gets all rooms the user has forgotten. + + Args: + user_id (str) + + Returns: + Deferred[set[str]] + """ + + def _get_forgotten_rooms_for_user_txn(txn): + # This is a slightly convoluted query that first looks up all rooms + # that the user has forgotten in the past, then rechecks that list + # to see if any have subsequently been updated. This is done so that + # we can use a partial index on `forgotten = 1` on the assumption + # that few users will actually forget many rooms. + # + # Note that a room is considered "forgotten" if *all* membership + # events for that user and room have the forgotten field set (as + # when a user forgets a room we update all rows for that user and + # room, not just the current one). + sql = """ + SELECT room_id, ( + SELECT count(*) FROM room_memberships + WHERE room_id = m.room_id AND user_id = m.user_id AND forgotten = 0 + ) AS count + FROM room_memberships AS m + WHERE user_id = ? AND forgotten = 1 + GROUP BY room_id, user_id; + """ + txn.execute(sql, (user_id,)) + return set(row[0] for row in txn if row[1] == 0) + + return self.runInteraction( + "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn + ) + + @defer.inlineCallbacks + def get_rooms_user_has_been_in(self, user_id): + """Get all rooms that the user has ever been in. + + Args: + user_id (str) + + Returns: + Deferred[set[str]]: Set of room IDs. + """ + + room_ids = yield self._simple_select_onecol( + table="room_memberships", + keyvalues={"membership": Membership.JOIN, "user_id": user_id}, + retcol="room_id", + desc="get_rooms_user_has_been_in", + ) + + return set(room_ids) + + +class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): + def __init__(self, db_conn, hs): + super(RoomMemberBackgroundUpdateStore, self).__init__(db_conn, hs) + self.register_background_update_handler( + _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile + ) + self.register_background_update_handler( + _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, + self._background_current_state_membership, + ) + self.register_background_index_update( + "room_membership_forgotten_idx", + index_name="room_memberships_user_room_forgotten", + table="room_memberships", + columns=["user_id", "room_id"], + where_clause="forgotten = 1", + ) + + @defer.inlineCallbacks + def _background_add_membership_profile(self, progress, batch_size): + target_min_stream_id = progress.get( + "target_min_stream_id_inclusive", self._min_stream_order_on_start + ) + max_stream_id = progress.get( + "max_stream_id_exclusive", self._stream_order_on_start + 1 + ) + + INSERT_CLUMP_SIZE = 1000 + + def add_membership_profile_txn(txn): + sql = """ + SELECT stream_ordering, event_id, events.room_id, event_json.json + FROM events + INNER JOIN event_json USING (event_id) + INNER JOIN room_memberships USING (event_id) + WHERE ? <= stream_ordering AND stream_ordering < ? + AND type = 'm.room.member' + ORDER BY stream_ordering DESC + LIMIT ? + """ + + txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) + + rows = self.cursor_to_dict(txn) + if not rows: + return 0 + + min_stream_id = rows[-1]["stream_ordering"] + + to_update = [] + for row in rows: + event_id = row["event_id"] + room_id = row["room_id"] + try: + event_json = json.loads(row["json"]) + content = event_json["content"] + except Exception: + continue + + display_name = content.get("displayname", None) + avatar_url = content.get("avatar_url", None) + + if display_name or avatar_url: + to_update.append((display_name, avatar_url, event_id, room_id)) + + to_update_sql = """ + UPDATE room_memberships SET display_name = ?, avatar_url = ? + WHERE event_id = ? AND room_id = ? + """ + for index in range(0, len(to_update), INSERT_CLUMP_SIZE): + clump = to_update[index : index + INSERT_CLUMP_SIZE] + txn.executemany(to_update_sql, clump) + + progress = { + "target_min_stream_id_inclusive": target_min_stream_id, + "max_stream_id_exclusive": min_stream_id, + } + + self._background_update_progress_txn( + txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress + ) + + return len(rows) + + result = yield self.runInteraction( + _MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn + ) + + if not result: + yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME) + + return result + + @defer.inlineCallbacks + def _background_current_state_membership(self, progress, batch_size): + """Update the new membership column on current_state_events. + + This works by iterating over all rooms in alphebetical order. + """ + + def _background_current_state_membership_txn(txn, last_processed_room): + processed = 0 + while processed < batch_size: + txn.execute( + """ + SELECT MIN(room_id) FROM current_state_events WHERE room_id > ? + """, + (last_processed_room,), + ) + row = txn.fetchone() + if not row or not row[0]: + return processed, True + + next_room, = row + + sql = """ + UPDATE current_state_events + SET membership = ( + SELECT membership FROM room_memberships + WHERE event_id = current_state_events.event_id + ) + WHERE room_id = ? + """ + txn.execute(sql, (next_room,)) + processed += txn.rowcount + + last_processed_room = next_room + + self._background_update_progress_txn( + txn, + _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, + {"last_processed_room": last_processed_room}, + ) + + return processed, False + + # If we haven't got a last processed room then just use the empty + # string, which will compare before all room IDs correctly. + last_processed_room = progress.get("last_processed_room", "") + + row_count, finished = yield self.runInteraction( + "_background_current_state_membership_update", + _background_current_state_membership_txn, + last_processed_room, + ) + + if finished: + yield self._end_background_update(_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME) + + return row_count + + +class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): + def __init__(self, db_conn, hs): + super(RoomMemberStore, self).__init__(db_conn, hs) + + def _store_room_members_txn(self, txn, events, backfilled): + """Store a room member in the database. + """ + self._simple_insert_many_txn( + txn, + table="room_memberships", + values=[ + { + "event_id": event.event_id, + "user_id": event.state_key, + "sender": event.user_id, + "room_id": event.room_id, + "membership": event.membership, + "display_name": event.content.get("displayname", None), + "avatar_url": event.content.get("avatar_url", None), + } + for event in events + ], + ) + + for event in events: + txn.call_after( + self._membership_stream_cache.entity_has_changed, + event.state_key, + event.internal_metadata.stream_ordering, + ) + txn.call_after( + self.get_invited_rooms_for_user.invalidate, (event.state_key,) + ) + + # We update the local_invites table only if the event is "current", + # i.e., its something that has just happened. If the event is an + # outlier it is only current if its an "out of band membership", + # like a remote invite or a rejection of a remote invite. + is_new_state = not backfilled and ( + not event.internal_metadata.is_outlier() + or event.internal_metadata.is_out_of_band_membership() + ) + is_mine = self.hs.is_mine_id(event.state_key) + if is_new_state and is_mine: + if event.membership == Membership.INVITE: + self._simple_insert_txn( + txn, + table="local_invites", + values={ + "event_id": event.event_id, + "invitee": event.state_key, + "inviter": event.sender, + "room_id": event.room_id, + "stream_id": event.internal_metadata.stream_ordering, + }, + ) + else: + sql = ( + "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE" + " room_id = ? AND invitee = ? AND locally_rejected is NULL" + " AND replaced_by is NULL" + ) + + txn.execute( + sql, + ( + event.internal_metadata.stream_ordering, + event.event_id, + event.room_id, + event.state_key, + ), + ) + + @defer.inlineCallbacks + def locally_reject_invite(self, user_id, room_id): + sql = ( + "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE" + " room_id = ? AND invitee = ? AND locally_rejected is NULL" + " AND replaced_by is NULL" + ) + + def f(txn, stream_ordering): + txn.execute(sql, (stream_ordering, True, room_id, user_id)) + + with self._stream_id_gen.get_next() as stream_ordering: + yield self.runInteraction("locally_reject_invite", f, stream_ordering) + + def forget(self, user_id, room_id): + """Indicate that user_id wishes to discard history for room_id.""" + + def f(txn): + sql = ( + "UPDATE" + " room_memberships" + " SET" + " forgotten = 1" + " WHERE" + " user_id = ?" + " AND" + " room_id = ?" + ) + txn.execute(sql, (user_id, room_id)) + + self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id)) + self._invalidate_cache_and_stream( + txn, self.get_forgotten_rooms_for_user, (user_id,) + ) + + return self.runInteraction("forget_membership", f) + + +class _JoinedHostsCache(object): + """Cache for joined hosts in a room that is optimised to handle updates + via state deltas. + """ + + def __init__(self, store, room_id): + self.store = store + self.room_id = room_id + + self.hosts_to_joined_users = {} + + self.state_group = object() + + self.linearizer = Linearizer("_JoinedHostsCache") + + self._len = 0 + + @defer.inlineCallbacks + def get_destinations(self, state_entry): + """Get set of destinations for a state entry + + Args: + state_entry(synapse.state._StateCacheEntry) + """ + if state_entry.state_group == self.state_group: + return frozenset(self.hosts_to_joined_users) + + with (yield self.linearizer.queue(())): + if state_entry.state_group == self.state_group: + pass + elif state_entry.prev_group == self.state_group: + for (typ, state_key), event_id in iteritems(state_entry.delta_ids): + if typ != EventTypes.Member: + continue + + host = intern_string(get_domain_from_id(state_key)) + user_id = state_key + known_joins = self.hosts_to_joined_users.setdefault(host, set()) + + event = yield self.store.get_event(event_id) + if event.membership == Membership.JOIN: + known_joins.add(user_id) + else: + known_joins.discard(user_id) + + if not known_joins: + self.hosts_to_joined_users.pop(host, None) + else: + joined_users = yield self.store.get_joined_users_from_state( + self.room_id, state_entry + ) + + self.hosts_to_joined_users = {} + for user_id in joined_users: + host = intern_string(get_domain_from_id(user_id)) + self.hosts_to_joined_users.setdefault(host, set()).add(user_id) + + if state_entry.state_group: + self.state_group = state_entry.state_group + else: + self.state_group = object() + self._len = sum(len(v) for v in itervalues(self.hosts_to_joined_users)) + return frozenset(self.hosts_to_joined_users) + + def __len__(self): + return self._len diff --git a/synapse/storage/data_stores/main/schema/delta/12/v12.sql b/synapse/storage/data_stores/main/schema/delta/12/v12.sql new file mode 100644 index 0000000000..5964c5aaac --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/12/v12.sql @@ -0,0 +1,63 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS rejections( + event_id TEXT NOT NULL, + reason TEXT NOT NULL, + last_check TEXT NOT NULL, + UNIQUE (event_id) +); + +-- Push notification endpoints that users have configured +CREATE TABLE IF NOT EXISTS pushers ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_name TEXT NOT NULL, + profile_tag VARCHAR(32) NOT NULL, + kind VARCHAR(8) NOT NULL, + app_id VARCHAR(64) NOT NULL, + app_display_name VARCHAR(64) NOT NULL, + device_display_name VARCHAR(128) NOT NULL, + pushkey VARBINARY(512) NOT NULL, + ts BIGINT UNSIGNED NOT NULL, + lang VARCHAR(8), + data LONGBLOB, + last_token TEXT, + last_success BIGINT UNSIGNED, + failing_since BIGINT UNSIGNED, + UNIQUE (app_id, pushkey) +); + +CREATE TABLE IF NOT EXISTS push_rules ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_name TEXT NOT NULL, + rule_id TEXT NOT NULL, + priority_class TINYINT NOT NULL, + priority INTEGER NOT NULL DEFAULT 0, + conditions TEXT NOT NULL, + actions TEXT NOT NULL, + UNIQUE(user_name, rule_id) +); + +CREATE INDEX IF NOT EXISTS push_rules_user_name on push_rules (user_name); + +CREATE TABLE IF NOT EXISTS user_filters( + user_id TEXT, + filter_id BIGINT UNSIGNED, + filter_json LONGBLOB +); + +CREATE INDEX IF NOT EXISTS user_filters_by_user_id_filter_id ON user_filters( + user_id, filter_id +); diff --git a/synapse/storage/data_stores/main/schema/delta/13/v13.sql b/synapse/storage/data_stores/main/schema/delta/13/v13.sql new file mode 100644 index 0000000000..f8649e5d99 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/13/v13.sql @@ -0,0 +1,19 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* We used to create a tables called application_services and + * application_services_regex, but these are no longer used and are removed in + * delta 54. + */ diff --git a/synapse/storage/data_stores/main/schema/delta/14/v14.sql b/synapse/storage/data_stores/main/schema/delta/14/v14.sql new file mode 100644 index 0000000000..a831920da6 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/14/v14.sql @@ -0,0 +1,23 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +CREATE TABLE IF NOT EXISTS push_rules_enable ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_name TEXT NOT NULL, + rule_id TEXT NOT NULL, + enabled TINYINT, + UNIQUE(user_name, rule_id) +); + +CREATE INDEX IF NOT EXISTS push_rules_enable_user_name on push_rules_enable (user_name); diff --git a/synapse/storage/data_stores/main/schema/delta/15/appservice_txns.sql b/synapse/storage/data_stores/main/schema/delta/15/appservice_txns.sql new file mode 100644 index 0000000000..e4f5e76aec --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/15/appservice_txns.sql @@ -0,0 +1,31 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS application_services_state( + as_id TEXT PRIMARY KEY, + state VARCHAR(5), + last_txn INTEGER +); + +CREATE TABLE IF NOT EXISTS application_services_txns( + as_id TEXT NOT NULL, + txn_id INTEGER NOT NULL, + event_ids TEXT NOT NULL, + UNIQUE(as_id, txn_id) +); + +CREATE INDEX IF NOT EXISTS application_services_txns_id ON application_services_txns ( + as_id +); diff --git a/synapse/storage/data_stores/main/schema/delta/15/presence_indices.sql b/synapse/storage/data_stores/main/schema/delta/15/presence_indices.sql new file mode 100644 index 0000000000..6b8d0f1ca7 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/15/presence_indices.sql @@ -0,0 +1,2 @@ + +CREATE INDEX IF NOT EXISTS presence_list_user_id ON presence_list (user_id); diff --git a/synapse/storage/data_stores/main/schema/delta/15/v15.sql b/synapse/storage/data_stores/main/schema/delta/15/v15.sql new file mode 100644 index 0000000000..9523d2bcc3 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/15/v15.sql @@ -0,0 +1,24 @@ +-- Drop, copy & recreate pushers table to change unique key +-- Also add access_token column at the same time +CREATE TABLE IF NOT EXISTS pushers2 ( + id BIGINT PRIMARY KEY, + user_name TEXT NOT NULL, + access_token BIGINT DEFAULT NULL, + profile_tag VARCHAR(32) NOT NULL, + kind VARCHAR(8) NOT NULL, + app_id VARCHAR(64) NOT NULL, + app_display_name VARCHAR(64) NOT NULL, + device_display_name VARCHAR(128) NOT NULL, + pushkey bytea NOT NULL, + ts BIGINT NOT NULL, + lang VARCHAR(8), + data bytea, + last_token TEXT, + last_success BIGINT, + failing_since BIGINT, + UNIQUE (app_id, pushkey) +); +INSERT INTO pushers2 (id, user_name, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_token, last_success, failing_since) + SELECT id, user_name, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_token, last_success, failing_since FROM pushers; +DROP TABLE pushers; +ALTER TABLE pushers2 RENAME TO pushers; diff --git a/synapse/storage/data_stores/main/schema/delta/16/events_order_index.sql b/synapse/storage/data_stores/main/schema/delta/16/events_order_index.sql new file mode 100644 index 0000000000..a48f215170 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/16/events_order_index.sql @@ -0,0 +1,4 @@ +CREATE INDEX events_order ON events (topological_ordering, stream_ordering); +CREATE INDEX events_order_room ON events ( + room_id, topological_ordering, stream_ordering +); diff --git a/synapse/storage/data_stores/main/schema/delta/16/remote_media_cache_index.sql b/synapse/storage/data_stores/main/schema/delta/16/remote_media_cache_index.sql new file mode 100644 index 0000000000..7a15265cb1 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/16/remote_media_cache_index.sql @@ -0,0 +1,2 @@ +CREATE INDEX IF NOT EXISTS remote_media_cache_thumbnails_media_id + ON remote_media_cache_thumbnails (media_id); \ No newline at end of file diff --git a/synapse/storage/data_stores/main/schema/delta/16/remove_duplicates.sql b/synapse/storage/data_stores/main/schema/delta/16/remove_duplicates.sql new file mode 100644 index 0000000000..65c97b5e2f --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/16/remove_duplicates.sql @@ -0,0 +1,9 @@ + + +DELETE FROM event_to_state_groups WHERE state_group not in ( + SELECT MAX(state_group) FROM event_to_state_groups GROUP BY event_id +); + +DELETE FROM event_to_state_groups WHERE rowid not in ( + SELECT MIN(rowid) FROM event_to_state_groups GROUP BY event_id +); diff --git a/synapse/storage/data_stores/main/schema/delta/16/room_alias_index.sql b/synapse/storage/data_stores/main/schema/delta/16/room_alias_index.sql new file mode 100644 index 0000000000..f82486132b --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/16/room_alias_index.sql @@ -0,0 +1,3 @@ + +CREATE INDEX IF NOT EXISTS room_aliases_id ON room_aliases(room_id); +CREATE INDEX IF NOT EXISTS room_alias_servers_alias ON room_alias_servers(room_alias); diff --git a/synapse/storage/data_stores/main/schema/delta/16/unique_constraints.sql b/synapse/storage/data_stores/main/schema/delta/16/unique_constraints.sql new file mode 100644 index 0000000000..5b8de52c33 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/16/unique_constraints.sql @@ -0,0 +1,72 @@ + +-- We can use SQLite features here, since other db support was only added in v16 + +-- +DELETE FROM current_state_events WHERE rowid not in ( + SELECT MIN(rowid) FROM current_state_events GROUP BY event_id +); + +DROP INDEX IF EXISTS current_state_events_event_id; +CREATE UNIQUE INDEX current_state_events_event_id ON current_state_events(event_id); + +-- +DELETE FROM room_memberships WHERE rowid not in ( + SELECT MIN(rowid) FROM room_memberships GROUP BY event_id +); + +DROP INDEX IF EXISTS room_memberships_event_id; +CREATE UNIQUE INDEX room_memberships_event_id ON room_memberships(event_id); + +-- +DELETE FROM topics WHERE rowid not in ( + SELECT MIN(rowid) FROM topics GROUP BY event_id +); + +DROP INDEX IF EXISTS topics_event_id; +CREATE UNIQUE INDEX topics_event_id ON topics(event_id); + +-- +DELETE FROM room_names WHERE rowid not in ( + SELECT MIN(rowid) FROM room_names GROUP BY event_id +); + +DROP INDEX IF EXISTS room_names_id; +CREATE UNIQUE INDEX room_names_id ON room_names(event_id); + +-- +DELETE FROM presence WHERE rowid not in ( + SELECT MIN(rowid) FROM presence GROUP BY user_id +); + +DROP INDEX IF EXISTS presence_id; +CREATE UNIQUE INDEX presence_id ON presence(user_id); + +-- +DELETE FROM presence_allow_inbound WHERE rowid not in ( + SELECT MIN(rowid) FROM presence_allow_inbound + GROUP BY observed_user_id, observer_user_id +); + +DROP INDEX IF EXISTS presence_allow_inbound_observers; +CREATE UNIQUE INDEX presence_allow_inbound_observers ON presence_allow_inbound( + observed_user_id, observer_user_id +); + +-- +DELETE FROM presence_list WHERE rowid not in ( + SELECT MIN(rowid) FROM presence_list + GROUP BY user_id, observed_user_id +); + +DROP INDEX IF EXISTS presence_list_observers; +CREATE UNIQUE INDEX presence_list_observers ON presence_list( + user_id, observed_user_id +); + +-- +DELETE FROM room_aliases WHERE rowid not in ( + SELECT MIN(rowid) FROM room_aliases GROUP BY room_alias +); + +DROP INDEX IF EXISTS room_aliases_id; +CREATE INDEX room_aliases_id ON room_aliases(room_id); diff --git a/synapse/storage/data_stores/main/schema/delta/16/users.sql b/synapse/storage/data_stores/main/schema/delta/16/users.sql new file mode 100644 index 0000000000..cd0709250d --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/16/users.sql @@ -0,0 +1,56 @@ +-- Convert `access_tokens`.user from rowids to user strings. +-- MUST BE DONE BEFORE REMOVING ID COLUMN FROM USERS TABLE BELOW +CREATE TABLE IF NOT EXISTS new_access_tokens( + id BIGINT UNSIGNED PRIMARY KEY, + user_id TEXT NOT NULL, + device_id TEXT, + token TEXT NOT NULL, + last_used BIGINT UNSIGNED, + UNIQUE(token) +); + +INSERT INTO new_access_tokens + SELECT a.id, u.name, a.device_id, a.token, a.last_used + FROM access_tokens as a + INNER JOIN users as u ON u.id = a.user_id; + +DROP TABLE access_tokens; + +ALTER TABLE new_access_tokens RENAME TO access_tokens; + +-- Remove ID column from `users` table +CREATE TABLE IF NOT EXISTS new_users( + name TEXT, + password_hash TEXT, + creation_ts BIGINT UNSIGNED, + admin BOOL DEFAULT 0 NOT NULL, + UNIQUE(name) +); + +INSERT INTO new_users SELECT name, password_hash, creation_ts, admin FROM users; + +DROP TABLE users; + +ALTER TABLE new_users RENAME TO users; + + +-- Remove UNIQUE constraint from `user_ips` table +CREATE TABLE IF NOT EXISTS new_user_ips ( + user_id TEXT NOT NULL, + access_token TEXT NOT NULL, + device_id TEXT, + ip TEXT NOT NULL, + user_agent TEXT NOT NULL, + last_seen BIGINT UNSIGNED NOT NULL +); + +INSERT INTO new_user_ips + SELECT user, access_token, device_id, ip, user_agent, last_seen FROM user_ips; + +DROP TABLE user_ips; + +ALTER TABLE new_user_ips RENAME TO user_ips; + +CREATE INDEX IF NOT EXISTS user_ips_user ON user_ips(user_id); +CREATE INDEX IF NOT EXISTS user_ips_user_ip ON user_ips(user_id, access_token, ip); + diff --git a/synapse/storage/data_stores/main/schema/delta/17/drop_indexes.sql b/synapse/storage/data_stores/main/schema/delta/17/drop_indexes.sql new file mode 100644 index 0000000000..7c9a90e27f --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/17/drop_indexes.sql @@ -0,0 +1,18 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +DROP INDEX IF EXISTS sent_transaction_dest; +DROP INDEX IF EXISTS sent_transaction_sent; +DROP INDEX IF EXISTS user_ips_user; diff --git a/synapse/storage/data_stores/main/schema/delta/17/server_keys.sql b/synapse/storage/data_stores/main/schema/delta/17/server_keys.sql new file mode 100644 index 0000000000..70b247a06b --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/17/server_keys.sql @@ -0,0 +1,24 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS server_keys_json ( + server_name TEXT, -- Server name. + key_id TEXT, -- Requested key id. + from_server TEXT, -- Which server the keys were fetched from. + ts_added_ms INTEGER, -- When the keys were fetched + ts_valid_until_ms INTEGER, -- When this version of the keys exipires. + key_json bytea, -- JSON certificate for the remote server. + CONSTRAINT uniqueness UNIQUE (server_name, key_id, from_server) +); diff --git a/synapse/storage/data_stores/main/schema/delta/17/user_threepids.sql b/synapse/storage/data_stores/main/schema/delta/17/user_threepids.sql new file mode 100644 index 0000000000..c17715ac80 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/17/user_threepids.sql @@ -0,0 +1,9 @@ +CREATE TABLE user_threepids ( + user_id TEXT NOT NULL, + medium TEXT NOT NULL, + address TEXT NOT NULL, + validated_at BIGINT NOT NULL, + added_at BIGINT NOT NULL, + CONSTRAINT user_medium_address UNIQUE (user_id, medium, address) +); +CREATE INDEX user_threepids_user_id ON user_threepids(user_id); diff --git a/synapse/storage/data_stores/main/schema/delta/18/server_keys_bigger_ints.sql b/synapse/storage/data_stores/main/schema/delta/18/server_keys_bigger_ints.sql new file mode 100644 index 0000000000..6e0871c92b --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/18/server_keys_bigger_ints.sql @@ -0,0 +1,32 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE IF NOT EXISTS new_server_keys_json ( + server_name TEXT NOT NULL, -- Server name. + key_id TEXT NOT NULL, -- Requested key id. + from_server TEXT NOT NULL, -- Which server the keys were fetched from. + ts_added_ms BIGINT NOT NULL, -- When the keys were fetched + ts_valid_until_ms BIGINT NOT NULL, -- When this version of the keys exipires. + key_json bytea NOT NULL, -- JSON certificate for the remote server. + CONSTRAINT server_keys_json_uniqueness UNIQUE (server_name, key_id, from_server) +); + +INSERT INTO new_server_keys_json + SELECT server_name, key_id, from_server,ts_added_ms, ts_valid_until_ms, key_json FROM server_keys_json ; + +DROP TABLE server_keys_json; + +ALTER TABLE new_server_keys_json RENAME TO server_keys_json; diff --git a/synapse/storage/data_stores/main/schema/delta/19/event_index.sql b/synapse/storage/data_stores/main/schema/delta/19/event_index.sql new file mode 100644 index 0000000000..18b97b4332 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/19/event_index.sql @@ -0,0 +1,19 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE INDEX events_order_topo_stream_room ON events( + topological_ordering, stream_ordering, room_id +); diff --git a/synapse/storage/data_stores/main/schema/delta/20/dummy.sql b/synapse/storage/data_stores/main/schema/delta/20/dummy.sql new file mode 100644 index 0000000000..e0ac49d1ec --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/20/dummy.sql @@ -0,0 +1 @@ +SELECT 1; diff --git a/synapse/storage/data_stores/main/schema/delta/20/pushers.py b/synapse/storage/data_stores/main/schema/delta/20/pushers.py new file mode 100644 index 0000000000..3edfcfd783 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/20/pushers.py @@ -0,0 +1,88 @@ +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +Main purpose of this upgrade is to change the unique key on the +pushers table again (it was missed when the v16 full schema was +made) but this also changes the pushkey and data columns to text. +When selecting a bytea column into a text column, postgres inserts +the hex encoded data, and there's no portable way of getting the +UTF-8 bytes, so we have to do it in Python. +""" + +import logging + +logger = logging.getLogger(__name__) + + +def run_create(cur, database_engine, *args, **kwargs): + logger.info("Porting pushers table...") + cur.execute( + """ + CREATE TABLE IF NOT EXISTS pushers2 ( + id BIGINT PRIMARY KEY, + user_name TEXT NOT NULL, + access_token BIGINT DEFAULT NULL, + profile_tag VARCHAR(32) NOT NULL, + kind VARCHAR(8) NOT NULL, + app_id VARCHAR(64) NOT NULL, + app_display_name VARCHAR(64) NOT NULL, + device_display_name VARCHAR(128) NOT NULL, + pushkey TEXT NOT NULL, + ts BIGINT NOT NULL, + lang VARCHAR(8), + data TEXT, + last_token TEXT, + last_success BIGINT, + failing_since BIGINT, + UNIQUE (app_id, pushkey, user_name) + ) + """ + ) + cur.execute( + """SELECT + id, user_name, access_token, profile_tag, kind, + app_id, app_display_name, device_display_name, + pushkey, ts, lang, data, last_token, last_success, + failing_since + FROM pushers + """ + ) + count = 0 + for row in cur.fetchall(): + row = list(row) + row[8] = bytes(row[8]).decode("utf-8") + row[11] = bytes(row[11]).decode("utf-8") + cur.execute( + database_engine.convert_param_style( + """ + INSERT into pushers2 ( + id, user_name, access_token, profile_tag, kind, + app_id, app_display_name, device_display_name, + pushkey, ts, lang, data, last_token, last_success, + failing_since + ) values (%s)""" + % (",".join(["?" for _ in range(len(row))])) + ), + row, + ) + count += 1 + cur.execute("DROP TABLE pushers") + cur.execute("ALTER TABLE pushers2 RENAME TO pushers") + logger.info("Moved %d pushers to new table", count) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/data_stores/main/schema/delta/21/end_to_end_keys.sql b/synapse/storage/data_stores/main/schema/delta/21/end_to_end_keys.sql new file mode 100644 index 0000000000..4c2fb20b77 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/21/end_to_end_keys.sql @@ -0,0 +1,34 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE IF NOT EXISTS e2e_device_keys_json ( + user_id TEXT NOT NULL, -- The user these keys are for. + device_id TEXT NOT NULL, -- Which of the user's devices these keys are for. + ts_added_ms BIGINT NOT NULL, -- When the keys were uploaded. + key_json TEXT NOT NULL, -- The keys for the device as a JSON blob. + CONSTRAINT e2e_device_keys_json_uniqueness UNIQUE (user_id, device_id) +); + + +CREATE TABLE IF NOT EXISTS e2e_one_time_keys_json ( + user_id TEXT NOT NULL, -- The user this one-time key is for. + device_id TEXT NOT NULL, -- The device this one-time key is for. + algorithm TEXT NOT NULL, -- Which algorithm this one-time key is for. + key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads. + ts_added_ms BIGINT NOT NULL, -- When this key was uploaded. + key_json TEXT NOT NULL, -- The key as a JSON blob. + CONSTRAINT e2e_one_time_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm, key_id) +); diff --git a/synapse/storage/data_stores/main/schema/delta/21/receipts.sql b/synapse/storage/data_stores/main/schema/delta/21/receipts.sql new file mode 100644 index 0000000000..d070845477 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/21/receipts.sql @@ -0,0 +1,38 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE IF NOT EXISTS receipts_graph( + room_id TEXT NOT NULL, + receipt_type TEXT NOT NULL, + user_id TEXT NOT NULL, + event_ids TEXT NOT NULL, + data TEXT NOT NULL, + CONSTRAINT receipts_graph_uniqueness UNIQUE (room_id, receipt_type, user_id) +); + +CREATE TABLE IF NOT EXISTS receipts_linearized ( + stream_id BIGINT NOT NULL, + room_id TEXT NOT NULL, + receipt_type TEXT NOT NULL, + user_id TEXT NOT NULL, + event_id TEXT NOT NULL, + data TEXT NOT NULL, + CONSTRAINT receipts_linearized_uniqueness UNIQUE (room_id, receipt_type, user_id) +); + +CREATE INDEX receipts_linearized_id ON receipts_linearized( + stream_id +); diff --git a/synapse/storage/data_stores/main/schema/delta/22/receipts_index.sql b/synapse/storage/data_stores/main/schema/delta/22/receipts_index.sql new file mode 100644 index 0000000000..bfc0b3bcaa --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/22/receipts_index.sql @@ -0,0 +1,22 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Using CREATE INDEX directly is deprecated in favour of using background + * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql + * and synapse/storage/registration.py for an example using + * "access_tokens_device_index" **/ +CREATE INDEX receipts_linearized_room_stream ON receipts_linearized( + room_id, stream_id +); diff --git a/synapse/storage/data_stores/main/schema/delta/22/user_threepids_unique.sql b/synapse/storage/data_stores/main/schema/delta/22/user_threepids_unique.sql new file mode 100644 index 0000000000..87edfa454c --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/22/user_threepids_unique.sql @@ -0,0 +1,19 @@ +CREATE TABLE IF NOT EXISTS user_threepids2 ( + user_id TEXT NOT NULL, + medium TEXT NOT NULL, + address TEXT NOT NULL, + validated_at BIGINT NOT NULL, + added_at BIGINT NOT NULL, + CONSTRAINT medium_address UNIQUE (medium, address) +); + +INSERT INTO user_threepids2 + SELECT * FROM user_threepids WHERE added_at IN ( + SELECT max(added_at) FROM user_threepids GROUP BY medium, address + ) +; + +DROP TABLE user_threepids; +ALTER TABLE user_threepids2 RENAME TO user_threepids; + +CREATE INDEX user_threepids_user_id ON user_threepids(user_id); diff --git a/synapse/storage/data_stores/main/schema/delta/23/drop_state_index.sql b/synapse/storage/data_stores/main/schema/delta/23/drop_state_index.sql new file mode 100644 index 0000000000..ae09fa0065 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/23/drop_state_index.sql @@ -0,0 +1,16 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +DROP INDEX IF EXISTS state_groups_state_tuple; diff --git a/synapse/storage/data_stores/main/schema/delta/24/stats_reporting.sql b/synapse/storage/data_stores/main/schema/delta/24/stats_reporting.sql new file mode 100644 index 0000000000..acea7483bd --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/24/stats_reporting.sql @@ -0,0 +1,18 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + /* We used to create a table called stats_reporting, but this is no longer + * used and is removed in delta 54. + */ \ No newline at end of file diff --git a/synapse/storage/data_stores/main/schema/delta/25/00background_updates.sql b/synapse/storage/data_stores/main/schema/delta/25/00background_updates.sql new file mode 100644 index 0000000000..2ad9e8fa56 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/25/00background_updates.sql @@ -0,0 +1,21 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE IF NOT EXISTS background_updates( + update_name TEXT NOT NULL, -- The name of the background update. + progress_json TEXT NOT NULL, -- The current progress of the update as JSON. + CONSTRAINT background_updates_uniqueness UNIQUE (update_name) +); diff --git a/synapse/storage/data_stores/main/schema/delta/25/fts.py b/synapse/storage/data_stores/main/schema/delta/25/fts.py new file mode 100644 index 0000000000..4b2ffd35fd --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/25/fts.py @@ -0,0 +1,82 @@ +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import simplejson + +from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.storage.prepare_database import get_statements + +logger = logging.getLogger(__name__) + + +POSTGRES_TABLE = """ +CREATE TABLE IF NOT EXISTS event_search ( + event_id TEXT, + room_id TEXT, + sender TEXT, + key TEXT, + vector tsvector +); + +CREATE INDEX event_search_fts_idx ON event_search USING gin(vector); +CREATE INDEX event_search_ev_idx ON event_search(event_id); +CREATE INDEX event_search_ev_ridx ON event_search(room_id); +""" + + +SQLITE_TABLE = ( + "CREATE VIRTUAL TABLE event_search" + " USING fts4 ( event_id, room_id, sender, key, value )" +) + + +def run_create(cur, database_engine, *args, **kwargs): + if isinstance(database_engine, PostgresEngine): + for statement in get_statements(POSTGRES_TABLE.splitlines()): + cur.execute(statement) + elif isinstance(database_engine, Sqlite3Engine): + cur.execute(SQLITE_TABLE) + else: + raise Exception("Unrecognized database engine") + + cur.execute("SELECT MIN(stream_ordering) FROM events") + rows = cur.fetchall() + min_stream_id = rows[0][0] + + cur.execute("SELECT MAX(stream_ordering) FROM events") + rows = cur.fetchall() + max_stream_id = rows[0][0] + + if min_stream_id is not None and max_stream_id is not None: + progress = { + "target_min_stream_id_inclusive": min_stream_id, + "max_stream_id_exclusive": max_stream_id + 1, + "rows_inserted": 0, + } + progress_json = simplejson.dumps(progress) + + sql = ( + "INSERT into background_updates (update_name, progress_json)" + " VALUES (?, ?)" + ) + + sql = database_engine.convert_param_style(sql) + + cur.execute(sql, ("event_search", progress_json)) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/data_stores/main/schema/delta/25/guest_access.sql b/synapse/storage/data_stores/main/schema/delta/25/guest_access.sql new file mode 100644 index 0000000000..1ea389b471 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/25/guest_access.sql @@ -0,0 +1,25 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * This is a manual index of guest_access content of state events, + * so that we can join on them in SELECT statements. + */ +CREATE TABLE IF NOT EXISTS guest_access( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + guest_access TEXT NOT NULL, + UNIQUE (event_id) +); diff --git a/synapse/storage/data_stores/main/schema/delta/25/history_visibility.sql b/synapse/storage/data_stores/main/schema/delta/25/history_visibility.sql new file mode 100644 index 0000000000..f468fc1897 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/25/history_visibility.sql @@ -0,0 +1,25 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * This is a manual index of history_visibility content of state events, + * so that we can join on them in SELECT statements. + */ +CREATE TABLE IF NOT EXISTS history_visibility( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + history_visibility TEXT NOT NULL, + UNIQUE (event_id) +); diff --git a/synapse/storage/data_stores/main/schema/delta/25/tags.sql b/synapse/storage/data_stores/main/schema/delta/25/tags.sql new file mode 100644 index 0000000000..7a32ce68e4 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/25/tags.sql @@ -0,0 +1,38 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE IF NOT EXISTS room_tags( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + tag TEXT NOT NULL, -- The name of the tag. + content TEXT NOT NULL, -- The JSON content of the tag. + CONSTRAINT room_tag_uniqueness UNIQUE (user_id, room_id, tag) +); + +CREATE TABLE IF NOT EXISTS room_tags_revisions ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + stream_id BIGINT NOT NULL, -- The current version of the room tags. + CONSTRAINT room_tag_revisions_uniqueness UNIQUE (user_id, room_id) +); + +CREATE TABLE IF NOT EXISTS private_user_data_max_stream_id( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + stream_id BIGINT NOT NULL, + CHECK (Lock='X') +); + +INSERT INTO private_user_data_max_stream_id (stream_id) VALUES (0); diff --git a/synapse/storage/data_stores/main/schema/delta/26/account_data.sql b/synapse/storage/data_stores/main/schema/delta/26/account_data.sql new file mode 100644 index 0000000000..e395de2b5e --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/26/account_data.sql @@ -0,0 +1,17 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +ALTER TABLE private_user_data_max_stream_id RENAME TO account_data_max_stream_id; diff --git a/synapse/storage/data_stores/main/schema/delta/27/account_data.sql b/synapse/storage/data_stores/main/schema/delta/27/account_data.sql new file mode 100644 index 0000000000..bf0558b5b3 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/27/account_data.sql @@ -0,0 +1,36 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS account_data( + user_id TEXT NOT NULL, + account_data_type TEXT NOT NULL, -- The type of the account_data. + stream_id BIGINT NOT NULL, -- The version of the account_data. + content TEXT NOT NULL, -- The JSON content of the account_data + CONSTRAINT account_data_uniqueness UNIQUE (user_id, account_data_type) +); + + +CREATE TABLE IF NOT EXISTS room_account_data( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + account_data_type TEXT NOT NULL, -- The type of the account_data. + stream_id BIGINT NOT NULL, -- The version of the account_data. + content TEXT NOT NULL, -- The JSON content of the account_data + CONSTRAINT room_account_data_uniqueness UNIQUE (user_id, room_id, account_data_type) +); + + +CREATE INDEX account_data_stream_id on account_data(user_id, stream_id); +CREATE INDEX room_account_data_stream_id on room_account_data(user_id, stream_id); diff --git a/synapse/storage/data_stores/main/schema/delta/27/forgotten_memberships.sql b/synapse/storage/data_stores/main/schema/delta/27/forgotten_memberships.sql new file mode 100644 index 0000000000..e2094f37fe --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/27/forgotten_memberships.sql @@ -0,0 +1,26 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Keeps track of what rooms users have left and don't want to be able to + * access again. + * + * If all users on this server have left a room, we can delete the room + * entirely. + * + * This column should always contain either 0 or 1. + */ + + ALTER TABLE room_memberships ADD COLUMN forgotten INTEGER DEFAULT 0; diff --git a/synapse/storage/data_stores/main/schema/delta/27/ts.py b/synapse/storage/data_stores/main/schema/delta/27/ts.py new file mode 100644 index 0000000000..414f9f5aa0 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/27/ts.py @@ -0,0 +1,61 @@ +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import simplejson + +from synapse.storage.prepare_database import get_statements + +logger = logging.getLogger(__name__) + + +ALTER_TABLE = ( + "ALTER TABLE events ADD COLUMN origin_server_ts BIGINT;" + "CREATE INDEX events_ts ON events(origin_server_ts, stream_ordering);" +) + + +def run_create(cur, database_engine, *args, **kwargs): + for statement in get_statements(ALTER_TABLE.splitlines()): + cur.execute(statement) + + cur.execute("SELECT MIN(stream_ordering) FROM events") + rows = cur.fetchall() + min_stream_id = rows[0][0] + + cur.execute("SELECT MAX(stream_ordering) FROM events") + rows = cur.fetchall() + max_stream_id = rows[0][0] + + if min_stream_id is not None and max_stream_id is not None: + progress = { + "target_min_stream_id_inclusive": min_stream_id, + "max_stream_id_exclusive": max_stream_id + 1, + "rows_inserted": 0, + } + progress_json = simplejson.dumps(progress) + + sql = ( + "INSERT into background_updates (update_name, progress_json)" + " VALUES (?, ?)" + ) + + sql = database_engine.convert_param_style(sql) + + cur.execute(sql, ("event_origin_server_ts", progress_json)) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/data_stores/main/schema/delta/28/event_push_actions.sql b/synapse/storage/data_stores/main/schema/delta/28/event_push_actions.sql new file mode 100644 index 0000000000..4d519849df --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/28/event_push_actions.sql @@ -0,0 +1,27 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS event_push_actions( + room_id TEXT NOT NULL, + event_id TEXT NOT NULL, + user_id TEXT NOT NULL, + profile_tag VARCHAR(32), + actions TEXT NOT NULL, + CONSTRAINT event_id_user_id_profile_tag_uniqueness UNIQUE (room_id, event_id, user_id, profile_tag) +); + + +CREATE INDEX event_push_actions_room_id_event_id_user_id_profile_tag on event_push_actions(room_id, event_id, user_id, profile_tag); +CREATE INDEX event_push_actions_room_id_user_id on event_push_actions(room_id, user_id); diff --git a/synapse/storage/data_stores/main/schema/delta/28/events_room_stream.sql b/synapse/storage/data_stores/main/schema/delta/28/events_room_stream.sql new file mode 100644 index 0000000000..36609475f1 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/28/events_room_stream.sql @@ -0,0 +1,20 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +/** Using CREATE INDEX directly is deprecated in favour of using background + * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql + * and synapse/storage/registration.py for an example using + * "access_tokens_device_index" **/ +CREATE INDEX events_room_stream on events(room_id, stream_ordering); diff --git a/synapse/storage/data_stores/main/schema/delta/28/public_roms_index.sql b/synapse/storage/data_stores/main/schema/delta/28/public_roms_index.sql new file mode 100644 index 0000000000..6c1fd68c5b --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/28/public_roms_index.sql @@ -0,0 +1,20 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +/** Using CREATE INDEX directly is deprecated in favour of using background + * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql + * and synapse/storage/registration.py for an example using + * "access_tokens_device_index" **/ +CREATE INDEX public_room_index on rooms(is_public); diff --git a/synapse/storage/data_stores/main/schema/delta/28/receipts_user_id_index.sql b/synapse/storage/data_stores/main/schema/delta/28/receipts_user_id_index.sql new file mode 100644 index 0000000000..cb84c69baa --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/28/receipts_user_id_index.sql @@ -0,0 +1,22 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Using CREATE INDEX directly is deprecated in favour of using background + * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql + * and synapse/storage/registration.py for an example using + * "access_tokens_device_index" **/ +CREATE INDEX receipts_linearized_user ON receipts_linearized( + user_id +); diff --git a/synapse/storage/data_stores/main/schema/delta/28/upgrade_times.sql b/synapse/storage/data_stores/main/schema/delta/28/upgrade_times.sql new file mode 100644 index 0000000000..3e4a9ab455 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/28/upgrade_times.sql @@ -0,0 +1,21 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Stores the timestamp when a user upgraded from a guest to a full user, if + * that happened. + */ + +ALTER TABLE users ADD COLUMN upgrade_ts BIGINT; diff --git a/synapse/storage/data_stores/main/schema/delta/28/users_is_guest.sql b/synapse/storage/data_stores/main/schema/delta/28/users_is_guest.sql new file mode 100644 index 0000000000..21d2b420bf --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/28/users_is_guest.sql @@ -0,0 +1,22 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE users ADD is_guest SMALLINT DEFAULT 0 NOT NULL; +/* + * NB: any guest users created between 27 and 28 will be incorrectly + * marked as not guests: we don't bother to fill these in correctly + * because guest access is not really complete in 27 anyway so it's + * very unlikley there will be any guest users created. + */ diff --git a/synapse/storage/data_stores/main/schema/delta/29/push_actions.sql b/synapse/storage/data_stores/main/schema/delta/29/push_actions.sql new file mode 100644 index 0000000000..84b21cf813 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/29/push_actions.sql @@ -0,0 +1,35 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE event_push_actions ADD COLUMN topological_ordering BIGINT; +ALTER TABLE event_push_actions ADD COLUMN stream_ordering BIGINT; +ALTER TABLE event_push_actions ADD COLUMN notif SMALLINT; +ALTER TABLE event_push_actions ADD COLUMN highlight SMALLINT; + +UPDATE event_push_actions SET stream_ordering = ( + SELECT stream_ordering FROM events WHERE event_id = event_push_actions.event_id +), topological_ordering = ( + SELECT topological_ordering FROM events WHERE event_id = event_push_actions.event_id +); + +UPDATE event_push_actions SET notif = 1, highlight = 0; + +/** Using CREATE INDEX directly is deprecated in favour of using background + * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql + * and synapse/storage/registration.py for an example using + * "access_tokens_device_index" **/ +CREATE INDEX event_push_actions_rm_tokens on event_push_actions( + user_id, room_id, topological_ordering, stream_ordering +); diff --git a/synapse/storage/data_stores/main/schema/delta/30/alias_creator.sql b/synapse/storage/data_stores/main/schema/delta/30/alias_creator.sql new file mode 100644 index 0000000000..c9d0dde638 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/30/alias_creator.sql @@ -0,0 +1,16 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE room_aliases ADD COLUMN creator TEXT; diff --git a/synapse/storage/data_stores/main/schema/delta/30/as_users.py b/synapse/storage/data_stores/main/schema/delta/30/as_users.py new file mode 100644 index 0000000000..9b95411fb6 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/30/as_users.py @@ -0,0 +1,69 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from six.moves import range + +from synapse.config.appservice import load_appservices + +logger = logging.getLogger(__name__) + + +def run_create(cur, database_engine, *args, **kwargs): + # NULL indicates user was not registered by an appservice. + try: + cur.execute("ALTER TABLE users ADD COLUMN appservice_id TEXT") + except Exception: + # Maybe we already added the column? Hope so... + pass + + +def run_upgrade(cur, database_engine, config, *args, **kwargs): + cur.execute("SELECT name FROM users") + rows = cur.fetchall() + + config_files = [] + try: + config_files = config.app_service_config_files + except AttributeError: + logger.warning("Could not get app_service_config_files from config") + pass + + appservices = load_appservices(config.server_name, config_files) + + owned = {} + + for row in rows: + user_id = row[0] + for appservice in appservices: + if appservice.is_exclusive_user(user_id): + if user_id in owned.keys(): + logger.error( + "user_id %s was owned by more than one application" + " service (IDs %s and %s); assigning arbitrarily to %s" + % (user_id, owned[user_id], appservice.id, owned[user_id]) + ) + owned.setdefault(appservice.id, []).append(user_id) + + for as_id, user_ids in owned.items(): + n = 100 + user_chunks = (user_ids[i : i + 100] for i in range(0, len(user_ids), n)) + for chunk in user_chunks: + cur.execute( + database_engine.convert_param_style( + "UPDATE users SET appservice_id = ? WHERE name IN (%s)" + % (",".join("?" for _ in chunk),) + ), + [as_id] + chunk, + ) diff --git a/synapse/storage/data_stores/main/schema/delta/30/deleted_pushers.sql b/synapse/storage/data_stores/main/schema/delta/30/deleted_pushers.sql new file mode 100644 index 0000000000..712c454aa1 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/30/deleted_pushers.sql @@ -0,0 +1,25 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS deleted_pushers( + stream_id BIGINT NOT NULL, + app_id TEXT NOT NULL, + pushkey TEXT NOT NULL, + user_id TEXT NOT NULL, + /* We only track the most recent delete for each app_id, pushkey and user_id. */ + UNIQUE (app_id, pushkey, user_id) +); + +CREATE INDEX deleted_pushers_stream_id ON deleted_pushers (stream_id); diff --git a/synapse/storage/data_stores/main/schema/delta/30/presence_stream.sql b/synapse/storage/data_stores/main/schema/delta/30/presence_stream.sql new file mode 100644 index 0000000000..606bbb037d --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/30/presence_stream.sql @@ -0,0 +1,30 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + CREATE TABLE presence_stream( + stream_id BIGINT, + user_id TEXT, + state TEXT, + last_active_ts BIGINT, + last_federation_update_ts BIGINT, + last_user_sync_ts BIGINT, + status_msg TEXT, + currently_active BOOLEAN + ); + + CREATE INDEX presence_stream_id ON presence_stream(stream_id, user_id); + CREATE INDEX presence_stream_user_id ON presence_stream(user_id); + CREATE INDEX presence_stream_state ON presence_stream(state); diff --git a/synapse/storage/data_stores/main/schema/delta/30/public_rooms.sql b/synapse/storage/data_stores/main/schema/delta/30/public_rooms.sql new file mode 100644 index 0000000000..f09db4faa6 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/30/public_rooms.sql @@ -0,0 +1,23 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +/* This release removes the restriction that published rooms must have an alias, + * so we go back and ensure the only 'public' rooms are ones with an alias. + * We use (1 = 0) and (1 = 1) so that it works in both postgres and sqlite + */ +UPDATE rooms SET is_public = (1 = 0) WHERE is_public = (1 = 1) AND room_id not in ( + SELECT room_id FROM room_aliases +); diff --git a/synapse/storage/data_stores/main/schema/delta/30/push_rule_stream.sql b/synapse/storage/data_stores/main/schema/delta/30/push_rule_stream.sql new file mode 100644 index 0000000000..735aa8d5f6 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/30/push_rule_stream.sql @@ -0,0 +1,38 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + +CREATE TABLE push_rules_stream( + stream_id BIGINT NOT NULL, + event_stream_ordering BIGINT NOT NULL, + user_id TEXT NOT NULL, + rule_id TEXT NOT NULL, + op TEXT NOT NULL, -- One of "ENABLE", "DISABLE", "ACTIONS", "ADD", "DELETE" + priority_class SMALLINT, + priority INTEGER, + conditions TEXT, + actions TEXT +); + +-- The extra data for each operation is: +-- * ENABLE, DISABLE, DELETE: [] +-- * ACTIONS: ["actions"] +-- * ADD: ["priority_class", "priority", "actions", "conditions"] + +-- Index for replication queries. +CREATE INDEX push_rules_stream_id ON push_rules_stream(stream_id); +-- Index for /sync queries. +CREATE INDEX push_rules_stream_user_stream_id on push_rules_stream(user_id, stream_id); diff --git a/synapse/storage/data_stores/main/schema/delta/30/state_stream.sql b/synapse/storage/data_stores/main/schema/delta/30/state_stream.sql new file mode 100644 index 0000000000..e85699e82e --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/30/state_stream.sql @@ -0,0 +1,33 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +/* We used to create a table called current_state_resets, but this is no + * longer used and is removed in delta 54. + */ + +/* The outlier events that have aquired a state group typically through + * backfill. This is tracked separately to the events table, as assigning a + * state group change the position of the existing event in the stream + * ordering. + * However since a stream_ordering is assigned in persist_event for the + * (event, state) pair, we can use that stream_ordering to identify when + * the new state was assigned for the event. + */ +CREATE TABLE IF NOT EXISTS ex_outlier_stream( + event_stream_ordering BIGINT PRIMARY KEY NOT NULL, + event_id TEXT NOT NULL, + state_group BIGINT NOT NULL +); diff --git a/synapse/storage/data_stores/main/schema/delta/30/threepid_guest_access_tokens.sql b/synapse/storage/data_stores/main/schema/delta/30/threepid_guest_access_tokens.sql new file mode 100644 index 0000000000..0dd2f1360c --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/30/threepid_guest_access_tokens.sql @@ -0,0 +1,24 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Stores guest account access tokens generated for unbound 3pids. +CREATE TABLE threepid_guest_access_tokens( + medium TEXT, -- The medium of the 3pid. Must be "email". + address TEXT, -- The 3pid address. + guest_access_token TEXT, -- The access token for a guest user for this 3pid. + first_inviter TEXT -- User ID of the first user to invite this 3pid to a room. +); + +CREATE UNIQUE INDEX threepid_guest_access_tokens_index ON threepid_guest_access_tokens(medium, address); diff --git a/synapse/storage/data_stores/main/schema/delta/31/invites.sql b/synapse/storage/data_stores/main/schema/delta/31/invites.sql new file mode 100644 index 0000000000..2c57846d5a --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/31/invites.sql @@ -0,0 +1,42 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE local_invites( + stream_id BIGINT NOT NULL, + inviter TEXT NOT NULL, + invitee TEXT NOT NULL, + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + locally_rejected TEXT, + replaced_by TEXT +); + +-- Insert all invites for local users into new `invites` table +INSERT INTO local_invites SELECT + stream_ordering as stream_id, + sender as inviter, + state_key as invitee, + event_id, + room_id, + NULL as locally_rejected, + NULL as replaced_by + FROM events + NATURAL JOIN current_state_events + NATURAL JOIN room_memberships + WHERE membership = 'invite' AND state_key IN (SELECT name FROM users); + +CREATE INDEX local_invites_id ON local_invites(stream_id); +CREATE INDEX local_invites_for_user_idx ON local_invites(invitee, locally_rejected, replaced_by, room_id); diff --git a/synapse/storage/data_stores/main/schema/delta/31/local_media_repository_url_cache.sql b/synapse/storage/data_stores/main/schema/delta/31/local_media_repository_url_cache.sql new file mode 100644 index 0000000000..9efb4280eb --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/31/local_media_repository_url_cache.sql @@ -0,0 +1,27 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE local_media_repository_url_cache( + url TEXT, -- the URL being cached + response_code INTEGER, -- the HTTP response code of this download attempt + etag TEXT, -- the etag header of this response + expires INTEGER, -- the number of ms this response was valid for + og TEXT, -- cache of the OG metadata of this URL as JSON + media_id TEXT, -- the media_id, if any, of the URL's content in the repo + download_ts BIGINT -- the timestamp of this download attempt +); + +CREATE INDEX local_media_repository_url_cache_by_url_download_ts + ON local_media_repository_url_cache(url, download_ts); diff --git a/synapse/storage/data_stores/main/schema/delta/31/pushers.py b/synapse/storage/data_stores/main/schema/delta/31/pushers.py new file mode 100644 index 0000000000..9bb504aad5 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/31/pushers.py @@ -0,0 +1,87 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Change the last_token to last_stream_ordering now that pushers no longer +# listen on an event stream but instead select out of the event_push_actions +# table. + + +import logging + +logger = logging.getLogger(__name__) + + +def token_to_stream_ordering(token): + return int(token[1:].split("_")[0]) + + +def run_create(cur, database_engine, *args, **kwargs): + logger.info("Porting pushers table, delta 31...") + cur.execute( + """ + CREATE TABLE IF NOT EXISTS pushers2 ( + id BIGINT PRIMARY KEY, + user_name TEXT NOT NULL, + access_token BIGINT DEFAULT NULL, + profile_tag VARCHAR(32) NOT NULL, + kind VARCHAR(8) NOT NULL, + app_id VARCHAR(64) NOT NULL, + app_display_name VARCHAR(64) NOT NULL, + device_display_name VARCHAR(128) NOT NULL, + pushkey TEXT NOT NULL, + ts BIGINT NOT NULL, + lang VARCHAR(8), + data TEXT, + last_stream_ordering INTEGER, + last_success BIGINT, + failing_since BIGINT, + UNIQUE (app_id, pushkey, user_name) + ) + """ + ) + cur.execute( + """SELECT + id, user_name, access_token, profile_tag, kind, + app_id, app_display_name, device_display_name, + pushkey, ts, lang, data, last_token, last_success, + failing_since + FROM pushers + """ + ) + count = 0 + for row in cur.fetchall(): + row = list(row) + row[12] = token_to_stream_ordering(row[12]) + cur.execute( + database_engine.convert_param_style( + """ + INSERT into pushers2 ( + id, user_name, access_token, profile_tag, kind, + app_id, app_display_name, device_display_name, + pushkey, ts, lang, data, last_stream_ordering, last_success, + failing_since + ) values (%s)""" + % (",".join(["?" for _ in range(len(row))])) + ), + row, + ) + count += 1 + cur.execute("DROP TABLE pushers") + cur.execute("ALTER TABLE pushers2 RENAME TO pushers") + logger.info("Moved %d pushers to new table", count) + + +def run_upgrade(cur, database_engine, *args, **kwargs): + pass diff --git a/synapse/storage/data_stores/main/schema/delta/31/pushers_index.sql b/synapse/storage/data_stores/main/schema/delta/31/pushers_index.sql new file mode 100644 index 0000000000..a82add88fd --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/31/pushers_index.sql @@ -0,0 +1,22 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Using CREATE INDEX directly is deprecated in favour of using background + * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql + * and synapse/storage/registration.py for an example using + * "access_tokens_device_index" **/ + CREATE INDEX event_push_actions_stream_ordering on event_push_actions( + stream_ordering, user_id + ); diff --git a/synapse/storage/data_stores/main/schema/delta/31/search_update.py b/synapse/storage/data_stores/main/schema/delta/31/search_update.py new file mode 100644 index 0000000000..7d8ca5f93f --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/31/search_update.py @@ -0,0 +1,66 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import simplejson + +from synapse.storage.engines import PostgresEngine +from synapse.storage.prepare_database import get_statements + +logger = logging.getLogger(__name__) + + +ALTER_TABLE = """ +ALTER TABLE event_search ADD COLUMN origin_server_ts BIGINT; +ALTER TABLE event_search ADD COLUMN stream_ordering BIGINT; +""" + + +def run_create(cur, database_engine, *args, **kwargs): + if not isinstance(database_engine, PostgresEngine): + return + + for statement in get_statements(ALTER_TABLE.splitlines()): + cur.execute(statement) + + cur.execute("SELECT MIN(stream_ordering) FROM events") + rows = cur.fetchall() + min_stream_id = rows[0][0] + + cur.execute("SELECT MAX(stream_ordering) FROM events") + rows = cur.fetchall() + max_stream_id = rows[0][0] + + if min_stream_id is not None and max_stream_id is not None: + progress = { + "target_min_stream_id_inclusive": min_stream_id, + "max_stream_id_exclusive": max_stream_id + 1, + "rows_inserted": 0, + "have_added_indexes": False, + } + progress_json = simplejson.dumps(progress) + + sql = ( + "INSERT into background_updates (update_name, progress_json)" + " VALUES (?, ?)" + ) + + sql = database_engine.convert_param_style(sql) + + cur.execute(sql, ("event_search_order", progress_json)) + + +def run_upgrade(cur, database_engine, *args, **kwargs): + pass diff --git a/synapse/storage/data_stores/main/schema/delta/32/events.sql b/synapse/storage/data_stores/main/schema/delta/32/events.sql new file mode 100644 index 0000000000..1dd0f9e170 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/32/events.sql @@ -0,0 +1,16 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE events ADD COLUMN received_ts BIGINT; diff --git a/synapse/storage/data_stores/main/schema/delta/32/openid.sql b/synapse/storage/data_stores/main/schema/delta/32/openid.sql new file mode 100644 index 0000000000..36f37b11c8 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/32/openid.sql @@ -0,0 +1,9 @@ + +CREATE TABLE open_id_tokens ( + token TEXT NOT NULL PRIMARY KEY, + ts_valid_until_ms bigint NOT NULL, + user_id TEXT NOT NULL, + UNIQUE (token) +); + +CREATE index open_id_tokens_ts_valid_until_ms ON open_id_tokens(ts_valid_until_ms); diff --git a/synapse/storage/data_stores/main/schema/delta/32/pusher_throttle.sql b/synapse/storage/data_stores/main/schema/delta/32/pusher_throttle.sql new file mode 100644 index 0000000000..d86d30c13c --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/32/pusher_throttle.sql @@ -0,0 +1,23 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE pusher_throttle( + pusher BIGINT NOT NULL, + room_id TEXT NOT NULL, + last_sent_ts BIGINT, + throttle_ms BIGINT, + PRIMARY KEY (pusher, room_id) +); diff --git a/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql b/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql new file mode 100644 index 0000000000..4219cdd06a --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql @@ -0,0 +1,34 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- The following indices are redundant, other indices are equivalent or +-- supersets +DROP INDEX IF EXISTS events_room_id; -- Prefix of events_room_stream +DROP INDEX IF EXISTS events_order; -- Prefix of events_order_topo_stream_room +DROP INDEX IF EXISTS events_topological_ordering; -- Prefix of events_order_topo_stream_room +DROP INDEX IF EXISTS events_stream_ordering; -- Duplicate of PRIMARY KEY +DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY +DROP INDEX IF EXISTS event_to_state_groups_id; -- Duplicate of PRIMARY KEY +DROP INDEX IF EXISTS event_push_actions_room_id_event_id_user_id_profile_tag; -- Duplicate of UNIQUE CONSTRAINT + +DROP INDEX IF EXISTS st_extrem_id; -- Prefix of UNIQUE CONSTRAINT +DROP INDEX IF EXISTS event_signatures_id; -- Prefix of UNIQUE CONSTRAINT +DROP INDEX IF EXISTS redactions_event_id; -- Duplicate of UNIQUE CONSTRAINT + +-- The following indices were unused +DROP INDEX IF EXISTS remote_media_cache_thumbnails_media_id; +DROP INDEX IF EXISTS evauth_edges_auth_id; +DROP INDEX IF EXISTS presence_stream_state; diff --git a/synapse/storage/data_stores/main/schema/delta/32/reports.sql b/synapse/storage/data_stores/main/schema/delta/32/reports.sql new file mode 100644 index 0000000000..d13609776f --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/32/reports.sql @@ -0,0 +1,25 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE event_reports( + id BIGINT NOT NULL PRIMARY KEY, + received_ts BIGINT NOT NULL, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL, + user_id TEXT NOT NULL, + reason TEXT, + content TEXT +); diff --git a/synapse/storage/data_stores/main/schema/delta/33/access_tokens_device_index.sql b/synapse/storage/data_stores/main/schema/delta/33/access_tokens_device_index.sql new file mode 100644 index 0000000000..61ad3fe3e8 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/33/access_tokens_device_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('access_tokens_device_index', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/33/devices.sql b/synapse/storage/data_stores/main/schema/delta/33/devices.sql new file mode 100644 index 0000000000..eca7268d82 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/33/devices.sql @@ -0,0 +1,21 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE devices ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + display_name TEXT, + CONSTRAINT device_uniqueness UNIQUE (user_id, device_id) +); diff --git a/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys.sql b/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys.sql new file mode 100644 index 0000000000..aa4a3b9f2f --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys.sql @@ -0,0 +1,19 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- make sure that we have a device record for each set of E2E keys, so that the +-- user can delete them if they like. +INSERT INTO devices + SELECT user_id, device_id, NULL FROM e2e_device_keys_json; diff --git a/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql b/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql new file mode 100644 index 0000000000..6671573398 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql @@ -0,0 +1,20 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- a previous version of the "devices_for_e2e_keys" delta set all the device +-- names to "unknown device". This wasn't terribly helpful +UPDATE devices + SET display_name = NULL + WHERE display_name = 'unknown device'; diff --git a/synapse/storage/data_stores/main/schema/delta/33/event_fields.py b/synapse/storage/data_stores/main/schema/delta/33/event_fields.py new file mode 100644 index 0000000000..bff1256a7b --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/33/event_fields.py @@ -0,0 +1,61 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import simplejson + +from synapse.storage.prepare_database import get_statements + +logger = logging.getLogger(__name__) + + +ALTER_TABLE = """ +ALTER TABLE events ADD COLUMN sender TEXT; +ALTER TABLE events ADD COLUMN contains_url BOOLEAN; +""" + + +def run_create(cur, database_engine, *args, **kwargs): + for statement in get_statements(ALTER_TABLE.splitlines()): + cur.execute(statement) + + cur.execute("SELECT MIN(stream_ordering) FROM events") + rows = cur.fetchall() + min_stream_id = rows[0][0] + + cur.execute("SELECT MAX(stream_ordering) FROM events") + rows = cur.fetchall() + max_stream_id = rows[0][0] + + if min_stream_id is not None and max_stream_id is not None: + progress = { + "target_min_stream_id_inclusive": min_stream_id, + "max_stream_id_exclusive": max_stream_id + 1, + "rows_inserted": 0, + } + progress_json = simplejson.dumps(progress) + + sql = ( + "INSERT into background_updates (update_name, progress_json)" + " VALUES (?, ?)" + ) + + sql = database_engine.convert_param_style(sql) + + cur.execute(sql, ("event_fields_sender_url", progress_json)) + + +def run_upgrade(cur, database_engine, *args, **kwargs): + pass diff --git a/synapse/storage/data_stores/main/schema/delta/33/remote_media_ts.py b/synapse/storage/data_stores/main/schema/delta/33/remote_media_ts.py new file mode 100644 index 0000000000..a26057dfb6 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/33/remote_media_ts.py @@ -0,0 +1,30 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +ALTER_TABLE = "ALTER TABLE remote_media_cache ADD COLUMN last_access_ts BIGINT" + + +def run_create(cur, database_engine, *args, **kwargs): + cur.execute(ALTER_TABLE) + + +def run_upgrade(cur, database_engine, *args, **kwargs): + cur.execute( + database_engine.convert_param_style( + "UPDATE remote_media_cache SET last_access_ts = ?" + ), + (int(time.time() * 1000),), + ) diff --git a/synapse/storage/data_stores/main/schema/delta/33/user_ips_index.sql b/synapse/storage/data_stores/main/schema/delta/33/user_ips_index.sql new file mode 100644 index 0000000000..473f75a78e --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/33/user_ips_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('user_ips_device_index', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/34/appservice_stream.sql b/synapse/storage/data_stores/main/schema/delta/34/appservice_stream.sql new file mode 100644 index 0000000000..69e16eda0f --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/34/appservice_stream.sql @@ -0,0 +1,23 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS appservice_stream_position( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + stream_ordering BIGINT, + CHECK (Lock='X') +); + +INSERT INTO appservice_stream_position (stream_ordering) + SELECT COALESCE(MAX(stream_ordering), 0) FROM events; diff --git a/synapse/storage/data_stores/main/schema/delta/34/cache_stream.py b/synapse/storage/data_stores/main/schema/delta/34/cache_stream.py new file mode 100644 index 0000000000..cf09e43e2b --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/34/cache_stream.py @@ -0,0 +1,46 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.storage.engines import PostgresEngine +from synapse.storage.prepare_database import get_statements + +logger = logging.getLogger(__name__) + + +# This stream is used to notify replication slaves that some caches have +# been invalidated that they cannot infer from the other streams. +CREATE_TABLE = """ +CREATE TABLE cache_invalidation_stream ( + stream_id BIGINT, + cache_func TEXT, + keys TEXT[], + invalidation_ts BIGINT +); + +CREATE INDEX cache_invalidation_stream_id ON cache_invalidation_stream(stream_id); +""" + + +def run_create(cur, database_engine, *args, **kwargs): + if not isinstance(database_engine, PostgresEngine): + return + + for statement in get_statements(CREATE_TABLE.splitlines()): + cur.execute(statement) + + +def run_upgrade(cur, database_engine, *args, **kwargs): + pass diff --git a/synapse/storage/data_stores/main/schema/delta/34/device_inbox.sql b/synapse/storage/data_stores/main/schema/delta/34/device_inbox.sql new file mode 100644 index 0000000000..e68844c74a --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/34/device_inbox.sql @@ -0,0 +1,24 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE device_inbox ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + stream_id BIGINT NOT NULL, + message_json TEXT NOT NULL -- {"type":, "sender":, "content",} +); + +CREATE INDEX device_inbox_user_stream_id ON device_inbox(user_id, device_id, stream_id); +CREATE INDEX device_inbox_stream_id ON device_inbox(stream_id); diff --git a/synapse/storage/data_stores/main/schema/delta/34/push_display_name_rename.sql b/synapse/storage/data_stores/main/schema/delta/34/push_display_name_rename.sql new file mode 100644 index 0000000000..0d9fe1a99a --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/34/push_display_name_rename.sql @@ -0,0 +1,20 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +DELETE FROM push_rules WHERE rule_id = 'global/override/.m.rule.contains_display_name'; +UPDATE push_rules SET rule_id = 'global/override/.m.rule.contains_display_name' WHERE rule_id = 'global/underride/.m.rule.contains_display_name'; + +DELETE FROM push_rules_enable WHERE rule_id = 'global/override/.m.rule.contains_display_name'; +UPDATE push_rules_enable SET rule_id = 'global/override/.m.rule.contains_display_name' WHERE rule_id = 'global/underride/.m.rule.contains_display_name'; diff --git a/synapse/storage/data_stores/main/schema/delta/34/received_txn_purge.py b/synapse/storage/data_stores/main/schema/delta/34/received_txn_purge.py new file mode 100644 index 0000000000..67d505e68b --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/34/received_txn_purge.py @@ -0,0 +1,32 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.storage.engines import PostgresEngine + +logger = logging.getLogger(__name__) + + +def run_create(cur, database_engine, *args, **kwargs): + if isinstance(database_engine, PostgresEngine): + cur.execute("TRUNCATE received_transactions") + else: + cur.execute("DELETE FROM received_transactions") + + cur.execute("CREATE INDEX received_transactions_ts ON received_transactions(ts)") + + +def run_upgrade(cur, database_engine, *args, **kwargs): + pass diff --git a/synapse/storage/data_stores/main/schema/delta/35/add_state_index.sql b/synapse/storage/data_stores/main/schema/delta/35/add_state_index.sql new file mode 100644 index 0000000000..33980d02f0 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/35/add_state_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT into background_updates (update_name, progress_json, depends_on) + VALUES ('state_group_state_type_index', '{}', 'state_group_state_deduplication'); diff --git a/synapse/storage/data_stores/main/schema/delta/35/contains_url.sql b/synapse/storage/data_stores/main/schema/delta/35/contains_url.sql new file mode 100644 index 0000000000..6cd123027b --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/35/contains_url.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + INSERT into background_updates (update_name, progress_json) + VALUES ('event_contains_url_index', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/35/device_outbox.sql b/synapse/storage/data_stores/main/schema/delta/35/device_outbox.sql new file mode 100644 index 0000000000..17e6c43105 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/35/device_outbox.sql @@ -0,0 +1,39 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +DROP TABLE IF EXISTS device_federation_outbox; +CREATE TABLE device_federation_outbox ( + destination TEXT NOT NULL, + stream_id BIGINT NOT NULL, + queued_ts BIGINT NOT NULL, + messages_json TEXT NOT NULL +); + + +DROP INDEX IF EXISTS device_federation_outbox_destination_id; +CREATE INDEX device_federation_outbox_destination_id + ON device_federation_outbox(destination, stream_id); + + +DROP TABLE IF EXISTS device_federation_inbox; +CREATE TABLE device_federation_inbox ( + origin TEXT NOT NULL, + message_id TEXT NOT NULL, + received_ts BIGINT NOT NULL +); + +DROP INDEX IF EXISTS device_federation_inbox_sender_id; +CREATE INDEX device_federation_inbox_sender_id + ON device_federation_inbox(origin, message_id); diff --git a/synapse/storage/data_stores/main/schema/delta/35/device_stream_id.sql b/synapse/storage/data_stores/main/schema/delta/35/device_stream_id.sql new file mode 100644 index 0000000000..7ab7d942e2 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/35/device_stream_id.sql @@ -0,0 +1,21 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE device_max_stream_id ( + stream_id BIGINT NOT NULL +); + +INSERT INTO device_max_stream_id (stream_id) + SELECT COALESCE(MAX(stream_id), 0) FROM device_inbox; diff --git a/synapse/storage/data_stores/main/schema/delta/35/event_push_actions_index.sql b/synapse/storage/data_stores/main/schema/delta/35/event_push_actions_index.sql new file mode 100644 index 0000000000..2e836d8e9c --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/35/event_push_actions_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + INSERT into background_updates (update_name, progress_json) + VALUES ('epa_highlight_index', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/35/public_room_list_change_stream.sql b/synapse/storage/data_stores/main/schema/delta/35/public_room_list_change_stream.sql new file mode 100644 index 0000000000..dd2bf2e28a --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/35/public_room_list_change_stream.sql @@ -0,0 +1,33 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE public_room_list_stream ( + stream_id BIGINT NOT NULL, + room_id TEXT NOT NULL, + visibility BOOLEAN NOT NULL +); + +INSERT INTO public_room_list_stream (stream_id, room_id, visibility) + SELECT 1, room_id, is_public FROM rooms + WHERE is_public = CAST(1 AS BOOLEAN); + +CREATE INDEX public_room_list_stream_idx on public_room_list_stream( + stream_id +); + +CREATE INDEX public_room_list_stream_rm_idx on public_room_list_stream( + room_id, stream_id +); diff --git a/synapse/storage/data_stores/main/schema/delta/35/state.sql b/synapse/storage/data_stores/main/schema/delta/35/state.sql new file mode 100644 index 0000000000..0f1fa68a89 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/35/state.sql @@ -0,0 +1,22 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE state_group_edges( + state_group BIGINT NOT NULL, + prev_state_group BIGINT NOT NULL +); + +CREATE INDEX state_group_edges_idx ON state_group_edges(state_group); +CREATE INDEX state_group_edges_prev_idx ON state_group_edges(prev_state_group); diff --git a/synapse/storage/data_stores/main/schema/delta/35/state_dedupe.sql b/synapse/storage/data_stores/main/schema/delta/35/state_dedupe.sql new file mode 100644 index 0000000000..97e5067ef4 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/35/state_dedupe.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT into background_updates (update_name, progress_json) + VALUES ('state_group_state_deduplication', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/35/stream_order_to_extrem.sql b/synapse/storage/data_stores/main/schema/delta/35/stream_order_to_extrem.sql new file mode 100644 index 0000000000..2b945d8a57 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/35/stream_order_to_extrem.sql @@ -0,0 +1,37 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE stream_ordering_to_exterm ( + stream_ordering BIGINT NOT NULL, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL +); + +INSERT INTO stream_ordering_to_exterm (stream_ordering, room_id, event_id) + SELECT stream_ordering, room_id, event_id FROM event_forward_extremities + INNER JOIN ( + SELECT room_id, max(stream_ordering) as stream_ordering FROM events + INNER JOIN event_forward_extremities USING (room_id, event_id) + GROUP BY room_id + ) AS rms USING (room_id); + +CREATE INDEX stream_ordering_to_exterm_idx on stream_ordering_to_exterm( + stream_ordering +); + +CREATE INDEX stream_ordering_to_exterm_rm_idx on stream_ordering_to_exterm( + room_id, stream_ordering +); diff --git a/synapse/storage/data_stores/main/schema/delta/36/readd_public_rooms.sql b/synapse/storage/data_stores/main/schema/delta/36/readd_public_rooms.sql new file mode 100644 index 0000000000..90d8fd18f9 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/36/readd_public_rooms.sql @@ -0,0 +1,26 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Re-add some entries to stream_ordering_to_exterm that were incorrectly deleted +INSERT INTO stream_ordering_to_exterm (stream_ordering, room_id, event_id) + SELECT + (SELECT stream_ordering FROM events where event_id = e.event_id) AS stream_ordering, + room_id, + event_id + FROM event_forward_extremities AS e + WHERE NOT EXISTS ( + SELECT room_id FROM stream_ordering_to_exterm AS s + WHERE s.room_id = e.room_id + ); diff --git a/synapse/storage/data_stores/main/schema/delta/37/remove_auth_idx.py b/synapse/storage/data_stores/main/schema/delta/37/remove_auth_idx.py new file mode 100644 index 0000000000..a377884169 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/37/remove_auth_idx.py @@ -0,0 +1,85 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.storage.engines import PostgresEngine +from synapse.storage.prepare_database import get_statements + +logger = logging.getLogger(__name__) + +DROP_INDICES = """ +-- We only ever query based on event_id +DROP INDEX IF EXISTS state_events_room_id; +DROP INDEX IF EXISTS state_events_type; +DROP INDEX IF EXISTS state_events_state_key; + +-- room_id is indexed elsewhere +DROP INDEX IF EXISTS current_state_events_room_id; +DROP INDEX IF EXISTS current_state_events_state_key; +DROP INDEX IF EXISTS current_state_events_type; + +DROP INDEX IF EXISTS transactions_have_ref; + +-- (topological_ordering, stream_ordering, room_id) seems like a strange index, +-- and is used incredibly rarely. +DROP INDEX IF EXISTS events_order_topo_stream_room; + +-- an equivalent index to this actually gets re-created in delta 41, because it +-- turned out that deleting it wasn't a great plan :/. In any case, let's +-- delete it here, and delta 41 will create a new one with an added UNIQUE +-- constraint +DROP INDEX IF EXISTS event_search_ev_idx; +""" + +POSTGRES_DROP_CONSTRAINT = """ +ALTER TABLE event_auth DROP CONSTRAINT IF EXISTS event_auth_event_id_auth_id_room_id_key; +""" + +SQLITE_DROP_CONSTRAINT = """ +DROP INDEX IF EXISTS evauth_edges_id; + +CREATE TABLE IF NOT EXISTS event_auth_new( + event_id TEXT NOT NULL, + auth_id TEXT NOT NULL, + room_id TEXT NOT NULL +); + +INSERT INTO event_auth_new + SELECT event_id, auth_id, room_id + FROM event_auth; + +DROP TABLE event_auth; + +ALTER TABLE event_auth_new RENAME TO event_auth; + +CREATE INDEX evauth_edges_id ON event_auth(event_id); +""" + + +def run_create(cur, database_engine, *args, **kwargs): + for statement in get_statements(DROP_INDICES.splitlines()): + cur.execute(statement) + + if isinstance(database_engine, PostgresEngine): + drop_constraint = POSTGRES_DROP_CONSTRAINT + else: + drop_constraint = SQLITE_DROP_CONSTRAINT + + for statement in get_statements(drop_constraint.splitlines()): + cur.execute(statement) + + +def run_upgrade(cur, database_engine, *args, **kwargs): + pass diff --git a/synapse/storage/data_stores/main/schema/delta/37/user_threepids.sql b/synapse/storage/data_stores/main/schema/delta/37/user_threepids.sql new file mode 100644 index 0000000000..cf7a90dd10 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/37/user_threepids.sql @@ -0,0 +1,52 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Update any email addresses that were stored with mixed case into all + * lowercase + */ + + -- There may be "duplicate" emails (with different case) already in the table, + -- so we find them and move all but the most recently used account. + UPDATE user_threepids + SET medium = 'email_old' + WHERE medium = 'email' + AND address IN ( + -- We select all the addresses that are linked to the user_id that is NOT + -- the most recently created. + SELECT u.address + FROM + user_threepids AS u, + -- `duplicate_addresses` is a table of all the email addresses that + -- appear multiple times and when the binding was created + ( + SELECT lower(u1.address) AS address, max(u1.added_at) AS max_ts + FROM user_threepids AS u1 + INNER JOIN user_threepids AS u2 ON u1.medium = u2.medium AND lower(u1.address) = lower(u2.address) AND u1.address != u2.address + WHERE u1.medium = 'email' AND u2.medium = 'email' + GROUP BY lower(u1.address) + ) AS duplicate_addresses + WHERE + lower(u.address) = duplicate_addresses.address + AND u.added_at != max_ts -- NOT the most recently created + ); + + +-- This update is now safe since we've removed the duplicate addresses. +UPDATE user_threepids SET address = LOWER(address) WHERE medium = 'email'; + + +/* Add an index for the select we do on passwored reset */ +CREATE INDEX user_threepids_medium_address on user_threepids (medium, address); diff --git a/synapse/storage/data_stores/main/schema/delta/38/postgres_fts_gist.sql b/synapse/storage/data_stores/main/schema/delta/38/postgres_fts_gist.sql new file mode 100644 index 0000000000..515e6b8e84 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/38/postgres_fts_gist.sql @@ -0,0 +1,19 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- We no longer do this given we back it out again in schema 47 + +-- INSERT into background_updates (update_name, progress_json) +-- VALUES ('event_search_postgres_gist', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/39/appservice_room_list.sql b/synapse/storage/data_stores/main/schema/delta/39/appservice_room_list.sql new file mode 100644 index 0000000000..74bdc49073 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/39/appservice_room_list.sql @@ -0,0 +1,29 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE appservice_room_list( + appservice_id TEXT NOT NULL, + network_id TEXT NOT NULL, + room_id TEXT NOT NULL +); + +-- Each appservice can have multiple published room lists associated with them, +-- keyed of a particular network_id +CREATE UNIQUE INDEX appservice_room_list_idx ON appservice_room_list( + appservice_id, network_id, room_id +); + +ALTER TABLE public_room_list_stream ADD COLUMN appservice_id TEXT; +ALTER TABLE public_room_list_stream ADD COLUMN network_id TEXT; diff --git a/synapse/storage/data_stores/main/schema/delta/39/device_federation_stream_idx.sql b/synapse/storage/data_stores/main/schema/delta/39/device_federation_stream_idx.sql new file mode 100644 index 0000000000..00be801e90 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/39/device_federation_stream_idx.sql @@ -0,0 +1,16 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE INDEX device_federation_outbox_id ON device_federation_outbox(stream_id); diff --git a/synapse/storage/data_stores/main/schema/delta/39/event_push_index.sql b/synapse/storage/data_stores/main/schema/delta/39/event_push_index.sql new file mode 100644 index 0000000000..de2ad93e5c --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/39/event_push_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('event_push_actions_highlights_index', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/39/federation_out_position.sql b/synapse/storage/data_stores/main/schema/delta/39/federation_out_position.sql new file mode 100644 index 0000000000..5af814290b --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/39/federation_out_position.sql @@ -0,0 +1,22 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + CREATE TABLE federation_stream_position( + type TEXT NOT NULL, + stream_id INTEGER NOT NULL + ); + + INSERT INTO federation_stream_position (type, stream_id) VALUES ('federation', -1); + INSERT INTO federation_stream_position (type, stream_id) SELECT 'events', coalesce(max(stream_ordering), -1) FROM events; diff --git a/synapse/storage/data_stores/main/schema/delta/39/membership_profile.sql b/synapse/storage/data_stores/main/schema/delta/39/membership_profile.sql new file mode 100644 index 0000000000..1bf911c8ab --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/39/membership_profile.sql @@ -0,0 +1,20 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE room_memberships ADD COLUMN display_name TEXT; +ALTER TABLE room_memberships ADD COLUMN avatar_url TEXT; + +INSERT into background_updates (update_name, progress_json) + VALUES ('room_membership_profile_update', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/40/current_state_idx.sql b/synapse/storage/data_stores/main/schema/delta/40/current_state_idx.sql new file mode 100644 index 0000000000..7ffa189f39 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/40/current_state_idx.sql @@ -0,0 +1,17 @@ +/* Copyright 2017 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('current_state_members_idx', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/40/device_inbox.sql b/synapse/storage/data_stores/main/schema/delta/40/device_inbox.sql new file mode 100644 index 0000000000..b9fe1f0480 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/40/device_inbox.sql @@ -0,0 +1,21 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- turn the pre-fill startup query into a index-only scan on postgresql. +INSERT into background_updates (update_name, progress_json) + VALUES ('device_inbox_stream_index', '{}'); + +INSERT into background_updates (update_name, progress_json, depends_on) + VALUES ('device_inbox_stream_drop', '{}', 'device_inbox_stream_index'); diff --git a/synapse/storage/data_stores/main/schema/delta/40/device_list_streams.sql b/synapse/storage/data_stores/main/schema/delta/40/device_list_streams.sql new file mode 100644 index 0000000000..dd6dcb65f1 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/40/device_list_streams.sql @@ -0,0 +1,60 @@ +/* Copyright 2017 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Cache of remote devices. +CREATE TABLE device_lists_remote_cache ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + content TEXT NOT NULL +); + +-- The last update we got for a user. Empty if we're not receiving updates for +-- that user. +CREATE TABLE device_lists_remote_extremeties ( + user_id TEXT NOT NULL, + stream_id TEXT NOT NULL +); + +-- we used to create non-unique indexes on these tables, but as of update 52 we create +-- unique indexes concurrently: +-- +-- CREATE INDEX device_lists_remote_cache_id ON device_lists_remote_cache(user_id, device_id); +-- CREATE INDEX device_lists_remote_extremeties_id ON device_lists_remote_extremeties(user_id, stream_id); + + +-- Stream of device lists updates. Includes both local and remotes +CREATE TABLE device_lists_stream ( + stream_id BIGINT NOT NULL, + user_id TEXT NOT NULL, + device_id TEXT NOT NULL +); + +CREATE INDEX device_lists_stream_id ON device_lists_stream(stream_id, user_id); + + +-- The stream of updates to send to other servers. We keep at least one row +-- per user that was sent so that the prev_id for any new updates can be +-- calculated +CREATE TABLE device_lists_outbound_pokes ( + destination TEXT NOT NULL, + stream_id BIGINT NOT NULL, + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + sent BOOLEAN NOT NULL, + ts BIGINT NOT NULL -- So that in future we can clear out pokes to dead servers +); + +CREATE INDEX device_lists_outbound_pokes_id ON device_lists_outbound_pokes(destination, stream_id); +CREATE INDEX device_lists_outbound_pokes_user ON device_lists_outbound_pokes(destination, user_id); diff --git a/synapse/storage/data_stores/main/schema/delta/40/event_push_summary.sql b/synapse/storage/data_stores/main/schema/delta/40/event_push_summary.sql new file mode 100644 index 0000000000..3918f0b794 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/40/event_push_summary.sql @@ -0,0 +1,37 @@ +/* Copyright 2017 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Aggregate of old notification counts that have been deleted out of the +-- main event_push_actions table. This count does not include those that were +-- highlights, as they remain in the event_push_actions table. +CREATE TABLE event_push_summary ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + notif_count BIGINT NOT NULL, + stream_ordering BIGINT NOT NULL +); + +CREATE INDEX event_push_summary_user_rm ON event_push_summary(user_id, room_id); + + +-- The stream ordering up to which we have aggregated the event_push_actions +-- table into event_push_summary +CREATE TABLE event_push_summary_stream_ordering ( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + stream_ordering BIGINT NOT NULL, + CHECK (Lock='X') +); + +INSERT INTO event_push_summary_stream_ordering (stream_ordering) VALUES (0); diff --git a/synapse/storage/data_stores/main/schema/delta/40/pushers.sql b/synapse/storage/data_stores/main/schema/delta/40/pushers.sql new file mode 100644 index 0000000000..054a223f14 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/40/pushers.sql @@ -0,0 +1,39 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS pushers2 ( + id BIGINT PRIMARY KEY, + user_name TEXT NOT NULL, + access_token BIGINT DEFAULT NULL, + profile_tag TEXT NOT NULL, + kind TEXT NOT NULL, + app_id TEXT NOT NULL, + app_display_name TEXT NOT NULL, + device_display_name TEXT NOT NULL, + pushkey TEXT NOT NULL, + ts BIGINT NOT NULL, + lang TEXT, + data TEXT, + last_stream_ordering INTEGER, + last_success BIGINT, + failing_since BIGINT, + UNIQUE (app_id, pushkey, user_name) +); + +INSERT INTO pushers2 SELECT * FROM PUSHERS; + +DROP TABLE PUSHERS; + +ALTER TABLE pushers2 RENAME TO pushers; diff --git a/synapse/storage/data_stores/main/schema/delta/41/device_list_stream_idx.sql b/synapse/storage/data_stores/main/schema/delta/41/device_list_stream_idx.sql new file mode 100644 index 0000000000..b7bee8b692 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/41/device_list_stream_idx.sql @@ -0,0 +1,17 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT into background_updates (update_name, progress_json) + VALUES ('device_lists_stream_idx', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/41/device_outbound_index.sql b/synapse/storage/data_stores/main/schema/delta/41/device_outbound_index.sql new file mode 100644 index 0000000000..62f0b9892b --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/41/device_outbound_index.sql @@ -0,0 +1,16 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE INDEX device_lists_outbound_pokes_stream ON device_lists_outbound_pokes(stream_id); diff --git a/synapse/storage/data_stores/main/schema/delta/41/event_search_event_id_idx.sql b/synapse/storage/data_stores/main/schema/delta/41/event_search_event_id_idx.sql new file mode 100644 index 0000000000..5d9cfecf36 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/41/event_search_event_id_idx.sql @@ -0,0 +1,17 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT into background_updates (update_name, progress_json) + VALUES ('event_search_event_id_idx', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/41/ratelimit.sql b/synapse/storage/data_stores/main/schema/delta/41/ratelimit.sql new file mode 100644 index 0000000000..a194bf0238 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/41/ratelimit.sql @@ -0,0 +1,22 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE ratelimit_override ( + user_id TEXT NOT NULL, + messages_per_second BIGINT, + burst_count BIGINT +); + +CREATE UNIQUE INDEX ratelimit_override_idx ON ratelimit_override(user_id); diff --git a/synapse/storage/data_stores/main/schema/delta/42/current_state_delta.sql b/synapse/storage/data_stores/main/schema/delta/42/current_state_delta.sql new file mode 100644 index 0000000000..d28851aff8 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/42/current_state_delta.sql @@ -0,0 +1,26 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE current_state_delta_stream ( + stream_id BIGINT NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL, + event_id TEXT, -- Is null if the key was removed + prev_event_id TEXT -- Is null if the key was added +); + +CREATE INDEX current_state_delta_stream_idx ON current_state_delta_stream(stream_id); diff --git a/synapse/storage/data_stores/main/schema/delta/42/device_list_last_id.sql b/synapse/storage/data_stores/main/schema/delta/42/device_list_last_id.sql new file mode 100644 index 0000000000..9ab8c14fa3 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/42/device_list_last_id.sql @@ -0,0 +1,33 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- Table of last stream_id that we sent to destination for user_id. This is +-- used to fill out the `prev_id` fields of outbound device list updates. +CREATE TABLE device_lists_outbound_last_success ( + destination TEXT NOT NULL, + user_id TEXT NOT NULL, + stream_id BIGINT NOT NULL +); + +INSERT INTO device_lists_outbound_last_success + SELECT destination, user_id, coalesce(max(stream_id), 0) as stream_id + FROM device_lists_outbound_pokes + WHERE sent = (1 = 1) -- sqlite doesn't have inbuilt boolean values + GROUP BY destination, user_id; + +CREATE INDEX device_lists_outbound_last_success_idx ON device_lists_outbound_last_success( + destination, user_id, stream_id +); diff --git a/synapse/storage/data_stores/main/schema/delta/42/event_auth_state_only.sql b/synapse/storage/data_stores/main/schema/delta/42/event_auth_state_only.sql new file mode 100644 index 0000000000..b8821ac759 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/42/event_auth_state_only.sql @@ -0,0 +1,17 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('event_auth_state_only', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/42/user_dir.py b/synapse/storage/data_stores/main/schema/delta/42/user_dir.py new file mode 100644 index 0000000000..506f326f4d --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/42/user_dir.py @@ -0,0 +1,84 @@ +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.storage.prepare_database import get_statements + +logger = logging.getLogger(__name__) + + +BOTH_TABLES = """ +CREATE TABLE user_directory_stream_pos ( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + stream_id BIGINT, + CHECK (Lock='X') +); + +INSERT INTO user_directory_stream_pos (stream_id) VALUES (null); + +CREATE TABLE user_directory ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, -- A room_id that we know the user is joined to + display_name TEXT, + avatar_url TEXT +); + +CREATE INDEX user_directory_room_idx ON user_directory(room_id); +CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id); + +CREATE TABLE users_in_pubic_room ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL -- A room_id that we know is public +); + +CREATE INDEX users_in_pubic_room_room_idx ON users_in_pubic_room(room_id); +CREATE UNIQUE INDEX users_in_pubic_room_user_idx ON users_in_pubic_room(user_id); +""" + + +POSTGRES_TABLE = """ +CREATE TABLE user_directory_search ( + user_id TEXT NOT NULL, + vector tsvector +); + +CREATE INDEX user_directory_search_fts_idx ON user_directory_search USING gin(vector); +CREATE UNIQUE INDEX user_directory_search_user_idx ON user_directory_search(user_id); +""" + + +SQLITE_TABLE = """ +CREATE VIRTUAL TABLE user_directory_search + USING fts4 ( user_id, value ); +""" + + +def run_create(cur, database_engine, *args, **kwargs): + for statement in get_statements(BOTH_TABLES.splitlines()): + cur.execute(statement) + + if isinstance(database_engine, PostgresEngine): + for statement in get_statements(POSTGRES_TABLE.splitlines()): + cur.execute(statement) + elif isinstance(database_engine, Sqlite3Engine): + for statement in get_statements(SQLITE_TABLE.splitlines()): + cur.execute(statement) + else: + raise Exception("Unrecognized database engine") + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/data_stores/main/schema/delta/43/blocked_rooms.sql b/synapse/storage/data_stores/main/schema/delta/43/blocked_rooms.sql new file mode 100644 index 0000000000..0e3cd143ff --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/43/blocked_rooms.sql @@ -0,0 +1,21 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE blocked_rooms ( + room_id TEXT NOT NULL, + user_id TEXT NOT NULL -- Admin who blocked the room +); + +CREATE UNIQUE INDEX blocked_rooms_idx ON blocked_rooms(room_id); diff --git a/synapse/storage/data_stores/main/schema/delta/43/quarantine_media.sql b/synapse/storage/data_stores/main/schema/delta/43/quarantine_media.sql new file mode 100644 index 0000000000..630907ec4f --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/43/quarantine_media.sql @@ -0,0 +1,17 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE local_media_repository ADD COLUMN quarantined_by TEXT; +ALTER TABLE remote_media_cache ADD COLUMN quarantined_by TEXT; diff --git a/synapse/storage/data_stores/main/schema/delta/43/url_cache.sql b/synapse/storage/data_stores/main/schema/delta/43/url_cache.sql new file mode 100644 index 0000000000..45ebe020da --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/43/url_cache.sql @@ -0,0 +1,16 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE local_media_repository ADD COLUMN url_cache TEXT; diff --git a/synapse/storage/data_stores/main/schema/delta/43/user_share.sql b/synapse/storage/data_stores/main/schema/delta/43/user_share.sql new file mode 100644 index 0000000000..ee7062abe4 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/43/user_share.sql @@ -0,0 +1,33 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Table keeping track of who shares a room with who. We only keep track +-- of this for local users, so `user_id` is local users only (but we do keep track +-- of which remote users share a room) +CREATE TABLE users_who_share_rooms ( + user_id TEXT NOT NULL, + other_user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + share_private BOOLEAN NOT NULL -- is the shared room private? i.e. they share a private room +); + + +CREATE UNIQUE INDEX users_who_share_rooms_u_idx ON users_who_share_rooms(user_id, other_user_id); +CREATE INDEX users_who_share_rooms_r_idx ON users_who_share_rooms(room_id); +CREATE INDEX users_who_share_rooms_o_idx ON users_who_share_rooms(other_user_id); + + +-- Make sure that we populate the table initially +UPDATE user_directory_stream_pos SET stream_id = NULL; diff --git a/synapse/storage/data_stores/main/schema/delta/44/expire_url_cache.sql b/synapse/storage/data_stores/main/schema/delta/44/expire_url_cache.sql new file mode 100644 index 0000000000..b12f9b2ebf --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/44/expire_url_cache.sql @@ -0,0 +1,41 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- this didn't work on SQLite 3.7 (because of lack of partial indexes), so was +-- removed and replaced with 46/local_media_repository_url_idx.sql. +-- +-- CREATE INDEX local_media_repository_url_idx ON local_media_repository(created_ts) WHERE url_cache IS NOT NULL; + +-- we need to change `expires` to `expires_ts` so that we can index on it. SQLite doesn't support +-- indices on expressions until 3.9. +CREATE TABLE local_media_repository_url_cache_new( + url TEXT, + response_code INTEGER, + etag TEXT, + expires_ts BIGINT, + og TEXT, + media_id TEXT, + download_ts BIGINT +); + +INSERT INTO local_media_repository_url_cache_new + SELECT url, response_code, etag, expires + download_ts, og, media_id, download_ts FROM local_media_repository_url_cache; + +DROP TABLE local_media_repository_url_cache; +ALTER TABLE local_media_repository_url_cache_new RENAME TO local_media_repository_url_cache; + +CREATE INDEX local_media_repository_url_cache_expires_idx ON local_media_repository_url_cache(expires_ts); +CREATE INDEX local_media_repository_url_cache_by_url_download_ts ON local_media_repository_url_cache(url, download_ts); +CREATE INDEX local_media_repository_url_cache_media_idx ON local_media_repository_url_cache(media_id); diff --git a/synapse/storage/data_stores/main/schema/delta/45/group_server.sql b/synapse/storage/data_stores/main/schema/delta/45/group_server.sql new file mode 100644 index 0000000000..b2333848a0 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/45/group_server.sql @@ -0,0 +1,167 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE groups ( + group_id TEXT NOT NULL, + name TEXT, -- the display name of the room + avatar_url TEXT, + short_description TEXT, + long_description TEXT +); + +CREATE UNIQUE INDEX groups_idx ON groups(group_id); + + +-- list of users the group server thinks are joined +CREATE TABLE group_users ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + is_admin BOOLEAN NOT NULL, + is_public BOOLEAN NOT NULL -- whether the users membership can be seen by everyone +); + + +CREATE INDEX groups_users_g_idx ON group_users(group_id, user_id); +CREATE INDEX groups_users_u_idx ON group_users(user_id); + +-- list of users the group server thinks are invited +CREATE TABLE group_invites ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL +); + +CREATE INDEX groups_invites_g_idx ON group_invites(group_id, user_id); +CREATE INDEX groups_invites_u_idx ON group_invites(user_id); + + +CREATE TABLE group_rooms ( + group_id TEXT NOT NULL, + room_id TEXT NOT NULL, + is_public BOOLEAN NOT NULL -- whether the room can be seen by everyone +); + +CREATE UNIQUE INDEX groups_rooms_g_idx ON group_rooms(group_id, room_id); +CREATE INDEX groups_rooms_r_idx ON group_rooms(room_id); + + +-- Rooms to include in the summary +CREATE TABLE group_summary_rooms ( + group_id TEXT NOT NULL, + room_id TEXT NOT NULL, + category_id TEXT NOT NULL, + room_order BIGINT NOT NULL, + is_public BOOLEAN NOT NULL, -- whether the room should be show to everyone + UNIQUE (group_id, category_id, room_id, room_order), + CHECK (room_order > 0) +); + +CREATE UNIQUE INDEX group_summary_rooms_g_idx ON group_summary_rooms(group_id, room_id, category_id); + + +-- Categories to include in the summary +CREATE TABLE group_summary_room_categories ( + group_id TEXT NOT NULL, + category_id TEXT NOT NULL, + cat_order BIGINT NOT NULL, + UNIQUE (group_id, category_id, cat_order), + CHECK (cat_order > 0) +); + +-- The categories in the group +CREATE TABLE group_room_categories ( + group_id TEXT NOT NULL, + category_id TEXT NOT NULL, + profile TEXT NOT NULL, + is_public BOOLEAN NOT NULL, -- whether the category should be show to everyone + UNIQUE (group_id, category_id) +); + +-- The users to include in the group summary +CREATE TABLE group_summary_users ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + role_id TEXT NOT NULL, + user_order BIGINT NOT NULL, + is_public BOOLEAN NOT NULL -- whether the user should be show to everyone +); + +CREATE INDEX group_summary_users_g_idx ON group_summary_users(group_id); + +-- The roles to include in the group summary +CREATE TABLE group_summary_roles ( + group_id TEXT NOT NULL, + role_id TEXT NOT NULL, + role_order BIGINT NOT NULL, + UNIQUE (group_id, role_id, role_order), + CHECK (role_order > 0) +); + + +-- The roles in a groups +CREATE TABLE group_roles ( + group_id TEXT NOT NULL, + role_id TEXT NOT NULL, + profile TEXT NOT NULL, + is_public BOOLEAN NOT NULL, -- whether the role should be show to everyone + UNIQUE (group_id, role_id) +); + + +-- List of attestations we've given out and need to renew +CREATE TABLE group_attestations_renewals ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + valid_until_ms BIGINT NOT NULL +); + +CREATE INDEX group_attestations_renewals_g_idx ON group_attestations_renewals(group_id, user_id); +CREATE INDEX group_attestations_renewals_u_idx ON group_attestations_renewals(user_id); +CREATE INDEX group_attestations_renewals_v_idx ON group_attestations_renewals(valid_until_ms); + + +-- List of attestations we've received from remotes and are interested in. +CREATE TABLE group_attestations_remote ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + valid_until_ms BIGINT NOT NULL, + attestation_json TEXT NOT NULL +); + +CREATE INDEX group_attestations_remote_g_idx ON group_attestations_remote(group_id, user_id); +CREATE INDEX group_attestations_remote_u_idx ON group_attestations_remote(user_id); +CREATE INDEX group_attestations_remote_v_idx ON group_attestations_remote(valid_until_ms); + + +-- The group membership for the HS's users +CREATE TABLE local_group_membership ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + is_admin BOOLEAN NOT NULL, + membership TEXT NOT NULL, + is_publicised BOOLEAN NOT NULL, -- if the user is publicising their membership + content TEXT NOT NULL +); + +CREATE INDEX local_group_membership_u_idx ON local_group_membership(user_id, group_id); +CREATE INDEX local_group_membership_g_idx ON local_group_membership(group_id); + + +CREATE TABLE local_group_updates ( + stream_id BIGINT NOT NULL, + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + type TEXT NOT NULL, + content TEXT NOT NULL +); diff --git a/synapse/storage/data_stores/main/schema/delta/45/profile_cache.sql b/synapse/storage/data_stores/main/schema/delta/45/profile_cache.sql new file mode 100644 index 0000000000..e5ddc84df0 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/45/profile_cache.sql @@ -0,0 +1,28 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- A subset of remote users whose profiles we have cached. +-- Whether a user is in this table or not is defined by the storage function +-- `is_subscribed_remote_profile_for_user` +CREATE TABLE remote_profile_cache ( + user_id TEXT NOT NULL, + displayname TEXT, + avatar_url TEXT, + last_check BIGINT NOT NULL +); + +CREATE UNIQUE INDEX remote_profile_cache_user_id ON remote_profile_cache(user_id); +CREATE INDEX remote_profile_cache_time ON remote_profile_cache(last_check); diff --git a/synapse/storage/data_stores/main/schema/delta/46/drop_refresh_tokens.sql b/synapse/storage/data_stores/main/schema/delta/46/drop_refresh_tokens.sql new file mode 100644 index 0000000000..68c48a89a9 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/46/drop_refresh_tokens.sql @@ -0,0 +1,17 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* we no longer use (or create) the refresh_tokens table */ +DROP TABLE IF EXISTS refresh_tokens; diff --git a/synapse/storage/data_stores/main/schema/delta/46/drop_unique_deleted_pushers.sql b/synapse/storage/data_stores/main/schema/delta/46/drop_unique_deleted_pushers.sql new file mode 100644 index 0000000000..bb307889c1 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/46/drop_unique_deleted_pushers.sql @@ -0,0 +1,35 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- drop the unique constraint on deleted_pushers so that we can just insert +-- into it rather than upserting. + +CREATE TABLE deleted_pushers2 ( + stream_id BIGINT NOT NULL, + app_id TEXT NOT NULL, + pushkey TEXT NOT NULL, + user_id TEXT NOT NULL +); + +INSERT INTO deleted_pushers2 (stream_id, app_id, pushkey, user_id) + SELECT stream_id, app_id, pushkey, user_id from deleted_pushers; + +DROP TABLE deleted_pushers; +ALTER TABLE deleted_pushers2 RENAME TO deleted_pushers; + +-- create the index after doing the inserts because that's more efficient. +-- it also means we can give it the same name as the old one without renaming. +CREATE INDEX deleted_pushers_stream_id ON deleted_pushers (stream_id); + diff --git a/synapse/storage/data_stores/main/schema/delta/46/group_server.sql b/synapse/storage/data_stores/main/schema/delta/46/group_server.sql new file mode 100644 index 0000000000..097679bc9a --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/46/group_server.sql @@ -0,0 +1,32 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE groups_new ( + group_id TEXT NOT NULL, + name TEXT, -- the display name of the room + avatar_url TEXT, + short_description TEXT, + long_description TEXT, + is_public BOOL NOT NULL -- whether non-members can access group APIs +); + +-- NB: awful hack to get the default to be true on postgres and 1 on sqlite +INSERT INTO groups_new + SELECT group_id, name, avatar_url, short_description, long_description, (1=1) FROM groups; + +DROP TABLE groups; +ALTER TABLE groups_new RENAME TO groups; + +CREATE UNIQUE INDEX groups_idx ON groups(group_id); diff --git a/synapse/storage/data_stores/main/schema/delta/46/local_media_repository_url_idx.sql b/synapse/storage/data_stores/main/schema/delta/46/local_media_repository_url_idx.sql new file mode 100644 index 0000000000..bbfc7f5d1a --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/46/local_media_repository_url_idx.sql @@ -0,0 +1,24 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- register a background update which will recreate the +-- local_media_repository_url_idx index. +-- +-- We do this as a bg update not because it is a particularly onerous +-- operation, but because we'd like it to be a partial index if possible, and +-- the background_index_update code will understand whether we are on +-- postgres or sqlite and behave accordingly. +INSERT INTO background_updates (update_name, progress_json) VALUES + ('local_media_repository_url_idx', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/46/user_dir_null_room_ids.sql b/synapse/storage/data_stores/main/schema/delta/46/user_dir_null_room_ids.sql new file mode 100644 index 0000000000..cb0d5a2576 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/46/user_dir_null_room_ids.sql @@ -0,0 +1,35 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- change the user_directory table to also cover global local user profiles +-- rather than just profiles within specific rooms. + +CREATE TABLE user_directory2 ( + user_id TEXT NOT NULL, + room_id TEXT, + display_name TEXT, + avatar_url TEXT +); + +INSERT INTO user_directory2(user_id, room_id, display_name, avatar_url) + SELECT user_id, room_id, display_name, avatar_url from user_directory; + +DROP TABLE user_directory; +ALTER TABLE user_directory2 RENAME TO user_directory; + +-- create indexes after doing the inserts because that's more efficient. +-- it also means we can give it the same name as the old one without renaming. +CREATE INDEX user_directory_room_idx ON user_directory(room_id); +CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id); diff --git a/synapse/storage/data_stores/main/schema/delta/46/user_dir_typos.sql b/synapse/storage/data_stores/main/schema/delta/46/user_dir_typos.sql new file mode 100644 index 0000000000..d9505f8da1 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/46/user_dir_typos.sql @@ -0,0 +1,24 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- this is just embarassing :| +ALTER TABLE users_in_pubic_room RENAME TO users_in_public_rooms; + +-- this is only 300K rows on matrix.org and takes ~3s to generate the index, +-- so is hopefully not going to block anyone else for that long... +CREATE INDEX users_in_public_rooms_room_idx ON users_in_public_rooms(room_id); +CREATE UNIQUE INDEX users_in_public_rooms_user_idx ON users_in_public_rooms(user_id); +DROP INDEX users_in_pubic_room_room_idx; +DROP INDEX users_in_pubic_room_user_idx; diff --git a/synapse/storage/data_stores/main/schema/delta/47/last_access_media.sql b/synapse/storage/data_stores/main/schema/delta/47/last_access_media.sql new file mode 100644 index 0000000000..f505fb22b5 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/47/last_access_media.sql @@ -0,0 +1,16 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE local_media_repository ADD COLUMN last_access_ts BIGINT; diff --git a/synapse/storage/data_stores/main/schema/delta/47/postgres_fts_gin.sql b/synapse/storage/data_stores/main/schema/delta/47/postgres_fts_gin.sql new file mode 100644 index 0000000000..31d7a817eb --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/47/postgres_fts_gin.sql @@ -0,0 +1,17 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT into background_updates (update_name, progress_json) + VALUES ('event_search_postgres_gin', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/47/push_actions_staging.sql b/synapse/storage/data_stores/main/schema/delta/47/push_actions_staging.sql new file mode 100644 index 0000000000..edccf4a96f --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/47/push_actions_staging.sql @@ -0,0 +1,28 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Temporary staging area for push actions that have been calculated for an +-- event, but the event hasn't yet been persisted. +-- When the event is persisted the rows are moved over to the +-- event_push_actions table. +CREATE TABLE event_push_actions_staging ( + event_id TEXT NOT NULL, + user_id TEXT NOT NULL, + actions TEXT NOT NULL, + notif SMALLINT NOT NULL, + highlight SMALLINT NOT NULL +); + +CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging(event_id); diff --git a/synapse/storage/data_stores/main/schema/delta/47/state_group_seq.py b/synapse/storage/data_stores/main/schema/delta/47/state_group_seq.py new file mode 100644 index 0000000000..9fd1ccf6f7 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/47/state_group_seq.py @@ -0,0 +1,34 @@ +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.engines import PostgresEngine + + +def run_create(cur, database_engine, *args, **kwargs): + if isinstance(database_engine, PostgresEngine): + # if we already have some state groups, we want to start making new + # ones with a higher id. + cur.execute("SELECT max(id) FROM state_groups") + row = cur.fetchone() + + if row[0] is None: + start_val = 1 + else: + start_val = row[0] + 1 + + cur.execute("CREATE SEQUENCE state_group_id_seq START WITH %s", (start_val,)) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/data_stores/main/schema/delta/48/add_user_consent.sql b/synapse/storage/data_stores/main/schema/delta/48/add_user_consent.sql new file mode 100644 index 0000000000..5237491506 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/48/add_user_consent.sql @@ -0,0 +1,18 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* record the version of the privacy policy the user has consented to + */ +ALTER TABLE users ADD COLUMN consent_version TEXT; diff --git a/synapse/storage/data_stores/main/schema/delta/48/add_user_ips_last_seen_index.sql b/synapse/storage/data_stores/main/schema/delta/48/add_user_ips_last_seen_index.sql new file mode 100644 index 0000000000..9248b0b24a --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/48/add_user_ips_last_seen_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT into background_updates (update_name, progress_json) + VALUES ('user_ips_last_seen_index', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/48/deactivated_users.sql b/synapse/storage/data_stores/main/schema/delta/48/deactivated_users.sql new file mode 100644 index 0000000000..e9013a6969 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/48/deactivated_users.sql @@ -0,0 +1,25 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Store any accounts that have been requested to be deactivated. + * We part the account from all the rooms its in when its + * deactivated. This can take some time and synapse may be restarted + * before it completes, so store the user IDs here until the process + * is complete. + */ +CREATE TABLE users_pending_deactivation ( + user_id TEXT NOT NULL +); diff --git a/synapse/storage/data_stores/main/schema/delta/48/group_unique_indexes.py b/synapse/storage/data_stores/main/schema/delta/48/group_unique_indexes.py new file mode 100644 index 0000000000..49f5f2c003 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/48/group_unique_indexes.py @@ -0,0 +1,63 @@ +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.engines import PostgresEngine +from synapse.storage.prepare_database import get_statements + +FIX_INDEXES = """ +-- rebuild indexes as uniques +DROP INDEX groups_invites_g_idx; +CREATE UNIQUE INDEX group_invites_g_idx ON group_invites(group_id, user_id); +DROP INDEX groups_users_g_idx; +CREATE UNIQUE INDEX group_users_g_idx ON group_users(group_id, user_id); + +-- rename other indexes to actually match their table names.. +DROP INDEX groups_users_u_idx; +CREATE INDEX group_users_u_idx ON group_users(user_id); +DROP INDEX groups_invites_u_idx; +CREATE INDEX group_invites_u_idx ON group_invites(user_id); +DROP INDEX groups_rooms_g_idx; +CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms(group_id, room_id); +DROP INDEX groups_rooms_r_idx; +CREATE INDEX group_rooms_r_idx ON group_rooms(room_id); +""" + + +def run_create(cur, database_engine, *args, **kwargs): + rowid = "ctid" if isinstance(database_engine, PostgresEngine) else "rowid" + + # remove duplicates from group_users & group_invites tables + cur.execute( + """ + DELETE FROM group_users WHERE %s NOT IN ( + SELECT min(%s) FROM group_users GROUP BY group_id, user_id + ); + """ + % (rowid, rowid) + ) + cur.execute( + """ + DELETE FROM group_invites WHERE %s NOT IN ( + SELECT min(%s) FROM group_invites GROUP BY group_id, user_id + ); + """ + % (rowid, rowid) + ) + + for statement in get_statements(FIX_INDEXES.splitlines()): + cur.execute(statement) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/data_stores/main/schema/delta/48/groups_joinable.sql b/synapse/storage/data_stores/main/schema/delta/48/groups_joinable.sql new file mode 100644 index 0000000000..ce26eaf0c9 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/48/groups_joinable.sql @@ -0,0 +1,22 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * This isn't a real ENUM because sqlite doesn't support it + * and we use a default of NULL for inserted rows and interpret + * NULL at the python store level as necessary so that existing + * rows are given the correct default policy. + */ +ALTER TABLE groups ADD COLUMN join_policy TEXT NOT NULL DEFAULT 'invite'; diff --git a/synapse/storage/data_stores/main/schema/delta/49/add_user_consent_server_notice_sent.sql b/synapse/storage/data_stores/main/schema/delta/49/add_user_consent_server_notice_sent.sql new file mode 100644 index 0000000000..14dcf18d73 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/49/add_user_consent_server_notice_sent.sql @@ -0,0 +1,20 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* record whether we have sent a server notice about consenting to the + * privacy policy. Specifically records the version of the policy we sent + * a message about. + */ +ALTER TABLE users ADD COLUMN consent_server_notice_sent TEXT; diff --git a/synapse/storage/data_stores/main/schema/delta/49/add_user_daily_visits.sql b/synapse/storage/data_stores/main/schema/delta/49/add_user_daily_visits.sql new file mode 100644 index 0000000000..3dd478196f --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/49/add_user_daily_visits.sql @@ -0,0 +1,21 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE user_daily_visits ( user_id TEXT NOT NULL, + device_id TEXT, + timestamp BIGINT NOT NULL ); +CREATE INDEX user_daily_visits_uts_idx ON user_daily_visits(user_id, timestamp); +CREATE INDEX user_daily_visits_ts_idx ON user_daily_visits(timestamp); diff --git a/synapse/storage/data_stores/main/schema/delta/49/add_user_ips_last_seen_only_index.sql b/synapse/storage/data_stores/main/schema/delta/49/add_user_ips_last_seen_only_index.sql new file mode 100644 index 0000000000..3a4ed59b5b --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/49/add_user_ips_last_seen_only_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT into background_updates (update_name, progress_json) + VALUES ('user_ips_last_seen_only_index', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/50/add_creation_ts_users_index.sql b/synapse/storage/data_stores/main/schema/delta/50/add_creation_ts_users_index.sql new file mode 100644 index 0000000000..c93ae47532 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/50/add_creation_ts_users_index.sql @@ -0,0 +1,19 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + +INSERT into background_updates (update_name, progress_json) + VALUES ('users_creation_ts', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/50/erasure_store.sql b/synapse/storage/data_stores/main/schema/delta/50/erasure_store.sql new file mode 100644 index 0000000000..5d8641a9ab --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/50/erasure_store.sql @@ -0,0 +1,21 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- a table of users who have requested that their details be erased +CREATE TABLE erased_users ( + user_id TEXT NOT NULL +); + +CREATE UNIQUE INDEX erased_users_user ON erased_users(user_id); diff --git a/synapse/storage/data_stores/main/schema/delta/50/make_event_content_nullable.py b/synapse/storage/data_stores/main/schema/delta/50/make_event_content_nullable.py new file mode 100644 index 0000000000..b1684a8441 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/50/make_event_content_nullable.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +We want to stop populating 'event.content', so we need to make it nullable. + +If this has to be rolled back, then the following should populate the missing data: + +Postgres: + + UPDATE events SET content=(ej.json::json)->'content' FROM event_json ej + WHERE ej.event_id = events.event_id AND + stream_ordering < ( + SELECT stream_ordering FROM events WHERE content IS NOT NULL + ORDER BY stream_ordering LIMIT 1 + ); + + UPDATE events SET content=(ej.json::json)->'content' FROM event_json ej + WHERE ej.event_id = events.event_id AND + stream_ordering > ( + SELECT stream_ordering FROM events WHERE content IS NOT NULL + ORDER BY stream_ordering DESC LIMIT 1 + ); + +SQLite: + + UPDATE events SET content=( + SELECT json_extract(json,'$.content') FROM event_json ej + WHERE ej.event_id = events.event_id + ) + WHERE + stream_ordering < ( + SELECT stream_ordering FROM events WHERE content IS NOT NULL + ORDER BY stream_ordering LIMIT 1 + ) + OR stream_ordering > ( + SELECT stream_ordering FROM events WHERE content IS NOT NULL + ORDER BY stream_ordering DESC LIMIT 1 + ); + +""" + +import logging + +from synapse.storage.engines import PostgresEngine + +logger = logging.getLogger(__name__) + + +def run_create(cur, database_engine, *args, **kwargs): + pass + + +def run_upgrade(cur, database_engine, *args, **kwargs): + if isinstance(database_engine, PostgresEngine): + cur.execute( + """ + ALTER TABLE events ALTER COLUMN content DROP NOT NULL; + """ + ) + return + + # sqlite is an arse about this. ref: https://www.sqlite.org/lang_altertable.html + + cur.execute( + "SELECT sql FROM sqlite_master WHERE tbl_name='events' AND type='table'" + ) + (oldsql,) = cur.fetchone() + + sql = oldsql.replace("content TEXT NOT NULL", "content TEXT") + if sql == oldsql: + raise Exception("Couldn't find null constraint to drop in %s" % oldsql) + + logger.info("Replacing definition of 'events' with: %s", sql) + + cur.execute("PRAGMA schema_version") + (oldver,) = cur.fetchone() + cur.execute("PRAGMA writable_schema=ON") + cur.execute( + "UPDATE sqlite_master SET sql=? WHERE tbl_name='events' AND type='table'", + (sql,), + ) + cur.execute("PRAGMA schema_version=%i" % (oldver + 1,)) + cur.execute("PRAGMA writable_schema=OFF") diff --git a/synapse/storage/data_stores/main/schema/delta/51/e2e_room_keys.sql b/synapse/storage/data_stores/main/schema/delta/51/e2e_room_keys.sql new file mode 100644 index 0000000000..c0e66a697d --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/51/e2e_room_keys.sql @@ -0,0 +1,39 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- users' optionally backed up encrypted e2e sessions +CREATE TABLE e2e_room_keys ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + session_id TEXT NOT NULL, + version TEXT NOT NULL, + first_message_index INT, + forwarded_count INT, + is_verified BOOLEAN, + session_data TEXT NOT NULL +); + +CREATE UNIQUE INDEX e2e_room_keys_idx ON e2e_room_keys(user_id, room_id, session_id); + +-- the metadata for each generation of encrypted e2e session backups +CREATE TABLE e2e_room_keys_versions ( + user_id TEXT NOT NULL, + version TEXT NOT NULL, + algorithm TEXT NOT NULL, + auth_data TEXT NOT NULL, + deleted SMALLINT DEFAULT 0 NOT NULL +); + +CREATE UNIQUE INDEX e2e_room_keys_versions_idx ON e2e_room_keys_versions(user_id, version); diff --git a/synapse/storage/data_stores/main/schema/delta/51/monthly_active_users.sql b/synapse/storage/data_stores/main/schema/delta/51/monthly_active_users.sql new file mode 100644 index 0000000000..c9d537d5a3 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/51/monthly_active_users.sql @@ -0,0 +1,27 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- a table of monthly active users, for use where blocking based on mau limits +CREATE TABLE monthly_active_users ( + user_id TEXT NOT NULL, + -- Last time we saw the user. Not guaranteed to be accurate due to rate limiting + -- on updates, Granularity of updates governed by + -- synapse.storage.monthly_active_users.LAST_SEEN_GRANULARITY + -- Measured in ms since epoch. + timestamp BIGINT NOT NULL +); + +CREATE UNIQUE INDEX monthly_active_users_users ON monthly_active_users(user_id); +CREATE INDEX monthly_active_users_time_stamp ON monthly_active_users(timestamp); diff --git a/synapse/storage/data_stores/main/schema/delta/52/add_event_to_state_group_index.sql b/synapse/storage/data_stores/main/schema/delta/52/add_event_to_state_group_index.sql new file mode 100644 index 0000000000..91e03d13e1 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/52/add_event_to_state_group_index.sql @@ -0,0 +1,19 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- This is needed to efficiently check for unreferenced state groups during +-- purge. Added events_to_state_group(state_group) index +INSERT into background_updates (update_name, progress_json) + VALUES ('event_to_state_groups_sg_index', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/52/device_list_streams_unique_idx.sql b/synapse/storage/data_stores/main/schema/delta/52/device_list_streams_unique_idx.sql new file mode 100644 index 0000000000..bfa49e6f92 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/52/device_list_streams_unique_idx.sql @@ -0,0 +1,36 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- register a background update which will create a unique index on +-- device_lists_remote_cache +INSERT into background_updates (update_name, progress_json) + VALUES ('device_lists_remote_cache_unique_idx', '{}'); + +-- and one on device_lists_remote_extremeties +INSERT into background_updates (update_name, progress_json, depends_on) + VALUES ( + 'device_lists_remote_extremeties_unique_idx', '{}', + + -- doesn't really depend on this, but we need to make sure both happen + -- before we drop the old indexes. + 'device_lists_remote_cache_unique_idx' + ); + +-- once they complete, we can drop the old indexes. +INSERT into background_updates (update_name, progress_json, depends_on) + VALUES ( + 'drop_device_list_streams_non_unique_indexes', '{}', + 'device_lists_remote_extremeties_unique_idx' + ); diff --git a/synapse/storage/data_stores/main/schema/delta/52/e2e_room_keys.sql b/synapse/storage/data_stores/main/schema/delta/52/e2e_room_keys.sql new file mode 100644 index 0000000000..db687cccae --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/52/e2e_room_keys.sql @@ -0,0 +1,53 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Change version column to an integer so we can do MAX() sensibly + */ +CREATE TABLE e2e_room_keys_versions_new ( + user_id TEXT NOT NULL, + version BIGINT NOT NULL, + algorithm TEXT NOT NULL, + auth_data TEXT NOT NULL, + deleted SMALLINT DEFAULT 0 NOT NULL +); + +INSERT INTO e2e_room_keys_versions_new + SELECT user_id, CAST(version as BIGINT), algorithm, auth_data, deleted FROM e2e_room_keys_versions; + +DROP TABLE e2e_room_keys_versions; +ALTER TABLE e2e_room_keys_versions_new RENAME TO e2e_room_keys_versions; + +CREATE UNIQUE INDEX e2e_room_keys_versions_idx ON e2e_room_keys_versions(user_id, version); + +/* Change e2e_rooms_keys to match + */ +CREATE TABLE e2e_room_keys_new ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + session_id TEXT NOT NULL, + version BIGINT NOT NULL, + first_message_index INT, + forwarded_count INT, + is_verified BOOLEAN, + session_data TEXT NOT NULL +); + +INSERT INTO e2e_room_keys_new + SELECT user_id, room_id, session_id, CAST(version as BIGINT), first_message_index, forwarded_count, is_verified, session_data FROM e2e_room_keys; + +DROP TABLE e2e_room_keys; +ALTER TABLE e2e_room_keys_new RENAME TO e2e_room_keys; + +CREATE UNIQUE INDEX e2e_room_keys_idx ON e2e_room_keys(user_id, room_id, session_id); diff --git a/synapse/storage/data_stores/main/schema/delta/53/add_user_type_to_users.sql b/synapse/storage/data_stores/main/schema/delta/53/add_user_type_to_users.sql new file mode 100644 index 0000000000..88ec2f83e5 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/53/add_user_type_to_users.sql @@ -0,0 +1,19 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* The type of the user: NULL for a regular user, or one of the constants in + * synapse.api.constants.UserTypes + */ +ALTER TABLE users ADD COLUMN user_type TEXT DEFAULT NULL; diff --git a/synapse/storage/data_stores/main/schema/delta/53/drop_sent_transactions.sql b/synapse/storage/data_stores/main/schema/delta/53/drop_sent_transactions.sql new file mode 100644 index 0000000000..e372f5a44a --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/53/drop_sent_transactions.sql @@ -0,0 +1,16 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +DROP TABLE IF EXISTS sent_transactions; diff --git a/synapse/storage/data_stores/main/schema/delta/53/event_format_version.sql b/synapse/storage/data_stores/main/schema/delta/53/event_format_version.sql new file mode 100644 index 0000000000..1d977c2834 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/53/event_format_version.sql @@ -0,0 +1,16 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE event_json ADD COLUMN format_version INTEGER; diff --git a/synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql b/synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql new file mode 100644 index 0000000000..ffcc896b58 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql @@ -0,0 +1,30 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Set up staging tables +INSERT INTO background_updates (update_name, progress_json) VALUES + ('populate_user_directory_createtables', '{}'); + +-- Run through each room and update the user directory according to who is in it +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('populate_user_directory_process_rooms', '{}', 'populate_user_directory_createtables'); + +-- Insert all users, if search_all_users is on +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('populate_user_directory_process_users', '{}', 'populate_user_directory_process_rooms'); + +-- Clean up staging tables +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('populate_user_directory_cleanup', '{}', 'populate_user_directory_process_users'); diff --git a/synapse/storage/data_stores/main/schema/delta/53/user_ips_index.sql b/synapse/storage/data_stores/main/schema/delta/53/user_ips_index.sql new file mode 100644 index 0000000000..b812c5794f --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/53/user_ips_index.sql @@ -0,0 +1,30 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + -- analyze user_ips, to help ensure the correct indices are used +INSERT INTO background_updates (update_name, progress_json) VALUES + ('user_ips_analyze', '{}'); + +-- delete duplicates +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('user_ips_remove_dupes', '{}', 'user_ips_analyze'); + +-- add a new unique index to user_ips table +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('user_ips_device_unique_index', '{}', 'user_ips_remove_dupes'); + +-- drop the old original index +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('user_ips_drop_nonunique_index', '{}', 'user_ips_device_unique_index'); diff --git a/synapse/storage/data_stores/main/schema/delta/53/user_share.sql b/synapse/storage/data_stores/main/schema/delta/53/user_share.sql new file mode 100644 index 0000000000..5831b1a6f8 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/53/user_share.sql @@ -0,0 +1,44 @@ +/* Copyright 2017 Vector Creations Ltd, 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Old disused version of the tables below. +DROP TABLE IF EXISTS users_who_share_rooms; + +-- Tables keeping track of what users share rooms. This is a map of local users +-- to local or remote users, per room. Remote users cannot be in the user_id +-- column, only the other_user_id column. There are two tables, one for public +-- rooms and those for private rooms. +CREATE TABLE IF NOT EXISTS users_who_share_public_rooms ( + user_id TEXT NOT NULL, + other_user_id TEXT NOT NULL, + room_id TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS users_who_share_private_rooms ( + user_id TEXT NOT NULL, + other_user_id TEXT NOT NULL, + room_id TEXT NOT NULL +); + +CREATE UNIQUE INDEX users_who_share_public_rooms_u_idx ON users_who_share_public_rooms(user_id, other_user_id, room_id); +CREATE INDEX users_who_share_public_rooms_r_idx ON users_who_share_public_rooms(room_id); +CREATE INDEX users_who_share_public_rooms_o_idx ON users_who_share_public_rooms(other_user_id); + +CREATE UNIQUE INDEX users_who_share_private_rooms_u_idx ON users_who_share_private_rooms(user_id, other_user_id, room_id); +CREATE INDEX users_who_share_private_rooms_r_idx ON users_who_share_private_rooms(room_id); +CREATE INDEX users_who_share_private_rooms_o_idx ON users_who_share_private_rooms(other_user_id); + +-- Make sure that we populate the tables initially by resetting the stream ID +UPDATE user_directory_stream_pos SET stream_id = NULL; diff --git a/synapse/storage/data_stores/main/schema/delta/53/user_threepid_id.sql b/synapse/storage/data_stores/main/schema/delta/53/user_threepid_id.sql new file mode 100644 index 0000000000..80c2c573b6 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/53/user_threepid_id.sql @@ -0,0 +1,29 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Tracks which identity server a user bound their threepid via. +CREATE TABLE user_threepid_id_server ( + user_id TEXT NOT NULL, + medium TEXT NOT NULL, + address TEXT NOT NULL, + id_server TEXT NOT NULL +); + +CREATE UNIQUE INDEX user_threepid_id_server_idx ON user_threepid_id_server( + user_id, medium, address, id_server +); + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('user_threepids_grandfather', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/53/users_in_public_rooms.sql b/synapse/storage/data_stores/main/schema/delta/53/users_in_public_rooms.sql new file mode 100644 index 0000000000..f7827ca6d2 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/53/users_in_public_rooms.sql @@ -0,0 +1,28 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- We don't need the old version of this table. +DROP TABLE IF EXISTS users_in_public_rooms; + +-- Old version of users_in_public_rooms +DROP TABLE IF EXISTS users_who_share_public_rooms; + +-- Track what users are in public rooms. +CREATE TABLE IF NOT EXISTS users_in_public_rooms ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL +); + +CREATE UNIQUE INDEX users_in_public_rooms_u_idx ON users_in_public_rooms(user_id, room_id); diff --git a/synapse/storage/data_stores/main/schema/delta/54/account_validity_with_renewal.sql b/synapse/storage/data_stores/main/schema/delta/54/account_validity_with_renewal.sql new file mode 100644 index 0000000000..0adb2ad55e --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/54/account_validity_with_renewal.sql @@ -0,0 +1,30 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- We previously changed the schema for this table without renaming the file, which means +-- that some databases might still be using the old schema. This ensures Synapse uses the +-- right schema for the table. +DROP TABLE IF EXISTS account_validity; + +-- Track what users are in public rooms. +CREATE TABLE IF NOT EXISTS account_validity ( + user_id TEXT PRIMARY KEY, + expiration_ts_ms BIGINT NOT NULL, + email_sent BOOLEAN NOT NULL, + renewal_token TEXT +); + +CREATE INDEX account_validity_email_sent_idx ON account_validity(email_sent, expiration_ts_ms) +CREATE UNIQUE INDEX account_validity_renewal_string_idx ON account_validity(renewal_token) diff --git a/synapse/storage/data_stores/main/schema/delta/54/add_validity_to_server_keys.sql b/synapse/storage/data_stores/main/schema/delta/54/add_validity_to_server_keys.sql new file mode 100644 index 0000000000..c01aa9d2d9 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/54/add_validity_to_server_keys.sql @@ -0,0 +1,23 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* When we can use this key until, before we have to refresh it. */ +ALTER TABLE server_signature_keys ADD COLUMN ts_valid_until_ms BIGINT; + +UPDATE server_signature_keys SET ts_valid_until_ms = ( + SELECT MAX(ts_valid_until_ms) FROM server_keys_json skj WHERE + skj.server_name = server_signature_keys.server_name AND + skj.key_id = server_signature_keys.key_id +); diff --git a/synapse/storage/data_stores/main/schema/delta/54/delete_forward_extremities.sql b/synapse/storage/data_stores/main/schema/delta/54/delete_forward_extremities.sql new file mode 100644 index 0000000000..b062ec840c --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/54/delete_forward_extremities.sql @@ -0,0 +1,23 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Start a background job to cleanup extremities that were incorrectly added +-- by bug #5269. +INSERT INTO background_updates (update_name, progress_json) VALUES + ('delete_soft_failed_extremities', '{}'); + +DROP TABLE IF EXISTS _extremities_to_check; -- To make this delta schema file idempotent. +CREATE TABLE _extremities_to_check AS SELECT event_id FROM event_forward_extremities; +CREATE INDEX _extremities_to_check_id ON _extremities_to_check(event_id); diff --git a/synapse/storage/data_stores/main/schema/delta/54/drop_legacy_tables.sql b/synapse/storage/data_stores/main/schema/delta/54/drop_legacy_tables.sql new file mode 100644 index 0000000000..dbbe682697 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/54/drop_legacy_tables.sql @@ -0,0 +1,30 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- we need to do this first due to foreign constraints +DROP TABLE IF EXISTS application_services_regex; + +DROP TABLE IF EXISTS application_services; +DROP TABLE IF EXISTS transaction_id_to_pdu; +DROP TABLE IF EXISTS stats_reporting; +DROP TABLE IF EXISTS current_state_resets; +DROP TABLE IF EXISTS event_content_hashes; +DROP TABLE IF EXISTS event_destinations; +DROP TABLE IF EXISTS event_edge_hashes; +DROP TABLE IF EXISTS event_signatures; +DROP TABLE IF EXISTS feedback; +DROP TABLE IF EXISTS room_hosts; +DROP TABLE IF EXISTS server_tls_certificates; +DROP TABLE IF EXISTS state_forward_extremities; diff --git a/synapse/storage/data_stores/main/schema/delta/54/drop_presence_list.sql b/synapse/storage/data_stores/main/schema/delta/54/drop_presence_list.sql new file mode 100644 index 0000000000..e6ee70c623 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/54/drop_presence_list.sql @@ -0,0 +1,16 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +DROP TABLE IF EXISTS presence_list; diff --git a/synapse/storage/data_stores/main/schema/delta/54/relations.sql b/synapse/storage/data_stores/main/schema/delta/54/relations.sql new file mode 100644 index 0000000000..134862b870 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/54/relations.sql @@ -0,0 +1,27 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Tracks related events, like reactions, replies, edits, etc. Note that things +-- in this table are not necessarily "valid", e.g. it may contain edits from +-- people who don't have power to edit other peoples events. +CREATE TABLE IF NOT EXISTS event_relations ( + event_id TEXT NOT NULL, + relates_to_id TEXT NOT NULL, + relation_type TEXT NOT NULL, + aggregation_key TEXT +); + +CREATE UNIQUE INDEX event_relations_id ON event_relations(event_id); +CREATE INDEX event_relations_relates ON event_relations(relates_to_id, relation_type, aggregation_key); diff --git a/synapse/storage/data_stores/main/schema/delta/54/stats.sql b/synapse/storage/data_stores/main/schema/delta/54/stats.sql new file mode 100644 index 0000000000..652e58308e --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/54/stats.sql @@ -0,0 +1,80 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE stats_stream_pos ( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + stream_id BIGINT, + CHECK (Lock='X') +); + +INSERT INTO stats_stream_pos (stream_id) VALUES (null); + +CREATE TABLE user_stats ( + user_id TEXT NOT NULL, + ts BIGINT NOT NULL, + bucket_size INT NOT NULL, + public_rooms INT NOT NULL, + private_rooms INT NOT NULL +); + +CREATE UNIQUE INDEX user_stats_user_ts ON user_stats(user_id, ts); + +CREATE TABLE room_stats ( + room_id TEXT NOT NULL, + ts BIGINT NOT NULL, + bucket_size INT NOT NULL, + current_state_events INT NOT NULL, + joined_members INT NOT NULL, + invited_members INT NOT NULL, + left_members INT NOT NULL, + banned_members INT NOT NULL, + state_events INT NOT NULL +); + +CREATE UNIQUE INDEX room_stats_room_ts ON room_stats(room_id, ts); + +-- cache of current room state; useful for the publicRooms list +CREATE TABLE room_state ( + room_id TEXT NOT NULL, + join_rules TEXT, + history_visibility TEXT, + encryption TEXT, + name TEXT, + topic TEXT, + avatar TEXT, + canonical_alias TEXT + -- get aliases straight from the right table +); + +CREATE UNIQUE INDEX room_state_room ON room_state(room_id); + +CREATE TABLE room_stats_earliest_token ( + room_id TEXT NOT NULL, + token BIGINT NOT NULL +); + +CREATE UNIQUE INDEX room_stats_earliest_token_idx ON room_stats_earliest_token(room_id); + +-- Set up staging tables +INSERT INTO background_updates (update_name, progress_json) VALUES + ('populate_stats_createtables', '{}'); + +-- Run through each room and update stats +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('populate_stats_process_rooms', '{}', 'populate_stats_createtables'); + +-- Clean up staging tables +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('populate_stats_cleanup', '{}', 'populate_stats_process_rooms'); diff --git a/synapse/storage/data_stores/main/schema/delta/54/stats2.sql b/synapse/storage/data_stores/main/schema/delta/54/stats2.sql new file mode 100644 index 0000000000..3b2d48447f --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/54/stats2.sql @@ -0,0 +1,28 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- This delta file gets run after `54/stats.sql` delta. + +-- We want to add some indices to the temporary stats table, so we re-insert +-- 'populate_stats_createtables' if we are still processing the rooms update. +INSERT INTO background_updates (update_name, progress_json) + SELECT 'populate_stats_createtables', '{}' + WHERE + 'populate_stats_process_rooms' IN ( + SELECT update_name FROM background_updates + ) + AND 'populate_stats_createtables' NOT IN ( -- don't insert if already exists + SELECT update_name FROM background_updates + ); diff --git a/synapse/storage/data_stores/main/schema/delta/55/access_token_expiry.sql b/synapse/storage/data_stores/main/schema/delta/55/access_token_expiry.sql new file mode 100644 index 0000000000..4590604bfd --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/55/access_token_expiry.sql @@ -0,0 +1,18 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- when this access token can be used until, in ms since the epoch. NULL means the token +-- never expires. +ALTER TABLE access_tokens ADD COLUMN valid_until_ms BIGINT; diff --git a/synapse/storage/data_stores/main/schema/delta/55/track_threepid_validations.sql b/synapse/storage/data_stores/main/schema/delta/55/track_threepid_validations.sql new file mode 100644 index 0000000000..a8eced2e0a --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/55/track_threepid_validations.sql @@ -0,0 +1,31 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +CREATE TABLE IF NOT EXISTS threepid_validation_session ( + session_id TEXT PRIMARY KEY, + medium TEXT NOT NULL, + address TEXT NOT NULL, + client_secret TEXT NOT NULL, + last_send_attempt BIGINT NOT NULL, + validated_at BIGINT +); + +CREATE TABLE IF NOT EXISTS threepid_validation_token ( + token TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + next_link TEXT, + expires BIGINT NOT NULL +); + +CREATE INDEX threepid_validation_token_session_id ON threepid_validation_token(session_id); diff --git a/synapse/storage/data_stores/main/schema/delta/55/users_alter_deactivated.sql b/synapse/storage/data_stores/main/schema/delta/55/users_alter_deactivated.sql new file mode 100644 index 0000000000..dabdde489b --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/55/users_alter_deactivated.sql @@ -0,0 +1,19 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE users ADD deactivated SMALLINT DEFAULT 0 NOT NULL; + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('users_set_deactivated_flag', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/56/add_spans_to_device_lists.sql b/synapse/storage/data_stores/main/schema/delta/56/add_spans_to_device_lists.sql new file mode 100644 index 0000000000..41807eb1e7 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/add_spans_to_device_lists.sql @@ -0,0 +1,20 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Opentracing context data for inclusion in the device_list_update EDUs, as a + * json-encoded dictionary. NULL if opentracing is disabled (or not enabled for this destination). + */ +ALTER TABLE device_lists_outbound_pokes ADD opentracing_context TEXT; diff --git a/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership.sql b/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership.sql new file mode 100644 index 0000000000..473018676f --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership.sql @@ -0,0 +1,22 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- We add membership to current state so that we don't need to join against +-- room_memberships, which can be surprisingly costly (we do such queries +-- very frequently). +-- This will be null for non-membership events and the content.membership key +-- for membership events. (Will also be null for membership events until the +-- background update job has finished). +ALTER TABLE current_state_events ADD membership TEXT; diff --git a/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership_mk2.sql b/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership_mk2.sql new file mode 100644 index 0000000000..3133d42d4a --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership_mk2.sql @@ -0,0 +1,24 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- We add membership to current state so that we don't need to join against +-- room_memberships, which can be surprisingly costly (we do such queries +-- very frequently). +-- This will be null for non-membership events and the content.membership key +-- for membership events. (Will also be null for membership events until the +-- background update job has finished). + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('current_state_events_membership', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/56/destinations_failure_ts.sql b/synapse/storage/data_stores/main/schema/delta/56/destinations_failure_ts.sql new file mode 100644 index 0000000000..f00889290b --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/destinations_failure_ts.sql @@ -0,0 +1,25 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Record the timestamp when a given server started failing + */ +ALTER TABLE destinations ADD failure_ts BIGINT; + +/* as a rough approximation, we assume that the server started failing at + * retry_interval before the last retry + */ +UPDATE destinations SET failure_ts = retry_last_ts - retry_interval + WHERE retry_last_ts > 0; diff --git a/synapse/storage/data_stores/main/schema/delta/56/destinations_retry_interval_type.sql.postgres b/synapse/storage/data_stores/main/schema/delta/56/destinations_retry_interval_type.sql.postgres new file mode 100644 index 0000000000..b9bbb18a91 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/destinations_retry_interval_type.sql.postgres @@ -0,0 +1,18 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- We want to store large retry intervals so we upgrade the column from INT +-- to BIGINT. We don't need to do this on SQLite. +ALTER TABLE destinations ALTER retry_interval SET DATA TYPE BIGINT; diff --git a/synapse/storage/data_stores/main/schema/delta/56/devices_last_seen.sql b/synapse/storage/data_stores/main/schema/delta/56/devices_last_seen.sql new file mode 100644 index 0000000000..dfa902d0ba --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/devices_last_seen.sql @@ -0,0 +1,24 @@ +/* Copyright 2019 Matrix.org Foundation CIC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Track last seen information for a device in the devices table, rather +-- than relying on it being in the user_ips table (which we want to be able +-- to purge old entries from) +ALTER TABLE devices ADD COLUMN last_seen BIGINT; +ALTER TABLE devices ADD COLUMN ip TEXT; +ALTER TABLE devices ADD COLUMN user_agent TEXT; + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('devices_last_seen', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/56/drop_unused_event_tables.sql b/synapse/storage/data_stores/main/schema/delta/56/drop_unused_event_tables.sql new file mode 100644 index 0000000000..9f09922c67 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/drop_unused_event_tables.sql @@ -0,0 +1,20 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- these tables are never used. +DROP TABLE IF EXISTS room_names; +DROP TABLE IF EXISTS topics; +DROP TABLE IF EXISTS history_visibility; +DROP TABLE IF EXISTS guest_access; diff --git a/synapse/storage/data_stores/main/schema/delta/56/fix_room_keys_index.sql b/synapse/storage/data_stores/main/schema/delta/56/fix_room_keys_index.sql new file mode 100644 index 0000000000..014cb3b538 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/fix_room_keys_index.sql @@ -0,0 +1,18 @@ +/* Copyright 2019 Matrix.org Foundation CIC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- version is supposed to be part of the room keys index +CREATE UNIQUE INDEX e2e_room_keys_with_version_idx ON e2e_room_keys(user_id, version, room_id, session_id); +DROP INDEX IF EXISTS e2e_room_keys_idx; diff --git a/synapse/storage/data_stores/main/schema/delta/56/public_room_list_idx.sql b/synapse/storage/data_stores/main/schema/delta/56/public_room_list_idx.sql new file mode 100644 index 0000000000..7be31ffebb --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/public_room_list_idx.sql @@ -0,0 +1,16 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE INDEX public_room_list_stream_network ON public_room_list_stream (appservice_id, network_id, room_id); diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql new file mode 100644 index 0000000000..fe51b02309 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql @@ -0,0 +1,17 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE redactions ADD COLUMN have_censored BOOL NOT NULL DEFAULT false; +CREATE INDEX redactions_have_censored ON redactions(event_id) WHERE not have_censored; diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql new file mode 100644 index 0000000000..77a5eca499 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql @@ -0,0 +1,20 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE redactions ADD COLUMN received_ts BIGINT; +CREATE INDEX redactions_have_censored_ts ON redactions(received_ts) WHERE not have_censored; + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('redactions_received_ts', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres new file mode 100644 index 0000000000..67471f3ef5 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres @@ -0,0 +1,25 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- There was a bug where we may have updated censored redactions as bytes, +-- which can (somehow) cause json to be inserted hex encoded. These updates go +-- and undoes any such hex encoded JSON. + +INSERT into background_updates (update_name, progress_json) + VALUES ('event_fix_redactions_bytes_create_index', '{}'); + +INSERT into background_updates (update_name, progress_json, depends_on) + VALUES ('event_fix_redactions_bytes', '{}', 'event_fix_redactions_bytes_create_index'); diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_membership_idx.sql b/synapse/storage/data_stores/main/schema/delta/56/room_membership_idx.sql new file mode 100644 index 0000000000..92ab1f5e65 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/room_membership_idx.sql @@ -0,0 +1,18 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Adds an index on room_memberships for fetching all forgotten rooms for a user +INSERT INTO background_updates (update_name, progress_json) VALUES + ('room_membership_forgotten_idx', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql b/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql new file mode 100644 index 0000000000..163529c071 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql @@ -0,0 +1,152 @@ +/* Copyright 2018 New Vector Ltd + * Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +----- First clean up from previous versions of room stats. + +-- First remove old stats stuff +DROP TABLE IF EXISTS room_stats; +DROP TABLE IF EXISTS room_state; +DROP TABLE IF EXISTS room_stats_state; +DROP TABLE IF EXISTS user_stats; +DROP TABLE IF EXISTS room_stats_earliest_tokens; +DROP TABLE IF EXISTS _temp_populate_stats_position; +DROP TABLE IF EXISTS _temp_populate_stats_rooms; +DROP TABLE IF EXISTS stats_stream_pos; + +-- Unschedule old background updates if they're still scheduled +DELETE FROM background_updates WHERE update_name IN ( + 'populate_stats_createtables', + 'populate_stats_process_rooms', + 'populate_stats_process_users', + 'populate_stats_cleanup' +); + +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('populate_stats_process_rooms', '{}', ''); + +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('populate_stats_process_users', '{}', 'populate_stats_process_rooms'); + +----- Create tables for our version of room stats. + +-- single-row table to track position of incremental updates +DROP TABLE IF EXISTS stats_incremental_position; +CREATE TABLE stats_incremental_position ( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + stream_id BIGINT NOT NULL, + CHECK (Lock='X') +); + +-- insert a null row and make sure it is the only one. +INSERT INTO stats_incremental_position ( + stream_id +) SELECT COALESCE(MAX(stream_ordering), 0) from events; + +-- represents PRESENT room statistics for a room +-- only holds absolute fields +DROP TABLE IF EXISTS room_stats_current; +CREATE TABLE room_stats_current ( + room_id TEXT NOT NULL PRIMARY KEY, + + -- These are absolute counts + current_state_events INT NOT NULL, + joined_members INT NOT NULL, + invited_members INT NOT NULL, + left_members INT NOT NULL, + banned_members INT NOT NULL, + + local_users_in_room INT NOT NULL, + + -- The maximum delta stream position that this row takes into account. + completed_delta_stream_id BIGINT NOT NULL +); + + +-- represents HISTORICAL room statistics for a room +DROP TABLE IF EXISTS room_stats_historical; +CREATE TABLE room_stats_historical ( + room_id TEXT NOT NULL, + -- These stats cover the time from (end_ts - bucket_size)...end_ts (in ms). + -- Note that end_ts is quantised. + end_ts BIGINT NOT NULL, + bucket_size BIGINT NOT NULL, + + -- These stats are absolute counts + current_state_events BIGINT NOT NULL, + joined_members BIGINT NOT NULL, + invited_members BIGINT NOT NULL, + left_members BIGINT NOT NULL, + banned_members BIGINT NOT NULL, + local_users_in_room BIGINT NOT NULL, + + -- These stats are per time slice + total_events BIGINT NOT NULL, + total_event_bytes BIGINT NOT NULL, + + PRIMARY KEY (room_id, end_ts) +); + +-- We use this index to speed up deletion of ancient room stats. +CREATE INDEX room_stats_historical_end_ts ON room_stats_historical (end_ts); + +-- represents PRESENT statistics for a user +-- only holds absolute fields +DROP TABLE IF EXISTS user_stats_current; +CREATE TABLE user_stats_current ( + user_id TEXT NOT NULL PRIMARY KEY, + + joined_rooms BIGINT NOT NULL, + + -- The maximum delta stream position that this row takes into account. + completed_delta_stream_id BIGINT NOT NULL +); + +-- represents HISTORICAL statistics for a user +DROP TABLE IF EXISTS user_stats_historical; +CREATE TABLE user_stats_historical ( + user_id TEXT NOT NULL, + end_ts BIGINT NOT NULL, + bucket_size BIGINT NOT NULL, + + joined_rooms BIGINT NOT NULL, + + invites_sent BIGINT NOT NULL, + rooms_created BIGINT NOT NULL, + total_events BIGINT NOT NULL, + total_event_bytes BIGINT NOT NULL, + + PRIMARY KEY (user_id, end_ts) +); + +-- We use this index to speed up deletion of ancient user stats. +CREATE INDEX user_stats_historical_end_ts ON user_stats_historical (end_ts); + + +CREATE TABLE room_stats_state ( + room_id TEXT NOT NULL, + name TEXT, + canonical_alias TEXT, + join_rules TEXT, + history_visibility TEXT, + encryption TEXT, + avatar TEXT, + guest_access TEXT, + is_federatable BOOLEAN, + topic TEXT +); + +CREATE UNIQUE INDEX room_stats_state_room ON room_stats_state(room_id); diff --git a/synapse/storage/data_stores/main/schema/delta/56/unique_user_filter_index.py b/synapse/storage/data_stores/main/schema/delta/56/unique_user_filter_index.py new file mode 100644 index 0000000000..1de8b54961 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/unique_user_filter_index.py @@ -0,0 +1,52 @@ +import logging + +from synapse.storage.engines import PostgresEngine + +logger = logging.getLogger(__name__) + + +""" +This migration updates the user_filters table as follows: + + - drops any (user_id, filter_id) duplicates + - makes the columns NON-NULLable + - turns the index into a UNIQUE index +""" + + +def run_upgrade(cur, database_engine, *args, **kwargs): + pass + + +def run_create(cur, database_engine, *args, **kwargs): + if isinstance(database_engine, PostgresEngine): + select_clause = """ + SELECT DISTINCT ON (user_id, filter_id) user_id, filter_id, filter_json + FROM user_filters + """ + else: + select_clause = """ + SELECT * FROM user_filters GROUP BY user_id, filter_id + """ + sql = """ + DROP TABLE IF EXISTS user_filters_migration; + DROP INDEX IF EXISTS user_filters_unique; + CREATE TABLE user_filters_migration ( + user_id TEXT NOT NULL, + filter_id BIGINT NOT NULL, + filter_json BYTEA NOT NULL + ); + INSERT INTO user_filters_migration (user_id, filter_id, filter_json) + %s; + CREATE UNIQUE INDEX user_filters_unique ON user_filters_migration + (user_id, filter_id); + DROP TABLE user_filters; + ALTER TABLE user_filters_migration RENAME TO user_filters; + """ % ( + select_clause, + ) + + if isinstance(database_engine, PostgresEngine): + cur.execute(sql) + else: + cur.executescript(sql) diff --git a/synapse/storage/data_stores/main/schema/delta/56/user_external_ids.sql b/synapse/storage/data_stores/main/schema/delta/56/user_external_ids.sql new file mode 100644 index 0000000000..91390c4527 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/user_external_ids.sql @@ -0,0 +1,24 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * a table which records mappings from external auth providers to mxids + */ +CREATE TABLE IF NOT EXISTS user_external_ids ( + auth_provider TEXT NOT NULL, + external_id TEXT NOT NULL, + user_id TEXT NOT NULL, + UNIQUE (auth_provider, external_id) +); diff --git a/synapse/storage/data_stores/main/schema/delta/56/users_in_public_rooms_idx.sql b/synapse/storage/data_stores/main/schema/delta/56/users_in_public_rooms_idx.sql new file mode 100644 index 0000000000..149f8be8b6 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/users_in_public_rooms_idx.sql @@ -0,0 +1,17 @@ +/* Copyright 2019 Matrix.org Foundation CIC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- this was apparently forgotten when the table was created back in delta 53. +CREATE INDEX users_in_public_rooms_r_idx ON users_in_public_rooms(room_id); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/application_services.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/application_services.sql new file mode 100644 index 0000000000..883fcd10b2 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/application_services.sql @@ -0,0 +1,37 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* We used to create tables called application_services and + * application_services_regex, but these are no longer used and are removed in + * delta 54. + */ + + +CREATE TABLE IF NOT EXISTS application_services_state( + as_id TEXT PRIMARY KEY, + state VARCHAR(5), + last_txn INTEGER +); + +CREATE TABLE IF NOT EXISTS application_services_txns( + as_id TEXT NOT NULL, + txn_id INTEGER NOT NULL, + event_ids TEXT NOT NULL, + UNIQUE(as_id, txn_id) +); + +CREATE INDEX application_services_txns_id ON application_services_txns ( + as_id +); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/event_edges.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/event_edges.sql new file mode 100644 index 0000000000..10ce2aa7a0 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/event_edges.sql @@ -0,0 +1,70 @@ +/* Copyright 2014-2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* We used to create tables called event_destinations and + * state_forward_extremities, but these are no longer used and are removed in + * delta 54. + */ + +CREATE TABLE IF NOT EXISTS event_forward_extremities( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + UNIQUE (event_id, room_id) +); + +CREATE INDEX ev_extrem_room ON event_forward_extremities(room_id); +CREATE INDEX ev_extrem_id ON event_forward_extremities(event_id); + + +CREATE TABLE IF NOT EXISTS event_backward_extremities( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + UNIQUE (event_id, room_id) +); + +CREATE INDEX ev_b_extrem_room ON event_backward_extremities(room_id); +CREATE INDEX ev_b_extrem_id ON event_backward_extremities(event_id); + + +CREATE TABLE IF NOT EXISTS event_edges( + event_id TEXT NOT NULL, + prev_event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + is_state BOOL NOT NULL, -- true if this is a prev_state edge rather than a regular + -- event dag edge. + UNIQUE (event_id, prev_event_id, room_id, is_state) +); + +CREATE INDEX ev_edges_id ON event_edges(event_id); +CREATE INDEX ev_edges_prev_id ON event_edges(prev_event_id); + + +CREATE TABLE IF NOT EXISTS room_depth( + room_id TEXT NOT NULL, + min_depth INTEGER NOT NULL, + UNIQUE (room_id) +); + +CREATE INDEX room_depth_room ON room_depth(room_id); + +CREATE TABLE IF NOT EXISTS event_auth( + event_id TEXT NOT NULL, + auth_id TEXT NOT NULL, + room_id TEXT NOT NULL, + UNIQUE (event_id, auth_id, room_id) +); + +CREATE INDEX evauth_edges_id ON event_auth(event_id); +CREATE INDEX evauth_edges_auth_id ON event_auth(auth_id); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/event_signatures.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/event_signatures.sql new file mode 100644 index 0000000000..95826da431 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/event_signatures.sql @@ -0,0 +1,38 @@ +/* Copyright 2014-2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + /* We used to create tables called event_content_hashes and event_edge_hashes, + * but these are no longer used and are removed in delta 54. + */ + +CREATE TABLE IF NOT EXISTS event_reference_hashes ( + event_id TEXT, + algorithm TEXT, + hash bytea, + UNIQUE (event_id, algorithm) +); + +CREATE INDEX event_reference_hashes_id ON event_reference_hashes(event_id); + + +CREATE TABLE IF NOT EXISTS event_signatures ( + event_id TEXT, + signature_name TEXT, + key_id TEXT, + signature bytea, + UNIQUE (event_id, signature_name, key_id) +); + +CREATE INDEX event_signatures_id ON event_signatures(event_id); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/im.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/im.sql new file mode 100644 index 0000000000..a1a2aa8e5b --- /dev/null +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/im.sql @@ -0,0 +1,120 @@ +/* Copyright 2014-2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* We used to create tables called room_hosts and feedback, + * but these are no longer used and are removed in delta 54. + */ + +CREATE TABLE IF NOT EXISTS events( + stream_ordering INTEGER PRIMARY KEY, + topological_ordering BIGINT NOT NULL, + event_id TEXT NOT NULL, + type TEXT NOT NULL, + room_id TEXT NOT NULL, + + -- 'content' used to be created NULLable, but as of delta 50 we drop that constraint. + -- the hack we use to drop the constraint doesn't work for an in-memory sqlite + -- database, which breaks the sytests. Hence, we no longer make it nullable. + content TEXT, + + unrecognized_keys TEXT, + processed BOOL NOT NULL, + outlier BOOL NOT NULL, + depth BIGINT DEFAULT 0 NOT NULL, + UNIQUE (event_id) +); + +CREATE INDEX events_stream_ordering ON events (stream_ordering); +CREATE INDEX events_topological_ordering ON events (topological_ordering); +CREATE INDEX events_order ON events (topological_ordering, stream_ordering); +CREATE INDEX events_room_id ON events (room_id); +CREATE INDEX events_order_room ON events ( + room_id, topological_ordering, stream_ordering +); + + +CREATE TABLE IF NOT EXISTS event_json( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + internal_metadata TEXT NOT NULL, + json TEXT NOT NULL, + UNIQUE (event_id) +); + +CREATE INDEX event_json_room_id ON event_json(room_id); + + +CREATE TABLE IF NOT EXISTS state_events( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL, + prev_state TEXT, + UNIQUE (event_id) +); + +CREATE INDEX state_events_room_id ON state_events (room_id); +CREATE INDEX state_events_type ON state_events (type); +CREATE INDEX state_events_state_key ON state_events (state_key); + + +CREATE TABLE IF NOT EXISTS current_state_events( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL, + UNIQUE (event_id), + UNIQUE (room_id, type, state_key) +); + +CREATE INDEX current_state_events_room_id ON current_state_events (room_id); +CREATE INDEX current_state_events_type ON current_state_events (type); +CREATE INDEX current_state_events_state_key ON current_state_events (state_key); + +CREATE TABLE IF NOT EXISTS room_memberships( + event_id TEXT NOT NULL, + user_id TEXT NOT NULL, + sender TEXT NOT NULL, + room_id TEXT NOT NULL, + membership TEXT NOT NULL, + UNIQUE (event_id) +); + +CREATE INDEX room_memberships_room_id ON room_memberships (room_id); +CREATE INDEX room_memberships_user_id ON room_memberships (user_id); + +CREATE TABLE IF NOT EXISTS topics( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + topic TEXT NOT NULL, + UNIQUE (event_id) +); + +CREATE INDEX topics_room_id ON topics(room_id); + +CREATE TABLE IF NOT EXISTS room_names( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + name TEXT NOT NULL, + UNIQUE (event_id) +); + +CREATE INDEX room_names_room_id ON room_names(room_id); + +CREATE TABLE IF NOT EXISTS rooms( + room_id TEXT PRIMARY KEY NOT NULL, + is_public BOOL, + creator TEXT +); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/keys.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/keys.sql new file mode 100644 index 0000000000..11cdffdbb3 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/keys.sql @@ -0,0 +1,26 @@ +/* Copyright 2014-2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- we used to create a table called server_tls_certificates, but this is no +-- longer used, and is removed in delta 54. + +CREATE TABLE IF NOT EXISTS server_signature_keys( + server_name TEXT, -- Server name. + key_id TEXT, -- Key version. + from_server TEXT, -- Which key server the key was fetched form. + ts_added_ms BIGINT, -- When the key was added. + verify_key bytea, -- NACL verification key. + UNIQUE (server_name, key_id) +); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/media_repository.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/media_repository.sql new file mode 100644 index 0000000000..8f3759bb2a --- /dev/null +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/media_repository.sql @@ -0,0 +1,68 @@ +/* Copyright 2014-2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS local_media_repository ( + media_id TEXT, -- The id used to refer to the media. + media_type TEXT, -- The MIME-type of the media. + media_length INTEGER, -- Length of the media in bytes. + created_ts BIGINT, -- When the content was uploaded in ms. + upload_name TEXT, -- The name the media was uploaded with. + user_id TEXT, -- The user who uploaded the file. + UNIQUE (media_id) +); + +CREATE TABLE IF NOT EXISTS local_media_repository_thumbnails ( + media_id TEXT, -- The id used to refer to the media. + thumbnail_width INTEGER, -- The width of the thumbnail in pixels. + thumbnail_height INTEGER, -- The height of the thumbnail in pixels. + thumbnail_type TEXT, -- The MIME-type of the thumbnail. + thumbnail_method TEXT, -- The method used to make the thumbnail. + thumbnail_length INTEGER, -- The length of the thumbnail in bytes. + UNIQUE ( + media_id, thumbnail_width, thumbnail_height, thumbnail_type + ) +); + +CREATE INDEX local_media_repository_thumbnails_media_id + ON local_media_repository_thumbnails (media_id); + +CREATE TABLE IF NOT EXISTS remote_media_cache ( + media_origin TEXT, -- The remote HS the media came from. + media_id TEXT, -- The id used to refer to the media on that server. + media_type TEXT, -- The MIME-type of the media. + created_ts BIGINT, -- When the content was uploaded in ms. + upload_name TEXT, -- The name the media was uploaded with. + media_length INTEGER, -- Length of the media in bytes. + filesystem_id TEXT, -- The name used to store the media on disk. + UNIQUE (media_origin, media_id) +); + +CREATE TABLE IF NOT EXISTS remote_media_cache_thumbnails ( + media_origin TEXT, -- The remote HS the media came from. + media_id TEXT, -- The id used to refer to the media. + thumbnail_width INTEGER, -- The width of the thumbnail in pixels. + thumbnail_height INTEGER, -- The height of the thumbnail in pixels. + thumbnail_method TEXT, -- The method used to make the thumbnail + thumbnail_type TEXT, -- The MIME-type of the thumbnail. + thumbnail_length INTEGER, -- The length of the thumbnail in bytes. + filesystem_id TEXT, -- The name used to store the media on disk. + UNIQUE ( + media_origin, media_id, thumbnail_width, thumbnail_height, + thumbnail_type + ) +); + +CREATE INDEX remote_media_cache_thumbnails_media_id + ON remote_media_cache_thumbnails (media_id); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/presence.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/presence.sql new file mode 100644 index 0000000000..01d2d8f833 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/presence.sql @@ -0,0 +1,32 @@ +/* Copyright 2014-2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +CREATE TABLE IF NOT EXISTS presence( + user_id TEXT NOT NULL, + state VARCHAR(20), + status_msg TEXT, + mtime BIGINT, -- miliseconds since last state change + UNIQUE (user_id) +); + +-- For each of /my/ users which possibly-remote users are allowed to see their +-- presence state +CREATE TABLE IF NOT EXISTS presence_allow_inbound( + observed_user_id TEXT NOT NULL, + observer_user_id TEXT NOT NULL, -- a UserID, + UNIQUE (observed_user_id, observer_user_id) +); + +-- We used to create a table called presence_list, but this is no longer used +-- and is removed in delta 54. \ No newline at end of file diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/profiles.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/profiles.sql new file mode 100644 index 0000000000..c04f4747d9 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/profiles.sql @@ -0,0 +1,20 @@ +/* Copyright 2014-2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +CREATE TABLE IF NOT EXISTS profiles( + user_id TEXT NOT NULL, + displayname TEXT, + avatar_url TEXT, + UNIQUE(user_id) +); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/push.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/push.sql new file mode 100644 index 0000000000..e44465cf45 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/push.sql @@ -0,0 +1,74 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS rejections( + event_id TEXT NOT NULL, + reason TEXT NOT NULL, + last_check TEXT NOT NULL, + UNIQUE (event_id) +); + +-- Push notification endpoints that users have configured +CREATE TABLE IF NOT EXISTS pushers ( + id BIGINT PRIMARY KEY, + user_name TEXT NOT NULL, + access_token BIGINT DEFAULT NULL, + profile_tag VARCHAR(32) NOT NULL, + kind VARCHAR(8) NOT NULL, + app_id VARCHAR(64) NOT NULL, + app_display_name VARCHAR(64) NOT NULL, + device_display_name VARCHAR(128) NOT NULL, + pushkey bytea NOT NULL, + ts BIGINT NOT NULL, + lang VARCHAR(8), + data bytea, + last_token TEXT, + last_success BIGINT, + failing_since BIGINT, + UNIQUE (app_id, pushkey) +); + +CREATE TABLE IF NOT EXISTS push_rules ( + id BIGINT PRIMARY KEY, + user_name TEXT NOT NULL, + rule_id TEXT NOT NULL, + priority_class SMALLINT NOT NULL, + priority INTEGER NOT NULL DEFAULT 0, + conditions TEXT NOT NULL, + actions TEXT NOT NULL, + UNIQUE(user_name, rule_id) +); + +CREATE INDEX push_rules_user_name on push_rules (user_name); + +CREATE TABLE IF NOT EXISTS user_filters( + user_id TEXT, + filter_id BIGINT, + filter_json bytea +); + +CREATE INDEX user_filters_by_user_id_filter_id ON user_filters( + user_id, filter_id +); + +CREATE TABLE IF NOT EXISTS push_rules_enable ( + id BIGINT PRIMARY KEY, + user_name TEXT NOT NULL, + rule_id TEXT NOT NULL, + enabled SMALLINT, + UNIQUE(user_name, rule_id) +); + +CREATE INDEX push_rules_enable_user_name on push_rules_enable (user_name); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/redactions.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/redactions.sql new file mode 100644 index 0000000000..318f0d9aa5 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/redactions.sql @@ -0,0 +1,22 @@ +/* Copyright 2014-2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +CREATE TABLE IF NOT EXISTS redactions ( + event_id TEXT NOT NULL, + redacts TEXT NOT NULL, + UNIQUE (event_id) +); + +CREATE INDEX redactions_event_id ON redactions (event_id); +CREATE INDEX redactions_redacts ON redactions (redacts); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/room_aliases.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/room_aliases.sql new file mode 100644 index 0000000000..d47da3b12f --- /dev/null +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/room_aliases.sql @@ -0,0 +1,29 @@ +/* Copyright 2014-2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS room_aliases( + room_alias TEXT NOT NULL, + room_id TEXT NOT NULL, + UNIQUE (room_alias) +); + +CREATE INDEX room_aliases_id ON room_aliases(room_id); + +CREATE TABLE IF NOT EXISTS room_alias_servers( + room_alias TEXT NOT NULL, + server TEXT NOT NULL +); + +CREATE INDEX room_alias_servers_alias ON room_alias_servers(room_alias); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/state.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/state.sql new file mode 100644 index 0000000000..96391a8f0e --- /dev/null +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/state.sql @@ -0,0 +1,40 @@ +/* Copyright 2014-2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS state_groups( + id BIGINT PRIMARY KEY, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS state_groups_state( + state_group BIGINT NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL, + event_id TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS event_to_state_groups( + event_id TEXT NOT NULL, + state_group BIGINT NOT NULL, + UNIQUE (event_id) +); + +CREATE INDEX state_groups_id ON state_groups(id); + +CREATE INDEX state_groups_state_id ON state_groups_state(state_group); +CREATE INDEX state_groups_state_tuple ON state_groups_state(room_id, type, state_key); +CREATE INDEX event_to_state_groups_id ON event_to_state_groups(event_id); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/transactions.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/transactions.sql new file mode 100644 index 0000000000..17e67bedac --- /dev/null +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/transactions.sql @@ -0,0 +1,44 @@ +/* Copyright 2014-2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +-- Stores what transaction ids we have received and what our response was +CREATE TABLE IF NOT EXISTS received_transactions( + transaction_id TEXT, + origin TEXT, + ts BIGINT, + response_code INTEGER, + response_json bytea, + has_been_referenced smallint default 0, -- Whether thishas been referenced by a prev_tx + UNIQUE (transaction_id, origin) +); + +CREATE INDEX transactions_have_ref ON received_transactions(origin, has_been_referenced);-- WHERE has_been_referenced = 0; + +-- For sent transactions only. +CREATE TABLE IF NOT EXISTS transaction_id_to_pdu( + transaction_id INTEGER, + destination TEXT, + pdu_id TEXT, + pdu_origin TEXT, + UNIQUE (transaction_id, destination) +); + +CREATE INDEX transaction_id_to_pdu_dest ON transaction_id_to_pdu(destination); + +-- To track destination health +CREATE TABLE IF NOT EXISTS destinations( + destination TEXT PRIMARY KEY, + retry_last_ts BIGINT, + retry_interval INTEGER +); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/users.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/users.sql new file mode 100644 index 0000000000..f013aa8b18 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/users.sql @@ -0,0 +1,42 @@ +/* Copyright 2014-2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +CREATE TABLE IF NOT EXISTS users( + name TEXT, + password_hash TEXT, + creation_ts BIGINT, + admin SMALLINT DEFAULT 0 NOT NULL, + UNIQUE(name) +); + +CREATE TABLE IF NOT EXISTS access_tokens( + id BIGINT PRIMARY KEY, + user_id TEXT NOT NULL, + device_id TEXT, + token TEXT NOT NULL, + last_used BIGINT, + UNIQUE(token) +); + +CREATE TABLE IF NOT EXISTS user_ips ( + user_id TEXT NOT NULL, + access_token TEXT NOT NULL, + device_id TEXT, + ip TEXT NOT NULL, + user_agent TEXT NOT NULL, + last_seen BIGINT NOT NULL +); + +CREATE INDEX user_ips_user ON user_ips(user_id); +CREATE INDEX user_ips_user_ip ON user_ips(user_id, access_token, ip); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres new file mode 100644 index 0000000000..4ad2929f32 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres @@ -0,0 +1,2035 @@ + + + + + +CREATE TABLE access_tokens ( + id bigint NOT NULL, + user_id text NOT NULL, + device_id text, + token text NOT NULL, + last_used bigint +); + + + +CREATE TABLE account_data ( + user_id text NOT NULL, + account_data_type text NOT NULL, + stream_id bigint NOT NULL, + content text NOT NULL +); + + + +CREATE TABLE account_data_max_stream_id ( + lock character(1) DEFAULT 'X'::bpchar NOT NULL, + stream_id bigint NOT NULL, + CONSTRAINT private_user_data_max_stream_id_lock_check CHECK ((lock = 'X'::bpchar)) +); + + + +CREATE TABLE account_validity ( + user_id text NOT NULL, + expiration_ts_ms bigint NOT NULL, + email_sent boolean NOT NULL, + renewal_token text +); + + + +CREATE TABLE application_services_state ( + as_id text NOT NULL, + state character varying(5), + last_txn integer +); + + + +CREATE TABLE application_services_txns ( + as_id text NOT NULL, + txn_id integer NOT NULL, + event_ids text NOT NULL +); + + + +CREATE TABLE appservice_room_list ( + appservice_id text NOT NULL, + network_id text NOT NULL, + room_id text NOT NULL +); + + + +CREATE TABLE appservice_stream_position ( + lock character(1) DEFAULT 'X'::bpchar NOT NULL, + stream_ordering bigint, + CONSTRAINT appservice_stream_position_lock_check CHECK ((lock = 'X'::bpchar)) +); + + +CREATE TABLE blocked_rooms ( + room_id text NOT NULL, + user_id text NOT NULL +); + + + +CREATE TABLE cache_invalidation_stream ( + stream_id bigint, + cache_func text, + keys text[], + invalidation_ts bigint +); + + + +CREATE TABLE current_state_delta_stream ( + stream_id bigint NOT NULL, + room_id text NOT NULL, + type text NOT NULL, + state_key text NOT NULL, + event_id text, + prev_event_id text +); + + + +CREATE TABLE current_state_events ( + event_id text NOT NULL, + room_id text NOT NULL, + type text NOT NULL, + state_key text NOT NULL +); + + + +CREATE TABLE deleted_pushers ( + stream_id bigint NOT NULL, + app_id text NOT NULL, + pushkey text NOT NULL, + user_id text NOT NULL +); + + + +CREATE TABLE destinations ( + destination text NOT NULL, + retry_last_ts bigint, + retry_interval integer +); + + + +CREATE TABLE device_federation_inbox ( + origin text NOT NULL, + message_id text NOT NULL, + received_ts bigint NOT NULL +); + + + +CREATE TABLE device_federation_outbox ( + destination text NOT NULL, + stream_id bigint NOT NULL, + queued_ts bigint NOT NULL, + messages_json text NOT NULL +); + + + +CREATE TABLE device_inbox ( + user_id text NOT NULL, + device_id text NOT NULL, + stream_id bigint NOT NULL, + message_json text NOT NULL +); + + + +CREATE TABLE device_lists_outbound_last_success ( + destination text NOT NULL, + user_id text NOT NULL, + stream_id bigint NOT NULL +); + + + +CREATE TABLE device_lists_outbound_pokes ( + destination text NOT NULL, + stream_id bigint NOT NULL, + user_id text NOT NULL, + device_id text NOT NULL, + sent boolean NOT NULL, + ts bigint NOT NULL +); + + + +CREATE TABLE device_lists_remote_cache ( + user_id text NOT NULL, + device_id text NOT NULL, + content text NOT NULL +); + + + +CREATE TABLE device_lists_remote_extremeties ( + user_id text NOT NULL, + stream_id text NOT NULL +); + + + +CREATE TABLE device_lists_stream ( + stream_id bigint NOT NULL, + user_id text NOT NULL, + device_id text NOT NULL +); + + + +CREATE TABLE device_max_stream_id ( + stream_id bigint NOT NULL +); + + + +CREATE TABLE devices ( + user_id text NOT NULL, + device_id text NOT NULL, + display_name text +); + + + +CREATE TABLE e2e_device_keys_json ( + user_id text NOT NULL, + device_id text NOT NULL, + ts_added_ms bigint NOT NULL, + key_json text NOT NULL +); + + + +CREATE TABLE e2e_one_time_keys_json ( + user_id text NOT NULL, + device_id text NOT NULL, + algorithm text NOT NULL, + key_id text NOT NULL, + ts_added_ms bigint NOT NULL, + key_json text NOT NULL +); + + + +CREATE TABLE e2e_room_keys ( + user_id text NOT NULL, + room_id text NOT NULL, + session_id text NOT NULL, + version bigint NOT NULL, + first_message_index integer, + forwarded_count integer, + is_verified boolean, + session_data text NOT NULL +); + + + +CREATE TABLE e2e_room_keys_versions ( + user_id text NOT NULL, + version bigint NOT NULL, + algorithm text NOT NULL, + auth_data text NOT NULL, + deleted smallint DEFAULT 0 NOT NULL +); + + + +CREATE TABLE erased_users ( + user_id text NOT NULL +); + + + +CREATE TABLE event_auth ( + event_id text NOT NULL, + auth_id text NOT NULL, + room_id text NOT NULL +); + + + +CREATE TABLE event_backward_extremities ( + event_id text NOT NULL, + room_id text NOT NULL +); + + + +CREATE TABLE event_edges ( + event_id text NOT NULL, + prev_event_id text NOT NULL, + room_id text NOT NULL, + is_state boolean NOT NULL +); + + + +CREATE TABLE event_forward_extremities ( + event_id text NOT NULL, + room_id text NOT NULL +); + + + +CREATE TABLE event_json ( + event_id text NOT NULL, + room_id text NOT NULL, + internal_metadata text NOT NULL, + json text NOT NULL, + format_version integer +); + + + +CREATE TABLE event_push_actions ( + room_id text NOT NULL, + event_id text NOT NULL, + user_id text NOT NULL, + profile_tag character varying(32), + actions text NOT NULL, + topological_ordering bigint, + stream_ordering bigint, + notif smallint, + highlight smallint +); + + + +CREATE TABLE event_push_actions_staging ( + event_id text NOT NULL, + user_id text NOT NULL, + actions text NOT NULL, + notif smallint NOT NULL, + highlight smallint NOT NULL +); + + + +CREATE TABLE event_push_summary ( + user_id text NOT NULL, + room_id text NOT NULL, + notif_count bigint NOT NULL, + stream_ordering bigint NOT NULL +); + + + +CREATE TABLE event_push_summary_stream_ordering ( + lock character(1) DEFAULT 'X'::bpchar NOT NULL, + stream_ordering bigint NOT NULL, + CONSTRAINT event_push_summary_stream_ordering_lock_check CHECK ((lock = 'X'::bpchar)) +); + + + +CREATE TABLE event_reference_hashes ( + event_id text, + algorithm text, + hash bytea +); + + + +CREATE TABLE event_relations ( + event_id text NOT NULL, + relates_to_id text NOT NULL, + relation_type text NOT NULL, + aggregation_key text +); + + + +CREATE TABLE event_reports ( + id bigint NOT NULL, + received_ts bigint NOT NULL, + room_id text NOT NULL, + event_id text NOT NULL, + user_id text NOT NULL, + reason text, + content text +); + + + +CREATE TABLE event_search ( + event_id text, + room_id text, + sender text, + key text, + vector tsvector, + origin_server_ts bigint, + stream_ordering bigint +); + + + +CREATE TABLE event_to_state_groups ( + event_id text NOT NULL, + state_group bigint NOT NULL +); + + + +CREATE TABLE events ( + stream_ordering integer NOT NULL, + topological_ordering bigint NOT NULL, + event_id text NOT NULL, + type text NOT NULL, + room_id text NOT NULL, + content text, + unrecognized_keys text, + processed boolean NOT NULL, + outlier boolean NOT NULL, + depth bigint DEFAULT 0 NOT NULL, + origin_server_ts bigint, + received_ts bigint, + sender text, + contains_url boolean +); + + + +CREATE TABLE ex_outlier_stream ( + event_stream_ordering bigint NOT NULL, + event_id text NOT NULL, + state_group bigint NOT NULL +); + + + +CREATE TABLE federation_stream_position ( + type text NOT NULL, + stream_id integer NOT NULL +); + + + +CREATE TABLE group_attestations_remote ( + group_id text NOT NULL, + user_id text NOT NULL, + valid_until_ms bigint NOT NULL, + attestation_json text NOT NULL +); + + + +CREATE TABLE group_attestations_renewals ( + group_id text NOT NULL, + user_id text NOT NULL, + valid_until_ms bigint NOT NULL +); + + + +CREATE TABLE group_invites ( + group_id text NOT NULL, + user_id text NOT NULL +); + + + +CREATE TABLE group_roles ( + group_id text NOT NULL, + role_id text NOT NULL, + profile text NOT NULL, + is_public boolean NOT NULL +); + + + +CREATE TABLE group_room_categories ( + group_id text NOT NULL, + category_id text NOT NULL, + profile text NOT NULL, + is_public boolean NOT NULL +); + + + +CREATE TABLE group_rooms ( + group_id text NOT NULL, + room_id text NOT NULL, + is_public boolean NOT NULL +); + + + +CREATE TABLE group_summary_roles ( + group_id text NOT NULL, + role_id text NOT NULL, + role_order bigint NOT NULL, + CONSTRAINT group_summary_roles_role_order_check CHECK ((role_order > 0)) +); + + + +CREATE TABLE group_summary_room_categories ( + group_id text NOT NULL, + category_id text NOT NULL, + cat_order bigint NOT NULL, + CONSTRAINT group_summary_room_categories_cat_order_check CHECK ((cat_order > 0)) +); + + + +CREATE TABLE group_summary_rooms ( + group_id text NOT NULL, + room_id text NOT NULL, + category_id text NOT NULL, + room_order bigint NOT NULL, + is_public boolean NOT NULL, + CONSTRAINT group_summary_rooms_room_order_check CHECK ((room_order > 0)) +); + + + +CREATE TABLE group_summary_users ( + group_id text NOT NULL, + user_id text NOT NULL, + role_id text NOT NULL, + user_order bigint NOT NULL, + is_public boolean NOT NULL +); + + + +CREATE TABLE group_users ( + group_id text NOT NULL, + user_id text NOT NULL, + is_admin boolean NOT NULL, + is_public boolean NOT NULL +); + + + +CREATE TABLE groups ( + group_id text NOT NULL, + name text, + avatar_url text, + short_description text, + long_description text, + is_public boolean NOT NULL, + join_policy text DEFAULT 'invite'::text NOT NULL +); + + + +CREATE TABLE guest_access ( + event_id text NOT NULL, + room_id text NOT NULL, + guest_access text NOT NULL +); + + + +CREATE TABLE history_visibility ( + event_id text NOT NULL, + room_id text NOT NULL, + history_visibility text NOT NULL +); + + + +CREATE TABLE local_group_membership ( + group_id text NOT NULL, + user_id text NOT NULL, + is_admin boolean NOT NULL, + membership text NOT NULL, + is_publicised boolean NOT NULL, + content text NOT NULL +); + + + +CREATE TABLE local_group_updates ( + stream_id bigint NOT NULL, + group_id text NOT NULL, + user_id text NOT NULL, + type text NOT NULL, + content text NOT NULL +); + + + +CREATE TABLE local_invites ( + stream_id bigint NOT NULL, + inviter text NOT NULL, + invitee text NOT NULL, + event_id text NOT NULL, + room_id text NOT NULL, + locally_rejected text, + replaced_by text +); + + + +CREATE TABLE local_media_repository ( + media_id text, + media_type text, + media_length integer, + created_ts bigint, + upload_name text, + user_id text, + quarantined_by text, + url_cache text, + last_access_ts bigint +); + + + +CREATE TABLE local_media_repository_thumbnails ( + media_id text, + thumbnail_width integer, + thumbnail_height integer, + thumbnail_type text, + thumbnail_method text, + thumbnail_length integer +); + + + +CREATE TABLE local_media_repository_url_cache ( + url text, + response_code integer, + etag text, + expires_ts bigint, + og text, + media_id text, + download_ts bigint +); + + + +CREATE TABLE monthly_active_users ( + user_id text NOT NULL, + "timestamp" bigint NOT NULL +); + + + +CREATE TABLE open_id_tokens ( + token text NOT NULL, + ts_valid_until_ms bigint NOT NULL, + user_id text NOT NULL +); + + + +CREATE TABLE presence ( + user_id text NOT NULL, + state character varying(20), + status_msg text, + mtime bigint +); + + + +CREATE TABLE presence_allow_inbound ( + observed_user_id text NOT NULL, + observer_user_id text NOT NULL +); + + + +CREATE TABLE presence_stream ( + stream_id bigint, + user_id text, + state text, + last_active_ts bigint, + last_federation_update_ts bigint, + last_user_sync_ts bigint, + status_msg text, + currently_active boolean +); + + + +CREATE TABLE profiles ( + user_id text NOT NULL, + displayname text, + avatar_url text +); + + + +CREATE TABLE public_room_list_stream ( + stream_id bigint NOT NULL, + room_id text NOT NULL, + visibility boolean NOT NULL, + appservice_id text, + network_id text +); + + + +CREATE TABLE push_rules ( + id bigint NOT NULL, + user_name text NOT NULL, + rule_id text NOT NULL, + priority_class smallint NOT NULL, + priority integer DEFAULT 0 NOT NULL, + conditions text NOT NULL, + actions text NOT NULL +); + + + +CREATE TABLE push_rules_enable ( + id bigint NOT NULL, + user_name text NOT NULL, + rule_id text NOT NULL, + enabled smallint +); + + + +CREATE TABLE push_rules_stream ( + stream_id bigint NOT NULL, + event_stream_ordering bigint NOT NULL, + user_id text NOT NULL, + rule_id text NOT NULL, + op text NOT NULL, + priority_class smallint, + priority integer, + conditions text, + actions text +); + + + +CREATE TABLE pusher_throttle ( + pusher bigint NOT NULL, + room_id text NOT NULL, + last_sent_ts bigint, + throttle_ms bigint +); + + + +CREATE TABLE pushers ( + id bigint NOT NULL, + user_name text NOT NULL, + access_token bigint, + profile_tag text NOT NULL, + kind text NOT NULL, + app_id text NOT NULL, + app_display_name text NOT NULL, + device_display_name text NOT NULL, + pushkey text NOT NULL, + ts bigint NOT NULL, + lang text, + data text, + last_stream_ordering integer, + last_success bigint, + failing_since bigint +); + + + +CREATE TABLE ratelimit_override ( + user_id text NOT NULL, + messages_per_second bigint, + burst_count bigint +); + + + +CREATE TABLE receipts_graph ( + room_id text NOT NULL, + receipt_type text NOT NULL, + user_id text NOT NULL, + event_ids text NOT NULL, + data text NOT NULL +); + + + +CREATE TABLE receipts_linearized ( + stream_id bigint NOT NULL, + room_id text NOT NULL, + receipt_type text NOT NULL, + user_id text NOT NULL, + event_id text NOT NULL, + data text NOT NULL +); + + + +CREATE TABLE received_transactions ( + transaction_id text, + origin text, + ts bigint, + response_code integer, + response_json bytea, + has_been_referenced smallint DEFAULT 0 +); + + + +CREATE TABLE redactions ( + event_id text NOT NULL, + redacts text NOT NULL +); + + + +CREATE TABLE rejections ( + event_id text NOT NULL, + reason text NOT NULL, + last_check text NOT NULL +); + + + +CREATE TABLE remote_media_cache ( + media_origin text, + media_id text, + media_type text, + created_ts bigint, + upload_name text, + media_length integer, + filesystem_id text, + last_access_ts bigint, + quarantined_by text +); + + + +CREATE TABLE remote_media_cache_thumbnails ( + media_origin text, + media_id text, + thumbnail_width integer, + thumbnail_height integer, + thumbnail_method text, + thumbnail_type text, + thumbnail_length integer, + filesystem_id text +); + + + +CREATE TABLE remote_profile_cache ( + user_id text NOT NULL, + displayname text, + avatar_url text, + last_check bigint NOT NULL +); + + + +CREATE TABLE room_account_data ( + user_id text NOT NULL, + room_id text NOT NULL, + account_data_type text NOT NULL, + stream_id bigint NOT NULL, + content text NOT NULL +); + + + +CREATE TABLE room_alias_servers ( + room_alias text NOT NULL, + server text NOT NULL +); + + + +CREATE TABLE room_aliases ( + room_alias text NOT NULL, + room_id text NOT NULL, + creator text +); + + + +CREATE TABLE room_depth ( + room_id text NOT NULL, + min_depth integer NOT NULL +); + + + +CREATE TABLE room_memberships ( + event_id text NOT NULL, + user_id text NOT NULL, + sender text NOT NULL, + room_id text NOT NULL, + membership text NOT NULL, + forgotten integer DEFAULT 0, + display_name text, + avatar_url text +); + + + +CREATE TABLE room_names ( + event_id text NOT NULL, + room_id text NOT NULL, + name text NOT NULL +); + + + +CREATE TABLE room_state ( + room_id text NOT NULL, + join_rules text, + history_visibility text, + encryption text, + name text, + topic text, + avatar text, + canonical_alias text +); + + + +CREATE TABLE room_stats ( + room_id text NOT NULL, + ts bigint NOT NULL, + bucket_size integer NOT NULL, + current_state_events integer NOT NULL, + joined_members integer NOT NULL, + invited_members integer NOT NULL, + left_members integer NOT NULL, + banned_members integer NOT NULL, + state_events integer NOT NULL +); + + + +CREATE TABLE room_stats_earliest_token ( + room_id text NOT NULL, + token bigint NOT NULL +); + + + +CREATE TABLE room_tags ( + user_id text NOT NULL, + room_id text NOT NULL, + tag text NOT NULL, + content text NOT NULL +); + + + +CREATE TABLE room_tags_revisions ( + user_id text NOT NULL, + room_id text NOT NULL, + stream_id bigint NOT NULL +); + + + +CREATE TABLE rooms ( + room_id text NOT NULL, + is_public boolean, + creator text +); + + + +CREATE TABLE server_keys_json ( + server_name text NOT NULL, + key_id text NOT NULL, + from_server text NOT NULL, + ts_added_ms bigint NOT NULL, + ts_valid_until_ms bigint NOT NULL, + key_json bytea NOT NULL +); + + + +CREATE TABLE server_signature_keys ( + server_name text, + key_id text, + from_server text, + ts_added_ms bigint, + verify_key bytea, + ts_valid_until_ms bigint +); + + + +CREATE TABLE state_events ( + event_id text NOT NULL, + room_id text NOT NULL, + type text NOT NULL, + state_key text NOT NULL, + prev_state text +); + + + +CREATE TABLE state_group_edges ( + state_group bigint NOT NULL, + prev_state_group bigint NOT NULL +); + + + +CREATE SEQUENCE state_group_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1; + + + +CREATE TABLE state_groups ( + id bigint NOT NULL, + room_id text NOT NULL, + event_id text NOT NULL +); + + + +CREATE TABLE state_groups_state ( + state_group bigint NOT NULL, + room_id text NOT NULL, + type text NOT NULL, + state_key text NOT NULL, + event_id text NOT NULL +); + + + +CREATE TABLE stats_stream_pos ( + lock character(1) DEFAULT 'X'::bpchar NOT NULL, + stream_id bigint, + CONSTRAINT stats_stream_pos_lock_check CHECK ((lock = 'X'::bpchar)) +); + + + +CREATE TABLE stream_ordering_to_exterm ( + stream_ordering bigint NOT NULL, + room_id text NOT NULL, + event_id text NOT NULL +); + + + +CREATE TABLE threepid_guest_access_tokens ( + medium text, + address text, + guest_access_token text, + first_inviter text +); + + + +CREATE TABLE topics ( + event_id text NOT NULL, + room_id text NOT NULL, + topic text NOT NULL +); + + + +CREATE TABLE user_daily_visits ( + user_id text NOT NULL, + device_id text, + "timestamp" bigint NOT NULL +); + + + +CREATE TABLE user_directory ( + user_id text NOT NULL, + room_id text, + display_name text, + avatar_url text +); + + + +CREATE TABLE user_directory_search ( + user_id text NOT NULL, + vector tsvector +); + + + +CREATE TABLE user_directory_stream_pos ( + lock character(1) DEFAULT 'X'::bpchar NOT NULL, + stream_id bigint, + CONSTRAINT user_directory_stream_pos_lock_check CHECK ((lock = 'X'::bpchar)) +); + + + +CREATE TABLE user_filters ( + user_id text, + filter_id bigint, + filter_json bytea +); + + + +CREATE TABLE user_ips ( + user_id text NOT NULL, + access_token text NOT NULL, + device_id text, + ip text NOT NULL, + user_agent text NOT NULL, + last_seen bigint NOT NULL +); + + + +CREATE TABLE user_stats ( + user_id text NOT NULL, + ts bigint NOT NULL, + bucket_size integer NOT NULL, + public_rooms integer NOT NULL, + private_rooms integer NOT NULL +); + + + +CREATE TABLE user_threepid_id_server ( + user_id text NOT NULL, + medium text NOT NULL, + address text NOT NULL, + id_server text NOT NULL +); + + + +CREATE TABLE user_threepids ( + user_id text NOT NULL, + medium text NOT NULL, + address text NOT NULL, + validated_at bigint NOT NULL, + added_at bigint NOT NULL +); + + + +CREATE TABLE users ( + name text, + password_hash text, + creation_ts bigint, + admin smallint DEFAULT 0 NOT NULL, + upgrade_ts bigint, + is_guest smallint DEFAULT 0 NOT NULL, + appservice_id text, + consent_version text, + consent_server_notice_sent text, + user_type text +); + + + +CREATE TABLE users_in_public_rooms ( + user_id text NOT NULL, + room_id text NOT NULL +); + + + +CREATE TABLE users_pending_deactivation ( + user_id text NOT NULL +); + + + +CREATE TABLE users_who_share_private_rooms ( + user_id text NOT NULL, + other_user_id text NOT NULL, + room_id text NOT NULL +); + + + +ALTER TABLE ONLY access_tokens + ADD CONSTRAINT access_tokens_pkey PRIMARY KEY (id); + + + +ALTER TABLE ONLY access_tokens + ADD CONSTRAINT access_tokens_token_key UNIQUE (token); + + + +ALTER TABLE ONLY account_data + ADD CONSTRAINT account_data_uniqueness UNIQUE (user_id, account_data_type); + + + +ALTER TABLE ONLY account_validity + ADD CONSTRAINT account_validity_pkey PRIMARY KEY (user_id); + + + +ALTER TABLE ONLY application_services_state + ADD CONSTRAINT application_services_state_pkey PRIMARY KEY (as_id); + + + +ALTER TABLE ONLY application_services_txns + ADD CONSTRAINT application_services_txns_as_id_txn_id_key UNIQUE (as_id, txn_id); + + + +ALTER TABLE ONLY appservice_stream_position + ADD CONSTRAINT appservice_stream_position_lock_key UNIQUE (lock); + + + +ALTER TABLE ONLY current_state_events + ADD CONSTRAINT current_state_events_event_id_key UNIQUE (event_id); + + + +ALTER TABLE ONLY current_state_events + ADD CONSTRAINT current_state_events_room_id_type_state_key_key UNIQUE (room_id, type, state_key); + + + +ALTER TABLE ONLY destinations + ADD CONSTRAINT destinations_pkey PRIMARY KEY (destination); + + + +ALTER TABLE ONLY devices + ADD CONSTRAINT device_uniqueness UNIQUE (user_id, device_id); + + + +ALTER TABLE ONLY e2e_device_keys_json + ADD CONSTRAINT e2e_device_keys_json_uniqueness UNIQUE (user_id, device_id); + + + +ALTER TABLE ONLY e2e_one_time_keys_json + ADD CONSTRAINT e2e_one_time_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm, key_id); + + + +ALTER TABLE ONLY event_backward_extremities + ADD CONSTRAINT event_backward_extremities_event_id_room_id_key UNIQUE (event_id, room_id); + + + +ALTER TABLE ONLY event_edges + ADD CONSTRAINT event_edges_event_id_prev_event_id_room_id_is_state_key UNIQUE (event_id, prev_event_id, room_id, is_state); + + + +ALTER TABLE ONLY event_forward_extremities + ADD CONSTRAINT event_forward_extremities_event_id_room_id_key UNIQUE (event_id, room_id); + + + +ALTER TABLE ONLY event_push_actions + ADD CONSTRAINT event_id_user_id_profile_tag_uniqueness UNIQUE (room_id, event_id, user_id, profile_tag); + + + +ALTER TABLE ONLY event_json + ADD CONSTRAINT event_json_event_id_key UNIQUE (event_id); + + + +ALTER TABLE ONLY event_push_summary_stream_ordering + ADD CONSTRAINT event_push_summary_stream_ordering_lock_key UNIQUE (lock); + + + +ALTER TABLE ONLY event_reference_hashes + ADD CONSTRAINT event_reference_hashes_event_id_algorithm_key UNIQUE (event_id, algorithm); + + + +ALTER TABLE ONLY event_reports + ADD CONSTRAINT event_reports_pkey PRIMARY KEY (id); + + + +ALTER TABLE ONLY event_to_state_groups + ADD CONSTRAINT event_to_state_groups_event_id_key UNIQUE (event_id); + + + +ALTER TABLE ONLY events + ADD CONSTRAINT events_event_id_key UNIQUE (event_id); + + + +ALTER TABLE ONLY events + ADD CONSTRAINT events_pkey PRIMARY KEY (stream_ordering); + + + +ALTER TABLE ONLY ex_outlier_stream + ADD CONSTRAINT ex_outlier_stream_pkey PRIMARY KEY (event_stream_ordering); + + + +ALTER TABLE ONLY group_roles + ADD CONSTRAINT group_roles_group_id_role_id_key UNIQUE (group_id, role_id); + + + +ALTER TABLE ONLY group_room_categories + ADD CONSTRAINT group_room_categories_group_id_category_id_key UNIQUE (group_id, category_id); + + + +ALTER TABLE ONLY group_summary_roles + ADD CONSTRAINT group_summary_roles_group_id_role_id_role_order_key UNIQUE (group_id, role_id, role_order); + + + +ALTER TABLE ONLY group_summary_room_categories + ADD CONSTRAINT group_summary_room_categories_group_id_category_id_cat_orde_key UNIQUE (group_id, category_id, cat_order); + + + +ALTER TABLE ONLY group_summary_rooms + ADD CONSTRAINT group_summary_rooms_group_id_category_id_room_id_room_order_key UNIQUE (group_id, category_id, room_id, room_order); + + + +ALTER TABLE ONLY guest_access + ADD CONSTRAINT guest_access_event_id_key UNIQUE (event_id); + + + +ALTER TABLE ONLY history_visibility + ADD CONSTRAINT history_visibility_event_id_key UNIQUE (event_id); + + + +ALTER TABLE ONLY local_media_repository + ADD CONSTRAINT local_media_repository_media_id_key UNIQUE (media_id); + + + +ALTER TABLE ONLY local_media_repository_thumbnails + ADD CONSTRAINT local_media_repository_thumbn_media_id_thumbnail_width_thum_key UNIQUE (media_id, thumbnail_width, thumbnail_height, thumbnail_type); + + + +ALTER TABLE ONLY user_threepids + ADD CONSTRAINT medium_address UNIQUE (medium, address); + + + +ALTER TABLE ONLY open_id_tokens + ADD CONSTRAINT open_id_tokens_pkey PRIMARY KEY (token); + + + +ALTER TABLE ONLY presence_allow_inbound + ADD CONSTRAINT presence_allow_inbound_observed_user_id_observer_user_id_key UNIQUE (observed_user_id, observer_user_id); + + + +ALTER TABLE ONLY presence + ADD CONSTRAINT presence_user_id_key UNIQUE (user_id); + + + +ALTER TABLE ONLY account_data_max_stream_id + ADD CONSTRAINT private_user_data_max_stream_id_lock_key UNIQUE (lock); + + + +ALTER TABLE ONLY profiles + ADD CONSTRAINT profiles_user_id_key UNIQUE (user_id); + + + +ALTER TABLE ONLY push_rules_enable + ADD CONSTRAINT push_rules_enable_pkey PRIMARY KEY (id); + + + +ALTER TABLE ONLY push_rules_enable + ADD CONSTRAINT push_rules_enable_user_name_rule_id_key UNIQUE (user_name, rule_id); + + + +ALTER TABLE ONLY push_rules + ADD CONSTRAINT push_rules_pkey PRIMARY KEY (id); + + + +ALTER TABLE ONLY push_rules + ADD CONSTRAINT push_rules_user_name_rule_id_key UNIQUE (user_name, rule_id); + + + +ALTER TABLE ONLY pusher_throttle + ADD CONSTRAINT pusher_throttle_pkey PRIMARY KEY (pusher, room_id); + + + +ALTER TABLE ONLY pushers + ADD CONSTRAINT pushers2_app_id_pushkey_user_name_key UNIQUE (app_id, pushkey, user_name); + + + +ALTER TABLE ONLY pushers + ADD CONSTRAINT pushers2_pkey PRIMARY KEY (id); + + + +ALTER TABLE ONLY receipts_graph + ADD CONSTRAINT receipts_graph_uniqueness UNIQUE (room_id, receipt_type, user_id); + + + +ALTER TABLE ONLY receipts_linearized + ADD CONSTRAINT receipts_linearized_uniqueness UNIQUE (room_id, receipt_type, user_id); + + + +ALTER TABLE ONLY received_transactions + ADD CONSTRAINT received_transactions_transaction_id_origin_key UNIQUE (transaction_id, origin); + + + +ALTER TABLE ONLY redactions + ADD CONSTRAINT redactions_event_id_key UNIQUE (event_id); + + + +ALTER TABLE ONLY rejections + ADD CONSTRAINT rejections_event_id_key UNIQUE (event_id); + + + +ALTER TABLE ONLY remote_media_cache + ADD CONSTRAINT remote_media_cache_media_origin_media_id_key UNIQUE (media_origin, media_id); + + + +ALTER TABLE ONLY remote_media_cache_thumbnails + ADD CONSTRAINT remote_media_cache_thumbnails_media_origin_media_id_thumbna_key UNIQUE (media_origin, media_id, thumbnail_width, thumbnail_height, thumbnail_type); + + + +ALTER TABLE ONLY room_account_data + ADD CONSTRAINT room_account_data_uniqueness UNIQUE (user_id, room_id, account_data_type); + + + +ALTER TABLE ONLY room_aliases + ADD CONSTRAINT room_aliases_room_alias_key UNIQUE (room_alias); + + + +ALTER TABLE ONLY room_depth + ADD CONSTRAINT room_depth_room_id_key UNIQUE (room_id); + + + +ALTER TABLE ONLY room_memberships + ADD CONSTRAINT room_memberships_event_id_key UNIQUE (event_id); + + + +ALTER TABLE ONLY room_names + ADD CONSTRAINT room_names_event_id_key UNIQUE (event_id); + + + +ALTER TABLE ONLY room_tags_revisions + ADD CONSTRAINT room_tag_revisions_uniqueness UNIQUE (user_id, room_id); + + + +ALTER TABLE ONLY room_tags + ADD CONSTRAINT room_tag_uniqueness UNIQUE (user_id, room_id, tag); + + + +ALTER TABLE ONLY rooms + ADD CONSTRAINT rooms_pkey PRIMARY KEY (room_id); + + + +ALTER TABLE ONLY server_keys_json + ADD CONSTRAINT server_keys_json_uniqueness UNIQUE (server_name, key_id, from_server); + + + +ALTER TABLE ONLY server_signature_keys + ADD CONSTRAINT server_signature_keys_server_name_key_id_key UNIQUE (server_name, key_id); + + + +ALTER TABLE ONLY state_events + ADD CONSTRAINT state_events_event_id_key UNIQUE (event_id); + + + +ALTER TABLE ONLY state_groups + ADD CONSTRAINT state_groups_pkey PRIMARY KEY (id); + + + +ALTER TABLE ONLY stats_stream_pos + ADD CONSTRAINT stats_stream_pos_lock_key UNIQUE (lock); + + + +ALTER TABLE ONLY topics + ADD CONSTRAINT topics_event_id_key UNIQUE (event_id); + + + +ALTER TABLE ONLY user_directory_stream_pos + ADD CONSTRAINT user_directory_stream_pos_lock_key UNIQUE (lock); + + + +ALTER TABLE ONLY users + ADD CONSTRAINT users_name_key UNIQUE (name); + + + +CREATE INDEX access_tokens_device_id ON access_tokens USING btree (user_id, device_id); + + + +CREATE INDEX account_data_stream_id ON account_data USING btree (user_id, stream_id); + + + +CREATE INDEX application_services_txns_id ON application_services_txns USING btree (as_id); + + + +CREATE UNIQUE INDEX appservice_room_list_idx ON appservice_room_list USING btree (appservice_id, network_id, room_id); + + + +CREATE UNIQUE INDEX blocked_rooms_idx ON blocked_rooms USING btree (room_id); + + + +CREATE INDEX cache_invalidation_stream_id ON cache_invalidation_stream USING btree (stream_id); + + + +CREATE INDEX current_state_delta_stream_idx ON current_state_delta_stream USING btree (stream_id); + + + +CREATE INDEX current_state_events_member_index ON current_state_events USING btree (state_key) WHERE (type = 'm.room.member'::text); + + + +CREATE INDEX deleted_pushers_stream_id ON deleted_pushers USING btree (stream_id); + + + +CREATE INDEX device_federation_inbox_sender_id ON device_federation_inbox USING btree (origin, message_id); + + + +CREATE INDEX device_federation_outbox_destination_id ON device_federation_outbox USING btree (destination, stream_id); + + + +CREATE INDEX device_federation_outbox_id ON device_federation_outbox USING btree (stream_id); + + + +CREATE INDEX device_inbox_stream_id_user_id ON device_inbox USING btree (stream_id, user_id); + + + +CREATE INDEX device_inbox_user_stream_id ON device_inbox USING btree (user_id, device_id, stream_id); + + + +CREATE INDEX device_lists_outbound_last_success_idx ON device_lists_outbound_last_success USING btree (destination, user_id, stream_id); + + + +CREATE INDEX device_lists_outbound_pokes_id ON device_lists_outbound_pokes USING btree (destination, stream_id); + + + +CREATE INDEX device_lists_outbound_pokes_stream ON device_lists_outbound_pokes USING btree (stream_id); + + + +CREATE INDEX device_lists_outbound_pokes_user ON device_lists_outbound_pokes USING btree (destination, user_id); + + + +CREATE UNIQUE INDEX device_lists_remote_cache_unique_id ON device_lists_remote_cache USING btree (user_id, device_id); + + + +CREATE UNIQUE INDEX device_lists_remote_extremeties_unique_idx ON device_lists_remote_extremeties USING btree (user_id); + + + +CREATE INDEX device_lists_stream_id ON device_lists_stream USING btree (stream_id, user_id); + + + +CREATE INDEX device_lists_stream_user_id ON device_lists_stream USING btree (user_id, device_id); + + + +CREATE UNIQUE INDEX e2e_room_keys_idx ON e2e_room_keys USING btree (user_id, room_id, session_id); + + + +CREATE UNIQUE INDEX e2e_room_keys_versions_idx ON e2e_room_keys_versions USING btree (user_id, version); + + + +CREATE UNIQUE INDEX erased_users_user ON erased_users USING btree (user_id); + + + +CREATE INDEX ev_b_extrem_id ON event_backward_extremities USING btree (event_id); + + + +CREATE INDEX ev_b_extrem_room ON event_backward_extremities USING btree (room_id); + + + +CREATE INDEX ev_edges_id ON event_edges USING btree (event_id); + + + +CREATE INDEX ev_edges_prev_id ON event_edges USING btree (prev_event_id); + + + +CREATE INDEX ev_extrem_id ON event_forward_extremities USING btree (event_id); + + + +CREATE INDEX ev_extrem_room ON event_forward_extremities USING btree (room_id); + + + +CREATE INDEX evauth_edges_id ON event_auth USING btree (event_id); + + + +CREATE INDEX event_contains_url_index ON events USING btree (room_id, topological_ordering, stream_ordering) WHERE ((contains_url = true) AND (outlier = false)); + + + +CREATE INDEX event_json_room_id ON event_json USING btree (room_id); + + + +CREATE INDEX event_push_actions_highlights_index ON event_push_actions USING btree (user_id, room_id, topological_ordering, stream_ordering) WHERE (highlight = 1); + + + +CREATE INDEX event_push_actions_rm_tokens ON event_push_actions USING btree (user_id, room_id, topological_ordering, stream_ordering); + + + +CREATE INDEX event_push_actions_room_id_user_id ON event_push_actions USING btree (room_id, user_id); + + + +CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging USING btree (event_id); + + + +CREATE INDEX event_push_actions_stream_ordering ON event_push_actions USING btree (stream_ordering, user_id); + + + +CREATE INDEX event_push_actions_u_highlight ON event_push_actions USING btree (user_id, stream_ordering); + + + +CREATE INDEX event_push_summary_user_rm ON event_push_summary USING btree (user_id, room_id); + + + +CREATE INDEX event_reference_hashes_id ON event_reference_hashes USING btree (event_id); + + + +CREATE UNIQUE INDEX event_relations_id ON event_relations USING btree (event_id); + + + +CREATE INDEX event_relations_relates ON event_relations USING btree (relates_to_id, relation_type, aggregation_key); + + + +CREATE INDEX event_search_ev_ridx ON event_search USING btree (room_id); + + + +CREATE UNIQUE INDEX event_search_event_id_idx ON event_search USING btree (event_id); + + + +CREATE INDEX event_search_fts_idx ON event_search USING gin (vector); + + + +CREATE INDEX event_to_state_groups_sg_index ON event_to_state_groups USING btree (state_group); + + + +CREATE INDEX events_order_room ON events USING btree (room_id, topological_ordering, stream_ordering); + + + +CREATE INDEX events_room_stream ON events USING btree (room_id, stream_ordering); + + + +CREATE INDEX events_ts ON events USING btree (origin_server_ts, stream_ordering); + + + +CREATE INDEX group_attestations_remote_g_idx ON group_attestations_remote USING btree (group_id, user_id); + + + +CREATE INDEX group_attestations_remote_u_idx ON group_attestations_remote USING btree (user_id); + + + +CREATE INDEX group_attestations_remote_v_idx ON group_attestations_remote USING btree (valid_until_ms); + + + +CREATE INDEX group_attestations_renewals_g_idx ON group_attestations_renewals USING btree (group_id, user_id); + + + +CREATE INDEX group_attestations_renewals_u_idx ON group_attestations_renewals USING btree (user_id); + + + +CREATE INDEX group_attestations_renewals_v_idx ON group_attestations_renewals USING btree (valid_until_ms); + + + +CREATE UNIQUE INDEX group_invites_g_idx ON group_invites USING btree (group_id, user_id); + + + +CREATE INDEX group_invites_u_idx ON group_invites USING btree (user_id); + + + +CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms USING btree (group_id, room_id); + + + +CREATE INDEX group_rooms_r_idx ON group_rooms USING btree (room_id); + + + +CREATE UNIQUE INDEX group_summary_rooms_g_idx ON group_summary_rooms USING btree (group_id, room_id, category_id); + + + +CREATE INDEX group_summary_users_g_idx ON group_summary_users USING btree (group_id); + + + +CREATE UNIQUE INDEX group_users_g_idx ON group_users USING btree (group_id, user_id); + + + +CREATE INDEX group_users_u_idx ON group_users USING btree (user_id); + + + +CREATE UNIQUE INDEX groups_idx ON groups USING btree (group_id); + + + +CREATE INDEX local_group_membership_g_idx ON local_group_membership USING btree (group_id); + + + +CREATE INDEX local_group_membership_u_idx ON local_group_membership USING btree (user_id, group_id); + + + +CREATE INDEX local_invites_for_user_idx ON local_invites USING btree (invitee, locally_rejected, replaced_by, room_id); + + + +CREATE INDEX local_invites_id ON local_invites USING btree (stream_id); + + + +CREATE INDEX local_media_repository_thumbnails_media_id ON local_media_repository_thumbnails USING btree (media_id); + + + +CREATE INDEX local_media_repository_url_cache_by_url_download_ts ON local_media_repository_url_cache USING btree (url, download_ts); + + + +CREATE INDEX local_media_repository_url_cache_expires_idx ON local_media_repository_url_cache USING btree (expires_ts); + + + +CREATE INDEX local_media_repository_url_cache_media_idx ON local_media_repository_url_cache USING btree (media_id); + + + +CREATE INDEX local_media_repository_url_idx ON local_media_repository USING btree (created_ts) WHERE (url_cache IS NOT NULL); + + + +CREATE INDEX monthly_active_users_time_stamp ON monthly_active_users USING btree ("timestamp"); + + + +CREATE UNIQUE INDEX monthly_active_users_users ON monthly_active_users USING btree (user_id); + + + +CREATE INDEX open_id_tokens_ts_valid_until_ms ON open_id_tokens USING btree (ts_valid_until_ms); + + + +CREATE INDEX presence_stream_id ON presence_stream USING btree (stream_id, user_id); + + + +CREATE INDEX presence_stream_user_id ON presence_stream USING btree (user_id); + + + +CREATE INDEX public_room_index ON rooms USING btree (is_public); + + + +CREATE INDEX public_room_list_stream_idx ON public_room_list_stream USING btree (stream_id); + + + +CREATE INDEX public_room_list_stream_rm_idx ON public_room_list_stream USING btree (room_id, stream_id); + + + +CREATE INDEX push_rules_enable_user_name ON push_rules_enable USING btree (user_name); + + + +CREATE INDEX push_rules_stream_id ON push_rules_stream USING btree (stream_id); + + + +CREATE INDEX push_rules_stream_user_stream_id ON push_rules_stream USING btree (user_id, stream_id); + + + +CREATE INDEX push_rules_user_name ON push_rules USING btree (user_name); + + + +CREATE UNIQUE INDEX ratelimit_override_idx ON ratelimit_override USING btree (user_id); + + + +CREATE INDEX receipts_linearized_id ON receipts_linearized USING btree (stream_id); + + + +CREATE INDEX receipts_linearized_room_stream ON receipts_linearized USING btree (room_id, stream_id); + + + +CREATE INDEX receipts_linearized_user ON receipts_linearized USING btree (user_id); + + + +CREATE INDEX received_transactions_ts ON received_transactions USING btree (ts); + + + +CREATE INDEX redactions_redacts ON redactions USING btree (redacts); + + + +CREATE INDEX remote_profile_cache_time ON remote_profile_cache USING btree (last_check); + + + +CREATE UNIQUE INDEX remote_profile_cache_user_id ON remote_profile_cache USING btree (user_id); + + + +CREATE INDEX room_account_data_stream_id ON room_account_data USING btree (user_id, stream_id); + + + +CREATE INDEX room_alias_servers_alias ON room_alias_servers USING btree (room_alias); + + + +CREATE INDEX room_aliases_id ON room_aliases USING btree (room_id); + + + +CREATE INDEX room_depth_room ON room_depth USING btree (room_id); + + + +CREATE INDEX room_memberships_room_id ON room_memberships USING btree (room_id); + + + +CREATE INDEX room_memberships_user_id ON room_memberships USING btree (user_id); + + + +CREATE INDEX room_names_room_id ON room_names USING btree (room_id); + + + +CREATE UNIQUE INDEX room_state_room ON room_state USING btree (room_id); + + + +CREATE UNIQUE INDEX room_stats_earliest_token_idx ON room_stats_earliest_token USING btree (room_id); + + + +CREATE UNIQUE INDEX room_stats_room_ts ON room_stats USING btree (room_id, ts); + + + +CREATE INDEX state_group_edges_idx ON state_group_edges USING btree (state_group); + + + +CREATE INDEX state_group_edges_prev_idx ON state_group_edges USING btree (prev_state_group); + + + +CREATE INDEX state_groups_state_type_idx ON state_groups_state USING btree (state_group, type, state_key); + + + +CREATE INDEX stream_ordering_to_exterm_idx ON stream_ordering_to_exterm USING btree (stream_ordering); + + + +CREATE INDEX stream_ordering_to_exterm_rm_idx ON stream_ordering_to_exterm USING btree (room_id, stream_ordering); + + + +CREATE UNIQUE INDEX threepid_guest_access_tokens_index ON threepid_guest_access_tokens USING btree (medium, address); + + + +CREATE INDEX topics_room_id ON topics USING btree (room_id); + + + +CREATE INDEX user_daily_visits_ts_idx ON user_daily_visits USING btree ("timestamp"); + + + +CREATE INDEX user_daily_visits_uts_idx ON user_daily_visits USING btree (user_id, "timestamp"); + + + +CREATE INDEX user_directory_room_idx ON user_directory USING btree (room_id); + + + +CREATE INDEX user_directory_search_fts_idx ON user_directory_search USING gin (vector); + + + +CREATE UNIQUE INDEX user_directory_search_user_idx ON user_directory_search USING btree (user_id); + + + +CREATE UNIQUE INDEX user_directory_user_idx ON user_directory USING btree (user_id); + + + +CREATE INDEX user_filters_by_user_id_filter_id ON user_filters USING btree (user_id, filter_id); + + + +CREATE INDEX user_ips_device_id ON user_ips USING btree (user_id, device_id, last_seen); + + + +CREATE INDEX user_ips_last_seen ON user_ips USING btree (user_id, last_seen); + + + +CREATE INDEX user_ips_last_seen_only ON user_ips USING btree (last_seen); + + + +CREATE UNIQUE INDEX user_ips_user_token_ip_unique_index ON user_ips USING btree (user_id, access_token, ip); + + + +CREATE UNIQUE INDEX user_stats_user_ts ON user_stats USING btree (user_id, ts); + + + +CREATE UNIQUE INDEX user_threepid_id_server_idx ON user_threepid_id_server USING btree (user_id, medium, address, id_server); + + + +CREATE INDEX user_threepids_medium_address ON user_threepids USING btree (medium, address); + + + +CREATE INDEX user_threepids_user_id ON user_threepids USING btree (user_id); + + + +CREATE INDEX users_creation_ts ON users USING btree (creation_ts); + + + +CREATE UNIQUE INDEX users_in_public_rooms_u_idx ON users_in_public_rooms USING btree (user_id, room_id); + + + +CREATE INDEX users_who_share_private_rooms_o_idx ON users_who_share_private_rooms USING btree (other_user_id); + + + +CREATE INDEX users_who_share_private_rooms_r_idx ON users_who_share_private_rooms USING btree (room_id); + + + +CREATE UNIQUE INDEX users_who_share_private_rooms_u_idx ON users_who_share_private_rooms USING btree (user_id, other_user_id, room_id); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite new file mode 100644 index 0000000000..bad33291e7 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite @@ -0,0 +1,259 @@ +CREATE TABLE application_services_state( as_id TEXT PRIMARY KEY, state VARCHAR(5), last_txn INTEGER ); +CREATE TABLE application_services_txns( as_id TEXT NOT NULL, txn_id INTEGER NOT NULL, event_ids TEXT NOT NULL, UNIQUE(as_id, txn_id) ); +CREATE INDEX application_services_txns_id ON application_services_txns ( as_id ); +CREATE TABLE presence( user_id TEXT NOT NULL, state VARCHAR(20), status_msg TEXT, mtime BIGINT, UNIQUE (user_id) ); +CREATE TABLE presence_allow_inbound( observed_user_id TEXT NOT NULL, observer_user_id TEXT NOT NULL, UNIQUE (observed_user_id, observer_user_id) ); +CREATE TABLE users( name TEXT, password_hash TEXT, creation_ts BIGINT, admin SMALLINT DEFAULT 0 NOT NULL, upgrade_ts BIGINT, is_guest SMALLINT DEFAULT 0 NOT NULL, appservice_id TEXT, consent_version TEXT, consent_server_notice_sent TEXT, user_type TEXT DEFAULT NULL, UNIQUE(name) ); +CREATE TABLE access_tokens( id BIGINT PRIMARY KEY, user_id TEXT NOT NULL, device_id TEXT, token TEXT NOT NULL, last_used BIGINT, UNIQUE(token) ); +CREATE TABLE user_ips ( user_id TEXT NOT NULL, access_token TEXT NOT NULL, device_id TEXT, ip TEXT NOT NULL, user_agent TEXT NOT NULL, last_seen BIGINT NOT NULL ); +CREATE TABLE profiles( user_id TEXT NOT NULL, displayname TEXT, avatar_url TEXT, UNIQUE(user_id) ); +CREATE TABLE received_transactions( transaction_id TEXT, origin TEXT, ts BIGINT, response_code INTEGER, response_json bytea, has_been_referenced smallint default 0, UNIQUE (transaction_id, origin) ); +CREATE TABLE destinations( destination TEXT PRIMARY KEY, retry_last_ts BIGINT, retry_interval INTEGER ); +CREATE TABLE events( stream_ordering INTEGER PRIMARY KEY, topological_ordering BIGINT NOT NULL, event_id TEXT NOT NULL, type TEXT NOT NULL, room_id TEXT NOT NULL, content TEXT, unrecognized_keys TEXT, processed BOOL NOT NULL, outlier BOOL NOT NULL, depth BIGINT DEFAULT 0 NOT NULL, origin_server_ts BIGINT, received_ts BIGINT, sender TEXT, contains_url BOOLEAN, UNIQUE (event_id) ); +CREATE INDEX events_order_room ON events ( room_id, topological_ordering, stream_ordering ); +CREATE TABLE event_json( event_id TEXT NOT NULL, room_id TEXT NOT NULL, internal_metadata TEXT NOT NULL, json TEXT NOT NULL, format_version INTEGER, UNIQUE (event_id) ); +CREATE INDEX event_json_room_id ON event_json(room_id); +CREATE TABLE state_events( event_id TEXT NOT NULL, room_id TEXT NOT NULL, type TEXT NOT NULL, state_key TEXT NOT NULL, prev_state TEXT, UNIQUE (event_id) ); +CREATE TABLE current_state_events( event_id TEXT NOT NULL, room_id TEXT NOT NULL, type TEXT NOT NULL, state_key TEXT NOT NULL, UNIQUE (event_id), UNIQUE (room_id, type, state_key) ); +CREATE TABLE room_memberships( event_id TEXT NOT NULL, user_id TEXT NOT NULL, sender TEXT NOT NULL, room_id TEXT NOT NULL, membership TEXT NOT NULL, forgotten INTEGER DEFAULT 0, display_name TEXT, avatar_url TEXT, UNIQUE (event_id) ); +CREATE INDEX room_memberships_room_id ON room_memberships (room_id); +CREATE INDEX room_memberships_user_id ON room_memberships (user_id); +CREATE TABLE topics( event_id TEXT NOT NULL, room_id TEXT NOT NULL, topic TEXT NOT NULL, UNIQUE (event_id) ); +CREATE INDEX topics_room_id ON topics(room_id); +CREATE TABLE room_names( event_id TEXT NOT NULL, room_id TEXT NOT NULL, name TEXT NOT NULL, UNIQUE (event_id) ); +CREATE INDEX room_names_room_id ON room_names(room_id); +CREATE TABLE rooms( room_id TEXT PRIMARY KEY NOT NULL, is_public BOOL, creator TEXT ); +CREATE TABLE server_signature_keys( server_name TEXT, key_id TEXT, from_server TEXT, ts_added_ms BIGINT, verify_key bytea, ts_valid_until_ms BIGINT, UNIQUE (server_name, key_id) ); +CREATE TABLE rejections( event_id TEXT NOT NULL, reason TEXT NOT NULL, last_check TEXT NOT NULL, UNIQUE (event_id) ); +CREATE TABLE push_rules ( id BIGINT PRIMARY KEY, user_name TEXT NOT NULL, rule_id TEXT NOT NULL, priority_class SMALLINT NOT NULL, priority INTEGER NOT NULL DEFAULT 0, conditions TEXT NOT NULL, actions TEXT NOT NULL, UNIQUE(user_name, rule_id) ); +CREATE INDEX push_rules_user_name on push_rules (user_name); +CREATE TABLE user_filters( user_id TEXT, filter_id BIGINT, filter_json bytea ); +CREATE INDEX user_filters_by_user_id_filter_id ON user_filters( user_id, filter_id ); +CREATE TABLE push_rules_enable ( id BIGINT PRIMARY KEY, user_name TEXT NOT NULL, rule_id TEXT NOT NULL, enabled SMALLINT, UNIQUE(user_name, rule_id) ); +CREATE INDEX push_rules_enable_user_name on push_rules_enable (user_name); +CREATE TABLE event_forward_extremities( event_id TEXT NOT NULL, room_id TEXT NOT NULL, UNIQUE (event_id, room_id) ); +CREATE INDEX ev_extrem_room ON event_forward_extremities(room_id); +CREATE INDEX ev_extrem_id ON event_forward_extremities(event_id); +CREATE TABLE event_backward_extremities( event_id TEXT NOT NULL, room_id TEXT NOT NULL, UNIQUE (event_id, room_id) ); +CREATE INDEX ev_b_extrem_room ON event_backward_extremities(room_id); +CREATE INDEX ev_b_extrem_id ON event_backward_extremities(event_id); +CREATE TABLE event_edges( event_id TEXT NOT NULL, prev_event_id TEXT NOT NULL, room_id TEXT NOT NULL, is_state BOOL NOT NULL, UNIQUE (event_id, prev_event_id, room_id, is_state) ); +CREATE INDEX ev_edges_id ON event_edges(event_id); +CREATE INDEX ev_edges_prev_id ON event_edges(prev_event_id); +CREATE TABLE room_depth( room_id TEXT NOT NULL, min_depth INTEGER NOT NULL, UNIQUE (room_id) ); +CREATE INDEX room_depth_room ON room_depth(room_id); +CREATE TABLE state_groups( id BIGINT PRIMARY KEY, room_id TEXT NOT NULL, event_id TEXT NOT NULL ); +CREATE TABLE state_groups_state( state_group BIGINT NOT NULL, room_id TEXT NOT NULL, type TEXT NOT NULL, state_key TEXT NOT NULL, event_id TEXT NOT NULL ); +CREATE TABLE event_to_state_groups( event_id TEXT NOT NULL, state_group BIGINT NOT NULL, UNIQUE (event_id) ); +CREATE TABLE local_media_repository ( media_id TEXT, media_type TEXT, media_length INTEGER, created_ts BIGINT, upload_name TEXT, user_id TEXT, quarantined_by TEXT, url_cache TEXT, last_access_ts BIGINT, UNIQUE (media_id) ); +CREATE TABLE local_media_repository_thumbnails ( media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_type TEXT, thumbnail_method TEXT, thumbnail_length INTEGER, UNIQUE ( media_id, thumbnail_width, thumbnail_height, thumbnail_type ) ); +CREATE INDEX local_media_repository_thumbnails_media_id ON local_media_repository_thumbnails (media_id); +CREATE TABLE remote_media_cache ( media_origin TEXT, media_id TEXT, media_type TEXT, created_ts BIGINT, upload_name TEXT, media_length INTEGER, filesystem_id TEXT, last_access_ts BIGINT, quarantined_by TEXT, UNIQUE (media_origin, media_id) ); +CREATE TABLE remote_media_cache_thumbnails ( media_origin TEXT, media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_method TEXT, thumbnail_type TEXT, thumbnail_length INTEGER, filesystem_id TEXT, UNIQUE ( media_origin, media_id, thumbnail_width, thumbnail_height, thumbnail_type ) ); +CREATE TABLE redactions ( event_id TEXT NOT NULL, redacts TEXT NOT NULL, UNIQUE (event_id) ); +CREATE INDEX redactions_redacts ON redactions (redacts); +CREATE TABLE room_aliases( room_alias TEXT NOT NULL, room_id TEXT NOT NULL, creator TEXT, UNIQUE (room_alias) ); +CREATE INDEX room_aliases_id ON room_aliases(room_id); +CREATE TABLE room_alias_servers( room_alias TEXT NOT NULL, server TEXT NOT NULL ); +CREATE INDEX room_alias_servers_alias ON room_alias_servers(room_alias); +CREATE TABLE event_reference_hashes ( event_id TEXT, algorithm TEXT, hash bytea, UNIQUE (event_id, algorithm) ); +CREATE INDEX event_reference_hashes_id ON event_reference_hashes(event_id); +CREATE TABLE IF NOT EXISTS "server_keys_json" ( server_name TEXT NOT NULL, key_id TEXT NOT NULL, from_server TEXT NOT NULL, ts_added_ms BIGINT NOT NULL, ts_valid_until_ms BIGINT NOT NULL, key_json bytea NOT NULL, CONSTRAINT server_keys_json_uniqueness UNIQUE (server_name, key_id, from_server) ); +CREATE TABLE e2e_device_keys_json ( user_id TEXT NOT NULL, device_id TEXT NOT NULL, ts_added_ms BIGINT NOT NULL, key_json TEXT NOT NULL, CONSTRAINT e2e_device_keys_json_uniqueness UNIQUE (user_id, device_id) ); +CREATE TABLE e2e_one_time_keys_json ( user_id TEXT NOT NULL, device_id TEXT NOT NULL, algorithm TEXT NOT NULL, key_id TEXT NOT NULL, ts_added_ms BIGINT NOT NULL, key_json TEXT NOT NULL, CONSTRAINT e2e_one_time_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm, key_id) ); +CREATE TABLE receipts_graph( room_id TEXT NOT NULL, receipt_type TEXT NOT NULL, user_id TEXT NOT NULL, event_ids TEXT NOT NULL, data TEXT NOT NULL, CONSTRAINT receipts_graph_uniqueness UNIQUE (room_id, receipt_type, user_id) ); +CREATE TABLE receipts_linearized ( stream_id BIGINT NOT NULL, room_id TEXT NOT NULL, receipt_type TEXT NOT NULL, user_id TEXT NOT NULL, event_id TEXT NOT NULL, data TEXT NOT NULL, CONSTRAINT receipts_linearized_uniqueness UNIQUE (room_id, receipt_type, user_id) ); +CREATE INDEX receipts_linearized_id ON receipts_linearized( stream_id ); +CREATE INDEX receipts_linearized_room_stream ON receipts_linearized( room_id, stream_id ); +CREATE TABLE IF NOT EXISTS "user_threepids" ( user_id TEXT NOT NULL, medium TEXT NOT NULL, address TEXT NOT NULL, validated_at BIGINT NOT NULL, added_at BIGINT NOT NULL, CONSTRAINT medium_address UNIQUE (medium, address) ); +CREATE INDEX user_threepids_user_id ON user_threepids(user_id); +CREATE VIRTUAL TABLE event_search USING fts4 ( event_id, room_id, sender, key, value ) +/* event_search(event_id,room_id,sender,"key",value) */; +CREATE TABLE IF NOT EXISTS 'event_search_content'(docid INTEGER PRIMARY KEY, 'c0event_id', 'c1room_id', 'c2sender', 'c3key', 'c4value'); +CREATE TABLE IF NOT EXISTS 'event_search_segments'(blockid INTEGER PRIMARY KEY, block BLOB); +CREATE TABLE IF NOT EXISTS 'event_search_segdir'(level INTEGER,idx INTEGER,start_block INTEGER,leaves_end_block INTEGER,end_block INTEGER,root BLOB,PRIMARY KEY(level, idx)); +CREATE TABLE IF NOT EXISTS 'event_search_docsize'(docid INTEGER PRIMARY KEY, size BLOB); +CREATE TABLE IF NOT EXISTS 'event_search_stat'(id INTEGER PRIMARY KEY, value BLOB); +CREATE TABLE guest_access( event_id TEXT NOT NULL, room_id TEXT NOT NULL, guest_access TEXT NOT NULL, UNIQUE (event_id) ); +CREATE TABLE history_visibility( event_id TEXT NOT NULL, room_id TEXT NOT NULL, history_visibility TEXT NOT NULL, UNIQUE (event_id) ); +CREATE TABLE room_tags( user_id TEXT NOT NULL, room_id TEXT NOT NULL, tag TEXT NOT NULL, content TEXT NOT NULL, CONSTRAINT room_tag_uniqueness UNIQUE (user_id, room_id, tag) ); +CREATE TABLE room_tags_revisions ( user_id TEXT NOT NULL, room_id TEXT NOT NULL, stream_id BIGINT NOT NULL, CONSTRAINT room_tag_revisions_uniqueness UNIQUE (user_id, room_id) ); +CREATE TABLE IF NOT EXISTS "account_data_max_stream_id"( Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, stream_id BIGINT NOT NULL, CHECK (Lock='X') ); +CREATE TABLE account_data( user_id TEXT NOT NULL, account_data_type TEXT NOT NULL, stream_id BIGINT NOT NULL, content TEXT NOT NULL, CONSTRAINT account_data_uniqueness UNIQUE (user_id, account_data_type) ); +CREATE TABLE room_account_data( user_id TEXT NOT NULL, room_id TEXT NOT NULL, account_data_type TEXT NOT NULL, stream_id BIGINT NOT NULL, content TEXT NOT NULL, CONSTRAINT room_account_data_uniqueness UNIQUE (user_id, room_id, account_data_type) ); +CREATE INDEX account_data_stream_id on account_data(user_id, stream_id); +CREATE INDEX room_account_data_stream_id on room_account_data(user_id, stream_id); +CREATE INDEX events_ts ON events(origin_server_ts, stream_ordering); +CREATE TABLE event_push_actions( room_id TEXT NOT NULL, event_id TEXT NOT NULL, user_id TEXT NOT NULL, profile_tag VARCHAR(32), actions TEXT NOT NULL, topological_ordering BIGINT, stream_ordering BIGINT, notif SMALLINT, highlight SMALLINT, CONSTRAINT event_id_user_id_profile_tag_uniqueness UNIQUE (room_id, event_id, user_id, profile_tag) ); +CREATE INDEX event_push_actions_room_id_user_id on event_push_actions(room_id, user_id); +CREATE INDEX events_room_stream on events(room_id, stream_ordering); +CREATE INDEX public_room_index on rooms(is_public); +CREATE INDEX receipts_linearized_user ON receipts_linearized( user_id ); +CREATE INDEX event_push_actions_rm_tokens on event_push_actions( user_id, room_id, topological_ordering, stream_ordering ); +CREATE TABLE presence_stream( stream_id BIGINT, user_id TEXT, state TEXT, last_active_ts BIGINT, last_federation_update_ts BIGINT, last_user_sync_ts BIGINT, status_msg TEXT, currently_active BOOLEAN ); +CREATE INDEX presence_stream_id ON presence_stream(stream_id, user_id); +CREATE INDEX presence_stream_user_id ON presence_stream(user_id); +CREATE TABLE push_rules_stream( stream_id BIGINT NOT NULL, event_stream_ordering BIGINT NOT NULL, user_id TEXT NOT NULL, rule_id TEXT NOT NULL, op TEXT NOT NULL, priority_class SMALLINT, priority INTEGER, conditions TEXT, actions TEXT ); +CREATE INDEX push_rules_stream_id ON push_rules_stream(stream_id); +CREATE INDEX push_rules_stream_user_stream_id on push_rules_stream(user_id, stream_id); +CREATE TABLE ex_outlier_stream( event_stream_ordering BIGINT PRIMARY KEY NOT NULL, event_id TEXT NOT NULL, state_group BIGINT NOT NULL ); +CREATE TABLE threepid_guest_access_tokens( medium TEXT, address TEXT, guest_access_token TEXT, first_inviter TEXT ); +CREATE UNIQUE INDEX threepid_guest_access_tokens_index ON threepid_guest_access_tokens(medium, address); +CREATE TABLE local_invites( stream_id BIGINT NOT NULL, inviter TEXT NOT NULL, invitee TEXT NOT NULL, event_id TEXT NOT NULL, room_id TEXT NOT NULL, locally_rejected TEXT, replaced_by TEXT ); +CREATE INDEX local_invites_id ON local_invites(stream_id); +CREATE INDEX local_invites_for_user_idx ON local_invites(invitee, locally_rejected, replaced_by, room_id); +CREATE INDEX event_push_actions_stream_ordering on event_push_actions( stream_ordering, user_id ); +CREATE TABLE open_id_tokens ( token TEXT NOT NULL PRIMARY KEY, ts_valid_until_ms bigint NOT NULL, user_id TEXT NOT NULL, UNIQUE (token) ); +CREATE INDEX open_id_tokens_ts_valid_until_ms ON open_id_tokens(ts_valid_until_ms); +CREATE TABLE pusher_throttle( pusher BIGINT NOT NULL, room_id TEXT NOT NULL, last_sent_ts BIGINT, throttle_ms BIGINT, PRIMARY KEY (pusher, room_id) ); +CREATE TABLE event_reports( id BIGINT NOT NULL PRIMARY KEY, received_ts BIGINT NOT NULL, room_id TEXT NOT NULL, event_id TEXT NOT NULL, user_id TEXT NOT NULL, reason TEXT, content TEXT ); +CREATE TABLE devices ( user_id TEXT NOT NULL, device_id TEXT NOT NULL, display_name TEXT, CONSTRAINT device_uniqueness UNIQUE (user_id, device_id) ); +CREATE TABLE appservice_stream_position( Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, stream_ordering BIGINT, CHECK (Lock='X') ); +CREATE TABLE device_inbox ( user_id TEXT NOT NULL, device_id TEXT NOT NULL, stream_id BIGINT NOT NULL, message_json TEXT NOT NULL ); +CREATE INDEX device_inbox_user_stream_id ON device_inbox(user_id, device_id, stream_id); +CREATE INDEX received_transactions_ts ON received_transactions(ts); +CREATE TABLE device_federation_outbox ( destination TEXT NOT NULL, stream_id BIGINT NOT NULL, queued_ts BIGINT NOT NULL, messages_json TEXT NOT NULL ); +CREATE INDEX device_federation_outbox_destination_id ON device_federation_outbox(destination, stream_id); +CREATE TABLE device_federation_inbox ( origin TEXT NOT NULL, message_id TEXT NOT NULL, received_ts BIGINT NOT NULL ); +CREATE INDEX device_federation_inbox_sender_id ON device_federation_inbox(origin, message_id); +CREATE TABLE device_max_stream_id ( stream_id BIGINT NOT NULL ); +CREATE TABLE public_room_list_stream ( stream_id BIGINT NOT NULL, room_id TEXT NOT NULL, visibility BOOLEAN NOT NULL , appservice_id TEXT, network_id TEXT); +CREATE INDEX public_room_list_stream_idx on public_room_list_stream( stream_id ); +CREATE INDEX public_room_list_stream_rm_idx on public_room_list_stream( room_id, stream_id ); +CREATE TABLE state_group_edges( state_group BIGINT NOT NULL, prev_state_group BIGINT NOT NULL ); +CREATE INDEX state_group_edges_idx ON state_group_edges(state_group); +CREATE INDEX state_group_edges_prev_idx ON state_group_edges(prev_state_group); +CREATE TABLE stream_ordering_to_exterm ( stream_ordering BIGINT NOT NULL, room_id TEXT NOT NULL, event_id TEXT NOT NULL ); +CREATE INDEX stream_ordering_to_exterm_idx on stream_ordering_to_exterm( stream_ordering ); +CREATE INDEX stream_ordering_to_exterm_rm_idx on stream_ordering_to_exterm( room_id, stream_ordering ); +CREATE TABLE IF NOT EXISTS "event_auth"( event_id TEXT NOT NULL, auth_id TEXT NOT NULL, room_id TEXT NOT NULL ); +CREATE INDEX evauth_edges_id ON event_auth(event_id); +CREATE INDEX user_threepids_medium_address on user_threepids (medium, address); +CREATE TABLE appservice_room_list( appservice_id TEXT NOT NULL, network_id TEXT NOT NULL, room_id TEXT NOT NULL ); +CREATE UNIQUE INDEX appservice_room_list_idx ON appservice_room_list( appservice_id, network_id, room_id ); +CREATE INDEX device_federation_outbox_id ON device_federation_outbox(stream_id); +CREATE TABLE federation_stream_position( type TEXT NOT NULL, stream_id INTEGER NOT NULL ); +CREATE TABLE device_lists_remote_cache ( user_id TEXT NOT NULL, device_id TEXT NOT NULL, content TEXT NOT NULL ); +CREATE TABLE device_lists_remote_extremeties ( user_id TEXT NOT NULL, stream_id TEXT NOT NULL ); +CREATE TABLE device_lists_stream ( stream_id BIGINT NOT NULL, user_id TEXT NOT NULL, device_id TEXT NOT NULL ); +CREATE INDEX device_lists_stream_id ON device_lists_stream(stream_id, user_id); +CREATE TABLE device_lists_outbound_pokes ( destination TEXT NOT NULL, stream_id BIGINT NOT NULL, user_id TEXT NOT NULL, device_id TEXT NOT NULL, sent BOOLEAN NOT NULL, ts BIGINT NOT NULL ); +CREATE INDEX device_lists_outbound_pokes_id ON device_lists_outbound_pokes(destination, stream_id); +CREATE INDEX device_lists_outbound_pokes_user ON device_lists_outbound_pokes(destination, user_id); +CREATE TABLE event_push_summary ( user_id TEXT NOT NULL, room_id TEXT NOT NULL, notif_count BIGINT NOT NULL, stream_ordering BIGINT NOT NULL ); +CREATE INDEX event_push_summary_user_rm ON event_push_summary(user_id, room_id); +CREATE TABLE event_push_summary_stream_ordering ( Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, stream_ordering BIGINT NOT NULL, CHECK (Lock='X') ); +CREATE TABLE IF NOT EXISTS "pushers" ( id BIGINT PRIMARY KEY, user_name TEXT NOT NULL, access_token BIGINT DEFAULT NULL, profile_tag TEXT NOT NULL, kind TEXT NOT NULL, app_id TEXT NOT NULL, app_display_name TEXT NOT NULL, device_display_name TEXT NOT NULL, pushkey TEXT NOT NULL, ts BIGINT NOT NULL, lang TEXT, data TEXT, last_stream_ordering INTEGER, last_success BIGINT, failing_since BIGINT, UNIQUE (app_id, pushkey, user_name) ); +CREATE INDEX device_lists_outbound_pokes_stream ON device_lists_outbound_pokes(stream_id); +CREATE TABLE ratelimit_override ( user_id TEXT NOT NULL, messages_per_second BIGINT, burst_count BIGINT ); +CREATE UNIQUE INDEX ratelimit_override_idx ON ratelimit_override(user_id); +CREATE TABLE current_state_delta_stream ( stream_id BIGINT NOT NULL, room_id TEXT NOT NULL, type TEXT NOT NULL, state_key TEXT NOT NULL, event_id TEXT, prev_event_id TEXT ); +CREATE INDEX current_state_delta_stream_idx ON current_state_delta_stream(stream_id); +CREATE TABLE device_lists_outbound_last_success ( destination TEXT NOT NULL, user_id TEXT NOT NULL, stream_id BIGINT NOT NULL ); +CREATE INDEX device_lists_outbound_last_success_idx ON device_lists_outbound_last_success( destination, user_id, stream_id ); +CREATE TABLE user_directory_stream_pos ( Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, stream_id BIGINT, CHECK (Lock='X') ); +CREATE VIRTUAL TABLE user_directory_search USING fts4 ( user_id, value ) +/* user_directory_search(user_id,value) */; +CREATE TABLE IF NOT EXISTS 'user_directory_search_content'(docid INTEGER PRIMARY KEY, 'c0user_id', 'c1value'); +CREATE TABLE IF NOT EXISTS 'user_directory_search_segments'(blockid INTEGER PRIMARY KEY, block BLOB); +CREATE TABLE IF NOT EXISTS 'user_directory_search_segdir'(level INTEGER,idx INTEGER,start_block INTEGER,leaves_end_block INTEGER,end_block INTEGER,root BLOB,PRIMARY KEY(level, idx)); +CREATE TABLE IF NOT EXISTS 'user_directory_search_docsize'(docid INTEGER PRIMARY KEY, size BLOB); +CREATE TABLE IF NOT EXISTS 'user_directory_search_stat'(id INTEGER PRIMARY KEY, value BLOB); +CREATE TABLE blocked_rooms ( room_id TEXT NOT NULL, user_id TEXT NOT NULL ); +CREATE UNIQUE INDEX blocked_rooms_idx ON blocked_rooms(room_id); +CREATE TABLE IF NOT EXISTS "local_media_repository_url_cache"( url TEXT, response_code INTEGER, etag TEXT, expires_ts BIGINT, og TEXT, media_id TEXT, download_ts BIGINT ); +CREATE INDEX local_media_repository_url_cache_expires_idx ON local_media_repository_url_cache(expires_ts); +CREATE INDEX local_media_repository_url_cache_by_url_download_ts ON local_media_repository_url_cache(url, download_ts); +CREATE INDEX local_media_repository_url_cache_media_idx ON local_media_repository_url_cache(media_id); +CREATE TABLE group_users ( group_id TEXT NOT NULL, user_id TEXT NOT NULL, is_admin BOOLEAN NOT NULL, is_public BOOLEAN NOT NULL ); +CREATE TABLE group_invites ( group_id TEXT NOT NULL, user_id TEXT NOT NULL ); +CREATE TABLE group_rooms ( group_id TEXT NOT NULL, room_id TEXT NOT NULL, is_public BOOLEAN NOT NULL ); +CREATE TABLE group_summary_rooms ( group_id TEXT NOT NULL, room_id TEXT NOT NULL, category_id TEXT NOT NULL, room_order BIGINT NOT NULL, is_public BOOLEAN NOT NULL, UNIQUE (group_id, category_id, room_id, room_order), CHECK (room_order > 0) ); +CREATE UNIQUE INDEX group_summary_rooms_g_idx ON group_summary_rooms(group_id, room_id, category_id); +CREATE TABLE group_summary_room_categories ( group_id TEXT NOT NULL, category_id TEXT NOT NULL, cat_order BIGINT NOT NULL, UNIQUE (group_id, category_id, cat_order), CHECK (cat_order > 0) ); +CREATE TABLE group_room_categories ( group_id TEXT NOT NULL, category_id TEXT NOT NULL, profile TEXT NOT NULL, is_public BOOLEAN NOT NULL, UNIQUE (group_id, category_id) ); +CREATE TABLE group_summary_users ( group_id TEXT NOT NULL, user_id TEXT NOT NULL, role_id TEXT NOT NULL, user_order BIGINT NOT NULL, is_public BOOLEAN NOT NULL ); +CREATE INDEX group_summary_users_g_idx ON group_summary_users(group_id); +CREATE TABLE group_summary_roles ( group_id TEXT NOT NULL, role_id TEXT NOT NULL, role_order BIGINT NOT NULL, UNIQUE (group_id, role_id, role_order), CHECK (role_order > 0) ); +CREATE TABLE group_roles ( group_id TEXT NOT NULL, role_id TEXT NOT NULL, profile TEXT NOT NULL, is_public BOOLEAN NOT NULL, UNIQUE (group_id, role_id) ); +CREATE TABLE group_attestations_renewals ( group_id TEXT NOT NULL, user_id TEXT NOT NULL, valid_until_ms BIGINT NOT NULL ); +CREATE INDEX group_attestations_renewals_g_idx ON group_attestations_renewals(group_id, user_id); +CREATE INDEX group_attestations_renewals_u_idx ON group_attestations_renewals(user_id); +CREATE INDEX group_attestations_renewals_v_idx ON group_attestations_renewals(valid_until_ms); +CREATE TABLE group_attestations_remote ( group_id TEXT NOT NULL, user_id TEXT NOT NULL, valid_until_ms BIGINT NOT NULL, attestation_json TEXT NOT NULL ); +CREATE INDEX group_attestations_remote_g_idx ON group_attestations_remote(group_id, user_id); +CREATE INDEX group_attestations_remote_u_idx ON group_attestations_remote(user_id); +CREATE INDEX group_attestations_remote_v_idx ON group_attestations_remote(valid_until_ms); +CREATE TABLE local_group_membership ( group_id TEXT NOT NULL, user_id TEXT NOT NULL, is_admin BOOLEAN NOT NULL, membership TEXT NOT NULL, is_publicised BOOLEAN NOT NULL, content TEXT NOT NULL ); +CREATE INDEX local_group_membership_u_idx ON local_group_membership(user_id, group_id); +CREATE INDEX local_group_membership_g_idx ON local_group_membership(group_id); +CREATE TABLE local_group_updates ( stream_id BIGINT NOT NULL, group_id TEXT NOT NULL, user_id TEXT NOT NULL, type TEXT NOT NULL, content TEXT NOT NULL ); +CREATE TABLE remote_profile_cache ( user_id TEXT NOT NULL, displayname TEXT, avatar_url TEXT, last_check BIGINT NOT NULL ); +CREATE UNIQUE INDEX remote_profile_cache_user_id ON remote_profile_cache(user_id); +CREATE INDEX remote_profile_cache_time ON remote_profile_cache(last_check); +CREATE TABLE IF NOT EXISTS "deleted_pushers" ( stream_id BIGINT NOT NULL, app_id TEXT NOT NULL, pushkey TEXT NOT NULL, user_id TEXT NOT NULL ); +CREATE INDEX deleted_pushers_stream_id ON deleted_pushers (stream_id); +CREATE TABLE IF NOT EXISTS "groups" ( group_id TEXT NOT NULL, name TEXT, avatar_url TEXT, short_description TEXT, long_description TEXT, is_public BOOL NOT NULL , join_policy TEXT NOT NULL DEFAULT 'invite'); +CREATE UNIQUE INDEX groups_idx ON groups(group_id); +CREATE TABLE IF NOT EXISTS "user_directory" ( user_id TEXT NOT NULL, room_id TEXT, display_name TEXT, avatar_url TEXT ); +CREATE INDEX user_directory_room_idx ON user_directory(room_id); +CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id); +CREATE TABLE event_push_actions_staging ( event_id TEXT NOT NULL, user_id TEXT NOT NULL, actions TEXT NOT NULL, notif SMALLINT NOT NULL, highlight SMALLINT NOT NULL ); +CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging(event_id); +CREATE TABLE users_pending_deactivation ( user_id TEXT NOT NULL ); +CREATE UNIQUE INDEX group_invites_g_idx ON group_invites(group_id, user_id); +CREATE UNIQUE INDEX group_users_g_idx ON group_users(group_id, user_id); +CREATE INDEX group_users_u_idx ON group_users(user_id); +CREATE INDEX group_invites_u_idx ON group_invites(user_id); +CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms(group_id, room_id); +CREATE INDEX group_rooms_r_idx ON group_rooms(room_id); +CREATE TABLE user_daily_visits ( user_id TEXT NOT NULL, device_id TEXT, timestamp BIGINT NOT NULL ); +CREATE INDEX user_daily_visits_uts_idx ON user_daily_visits(user_id, timestamp); +CREATE INDEX user_daily_visits_ts_idx ON user_daily_visits(timestamp); +CREATE TABLE erased_users ( user_id TEXT NOT NULL ); +CREATE UNIQUE INDEX erased_users_user ON erased_users(user_id); +CREATE TABLE monthly_active_users ( user_id TEXT NOT NULL, timestamp BIGINT NOT NULL ); +CREATE UNIQUE INDEX monthly_active_users_users ON monthly_active_users(user_id); +CREATE INDEX monthly_active_users_time_stamp ON monthly_active_users(timestamp); +CREATE TABLE IF NOT EXISTS "e2e_room_keys_versions" ( user_id TEXT NOT NULL, version BIGINT NOT NULL, algorithm TEXT NOT NULL, auth_data TEXT NOT NULL, deleted SMALLINT DEFAULT 0 NOT NULL ); +CREATE UNIQUE INDEX e2e_room_keys_versions_idx ON e2e_room_keys_versions(user_id, version); +CREATE TABLE IF NOT EXISTS "e2e_room_keys" ( user_id TEXT NOT NULL, room_id TEXT NOT NULL, session_id TEXT NOT NULL, version BIGINT NOT NULL, first_message_index INT, forwarded_count INT, is_verified BOOLEAN, session_data TEXT NOT NULL ); +CREATE UNIQUE INDEX e2e_room_keys_idx ON e2e_room_keys(user_id, room_id, session_id); +CREATE TABLE users_who_share_private_rooms ( user_id TEXT NOT NULL, other_user_id TEXT NOT NULL, room_id TEXT NOT NULL ); +CREATE UNIQUE INDEX users_who_share_private_rooms_u_idx ON users_who_share_private_rooms(user_id, other_user_id, room_id); +CREATE INDEX users_who_share_private_rooms_r_idx ON users_who_share_private_rooms(room_id); +CREATE INDEX users_who_share_private_rooms_o_idx ON users_who_share_private_rooms(other_user_id); +CREATE TABLE user_threepid_id_server ( user_id TEXT NOT NULL, medium TEXT NOT NULL, address TEXT NOT NULL, id_server TEXT NOT NULL ); +CREATE UNIQUE INDEX user_threepid_id_server_idx ON user_threepid_id_server( user_id, medium, address, id_server ); +CREATE TABLE users_in_public_rooms ( user_id TEXT NOT NULL, room_id TEXT NOT NULL ); +CREATE UNIQUE INDEX users_in_public_rooms_u_idx ON users_in_public_rooms(user_id, room_id); +CREATE TABLE account_validity ( user_id TEXT PRIMARY KEY, expiration_ts_ms BIGINT NOT NULL, email_sent BOOLEAN NOT NULL, renewal_token TEXT ); +CREATE TABLE event_relations ( event_id TEXT NOT NULL, relates_to_id TEXT NOT NULL, relation_type TEXT NOT NULL, aggregation_key TEXT ); +CREATE UNIQUE INDEX event_relations_id ON event_relations(event_id); +CREATE INDEX event_relations_relates ON event_relations(relates_to_id, relation_type, aggregation_key); +CREATE TABLE stats_stream_pos ( Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, stream_id BIGINT, CHECK (Lock='X') ); +CREATE TABLE user_stats ( user_id TEXT NOT NULL, ts BIGINT NOT NULL, bucket_size INT NOT NULL, public_rooms INT NOT NULL, private_rooms INT NOT NULL ); +CREATE UNIQUE INDEX user_stats_user_ts ON user_stats(user_id, ts); +CREATE TABLE room_stats ( room_id TEXT NOT NULL, ts BIGINT NOT NULL, bucket_size INT NOT NULL, current_state_events INT NOT NULL, joined_members INT NOT NULL, invited_members INT NOT NULL, left_members INT NOT NULL, banned_members INT NOT NULL, state_events INT NOT NULL ); +CREATE UNIQUE INDEX room_stats_room_ts ON room_stats(room_id, ts); +CREATE TABLE room_state ( room_id TEXT NOT NULL, join_rules TEXT, history_visibility TEXT, encryption TEXT, name TEXT, topic TEXT, avatar TEXT, canonical_alias TEXT ); +CREATE UNIQUE INDEX room_state_room ON room_state(room_id); +CREATE TABLE room_stats_earliest_token ( room_id TEXT NOT NULL, token BIGINT NOT NULL ); +CREATE UNIQUE INDEX room_stats_earliest_token_idx ON room_stats_earliest_token(room_id); +CREATE INDEX access_tokens_device_id ON access_tokens (user_id, device_id); +CREATE INDEX user_ips_device_id ON user_ips (user_id, device_id, last_seen); +CREATE INDEX event_contains_url_index ON events (room_id, topological_ordering, stream_ordering); +CREATE INDEX event_push_actions_u_highlight ON event_push_actions (user_id, stream_ordering); +CREATE INDEX event_push_actions_highlights_index ON event_push_actions (user_id, room_id, topological_ordering, stream_ordering); +CREATE INDEX current_state_events_member_index ON current_state_events (state_key); +CREATE INDEX device_inbox_stream_id_user_id ON device_inbox (stream_id, user_id); +CREATE INDEX device_lists_stream_user_id ON device_lists_stream (user_id, device_id); +CREATE INDEX local_media_repository_url_idx ON local_media_repository (created_ts); +CREATE INDEX user_ips_last_seen ON user_ips (user_id, last_seen); +CREATE INDEX user_ips_last_seen_only ON user_ips (last_seen); +CREATE INDEX users_creation_ts ON users (creation_ts); +CREATE INDEX event_to_state_groups_sg_index ON event_to_state_groups (state_group); +CREATE UNIQUE INDEX device_lists_remote_cache_unique_id ON device_lists_remote_cache (user_id, device_id); +CREATE INDEX state_groups_state_type_idx ON state_groups_state(state_group, type, state_key); +CREATE UNIQUE INDEX device_lists_remote_extremeties_unique_idx ON device_lists_remote_extremeties (user_id); +CREATE UNIQUE INDEX user_ips_user_token_ip_unique_index ON user_ips (user_id, access_token, ip); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql b/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql new file mode 100644 index 0000000000..c265fd20e2 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql @@ -0,0 +1,7 @@ + +INSERT INTO appservice_stream_position (stream_ordering) SELECT COALESCE(MAX(stream_ordering), 0) FROM events; +INSERT INTO federation_stream_position (type, stream_id) VALUES ('federation', -1); +INSERT INTO federation_stream_position (type, stream_id) SELECT 'events', coalesce(max(stream_ordering), -1) FROM events; +INSERT INTO user_directory_stream_pos (stream_id) VALUES (0); +INSERT INTO stats_stream_pos (stream_id) VALUES (0); +INSERT INTO event_push_summary_stream_ordering (stream_ordering) VALUES (0); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.txt b/synapse/storage/data_stores/main/schema/full_schemas/README.txt new file mode 100644 index 0000000000..d3f6401344 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/full_schemas/README.txt @@ -0,0 +1,19 @@ +Building full schema dumps +========================== + +These schemas need to be made from a database that has had all background updates run. + +Postgres +-------- + +$ pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner $DATABASE_NAME| sed -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' > full.sql.postgres + +SQLite +------ + +$ sqlite3 $DATABASE_FILE ".schema" > full.sql.sqlite + +After +----- + +Delete the CREATE statements for "sqlite_stat1", "schema_version", "applied_schema_deltas", and "applied_module_schemas". \ No newline at end of file diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py new file mode 100644 index 0000000000..0e08497452 --- /dev/null +++ b/synapse/storage/data_stores/main/search.py @@ -0,0 +1,712 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re +from collections import namedtuple + +from six import string_types + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.api.errors import SynapseError +from synapse.storage._base import make_in_list_sql_clause +from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage.engines import PostgresEngine, Sqlite3Engine + +logger = logging.getLogger(__name__) + +SearchEntry = namedtuple( + "SearchEntry", + ["key", "value", "event_id", "room_id", "stream_ordering", "origin_server_ts"], +) + + +class SearchBackgroundUpdateStore(BackgroundUpdateStore): + + EVENT_SEARCH_UPDATE_NAME = "event_search" + EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" + EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist" + EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" + + def __init__(self, db_conn, hs): + super(SearchBackgroundUpdateStore, self).__init__(db_conn, hs) + + if not hs.config.enable_search: + return + + self.register_background_update_handler( + self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search + ) + self.register_background_update_handler( + self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_reindex_search_order + ) + + # we used to have a background update to turn the GIN index into a + # GIST one; we no longer do that (obviously) because we actually want + # a GIN index. However, it's possible that some people might still have + # the background update queued, so we register a handler to clear the + # background update. + self.register_noop_background_update(self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME) + + self.register_background_update_handler( + self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search + ) + + @defer.inlineCallbacks + def _background_reindex_search(self, progress, batch_size): + # we work through the events table from highest stream id to lowest + target_min_stream_id = progress["target_min_stream_id_inclusive"] + max_stream_id = progress["max_stream_id_exclusive"] + rows_inserted = progress.get("rows_inserted", 0) + + TYPES = ["m.room.name", "m.room.message", "m.room.topic"] + + def reindex_search_txn(txn): + sql = ( + "SELECT stream_ordering, event_id, room_id, type, json, " + " origin_server_ts FROM events" + " JOIN event_json USING (room_id, event_id)" + " WHERE ? <= stream_ordering AND stream_ordering < ?" + " AND (%s)" + " ORDER BY stream_ordering DESC" + " LIMIT ?" + ) % (" OR ".join("type = '%s'" % (t,) for t in TYPES),) + + txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) + + # we could stream straight from the results into + # store_search_entries_txn with a generator function, but that + # would mean having two cursors open on the database at once. + # Instead we just build a list of results. + rows = self.cursor_to_dict(txn) + if not rows: + return 0 + + min_stream_id = rows[-1]["stream_ordering"] + + event_search_rows = [] + for row in rows: + try: + event_id = row["event_id"] + room_id = row["room_id"] + etype = row["type"] + stream_ordering = row["stream_ordering"] + origin_server_ts = row["origin_server_ts"] + try: + event_json = json.loads(row["json"]) + content = event_json["content"] + except Exception: + continue + + if etype == "m.room.message": + key = "content.body" + value = content["body"] + elif etype == "m.room.topic": + key = "content.topic" + value = content["topic"] + elif etype == "m.room.name": + key = "content.name" + value = content["name"] + else: + raise Exception("unexpected event type %s" % etype) + except (KeyError, AttributeError): + # If the event is missing a necessary field then + # skip over it. + continue + + if not isinstance(value, string_types): + # If the event body, name or topic isn't a string + # then skip over it + continue + + event_search_rows.append( + SearchEntry( + key=key, + value=value, + event_id=event_id, + room_id=room_id, + stream_ordering=stream_ordering, + origin_server_ts=origin_server_ts, + ) + ) + + self.store_search_entries_txn(txn, event_search_rows) + + progress = { + "target_min_stream_id_inclusive": target_min_stream_id, + "max_stream_id_exclusive": min_stream_id, + "rows_inserted": rows_inserted + len(event_search_rows), + } + + self._background_update_progress_txn( + txn, self.EVENT_SEARCH_UPDATE_NAME, progress + ) + + return len(event_search_rows) + + result = yield self.runInteraction( + self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn + ) + + if not result: + yield self._end_background_update(self.EVENT_SEARCH_UPDATE_NAME) + + return result + + @defer.inlineCallbacks + def _background_reindex_gin_search(self, progress, batch_size): + """This handles old synapses which used GIST indexes, if any; + converting them back to be GIN as per the actual schema. + """ + + def create_index(conn): + conn.rollback() + + # we have to set autocommit, because postgres refuses to + # CREATE INDEX CONCURRENTLY without it. + conn.set_session(autocommit=True) + + try: + c = conn.cursor() + + # if we skipped the conversion to GIST, we may already/still + # have an event_search_fts_idx; unfortunately postgres 9.4 + # doesn't support CREATE INDEX IF EXISTS so we just catch the + # exception and ignore it. + import psycopg2 + + try: + c.execute( + "CREATE INDEX CONCURRENTLY event_search_fts_idx" + " ON event_search USING GIN (vector)" + ) + except psycopg2.ProgrammingError as e: + logger.warn( + "Ignoring error %r when trying to switch from GIST to GIN", e + ) + + # we should now be able to delete the GIST index. + c.execute("DROP INDEX IF EXISTS event_search_fts_idx_gist") + finally: + conn.set_session(autocommit=False) + + if isinstance(self.database_engine, PostgresEngine): + yield self.runWithConnection(create_index) + + yield self._end_background_update(self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME) + return 1 + + @defer.inlineCallbacks + def _background_reindex_search_order(self, progress, batch_size): + target_min_stream_id = progress["target_min_stream_id_inclusive"] + max_stream_id = progress["max_stream_id_exclusive"] + rows_inserted = progress.get("rows_inserted", 0) + have_added_index = progress["have_added_indexes"] + + if not have_added_index: + + def create_index(conn): + conn.rollback() + conn.set_session(autocommit=True) + c = conn.cursor() + + # We create with NULLS FIRST so that when we search *backwards* + # we get the ones with non null origin_server_ts *first* + c.execute( + "CREATE INDEX CONCURRENTLY event_search_room_order ON event_search(" + "room_id, origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)" + ) + c.execute( + "CREATE INDEX CONCURRENTLY event_search_order ON event_search(" + "origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)" + ) + conn.set_session(autocommit=False) + + yield self.runWithConnection(create_index) + + pg = dict(progress) + pg["have_added_indexes"] = True + + yield self.runInteraction( + self.EVENT_SEARCH_ORDER_UPDATE_NAME, + self._background_update_progress_txn, + self.EVENT_SEARCH_ORDER_UPDATE_NAME, + pg, + ) + + def reindex_search_txn(txn): + sql = ( + "UPDATE event_search AS es SET stream_ordering = e.stream_ordering," + " origin_server_ts = e.origin_server_ts" + " FROM events AS e" + " WHERE e.event_id = es.event_id" + " AND ? <= e.stream_ordering AND e.stream_ordering < ?" + " RETURNING es.stream_ordering" + ) + + min_stream_id = max_stream_id - batch_size + txn.execute(sql, (min_stream_id, max_stream_id)) + rows = txn.fetchall() + + if min_stream_id < target_min_stream_id: + # We've recached the end. + return len(rows), False + + progress = { + "target_min_stream_id_inclusive": target_min_stream_id, + "max_stream_id_exclusive": min_stream_id, + "rows_inserted": rows_inserted + len(rows), + "have_added_indexes": True, + } + + self._background_update_progress_txn( + txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, progress + ) + + return len(rows), True + + num_rows, finished = yield self.runInteraction( + self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn + ) + + if not finished: + yield self._end_background_update(self.EVENT_SEARCH_ORDER_UPDATE_NAME) + + return num_rows + + def store_search_entries_txn(self, txn, entries): + """Add entries to the search table + + Args: + txn (cursor): + entries (iterable[SearchEntry]): + entries to be added to the table + """ + if not self.hs.config.enable_search: + return + if isinstance(self.database_engine, PostgresEngine): + sql = ( + "INSERT INTO event_search" + " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)" + " VALUES (?,?,?,to_tsvector('english', ?),?,?)" + ) + + args = ( + ( + entry.event_id, + entry.room_id, + entry.key, + entry.value, + entry.stream_ordering, + entry.origin_server_ts, + ) + for entry in entries + ) + + txn.executemany(sql, args) + + elif isinstance(self.database_engine, Sqlite3Engine): + sql = ( + "INSERT INTO event_search (event_id, room_id, key, value)" + " VALUES (?,?,?,?)" + ) + args = ( + (entry.event_id, entry.room_id, entry.key, entry.value) + for entry in entries + ) + + txn.executemany(sql, args) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + + +class SearchStore(SearchBackgroundUpdateStore): + def __init__(self, db_conn, hs): + super(SearchStore, self).__init__(db_conn, hs) + + def store_event_search_txn(self, txn, event, key, value): + """Add event to the search table + + Args: + txn (cursor): + event (EventBase): + key (str): + value (str): + """ + self.store_search_entries_txn( + txn, + ( + SearchEntry( + key=key, + value=value, + event_id=event.event_id, + room_id=event.room_id, + stream_ordering=event.internal_metadata.stream_ordering, + origin_server_ts=event.origin_server_ts, + ), + ), + ) + + @defer.inlineCallbacks + def search_msgs(self, room_ids, search_term, keys): + """Performs a full text search over events with given keys. + + Args: + room_ids (list): List of room ids to search in + search_term (str): Search term to search for + keys (list): List of keys to search in, currently supports + "content.body", "content.name", "content.topic" + + Returns: + list of dicts + """ + clauses = [] + + search_query = search_query = _parse_query(self.database_engine, search_term) + + args = [] + + # Make sure we don't explode because the person is in too many rooms. + # We filter the results below regardless. + if len(room_ids) < 500: + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", room_ids + ) + clauses = [clause] + + local_clauses = [] + for key in keys: + local_clauses.append("key = ?") + args.append(key) + + clauses.append("(%s)" % (" OR ".join(local_clauses),)) + + count_args = args + count_clauses = clauses + + if isinstance(self.database_engine, PostgresEngine): + sql = ( + "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) AS rank," + " room_id, event_id" + " FROM event_search" + " WHERE vector @@ to_tsquery('english', ?)" + ) + args = [search_query, search_query] + args + + count_sql = ( + "SELECT room_id, count(*) as count FROM event_search" + " WHERE vector @@ to_tsquery('english', ?)" + ) + count_args = [search_query] + count_args + elif isinstance(self.database_engine, Sqlite3Engine): + sql = ( + "SELECT rank(matchinfo(event_search)) as rank, room_id, event_id" + " FROM event_search" + " WHERE value MATCH ?" + ) + args = [search_query] + args + + count_sql = ( + "SELECT room_id, count(*) as count FROM event_search" + " WHERE value MATCH ?" + ) + count_args = [search_term] + count_args + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + + for clause in clauses: + sql += " AND " + clause + + for clause in count_clauses: + count_sql += " AND " + clause + + # We add an arbitrary limit here to ensure we don't try to pull the + # entire table from the database. + sql += " ORDER BY rank DESC LIMIT 500" + + results = yield self._execute("search_msgs", self.cursor_to_dict, sql, *args) + + results = list(filter(lambda row: row["room_id"] in room_ids, results)) + + events = yield self.get_events_as_list([r["event_id"] for r in results]) + + event_map = {ev.event_id: ev for ev in events} + + highlights = None + if isinstance(self.database_engine, PostgresEngine): + highlights = yield self._find_highlights_in_postgres(search_query, events) + + count_sql += " GROUP BY room_id" + + count_results = yield self._execute( + "search_rooms_count", self.cursor_to_dict, count_sql, *count_args + ) + + count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) + + return { + "results": [ + {"event": event_map[r["event_id"]], "rank": r["rank"]} + for r in results + if r["event_id"] in event_map + ], + "highlights": highlights, + "count": count, + } + + @defer.inlineCallbacks + def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None): + """Performs a full text search over events with given keys. + + Args: + room_id (list): The room_ids to search in + search_term (str): Search term to search for + keys (list): List of keys to search in, currently supports + "content.body", "content.name", "content.topic" + pagination_token (str): A pagination token previously returned + + Returns: + list of dicts + """ + clauses = [] + + search_query = search_query = _parse_query(self.database_engine, search_term) + + args = [] + + # Make sure we don't explode because the person is in too many rooms. + # We filter the results below regardless. + if len(room_ids) < 500: + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", room_ids + ) + clauses = [clause] + + local_clauses = [] + for key in keys: + local_clauses.append("key = ?") + args.append(key) + + clauses.append("(%s)" % (" OR ".join(local_clauses),)) + + # take copies of the current args and clauses lists, before adding + # pagination clauses to main query. + count_args = list(args) + count_clauses = list(clauses) + + if pagination_token: + try: + origin_server_ts, stream = pagination_token.split(",") + origin_server_ts = int(origin_server_ts) + stream = int(stream) + except Exception: + raise SynapseError(400, "Invalid pagination token") + + clauses.append( + "(origin_server_ts < ?" + " OR (origin_server_ts = ? AND stream_ordering < ?))" + ) + args.extend([origin_server_ts, origin_server_ts, stream]) + + if isinstance(self.database_engine, PostgresEngine): + sql = ( + "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) as rank," + " origin_server_ts, stream_ordering, room_id, event_id" + " FROM event_search" + " WHERE vector @@ to_tsquery('english', ?) AND " + ) + args = [search_query, search_query] + args + + count_sql = ( + "SELECT room_id, count(*) as count FROM event_search" + " WHERE vector @@ to_tsquery('english', ?) AND " + ) + count_args = [search_query] + count_args + elif isinstance(self.database_engine, Sqlite3Engine): + # We use CROSS JOIN here to ensure we use the right indexes. + # https://sqlite.org/optoverview.html#crossjoin + # + # We want to use the full text search index on event_search to + # extract all possible matches first, then lookup those matches + # in the events table to get the topological ordering. We need + # to use the indexes in this order because sqlite refuses to + # MATCH unless it uses the full text search index + sql = ( + "SELECT rank(matchinfo) as rank, room_id, event_id," + " origin_server_ts, stream_ordering" + " FROM (SELECT key, event_id, matchinfo(event_search) as matchinfo" + " FROM event_search" + " WHERE value MATCH ?" + " )" + " CROSS JOIN events USING (event_id)" + " WHERE " + ) + args = [search_query] + args + + count_sql = ( + "SELECT room_id, count(*) as count FROM event_search" + " WHERE value MATCH ? AND " + ) + count_args = [search_term] + count_args + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + + sql += " AND ".join(clauses) + count_sql += " AND ".join(count_clauses) + + # We add an arbitrary limit here to ensure we don't try to pull the + # entire table from the database. + if isinstance(self.database_engine, PostgresEngine): + sql += ( + " ORDER BY origin_server_ts DESC NULLS LAST," + " stream_ordering DESC NULLS LAST LIMIT ?" + ) + elif isinstance(self.database_engine, Sqlite3Engine): + sql += " ORDER BY origin_server_ts DESC, stream_ordering DESC LIMIT ?" + else: + raise Exception("Unrecognized database engine") + + args.append(limit) + + results = yield self._execute("search_rooms", self.cursor_to_dict, sql, *args) + + results = list(filter(lambda row: row["room_id"] in room_ids, results)) + + events = yield self.get_events_as_list([r["event_id"] for r in results]) + + event_map = {ev.event_id: ev for ev in events} + + highlights = None + if isinstance(self.database_engine, PostgresEngine): + highlights = yield self._find_highlights_in_postgres(search_query, events) + + count_sql += " GROUP BY room_id" + + count_results = yield self._execute( + "search_rooms_count", self.cursor_to_dict, count_sql, *count_args + ) + + count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) + + return { + "results": [ + { + "event": event_map[r["event_id"]], + "rank": r["rank"], + "pagination_token": "%s,%s" + % (r["origin_server_ts"], r["stream_ordering"]), + } + for r in results + if r["event_id"] in event_map + ], + "highlights": highlights, + "count": count, + } + + def _find_highlights_in_postgres(self, search_query, events): + """Given a list of events and a search term, return a list of words + that match from the content of the event. + + This is used to give a list of words that clients can match against to + highlight the matching parts. + + Args: + search_query (str) + events (list): A list of events + + Returns: + deferred : A set of strings. + """ + + def f(txn): + highlight_words = set() + for event in events: + # As a hack we simply join values of all possible keys. This is + # fine since we're only using them to find possible highlights. + values = [] + for key in ("body", "name", "topic"): + v = event.content.get(key, None) + if v: + values.append(v) + + if not values: + continue + + value = " ".join(values) + + # We need to find some values for StartSel and StopSel that + # aren't in the value so that we can pick results out. + start_sel = "<" + stop_sel = ">" + + while start_sel in value: + start_sel += "<" + while stop_sel in value: + stop_sel += ">" + + query = "SELECT ts_headline(?, to_tsquery('english', ?), %s)" % ( + _to_postgres_options( + { + "StartSel": start_sel, + "StopSel": stop_sel, + "MaxFragments": "50", + } + ) + ) + txn.execute(query, (value, search_query)) + headline, = txn.fetchall()[0] + + # Now we need to pick the possible highlights out of the haedline + # result. + matcher_regex = "%s(.*?)%s" % ( + re.escape(start_sel), + re.escape(stop_sel), + ) + + res = re.findall(matcher_regex, headline) + highlight_words.update([r.lower() for r in res]) + + return highlight_words + + return self.runInteraction("_find_highlights", f) + + +def _to_postgres_options(options_dict): + return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),) + + +def _parse_query(database_engine, search_term): + """Takes a plain unicode string from the user and converts it into a form + that can be passed to database. + We use this so that we can add prefix matching, which isn't something + that is supported by default. + """ + + # Pull out the individual words, discarding any non-word characters. + results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) + + if isinstance(database_engine, PostgresEngine): + return " & ".join(result + ":*" for result in results) + elif isinstance(database_engine, Sqlite3Engine): + return " & ".join(result + "*" for result in results) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") diff --git a/synapse/storage/data_stores/main/signatures.py b/synapse/storage/data_stores/main/signatures.py new file mode 100644 index 0000000000..556191b76f --- /dev/null +++ b/synapse/storage/data_stores/main/signatures.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import six + +from unpaddedbase64 import encode_base64 + +from twisted.internet import defer + +from synapse.crypto.event_signing import compute_event_reference_hash +from synapse.storage._base import SQLBaseStore +from synapse.util.caches.descriptors import cached, cachedList + +# py2 sqlite has buffer hardcoded as only binary type, so we must use it, +# despite being deprecated and removed in favor of memoryview +if six.PY2: + db_binary_type = six.moves.builtins.buffer +else: + db_binary_type = memoryview + + +class SignatureWorkerStore(SQLBaseStore): + @cached() + def get_event_reference_hash(self, event_id): + # This is a dummy function to allow get_event_reference_hashes + # to use its cache + raise NotImplementedError() + + @cachedList( + cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1 + ) + def get_event_reference_hashes(self, event_ids): + def f(txn): + return { + event_id: self._get_event_reference_hashes_txn(txn, event_id) + for event_id in event_ids + } + + return self.runInteraction("get_event_reference_hashes", f) + + @defer.inlineCallbacks + def add_event_hashes(self, event_ids): + hashes = yield self.get_event_reference_hashes(event_ids) + hashes = { + e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"} + for e_id, h in hashes.items() + } + + return list(hashes.items()) + + def _get_event_reference_hashes_txn(self, txn, event_id): + """Get all the hashes for a given PDU. + Args: + txn (cursor): + event_id (str): Id for the Event. + Returns: + A dict[unicode, bytes] of algorithm -> hash. + """ + query = ( + "SELECT algorithm, hash" + " FROM event_reference_hashes" + " WHERE event_id = ?" + ) + txn.execute(query, (event_id,)) + return {k: v for k, v in txn} + + +class SignatureStore(SignatureWorkerStore): + """Persistence for event signatures and hashes""" + + def _store_event_reference_hashes_txn(self, txn, events): + """Store a hash for a PDU + Args: + txn (cursor): + events (list): list of Events. + """ + + vals = [] + for event in events: + ref_alg, ref_hash_bytes = compute_event_reference_hash(event) + vals.append( + { + "event_id": event.event_id, + "algorithm": ref_alg, + "hash": db_binary_type(ref_hash_bytes), + } + ) + + self._simple_insert_many_txn(txn, table="event_reference_hashes", values=vals) diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py new file mode 100644 index 0000000000..d54442e5fa --- /dev/null +++ b/synapse/storage/data_stores/main/state.py @@ -0,0 +1,1244 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from collections import namedtuple + +from six import iteritems, itervalues +from six.moves import range + +from twisted.internet import defer + +from synapse.api.constants import EventTypes +from synapse.api.errors import NotFoundError +from synapse.storage._base import SQLBaseStore +from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.engines import PostgresEngine +from synapse.storage.state import StateFilter +from synapse.util.caches import get_cache_factor_for, intern_string +from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.caches.dictionary_cache import DictionaryCache +from synapse.util.stringutils import to_ascii + +logger = logging.getLogger(__name__) + + +MAX_STATE_DELTA_HOPS = 100 + + +class _GetStateGroupDelta( + namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids")) +): + """Return type of get_state_group_delta that implements __len__, which lets + us use the itrable flag when caching + """ + + __slots__ = [] + + def __len__(self): + return len(self.delta_ids) if self.delta_ids else 0 + + +class StateGroupBackgroundUpdateStore(SQLBaseStore): + """Defines functions related to state groups needed to run the state backgroud + updates. + """ + + def _count_state_group_hops_txn(self, txn, state_group): + """Given a state group, count how many hops there are in the tree. + + This is used to ensure the delta chains don't get too long. + """ + if isinstance(self.database_engine, PostgresEngine): + sql = """ + WITH RECURSIVE state(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, state s + WHERE s.state_group = e.state_group + ) + SELECT count(*) FROM state; + """ + + txn.execute(sql, (state_group,)) + row = txn.fetchone() + if row and row[0]: + return row[0] + else: + return 0 + else: + # We don't use WITH RECURSIVE on sqlite3 as there are distributions + # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + next_group = state_group + count = 0 + + while next_group: + next_group = self._simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": next_group}, + retcol="prev_state_group", + allow_none=True, + ) + if next_group: + count += 1 + + return count + + def _get_state_groups_from_groups_txn( + self, txn, groups, state_filter=StateFilter.all() + ): + results = {group: {} for group in groups} + + where_clause, where_args = state_filter.make_sql_filter_clause() + + # Unless the filter clause is empty, we're going to append it after an + # existing where clause + if where_clause: + where_clause = " AND (%s)" % (where_clause,) + + if isinstance(self.database_engine, PostgresEngine): + # Temporarily disable sequential scans in this transaction. This is + # a temporary hack until we can add the right indices in + txn.execute("SET LOCAL enable_seqscan=off") + + # The below query walks the state_group tree so that the "state" + # table includes all state_groups in the tree. It then joins + # against `state_groups_state` to fetch the latest state. + # It assumes that previous state groups are always numerically + # lesser. + # The PARTITION is used to get the event_id in the greatest state + # group for the given type, state_key. + # This may return multiple rows per (type, state_key), but last_value + # should be the same. + sql = """ + WITH RECURSIVE state(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, state s + WHERE s.state_group = e.state_group + ) + SELECT DISTINCT type, state_key, last_value(event_id) OVER ( + PARTITION BY type, state_key ORDER BY state_group ASC + ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + ) AS event_id FROM state_groups_state + WHERE state_group IN ( + SELECT state_group FROM state + ) + """ + + for group in groups: + args = [group] + args.extend(where_args) + + txn.execute(sql + where_clause, args) + for row in txn: + typ, state_key, event_id = row + key = (typ, state_key) + results[group][key] = event_id + else: + max_entries_returned = state_filter.max_entries_returned() + + # We don't use WITH RECURSIVE on sqlite3 as there are distributions + # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + for group in groups: + next_group = group + + while next_group: + # We did this before by getting the list of group ids, and + # then passing that list to sqlite to get latest event for + # each (type, state_key). However, that was terribly slow + # without the right indices (which we can't add until + # after we finish deduping state, which requires this func) + args = [next_group] + args.extend(where_args) + + txn.execute( + "SELECT type, state_key, event_id FROM state_groups_state" + " WHERE state_group = ? " + where_clause, + args, + ) + results[group].update( + ((typ, state_key), event_id) + for typ, state_key, event_id in txn + if (typ, state_key) not in results[group] + ) + + # If the number of entries in the (type,state_key)->event_id dict + # matches the number of (type,state_keys) types we were searching + # for, then we must have found them all, so no need to go walk + # further down the tree... UNLESS our types filter contained + # wildcards (i.e. Nones) in which case we have to do an exhaustive + # search + if ( + max_entries_returned is not None + and len(results[group]) == max_entries_returned + ): + break + + next_group = self._simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": next_group}, + retcol="prev_state_group", + allow_none=True, + ) + + return results + + +# this inherits from EventsWorkerStore because it calls self.get_events +class StateGroupWorkerStore( + EventsWorkerStore, StateGroupBackgroundUpdateStore, SQLBaseStore +): + """The parts of StateGroupStore that can be called from workers. + """ + + STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" + STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" + CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" + + def __init__(self, db_conn, hs): + super(StateGroupWorkerStore, self).__init__(db_conn, hs) + + # Originally the state store used a single DictionaryCache to cache the + # event IDs for the state types in a given state group to avoid hammering + # on the state_group* tables. + # + # The point of using a DictionaryCache is that it can cache a subset + # of the state events for a given state group (i.e. a subset of the keys for a + # given dict which is an entry in the cache for a given state group ID). + # + # However, this poses problems when performing complicated queries + # on the store - for instance: "give me all the state for this group, but + # limit members to this subset of users", as DictionaryCache's API isn't + # rich enough to say "please cache any of these fields, apart from this subset". + # This is problematic when lazy loading members, which requires this behaviour, + # as without it the cache has no choice but to speculatively load all + # state events for the group, which negates the efficiency being sought. + # + # Rather than overcomplicating DictionaryCache's API, we instead split the + # state_group_cache into two halves - one for tracking non-member events, + # and the other for tracking member_events. This means that lazy loading + # queries can be made in a cache-friendly manner by querying both caches + # separately and then merging the result. So for the example above, you + # would query the members cache for a specific subset of state keys + # (which DictionaryCache will handle efficiently and fine) and the non-members + # cache for all state (which DictionaryCache will similarly handle fine) + # and then just merge the results together. + # + # We size the non-members cache to be smaller than the members cache as the + # vast majority of state in Matrix (today) is member events. + + self._state_group_cache = DictionaryCache( + "*stateGroupCache*", + # TODO: this hasn't been tuned yet + 50000 * get_cache_factor_for("stateGroupCache"), + ) + self._state_group_members_cache = DictionaryCache( + "*stateGroupMembersCache*", + 500000 * get_cache_factor_for("stateGroupMembersCache"), + ) + + @defer.inlineCallbacks + def get_room_version(self, room_id): + """Get the room_version of a given room + + Args: + room_id (str) + + Returns: + Deferred[str] + + Raises: + NotFoundError if the room is unknown + """ + # for now we do this by looking at the create event. We may want to cache this + # more intelligently in future. + + # Retrieve the room's create event + create_event = yield self.get_create_event_for_room(room_id) + return create_event.content.get("room_version", "1") + + @defer.inlineCallbacks + def get_room_predecessor(self, room_id): + """Get the predecessor room of an upgraded room if one exists. + Otherwise return None. + + Args: + room_id (str) + + Returns: + Deferred[unicode|None]: predecessor room id + + Raises: + NotFoundError if the room is unknown + """ + # Retrieve the room's create event + create_event = yield self.get_create_event_for_room(room_id) + + # Return predecessor if present + return create_event.content.get("predecessor", None) + + @defer.inlineCallbacks + def get_create_event_for_room(self, room_id): + """Get the create state event for a room. + + Args: + room_id (str) + + Returns: + Deferred[EventBase]: The room creation event. + + Raises: + NotFoundError if the room is unknown + """ + state_ids = yield self.get_current_state_ids(room_id) + create_id = state_ids.get((EventTypes.Create, "")) + + # If we can't find the create event, assume we've hit a dead end + if not create_id: + raise NotFoundError("Unknown room %s" % (room_id)) + + # Retrieve the room's create event and return + create_event = yield self.get_event(create_id) + return create_event + + @cached(max_entries=100000, iterable=True) + def get_current_state_ids(self, room_id): + """Get the current state event ids for a room based on the + current_state_events table. + + Args: + room_id (str) + + Returns: + deferred: dict of (type, state_key) -> event_id + """ + + def _get_current_state_ids_txn(txn): + txn.execute( + """SELECT type, state_key, event_id FROM current_state_events + WHERE room_id = ? + """, + (room_id,), + ) + + return { + (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn + } + + return self.runInteraction("get_current_state_ids", _get_current_state_ids_txn) + + # FIXME: how should this be cached? + def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()): + """Get the current state event of a given type for a room based on the + current_state_events table. This may not be as up-to-date as the result + of doing a fresh state resolution as per state_handler.get_current_state + + Args: + room_id (str) + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns: + Deferred[dict[tuple[str, str], str]]: Map from type/state_key to + event ID. + """ + + where_clause, where_args = state_filter.make_sql_filter_clause() + + if not where_clause: + # We delegate to the cached version + return self.get_current_state_ids(room_id) + + def _get_filtered_current_state_ids_txn(txn): + results = {} + sql = """ + SELECT type, state_key, event_id FROM current_state_events + WHERE room_id = ? + """ + + if where_clause: + sql += " AND (%s)" % (where_clause,) + + args = [room_id] + args.extend(where_args) + txn.execute(sql, args) + for row in txn: + typ, state_key, event_id = row + key = (intern_string(typ), intern_string(state_key)) + results[key] = event_id + + return results + + return self.runInteraction( + "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn + ) + + @defer.inlineCallbacks + def get_canonical_alias_for_room(self, room_id): + """Get canonical alias for room, if any + + Args: + room_id (str) + + Returns: + Deferred[str|None]: The canonical alias, if any + """ + + state = yield self.get_filtered_current_state_ids( + room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")]) + ) + + event_id = state.get((EventTypes.CanonicalAlias, "")) + if not event_id: + return + + event = yield self.get_event(event_id, allow_none=True) + if not event: + return + + return event.content.get("canonical_alias") + + @cached(max_entries=10000, iterable=True) + def get_state_group_delta(self, state_group): + """Given a state group try to return a previous group and a delta between + the old and the new. + + Returns: + (prev_group, delta_ids), where both may be None. + """ + + def _get_state_group_delta_txn(txn): + prev_group = self._simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": state_group}, + retcol="prev_state_group", + allow_none=True, + ) + + if not prev_group: + return _GetStateGroupDelta(None, None) + + delta_ids = self._simple_select_list_txn( + txn, + table="state_groups_state", + keyvalues={"state_group": state_group}, + retcols=("type", "state_key", "event_id"), + ) + + return _GetStateGroupDelta( + prev_group, + {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids}, + ) + + return self.runInteraction("get_state_group_delta", _get_state_group_delta_txn) + + @defer.inlineCallbacks + def get_state_groups_ids(self, _room_id, event_ids): + """Get the event IDs of all the state for the state groups for the given events + + Args: + _room_id (str): id of the room for these events + event_ids (iterable[str]): ids of the events + + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) + """ + if not event_ids: + return {} + + event_to_groups = yield self._get_state_group_for_events(event_ids) + + groups = set(itervalues(event_to_groups)) + group_to_state = yield self._get_state_for_groups(groups) + + return group_to_state + + @defer.inlineCallbacks + def get_state_ids_for_group(self, state_group): + """Get the event IDs of all the state in the given state group + + Args: + state_group (int) + + Returns: + Deferred[dict]: Resolves to a map of (type, state_key) -> event_id + """ + group_to_state = yield self._get_state_for_groups((state_group,)) + + return group_to_state[state_group] + + @defer.inlineCallbacks + def get_state_groups(self, room_id, event_ids): + """ Get the state groups for the given list of event_ids + + Returns: + Deferred[dict[int, list[EventBase]]]: + dict of state_group_id -> list of state events. + """ + if not event_ids: + return {} + + group_to_ids = yield self.get_state_groups_ids(room_id, event_ids) + + state_event_map = yield self.get_events( + [ + ev_id + for group_ids in itervalues(group_to_ids) + for ev_id in itervalues(group_ids) + ], + get_prev_content=False, + ) + + return { + group: [ + state_event_map[v] + for v in itervalues(event_id_map) + if v in state_event_map + ] + for group, event_id_map in iteritems(group_to_ids) + } + + @defer.inlineCallbacks + def _get_state_groups_from_groups(self, groups, state_filter): + """Returns the state groups for a given set of groups, filtering on + types of state events. + + Args: + groups(list[int]): list of state group IDs to query + state_filter (StateFilter): The state filter used to fetch state + from the database. + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) + """ + results = {} + + chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] + for chunk in chunks: + res = yield self.runInteraction( + "_get_state_groups_from_groups", + self._get_state_groups_from_groups_txn, + chunk, + state_filter, + ) + results.update(res) + + return results + + @defer.inlineCallbacks + def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): + """Given a list of event_ids and type tuples, return a list of state + dicts for each event. + + Args: + event_ids (list[string]) + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns: + deferred: A dict of (event_id) -> (type, state_key) -> [state_events] + """ + event_to_groups = yield self._get_state_group_for_events(event_ids) + + groups = set(itervalues(event_to_groups)) + group_to_state = yield self._get_state_for_groups(groups, state_filter) + + state_event_map = yield self.get_events( + [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)], + get_prev_content=False, + ) + + event_to_state = { + event_id: { + k: state_event_map[v] + for k, v in iteritems(group_to_state[group]) + if v in state_event_map + } + for event_id, group in iteritems(event_to_groups) + } + + return {event: event_to_state[event] for event in event_ids} + + @defer.inlineCallbacks + def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()): + """ + Get the state dicts corresponding to a list of events, containing the event_ids + of the state events (as opposed to the events themselves) + + Args: + event_ids(list(str)): events whose state should be returned + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns: + A deferred dict from event_id -> (type, state_key) -> event_id + """ + event_to_groups = yield self._get_state_group_for_events(event_ids) + + groups = set(itervalues(event_to_groups)) + group_to_state = yield self._get_state_for_groups(groups, state_filter) + + event_to_state = { + event_id: group_to_state[group] + for event_id, group in iteritems(event_to_groups) + } + + return {event: event_to_state[event] for event in event_ids} + + @defer.inlineCallbacks + def get_state_for_event(self, event_id, state_filter=StateFilter.all()): + """ + Get the state dict corresponding to a particular event + + Args: + event_id(str): event whose state should be returned + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns: + A deferred dict from (type, state_key) -> state_event + """ + state_map = yield self.get_state_for_events([event_id], state_filter) + return state_map[event_id] + + @defer.inlineCallbacks + def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()): + """ + Get the state dict corresponding to a particular event + + Args: + event_id(str): event whose state should be returned + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns: + A deferred dict from (type, state_key) -> state_event + """ + state_map = yield self.get_state_ids_for_events([event_id], state_filter) + return state_map[event_id] + + @cached(max_entries=50000) + def _get_state_group_for_event(self, event_id): + return self._simple_select_one_onecol( + table="event_to_state_groups", + keyvalues={"event_id": event_id}, + retcol="state_group", + allow_none=True, + desc="_get_state_group_for_event", + ) + + @cachedList( + cached_method_name="_get_state_group_for_event", + list_name="event_ids", + num_args=1, + inlineCallbacks=True, + ) + def _get_state_group_for_events(self, event_ids): + """Returns mapping event_id -> state_group + """ + rows = yield self._simple_select_many_batch( + table="event_to_state_groups", + column="event_id", + iterable=event_ids, + keyvalues={}, + retcols=("event_id", "state_group"), + desc="_get_state_group_for_events", + ) + + return {row["event_id"]: row["state_group"] for row in rows} + + def _get_state_for_group_using_cache(self, cache, group, state_filter): + """Checks if group is in cache. See `_get_state_for_groups` + + Args: + cache(DictionaryCache): the state group cache to use + group(int): The state group to lookup + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns 2-tuple (`state_dict`, `got_all`). + `got_all` is a bool indicating if we successfully retrieved all + requests state from the cache, if False we need to query the DB for the + missing state. + """ + is_all, known_absent, state_dict_ids = cache.get(group) + + if is_all or state_filter.is_full(): + # Either we have everything or want everything, either way + # `is_all` tells us whether we've gotten everything. + return state_filter.filter_state(state_dict_ids), is_all + + # tracks whether any of our requested types are missing from the cache + missing_types = False + + if state_filter.has_wildcards(): + # We don't know if we fetched all the state keys for the types in + # the filter that are wildcards, so we have to assume that we may + # have missed some. + missing_types = True + else: + # There aren't any wild cards, so `concrete_types()` returns the + # complete list of event types we're wanting. + for key in state_filter.concrete_types(): + if key not in state_dict_ids and key not in known_absent: + missing_types = True + break + + return state_filter.filter_state(state_dict_ids), not missing_types + + @defer.inlineCallbacks + def _get_state_for_groups(self, groups, state_filter=StateFilter.all()): + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key + + Args: + groups (iterable[int]): list of state groups for which we want + to get the state. + state_filter (StateFilter): The state filter used to fetch state + from the database. + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) + """ + + member_filter, non_member_filter = state_filter.get_member_split() + + # Now we look them up in the member and non-member caches + non_member_state, incomplete_groups_nm, = ( + yield self._get_state_for_groups_using_cache( + groups, self._state_group_cache, state_filter=non_member_filter + ) + ) + + member_state, incomplete_groups_m, = ( + yield self._get_state_for_groups_using_cache( + groups, self._state_group_members_cache, state_filter=member_filter + ) + ) + + state = dict(non_member_state) + for group in groups: + state[group].update(member_state[group]) + + # Now fetch any missing groups from the database + + incomplete_groups = incomplete_groups_m | incomplete_groups_nm + + if not incomplete_groups: + return state + + cache_sequence_nm = self._state_group_cache.sequence + cache_sequence_m = self._state_group_members_cache.sequence + + # Help the cache hit ratio by expanding the filter a bit + db_state_filter = state_filter.return_expanded() + + group_to_state_dict = yield self._get_state_groups_from_groups( + list(incomplete_groups), state_filter=db_state_filter + ) + + # Now lets update the caches + self._insert_into_cache( + group_to_state_dict, + db_state_filter, + cache_seq_num_members=cache_sequence_m, + cache_seq_num_non_members=cache_sequence_nm, + ) + + # And finally update the result dict, by filtering out any extra + # stuff we pulled out of the database. + for group, group_state_dict in iteritems(group_to_state_dict): + # We just replace any existing entries, as we will have loaded + # everything we need from the database anyway. + state[group] = state_filter.filter_state(group_state_dict) + + return state + + def _get_state_for_groups_using_cache(self, groups, cache, state_filter): + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key, querying from a specific cache. + + Args: + groups (iterable[int]): list of state groups for which we want + to get the state. + cache (DictionaryCache): the cache of group ids to state dicts which + we will pass through - either the normal state cache or the specific + members state cache. + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns: + tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of + dict of state_group_id -> (dict of (type, state_key) -> event id) + of entries in the cache, and the state group ids either missing + from the cache or incomplete. + """ + results = {} + incomplete_groups = set() + for group in set(groups): + state_dict_ids, got_all = self._get_state_for_group_using_cache( + cache, group, state_filter + ) + results[group] = state_dict_ids + + if not got_all: + incomplete_groups.add(group) + + return results, incomplete_groups + + def _insert_into_cache( + self, + group_to_state_dict, + state_filter, + cache_seq_num_members, + cache_seq_num_non_members, + ): + """Inserts results from querying the database into the relevant cache. + + Args: + group_to_state_dict (dict): The new entries pulled from database. + Map from state group to state dict + state_filter (StateFilter): The state filter used to fetch state + from the database. + cache_seq_num_members (int): Sequence number of member cache since + last lookup in cache + cache_seq_num_non_members (int): Sequence number of member cache since + last lookup in cache + """ + + # We need to work out which types we've fetched from the DB for the + # member vs non-member caches. This should be as accurate as possible, + # but can be an underestimate (e.g. when we have wild cards) + + member_filter, non_member_filter = state_filter.get_member_split() + if member_filter.is_full(): + # We fetched all member events + member_types = None + else: + # `concrete_types()` will only return a subset when there are wild + # cards in the filter, but that's fine. + member_types = member_filter.concrete_types() + + if non_member_filter.is_full(): + # We fetched all non member events + non_member_types = None + else: + non_member_types = non_member_filter.concrete_types() + + for group, group_state_dict in iteritems(group_to_state_dict): + state_dict_members = {} + state_dict_non_members = {} + + for k, v in iteritems(group_state_dict): + if k[0] == EventTypes.Member: + state_dict_members[k] = v + else: + state_dict_non_members[k] = v + + self._state_group_members_cache.update( + cache_seq_num_members, + key=group, + value=state_dict_members, + fetched_keys=member_types, + ) + + self._state_group_cache.update( + cache_seq_num_non_members, + key=group, + value=state_dict_non_members, + fetched_keys=non_member_types, + ) + + def store_state_group( + self, event_id, room_id, prev_group, delta_ids, current_state_ids + ): + """Store a new set of state, returning a newly assigned state group. + + Args: + event_id (str): The event ID for which the state was calculated + room_id (str) + prev_group (int|None): A previous state group for the room, optional. + delta_ids (dict|None): The delta between state at `prev_group` and + `current_state_ids`, if `prev_group` was given. Same format as + `current_state_ids`. + current_state_ids (dict): The state to store. Map of (type, state_key) + to event_id. + + Returns: + Deferred[int]: The state group ID + """ + + def _store_state_group_txn(txn): + if current_state_ids is None: + # AFAIK, this can never happen + raise Exception("current_state_ids cannot be None") + + state_group = self.database_engine.get_next_state_group_id(txn) + + self._simple_insert_txn( + txn, + table="state_groups", + values={"id": state_group, "room_id": room_id, "event_id": event_id}, + ) + + # We persist as a delta if we can, while also ensuring the chain + # of deltas isn't tooo long, as otherwise read performance degrades. + if prev_group: + is_in_db = self._simple_select_one_onecol_txn( + txn, + table="state_groups", + keyvalues={"id": prev_group}, + retcol="id", + allow_none=True, + ) + if not is_in_db: + raise Exception( + "Trying to persist state with unpersisted prev_group: %r" + % (prev_group,) + ) + + potential_hops = self._count_state_group_hops_txn(txn, prev_group) + if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: + self._simple_insert_txn( + txn, + table="state_group_edges", + values={"state_group": state_group, "prev_state_group": prev_group}, + ) + + self._simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(delta_ids) + ], + ) + else: + self._simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(current_state_ids) + ], + ) + + # Prefill the state group caches with this group. + # It's fine to use the sequence like this as the state group map + # is immutable. (If the map wasn't immutable then this prefill could + # race with another update) + + current_member_state_ids = { + s: ev + for (s, ev) in iteritems(current_state_ids) + if s[0] == EventTypes.Member + } + txn.call_after( + self._state_group_members_cache.update, + self._state_group_members_cache.sequence, + key=state_group, + value=dict(current_member_state_ids), + ) + + current_non_member_state_ids = { + s: ev + for (s, ev) in iteritems(current_state_ids) + if s[0] != EventTypes.Member + } + txn.call_after( + self._state_group_cache.update, + self._state_group_cache.sequence, + key=state_group, + value=dict(current_non_member_state_ids), + ) + + return state_group + + return self.runInteraction("store_state_group", _store_state_group_txn) + + +class StateBackgroundUpdateStore( + StateGroupBackgroundUpdateStore, BackgroundUpdateStore +): + + STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" + STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" + CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" + EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" + + def __init__(self, db_conn, hs): + super(StateBackgroundUpdateStore, self).__init__(db_conn, hs) + self.register_background_update_handler( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, + self._background_deduplicate_state, + ) + self.register_background_update_handler( + self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state + ) + self.register_background_index_update( + self.CURRENT_STATE_INDEX_UPDATE_NAME, + index_name="current_state_events_member_index", + table="current_state_events", + columns=["state_key"], + where_clause="type='m.room.member'", + ) + self.register_background_index_update( + self.EVENT_STATE_GROUP_INDEX_UPDATE_NAME, + index_name="event_to_state_groups_sg_index", + table="event_to_state_groups", + columns=["state_group"], + ) + + @defer.inlineCallbacks + def _background_deduplicate_state(self, progress, batch_size): + """This background update will slowly deduplicate state by reencoding + them as deltas. + """ + last_state_group = progress.get("last_state_group", 0) + rows_inserted = progress.get("rows_inserted", 0) + max_group = progress.get("max_group", None) + + BATCH_SIZE_SCALE_FACTOR = 100 + + batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR)) + + if max_group is None: + rows = yield self._execute( + "_background_deduplicate_state", + None, + "SELECT coalesce(max(id), 0) FROM state_groups", + ) + max_group = rows[0][0] + + def reindex_txn(txn): + new_last_state_group = last_state_group + for count in range(batch_size): + txn.execute( + "SELECT id, room_id FROM state_groups" + " WHERE ? < id AND id <= ?" + " ORDER BY id ASC" + " LIMIT 1", + (new_last_state_group, max_group), + ) + row = txn.fetchone() + if row: + state_group, room_id = row + + if not row or not state_group: + return True, count + + txn.execute( + "SELECT state_group FROM state_group_edges" + " WHERE state_group = ?", + (state_group,), + ) + + # If we reach a point where we've already started inserting + # edges we should stop. + if txn.fetchall(): + return True, count + + txn.execute( + "SELECT coalesce(max(id), 0) FROM state_groups" + " WHERE id < ? AND room_id = ?", + (state_group, room_id), + ) + prev_group, = txn.fetchone() + new_last_state_group = state_group + + if prev_group: + potential_hops = self._count_state_group_hops_txn(txn, prev_group) + if potential_hops >= MAX_STATE_DELTA_HOPS: + # We want to ensure chains are at most this long,# + # otherwise read performance degrades. + continue + + prev_state = self._get_state_groups_from_groups_txn( + txn, [prev_group] + ) + prev_state = prev_state[prev_group] + + curr_state = self._get_state_groups_from_groups_txn( + txn, [state_group] + ) + curr_state = curr_state[state_group] + + if not set(prev_state.keys()) - set(curr_state.keys()): + # We can only do a delta if the current has a strict super set + # of keys + + delta_state = { + key: value + for key, value in iteritems(curr_state) + if prev_state.get(key, None) != value + } + + self._simple_delete_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": state_group}, + ) + + self._simple_insert_txn( + txn, + table="state_group_edges", + values={ + "state_group": state_group, + "prev_state_group": prev_group, + }, + ) + + self._simple_delete_txn( + txn, + table="state_groups_state", + keyvalues={"state_group": state_group}, + ) + + self._simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(delta_state) + ], + ) + + progress = { + "last_state_group": state_group, + "rows_inserted": rows_inserted + batch_size, + "max_group": max_group, + } + + self._background_update_progress_txn( + txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress + ) + + return False, batch_size + + finished, result = yield self.runInteraction( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn + ) + + if finished: + yield self._end_background_update( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME + ) + + return result * BATCH_SIZE_SCALE_FACTOR + + @defer.inlineCallbacks + def _background_index_state(self, progress, batch_size): + def reindex_txn(conn): + conn.rollback() + if isinstance(self.database_engine, PostgresEngine): + # postgres insists on autocommit for the index + conn.set_session(autocommit=True) + try: + txn = conn.cursor() + txn.execute( + "CREATE INDEX CONCURRENTLY state_groups_state_type_idx" + " ON state_groups_state(state_group, type, state_key)" + ) + txn.execute("DROP INDEX IF EXISTS state_groups_state_id") + finally: + conn.set_session(autocommit=False) + else: + txn = conn.cursor() + txn.execute( + "CREATE INDEX state_groups_state_type_idx" + " ON state_groups_state(state_group, type, state_key)" + ) + txn.execute("DROP INDEX IF EXISTS state_groups_state_id") + + yield self.runWithConnection(reindex_txn) + + yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) + + return 1 + + +class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore): + """ Keeps track of the state at a given event. + + This is done by the concept of `state groups`. Every event is a assigned + a state group (identified by an arbitrary string), which references a + collection of state events. The current state of an event is then the + collection of state events referenced by the event's state group. + + Hence, every change in the current state causes a new state group to be + generated. However, if no change happens (e.g., if we get a message event + with only one parent it inherits the state group from its parent.) + + There are three tables: + * `state_groups`: Stores group name, first event with in the group and + room id. + * `event_to_state_groups`: Maps events to state groups. + * `state_groups_state`: Maps state group to state events. + """ + + def __init__(self, db_conn, hs): + super(StateStore, self).__init__(db_conn, hs) + + def _store_event_state_mappings_txn(self, txn, events_and_contexts): + state_groups = {} + for event, context in events_and_contexts: + if event.internal_metadata.is_outlier(): + continue + + # if the event was rejected, just give it the same state as its + # predecessor. + if context.rejected: + state_groups[event.event_id] = context.prev_group + continue + + state_groups[event.event_id] = context.state_group + + self._simple_insert_many_txn( + txn, + table="event_to_state_groups", + values=[ + {"state_group": state_group_id, "event_id": event_id} + for event_id, state_group_id in iteritems(state_groups) + ], + ) + + for event_id, state_group_id in iteritems(state_groups): + txn.call_after( + self._get_state_group_for_event.prefill, (event_id,), state_group_id + ) diff --git a/synapse/storage/data_stores/main/state_deltas.py b/synapse/storage/data_stores/main/state_deltas.py new file mode 100644 index 0000000000..28f33ec18f --- /dev/null +++ b/synapse/storage/data_stores/main/state_deltas.py @@ -0,0 +1,119 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.storage._base import SQLBaseStore + +logger = logging.getLogger(__name__) + + +class StateDeltasStore(SQLBaseStore): + def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int): + """Fetch a list of room state changes since the given stream id + + Each entry in the result contains the following fields: + - stream_id (int) + - room_id (str) + - type (str): event type + - state_key (str): + - event_id (str|None): new event_id for this state key. None if the + state has been deleted. + - prev_event_id (str|None): previous event_id for this state key. None + if it's new state. + + Args: + prev_stream_id (int): point to get changes since (exclusive) + max_stream_id (int): the point that we know has been correctly persisted + - ie, an upper limit to return changes from. + + Returns: + Deferred[tuple[int, list[dict]]: A tuple consisting of: + - the stream id which these results go up to + - list of current_state_delta_stream rows. If it is empty, we are + up to date. + """ + prev_stream_id = int(prev_stream_id) + + # check we're not going backwards + assert prev_stream_id <= max_stream_id + + if not self._curr_state_delta_stream_cache.has_any_entity_changed( + prev_stream_id + ): + # if the CSDs haven't changed between prev_stream_id and now, we + # know for certain that they haven't changed between prev_stream_id and + # max_stream_id. + return max_stream_id, [] + + def get_current_state_deltas_txn(txn): + # First we calculate the max stream id that will give us less than + # N results. + # We arbitarily limit to 100 stream_id entries to ensure we don't + # select toooo many. + sql = """ + SELECT stream_id, count(*) + FROM current_state_delta_stream + WHERE stream_id > ? AND stream_id <= ? + GROUP BY stream_id + ORDER BY stream_id ASC + LIMIT 100 + """ + txn.execute(sql, (prev_stream_id, max_stream_id)) + + total = 0 + + for stream_id, count in txn: + total += count + if total > 100: + # We arbitarily limit to 100 entries to ensure we don't + # select toooo many. + logger.debug( + "Clipping current_state_delta_stream rows to stream_id %i", + stream_id, + ) + clipped_stream_id = stream_id + break + else: + # if there's no problem, we may as well go right up to the max_stream_id + clipped_stream_id = max_stream_id + + # Now actually get the deltas + sql = """ + SELECT stream_id, room_id, type, state_key, event_id, prev_event_id + FROM current_state_delta_stream + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + """ + txn.execute(sql, (prev_stream_id, clipped_stream_id)) + return clipped_stream_id, self.cursor_to_dict(txn) + + return self.runInteraction( + "get_current_state_deltas", get_current_state_deltas_txn + ) + + def _get_max_stream_id_in_current_state_deltas_txn(self, txn): + return self._simple_select_one_onecol_txn( + txn, + table="current_state_delta_stream", + keyvalues={}, + retcol="COALESCE(MAX(stream_id), -1)", + ) + + def get_max_stream_id_in_current_state_deltas(self): + return self.runInteraction( + "get_max_stream_id_in_current_state_deltas", + self._get_max_stream_id_in_current_state_deltas_txn, + ) diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py new file mode 100644 index 0000000000..5ab639b2ad --- /dev/null +++ b/synapse/storage/data_stores/main/stats.py @@ -0,0 +1,881 @@ +# -*- coding: utf-8 -*- +# Copyright 2018, 2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from itertools import chain + +from twisted.internet import defer +from twisted.internet.defer import DeferredLock + +from synapse.api.constants import EventTypes, Membership +from synapse.storage.data_stores.main.state_deltas import StateDeltasStore +from synapse.storage.engines import PostgresEngine +from synapse.util.caches.descriptors import cached + +logger = logging.getLogger(__name__) + +# these fields track absolutes (e.g. total number of rooms on the server) +# You can think of these as Prometheus Gauges. +# You can draw these stats on a line graph. +# Example: number of users in a room +ABSOLUTE_STATS_FIELDS = { + "room": ( + "current_state_events", + "joined_members", + "invited_members", + "left_members", + "banned_members", + "local_users_in_room", + ), + "user": ("joined_rooms",), +} + +# these fields are per-timeslice and so should be reset to 0 upon a new slice +# You can draw these stats on a histogram. +# Example: number of events sent locally during a time slice +PER_SLICE_FIELDS = { + "room": ("total_events", "total_event_bytes"), + "user": ("invites_sent", "rooms_created", "total_events", "total_event_bytes"), +} + +TYPE_TO_TABLE = {"room": ("room_stats", "room_id"), "user": ("user_stats", "user_id")} + +# these are the tables (& ID columns) which contain our actual subjects +TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")} + + +class StatsStore(StateDeltasStore): + def __init__(self, db_conn, hs): + super(StatsStore, self).__init__(db_conn, hs) + + self.server_name = hs.hostname + self.clock = self.hs.get_clock() + self.stats_enabled = hs.config.stats_enabled + self.stats_bucket_size = hs.config.stats_bucket_size + + self.stats_delta_processing_lock = DeferredLock() + + self.register_background_update_handler( + "populate_stats_process_rooms", self._populate_stats_process_rooms + ) + self.register_background_update_handler( + "populate_stats_process_users", self._populate_stats_process_users + ) + # we no longer need to perform clean-up, but we will give ourselves + # the potential to reintroduce it in the future – so documentation + # will still encourage the use of this no-op handler. + self.register_noop_background_update("populate_stats_cleanup") + self.register_noop_background_update("populate_stats_prepare") + + def quantise_stats_time(self, ts): + """ + Quantises a timestamp to be a multiple of the bucket size. + + Args: + ts (int): the timestamp to quantise, in milliseconds since the Unix + Epoch + + Returns: + int: a timestamp which + - is divisible by the bucket size; + - is no later than `ts`; and + - is the largest such timestamp. + """ + return (ts // self.stats_bucket_size) * self.stats_bucket_size + + @defer.inlineCallbacks + def _populate_stats_process_users(self, progress, batch_size): + """ + This is a background update which regenerates statistics for users. + """ + if not self.stats_enabled: + yield self._end_background_update("populate_stats_process_users") + return 1 + + last_user_id = progress.get("last_user_id", "") + + def _get_next_batch(txn): + sql = """ + SELECT DISTINCT name FROM users + WHERE name > ? + ORDER BY name ASC + LIMIT ? + """ + txn.execute(sql, (last_user_id, batch_size)) + return [r for r, in txn] + + users_to_work_on = yield self.runInteraction( + "_populate_stats_process_users", _get_next_batch + ) + + # No more rooms -- complete the transaction. + if not users_to_work_on: + yield self._end_background_update("populate_stats_process_users") + return 1 + + for user_id in users_to_work_on: + yield self._calculate_and_set_initial_state_for_user(user_id) + progress["last_user_id"] = user_id + + yield self.runInteraction( + "populate_stats_process_users", + self._background_update_progress_txn, + "populate_stats_process_users", + progress, + ) + + return len(users_to_work_on) + + @defer.inlineCallbacks + def _populate_stats_process_rooms(self, progress, batch_size): + """ + This is a background update which regenerates statistics for rooms. + """ + if not self.stats_enabled: + yield self._end_background_update("populate_stats_process_rooms") + return 1 + + last_room_id = progress.get("last_room_id", "") + + def _get_next_batch(txn): + sql = """ + SELECT DISTINCT room_id FROM current_state_events + WHERE room_id > ? + ORDER BY room_id ASC + LIMIT ? + """ + txn.execute(sql, (last_room_id, batch_size)) + return [r for r, in txn] + + rooms_to_work_on = yield self.runInteraction( + "populate_stats_rooms_get_batch", _get_next_batch + ) + + # No more rooms -- complete the transaction. + if not rooms_to_work_on: + yield self._end_background_update("populate_stats_process_rooms") + return 1 + + for room_id in rooms_to_work_on: + yield self._calculate_and_set_initial_state_for_room(room_id) + progress["last_room_id"] = room_id + + yield self.runInteraction( + "_populate_stats_process_rooms", + self._background_update_progress_txn, + "populate_stats_process_rooms", + progress, + ) + + return len(rooms_to_work_on) + + def get_stats_positions(self): + """ + Returns the stats processor positions. + """ + return self._simple_select_one_onecol( + table="stats_incremental_position", + keyvalues={}, + retcol="stream_id", + desc="stats_incremental_position", + ) + + def update_room_state(self, room_id, fields): + """ + Args: + room_id (str) + fields (dict[str:Any]) + """ + + # For whatever reason some of the fields may contain null bytes, which + # postgres isn't a fan of, so we replace those fields with null. + for col in ( + "join_rules", + "history_visibility", + "encryption", + "name", + "topic", + "avatar", + "canonical_alias", + ): + field = fields.get(col) + if field and "\0" in field: + fields[col] = None + + return self._simple_upsert( + table="room_stats_state", + keyvalues={"room_id": room_id}, + values=fields, + desc="update_room_state", + ) + + def get_statistics_for_subject(self, stats_type, stats_id, start, size=100): + """ + Get statistics for a given subject. + + Args: + stats_type (str): The type of subject + stats_id (str): The ID of the subject (e.g. room_id or user_id) + start (int): Pagination start. Number of entries, not timestamp. + size (int): How many entries to return. + + Returns: + Deferred[list[dict]], where the dict has the keys of + ABSOLUTE_STATS_FIELDS[stats_type], and "bucket_size" and "end_ts". + """ + return self.runInteraction( + "get_statistics_for_subject", + self._get_statistics_for_subject_txn, + stats_type, + stats_id, + start, + size, + ) + + def _get_statistics_for_subject_txn( + self, txn, stats_type, stats_id, start, size=100 + ): + """ + Transaction-bound version of L{get_statistics_for_subject}. + """ + + table, id_col = TYPE_TO_TABLE[stats_type] + selected_columns = list( + ABSOLUTE_STATS_FIELDS[stats_type] + PER_SLICE_FIELDS[stats_type] + ) + + slice_list = self._simple_select_list_paginate_txn( + txn, + table + "_historical", + {id_col: stats_id}, + "end_ts", + start, + size, + retcols=selected_columns + ["bucket_size", "end_ts"], + order_direction="DESC", + ) + + return slice_list + + def get_room_stats_state(self, room_id): + """ + Returns the current room_stats_state for a room. + + Args: + room_id (str): The ID of the room to return state for. + + Returns (dict): + Dictionary containing these keys: + "name", "topic", "canonical_alias", "avatar", "join_rules", + "history_visibility" + """ + return self._simple_select_one( + "room_stats_state", + {"room_id": room_id}, + retcols=( + "name", + "topic", + "canonical_alias", + "avatar", + "join_rules", + "history_visibility", + ), + ) + + @cached() + def get_earliest_token_for_stats(self, stats_type, id): + """ + Fetch the "earliest token". This is used by the room stats delta + processor to ignore deltas that have been processed between the + start of the background task and any particular room's stats + being calculated. + + Returns: + Deferred[int] + """ + table, id_col = TYPE_TO_TABLE[stats_type] + + return self._simple_select_one_onecol( + "%s_current" % (table,), + keyvalues={id_col: id}, + retcol="completed_delta_stream_id", + allow_none=True, + ) + + def bulk_update_stats_delta(self, ts, updates, stream_id): + """Bulk update stats tables for a given stream_id and updates the stats + incremental position. + + Args: + ts (int): Current timestamp in ms + updates(dict[str, dict[str, dict[str, Counter]]]): The updates to + commit as a mapping stats_type -> stats_id -> field -> delta. + stream_id (int): Current position. + + Returns: + Deferred + """ + + def _bulk_update_stats_delta_txn(txn): + for stats_type, stats_updates in updates.items(): + for stats_id, fields in stats_updates.items(): + logger.info( + "Updating %s stats for %s: %s", stats_type, stats_id, fields + ) + self._update_stats_delta_txn( + txn, + ts=ts, + stats_type=stats_type, + stats_id=stats_id, + fields=fields, + complete_with_stream_id=stream_id, + ) + + self._simple_update_one_txn( + txn, + table="stats_incremental_position", + keyvalues={}, + updatevalues={"stream_id": stream_id}, + ) + + return self.runInteraction( + "bulk_update_stats_delta", _bulk_update_stats_delta_txn + ) + + def update_stats_delta( + self, + ts, + stats_type, + stats_id, + fields, + complete_with_stream_id, + absolute_field_overrides=None, + ): + """ + Updates the statistics for a subject, with a delta (difference/relative + change). + + Args: + ts (int): timestamp of the change + stats_type (str): "room" or "user" – the kind of subject + stats_id (str): the subject's ID (room ID or user ID) + fields (dict[str, int]): Deltas of stats values. + complete_with_stream_id (int, optional): + If supplied, converts an incomplete row into a complete row, + with the supplied stream_id marked as the stream_id where the + row was completed. + absolute_field_overrides (dict[str, int]): Current stats values + (i.e. not deltas) of absolute fields. + Does not work with per-slice fields. + """ + + return self.runInteraction( + "update_stats_delta", + self._update_stats_delta_txn, + ts, + stats_type, + stats_id, + fields, + complete_with_stream_id=complete_with_stream_id, + absolute_field_overrides=absolute_field_overrides, + ) + + def _update_stats_delta_txn( + self, + txn, + ts, + stats_type, + stats_id, + fields, + complete_with_stream_id, + absolute_field_overrides=None, + ): + if absolute_field_overrides is None: + absolute_field_overrides = {} + + table, id_col = TYPE_TO_TABLE[stats_type] + + quantised_ts = self.quantise_stats_time(int(ts)) + end_ts = quantised_ts + self.stats_bucket_size + + # Lets be paranoid and check that all the given field names are known + abs_field_names = ABSOLUTE_STATS_FIELDS[stats_type] + slice_field_names = PER_SLICE_FIELDS[stats_type] + for field in chain(fields.keys(), absolute_field_overrides.keys()): + if field not in abs_field_names and field not in slice_field_names: + # guard against potential SQL injection dodginess + raise ValueError( + "%s is not a recognised field" + " for stats type %s" % (field, stats_type) + ) + + # Per slice fields do not get added to the _current table + + # This calculates the deltas (`field = field + ?` values) + # for absolute fields, + # * defaulting to 0 if not specified + # (required for the INSERT part of upserting to work) + # * omitting overrides specified in `absolute_field_overrides` + deltas_of_absolute_fields = { + key: fields.get(key, 0) + for key in abs_field_names + if key not in absolute_field_overrides + } + + # Keep the delta stream ID field up to date + absolute_field_overrides = absolute_field_overrides.copy() + absolute_field_overrides["completed_delta_stream_id"] = complete_with_stream_id + + # first upsert the `_current` table + self._upsert_with_additive_relatives_txn( + txn=txn, + table=table + "_current", + keyvalues={id_col: stats_id}, + absolutes=absolute_field_overrides, + additive_relatives=deltas_of_absolute_fields, + ) + + per_slice_additive_relatives = { + key: fields.get(key, 0) for key in slice_field_names + } + self._upsert_copy_from_table_with_additive_relatives_txn( + txn=txn, + into_table=table + "_historical", + keyvalues={id_col: stats_id}, + extra_dst_insvalues={"bucket_size": self.stats_bucket_size}, + extra_dst_keyvalues={"end_ts": end_ts}, + additive_relatives=per_slice_additive_relatives, + src_table=table + "_current", + copy_columns=abs_field_names, + ) + + def _upsert_with_additive_relatives_txn( + self, txn, table, keyvalues, absolutes, additive_relatives + ): + """Used to update values in the stats tables. + + This is basically a slightly convoluted upsert that *adds* to any + existing rows. + + Args: + txn + table (str): Table name + keyvalues (dict[str, any]): Row-identifying key values + absolutes (dict[str, any]): Absolute (set) fields + additive_relatives (dict[str, int]): Fields that will be added onto + if existing row present. + """ + if self.database_engine.can_native_upsert: + absolute_updates = [ + "%(field)s = EXCLUDED.%(field)s" % {"field": field} + for field in absolutes.keys() + ] + + relative_updates = [ + "%(field)s = EXCLUDED.%(field)s + %(table)s.%(field)s" + % {"table": table, "field": field} + for field in additive_relatives.keys() + ] + + insert_cols = [] + qargs = [] + + for (key, val) in chain( + keyvalues.items(), absolutes.items(), additive_relatives.items() + ): + insert_cols.append(key) + qargs.append(val) + + sql = """ + INSERT INTO %(table)s (%(insert_cols_cs)s) + VALUES (%(insert_vals_qs)s) + ON CONFLICT (%(key_columns)s) DO UPDATE SET %(updates)s + """ % { + "table": table, + "insert_cols_cs": ", ".join(insert_cols), + "insert_vals_qs": ", ".join( + ["?"] * (len(keyvalues) + len(absolutes) + len(additive_relatives)) + ), + "key_columns": ", ".join(keyvalues), + "updates": ", ".join(chain(absolute_updates, relative_updates)), + } + + txn.execute(sql, qargs) + else: + self.database_engine.lock_table(txn, table) + retcols = list(chain(absolutes.keys(), additive_relatives.keys())) + current_row = self._simple_select_one_txn( + txn, table, keyvalues, retcols, allow_none=True + ) + if current_row is None: + merged_dict = {**keyvalues, **absolutes, **additive_relatives} + self._simple_insert_txn(txn, table, merged_dict) + else: + for (key, val) in additive_relatives.items(): + current_row[key] += val + current_row.update(absolutes) + self._simple_update_one_txn(txn, table, keyvalues, current_row) + + def _upsert_copy_from_table_with_additive_relatives_txn( + self, + txn, + into_table, + keyvalues, + extra_dst_keyvalues, + extra_dst_insvalues, + additive_relatives, + src_table, + copy_columns, + ): + """Updates the historic stats table with latest updates. + + This involves copying "absolute" fields from the `_current` table, and + adding relative fields to any existing values. + + Args: + txn: Transaction + into_table (str): The destination table to UPSERT the row into + keyvalues (dict[str, any]): Row-identifying key values + extra_dst_keyvalues (dict[str, any]): Additional keyvalues + for `into_table`. + extra_dst_insvalues (dict[str, any]): Additional values to insert + on new row creation for `into_table`. + additive_relatives (dict[str, any]): Fields that will be added onto + if existing row present. (Must be disjoint from copy_columns.) + src_table (str): The source table to copy from + copy_columns (iterable[str]): The list of columns to copy + """ + if self.database_engine.can_native_upsert: + ins_columns = chain( + keyvalues, + copy_columns, + additive_relatives, + extra_dst_keyvalues, + extra_dst_insvalues, + ) + sel_exprs = chain( + keyvalues, + copy_columns, + ( + "?" + for _ in chain( + additive_relatives, extra_dst_keyvalues, extra_dst_insvalues + ) + ), + ) + keyvalues_where = ("%s = ?" % f for f in keyvalues) + + sets_cc = ("%s = EXCLUDED.%s" % (f, f) for f in copy_columns) + sets_ar = ( + "%s = EXCLUDED.%s + %s.%s" % (f, f, into_table, f) + for f in additive_relatives + ) + + sql = """ + INSERT INTO %(into_table)s (%(ins_columns)s) + SELECT %(sel_exprs)s + FROM %(src_table)s + WHERE %(keyvalues_where)s + ON CONFLICT (%(keyvalues)s) + DO UPDATE SET %(sets)s + """ % { + "into_table": into_table, + "ins_columns": ", ".join(ins_columns), + "sel_exprs": ", ".join(sel_exprs), + "keyvalues_where": " AND ".join(keyvalues_where), + "src_table": src_table, + "keyvalues": ", ".join( + chain(keyvalues.keys(), extra_dst_keyvalues.keys()) + ), + "sets": ", ".join(chain(sets_cc, sets_ar)), + } + + qargs = list( + chain( + additive_relatives.values(), + extra_dst_keyvalues.values(), + extra_dst_insvalues.values(), + keyvalues.values(), + ) + ) + txn.execute(sql, qargs) + else: + self.database_engine.lock_table(txn, into_table) + src_row = self._simple_select_one_txn( + txn, src_table, keyvalues, copy_columns + ) + all_dest_keyvalues = {**keyvalues, **extra_dst_keyvalues} + dest_current_row = self._simple_select_one_txn( + txn, + into_table, + keyvalues=all_dest_keyvalues, + retcols=list(chain(additive_relatives.keys(), copy_columns)), + allow_none=True, + ) + + if dest_current_row is None: + merged_dict = { + **keyvalues, + **extra_dst_keyvalues, + **extra_dst_insvalues, + **src_row, + **additive_relatives, + } + self._simple_insert_txn(txn, into_table, merged_dict) + else: + for (key, val) in additive_relatives.items(): + src_row[key] = dest_current_row[key] + val + self._simple_update_txn(txn, into_table, all_dest_keyvalues, src_row) + + def get_changes_room_total_events_and_bytes(self, min_pos, max_pos): + """Fetches the counts of events in the given range of stream IDs. + + Args: + min_pos (int) + max_pos (int) + + Returns: + Deferred[dict[str, dict[str, int]]]: Mapping of room ID to field + changes. + """ + + return self.runInteraction( + "stats_incremental_total_events_and_bytes", + self.get_changes_room_total_events_and_bytes_txn, + min_pos, + max_pos, + ) + + def get_changes_room_total_events_and_bytes_txn(self, txn, low_pos, high_pos): + """Gets the total_events and total_event_bytes counts for rooms and + senders, in a range of stream_orderings (including backfilled events). + + Args: + txn + low_pos (int): Low stream ordering + high_pos (int): High stream ordering + + Returns: + tuple[dict[str, dict[str, int]], dict[str, dict[str, int]]]: The + room and user deltas for total_events/total_event_bytes in the + format of `stats_id` -> fields + """ + + if low_pos >= high_pos: + # nothing to do here. + return {}, {} + + if isinstance(self.database_engine, PostgresEngine): + new_bytes_expression = "OCTET_LENGTH(json)" + else: + new_bytes_expression = "LENGTH(CAST(json AS BLOB))" + + sql = """ + SELECT events.room_id, COUNT(*) AS new_events, SUM(%s) AS new_bytes + FROM events INNER JOIN event_json USING (event_id) + WHERE (? < stream_ordering AND stream_ordering <= ?) + OR (? <= stream_ordering AND stream_ordering <= ?) + GROUP BY events.room_id + """ % ( + new_bytes_expression, + ) + + txn.execute(sql, (low_pos, high_pos, -high_pos, -low_pos)) + + room_deltas = { + room_id: {"total_events": new_events, "total_event_bytes": new_bytes} + for room_id, new_events, new_bytes in txn + } + + sql = """ + SELECT events.sender, COUNT(*) AS new_events, SUM(%s) AS new_bytes + FROM events INNER JOIN event_json USING (event_id) + WHERE (? < stream_ordering AND stream_ordering <= ?) + OR (? <= stream_ordering AND stream_ordering <= ?) + GROUP BY events.sender + """ % ( + new_bytes_expression, + ) + + txn.execute(sql, (low_pos, high_pos, -high_pos, -low_pos)) + + user_deltas = { + user_id: {"total_events": new_events, "total_event_bytes": new_bytes} + for user_id, new_events, new_bytes in txn + if self.hs.is_mine_id(user_id) + } + + return room_deltas, user_deltas + + @defer.inlineCallbacks + def _calculate_and_set_initial_state_for_room(self, room_id): + """Calculate and insert an entry into room_stats_current. + + Args: + room_id (str) + + Returns: + Deferred[tuple[dict, dict, int]]: A tuple of room state, membership + counts and stream position. + """ + + def _fetch_current_state_stats(txn): + pos = self.get_room_max_stream_ordering() + + rows = self._simple_select_many_txn( + txn, + table="current_state_events", + column="type", + iterable=[ + EventTypes.Create, + EventTypes.JoinRules, + EventTypes.RoomHistoryVisibility, + EventTypes.Encryption, + EventTypes.Name, + EventTypes.Topic, + EventTypes.RoomAvatar, + EventTypes.CanonicalAlias, + ], + keyvalues={"room_id": room_id, "state_key": ""}, + retcols=["event_id"], + ) + + event_ids = [row["event_id"] for row in rows] + + txn.execute( + """ + SELECT membership, count(*) FROM current_state_events + WHERE room_id = ? AND type = 'm.room.member' + GROUP BY membership + """, + (room_id,), + ) + membership_counts = {membership: cnt for membership, cnt in txn} + + txn.execute( + """ + SELECT COALESCE(count(*), 0) FROM current_state_events + WHERE room_id = ? + """, + (room_id,), + ) + + current_state_events_count, = txn.fetchone() + + users_in_room = self.get_users_in_room_txn(txn, room_id) + + return ( + event_ids, + membership_counts, + current_state_events_count, + users_in_room, + pos, + ) + + ( + event_ids, + membership_counts, + current_state_events_count, + users_in_room, + pos, + ) = yield self.runInteraction( + "get_initial_state_for_room", _fetch_current_state_stats + ) + + state_event_map = yield self.get_events(event_ids, get_prev_content=False) + + room_state = { + "join_rules": None, + "history_visibility": None, + "encryption": None, + "name": None, + "topic": None, + "avatar": None, + "canonical_alias": None, + "is_federatable": True, + } + + for event in state_event_map.values(): + if event.type == EventTypes.JoinRules: + room_state["join_rules"] = event.content.get("join_rule") + elif event.type == EventTypes.RoomHistoryVisibility: + room_state["history_visibility"] = event.content.get( + "history_visibility" + ) + elif event.type == EventTypes.Encryption: + room_state["encryption"] = event.content.get("algorithm") + elif event.type == EventTypes.Name: + room_state["name"] = event.content.get("name") + elif event.type == EventTypes.Topic: + room_state["topic"] = event.content.get("topic") + elif event.type == EventTypes.RoomAvatar: + room_state["avatar"] = event.content.get("url") + elif event.type == EventTypes.CanonicalAlias: + room_state["canonical_alias"] = event.content.get("alias") + elif event.type == EventTypes.Create: + room_state["is_federatable"] = ( + event.content.get("m.federate", True) is True + ) + + yield self.update_room_state(room_id, room_state) + + local_users_in_room = [u for u in users_in_room if self.hs.is_mine_id(u)] + + yield self.update_stats_delta( + ts=self.clock.time_msec(), + stats_type="room", + stats_id=room_id, + fields={}, + complete_with_stream_id=pos, + absolute_field_overrides={ + "current_state_events": current_state_events_count, + "joined_members": membership_counts.get(Membership.JOIN, 0), + "invited_members": membership_counts.get(Membership.INVITE, 0), + "left_members": membership_counts.get(Membership.LEAVE, 0), + "banned_members": membership_counts.get(Membership.BAN, 0), + "local_users_in_room": len(local_users_in_room), + }, + ) + + @defer.inlineCallbacks + def _calculate_and_set_initial_state_for_user(self, user_id): + def _calculate_and_set_initial_state_for_user_txn(txn): + pos = self._get_max_stream_id_in_current_state_deltas_txn(txn) + + txn.execute( + """ + SELECT COUNT(distinct room_id) FROM current_state_events + WHERE type = 'm.room.member' AND state_key = ? + AND membership = 'join' + """, + (user_id,), + ) + count, = txn.fetchone() + return count, pos + + joined_rooms, pos = yield self.runInteraction( + "calculate_and_set_initial_state_for_user", + _calculate_and_set_initial_state_for_user_txn, + ) + + yield self.update_stats_delta( + ts=self.clock.time_msec(), + stats_type="user", + stats_id=user_id, + fields={}, + complete_with_stream_id=pos, + absolute_field_overrides={"joined_rooms": joined_rooms}, + ) diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py new file mode 100644 index 0000000000..263999dfca --- /dev/null +++ b/synapse/storage/data_stores/main/stream.py @@ -0,0 +1,948 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" This module is responsible for getting events from the DB for pagination +and event streaming. + +The order it returns events in depend on whether we are streaming forwards or +are paginating backwards. We do this because we want to handle out of order +messages nicely, while still returning them in the correct order when we +paginate bacwards. + +This is implemented by keeping two ordering columns: stream_ordering and +topological_ordering. Stream ordering is basically insertion/received order +(except for events from backfill requests). The topological_ordering is a +weak ordering of events based on the pdu graph. + +This means that we have to have two different types of tokens, depending on +what sort order was used: + - stream tokens are of the form: "s%d", which maps directly to the column + - topological tokems: "t%d-%d", where the integers map to the topological + and stream ordering columns respectively. +""" + +import abc +import logging +from collections import namedtuple + +from six.moves import range + +from twisted.internet import defer + +from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.engines import PostgresEngine +from synapse.types import RoomStreamToken +from synapse.util.caches.stream_change_cache import StreamChangeCache + +logger = logging.getLogger(__name__) + + +MAX_STREAM_SIZE = 1000 + + +_STREAM_TOKEN = "stream" +_TOPOLOGICAL_TOKEN = "topological" + + +# Used as return values for pagination APIs +_EventDictReturn = namedtuple( + "_EventDictReturn", ("event_id", "topological_ordering", "stream_ordering") +) + + +def generate_pagination_where_clause( + direction, column_names, from_token, to_token, engine +): + """Creates an SQL expression to bound the columns by the pagination + tokens. + + For example creates an SQL expression like: + + (6, 7) >= (topological_ordering, stream_ordering) + AND (5, 3) < (topological_ordering, stream_ordering) + + would be generated for dir=b, from_token=(6, 7) and to_token=(5, 3). + + Note that tokens are considered to be after the row they are in, e.g. if + a row A has a token T, then we consider A to be before T. This convention + is important when figuring out inequalities for the generated SQL, and + produces the following result: + - If paginating forwards then we exclude any rows matching the from + token, but include those that match the to token. + - If paginating backwards then we include any rows matching the from + token, but include those that match the to token. + + Args: + direction (str): Whether we're paginating backwards("b") or + forwards ("f"). + column_names (tuple[str, str]): The column names to bound. Must *not* + be user defined as these get inserted directly into the SQL + statement without escapes. + from_token (tuple[int, int]|None): The start point for the pagination. + This is an exclusive minimum bound if direction is "f", and an + inclusive maximum bound if direction is "b". + to_token (tuple[int, int]|None): The endpoint point for the pagination. + This is an inclusive maximum bound if direction is "f", and an + exclusive minimum bound if direction is "b". + engine: The database engine to generate the clauses for + + Returns: + str: The sql expression + """ + assert direction in ("b", "f") + + where_clause = [] + if from_token: + where_clause.append( + _make_generic_sql_bound( + bound=">=" if direction == "b" else "<", + column_names=column_names, + values=from_token, + engine=engine, + ) + ) + + if to_token: + where_clause.append( + _make_generic_sql_bound( + bound="<" if direction == "b" else ">=", + column_names=column_names, + values=to_token, + engine=engine, + ) + ) + + return " AND ".join(where_clause) + + +def _make_generic_sql_bound(bound, column_names, values, engine): + """Create an SQL expression that bounds the given column names by the + values, e.g. create the equivalent of `(1, 2) < (col1, col2)`. + + Only works with two columns. + + Older versions of SQLite don't support that syntax so we have to expand it + out manually. + + Args: + bound (str): The comparison operator to use. One of ">", "<", ">=", + "<=", where the values are on the left and columns on the right. + names (tuple[str, str]): The column names. Must *not* be user defined + as these get inserted directly into the SQL statement without + escapes. + values (tuple[int|None, int]): The values to bound the columns by. If + the first value is None then only creates a bound on the second + column. + engine: The database engine to generate the SQL for + + Returns: + str + """ + + assert bound in (">", "<", ">=", "<=") + + name1, name2 = column_names + val1, val2 = values + + if val1 is None: + val2 = int(val2) + return "(%d %s %s)" % (val2, bound, name2) + + val1 = int(val1) + val2 = int(val2) + + if isinstance(engine, PostgresEngine): + # Postgres doesn't optimise ``(x < a) OR (x=a AND y ? AND stream_ordering <= ?" + " ORDER BY stream_ordering %s LIMIT ?" + ) % (order,) + txn.execute(sql, (room_id, from_id, to_id, limit)) + + rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] + return rows + + rows = yield self.runInteraction("get_room_events_stream_for_room", f) + + ret = yield self.get_events_as_list( + [r.event_id for r in rows], get_prev_content=True + ) + + self._set_before_and_after(ret, rows, topo_order=from_id is None) + + if order.lower() == "desc": + ret.reverse() + + if rows: + key = "s%d" % min(r.stream_ordering for r in rows) + else: + # Assume we didn't get anything because there was nothing to + # get. + key = from_key + + return ret, key + + @defer.inlineCallbacks + def get_membership_changes_for_user(self, user_id, from_key, to_key): + from_id = RoomStreamToken.parse_stream_token(from_key).stream + to_id = RoomStreamToken.parse_stream_token(to_key).stream + + if from_key == to_key: + return [] + + if from_id: + has_changed = self._membership_stream_cache.has_entity_changed( + user_id, int(from_id) + ) + if not has_changed: + return [] + + def f(txn): + sql = ( + "SELECT m.event_id, stream_ordering FROM events AS e," + " room_memberships AS m" + " WHERE e.event_id = m.event_id" + " AND m.user_id = ?" + " AND e.stream_ordering > ? AND e.stream_ordering <= ?" + " ORDER BY e.stream_ordering ASC" + ) + txn.execute(sql, (user_id, from_id, to_id)) + + rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] + + return rows + + rows = yield self.runInteraction("get_membership_changes_for_user", f) + + ret = yield self.get_events_as_list( + [r.event_id for r in rows], get_prev_content=True + ) + + self._set_before_and_after(ret, rows, topo_order=False) + + return ret + + @defer.inlineCallbacks + def get_recent_events_for_room(self, room_id, limit, end_token): + """Get the most recent events in the room in topological ordering. + + Args: + room_id (str) + limit (int) + end_token (str): The stream token representing now. + + Returns: + Deferred[tuple[list[FrozenEvent], str]]: Returns a list of + events and a token pointing to the start of the returned + events. + The events returned are in ascending order. + """ + + rows, token = yield self.get_recent_event_ids_for_room( + room_id, limit, end_token + ) + + logger.debug("stream before") + events = yield self.get_events_as_list( + [r.event_id for r in rows], get_prev_content=True + ) + logger.debug("stream after") + + self._set_before_and_after(events, rows) + + return (events, token) + + @defer.inlineCallbacks + def get_recent_event_ids_for_room(self, room_id, limit, end_token): + """Get the most recent events in the room in topological ordering. + + Args: + room_id (str) + limit (int) + end_token (str): The stream token representing now. + + Returns: + Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of + _EventDictReturn and a token pointing to the start of the returned + events. + The events returned are in ascending order. + """ + # Allow a zero limit here, and no-op. + if limit == 0: + return [], end_token + + end_token = RoomStreamToken.parse(end_token) + + rows, token = yield self.runInteraction( + "get_recent_event_ids_for_room", + self._paginate_room_events_txn, + room_id, + from_token=end_token, + limit=limit, + ) + + # We want to return the results in ascending order. + rows.reverse() + + return rows, token + + def get_room_event_after_stream_ordering(self, room_id, stream_ordering): + """Gets details of the first event in a room at or after a stream ordering + + Args: + room_id (str): + stream_ordering (int): + + Returns: + Deferred[(int, int, str)]: + (stream ordering, topological ordering, event_id) + """ + + def _f(txn): + sql = ( + "SELECT stream_ordering, topological_ordering, event_id" + " FROM events" + " WHERE room_id = ? AND stream_ordering >= ?" + " AND NOT outlier" + " ORDER BY stream_ordering" + " LIMIT 1" + ) + txn.execute(sql, (room_id, stream_ordering)) + return txn.fetchone() + + return self.runInteraction("get_room_event_after_stream_ordering", _f) + + @defer.inlineCallbacks + def get_room_events_max_id(self, room_id=None): + """Returns the current token for rooms stream. + + By default, it returns the current global stream token. Specifying a + `room_id` causes it to return the current room specific topological + token. + """ + token = yield self.get_room_max_stream_ordering() + if room_id is None: + return "s%d" % (token,) + else: + topo = yield self.runInteraction( + "_get_max_topological_txn", self._get_max_topological_txn, room_id + ) + return "t%d-%d" % (topo, token) + + def get_stream_token_for_event(self, event_id): + """The stream token for an event + Args: + event_id(str): The id of the event to look up a stream token for. + Raises: + StoreError if the event wasn't in the database. + Returns: + A deferred "s%d" stream token. + """ + return self._simple_select_one_onecol( + table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering" + ).addCallback(lambda row: "s%d" % (row,)) + + def get_topological_token_for_event(self, event_id): + """The stream token for an event + Args: + event_id(str): The id of the event to look up a stream token for. + Raises: + StoreError if the event wasn't in the database. + Returns: + A deferred "t%d-%d" topological token. + """ + return self._simple_select_one( + table="events", + keyvalues={"event_id": event_id}, + retcols=("stream_ordering", "topological_ordering"), + desc="get_topological_token_for_event", + ).addCallback( + lambda row: "t%d-%d" % (row["topological_ordering"], row["stream_ordering"]) + ) + + def get_max_topological_token(self, room_id, stream_key): + """Get the max topological token in a room before the given stream + ordering. + + Args: + room_id (str) + stream_key (int) + + Returns: + Deferred[int] + """ + sql = ( + "SELECT coalesce(max(topological_ordering), 0) FROM events" + " WHERE room_id = ? AND stream_ordering < ?" + ) + return self._execute( + "get_max_topological_token", None, sql, room_id, stream_key + ).addCallback(lambda r: r[0][0] if r else 0) + + def _get_max_topological_txn(self, txn, room_id): + txn.execute( + "SELECT MAX(topological_ordering) FROM events" " WHERE room_id = ?", + (room_id,), + ) + + rows = txn.fetchall() + return rows[0][0] if rows else 0 + + @staticmethod + def _set_before_and_after(events, rows, topo_order=True): + """Inserts ordering information to events' internal metadata from + the DB rows. + + Args: + events (list[FrozenEvent]) + rows (list[_EventDictReturn]) + topo_order (bool): Whether the events were ordered topologically + or by stream ordering. If true then all rows should have a non + null topological_ordering. + """ + for event, row in zip(events, rows): + stream = row.stream_ordering + if topo_order and row.topological_ordering: + topo = row.topological_ordering + else: + topo = None + internal = event.internal_metadata + internal.before = str(RoomStreamToken(topo, stream - 1)) + internal.after = str(RoomStreamToken(topo, stream)) + internal.order = (int(topo) if topo else 0, int(stream)) + + @defer.inlineCallbacks + def get_events_around( + self, room_id, event_id, before_limit, after_limit, event_filter=None + ): + """Retrieve events and pagination tokens around a given event in a + room. + + Args: + room_id (str) + event_id (str) + before_limit (int) + after_limit (int) + event_filter (Filter|None) + + Returns: + dict + """ + + results = yield self.runInteraction( + "get_events_around", + self._get_events_around_txn, + room_id, + event_id, + before_limit, + after_limit, + event_filter, + ) + + events_before = yield self.get_events_as_list( + [e for e in results["before"]["event_ids"]], get_prev_content=True + ) + + events_after = yield self.get_events_as_list( + [e for e in results["after"]["event_ids"]], get_prev_content=True + ) + + return { + "events_before": events_before, + "events_after": events_after, + "start": results["before"]["token"], + "end": results["after"]["token"], + } + + def _get_events_around_txn( + self, txn, room_id, event_id, before_limit, after_limit, event_filter + ): + """Retrieves event_ids and pagination tokens around a given event in a + room. + + Args: + room_id (str) + event_id (str) + before_limit (int) + after_limit (int) + event_filter (Filter|None) + + Returns: + dict + """ + + results = self._simple_select_one_txn( + txn, + "events", + keyvalues={"event_id": event_id, "room_id": room_id}, + retcols=["stream_ordering", "topological_ordering"], + ) + + # Paginating backwards includes the event at the token, but paginating + # forward doesn't. + before_token = RoomStreamToken( + results["topological_ordering"] - 1, results["stream_ordering"] + ) + + after_token = RoomStreamToken( + results["topological_ordering"], results["stream_ordering"] + ) + + rows, start_token = self._paginate_room_events_txn( + txn, + room_id, + before_token, + direction="b", + limit=before_limit, + event_filter=event_filter, + ) + events_before = [r.event_id for r in rows] + + rows, end_token = self._paginate_room_events_txn( + txn, + room_id, + after_token, + direction="f", + limit=after_limit, + event_filter=event_filter, + ) + events_after = [r.event_id for r in rows] + + return { + "before": {"event_ids": events_before, "token": start_token}, + "after": {"event_ids": events_after, "token": end_token}, + } + + @defer.inlineCallbacks + def get_all_new_events_stream(self, from_id, current_id, limit): + """Get all new events + + Returns all events with from_id < stream_ordering <= current_id. + + Args: + from_id (int): the stream_ordering of the last event we processed + current_id (int): the stream_ordering of the most recently processed event + limit (int): the maximum number of events to return + + Returns: + Deferred[Tuple[int, list[FrozenEvent]]]: A tuple of (next_id, events), where + `next_id` is the next value to pass as `from_id` (it will either be the + stream_ordering of the last returned event, or, if fewer than `limit` events + were found, `current_id`. + """ + + def get_all_new_events_stream_txn(txn): + sql = ( + "SELECT e.stream_ordering, e.event_id" + " FROM events AS e" + " WHERE" + " ? < e.stream_ordering AND e.stream_ordering <= ?" + " ORDER BY e.stream_ordering ASC" + " LIMIT ?" + ) + + txn.execute(sql, (from_id, current_id, limit)) + rows = txn.fetchall() + + upper_bound = current_id + if len(rows) == limit: + upper_bound = rows[-1][0] + + return upper_bound, [row[1] for row in rows] + + upper_bound, event_ids = yield self.runInteraction( + "get_all_new_events_stream", get_all_new_events_stream_txn + ) + + events = yield self.get_events_as_list(event_ids) + + return upper_bound, events + + def get_federation_out_pos(self, typ): + return self._simple_select_one_onecol( + table="federation_stream_position", + retcol="stream_id", + keyvalues={"type": typ}, + desc="get_federation_out_pos", + ) + + def update_federation_out_pos(self, typ, stream_id): + return self._simple_update_one( + table="federation_stream_position", + keyvalues={"type": typ}, + updatevalues={"stream_id": stream_id}, + desc="update_federation_out_pos", + ) + + def has_room_changed_since(self, room_id, stream_id): + return self._events_stream_cache.has_entity_changed(room_id, stream_id) + + def _paginate_room_events_txn( + self, + txn, + room_id, + from_token, + to_token=None, + direction="b", + limit=-1, + event_filter=None, + ): + """Returns list of events before or after a given token. + + Args: + txn + room_id (str) + from_token (RoomStreamToken): The token used to stream from + to_token (RoomStreamToken|None): A token which if given limits the + results to only those before + direction(char): Either 'b' or 'f' to indicate whether we are + paginating forwards or backwards from `from_key`. + limit (int): The maximum number of events to return. + event_filter (Filter|None): If provided filters the events to + those that match the filter. + + Returns: + Deferred[tuple[list[_EventDictReturn], str]]: Returns the results + as a list of _EventDictReturn and a token that points to the end + of the result set. If no events are returned then the end of the + stream has been reached (i.e. there are no events between + `from_token` and `to_token`), or `limit` is zero. + """ + + assert int(limit) >= 0 + + # Tokens really represent positions between elements, but we use + # the convention of pointing to the event before the gap. Hence + # we have a bit of asymmetry when it comes to equalities. + args = [False, room_id] + if direction == "b": + order = "DESC" + else: + order = "ASC" + + bounds = generate_pagination_where_clause( + direction=direction, + column_names=("topological_ordering", "stream_ordering"), + from_token=from_token, + to_token=to_token, + engine=self.database_engine, + ) + + filter_clause, filter_args = filter_to_clause(event_filter) + + if filter_clause: + bounds += " AND " + filter_clause + args.extend(filter_args) + + args.append(int(limit)) + + sql = ( + "SELECT event_id, topological_ordering, stream_ordering" + " FROM events" + " WHERE outlier = ? AND room_id = ? AND %(bounds)s" + " ORDER BY topological_ordering %(order)s," + " stream_ordering %(order)s LIMIT ?" + ) % {"bounds": bounds, "order": order} + + txn.execute(sql, args) + + rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn] + + if rows: + topo = rows[-1].topological_ordering + toke = rows[-1].stream_ordering + if direction == "b": + # Tokens are positions between events. + # This token points *after* the last event in the chunk. + # We need it to point to the event before it in the chunk + # when we are going backwards so we subtract one from the + # stream part. + toke -= 1 + next_token = RoomStreamToken(topo, toke) + else: + # TODO (erikj): We should work out what to do here instead. + next_token = to_token if to_token else from_token + + return rows, str(next_token) + + @defer.inlineCallbacks + def paginate_room_events( + self, room_id, from_key, to_key=None, direction="b", limit=-1, event_filter=None + ): + """Returns list of events before or after a given token. + + Args: + room_id (str) + from_key (str): The token used to stream from + to_key (str|None): A token which if given limits the results to + only those before + direction(char): Either 'b' or 'f' to indicate whether we are + paginating forwards or backwards from `from_key`. + limit (int): The maximum number of events to return. + event_filter (Filter|None): If provided filters the events to + those that match the filter. + + Returns: + tuple[list[FrozenEvent], str]: Returns the results as a list of + events and a token that points to the end of the result set. If no + events are returned then the end of the stream has been reached + (i.e. there are no events between `from_key` and `to_key`). + """ + + from_key = RoomStreamToken.parse(from_key) + if to_key: + to_key = RoomStreamToken.parse(to_key) + + rows, token = yield self.runInteraction( + "paginate_room_events", + self._paginate_room_events_txn, + room_id, + from_key, + to_key, + direction, + limit, + event_filter, + ) + + events = yield self.get_events_as_list( + [r.event_id for r in rows], get_prev_content=True + ) + + self._set_before_and_after(events, rows) + + return (events, token) + + +class StreamStore(StreamWorkerStore): + def get_room_max_stream_ordering(self): + return self._stream_id_gen.get_current_token() + + def get_room_min_stream_ordering(self): + return self._backfill_id_gen.get_current_token() diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/data_stores/main/tags.py new file mode 100644 index 0000000000..10d1887f75 --- /dev/null +++ b/synapse/storage/data_stores/main/tags.py @@ -0,0 +1,265 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from six.moves import range + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore +from synapse.util.caches.descriptors import cached + +logger = logging.getLogger(__name__) + + +class TagsWorkerStore(AccountDataWorkerStore): + @cached() + def get_tags_for_user(self, user_id): + """Get all the tags for a user. + + + Args: + user_id(str): The user to get the tags for. + Returns: + A deferred dict mapping from room_id strings to dicts mapping from + tag strings to tag content. + """ + + deferred = self._simple_select_list( + "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"] + ) + + @deferred.addCallback + def tags_by_room(rows): + tags_by_room = {} + for row in rows: + room_tags = tags_by_room.setdefault(row["room_id"], {}) + room_tags[row["tag"]] = json.loads(row["content"]) + return tags_by_room + + return deferred + + @defer.inlineCallbacks + def get_all_updated_tags(self, last_id, current_id, limit): + """Get all the client tags that have changed on the server + Args: + last_id(int): The position to fetch from. + current_id(int): The position to fetch up to. + Returns: + A deferred list of tuples of stream_id int, user_id string, + room_id string, tag string and content string. + """ + if last_id == current_id: + return [] + + def get_all_updated_tags_txn(txn): + sql = ( + "SELECT stream_id, user_id, room_id" + " FROM room_tags_revisions as r" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + return txn.fetchall() + + tag_ids = yield self.runInteraction( + "get_all_updated_tags", get_all_updated_tags_txn + ) + + def get_tag_content(txn, tag_ids): + sql = ( + "SELECT tag, content" " FROM room_tags" " WHERE user_id=? AND room_id=?" + ) + results = [] + for stream_id, user_id, room_id in tag_ids: + txn.execute(sql, (user_id, room_id)) + tags = [] + for tag, content in txn: + tags.append(json.dumps(tag) + ":" + content) + tag_json = "{" + ",".join(tags) + "}" + results.append((stream_id, user_id, room_id, tag_json)) + + return results + + batch_size = 50 + results = [] + for i in range(0, len(tag_ids), batch_size): + tags = yield self.runInteraction( + "get_all_updated_tag_content", + get_tag_content, + tag_ids[i : i + batch_size], + ) + results.extend(tags) + + return results + + @defer.inlineCallbacks + def get_updated_tags(self, user_id, stream_id): + """Get all the tags for the rooms where the tags have changed since the + given version + + Args: + user_id(str): The user to get the tags for. + stream_id(int): The earliest update to get for the user. + Returns: + A deferred dict mapping from room_id strings to lists of tag + strings for all the rooms that changed since the stream_id token. + """ + + def get_updated_tags_txn(txn): + sql = ( + "SELECT room_id from room_tags_revisions" + " WHERE user_id = ? AND stream_id > ?" + ) + txn.execute(sql, (user_id, stream_id)) + room_ids = [row[0] for row in txn] + return room_ids + + changed = self._account_data_stream_cache.has_entity_changed( + user_id, int(stream_id) + ) + if not changed: + return {} + + room_ids = yield self.runInteraction("get_updated_tags", get_updated_tags_txn) + + results = {} + if room_ids: + tags_by_room = yield self.get_tags_for_user(user_id) + for room_id in room_ids: + results[room_id] = tags_by_room.get(room_id, {}) + + return results + + def get_tags_for_room(self, user_id, room_id): + """Get all the tags for the given room + Args: + user_id(str): The user to get tags for + room_id(str): The room to get tags for + Returns: + A deferred list of string tags. + """ + return self._simple_select_list( + table="room_tags", + keyvalues={"user_id": user_id, "room_id": room_id}, + retcols=("tag", "content"), + desc="get_tags_for_room", + ).addCallback( + lambda rows: {row["tag"]: json.loads(row["content"]) for row in rows} + ) + + +class TagsStore(TagsWorkerStore): + @defer.inlineCallbacks + def add_tag_to_room(self, user_id, room_id, tag, content): + """Add a tag to a room for a user. + Args: + user_id(str): The user to add a tag for. + room_id(str): The room to add a tag for. + tag(str): The tag name to add. + content(dict): A json object to associate with the tag. + Returns: + A deferred that completes once the tag has been added. + """ + content_json = json.dumps(content) + + def add_tag_txn(txn, next_id): + self._simple_upsert_txn( + txn, + table="room_tags", + keyvalues={"user_id": user_id, "room_id": room_id, "tag": tag}, + values={"content": content_json}, + ) + self._update_revision_txn(txn, user_id, room_id, next_id) + + with self._account_data_id_gen.get_next() as next_id: + yield self.runInteraction("add_tag", add_tag_txn, next_id) + + self.get_tags_for_user.invalidate((user_id,)) + + result = self._account_data_id_gen.get_current_token() + return result + + @defer.inlineCallbacks + def remove_tag_from_room(self, user_id, room_id, tag): + """Remove a tag from a room for a user. + Returns: + A deferred that completes once the tag has been removed + """ + + def remove_tag_txn(txn, next_id): + sql = ( + "DELETE FROM room_tags " + " WHERE user_id = ? AND room_id = ? AND tag = ?" + ) + txn.execute(sql, (user_id, room_id, tag)) + self._update_revision_txn(txn, user_id, room_id, next_id) + + with self._account_data_id_gen.get_next() as next_id: + yield self.runInteraction("remove_tag", remove_tag_txn, next_id) + + self.get_tags_for_user.invalidate((user_id,)) + + result = self._account_data_id_gen.get_current_token() + return result + + def _update_revision_txn(self, txn, user_id, room_id, next_id): + """Update the latest revision of the tags for the given user and room. + + Args: + txn: The database cursor + user_id(str): The ID of the user. + room_id(str): The ID of the room. + next_id(int): The the revision to advance to. + """ + + txn.call_after( + self._account_data_stream_cache.entity_has_changed, user_id, next_id + ) + + update_max_id_sql = ( + "UPDATE account_data_max_stream_id" + " SET stream_id = ?" + " WHERE stream_id < ?" + ) + txn.execute(update_max_id_sql, (next_id, next_id)) + + update_sql = ( + "UPDATE room_tags_revisions" + " SET stream_id = ?" + " WHERE user_id = ?" + " AND room_id = ?" + ) + txn.execute(update_sql, (next_id, user_id, room_id)) + + if txn.rowcount == 0: + insert_sql = ( + "INSERT INTO room_tags_revisions (user_id, room_id, stream_id)" + " VALUES (?, ?, ?)" + ) + try: + txn.execute(insert_sql, (user_id, room_id, next_id)) + except self.database_engine.module.IntegrityError: + # Ignore insertion errors. It doesn't matter if the row wasn't + # inserted because if two updates happend concurrently the one + # with the higher stream_id will not be reported to a client + # unless the previous update has completed. It doesn't matter + # which stream_id ends up in the table, as long as it is higher + # than the id that the client has. + pass diff --git a/synapse/storage/data_stores/main/transactions.py b/synapse/storage/data_stores/main/transactions.py new file mode 100644 index 0000000000..01b1be5e14 --- /dev/null +++ b/synapse/storage/data_stores/main/transactions.py @@ -0,0 +1,273 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from collections import namedtuple + +import six + +from canonicaljson import encode_canonical_json + +from twisted.internet import defer + +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.util.caches.expiringcache import ExpiringCache + +# py2 sqlite has buffer hardcoded as only binary type, so we must use it, +# despite being deprecated and removed in favor of memoryview +if six.PY2: + db_binary_type = six.moves.builtins.buffer +else: + db_binary_type = memoryview + +logger = logging.getLogger(__name__) + + +_TransactionRow = namedtuple( + "_TransactionRow", + ("id", "transaction_id", "destination", "ts", "response_code", "response_json"), +) + +_UpdateTransactionRow = namedtuple( + "_TransactionRow", ("response_code", "response_json") +) + +SENTINEL = object() + + +class TransactionStore(SQLBaseStore): + """A collection of queries for handling PDUs. + """ + + def __init__(self, db_conn, hs): + super(TransactionStore, self).__init__(db_conn, hs) + + self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000) + + self._destination_retry_cache = ExpiringCache( + cache_name="get_destination_retry_timings", + clock=self._clock, + expiry_ms=5 * 60 * 1000, + ) + + def get_received_txn_response(self, transaction_id, origin): + """For an incoming transaction from a given origin, check if we have + already responded to it. If so, return the response code and response + body (as a dict). + + Args: + transaction_id (str) + origin(str) + + Returns: + tuple: None if we have not previously responded to + this transaction or a 2-tuple of (int, dict) + """ + + return self.runInteraction( + "get_received_txn_response", + self._get_received_txn_response, + transaction_id, + origin, + ) + + def _get_received_txn_response(self, txn, transaction_id, origin): + result = self._simple_select_one_txn( + txn, + table="received_transactions", + keyvalues={"transaction_id": transaction_id, "origin": origin}, + retcols=( + "transaction_id", + "origin", + "ts", + "response_code", + "response_json", + "has_been_referenced", + ), + allow_none=True, + ) + + if result and result["response_code"]: + return result["response_code"], db_to_json(result["response_json"]) + + else: + return None + + def set_received_txn_response(self, transaction_id, origin, code, response_dict): + """Persist the response we returened for an incoming transaction, and + should return for subsequent transactions with the same transaction_id + and origin. + + Args: + txn + transaction_id (str) + origin (str) + code (int) + response_json (str) + """ + + return self._simple_insert( + table="received_transactions", + values={ + "transaction_id": transaction_id, + "origin": origin, + "response_code": code, + "response_json": db_binary_type(encode_canonical_json(response_dict)), + "ts": self._clock.time_msec(), + }, + or_ignore=True, + desc="set_received_txn_response", + ) + + @defer.inlineCallbacks + def get_destination_retry_timings(self, destination): + """Gets the current retry timings (if any) for a given destination. + + Args: + destination (str) + + Returns: + None if not retrying + Otherwise a dict for the retry scheme + """ + + result = self._destination_retry_cache.get(destination, SENTINEL) + if result is not SENTINEL: + return result + + result = yield self.runInteraction( + "get_destination_retry_timings", + self._get_destination_retry_timings, + destination, + ) + + # We don't hugely care about race conditions between getting and + # invalidating the cache, since we time out fairly quickly anyway. + self._destination_retry_cache[destination] = result + return result + + def _get_destination_retry_timings(self, txn, destination): + result = self._simple_select_one_txn( + txn, + table="destinations", + keyvalues={"destination": destination}, + retcols=("destination", "failure_ts", "retry_last_ts", "retry_interval"), + allow_none=True, + ) + + if result and result["retry_last_ts"] > 0: + return result + else: + return None + + def set_destination_retry_timings( + self, destination, failure_ts, retry_last_ts, retry_interval + ): + """Sets the current retry timings for a given destination. + Both timings should be zero if retrying is no longer occuring. + + Args: + destination (str) + failure_ts (int|None) - when the server started failing (ms since epoch) + retry_last_ts (int) - time of last retry attempt in unix epoch ms + retry_interval (int) - how long until next retry in ms + """ + + self._destination_retry_cache.pop(destination, None) + return self.runInteraction( + "set_destination_retry_timings", + self._set_destination_retry_timings, + destination, + failure_ts, + retry_last_ts, + retry_interval, + ) + + def _set_destination_retry_timings( + self, txn, destination, failure_ts, retry_last_ts, retry_interval + ): + + if self.database_engine.can_native_upsert: + # Upsert retry time interval if retry_interval is zero (i.e. we're + # resetting it) or greater than the existing retry interval. + + sql = """ + INSERT INTO destinations ( + destination, failure_ts, retry_last_ts, retry_interval + ) + VALUES (?, ?, ?, ?) + ON CONFLICT (destination) DO UPDATE SET + failure_ts = EXCLUDED.failure_ts, + retry_last_ts = EXCLUDED.retry_last_ts, + retry_interval = EXCLUDED.retry_interval + WHERE + EXCLUDED.retry_interval = 0 + OR destinations.retry_interval < EXCLUDED.retry_interval + """ + + txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval)) + + return + + self.database_engine.lock_table(txn, "destinations") + + # We need to be careful here as the data may have changed from under us + # due to a worker setting the timings. + + prev_row = self._simple_select_one_txn( + txn, + table="destinations", + keyvalues={"destination": destination}, + retcols=("failure_ts", "retry_last_ts", "retry_interval"), + allow_none=True, + ) + + if not prev_row: + self._simple_insert_txn( + txn, + table="destinations", + values={ + "destination": destination, + "failure_ts": failure_ts, + "retry_last_ts": retry_last_ts, + "retry_interval": retry_interval, + }, + ) + elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval: + self._simple_update_one_txn( + txn, + "destinations", + keyvalues={"destination": destination}, + updatevalues={ + "failure_ts": failure_ts, + "retry_last_ts": retry_last_ts, + "retry_interval": retry_interval, + }, + ) + + def _start_cleanup_transactions(self): + return run_as_background_process( + "cleanup_transactions", self._cleanup_transactions + ) + + def _cleanup_transactions(self): + now = self._clock.time_msec() + month_ago = now - 30 * 24 * 60 * 60 * 1000 + + def _cleanup_transactions_txn(txn): + txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) + + return self.runInteraction("_cleanup_transactions", _cleanup_transactions_txn) diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/data_stores/main/user_directory.py new file mode 100644 index 0000000000..652abe0e6a --- /dev/null +++ b/synapse/storage/data_stores/main/user_directory.py @@ -0,0 +1,827 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, JoinRules +from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage.data_stores.main.state import StateFilter +from synapse.storage.data_stores.main.state_deltas import StateDeltasStore +from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.types import get_domain_from_id, get_localpart_from_id +from synapse.util.caches.descriptors import cached + +logger = logging.getLogger(__name__) + + +TEMP_TABLE = "_temp_populate_user_directory" + + +class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore): + + # How many records do we calculate before sending it to + # add_users_who_share_private_rooms? + SHARE_PRIVATE_WORKING_SET = 500 + + def __init__(self, db_conn, hs): + super(UserDirectoryBackgroundUpdateStore, self).__init__(db_conn, hs) + + self.server_name = hs.hostname + + self.register_background_update_handler( + "populate_user_directory_createtables", + self._populate_user_directory_createtables, + ) + self.register_background_update_handler( + "populate_user_directory_process_rooms", + self._populate_user_directory_process_rooms, + ) + self.register_background_update_handler( + "populate_user_directory_process_users", + self._populate_user_directory_process_users, + ) + self.register_background_update_handler( + "populate_user_directory_cleanup", self._populate_user_directory_cleanup + ) + + @defer.inlineCallbacks + def _populate_user_directory_createtables(self, progress, batch_size): + + # Get all the rooms that we want to process. + def _make_staging_area(txn): + sql = ( + "CREATE TABLE IF NOT EXISTS " + + TEMP_TABLE + + "_rooms(room_id TEXT NOT NULL, events BIGINT NOT NULL)" + ) + txn.execute(sql) + + sql = ( + "CREATE TABLE IF NOT EXISTS " + + TEMP_TABLE + + "_position(position TEXT NOT NULL)" + ) + txn.execute(sql) + + # Get rooms we want to process from the database + sql = """ + SELECT room_id, count(*) FROM current_state_events + GROUP BY room_id + """ + txn.execute(sql) + rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()] + self._simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms) + del rooms + + # If search all users is on, get all the users we want to add. + if self.hs.config.user_directory_search_all_users: + sql = ( + "CREATE TABLE IF NOT EXISTS " + + TEMP_TABLE + + "_users(user_id TEXT NOT NULL)" + ) + txn.execute(sql) + + txn.execute("SELECT name FROM users") + users = [{"user_id": x[0]} for x in txn.fetchall()] + + self._simple_insert_many_txn(txn, TEMP_TABLE + "_users", users) + + new_pos = yield self.get_max_stream_id_in_current_state_deltas() + yield self.runInteraction( + "populate_user_directory_temp_build", _make_staging_area + ) + yield self._simple_insert(TEMP_TABLE + "_position", {"position": new_pos}) + + yield self._end_background_update("populate_user_directory_createtables") + return 1 + + @defer.inlineCallbacks + def _populate_user_directory_cleanup(self, progress, batch_size): + """ + Update the user directory stream position, then clean up the old tables. + """ + position = yield self._simple_select_one_onecol( + TEMP_TABLE + "_position", None, "position" + ) + yield self.update_user_directory_stream_pos(position) + + def _delete_staging_area(txn): + txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms") + txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users") + txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position") + + yield self.runInteraction( + "populate_user_directory_cleanup", _delete_staging_area + ) + + yield self._end_background_update("populate_user_directory_cleanup") + return 1 + + @defer.inlineCallbacks + def _populate_user_directory_process_rooms(self, progress, batch_size): + """ + Args: + progress (dict) + batch_size (int): Maximum number of state events to process + per cycle. + """ + state = self.hs.get_state_handler() + + # If we don't have progress filed, delete everything. + if not progress: + yield self.delete_all_from_user_dir() + + def _get_next_batch(txn): + # Only fetch 250 rooms, so we don't fetch too many at once, even + # if those 250 rooms have less than batch_size state events. + sql = """ + SELECT room_id, events FROM %s + ORDER BY events DESC + LIMIT 250 + """ % ( + TEMP_TABLE + "_rooms", + ) + txn.execute(sql) + rooms_to_work_on = txn.fetchall() + + if not rooms_to_work_on: + return None + + # Get how many are left to process, so we can give status on how + # far we are in processing + txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms") + progress["remaining"] = txn.fetchone()[0] + + return rooms_to_work_on + + rooms_to_work_on = yield self.runInteraction( + "populate_user_directory_temp_read", _get_next_batch + ) + + # No more rooms -- complete the transaction. + if not rooms_to_work_on: + yield self._end_background_update("populate_user_directory_process_rooms") + return 1 + + logger.info( + "Processing the next %d rooms of %d remaining" + % (len(rooms_to_work_on), progress["remaining"]) + ) + + processed_event_count = 0 + + for room_id, event_count in rooms_to_work_on: + is_in_room = yield self.is_host_joined(room_id, self.server_name) + + if is_in_room: + is_public = yield self.is_room_world_readable_or_publicly_joinable( + room_id + ) + + users_with_profile = yield state.get_current_users_in_room(room_id) + user_ids = set(users_with_profile) + + # Update each user in the user directory. + for user_id, profile in users_with_profile.items(): + yield self.update_profile_in_user_dir( + user_id, profile.display_name, profile.avatar_url + ) + + to_insert = set() + + if is_public: + for user_id in user_ids: + if self.get_if_app_services_interested_in_user(user_id): + continue + + to_insert.add(user_id) + + if to_insert: + yield self.add_users_in_public_rooms(room_id, to_insert) + to_insert.clear() + else: + for user_id in user_ids: + if not self.hs.is_mine_id(user_id): + continue + + if self.get_if_app_services_interested_in_user(user_id): + continue + + for other_user_id in user_ids: + if user_id == other_user_id: + continue + + user_set = (user_id, other_user_id) + to_insert.add(user_set) + + # If it gets too big, stop and write to the database + # to prevent storing too much in RAM. + if len(to_insert) >= self.SHARE_PRIVATE_WORKING_SET: + yield self.add_users_who_share_private_room( + room_id, to_insert + ) + to_insert.clear() + + if to_insert: + yield self.add_users_who_share_private_room(room_id, to_insert) + to_insert.clear() + + # We've finished a room. Delete it from the table. + yield self._simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id}) + # Update the remaining counter. + progress["remaining"] -= 1 + yield self.runInteraction( + "populate_user_directory", + self._background_update_progress_txn, + "populate_user_directory_process_rooms", + progress, + ) + + processed_event_count += event_count + + if processed_event_count > batch_size: + # Don't process any more rooms, we've hit our batch size. + return processed_event_count + + return processed_event_count + + @defer.inlineCallbacks + def _populate_user_directory_process_users(self, progress, batch_size): + """ + If search_all_users is enabled, add all of the users to the user directory. + """ + if not self.hs.config.user_directory_search_all_users: + yield self._end_background_update("populate_user_directory_process_users") + return 1 + + def _get_next_batch(txn): + sql = "SELECT user_id FROM %s LIMIT %s" % ( + TEMP_TABLE + "_users", + str(batch_size), + ) + txn.execute(sql) + users_to_work_on = txn.fetchall() + + if not users_to_work_on: + return None + + users_to_work_on = [x[0] for x in users_to_work_on] + + # Get how many are left to process, so we can give status on how + # far we are in processing + sql = "SELECT COUNT(*) FROM " + TEMP_TABLE + "_users" + txn.execute(sql) + progress["remaining"] = txn.fetchone()[0] + + return users_to_work_on + + users_to_work_on = yield self.runInteraction( + "populate_user_directory_temp_read", _get_next_batch + ) + + # No more users -- complete the transaction. + if not users_to_work_on: + yield self._end_background_update("populate_user_directory_process_users") + return 1 + + logger.info( + "Processing the next %d users of %d remaining" + % (len(users_to_work_on), progress["remaining"]) + ) + + for user_id in users_to_work_on: + profile = yield self.get_profileinfo(get_localpart_from_id(user_id)) + yield self.update_profile_in_user_dir( + user_id, profile.display_name, profile.avatar_url + ) + + # We've finished processing a user. Delete it from the table. + yield self._simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id}) + # Update the remaining counter. + progress["remaining"] -= 1 + yield self.runInteraction( + "populate_user_directory", + self._background_update_progress_txn, + "populate_user_directory_process_users", + progress, + ) + + return len(users_to_work_on) + + @defer.inlineCallbacks + def is_room_world_readable_or_publicly_joinable(self, room_id): + """Check if the room is either world_readable or publically joinable + """ + + # Create a state filter that only queries join and history state event + types_to_filter = ( + (EventTypes.JoinRules, ""), + (EventTypes.RoomHistoryVisibility, ""), + ) + + current_state_ids = yield self.get_filtered_current_state_ids( + room_id, StateFilter.from_types(types_to_filter) + ) + + join_rules_id = current_state_ids.get((EventTypes.JoinRules, "")) + if join_rules_id: + join_rule_ev = yield self.get_event(join_rules_id, allow_none=True) + if join_rule_ev: + if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC: + return True + + hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, "")) + if hist_vis_id: + hist_vis_ev = yield self.get_event(hist_vis_id, allow_none=True) + if hist_vis_ev: + if hist_vis_ev.content.get("history_visibility") == "world_readable": + return True + + return False + + def update_profile_in_user_dir(self, user_id, display_name, avatar_url): + """ + Update or add a user's profile in the user directory. + """ + + def _update_profile_in_user_dir_txn(txn): + new_entry = self._simple_upsert_txn( + txn, + table="user_directory", + keyvalues={"user_id": user_id}, + values={"display_name": display_name, "avatar_url": avatar_url}, + lock=False, # We're only inserter + ) + + if isinstance(self.database_engine, PostgresEngine): + # We weight the localpart most highly, then display name and finally + # server name + if self.database_engine.can_native_upsert: + sql = """ + INSERT INTO user_directory_search(user_id, vector) + VALUES (?, + setweight(to_tsvector('english', ?), 'A') + || setweight(to_tsvector('english', ?), 'D') + || setweight(to_tsvector('english', COALESCE(?, '')), 'B') + ) ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector + """ + txn.execute( + sql, + ( + user_id, + get_localpart_from_id(user_id), + get_domain_from_id(user_id), + display_name, + ), + ) + else: + # TODO: Remove this code after we've bumped the minimum version + # of postgres to always support upserts, so we can get rid of + # `new_entry` usage + if new_entry is True: + sql = """ + INSERT INTO user_directory_search(user_id, vector) + VALUES (?, + setweight(to_tsvector('english', ?), 'A') + || setweight(to_tsvector('english', ?), 'D') + || setweight(to_tsvector('english', COALESCE(?, '')), 'B') + ) + """ + txn.execute( + sql, + ( + user_id, + get_localpart_from_id(user_id), + get_domain_from_id(user_id), + display_name, + ), + ) + elif new_entry is False: + sql = """ + UPDATE user_directory_search + SET vector = setweight(to_tsvector('english', ?), 'A') + || setweight(to_tsvector('english', ?), 'D') + || setweight(to_tsvector('english', COALESCE(?, '')), 'B') + WHERE user_id = ? + """ + txn.execute( + sql, + ( + get_localpart_from_id(user_id), + get_domain_from_id(user_id), + display_name, + user_id, + ), + ) + else: + raise RuntimeError( + "upsert returned None when 'can_native_upsert' is False" + ) + elif isinstance(self.database_engine, Sqlite3Engine): + value = "%s %s" % (user_id, display_name) if display_name else user_id + self._simple_upsert_txn( + txn, + table="user_directory_search", + keyvalues={"user_id": user_id}, + values={"value": value}, + lock=False, # We're only inserter + ) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + + txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) + + return self.runInteraction( + "update_profile_in_user_dir", _update_profile_in_user_dir_txn + ) + + def add_users_who_share_private_room(self, room_id, user_id_tuples): + """Insert entries into the users_who_share_private_rooms table. The first + user should be a local user. + + Args: + room_id (str) + user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs. + """ + + def _add_users_who_share_room_txn(txn): + self._simple_upsert_many_txn( + txn, + table="users_who_share_private_rooms", + key_names=["user_id", "other_user_id", "room_id"], + key_values=[ + (user_id, other_user_id, room_id) + for user_id, other_user_id in user_id_tuples + ], + value_names=(), + value_values=None, + ) + + return self.runInteraction( + "add_users_who_share_room", _add_users_who_share_room_txn + ) + + def add_users_in_public_rooms(self, room_id, user_ids): + """Insert entries into the users_who_share_private_rooms table. The first + user should be a local user. + + Args: + room_id (str) + user_ids (list[str]) + """ + + def _add_users_in_public_rooms_txn(txn): + + self._simple_upsert_many_txn( + txn, + table="users_in_public_rooms", + key_names=["user_id", "room_id"], + key_values=[(user_id, room_id) for user_id in user_ids], + value_names=(), + value_values=None, + ) + + return self.runInteraction( + "add_users_in_public_rooms", _add_users_in_public_rooms_txn + ) + + def delete_all_from_user_dir(self): + """Delete the entire user directory + """ + + def _delete_all_from_user_dir_txn(txn): + txn.execute("DELETE FROM user_directory") + txn.execute("DELETE FROM user_directory_search") + txn.execute("DELETE FROM users_in_public_rooms") + txn.execute("DELETE FROM users_who_share_private_rooms") + txn.call_after(self.get_user_in_directory.invalidate_all) + + return self.runInteraction( + "delete_all_from_user_dir", _delete_all_from_user_dir_txn + ) + + @cached() + def get_user_in_directory(self, user_id): + return self._simple_select_one( + table="user_directory", + keyvalues={"user_id": user_id}, + retcols=("display_name", "avatar_url"), + allow_none=True, + desc="get_user_in_directory", + ) + + def update_user_directory_stream_pos(self, stream_id): + return self._simple_update_one( + table="user_directory_stream_pos", + keyvalues={}, + updatevalues={"stream_id": stream_id}, + desc="update_user_directory_stream_pos", + ) + + +class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): + + # How many records do we calculate before sending it to + # add_users_who_share_private_rooms? + SHARE_PRIVATE_WORKING_SET = 500 + + def __init__(self, db_conn, hs): + super(UserDirectoryStore, self).__init__(db_conn, hs) + + def remove_from_user_dir(self, user_id): + def _remove_from_user_dir_txn(txn): + self._simple_delete_txn( + txn, table="user_directory", keyvalues={"user_id": user_id} + ) + self._simple_delete_txn( + txn, table="user_directory_search", keyvalues={"user_id": user_id} + ) + self._simple_delete_txn( + txn, table="users_in_public_rooms", keyvalues={"user_id": user_id} + ) + self._simple_delete_txn( + txn, + table="users_who_share_private_rooms", + keyvalues={"user_id": user_id}, + ) + self._simple_delete_txn( + txn, + table="users_who_share_private_rooms", + keyvalues={"other_user_id": user_id}, + ) + txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) + + return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn) + + @defer.inlineCallbacks + def get_users_in_dir_due_to_room(self, room_id): + """Get all user_ids that are in the room directory because they're + in the given room_id + """ + user_ids_share_pub = yield self._simple_select_onecol( + table="users_in_public_rooms", + keyvalues={"room_id": room_id}, + retcol="user_id", + desc="get_users_in_dir_due_to_room", + ) + + user_ids_share_priv = yield self._simple_select_onecol( + table="users_who_share_private_rooms", + keyvalues={"room_id": room_id}, + retcol="other_user_id", + desc="get_users_in_dir_due_to_room", + ) + + user_ids = set(user_ids_share_pub) + user_ids.update(user_ids_share_priv) + + return user_ids + + def remove_user_who_share_room(self, user_id, room_id): + """ + Deletes entries in the users_who_share_*_rooms table. The first + user should be a local user. + + Args: + user_id (str) + room_id (str) + """ + + def _remove_user_who_share_room_txn(txn): + self._simple_delete_txn( + txn, + table="users_who_share_private_rooms", + keyvalues={"user_id": user_id, "room_id": room_id}, + ) + self._simple_delete_txn( + txn, + table="users_who_share_private_rooms", + keyvalues={"other_user_id": user_id, "room_id": room_id}, + ) + self._simple_delete_txn( + txn, + table="users_in_public_rooms", + keyvalues={"user_id": user_id, "room_id": room_id}, + ) + + return self.runInteraction( + "remove_user_who_share_room", _remove_user_who_share_room_txn + ) + + @defer.inlineCallbacks + def get_user_dir_rooms_user_is_in(self, user_id): + """ + Returns the rooms that a user is in. + + Args: + user_id(str): Must be a local user + + Returns: + list: user_id + """ + rows = yield self._simple_select_onecol( + table="users_who_share_private_rooms", + keyvalues={"user_id": user_id}, + retcol="room_id", + desc="get_rooms_user_is_in", + ) + + pub_rows = yield self._simple_select_onecol( + table="users_in_public_rooms", + keyvalues={"user_id": user_id}, + retcol="room_id", + desc="get_rooms_user_is_in", + ) + + users = set(pub_rows) + users.update(rows) + return list(users) + + @defer.inlineCallbacks + def get_rooms_in_common_for_users(self, user_id, other_user_id): + """Given two user_ids find out the list of rooms they share. + """ + sql = """ + SELECT room_id FROM ( + SELECT c.room_id FROM current_state_events AS c + INNER JOIN room_memberships AS m USING (event_id) + WHERE type = 'm.room.member' + AND m.membership = 'join' + AND state_key = ? + ) AS f1 INNER JOIN ( + SELECT c.room_id FROM current_state_events AS c + INNER JOIN room_memberships AS m USING (event_id) + WHERE type = 'm.room.member' + AND m.membership = 'join' + AND state_key = ? + ) f2 USING (room_id) + """ + + rows = yield self._execute( + "get_rooms_in_common_for_users", None, sql, user_id, other_user_id + ) + + return [room_id for room_id, in rows] + + def get_user_directory_stream_pos(self): + return self._simple_select_one_onecol( + table="user_directory_stream_pos", + keyvalues={}, + retcol="stream_id", + desc="get_user_directory_stream_pos", + ) + + @defer.inlineCallbacks + def search_user_dir(self, user_id, search_term, limit): + """Searches for users in directory + + Returns: + dict of the form:: + + { + "limited": , # whether there were more results or not + "results": [ # Ordered by best match first + { + "user_id": , + "display_name": , + "avatar_url": + } + ] + } + """ + + if self.hs.config.user_directory_search_all_users: + join_args = (user_id,) + where_clause = "user_id != ?" + else: + join_args = (user_id,) + where_clause = """ + ( + EXISTS (select 1 from users_in_public_rooms WHERE user_id = t.user_id) + OR EXISTS ( + SELECT 1 FROM users_who_share_private_rooms + WHERE user_id = ? AND other_user_id = t.user_id + ) + ) + """ + + if isinstance(self.database_engine, PostgresEngine): + full_query, exact_query, prefix_query = _parse_query_postgres(search_term) + + # We order by rank and then if they have profile info + # The ranking algorithm is hand tweaked for "best" results. Broadly + # the idea is we give a higher weight to exact matches. + # The array of numbers are the weights for the various part of the + # search: (domain, _, display name, localpart) + sql = """ + SELECT d.user_id AS user_id, display_name, avatar_url + FROM user_directory_search as t + INNER JOIN user_directory AS d USING (user_id) + WHERE + %s + AND vector @@ to_tsquery('english', ?) + ORDER BY + (CASE WHEN d.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END) + * (CASE WHEN display_name IS NOT NULL THEN 1.2 ELSE 1.0 END) + * (CASE WHEN avatar_url IS NOT NULL THEN 1.2 ELSE 1.0 END) + * ( + 3 * ts_rank_cd( + '{0.1, 0.1, 0.9, 1.0}', + vector, + to_tsquery('english', ?), + 8 + ) + + ts_rank_cd( + '{0.1, 0.1, 0.9, 1.0}', + vector, + to_tsquery('english', ?), + 8 + ) + ) + DESC, + display_name IS NULL, + avatar_url IS NULL + LIMIT ? + """ % ( + where_clause, + ) + args = join_args + (full_query, exact_query, prefix_query, limit + 1) + elif isinstance(self.database_engine, Sqlite3Engine): + search_query = _parse_query_sqlite(search_term) + + sql = """ + SELECT d.user_id AS user_id, display_name, avatar_url + FROM user_directory_search as t + INNER JOIN user_directory AS d USING (user_id) + WHERE + %s + AND value MATCH ? + ORDER BY + rank(matchinfo(user_directory_search)) DESC, + display_name IS NULL, + avatar_url IS NULL + LIMIT ? + """ % ( + where_clause, + ) + args = join_args + (search_query, limit + 1) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + + results = yield self._execute( + "search_user_dir", self.cursor_to_dict, sql, *args + ) + + limited = len(results) > limit + + return {"limited": limited, "results": results} + + +def _parse_query_sqlite(search_term): + """Takes a plain unicode string from the user and converts it into a form + that can be passed to database. + We use this so that we can add prefix matching, which isn't something + that is supported by default. + + We specifically add both a prefix and non prefix matching term so that + exact matches get ranked higher. + """ + + # Pull out the individual words, discarding any non-word characters. + results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) + return " & ".join("(%s* OR %s)" % (result, result) for result in results) + + +def _parse_query_postgres(search_term): + """Takes a plain unicode string from the user and converts it into a form + that can be passed to database. + We use this so that we can add prefix matching, which isn't something + that is supported by default. + """ + + # Pull out the individual words, discarding any non-word characters. + results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) + + both = " & ".join("(%s:* | %s)" % (result, result) for result in results) + exact = " & ".join("%s" % (result,) for result in results) + prefix = " & ".join("%s:*" % (result,) for result in results) + + return both, exact, prefix diff --git a/synapse/storage/data_stores/main/user_erasure_store.py b/synapse/storage/data_stores/main/user_erasure_store.py new file mode 100644 index 0000000000..aa4f0da5f0 --- /dev/null +++ b/synapse/storage/data_stores/main/user_erasure_store.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import operator + +from synapse.storage._base import SQLBaseStore +from synapse.util.caches.descriptors import cached, cachedList + + +class UserErasureWorkerStore(SQLBaseStore): + @cached() + def is_user_erased(self, user_id): + """ + Check if the given user id has requested erasure + + Args: + user_id (str): full user id to check + + Returns: + Deferred[bool]: True if the user has requested erasure + """ + return self._simple_select_onecol( + table="erased_users", + keyvalues={"user_id": user_id}, + retcol="1", + desc="is_user_erased", + ).addCallback(operator.truth) + + @cachedList( + cached_method_name="is_user_erased", list_name="user_ids", inlineCallbacks=True + ) + def are_users_erased(self, user_ids): + """ + Checks which users in a list have requested erasure + + Args: + user_ids (iterable[str]): full user id to check + + Returns: + Deferred[dict[str, bool]]: + for each user, whether the user has requested erasure. + """ + # this serves the dual purpose of (a) making sure we can do len and + # iterate it multiple times, and (b) avoiding duplicates. + user_ids = tuple(set(user_ids)) + + rows = yield self._simple_select_many_batch( + table="erased_users", + column="user_id", + iterable=user_ids, + retcols=("user_id",), + desc="are_users_erased", + ) + erased_users = set(row["user_id"] for row in rows) + + res = dict((u, u in erased_users) for u in user_ids) + return res + + +class UserErasureStore(UserErasureWorkerStore): + def mark_user_erased(self, user_id): + """Indicate that user_id wishes their message history to be erased. + + Args: + user_id (str): full user_id to be erased + """ + + def f(txn): + # first check if they are already in the list + txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?", (user_id,)) + if txn.fetchone(): + return + + # they are not already there: do the insert. + txn.execute("INSERT INTO erased_users (user_id) VALUES (?)", (user_id,)) + + self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,)) + + return self.runInteraction("mark_user_erased", f) diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py deleted file mode 100644 index f04aad0743..0000000000 --- a/synapse/storage/deviceinbox.py +++ /dev/null @@ -1,457 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from canonicaljson import json - -from twisted.internet import defer - -from synapse.logging.opentracing import log_kv, set_tag, trace -from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.storage.background_updates import BackgroundUpdateStore -from synapse.util.caches.expiringcache import ExpiringCache - -logger = logging.getLogger(__name__) - - -class DeviceInboxWorkerStore(SQLBaseStore): - def get_to_device_stream_token(self): - return self._device_inbox_id_gen.get_current_token() - - def get_new_messages_for_device( - self, user_id, device_id, last_stream_id, current_stream_id, limit=100 - ): - """ - Args: - user_id(str): The recipient user_id. - device_id(str): The recipient device_id. - current_stream_id(int): The current position of the to device - message stream. - Returns: - Deferred ([dict], int): List of messages for the device and where - in the stream the messages got to. - """ - has_changed = self._device_inbox_stream_cache.has_entity_changed( - user_id, last_stream_id - ) - if not has_changed: - return defer.succeed(([], current_stream_id)) - - def get_new_messages_for_device_txn(txn): - sql = ( - "SELECT stream_id, message_json FROM device_inbox" - " WHERE user_id = ? AND device_id = ?" - " AND ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC" - " LIMIT ?" - ) - txn.execute( - sql, (user_id, device_id, last_stream_id, current_stream_id, limit) - ) - messages = [] - for row in txn: - stream_pos = row[0] - messages.append(json.loads(row[1])) - if len(messages) < limit: - stream_pos = current_stream_id - return messages, stream_pos - - return self.runInteraction( - "get_new_messages_for_device", get_new_messages_for_device_txn - ) - - @trace - @defer.inlineCallbacks - def delete_messages_for_device(self, user_id, device_id, up_to_stream_id): - """ - Args: - user_id(str): The recipient user_id. - device_id(str): The recipient device_id. - up_to_stream_id(int): Where to delete messages up to. - Returns: - A deferred that resolves to the number of messages deleted. - """ - # If we have cached the last stream id we've deleted up to, we can - # check if there is likely to be anything that needs deleting - last_deleted_stream_id = self._last_device_delete_cache.get( - (user_id, device_id), None - ) - - set_tag("last_deleted_stream_id", last_deleted_stream_id) - - if last_deleted_stream_id: - has_changed = self._device_inbox_stream_cache.has_entity_changed( - user_id, last_deleted_stream_id - ) - if not has_changed: - log_kv({"message": "No changes in cache since last check"}) - return 0 - - def delete_messages_for_device_txn(txn): - sql = ( - "DELETE FROM device_inbox" - " WHERE user_id = ? AND device_id = ?" - " AND stream_id <= ?" - ) - txn.execute(sql, (user_id, device_id, up_to_stream_id)) - return txn.rowcount - - count = yield self.runInteraction( - "delete_messages_for_device", delete_messages_for_device_txn - ) - - log_kv( - {"message": "deleted {} messages for device".format(count), "count": count} - ) - - # Update the cache, ensuring that we only ever increase the value - last_deleted_stream_id = self._last_device_delete_cache.get( - (user_id, device_id), 0 - ) - self._last_device_delete_cache[(user_id, device_id)] = max( - last_deleted_stream_id, up_to_stream_id - ) - - return count - - @trace - def get_new_device_msgs_for_remote( - self, destination, last_stream_id, current_stream_id, limit - ): - """ - Args: - destination(str): The name of the remote server. - last_stream_id(int|long): The last position of the device message stream - that the server sent up to. - current_stream_id(int|long): The current position of the device - message stream. - Returns: - Deferred ([dict], int|long): List of messages for the device and where - in the stream the messages got to. - """ - - set_tag("destination", destination) - set_tag("last_stream_id", last_stream_id) - set_tag("current_stream_id", current_stream_id) - set_tag("limit", limit) - - has_changed = self._device_federation_outbox_stream_cache.has_entity_changed( - destination, last_stream_id - ) - if not has_changed or last_stream_id == current_stream_id: - log_kv({"message": "No new messages in stream"}) - return defer.succeed(([], current_stream_id)) - - if limit <= 0: - # This can happen if we run out of room for EDUs in the transaction. - return defer.succeed(([], last_stream_id)) - - @trace - def get_new_messages_for_remote_destination_txn(txn): - sql = ( - "SELECT stream_id, messages_json FROM device_federation_outbox" - " WHERE destination = ?" - " AND ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC" - " LIMIT ?" - ) - txn.execute(sql, (destination, last_stream_id, current_stream_id, limit)) - messages = [] - for row in txn: - stream_pos = row[0] - messages.append(json.loads(row[1])) - if len(messages) < limit: - log_kv({"message": "Set stream position to current position"}) - stream_pos = current_stream_id - return messages, stream_pos - - return self.runInteraction( - "get_new_device_msgs_for_remote", - get_new_messages_for_remote_destination_txn, - ) - - @trace - def delete_device_msgs_for_remote(self, destination, up_to_stream_id): - """Used to delete messages when the remote destination acknowledges - their receipt. - - Args: - destination(str): The destination server_name - up_to_stream_id(int): Where to delete messages up to. - Returns: - A deferred that resolves when the messages have been deleted. - """ - - def delete_messages_for_remote_destination_txn(txn): - sql = ( - "DELETE FROM device_federation_outbox" - " WHERE destination = ?" - " AND stream_id <= ?" - ) - txn.execute(sql, (destination, up_to_stream_id)) - - return self.runInteraction( - "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn - ) - - -class DeviceInboxBackgroundUpdateStore(BackgroundUpdateStore): - DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" - - def __init__(self, db_conn, hs): - super(DeviceInboxBackgroundUpdateStore, self).__init__(db_conn, hs) - - self.register_background_index_update( - "device_inbox_stream_index", - index_name="device_inbox_stream_id_user_id", - table="device_inbox", - columns=["stream_id", "user_id"], - ) - - self.register_background_update_handler( - self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox - ) - - @defer.inlineCallbacks - def _background_drop_index_device_inbox(self, progress, batch_size): - def reindex_txn(conn): - txn = conn.cursor() - txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id") - txn.close() - - yield self.runWithConnection(reindex_txn) - - yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID) - - return 1 - - -class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore): - DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" - - def __init__(self, db_conn, hs): - super(DeviceInboxStore, self).__init__(db_conn, hs) - - # Map of (user_id, device_id) to the last stream_id that has been - # deleted up to. This is so that we can no op deletions. - self._last_device_delete_cache = ExpiringCache( - cache_name="last_device_delete_cache", - clock=self._clock, - max_len=10000, - expiry_ms=30 * 60 * 1000, - ) - - @trace - @defer.inlineCallbacks - def add_messages_to_device_inbox( - self, local_messages_by_user_then_device, remote_messages_by_destination - ): - """Used to send messages from this server. - - Args: - sender_user_id(str): The ID of the user sending these messages. - local_messages_by_user_and_device(dict): - Dictionary of user_id to device_id to message. - remote_messages_by_destination(dict): - Dictionary of destination server_name to the EDU JSON to send. - Returns: - A deferred stream_id that resolves when the messages have been - inserted. - """ - - def add_messages_txn(txn, now_ms, stream_id): - # Add the local messages directly to the local inbox. - self._add_messages_to_local_device_inbox_txn( - txn, stream_id, local_messages_by_user_then_device - ) - - # Add the remote messages to the federation outbox. - # We'll send them to a remote server when we next send a - # federation transaction to that destination. - sql = ( - "INSERT INTO device_federation_outbox" - " (destination, stream_id, queued_ts, messages_json)" - " VALUES (?,?,?,?)" - ) - rows = [] - for destination, edu in remote_messages_by_destination.items(): - edu_json = json.dumps(edu) - rows.append((destination, stream_id, now_ms, edu_json)) - txn.executemany(sql, rows) - - with self._device_inbox_id_gen.get_next() as stream_id: - now_ms = self.clock.time_msec() - yield self.runInteraction( - "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id - ) - for user_id in local_messages_by_user_then_device.keys(): - self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id) - for destination in remote_messages_by_destination.keys(): - self._device_federation_outbox_stream_cache.entity_has_changed( - destination, stream_id - ) - - return self._device_inbox_id_gen.get_current_token() - - @defer.inlineCallbacks - def add_messages_from_remote_to_device_inbox( - self, origin, message_id, local_messages_by_user_then_device - ): - def add_messages_txn(txn, now_ms, stream_id): - # Check if we've already inserted a matching message_id for that - # origin. This can happen if the origin doesn't receive our - # acknowledgement from the first time we received the message. - already_inserted = self._simple_select_one_txn( - txn, - table="device_federation_inbox", - keyvalues={"origin": origin, "message_id": message_id}, - retcols=("message_id",), - allow_none=True, - ) - if already_inserted is not None: - return - - # Add an entry for this message_id so that we know we've processed - # it. - self._simple_insert_txn( - txn, - table="device_federation_inbox", - values={ - "origin": origin, - "message_id": message_id, - "received_ts": now_ms, - }, - ) - - # Add the messages to the approriate local device inboxes so that - # they'll be sent to the devices when they next sync. - self._add_messages_to_local_device_inbox_txn( - txn, stream_id, local_messages_by_user_then_device - ) - - with self._device_inbox_id_gen.get_next() as stream_id: - now_ms = self.clock.time_msec() - yield self.runInteraction( - "add_messages_from_remote_to_device_inbox", - add_messages_txn, - now_ms, - stream_id, - ) - for user_id in local_messages_by_user_then_device.keys(): - self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id) - - return stream_id - - def _add_messages_to_local_device_inbox_txn( - self, txn, stream_id, messages_by_user_then_device - ): - sql = "UPDATE device_max_stream_id" " SET stream_id = ?" " WHERE stream_id < ?" - txn.execute(sql, (stream_id, stream_id)) - - local_by_user_then_device = {} - for user_id, messages_by_device in messages_by_user_then_device.items(): - messages_json_for_user = {} - devices = list(messages_by_device.keys()) - if len(devices) == 1 and devices[0] == "*": - # Handle wildcard device_ids. - sql = "SELECT device_id FROM devices" " WHERE user_id = ?" - txn.execute(sql, (user_id,)) - message_json = json.dumps(messages_by_device["*"]) - for row in txn: - # Add the message for all devices for this user on this - # server. - device = row[0] - messages_json_for_user[device] = message_json - else: - if not devices: - continue - - clause, args = make_in_list_sql_clause( - txn.database_engine, "device_id", devices - ) - sql = "SELECT device_id FROM devices WHERE user_id = ? AND " + clause - - # TODO: Maybe this needs to be done in batches if there are - # too many local devices for a given user. - txn.execute(sql, [user_id] + list(args)) - for row in txn: - # Only insert into the local inbox if the device exists on - # this server - device = row[0] - message_json = json.dumps(messages_by_device[device]) - messages_json_for_user[device] = message_json - - if messages_json_for_user: - local_by_user_then_device[user_id] = messages_json_for_user - - if not local_by_user_then_device: - return - - sql = ( - "INSERT INTO device_inbox" - " (user_id, device_id, stream_id, message_json)" - " VALUES (?,?,?,?)" - ) - rows = [] - for user_id, messages_by_device in local_by_user_then_device.items(): - for device_id, message_json in messages_by_device.items(): - rows.append((user_id, device_id, stream_id, message_json)) - - txn.executemany(sql, rows) - - def get_all_new_device_messages(self, last_pos, current_pos, limit): - """ - Args: - last_pos(int): - current_pos(int): - limit(int): - Returns: - A deferred list of rows from the device inbox - """ - if last_pos == current_pos: - return defer.succeed([]) - - def get_all_new_device_messages_txn(txn): - # We limit like this as we might have multiple rows per stream_id, and - # we want to make sure we always get all entries for any stream_id - # we return. - upper_pos = min(current_pos, last_pos + limit) - sql = ( - "SELECT max(stream_id), user_id" - " FROM device_inbox" - " WHERE ? < stream_id AND stream_id <= ?" - " GROUP BY user_id" - ) - txn.execute(sql, (last_pos, upper_pos)) - rows = txn.fetchall() - - sql = ( - "SELECT max(stream_id), destination" - " FROM device_federation_outbox" - " WHERE ? < stream_id AND stream_id <= ?" - " GROUP BY destination" - ) - txn.execute(sql, (last_pos, upper_pos)) - rows.extend(txn) - - # Order by ascending stream ordering - rows.sort() - - return rows - - return self.runInteraction( - "get_all_new_device_messages", get_all_new_device_messages_txn - ) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py deleted file mode 100644 index ac5239e50a..0000000000 --- a/synapse/storage/devices.py +++ /dev/null @@ -1,937 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging - -from six import iteritems - -from canonicaljson import json - -from twisted.internet import defer - -from synapse.api.errors import StoreError -from synapse.logging.opentracing import ( - get_active_span_text_map, - set_tag, - trace, - whitelisted_homeserver, -) -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import ( - Cache, - SQLBaseStore, - db_to_json, - make_in_list_sql_clause, -) -from synapse.storage.background_updates import BackgroundUpdateStore -from synapse.util import batch_iter -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList - -logger = logging.getLogger(__name__) - -DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = ( - "drop_device_list_streams_non_unique_indexes" -) - - -class DeviceWorkerStore(SQLBaseStore): - def get_device(self, user_id, device_id): - """Retrieve a device. - - Args: - user_id (str): The ID of the user which owns the device - device_id (str): The ID of the device to retrieve - Returns: - defer.Deferred for a dict containing the device information - Raises: - StoreError: if the device is not found - """ - return self._simple_select_one( - table="devices", - keyvalues={"user_id": user_id, "device_id": device_id}, - retcols=("user_id", "device_id", "display_name"), - desc="get_device", - ) - - @defer.inlineCallbacks - def get_devices_by_user(self, user_id): - """Retrieve all of a user's registered devices. - - Args: - user_id (str): - Returns: - defer.Deferred: resolves to a dict from device_id to a dict - containing "device_id", "user_id" and "display_name" for each - device. - """ - devices = yield self._simple_select_list( - table="devices", - keyvalues={"user_id": user_id}, - retcols=("user_id", "device_id", "display_name"), - desc="get_devices_by_user", - ) - - return {d["device_id"]: d for d in devices} - - @trace - @defer.inlineCallbacks - def get_devices_by_remote(self, destination, from_stream_id, limit): - """Get stream of updates to send to remote servers - - Returns: - Deferred[tuple[int, list[dict]]]: - current stream id (ie, the stream id of the last update included in the - response), and the list of updates - """ - now_stream_id = self._device_list_id_gen.get_current_token() - - has_changed = self._device_list_federation_stream_cache.has_entity_changed( - destination, int(from_stream_id) - ) - if not has_changed: - return now_stream_id, [] - - # We retrieve n+1 devices from the list of outbound pokes where n is - # our outbound device update limit. We then check if the very last - # device has the same stream_id as the second-to-last device. If so, - # then we ignore all devices with that stream_id and only send the - # devices with a lower stream_id. - # - # If when culling the list we end up with no devices afterwards, we - # consider the device update to be too large, and simply skip the - # stream_id; the rationale being that such a large device list update - # is likely an error. - updates = yield self.runInteraction( - "get_devices_by_remote", - self._get_devices_by_remote_txn, - destination, - from_stream_id, - now_stream_id, - limit + 1, - ) - - # Return an empty list if there are no updates - if not updates: - return now_stream_id, [] - - # if we have exceeded the limit, we need to exclude any results with the - # same stream_id as the last row. - if len(updates) > limit: - stream_id_cutoff = updates[-1][2] - now_stream_id = stream_id_cutoff - 1 - else: - stream_id_cutoff = None - - # Perform the equivalent of a GROUP BY - # - # Iterate through the updates list and copy non-duplicate - # (user_id, device_id) entries into a map, with the value being - # the max stream_id across each set of duplicate entries - # - # maps (user_id, device_id) -> (stream_id, opentracing_context) - # as long as their stream_id does not match that of the last row - # - # opentracing_context contains the opentracing metadata for the request - # that created the poke - # - # The most recent request's opentracing_context is used as the - # context which created the Edu. - - query_map = {} - for update in updates: - if stream_id_cutoff is not None and update[2] >= stream_id_cutoff: - # Stop processing updates - break - - key = (update[0], update[1]) - - update_context = update[3] - update_stream_id = update[2] - - previous_update_stream_id, _ = query_map.get(key, (0, None)) - - if update_stream_id > previous_update_stream_id: - query_map[key] = (update_stream_id, update_context) - - # If we didn't find any updates with a stream_id lower than the cutoff, it - # means that there are more than limit updates all of which have the same - # steam_id. - - # That should only happen if a client is spamming the server with new - # devices, in which case E2E isn't going to work well anyway. We'll just - # skip that stream_id and return an empty list, and continue with the next - # stream_id next time. - if not query_map: - return stream_id_cutoff, [] - - results = yield self._get_device_update_edus_by_remote( - destination, from_stream_id, query_map - ) - - return now_stream_id, results - - def _get_devices_by_remote_txn( - self, txn, destination, from_stream_id, now_stream_id, limit - ): - """Return device update information for a given remote destination - - Args: - txn (LoggingTransaction): The transaction to execute - destination (str): The host the device updates are intended for - from_stream_id (int): The minimum stream_id to filter updates by, exclusive - now_stream_id (int): The maximum stream_id to filter updates by, inclusive - limit (int): Maximum number of device updates to return - - Returns: - List: List of device updates - """ - sql = """ - SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes - WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ? - ORDER BY stream_id - LIMIT ? - """ - txn.execute(sql, (destination, from_stream_id, now_stream_id, False, limit)) - - return list(txn) - - @defer.inlineCallbacks - def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_map): - """Returns a list of device update EDUs as well as E2EE keys - - Args: - destination (str): The host the device updates are intended for - from_stream_id (int): The minimum stream_id to filter updates by, exclusive - query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping - user_id/device_id to update stream_id and the relevent json-encoded - opentracing context - - Returns: - List[Dict]: List of objects representing an device update EDU - - """ - devices = yield self.runInteraction( - "_get_e2e_device_keys_txn", - self._get_e2e_device_keys_txn, - query_map.keys(), - include_all_devices=True, - include_deleted_devices=True, - ) - - results = [] - for user_id, user_devices in iteritems(devices): - # The prev_id for the first row is always the last row before - # `from_stream_id` - prev_id = yield self._get_last_device_update_for_remote_user( - destination, user_id, from_stream_id - ) - for device_id, device in iteritems(user_devices): - stream_id, opentracing_context = query_map[(user_id, device_id)] - result = { - "user_id": user_id, - "device_id": device_id, - "prev_id": [prev_id] if prev_id else [], - "stream_id": stream_id, - "org.matrix.opentracing_context": opentracing_context, - } - - prev_id = stream_id - - if device is not None: - key_json = device.get("key_json", None) - if key_json: - result["keys"] = db_to_json(key_json) - device_display_name = device.get("device_display_name", None) - if device_display_name: - result["device_display_name"] = device_display_name - else: - result["deleted"] = True - - results.append(result) - - return results - - def _get_last_device_update_for_remote_user( - self, destination, user_id, from_stream_id - ): - def f(txn): - prev_sent_id_sql = """ - SELECT coalesce(max(stream_id), 0) as stream_id - FROM device_lists_outbound_last_success - WHERE destination = ? AND user_id = ? AND stream_id <= ? - """ - txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id)) - rows = txn.fetchall() - return rows[0][0] - - return self.runInteraction("get_last_device_update_for_remote_user", f) - - def mark_as_sent_devices_by_remote(self, destination, stream_id): - """Mark that updates have successfully been sent to the destination. - """ - return self.runInteraction( - "mark_as_sent_devices_by_remote", - self._mark_as_sent_devices_by_remote_txn, - destination, - stream_id, - ) - - def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): - # We update the device_lists_outbound_last_success with the successfully - # poked users. We do the join to see which users need to be inserted and - # which updated. - sql = """ - SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL) - FROM device_lists_outbound_pokes as o - LEFT JOIN device_lists_outbound_last_success as s - USING (destination, user_id) - WHERE destination = ? AND o.stream_id <= ? - GROUP BY user_id - """ - txn.execute(sql, (destination, stream_id)) - rows = txn.fetchall() - - sql = """ - UPDATE device_lists_outbound_last_success - SET stream_id = ? - WHERE destination = ? AND user_id = ? - """ - txn.executemany(sql, ((row[1], destination, row[0]) for row in rows if row[2])) - - sql = """ - INSERT INTO device_lists_outbound_last_success - (destination, user_id, stream_id) VALUES (?, ?, ?) - """ - txn.executemany( - sql, ((destination, row[0], row[1]) for row in rows if not row[2]) - ) - - # Delete all sent outbound pokes - sql = """ - DELETE FROM device_lists_outbound_pokes - WHERE destination = ? AND stream_id <= ? - """ - txn.execute(sql, (destination, stream_id)) - - def get_device_stream_token(self): - return self._device_list_id_gen.get_current_token() - - @trace - @defer.inlineCallbacks - def get_user_devices_from_cache(self, query_list): - """Get the devices (and keys if any) for remote users from the cache. - - Args: - query_list(list): List of (user_id, device_ids), if device_ids is - falsey then return all device ids for that user. - - Returns: - (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is - a set of user_ids and results_map is a mapping of - user_id -> device_id -> device_info - """ - user_ids = set(user_id for user_id, _ in query_list) - user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids)) - user_ids_in_cache = set( - user_id for user_id, stream_id in user_map.items() if stream_id - ) - user_ids_not_in_cache = user_ids - user_ids_in_cache - - results = {} - for user_id, device_id in query_list: - if user_id not in user_ids_in_cache: - continue - - if device_id: - device = yield self._get_cached_user_device(user_id, device_id) - results.setdefault(user_id, {})[device_id] = device - else: - results[user_id] = yield self._get_cached_devices_for_user(user_id) - - set_tag("in_cache", results) - set_tag("not_in_cache", user_ids_not_in_cache) - - return user_ids_not_in_cache, results - - @cachedInlineCallbacks(num_args=2, tree=True) - def _get_cached_user_device(self, user_id, device_id): - content = yield self._simple_select_one_onecol( - table="device_lists_remote_cache", - keyvalues={"user_id": user_id, "device_id": device_id}, - retcol="content", - desc="_get_cached_user_device", - ) - return db_to_json(content) - - @cachedInlineCallbacks() - def _get_cached_devices_for_user(self, user_id): - devices = yield self._simple_select_list( - table="device_lists_remote_cache", - keyvalues={"user_id": user_id}, - retcols=("device_id", "content"), - desc="_get_cached_devices_for_user", - ) - return { - device["device_id"]: db_to_json(device["content"]) for device in devices - } - - def get_devices_with_keys_by_user(self, user_id): - """Get all devices (with any device keys) for a user - - Returns: - (stream_id, devices) - """ - return self.runInteraction( - "get_devices_with_keys_by_user", - self._get_devices_with_keys_by_user_txn, - user_id, - ) - - def _get_devices_with_keys_by_user_txn(self, txn, user_id): - now_stream_id = self._device_list_id_gen.get_current_token() - - devices = self._get_e2e_device_keys_txn( - txn, [(user_id, None)], include_all_devices=True - ) - - if devices: - user_devices = devices[user_id] - results = [] - for device_id, device in iteritems(user_devices): - result = {"device_id": device_id} - - key_json = device.get("key_json", None) - if key_json: - result["keys"] = db_to_json(key_json) - device_display_name = device.get("device_display_name", None) - if device_display_name: - result["device_display_name"] = device_display_name - - results.append(result) - - return now_stream_id, results - - return now_stream_id, [] - - def get_users_whose_devices_changed(self, from_key, user_ids): - """Get set of users whose devices have changed since `from_key` that - are in the given list of user_ids. - - Args: - from_key (str): The device lists stream token - user_ids (Iterable[str]) - - Returns: - Deferred[set[str]]: The set of user_ids whose devices have changed - since `from_key` - """ - from_key = int(from_key) - - # Get set of users who *may* have changed. Users not in the returned - # list have definitely not changed. - to_check = list( - self._device_list_stream_cache.get_entities_changed(user_ids, from_key) - ) - - if not to_check: - return defer.succeed(set()) - - def _get_users_whose_devices_changed_txn(txn): - changes = set() - - sql = """ - SELECT DISTINCT user_id FROM device_lists_stream - WHERE stream_id > ? - AND - """ - - for chunk in batch_iter(to_check, 100): - clause, args = make_in_list_sql_clause( - txn.database_engine, "user_id", chunk - ) - txn.execute(sql + clause, (from_key,) + tuple(args)) - changes.update(user_id for user_id, in txn) - - return changes - - return self.runInteraction( - "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn - ) - - def get_all_device_list_changes_for_remotes(self, from_key, to_key): - """Return a list of `(stream_id, user_id, destination)` which is the - combined list of changes to devices, and which destinations need to be - poked. `destination` may be None if no destinations need to be poked. - """ - # We do a group by here as there can be a large number of duplicate - # entries, since we throw away device IDs. - sql = """ - SELECT MAX(stream_id) AS stream_id, user_id, destination - FROM device_lists_stream - LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id) - WHERE ? < stream_id AND stream_id <= ? - GROUP BY user_id, destination - """ - return self._execute( - "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key - ) - - @cached(max_entries=10000) - def get_device_list_last_stream_id_for_remote(self, user_id): - """Get the last stream_id we got for a user. May be None if we haven't - got any information for them. - """ - return self._simple_select_one_onecol( - table="device_lists_remote_extremeties", - keyvalues={"user_id": user_id}, - retcol="stream_id", - desc="get_device_list_last_stream_id_for_remote", - allow_none=True, - ) - - @cachedList( - cached_method_name="get_device_list_last_stream_id_for_remote", - list_name="user_ids", - inlineCallbacks=True, - ) - def get_device_list_last_stream_id_for_remotes(self, user_ids): - rows = yield self._simple_select_many_batch( - table="device_lists_remote_extremeties", - column="user_id", - iterable=user_ids, - retcols=("user_id", "stream_id"), - desc="get_device_list_last_stream_id_for_remotes", - ) - - results = {user_id: None for user_id in user_ids} - results.update({row["user_id"]: row["stream_id"] for row in rows}) - - return results - - -class DeviceBackgroundUpdateStore(BackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(DeviceBackgroundUpdateStore, self).__init__(db_conn, hs) - - self.register_background_index_update( - "device_lists_stream_idx", - index_name="device_lists_stream_user_id", - table="device_lists_stream", - columns=["user_id", "device_id"], - ) - - # create a unique index on device_lists_remote_cache - self.register_background_index_update( - "device_lists_remote_cache_unique_idx", - index_name="device_lists_remote_cache_unique_id", - table="device_lists_remote_cache", - columns=["user_id", "device_id"], - unique=True, - ) - - # And one on device_lists_remote_extremeties - self.register_background_index_update( - "device_lists_remote_extremeties_unique_idx", - index_name="device_lists_remote_extremeties_unique_idx", - table="device_lists_remote_extremeties", - columns=["user_id"], - unique=True, - ) - - # once they complete, we can remove the old non-unique indexes. - self.register_background_update_handler( - DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES, - self._drop_device_list_streams_non_unique_indexes, - ) - - @defer.inlineCallbacks - def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size): - def f(conn): - txn = conn.cursor() - txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id") - txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id") - txn.close() - - yield self.runWithConnection(f) - yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES) - return 1 - - -class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(DeviceStore, self).__init__(db_conn, hs) - - # Map of (user_id, device_id) -> bool. If there is an entry that implies - # the device exists. - self.device_id_exists_cache = Cache( - name="device_id_exists", keylen=2, max_entries=10000 - ) - - self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000) - - @defer.inlineCallbacks - def store_device(self, user_id, device_id, initial_device_display_name): - """Ensure the given device is known; add it to the store if not - - Args: - user_id (str): id of user associated with the device - device_id (str): id of device - initial_device_display_name (str): initial displayname of the - device. Ignored if device exists. - Returns: - defer.Deferred: boolean whether the device was inserted or an - existing device existed with that ID. - """ - key = (user_id, device_id) - if self.device_id_exists_cache.get(key, None): - return False - - try: - inserted = yield self._simple_insert( - "devices", - values={ - "user_id": user_id, - "device_id": device_id, - "display_name": initial_device_display_name, - }, - desc="store_device", - or_ignore=True, - ) - self.device_id_exists_cache.prefill(key, True) - return inserted - except Exception as e: - logger.error( - "store_device with device_id=%s(%r) user_id=%s(%r)" - " display_name=%s(%r) failed: %s", - type(device_id).__name__, - device_id, - type(user_id).__name__, - user_id, - type(initial_device_display_name).__name__, - initial_device_display_name, - e, - ) - raise StoreError(500, "Problem storing device.") - - @defer.inlineCallbacks - def delete_device(self, user_id, device_id): - """Delete a device. - - Args: - user_id (str): The ID of the user which owns the device - device_id (str): The ID of the device to delete - Returns: - defer.Deferred - """ - yield self._simple_delete_one( - table="devices", - keyvalues={"user_id": user_id, "device_id": device_id}, - desc="delete_device", - ) - - self.device_id_exists_cache.invalidate((user_id, device_id)) - - @defer.inlineCallbacks - def delete_devices(self, user_id, device_ids): - """Deletes several devices. - - Args: - user_id (str): The ID of the user which owns the devices - device_ids (list): The IDs of the devices to delete - Returns: - defer.Deferred - """ - yield self._simple_delete_many( - table="devices", - column="device_id", - iterable=device_ids, - keyvalues={"user_id": user_id}, - desc="delete_devices", - ) - for device_id in device_ids: - self.device_id_exists_cache.invalidate((user_id, device_id)) - - def update_device(self, user_id, device_id, new_display_name=None): - """Update a device. - - Args: - user_id (str): The ID of the user which owns the device - device_id (str): The ID of the device to update - new_display_name (str|None): new displayname for device; None - to leave unchanged - Raises: - StoreError: if the device is not found - Returns: - defer.Deferred - """ - updates = {} - if new_display_name is not None: - updates["display_name"] = new_display_name - if not updates: - return defer.succeed(None) - return self._simple_update_one( - table="devices", - keyvalues={"user_id": user_id, "device_id": device_id}, - updatevalues=updates, - desc="update_device", - ) - - @defer.inlineCallbacks - def mark_remote_user_device_list_as_unsubscribed(self, user_id): - """Mark that we no longer track device lists for remote user. - """ - yield self._simple_delete( - table="device_lists_remote_extremeties", - keyvalues={"user_id": user_id}, - desc="mark_remote_user_device_list_as_unsubscribed", - ) - self.get_device_list_last_stream_id_for_remote.invalidate((user_id,)) - - def update_remote_device_list_cache_entry( - self, user_id, device_id, content, stream_id - ): - """Updates a single device in the cache of a remote user's devicelist. - - Note: assumes that we are the only thread that can be updating this user's - device list. - - Args: - user_id (str): User to update device list for - device_id (str): ID of decivice being updated - content (dict): new data on this device - stream_id (int): the version of the device list - - Returns: - Deferred[None] - """ - return self.runInteraction( - "update_remote_device_list_cache_entry", - self._update_remote_device_list_cache_entry_txn, - user_id, - device_id, - content, - stream_id, - ) - - def _update_remote_device_list_cache_entry_txn( - self, txn, user_id, device_id, content, stream_id - ): - if content.get("deleted"): - self._simple_delete_txn( - txn, - table="device_lists_remote_cache", - keyvalues={"user_id": user_id, "device_id": device_id}, - ) - - txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id)) - else: - self._simple_upsert_txn( - txn, - table="device_lists_remote_cache", - keyvalues={"user_id": user_id, "device_id": device_id}, - values={"content": json.dumps(content)}, - # we don't need to lock, because we assume we are the only thread - # updating this user's devices. - lock=False, - ) - - txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id)) - txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,)) - txn.call_after( - self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) - ) - - self._simple_upsert_txn( - txn, - table="device_lists_remote_extremeties", - keyvalues={"user_id": user_id}, - values={"stream_id": stream_id}, - # again, we can assume we are the only thread updating this user's - # extremity. - lock=False, - ) - - def update_remote_device_list_cache(self, user_id, devices, stream_id): - """Replace the entire cache of the remote user's devices. - - Note: assumes that we are the only thread that can be updating this user's - device list. - - Args: - user_id (str): User to update device list for - devices (list[dict]): list of device objects supplied over federation - stream_id (int): the version of the device list - - Returns: - Deferred[None] - """ - return self.runInteraction( - "update_remote_device_list_cache", - self._update_remote_device_list_cache_txn, - user_id, - devices, - stream_id, - ) - - def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id): - self._simple_delete_txn( - txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id} - ) - - self._simple_insert_many_txn( - txn, - table="device_lists_remote_cache", - values=[ - { - "user_id": user_id, - "device_id": content["device_id"], - "content": json.dumps(content), - } - for content in devices - ], - ) - - txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,)) - txn.call_after(self._get_cached_user_device.invalidate_many, (user_id,)) - txn.call_after( - self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) - ) - - self._simple_upsert_txn( - txn, - table="device_lists_remote_extremeties", - keyvalues={"user_id": user_id}, - values={"stream_id": stream_id}, - # we don't need to lock, because we can assume we are the only thread - # updating this user's extremity. - lock=False, - ) - - @defer.inlineCallbacks - def add_device_change_to_streams(self, user_id, device_ids, hosts): - """Persist that a user's devices have been updated, and which hosts - (if any) should be poked. - """ - with self._device_list_id_gen.get_next() as stream_id: - yield self.runInteraction( - "add_device_change_to_streams", - self._add_device_change_txn, - user_id, - device_ids, - hosts, - stream_id, - ) - return stream_id - - def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id): - now = self._clock.time_msec() - - txn.call_after( - self._device_list_stream_cache.entity_has_changed, user_id, stream_id - ) - for host in hosts: - txn.call_after( - self._device_list_federation_stream_cache.entity_has_changed, - host, - stream_id, - ) - - # Delete older entries in the table, as we really only care about - # when the latest change happened. - txn.executemany( - """ - DELETE FROM device_lists_stream - WHERE user_id = ? AND device_id = ? AND stream_id < ? - """, - [(user_id, device_id, stream_id) for device_id in device_ids], - ) - - self._simple_insert_many_txn( - txn, - table="device_lists_stream", - values=[ - {"stream_id": stream_id, "user_id": user_id, "device_id": device_id} - for device_id in device_ids - ], - ) - - context = get_active_span_text_map() - - self._simple_insert_many_txn( - txn, - table="device_lists_outbound_pokes", - values=[ - { - "destination": destination, - "stream_id": stream_id, - "user_id": user_id, - "device_id": device_id, - "sent": False, - "ts": now, - "opentracing_context": json.dumps(context) - if whitelisted_homeserver(destination) - else "{}", - } - for destination in hosts - for device_id in device_ids - ], - ) - - def _prune_old_outbound_device_pokes(self): - """Delete old entries out of the device_lists_outbound_pokes to ensure - that we don't fill up due to dead servers. We keep one entry per - (destination, user_id) tuple to ensure that the prev_ids remain correct - if the server does come back. - """ - yesterday = self._clock.time_msec() - 24 * 60 * 60 * 1000 - - def _prune_txn(txn): - select_sql = """ - SELECT destination, user_id, max(stream_id) as stream_id - FROM device_lists_outbound_pokes - GROUP BY destination, user_id - HAVING min(ts) < ? AND count(*) > 1 - """ - - txn.execute(select_sql, (yesterday,)) - rows = txn.fetchall() - - if not rows: - return - - delete_sql = """ - DELETE FROM device_lists_outbound_pokes - WHERE ts < ? AND destination = ? AND user_id = ? AND stream_id < ? - """ - - txn.executemany( - delete_sql, ((yesterday, row[0], row[1], row[2]) for row in rows) - ) - - # Since we've deleted unsent deltas, we need to remove the entry - # of last successful sent so that the prev_ids are correctly set. - sql = """ - DELETE FROM device_lists_outbound_last_success - WHERE destination = ? AND user_id = ? - """ - txn.executemany(sql, ((row[0], row[1]) for row in rows)) - - logger.info("Pruned %d device list outbound pokes", txn.rowcount) - - return run_as_background_process( - "prune_old_outbound_device_pokes", - self.runInteraction, - "_prune_old_outbound_device_pokes", - _prune_txn, - ) diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py deleted file mode 100644 index eed7757ed5..0000000000 --- a/synapse/storage/directory.py +++ /dev/null @@ -1,174 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import namedtuple - -from twisted.internet import defer - -from synapse.api.errors import SynapseError -from synapse.util.caches.descriptors import cached - -from ._base import SQLBaseStore - -RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers")) - - -class DirectoryWorkerStore(SQLBaseStore): - @defer.inlineCallbacks - def get_association_from_room_alias(self, room_alias): - """ Get's the room_id and server list for a given room_alias - - Args: - room_alias (RoomAlias) - - Returns: - Deferred: results in namedtuple with keys "room_id" and - "servers" or None if no association can be found - """ - room_id = yield self._simple_select_one_onecol( - "room_aliases", - {"room_alias": room_alias.to_string()}, - "room_id", - allow_none=True, - desc="get_association_from_room_alias", - ) - - if not room_id: - return None - - servers = yield self._simple_select_onecol( - "room_alias_servers", - {"room_alias": room_alias.to_string()}, - "server", - desc="get_association_from_room_alias", - ) - - if not servers: - return None - - return RoomAliasMapping(room_id, room_alias.to_string(), servers) - - def get_room_alias_creator(self, room_alias): - return self._simple_select_one_onecol( - table="room_aliases", - keyvalues={"room_alias": room_alias}, - retcol="creator", - desc="get_room_alias_creator", - ) - - @cached(max_entries=5000) - def get_aliases_for_room(self, room_id): - return self._simple_select_onecol( - "room_aliases", - {"room_id": room_id}, - "room_alias", - desc="get_aliases_for_room", - ) - - -class DirectoryStore(DirectoryWorkerStore): - @defer.inlineCallbacks - def create_room_alias_association(self, room_alias, room_id, servers, creator=None): - """ Creates an association between a room alias and room_id/servers - - Args: - room_alias (RoomAlias) - room_id (str) - servers (list) - creator (str): Optional user_id of creator. - - Returns: - Deferred - """ - - def alias_txn(txn): - self._simple_insert_txn( - txn, - "room_aliases", - { - "room_alias": room_alias.to_string(), - "room_id": room_id, - "creator": creator, - }, - ) - - self._simple_insert_many_txn( - txn, - table="room_alias_servers", - values=[ - {"room_alias": room_alias.to_string(), "server": server} - for server in servers - ], - ) - - self._invalidate_cache_and_stream( - txn, self.get_aliases_for_room, (room_id,) - ) - - try: - ret = yield self.runInteraction("create_room_alias_association", alias_txn) - except self.database_engine.module.IntegrityError: - raise SynapseError( - 409, "Room alias %s already exists" % room_alias.to_string() - ) - return ret - - @defer.inlineCallbacks - def delete_room_alias(self, room_alias): - room_id = yield self.runInteraction( - "delete_room_alias", self._delete_room_alias_txn, room_alias - ) - - return room_id - - def _delete_room_alias_txn(self, txn, room_alias): - txn.execute( - "SELECT room_id FROM room_aliases WHERE room_alias = ?", - (room_alias.to_string(),), - ) - - res = txn.fetchone() - if res: - room_id = res[0] - else: - return None - - txn.execute( - "DELETE FROM room_aliases WHERE room_alias = ?", (room_alias.to_string(),) - ) - - txn.execute( - "DELETE FROM room_alias_servers WHERE room_alias = ?", - (room_alias.to_string(),), - ) - - self._invalidate_cache_and_stream(txn, self.get_aliases_for_room, (room_id,)) - - return room_id - - def update_aliases_for_room(self, old_room_id, new_room_id, creator): - def _update_aliases_for_room_txn(txn): - sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?" - txn.execute(sql, (new_room_id, creator, old_room_id)) - self._invalidate_cache_and_stream( - txn, self.get_aliases_for_room, (old_room_id,) - ) - self._invalidate_cache_and_stream( - txn, self.get_aliases_for_room, (new_room_id,) - ) - - return self.runInteraction( - "_update_aliases_for_room_txn", _update_aliases_for_room_txn - ) diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py deleted file mode 100644 index be2fe2bab6..0000000000 --- a/synapse/storage/e2e_room_keys.py +++ /dev/null @@ -1,337 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2017 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json - -from twisted.internet import defer - -from synapse.api.errors import StoreError -from synapse.logging.opentracing import log_kv, trace - -from ._base import SQLBaseStore - - -class EndToEndRoomKeyStore(SQLBaseStore): - @defer.inlineCallbacks - def get_e2e_room_key(self, user_id, version, room_id, session_id): - """Get the encrypted E2E room key for a given session from a given - backup version of room_keys. We only store the 'best' room key for a given - session at a given time, as determined by the handler. - - Args: - user_id(str): the user whose backup we're querying - version(str): the version ID of the backup for the set of keys we're querying - room_id(str): the ID of the room whose keys we're querying. - This is a bit redundant as it's implied by the session_id, but - we include for consistency with the rest of the API. - session_id(str): the session whose room_key we're querying. - - Returns: - A deferred dict giving the session_data and message metadata for - this room key. - """ - - row = yield self._simple_select_one( - table="e2e_room_keys", - keyvalues={ - "user_id": user_id, - "version": version, - "room_id": room_id, - "session_id": session_id, - }, - retcols=( - "first_message_index", - "forwarded_count", - "is_verified", - "session_data", - ), - desc="get_e2e_room_key", - ) - - row["session_data"] = json.loads(row["session_data"]) - - return row - - @defer.inlineCallbacks - def set_e2e_room_key(self, user_id, version, room_id, session_id, room_key): - """Replaces or inserts the encrypted E2E room key for a given session in - a given backup - - Args: - user_id(str): the user whose backup we're setting - version(str): the version ID of the backup we're updating - room_id(str): the ID of the room whose keys we're setting - session_id(str): the session whose room_key we're setting - room_key(dict): the room_key being set - Raises: - StoreError - """ - - yield self._simple_upsert( - table="e2e_room_keys", - keyvalues={ - "user_id": user_id, - "version": version, - "room_id": room_id, - "session_id": session_id, - }, - values={ - "first_message_index": room_key["first_message_index"], - "forwarded_count": room_key["forwarded_count"], - "is_verified": room_key["is_verified"], - "session_data": json.dumps(room_key["session_data"]), - }, - lock=False, - ) - log_kv( - { - "message": "Set room key", - "room_id": room_id, - "session_id": session_id, - "room_key": room_key, - } - ) - - @trace - @defer.inlineCallbacks - def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): - """Bulk get the E2E room keys for a given backup, optionally filtered to a given - room, or a given session. - - Args: - user_id(str): the user whose backup we're querying - version(str): the version ID of the backup for the set of keys we're querying - room_id(str): Optional. the ID of the room whose keys we're querying, if any. - If not specified, we return the keys for all the rooms in the backup. - session_id(str): Optional. the session whose room_key we're querying, if any. - If specified, we also require the room_id to be specified. - If not specified, we return all the keys in this version of - the backup (or for the specified room) - - Returns: - A deferred list of dicts giving the session_data and message metadata for - these room keys. - """ - - try: - version = int(version) - except ValueError: - return {"rooms": {}} - - keyvalues = {"user_id": user_id, "version": version} - if room_id: - keyvalues["room_id"] = room_id - if session_id: - keyvalues["session_id"] = session_id - - rows = yield self._simple_select_list( - table="e2e_room_keys", - keyvalues=keyvalues, - retcols=( - "user_id", - "room_id", - "session_id", - "first_message_index", - "forwarded_count", - "is_verified", - "session_data", - ), - desc="get_e2e_room_keys", - ) - - sessions = {"rooms": {}} - for row in rows: - room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}}) - room_entry["sessions"][row["session_id"]] = { - "first_message_index": row["first_message_index"], - "forwarded_count": row["forwarded_count"], - "is_verified": row["is_verified"], - "session_data": json.loads(row["session_data"]), - } - - return sessions - - @trace - @defer.inlineCallbacks - def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): - """Bulk delete the E2E room keys for a given backup, optionally filtered to a given - room or a given session. - - Args: - user_id(str): the user whose backup we're deleting from - version(str): the version ID of the backup for the set of keys we're deleting - room_id(str): Optional. the ID of the room whose keys we're deleting, if any. - If not specified, we delete the keys for all the rooms in the backup. - session_id(str): Optional. the session whose room_key we're querying, if any. - If specified, we also require the room_id to be specified. - If not specified, we delete all the keys in this version of - the backup (or for the specified room) - - Returns: - A deferred of the deletion transaction - """ - - keyvalues = {"user_id": user_id, "version": int(version)} - if room_id: - keyvalues["room_id"] = room_id - if session_id: - keyvalues["session_id"] = session_id - - yield self._simple_delete( - table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys" - ) - - @staticmethod - def _get_current_version(txn, user_id): - txn.execute( - "SELECT MAX(version) FROM e2e_room_keys_versions " - "WHERE user_id=? AND deleted=0", - (user_id,), - ) - row = txn.fetchone() - if not row: - raise StoreError(404, "No current backup version") - return row[0] - - def get_e2e_room_keys_version_info(self, user_id, version=None): - """Get info metadata about a version of our room_keys backup. - - Args: - user_id(str): the user whose backup we're querying - version(str): Optional. the version ID of the backup we're querying about - If missing, we return the information about the current version. - Raises: - StoreError: with code 404 if there are no e2e_room_keys_versions present - Returns: - A deferred dict giving the info metadata for this backup version, with - fields including: - version(str) - algorithm(str) - auth_data(object): opaque dict supplied by the client - """ - - def _get_e2e_room_keys_version_info_txn(txn): - if version is None: - this_version = self._get_current_version(txn, user_id) - else: - try: - this_version = int(version) - except ValueError: - # Our versions are all ints so if we can't convert it to an integer, - # it isn't there. - raise StoreError(404, "No row found") - - result = self._simple_select_one_txn( - txn, - table="e2e_room_keys_versions", - keyvalues={"user_id": user_id, "version": this_version, "deleted": 0}, - retcols=("version", "algorithm", "auth_data"), - ) - result["auth_data"] = json.loads(result["auth_data"]) - result["version"] = str(result["version"]) - return result - - return self.runInteraction( - "get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn - ) - - @trace - def create_e2e_room_keys_version(self, user_id, info): - """Atomically creates a new version of this user's e2e_room_keys store - with the given version info. - - Args: - user_id(str): the user whose backup we're creating a version - info(dict): the info about the backup version to be created - - Returns: - A deferred string for the newly created version ID - """ - - def _create_e2e_room_keys_version_txn(txn): - txn.execute( - "SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?", - (user_id,), - ) - current_version = txn.fetchone()[0] - if current_version is None: - current_version = "0" - - new_version = str(int(current_version) + 1) - - self._simple_insert_txn( - txn, - table="e2e_room_keys_versions", - values={ - "user_id": user_id, - "version": new_version, - "algorithm": info["algorithm"], - "auth_data": json.dumps(info["auth_data"]), - }, - ) - - return new_version - - return self.runInteraction( - "create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn - ) - - @trace - def update_e2e_room_keys_version(self, user_id, version, info): - """Update a given backup version - - Args: - user_id(str): the user whose backup version we're updating - version(str): the version ID of the backup version we're updating - info(dict): the new backup version info to store - """ - - return self._simple_update( - table="e2e_room_keys_versions", - keyvalues={"user_id": user_id, "version": version}, - updatevalues={"auth_data": json.dumps(info["auth_data"])}, - desc="update_e2e_room_keys_version", - ) - - @trace - def delete_e2e_room_keys_version(self, user_id, version=None): - """Delete a given backup version of the user's room keys. - Doesn't delete their actual key data. - - Args: - user_id(str): the user whose backup version we're deleting - version(str): Optional. the version ID of the backup version we're deleting - If missing, we delete the current backup version info. - Raises: - StoreError: with code 404 if there are no e2e_room_keys_versions present, - or if the version requested doesn't exist. - """ - - def _delete_e2e_room_keys_version_txn(txn): - if version is None: - this_version = self._get_current_version(txn, user_id) - else: - this_version = version - - return self._simple_update_one_txn( - txn, - table="e2e_room_keys_versions", - keyvalues={"user_id": user_id, "version": this_version}, - updatevalues={"deleted": 1}, - ) - - return self.runInteraction( - "delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn - ) diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py deleted file mode 100644 index 872bc75490..0000000000 --- a/synapse/storage/end_to_end_keys.py +++ /dev/null @@ -1,323 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from six import iteritems - -from canonicaljson import encode_canonical_json - -from twisted.internet import defer - -from synapse.logging.opentracing import log_kv, set_tag, trace -from synapse.util.caches.descriptors import cached - -from ._base import SQLBaseStore, db_to_json - - -class EndToEndKeyWorkerStore(SQLBaseStore): - @trace - @defer.inlineCallbacks - def get_e2e_device_keys( - self, query_list, include_all_devices=False, include_deleted_devices=False - ): - """Fetch a list of device keys. - Args: - query_list(list): List of pairs of user_ids and device_ids. - include_all_devices (bool): whether to include entries for devices - that don't have device keys - include_deleted_devices (bool): whether to include null entries for - devices which no longer exist (but were in the query_list). - This option only takes effect if include_all_devices is true. - Returns: - Dict mapping from user-id to dict mapping from device_id to - key data. The key data will be a dict in the same format as the - DeviceKeys type returned by POST /_matrix/client/r0/keys/query. - """ - set_tag("query_list", query_list) - if not query_list: - return {} - - results = yield self.runInteraction( - "get_e2e_device_keys", - self._get_e2e_device_keys_txn, - query_list, - include_all_devices, - include_deleted_devices, - ) - - # Build the result structure, un-jsonify the results, and add the - # "unsigned" section - rv = {} - for user_id, device_keys in iteritems(results): - rv[user_id] = {} - for device_id, device_info in iteritems(device_keys): - r = db_to_json(device_info.pop("key_json")) - r["unsigned"] = {} - display_name = device_info["device_display_name"] - if display_name is not None: - r["unsigned"]["device_display_name"] = display_name - rv[user_id][device_id] = r - - return rv - - @trace - def _get_e2e_device_keys_txn( - self, txn, query_list, include_all_devices=False, include_deleted_devices=False - ): - set_tag("include_all_devices", include_all_devices) - set_tag("include_deleted_devices", include_deleted_devices) - - query_clauses = [] - query_params = [] - - if include_all_devices is False: - include_deleted_devices = False - - if include_deleted_devices: - deleted_devices = set(query_list) - - for (user_id, device_id) in query_list: - query_clause = "user_id = ?" - query_params.append(user_id) - - if device_id is not None: - query_clause += " AND device_id = ?" - query_params.append(device_id) - - query_clauses.append(query_clause) - - sql = ( - "SELECT user_id, device_id, " - " d.display_name AS device_display_name, " - " k.key_json" - " FROM devices d" - " %s JOIN e2e_device_keys_json k USING (user_id, device_id)" - " WHERE %s" - ) % ( - "LEFT" if include_all_devices else "INNER", - " OR ".join("(" + q + ")" for q in query_clauses), - ) - - txn.execute(sql, query_params) - rows = self.cursor_to_dict(txn) - - result = {} - for row in rows: - if include_deleted_devices: - deleted_devices.remove((row["user_id"], row["device_id"])) - result.setdefault(row["user_id"], {})[row["device_id"]] = row - - if include_deleted_devices: - for user_id, device_id in deleted_devices: - result.setdefault(user_id, {})[device_id] = None - - log_kv(result) - return result - - @defer.inlineCallbacks - def get_e2e_one_time_keys(self, user_id, device_id, key_ids): - """Retrieve a number of one-time keys for a user - - Args: - user_id(str): id of user to get keys for - device_id(str): id of device to get keys for - key_ids(list[str]): list of key ids (excluding algorithm) to - retrieve - - Returns: - deferred resolving to Dict[(str, str), str]: map from (algorithm, - key_id) to json string for key - """ - - rows = yield self._simple_select_many_batch( - table="e2e_one_time_keys_json", - column="key_id", - iterable=key_ids, - retcols=("algorithm", "key_id", "key_json"), - keyvalues={"user_id": user_id, "device_id": device_id}, - desc="add_e2e_one_time_keys_check", - ) - result = {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows} - log_kv({"message": "Fetched one time keys for user", "one_time_keys": result}) - return result - - @defer.inlineCallbacks - def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys): - """Insert some new one time keys for a device. Errors if any of the - keys already exist. - - Args: - user_id(str): id of user to get keys for - device_id(str): id of device to get keys for - time_now(long): insertion time to record (ms since epoch) - new_keys(iterable[(str, str, str)]: keys to add - each a tuple of - (algorithm, key_id, key json) - """ - - def _add_e2e_one_time_keys(txn): - set_tag("user_id", user_id) - set_tag("device_id", device_id) - set_tag("new_keys", new_keys) - # We are protected from race between lookup and insertion due to - # a unique constraint. If there is a race of two calls to - # `add_e2e_one_time_keys` then they'll conflict and we will only - # insert one set. - self._simple_insert_many_txn( - txn, - table="e2e_one_time_keys_json", - values=[ - { - "user_id": user_id, - "device_id": device_id, - "algorithm": algorithm, - "key_id": key_id, - "ts_added_ms": time_now, - "key_json": json_bytes, - } - for algorithm, key_id, json_bytes in new_keys - ], - ) - self._invalidate_cache_and_stream( - txn, self.count_e2e_one_time_keys, (user_id, device_id) - ) - - yield self.runInteraction( - "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys - ) - - @cached(max_entries=10000) - def count_e2e_one_time_keys(self, user_id, device_id): - """ Count the number of one time keys the server has for a device - Returns: - Dict mapping from algorithm to number of keys for that algorithm. - """ - - def _count_e2e_one_time_keys(txn): - sql = ( - "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json" - " WHERE user_id = ? AND device_id = ?" - " GROUP BY algorithm" - ) - txn.execute(sql, (user_id, device_id)) - result = {} - for algorithm, key_count in txn: - result[algorithm] = key_count - return result - - return self.runInteraction("count_e2e_one_time_keys", _count_e2e_one_time_keys) - - -class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): - def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys): - """Stores device keys for a device. Returns whether there was a change - or the keys were already in the database. - """ - - def _set_e2e_device_keys_txn(txn): - set_tag("user_id", user_id) - set_tag("device_id", device_id) - set_tag("time_now", time_now) - set_tag("device_keys", device_keys) - - old_key_json = self._simple_select_one_onecol_txn( - txn, - table="e2e_device_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - retcol="key_json", - allow_none=True, - ) - - # In py3 we need old_key_json to match new_key_json type. The DB - # returns unicode while encode_canonical_json returns bytes. - new_key_json = encode_canonical_json(device_keys).decode("utf-8") - - if old_key_json == new_key_json: - log_kv({"Message": "Device key already stored."}) - return False - - self._simple_upsert_txn( - txn, - table="e2e_device_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - values={"ts_added_ms": time_now, "key_json": new_key_json}, - ) - log_kv({"message": "Device keys stored."}) - return True - - return self.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn) - - def claim_e2e_one_time_keys(self, query_list): - """Take a list of one time keys out of the database""" - - @trace - def _claim_e2e_one_time_keys(txn): - sql = ( - "SELECT key_id, key_json FROM e2e_one_time_keys_json" - " WHERE user_id = ? AND device_id = ? AND algorithm = ?" - " LIMIT 1" - ) - result = {} - delete = [] - for user_id, device_id, algorithm in query_list: - user_result = result.setdefault(user_id, {}) - device_result = user_result.setdefault(device_id, {}) - txn.execute(sql, (user_id, device_id, algorithm)) - for key_id, key_json in txn: - device_result[algorithm + ":" + key_id] = key_json - delete.append((user_id, device_id, algorithm, key_id)) - sql = ( - "DELETE FROM e2e_one_time_keys_json" - " WHERE user_id = ? AND device_id = ? AND algorithm = ?" - " AND key_id = ?" - ) - for user_id, device_id, algorithm, key_id in delete: - log_kv( - { - "message": "Executing claim e2e_one_time_keys transaction on database." - } - ) - txn.execute(sql, (user_id, device_id, algorithm, key_id)) - log_kv({"message": "finished executing and invalidating cache"}) - self._invalidate_cache_and_stream( - txn, self.count_e2e_one_time_keys, (user_id, device_id) - ) - return result - - return self.runInteraction("claim_e2e_one_time_keys", _claim_e2e_one_time_keys) - - def delete_e2e_keys_by_device(self, user_id, device_id): - def delete_e2e_keys_by_device_txn(txn): - log_kv( - { - "message": "Deleting keys for device", - "device_id": device_id, - "user_id": user_id, - } - ) - self._simple_delete_txn( - txn, - table="e2e_device_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - ) - self._simple_delete_txn( - txn, - table="e2e_one_time_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - ) - self._invalidate_cache_and_stream( - txn, self.count_e2e_one_time_keys, (user_id, device_id) - ) - - return self.runInteraction( - "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn - ) diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py deleted file mode 100644 index 47cc10d32a..0000000000 --- a/synapse/storage/event_federation.py +++ /dev/null @@ -1,672 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import itertools -import logging -import random - -from six.moves import range -from six.moves.queue import Empty, PriorityQueue - -from unpaddedbase64 import encode_base64 - -from twisted.internet import defer - -from synapse.api.errors import StoreError -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.storage.events_worker import EventsWorkerStore -from synapse.storage.signatures import SignatureWorkerStore -from synapse.util.caches.descriptors import cached - -logger = logging.getLogger(__name__) - - -class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): - def get_auth_chain(self, event_ids, include_given=False): - """Get auth events for given event_ids. The events *must* be state events. - - Args: - event_ids (list): state events - include_given (bool): include the given events in result - - Returns: - list of events - """ - return self.get_auth_chain_ids( - event_ids, include_given=include_given - ).addCallback(self.get_events_as_list) - - def get_auth_chain_ids(self, event_ids, include_given=False): - """Get auth events for given event_ids. The events *must* be state events. - - Args: - event_ids (list): state events - include_given (bool): include the given events in result - - Returns: - list of event_ids - """ - return self.runInteraction( - "get_auth_chain_ids", self._get_auth_chain_ids_txn, event_ids, include_given - ) - - def _get_auth_chain_ids_txn(self, txn, event_ids, include_given): - if include_given: - results = set(event_ids) - else: - results = set() - - base_sql = "SELECT auth_id FROM event_auth WHERE " - - front = set(event_ids) - while front: - new_front = set() - front_list = list(front) - chunks = [front_list[x : x + 100] for x in range(0, len(front), 100)] - for chunk in chunks: - clause, args = make_in_list_sql_clause( - txn.database_engine, "event_id", chunk - ) - txn.execute(base_sql + clause, list(args)) - new_front.update([r[0] for r in txn]) - - new_front -= results - - front = new_front - results.update(front) - - return list(results) - - def get_oldest_events_in_room(self, room_id): - return self.runInteraction( - "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id - ) - - def get_oldest_events_with_depth_in_room(self, room_id): - return self.runInteraction( - "get_oldest_events_with_depth_in_room", - self.get_oldest_events_with_depth_in_room_txn, - room_id, - ) - - def get_oldest_events_with_depth_in_room_txn(self, txn, room_id): - sql = ( - "SELECT b.event_id, MAX(e.depth) FROM events as e" - " INNER JOIN event_edges as g" - " ON g.event_id = e.event_id" - " INNER JOIN event_backward_extremities as b" - " ON g.prev_event_id = b.event_id" - " WHERE b.room_id = ? AND g.is_state is ?" - " GROUP BY b.event_id" - ) - - txn.execute(sql, (room_id, False)) - - return dict(txn) - - @defer.inlineCallbacks - def get_max_depth_of(self, event_ids): - """Returns the max depth of a set of event IDs - - Args: - event_ids (list[str]) - - Returns - Deferred[int] - """ - rows = yield self._simple_select_many_batch( - table="events", - column="event_id", - iterable=event_ids, - retcols=("depth",), - desc="get_max_depth_of", - ) - - if not rows: - return 0 - else: - return max(row["depth"] for row in rows) - - def _get_oldest_events_in_room_txn(self, txn, room_id): - return self._simple_select_onecol_txn( - txn, - table="event_backward_extremities", - keyvalues={"room_id": room_id}, - retcol="event_id", - ) - - @defer.inlineCallbacks - def get_prev_events_for_room(self, room_id): - """ - Gets a subset of the current forward extremities in the given room. - - Limits the result to 10 extremities, so that we can avoid creating - events which refer to hundreds of prev_events. - - Args: - room_id (str): room_id - - Returns: - Deferred[list[(str, dict[str, str], int)]] - for each event, a tuple of (event_id, hashes, depth) - where *hashes* is a map from algorithm to hash. - """ - res = yield self.get_latest_event_ids_and_hashes_in_room(room_id) - if len(res) > 10: - # Sort by reverse depth, so we point to the most recent. - res.sort(key=lambda a: -a[2]) - - # we use half of the limit for the actual most recent events, and - # the other half to randomly point to some of the older events, to - # make sure that we don't completely ignore the older events. - res = res[0:5] + random.sample(res[5:], 5) - - return res - - def get_latest_event_ids_and_hashes_in_room(self, room_id): - """ - Gets the current forward extremities in the given room - - Args: - room_id (str): room_id - - Returns: - Deferred[list[(str, dict[str, str], int)]] - for each event, a tuple of (event_id, hashes, depth) - where *hashes* is a map from algorithm to hash. - """ - - return self.runInteraction( - "get_latest_event_ids_and_hashes_in_room", - self._get_latest_event_ids_and_hashes_in_room, - room_id, - ) - - def get_rooms_with_many_extremities(self, min_count, limit, room_id_filter): - """Get the top rooms with at least N extremities. - - Args: - min_count (int): The minimum number of extremities - limit (int): The maximum number of rooms to return. - room_id_filter (iterable[str]): room_ids to exclude from the results - - Returns: - Deferred[list]: At most `limit` room IDs that have at least - `min_count` extremities, sorted by extremity count. - """ - - def _get_rooms_with_many_extremities_txn(txn): - where_clause = "1=1" - if room_id_filter: - where_clause = "room_id NOT IN (%s)" % ( - ",".join("?" for _ in room_id_filter), - ) - - sql = """ - SELECT room_id FROM event_forward_extremities - WHERE %s - GROUP BY room_id - HAVING count(*) > ? - ORDER BY count(*) DESC - LIMIT ? - """ % ( - where_clause, - ) - - query_args = list(itertools.chain(room_id_filter, [min_count, limit])) - txn.execute(sql, query_args) - return [room_id for room_id, in txn] - - return self.runInteraction( - "get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn - ) - - @cached(max_entries=5000, iterable=True) - def get_latest_event_ids_in_room(self, room_id): - return self._simple_select_onecol( - table="event_forward_extremities", - keyvalues={"room_id": room_id}, - retcol="event_id", - desc="get_latest_event_ids_in_room", - ) - - def _get_latest_event_ids_and_hashes_in_room(self, txn, room_id): - sql = ( - "SELECT e.event_id, e.depth FROM events as e " - "INNER JOIN event_forward_extremities as f " - "ON e.event_id = f.event_id " - "AND e.room_id = f.room_id " - "WHERE f.room_id = ?" - ) - - txn.execute(sql, (room_id,)) - - results = [] - for event_id, depth in txn.fetchall(): - hashes = self._get_event_reference_hashes_txn(txn, event_id) - prev_hashes = { - k: encode_base64(v) for k, v in hashes.items() if k == "sha256" - } - results.append((event_id, prev_hashes, depth)) - - return results - - def get_min_depth(self, room_id): - """ For hte given room, get the minimum depth we have seen for it. - """ - return self.runInteraction( - "get_min_depth", self._get_min_depth_interaction, room_id - ) - - def _get_min_depth_interaction(self, txn, room_id): - min_depth = self._simple_select_one_onecol_txn( - txn, - table="room_depth", - keyvalues={"room_id": room_id}, - retcol="min_depth", - allow_none=True, - ) - - return int(min_depth) if min_depth is not None else None - - def get_forward_extremeties_for_room(self, room_id, stream_ordering): - """For a given room_id and stream_ordering, return the forward - extremeties of the room at that point in "time". - - Throws a StoreError if we have since purged the index for - stream_orderings from that point. - - Args: - room_id (str): - stream_ordering (int): - - Returns: - deferred, which resolves to a list of event_ids - """ - # We want to make the cache more effective, so we clamp to the last - # change before the given ordering. - last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id) - - # We don't always have a full stream_to_exterm_id table, e.g. after - # the upgrade that introduced it, so we make sure we never ask for a - # stream_ordering from before a restart - last_change = max(self._stream_order_on_start, last_change) - - # provided the last_change is recent enough, we now clamp the requested - # stream_ordering to it. - if last_change > self.stream_ordering_month_ago: - stream_ordering = min(last_change, stream_ordering) - - return self._get_forward_extremeties_for_room(room_id, stream_ordering) - - @cached(max_entries=5000, num_args=2) - def _get_forward_extremeties_for_room(self, room_id, stream_ordering): - """For a given room_id and stream_ordering, return the forward - extremeties of the room at that point in "time". - - Throws a StoreError if we have since purged the index for - stream_orderings from that point. - """ - - if stream_ordering <= self.stream_ordering_month_ago: - raise StoreError(400, "stream_ordering too old") - - sql = """ - SELECT event_id FROM stream_ordering_to_exterm - INNER JOIN ( - SELECT room_id, MAX(stream_ordering) AS stream_ordering - FROM stream_ordering_to_exterm - WHERE stream_ordering <= ? GROUP BY room_id - ) AS rms USING (room_id, stream_ordering) - WHERE room_id = ? - """ - - def get_forward_extremeties_for_room_txn(txn): - txn.execute(sql, (stream_ordering, room_id)) - return [event_id for event_id, in txn] - - return self.runInteraction( - "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn - ) - - def get_backfill_events(self, room_id, event_list, limit): - """Get a list of Events for a given topic that occurred before (and - including) the events in event_list. Return a list of max size `limit` - - Args: - txn - room_id (str) - event_list (list) - limit (int) - """ - return ( - self.runInteraction( - "get_backfill_events", - self._get_backfill_events, - room_id, - event_list, - limit, - ) - .addCallback(self.get_events_as_list) - .addCallback(lambda l: sorted(l, key=lambda e: -e.depth)) - ) - - def _get_backfill_events(self, txn, room_id, event_list, limit): - logger.debug( - "_get_backfill_events: %s, %s, %s", room_id, repr(event_list), limit - ) - - event_results = set() - - # We want to make sure that we do a breadth-first, "depth" ordered - # search. - - query = ( - "SELECT depth, prev_event_id FROM event_edges" - " INNER JOIN events" - " ON prev_event_id = events.event_id" - " WHERE event_edges.event_id = ?" - " AND event_edges.is_state = ?" - " LIMIT ?" - ) - - queue = PriorityQueue() - - for event_id in event_list: - depth = self._simple_select_one_onecol_txn( - txn, - table="events", - keyvalues={"event_id": event_id, "room_id": room_id}, - retcol="depth", - allow_none=True, - ) - - if depth: - queue.put((-depth, event_id)) - - while not queue.empty() and len(event_results) < limit: - try: - _, event_id = queue.get_nowait() - except Empty: - break - - if event_id in event_results: - continue - - event_results.add(event_id) - - txn.execute(query, (event_id, False, limit - len(event_results))) - - for row in txn: - if row[1] not in event_results: - queue.put((-row[0], row[1])) - - return event_results - - @defer.inlineCallbacks - def get_missing_events(self, room_id, earliest_events, latest_events, limit): - ids = yield self.runInteraction( - "get_missing_events", - self._get_missing_events, - room_id, - earliest_events, - latest_events, - limit, - ) - events = yield self.get_events_as_list(ids) - return events - - def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit): - - seen_events = set(earliest_events) - front = set(latest_events) - seen_events - event_results = [] - - query = ( - "SELECT prev_event_id FROM event_edges " - "WHERE room_id = ? AND event_id = ? AND is_state = ? " - "LIMIT ?" - ) - - while front and len(event_results) < limit: - new_front = set() - for event_id in front: - txn.execute( - query, (room_id, event_id, False, limit - len(event_results)) - ) - - new_results = set(t[0] for t in txn) - seen_events - - new_front |= new_results - seen_events |= new_results - event_results.extend(new_results) - - front = new_front - - # we built the list working backwards from latest_events; we now need to - # reverse it so that the events are approximately chronological. - event_results.reverse() - return event_results - - @defer.inlineCallbacks - def get_successor_events(self, event_ids): - """Fetch all events that have the given events as a prev event - - Args: - event_ids (iterable[str]) - - Returns: - Deferred[list[str]] - """ - rows = yield self._simple_select_many_batch( - table="event_edges", - column="prev_event_id", - iterable=event_ids, - retcols=("event_id",), - desc="get_successor_events", - ) - - return [row["event_id"] for row in rows] - - -class EventFederationStore(EventFederationWorkerStore): - """ Responsible for storing and serving up the various graphs associated - with an event. Including the main event graph and the auth chains for an - event. - - Also has methods for getting the front (latest) and back (oldest) edges - of the event graphs. These are used to generate the parents for new events - and backfilling from another server respectively. - """ - - EVENT_AUTH_STATE_ONLY = "event_auth_state_only" - - def __init__(self, db_conn, hs): - super(EventFederationStore, self).__init__(db_conn, hs) - - self.register_background_update_handler( - self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth - ) - - hs.get_clock().looping_call( - self._delete_old_forward_extrem_cache, 60 * 60 * 1000 - ) - - def _update_min_depth_for_room_txn(self, txn, room_id, depth): - min_depth = self._get_min_depth_interaction(txn, room_id) - - if min_depth and depth >= min_depth: - return - - self._simple_upsert_txn( - txn, - table="room_depth", - keyvalues={"room_id": room_id}, - values={"min_depth": depth}, - ) - - def _handle_mult_prev_events(self, txn, events): - """ - For the given event, update the event edges table and forward and - backward extremities tables. - """ - self._simple_insert_many_txn( - txn, - table="event_edges", - values=[ - { - "event_id": ev.event_id, - "prev_event_id": e_id, - "room_id": ev.room_id, - "is_state": False, - } - for ev in events - for e_id in ev.prev_event_ids() - ], - ) - - self._update_backward_extremeties(txn, events) - - def _update_backward_extremeties(self, txn, events): - """Updates the event_backward_extremities tables based on the new/updated - events being persisted. - - This is called for new events *and* for events that were outliers, but - are now being persisted as non-outliers. - - Forward extremities are handled when we first start persisting the events. - """ - events_by_room = {} - for ev in events: - events_by_room.setdefault(ev.room_id, []).append(ev) - - query = ( - "INSERT INTO event_backward_extremities (event_id, room_id)" - " SELECT ?, ? WHERE NOT EXISTS (" - " SELECT 1 FROM event_backward_extremities" - " WHERE event_id = ? AND room_id = ?" - " )" - " AND NOT EXISTS (" - " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " - " AND outlier = ?" - " )" - ) - - txn.executemany( - query, - [ - (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False) - for ev in events - for e_id in ev.prev_event_ids() - if not ev.internal_metadata.is_outlier() - ], - ) - - query = ( - "DELETE FROM event_backward_extremities" - " WHERE event_id = ? AND room_id = ?" - ) - txn.executemany( - query, - [ - (ev.event_id, ev.room_id) - for ev in events - if not ev.internal_metadata.is_outlier() - ], - ) - - def _delete_old_forward_extrem_cache(self): - def _delete_old_forward_extrem_cache_txn(txn): - # Delete entries older than a month, while making sure we don't delete - # the only entries for a room. - sql = """ - DELETE FROM stream_ordering_to_exterm - WHERE - room_id IN ( - SELECT room_id - FROM stream_ordering_to_exterm - WHERE stream_ordering > ? - ) AND stream_ordering < ? - """ - txn.execute( - sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago) - ) - - return run_as_background_process( - "delete_old_forward_extrem_cache", - self.runInteraction, - "_delete_old_forward_extrem_cache", - _delete_old_forward_extrem_cache_txn, - ) - - def clean_room_for_join(self, room_id): - return self.runInteraction( - "clean_room_for_join", self._clean_room_for_join_txn, room_id - ) - - def _clean_room_for_join_txn(self, txn, room_id): - query = "DELETE FROM event_forward_extremities WHERE room_id = ?" - - txn.execute(query, (room_id,)) - txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) - - @defer.inlineCallbacks - def _background_delete_non_state_event_auth(self, progress, batch_size): - def delete_event_auth(txn): - target_min_stream_id = progress.get("target_min_stream_id_inclusive") - max_stream_id = progress.get("max_stream_id_exclusive") - - if not target_min_stream_id or not max_stream_id: - txn.execute("SELECT COALESCE(MIN(stream_ordering), 0) FROM events") - rows = txn.fetchall() - target_min_stream_id = rows[0][0] - - txn.execute("SELECT COALESCE(MAX(stream_ordering), 0) FROM events") - rows = txn.fetchall() - max_stream_id = rows[0][0] - - min_stream_id = max_stream_id - batch_size - - sql = """ - DELETE FROM event_auth - WHERE event_id IN ( - SELECT event_id FROM events - LEFT JOIN state_events USING (room_id, event_id) - WHERE ? <= stream_ordering AND stream_ordering < ? - AND state_key IS null - ) - """ - - txn.execute(sql, (min_stream_id, max_stream_id)) - - new_progress = { - "target_min_stream_id_inclusive": target_min_stream_id, - "max_stream_id_exclusive": min_stream_id, - } - - self._background_update_progress_txn( - txn, self.EVENT_AUTH_STATE_ONLY, new_progress - ) - - return min_stream_id >= target_min_stream_id - - result = yield self.runInteraction( - self.EVENT_AUTH_STATE_ONLY, delete_event_auth - ) - - if not result: - yield self._end_background_update(self.EVENT_AUTH_STATE_ONLY) - - return batch_size diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py deleted file mode 100644 index 22025effbc..0000000000 --- a/synapse/storage/event_push_actions.py +++ /dev/null @@ -1,960 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015 OpenMarket Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from six import iteritems - -from canonicaljson import json - -from twisted.internet import defer - -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import LoggingTransaction, SQLBaseStore -from synapse.util.caches.descriptors import cachedInlineCallbacks - -logger = logging.getLogger(__name__) - - -DEFAULT_NOTIF_ACTION = ["notify", {"set_tweak": "highlight", "value": False}] -DEFAULT_HIGHLIGHT_ACTION = [ - "notify", - {"set_tweak": "sound", "value": "default"}, - {"set_tweak": "highlight"}, -] - - -def _serialize_action(actions, is_highlight): - """Custom serializer for actions. This allows us to "compress" common actions. - - We use the fact that most users have the same actions for notifs (and for - highlights). - We store these default actions as the empty string rather than the full JSON. - Since the empty string isn't valid JSON there is no risk of this clashing with - any real JSON actions - """ - if is_highlight: - if actions == DEFAULT_HIGHLIGHT_ACTION: - return "" # We use empty string as the column is non-NULL - else: - if actions == DEFAULT_NOTIF_ACTION: - return "" - return json.dumps(actions) - - -def _deserialize_action(actions, is_highlight): - """Custom deserializer for actions. This allows us to "compress" common actions - """ - if actions: - return json.loads(actions) - - if is_highlight: - return DEFAULT_HIGHLIGHT_ACTION - else: - return DEFAULT_NOTIF_ACTION - - -class EventPushActionsWorkerStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(EventPushActionsWorkerStore, self).__init__(db_conn, hs) - - # These get correctly set by _find_stream_orderings_for_times_txn - self.stream_ordering_month_ago = None - self.stream_ordering_day_ago = None - - cur = LoggingTransaction( - db_conn.cursor(), - name="_find_stream_orderings_for_times_txn", - database_engine=self.database_engine, - ) - self._find_stream_orderings_for_times_txn(cur) - cur.close() - - self.find_stream_orderings_looping_call = self._clock.looping_call( - self._find_stream_orderings_for_times, 10 * 60 * 1000 - ) - self._rotate_delay = 3 - self._rotate_count = 10000 - - @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000) - def get_unread_event_push_actions_by_room_for_user( - self, room_id, user_id, last_read_event_id - ): - ret = yield self.runInteraction( - "get_unread_event_push_actions_by_room", - self._get_unread_counts_by_receipt_txn, - room_id, - user_id, - last_read_event_id, - ) - return ret - - def _get_unread_counts_by_receipt_txn( - self, txn, room_id, user_id, last_read_event_id - ): - sql = ( - "SELECT stream_ordering" - " FROM events" - " WHERE room_id = ? AND event_id = ?" - ) - txn.execute(sql, (room_id, last_read_event_id)) - results = txn.fetchall() - if len(results) == 0: - return {"notify_count": 0, "highlight_count": 0} - - stream_ordering = results[0][0] - - return self._get_unread_counts_by_pos_txn( - txn, room_id, user_id, stream_ordering - ) - - def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering): - - # First get number of notifications. - # We don't need to put a notif=1 clause as all rows always have - # notif=1 - sql = ( - "SELECT count(*)" - " FROM event_push_actions ea" - " WHERE" - " user_id = ?" - " AND room_id = ?" - " AND stream_ordering > ?" - ) - - txn.execute(sql, (user_id, room_id, stream_ordering)) - row = txn.fetchone() - notify_count = row[0] if row else 0 - - txn.execute( - """ - SELECT notif_count FROM event_push_summary - WHERE room_id = ? AND user_id = ? AND stream_ordering > ? - """, - (room_id, user_id, stream_ordering), - ) - rows = txn.fetchall() - if rows: - notify_count += rows[0][0] - - # Now get the number of highlights - sql = ( - "SELECT count(*)" - " FROM event_push_actions ea" - " WHERE" - " highlight = 1" - " AND user_id = ?" - " AND room_id = ?" - " AND stream_ordering > ?" - ) - - txn.execute(sql, (user_id, room_id, stream_ordering)) - row = txn.fetchone() - highlight_count = row[0] if row else 0 - - return {"notify_count": notify_count, "highlight_count": highlight_count} - - @defer.inlineCallbacks - def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering): - def f(txn): - sql = ( - "SELECT DISTINCT(user_id) FROM event_push_actions WHERE" - " stream_ordering >= ? AND stream_ordering <= ?" - ) - txn.execute(sql, (min_stream_ordering, max_stream_ordering)) - return [r[0] for r in txn] - - ret = yield self.runInteraction("get_push_action_users_in_range", f) - return ret - - @defer.inlineCallbacks - def get_unread_push_actions_for_user_in_range_for_http( - self, user_id, min_stream_ordering, max_stream_ordering, limit=20 - ): - """Get a list of the most recent unread push actions for a given user, - within the given stream ordering range. Called by the httppusher. - - Args: - user_id (str): The user to fetch push actions for. - min_stream_ordering(int): The exclusive lower bound on the - stream ordering of event push actions to fetch. - max_stream_ordering(int): The inclusive upper bound on the - stream ordering of event push actions to fetch. - limit (int): The maximum number of rows to return. - Returns: - A promise which resolves to a list of dicts with the keys "event_id", - "room_id", "stream_ordering", "actions". - The list will be ordered by ascending stream_ordering. - The list will have between 0~limit entries. - """ - # find rooms that have a read receipt in them and return the next - # push actions - def get_after_receipt(txn): - # find rooms that have a read receipt in them and return the next - # push actions - sql = ( - "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," - " ep.highlight " - " FROM (" - " SELECT room_id," - " MAX(stream_ordering) as stream_ordering" - " FROM events" - " INNER JOIN receipts_linearized USING (room_id, event_id)" - " WHERE receipt_type = 'm.read' AND user_id = ?" - " GROUP BY room_id" - ") AS rl," - " event_push_actions AS ep" - " WHERE" - " ep.room_id = rl.room_id" - " AND ep.stream_ordering > rl.stream_ordering" - " AND ep.user_id = ?" - " AND ep.stream_ordering > ?" - " AND ep.stream_ordering <= ?" - " ORDER BY ep.stream_ordering ASC LIMIT ?" - ) - args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] - txn.execute(sql, args) - return txn.fetchall() - - after_read_receipt = yield self.runInteraction( - "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt - ) - - # There are rooms with push actions in them but you don't have a read receipt in - # them e.g. rooms you've been invited to, so get push actions for rooms which do - # not have read receipts in them too. - def get_no_receipt(txn): - sql = ( - "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," - " ep.highlight " - " FROM event_push_actions AS ep" - " INNER JOIN events AS e USING (room_id, event_id)" - " WHERE" - " ep.room_id NOT IN (" - " SELECT room_id FROM receipts_linearized" - " WHERE receipt_type = 'm.read' AND user_id = ?" - " GROUP BY room_id" - " )" - " AND ep.user_id = ?" - " AND ep.stream_ordering > ?" - " AND ep.stream_ordering <= ?" - " ORDER BY ep.stream_ordering ASC LIMIT ?" - ) - args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] - txn.execute(sql, args) - return txn.fetchall() - - no_read_receipt = yield self.runInteraction( - "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt - ) - - notifs = [ - { - "event_id": row[0], - "room_id": row[1], - "stream_ordering": row[2], - "actions": _deserialize_action(row[3], row[4]), - } - for row in after_read_receipt + no_read_receipt - ] - - # Now sort it so it's ordered correctly, since currently it will - # contain results from the first query, correctly ordered, followed - # by results from the second query, but we want them all ordered - # by stream_ordering, oldest first. - notifs.sort(key=lambda r: r["stream_ordering"]) - - # Take only up to the limit. We have to stop at the limit because - # one of the subqueries may have hit the limit. - return notifs[:limit] - - @defer.inlineCallbacks - def get_unread_push_actions_for_user_in_range_for_email( - self, user_id, min_stream_ordering, max_stream_ordering, limit=20 - ): - """Get a list of the most recent unread push actions for a given user, - within the given stream ordering range. Called by the emailpusher - - Args: - user_id (str): The user to fetch push actions for. - min_stream_ordering(int): The exclusive lower bound on the - stream ordering of event push actions to fetch. - max_stream_ordering(int): The inclusive upper bound on the - stream ordering of event push actions to fetch. - limit (int): The maximum number of rows to return. - Returns: - A promise which resolves to a list of dicts with the keys "event_id", - "room_id", "stream_ordering", "actions", "received_ts". - The list will be ordered by descending received_ts. - The list will have between 0~limit entries. - """ - # find rooms that have a read receipt in them and return the most recent - # push actions - def get_after_receipt(txn): - sql = ( - "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," - " ep.highlight, e.received_ts" - " FROM (" - " SELECT room_id," - " MAX(stream_ordering) as stream_ordering" - " FROM events" - " INNER JOIN receipts_linearized USING (room_id, event_id)" - " WHERE receipt_type = 'm.read' AND user_id = ?" - " GROUP BY room_id" - ") AS rl," - " event_push_actions AS ep" - " INNER JOIN events AS e USING (room_id, event_id)" - " WHERE" - " ep.room_id = rl.room_id" - " AND ep.stream_ordering > rl.stream_ordering" - " AND ep.user_id = ?" - " AND ep.stream_ordering > ?" - " AND ep.stream_ordering <= ?" - " ORDER BY ep.stream_ordering DESC LIMIT ?" - ) - args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] - txn.execute(sql, args) - return txn.fetchall() - - after_read_receipt = yield self.runInteraction( - "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt - ) - - # There are rooms with push actions in them but you don't have a read receipt in - # them e.g. rooms you've been invited to, so get push actions for rooms which do - # not have read receipts in them too. - def get_no_receipt(txn): - sql = ( - "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," - " ep.highlight, e.received_ts" - " FROM event_push_actions AS ep" - " INNER JOIN events AS e USING (room_id, event_id)" - " WHERE" - " ep.room_id NOT IN (" - " SELECT room_id FROM receipts_linearized" - " WHERE receipt_type = 'm.read' AND user_id = ?" - " GROUP BY room_id" - " )" - " AND ep.user_id = ?" - " AND ep.stream_ordering > ?" - " AND ep.stream_ordering <= ?" - " ORDER BY ep.stream_ordering DESC LIMIT ?" - ) - args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] - txn.execute(sql, args) - return txn.fetchall() - - no_read_receipt = yield self.runInteraction( - "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt - ) - - # Make a list of dicts from the two sets of results. - notifs = [ - { - "event_id": row[0], - "room_id": row[1], - "stream_ordering": row[2], - "actions": _deserialize_action(row[3], row[4]), - "received_ts": row[5], - } - for row in after_read_receipt + no_read_receipt - ] - - # Now sort it so it's ordered correctly, since currently it will - # contain results from the first query, correctly ordered, followed - # by results from the second query, but we want them all ordered - # by received_ts (most recent first) - notifs.sort(key=lambda r: -(r["received_ts"] or 0)) - - # Now return the first `limit` - return notifs[:limit] - - def get_if_maybe_push_in_range_for_user(self, user_id, min_stream_ordering): - """A fast check to see if there might be something to push for the - user since the given stream ordering. May return false positives. - - Useful to know whether to bother starting a pusher on start up or not. - - Args: - user_id (str) - min_stream_ordering (int) - - Returns: - Deferred[bool]: True if there may be push to process, False if - there definitely isn't. - """ - - def _get_if_maybe_push_in_range_for_user_txn(txn): - sql = """ - SELECT 1 FROM event_push_actions - WHERE user_id = ? AND stream_ordering > ? - LIMIT 1 - """ - - txn.execute(sql, (user_id, min_stream_ordering)) - return bool(txn.fetchone()) - - return self.runInteraction( - "get_if_maybe_push_in_range_for_user", - _get_if_maybe_push_in_range_for_user_txn, - ) - - def add_push_actions_to_staging(self, event_id, user_id_actions): - """Add the push actions for the event to the push action staging area. - - Args: - event_id (str) - user_id_actions (dict[str, list[dict|str])]): A dictionary mapping - user_id to list of push actions, where an action can either be - a string or dict. - - Returns: - Deferred - """ - - if not user_id_actions: - return - - # This is a helper function for generating the necessary tuple that - # can be used to inert into the `event_push_actions_staging` table. - def _gen_entry(user_id, actions): - is_highlight = 1 if _action_has_highlight(actions) else 0 - return ( - event_id, # event_id column - user_id, # user_id column - _serialize_action(actions, is_highlight), # actions column - 1, # notif column - is_highlight, # highlight column - ) - - def _add_push_actions_to_staging_txn(txn): - # We don't use _simple_insert_many here to avoid the overhead - # of generating lists of dicts. - - sql = """ - INSERT INTO event_push_actions_staging - (event_id, user_id, actions, notif, highlight) - VALUES (?, ?, ?, ?, ?) - """ - - txn.executemany( - sql, - ( - _gen_entry(user_id, actions) - for user_id, actions in iteritems(user_id_actions) - ), - ) - - return self.runInteraction( - "add_push_actions_to_staging", _add_push_actions_to_staging_txn - ) - - @defer.inlineCallbacks - def remove_push_actions_from_staging(self, event_id): - """Called if we failed to persist the event to ensure that stale push - actions don't build up in the DB - - Args: - event_id (str) - """ - - try: - res = yield self._simple_delete( - table="event_push_actions_staging", - keyvalues={"event_id": event_id}, - desc="remove_push_actions_from_staging", - ) - return res - except Exception: - # this method is called from an exception handler, so propagating - # another exception here really isn't helpful - there's nothing - # the caller can do about it. Just log the exception and move on. - logger.exception( - "Error removing push actions after event persistence failure" - ) - - def _find_stream_orderings_for_times(self): - return run_as_background_process( - "event_push_action_stream_orderings", - self.runInteraction, - "_find_stream_orderings_for_times", - self._find_stream_orderings_for_times_txn, - ) - - def _find_stream_orderings_for_times_txn(self, txn): - logger.info("Searching for stream ordering 1 month ago") - self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn( - txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000 - ) - logger.info( - "Found stream ordering 1 month ago: it's %d", self.stream_ordering_month_ago - ) - logger.info("Searching for stream ordering 1 day ago") - self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn( - txn, self._clock.time_msec() - 24 * 60 * 60 * 1000 - ) - logger.info( - "Found stream ordering 1 day ago: it's %d", self.stream_ordering_day_ago - ) - - def find_first_stream_ordering_after_ts(self, ts): - """Gets the stream ordering corresponding to a given timestamp. - - Specifically, finds the stream_ordering of the first event that was - received on or after the timestamp. This is done by a binary search on - the events table, since there is no index on received_ts, so is - relatively slow. - - Args: - ts (int): timestamp in millis - - Returns: - Deferred[int]: stream ordering of the first event received on/after - the timestamp - """ - return self.runInteraction( - "_find_first_stream_ordering_after_ts_txn", - self._find_first_stream_ordering_after_ts_txn, - ts, - ) - - @staticmethod - def _find_first_stream_ordering_after_ts_txn(txn, ts): - """ - Find the stream_ordering of the first event that was received on or - after a given timestamp. This is relatively slow as there is no index - on received_ts but we can then use this to delete push actions before - this. - - received_ts must necessarily be in the same order as stream_ordering - and stream_ordering is indexed, so we manually binary search using - stream_ordering - - Args: - txn (twisted.enterprise.adbapi.Transaction): - ts (int): timestamp to search for - - Returns: - int: stream ordering - """ - txn.execute("SELECT MAX(stream_ordering) FROM events") - max_stream_ordering = txn.fetchone()[0] - - if max_stream_ordering is None: - return 0 - - # We want the first stream_ordering in which received_ts is greater - # than or equal to ts. Call this point X. - # - # We maintain the invariants: - # - # range_start <= X <= range_end - # - range_start = 0 - range_end = max_stream_ordering + 1 - - # Given a stream_ordering, look up the timestamp at that - # stream_ordering. - # - # The array may be sparse (we may be missing some stream_orderings). - # We treat the gaps as the same as having the same value as the - # preceding entry, because we will pick the lowest stream_ordering - # which satisfies our requirement of received_ts >= ts. - # - # For example, if our array of events indexed by stream_ordering is - # [10, , 20], we should treat this as being equivalent to - # [10, 10, 20]. - # - sql = ( - "SELECT received_ts FROM events" - " WHERE stream_ordering <= ?" - " ORDER BY stream_ordering DESC" - " LIMIT 1" - ) - - while range_end - range_start > 0: - middle = (range_end + range_start) // 2 - txn.execute(sql, (middle,)) - row = txn.fetchone() - if row is None: - # no rows with stream_ordering<=middle - range_start = middle + 1 - continue - - middle_ts = row[0] - if ts > middle_ts: - # we got a timestamp lower than the one we were looking for. - # definitely need to look higher: X > middle. - range_start = middle + 1 - else: - # we got a timestamp higher than (or the same as) the one we - # were looking for. We aren't yet sure about the point we - # looked up, but we can be sure that X <= middle. - range_end = middle - - return range_end - - -class EventPushActionsStore(EventPushActionsWorkerStore): - EPA_HIGHLIGHT_INDEX = "epa_highlight_index" - - def __init__(self, db_conn, hs): - super(EventPushActionsStore, self).__init__(db_conn, hs) - - self.register_background_index_update( - self.EPA_HIGHLIGHT_INDEX, - index_name="event_push_actions_u_highlight", - table="event_push_actions", - columns=["user_id", "stream_ordering"], - ) - - self.register_background_index_update( - "event_push_actions_highlights_index", - index_name="event_push_actions_highlights_index", - table="event_push_actions", - columns=["user_id", "room_id", "topological_ordering", "stream_ordering"], - where_clause="highlight=1", - ) - - self._doing_notif_rotation = False - self._rotate_notif_loop = self._clock.looping_call( - self._start_rotate_notifs, 30 * 60 * 1000 - ) - - def _set_push_actions_for_event_and_users_txn( - self, txn, events_and_contexts, all_events_and_contexts - ): - """Handles moving push actions from staging table to main - event_push_actions table for all events in `events_and_contexts`. - - Also ensures that all events in `all_events_and_contexts` are removed - from the push action staging area. - - Args: - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting - all_events_and_contexts (list[(EventBase, EventContext)]): all - events that we were going to persist. This includes events - we've already persisted, etc, that wouldn't appear in - events_and_context. - """ - - sql = """ - INSERT INTO event_push_actions ( - room_id, event_id, user_id, actions, stream_ordering, - topological_ordering, notif, highlight - ) - SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight - FROM event_push_actions_staging - WHERE event_id = ? - """ - - if events_and_contexts: - txn.executemany( - sql, - ( - ( - event.room_id, - event.internal_metadata.stream_ordering, - event.depth, - event.event_id, - ) - for event, _ in events_and_contexts - ), - ) - - for event, _ in events_and_contexts: - user_ids = self._simple_select_onecol_txn( - txn, - table="event_push_actions_staging", - keyvalues={"event_id": event.event_id}, - retcol="user_id", - ) - - for uid in user_ids: - txn.call_after( - self.get_unread_event_push_actions_by_room_for_user.invalidate_many, - (event.room_id, uid), - ) - - # Now we delete the staging area for *all* events that were being - # persisted. - txn.executemany( - "DELETE FROM event_push_actions_staging WHERE event_id = ?", - ((event.event_id,) for event, _ in all_events_and_contexts), - ) - - @defer.inlineCallbacks - def get_push_actions_for_user( - self, user_id, before=None, limit=50, only_highlight=False - ): - def f(txn): - before_clause = "" - if before: - before_clause = "AND epa.stream_ordering < ?" - args = [user_id, before, limit] - else: - args = [user_id, limit] - - if only_highlight: - if len(before_clause) > 0: - before_clause += " " - before_clause += "AND epa.highlight = 1" - - # NB. This assumes event_ids are globally unique since - # it makes the query easier to index - sql = ( - "SELECT epa.event_id, epa.room_id," - " epa.stream_ordering, epa.topological_ordering," - " epa.actions, epa.highlight, epa.profile_tag, e.received_ts" - " FROM event_push_actions epa, events e" - " WHERE epa.event_id = e.event_id" - " AND epa.user_id = ? %s" - " ORDER BY epa.stream_ordering DESC" - " LIMIT ?" % (before_clause,) - ) - txn.execute(sql, args) - return self.cursor_to_dict(txn) - - push_actions = yield self.runInteraction("get_push_actions_for_user", f) - for pa in push_actions: - pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) - return push_actions - - @defer.inlineCallbacks - def get_time_of_last_push_action_before(self, stream_ordering): - def f(txn): - sql = ( - "SELECT e.received_ts" - " FROM event_push_actions AS ep" - " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id" - " WHERE ep.stream_ordering > ?" - " ORDER BY ep.stream_ordering ASC" - " LIMIT 1" - ) - txn.execute(sql, (stream_ordering,)) - return txn.fetchone() - - result = yield self.runInteraction("get_time_of_last_push_action_before", f) - return result[0] if result else None - - @defer.inlineCallbacks - def get_latest_push_action_stream_ordering(self): - def f(txn): - txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions") - return txn.fetchone() - - result = yield self.runInteraction("get_latest_push_action_stream_ordering", f) - return result[0] or 0 - - def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id): - # Sad that we have to blow away the cache for the whole room here - txn.call_after( - self.get_unread_event_push_actions_by_room_for_user.invalidate_many, - (room_id,), - ) - txn.execute( - "DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?", - (room_id, event_id), - ) - - def _remove_old_push_actions_before_txn( - self, txn, room_id, user_id, stream_ordering - ): - """ - Purges old push actions for a user and room before a given - stream_ordering. - - We however keep a months worth of highlighted notifications, so that - users can still get a list of recent highlights. - - Args: - txn: The transcation - room_id: Room ID to delete from - user_id: user ID to delete for - stream_ordering: The lowest stream ordering which will - not be deleted. - """ - txn.call_after( - self.get_unread_event_push_actions_by_room_for_user.invalidate_many, - (room_id, user_id), - ) - - # We need to join on the events table to get the received_ts for - # event_push_actions and sqlite won't let us use a join in a delete so - # we can't just delete where received_ts < x. Furthermore we can - # only identify event_push_actions by a tuple of room_id, event_id - # we we can't use a subquery. - # Instead, we look up the stream ordering for the last event in that - # room received before the threshold time and delete event_push_actions - # in the room with a stream_odering before that. - txn.execute( - "DELETE FROM event_push_actions " - " WHERE user_id = ? AND room_id = ? AND " - " stream_ordering <= ?" - " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)", - (user_id, room_id, stream_ordering, self.stream_ordering_month_ago), - ) - - txn.execute( - """ - DELETE FROM event_push_summary - WHERE room_id = ? AND user_id = ? AND stream_ordering <= ? - """, - (room_id, user_id, stream_ordering), - ) - - def _start_rotate_notifs(self): - return run_as_background_process("rotate_notifs", self._rotate_notifs) - - @defer.inlineCallbacks - def _rotate_notifs(self): - if self._doing_notif_rotation or self.stream_ordering_day_ago is None: - return - self._doing_notif_rotation = True - - try: - while True: - logger.info("Rotating notifications") - - caught_up = yield self.runInteraction( - "_rotate_notifs", self._rotate_notifs_txn - ) - if caught_up: - break - yield self.hs.get_clock().sleep(self._rotate_delay) - finally: - self._doing_notif_rotation = False - - def _rotate_notifs_txn(self, txn): - """Archives older notifications into event_push_summary. Returns whether - the archiving process has caught up or not. - """ - - old_rotate_stream_ordering = self._simple_select_one_onecol_txn( - txn, - table="event_push_summary_stream_ordering", - keyvalues={}, - retcol="stream_ordering", - ) - - # We don't to try and rotate millions of rows at once, so we cap the - # maximum stream ordering we'll rotate before. - txn.execute( - """ - SELECT stream_ordering FROM event_push_actions - WHERE stream_ordering > ? - ORDER BY stream_ordering ASC LIMIT 1 OFFSET ? - """, - (old_rotate_stream_ordering, self._rotate_count), - ) - stream_row = txn.fetchone() - if stream_row: - offset_stream_ordering, = stream_row - rotate_to_stream_ordering = min( - self.stream_ordering_day_ago, offset_stream_ordering - ) - caught_up = offset_stream_ordering >= self.stream_ordering_day_ago - else: - rotate_to_stream_ordering = self.stream_ordering_day_ago - caught_up = True - - logger.info("Rotating notifications up to: %s", rotate_to_stream_ordering) - - self._rotate_notifs_before_txn(txn, rotate_to_stream_ordering) - - # We have caught up iff we were limited by `stream_ordering_day_ago` - return caught_up - - def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering): - old_rotate_stream_ordering = self._simple_select_one_onecol_txn( - txn, - table="event_push_summary_stream_ordering", - keyvalues={}, - retcol="stream_ordering", - ) - - # Calculate the new counts that should be upserted into event_push_summary - sql = """ - SELECT user_id, room_id, - coalesce(old.notif_count, 0) + upd.notif_count, - upd.stream_ordering, - old.user_id - FROM ( - SELECT user_id, room_id, count(*) as notif_count, - max(stream_ordering) as stream_ordering - FROM event_push_actions - WHERE ? <= stream_ordering AND stream_ordering < ? - AND highlight = 0 - GROUP BY user_id, room_id - ) AS upd - LEFT JOIN event_push_summary AS old USING (user_id, room_id) - """ - - txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering)) - rows = txn.fetchall() - - logger.info("Rotating notifications, handling %d rows", len(rows)) - - # If the `old.user_id` above is NULL then we know there isn't already an - # entry in the table, so we simply insert it. Otherwise we update the - # existing table. - self._simple_insert_many_txn( - txn, - table="event_push_summary", - values=[ - { - "user_id": row[0], - "room_id": row[1], - "notif_count": row[2], - "stream_ordering": row[3], - } - for row in rows - if row[4] is None - ], - ) - - txn.executemany( - """ - UPDATE event_push_summary SET notif_count = ?, stream_ordering = ? - WHERE user_id = ? AND room_id = ? - """, - ((row[2], row[3], row[0], row[1]) for row in rows if row[4] is not None), - ) - - txn.execute( - "DELETE FROM event_push_actions" - " WHERE ? <= stream_ordering AND stream_ordering < ? AND highlight = 0", - (old_rotate_stream_ordering, rotate_to_stream_ordering), - ) - - logger.info("Rotating notifications, deleted %s push actions", txn.rowcount) - - txn.execute( - "UPDATE event_push_summary_stream_ordering SET stream_ordering = ?", - (rotate_to_stream_ordering,), - ) - - -def _action_has_highlight(actions): - for action in actions: - try: - if action.get("set_tweak", None) == "highlight": - return action.get("value", True) - except AttributeError: - pass - - return False diff --git a/synapse/storage/events.py b/synapse/storage/events.py deleted file mode 100644 index ee49ef235d..0000000000 --- a/synapse/storage/events.py +++ /dev/null @@ -1,2489 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018-2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import itertools -import logging -from collections import Counter as c_counter, OrderedDict, deque, namedtuple -from functools import wraps - -from six import iteritems, text_type -from six.moves import range - -from canonicaljson import json -from prometheus_client import Counter, Histogram - -from twisted.internet import defer - -import synapse.metrics -from synapse.api.constants import EventTypes -from synapse.api.errors import SynapseError -from synapse.events import EventBase # noqa: F401 -from synapse.events.snapshot import EventContext # noqa: F401 -from synapse.events.utils import prune_event_dict -from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable -from synapse.logging.utils import log_function -from synapse.metrics import BucketCollector -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.state import StateResolutionStore -from synapse.storage._base import make_in_list_sql_clause -from synapse.storage.background_updates import BackgroundUpdateStore -from synapse.storage.event_federation import EventFederationStore -from synapse.storage.events_worker import EventsWorkerStore -from synapse.storage.state import StateGroupWorkerStore -from synapse.types import RoomStreamToken, get_domain_from_id -from synapse.util import batch_iter -from synapse.util.async_helpers import ObservableDeferred -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks -from synapse.util.frozenutils import frozendict_json_encoder -from synapse.util.metrics import Measure - -logger = logging.getLogger(__name__) - -persist_event_counter = Counter("synapse_storage_events_persisted_events", "") -event_counter = Counter( - "synapse_storage_events_persisted_events_sep", - "", - ["type", "origin_type", "origin_entity"], -) - -# The number of times we are recalculating the current state -state_delta_counter = Counter("synapse_storage_events_state_delta", "") - -# The number of times we are recalculating state when there is only a -# single forward extremity -state_delta_single_event_counter = Counter( - "synapse_storage_events_state_delta_single_event", "" -) - -# The number of times we are reculating state when we could have resonably -# calculated the delta when we calculated the state for an event we were -# persisting. -state_delta_reuse_delta_counter = Counter( - "synapse_storage_events_state_delta_reuse_delta", "" -) - -# The number of forward extremities for each new event. -forward_extremities_counter = Histogram( - "synapse_storage_events_forward_extremities_persisted", - "Number of forward extremities for each new event", - buckets=(1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), -) - -# The number of stale forward extremities for each new event. Stale extremities -# are those that were in the previous set of extremities as well as the new. -stale_forward_extremities_counter = Histogram( - "synapse_storage_events_stale_forward_extremities_persisted", - "Number of unchanged forward extremities for each new event", - buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), -) - - -def encode_json(json_object): - """ - Encode a Python object as JSON and return it in a Unicode string. - """ - out = frozendict_json_encoder.encode(json_object) - if isinstance(out, bytes): - out = out.decode("utf8") - return out - - -class _EventPeristenceQueue(object): - """Queues up events so that they can be persisted in bulk with only one - concurrent transaction per room. - """ - - _EventPersistQueueItem = namedtuple( - "_EventPersistQueueItem", ("events_and_contexts", "backfilled", "deferred") - ) - - def __init__(self): - self._event_persist_queues = {} - self._currently_persisting_rooms = set() - - def add_to_queue(self, room_id, events_and_contexts, backfilled): - """Add events to the queue, with the given persist_event options. - - NB: due to the normal usage pattern of this method, it does *not* - follow the synapse logcontext rules, and leaves the logcontext in - place whether or not the returned deferred is ready. - - Args: - room_id (str): - events_and_contexts (list[(EventBase, EventContext)]): - backfilled (bool): - - Returns: - defer.Deferred: a deferred which will resolve once the events are - persisted. Runs its callbacks *without* a logcontext. - """ - queue = self._event_persist_queues.setdefault(room_id, deque()) - if queue: - # if the last item in the queue has the same `backfilled` setting, - # we can just add these new events to that item. - end_item = queue[-1] - if end_item.backfilled == backfilled: - end_item.events_and_contexts.extend(events_and_contexts) - return end_item.deferred.observe() - - deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True) - - queue.append( - self._EventPersistQueueItem( - events_and_contexts=events_and_contexts, - backfilled=backfilled, - deferred=deferred, - ) - ) - - return deferred.observe() - - def handle_queue(self, room_id, per_item_callback): - """Attempts to handle the queue for a room if not already being handled. - - The given callback will be invoked with for each item in the queue, - of type _EventPersistQueueItem. The per_item_callback will continuously - be called with new items, unless the queue becomnes empty. The return - value of the function will be given to the deferreds waiting on the item, - exceptions will be passed to the deferreds as well. - - This function should therefore be called whenever anything is added - to the queue. - - If another callback is currently handling the queue then it will not be - invoked. - """ - - if room_id in self._currently_persisting_rooms: - return - - self._currently_persisting_rooms.add(room_id) - - @defer.inlineCallbacks - def handle_queue_loop(): - try: - queue = self._get_drainining_queue(room_id) - for item in queue: - try: - ret = yield per_item_callback(item) - except Exception: - with PreserveLoggingContext(): - item.deferred.errback() - else: - with PreserveLoggingContext(): - item.deferred.callback(ret) - finally: - queue = self._event_persist_queues.pop(room_id, None) - if queue: - self._event_persist_queues[room_id] = queue - self._currently_persisting_rooms.discard(room_id) - - # set handle_queue_loop off in the background - run_as_background_process("persist_events", handle_queue_loop) - - def _get_drainining_queue(self, room_id): - queue = self._event_persist_queues.setdefault(room_id, deque()) - - try: - while True: - yield queue.popleft() - except IndexError: - # Queue has been drained. - pass - - -_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) - - -def _retry_on_integrity_error(func): - """Wraps a database function so that it gets retried on IntegrityError, - with `delete_existing=True` passed in. - - Args: - func: function that returns a Deferred and accepts a `delete_existing` arg - """ - - @wraps(func) - @defer.inlineCallbacks - def f(self, *args, **kwargs): - try: - res = yield func(self, *args, **kwargs) - except self.database_engine.module.IntegrityError: - logger.exception("IntegrityError, retrying.") - res = yield func(self, *args, delete_existing=True, **kwargs) - return res - - return f - - -# inherits from EventFederationStore so that we can call _update_backward_extremities -# and _handle_mult_prev_events (though arguably those could both be moved in here) -class EventsStore( - StateGroupWorkerStore, - EventFederationStore, - EventsWorkerStore, - BackgroundUpdateStore, -): - def __init__(self, db_conn, hs): - super(EventsStore, self).__init__(db_conn, hs) - - self._event_persist_queue = _EventPeristenceQueue() - self._state_resolution_handler = hs.get_state_resolution_handler() - - # Collect metrics on the number of forward extremities that exist. - # Counter of number of extremities to count - self._current_forward_extremities_amount = c_counter() - - BucketCollector( - "synapse_forward_extremities", - lambda: self._current_forward_extremities_amount, - buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"], - ) - - # Read the extrems every 60 minutes - def read_forward_extremities(): - # run as a background process to make sure that the database transactions - # have a logcontext to report to - return run_as_background_process( - "read_forward_extremities", self._read_forward_extremities - ) - - hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000) - - def _censor_redactions(): - return run_as_background_process( - "_censor_redactions", self._censor_redactions - ) - - if self.hs.config.redaction_retention_period is not None: - hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000) - - @defer.inlineCallbacks - def _read_forward_extremities(self): - def fetch(txn): - txn.execute( - """ - select count(*) c from event_forward_extremities - group by room_id - """ - ) - return txn.fetchall() - - res = yield self.runInteraction("read_forward_extremities", fetch) - self._current_forward_extremities_amount = c_counter(list(x[0] for x in res)) - - @defer.inlineCallbacks - def persist_events(self, events_and_contexts, backfilled=False): - """ - Write events to the database - Args: - events_and_contexts: list of tuples of (event, context) - backfilled (bool): Whether the results are retrieved from federation - via backfill or not. Used to determine if they're "new" events - which might update the current state etc. - - Returns: - Deferred[int]: the stream ordering of the latest persisted event - """ - partitioned = {} - for event, ctx in events_and_contexts: - partitioned.setdefault(event.room_id, []).append((event, ctx)) - - deferreds = [] - for room_id, evs_ctxs in iteritems(partitioned): - d = self._event_persist_queue.add_to_queue( - room_id, evs_ctxs, backfilled=backfilled - ) - deferreds.append(d) - - for room_id in partitioned: - self._maybe_start_persisting(room_id) - - yield make_deferred_yieldable( - defer.gatherResults(deferreds, consumeErrors=True) - ) - - max_persisted_id = yield self._stream_id_gen.get_current_token() - - return max_persisted_id - - @defer.inlineCallbacks - @log_function - def persist_event(self, event, context, backfilled=False): - """ - - Args: - event (EventBase): - context (EventContext): - backfilled (bool): - - Returns: - Deferred: resolves to (int, int): the stream ordering of ``event``, - and the stream ordering of the latest persisted event - """ - deferred = self._event_persist_queue.add_to_queue( - event.room_id, [(event, context)], backfilled=backfilled - ) - - self._maybe_start_persisting(event.room_id) - - yield make_deferred_yieldable(deferred) - - max_persisted_id = yield self._stream_id_gen.get_current_token() - return (event.internal_metadata.stream_ordering, max_persisted_id) - - def _maybe_start_persisting(self, room_id): - @defer.inlineCallbacks - def persisting_queue(item): - with Measure(self._clock, "persist_events"): - yield self._persist_events( - item.events_and_contexts, backfilled=item.backfilled - ) - - self._event_persist_queue.handle_queue(room_id, persisting_queue) - - @_retry_on_integrity_error - @defer.inlineCallbacks - def _persist_events( - self, events_and_contexts, backfilled=False, delete_existing=False - ): - """Persist events to db - - Args: - events_and_contexts (list[(EventBase, EventContext)]): - backfilled (bool): - delete_existing (bool): - - Returns: - Deferred: resolves when the events have been persisted - """ - if not events_and_contexts: - return - - chunks = [ - events_and_contexts[x : x + 100] - for x in range(0, len(events_and_contexts), 100) - ] - - for chunk in chunks: - # We can't easily parallelize these since different chunks - # might contain the same event. :( - - # NB: Assumes that we are only persisting events for one room - # at a time. - - # map room_id->list[event_ids] giving the new forward - # extremities in each room - new_forward_extremeties = {} - - # map room_id->(type,state_key)->event_id tracking the full - # state in each room after adding these events. - # This is simply used to prefill the get_current_state_ids - # cache - current_state_for_room = {} - - # map room_id->(to_delete, to_insert) where to_delete is a list - # of type/state keys to remove from current state, and to_insert - # is a map (type,key)->event_id giving the state delta in each - # room - state_delta_for_room = {} - - if not backfilled: - with Measure(self._clock, "_calculate_state_and_extrem"): - # Work out the new "current state" for each room. - # We do this by working out what the new extremities are and then - # calculating the state from that. - events_by_room = {} - for event, context in chunk: - events_by_room.setdefault(event.room_id, []).append( - (event, context) - ) - - for room_id, ev_ctx_rm in iteritems(events_by_room): - latest_event_ids = yield self.get_latest_event_ids_in_room( - room_id - ) - new_latest_event_ids = yield self._calculate_new_extremities( - room_id, ev_ctx_rm, latest_event_ids - ) - - latest_event_ids = set(latest_event_ids) - if new_latest_event_ids == latest_event_ids: - # No change in extremities, so no change in state - continue - - # there should always be at least one forward extremity. - # (except during the initial persistence of the send_join - # results, in which case there will be no existing - # extremities, so we'll `continue` above and skip this bit.) - assert new_latest_event_ids, "No forward extremities left!" - - new_forward_extremeties[room_id] = new_latest_event_ids - - len_1 = ( - len(latest_event_ids) == 1 - and len(new_latest_event_ids) == 1 - ) - if len_1: - all_single_prev_not_state = all( - len(event.prev_event_ids()) == 1 - and not event.is_state() - for event, ctx in ev_ctx_rm - ) - # Don't bother calculating state if they're just - # a long chain of single ancestor non-state events. - if all_single_prev_not_state: - continue - - state_delta_counter.inc() - if len(new_latest_event_ids) == 1: - state_delta_single_event_counter.inc() - - # This is a fairly handwavey check to see if we could - # have guessed what the delta would have been when - # processing one of these events. - # What we're interested in is if the latest extremities - # were the same when we created the event as they are - # now. When this server creates a new event (as opposed - # to receiving it over federation) it will use the - # forward extremities as the prev_events, so we can - # guess this by looking at the prev_events and checking - # if they match the current forward extremities. - for ev, _ in ev_ctx_rm: - prev_event_ids = set(ev.prev_event_ids()) - if latest_event_ids == prev_event_ids: - state_delta_reuse_delta_counter.inc() - break - - logger.info("Calculating state delta for room %s", room_id) - with Measure( - self._clock, "persist_events.get_new_state_after_events" - ): - res = yield self._get_new_state_after_events( - room_id, - ev_ctx_rm, - latest_event_ids, - new_latest_event_ids, - ) - current_state, delta_ids = res - - # If either are not None then there has been a change, - # and we need to work out the delta (or use that - # given) - if delta_ids is not None: - # If there is a delta we know that we've - # only added or replaced state, never - # removed keys entirely. - state_delta_for_room[room_id] = ([], delta_ids) - elif current_state is not None: - with Measure( - self._clock, "persist_events.calculate_state_delta" - ): - delta = yield self._calculate_state_delta( - room_id, current_state - ) - state_delta_for_room[room_id] = delta - - # If we have the current_state then lets prefill - # the cache with it. - if current_state is not None: - current_state_for_room[room_id] = current_state - - # We want to calculate the stream orderings as late as possible, as - # we only notify after all events with a lesser stream ordering have - # been persisted. I.e. if we spend 10s inside the with block then - # that will delay all subsequent events from being notified about. - # Hence why we do it down here rather than wrapping the entire - # function. - # - # Its safe to do this after calculating the state deltas etc as we - # only need to protect the *persistence* of the events. This is to - # ensure that queries of the form "fetch events since X" don't - # return events and stream positions after events that are still in - # flight, as otherwise subsequent requests "fetch event since Y" - # will not return those events. - # - # Note: Multiple instances of this function cannot be in flight at - # the same time for the same room. - if backfilled: - stream_ordering_manager = self._backfill_id_gen.get_next_mult( - len(chunk) - ) - else: - stream_ordering_manager = self._stream_id_gen.get_next_mult(len(chunk)) - - with stream_ordering_manager as stream_orderings: - for (event, context), stream in zip(chunk, stream_orderings): - event.internal_metadata.stream_ordering = stream - - yield self.runInteraction( - "persist_events", - self._persist_events_txn, - events_and_contexts=chunk, - backfilled=backfilled, - delete_existing=delete_existing, - state_delta_for_room=state_delta_for_room, - new_forward_extremeties=new_forward_extremeties, - ) - persist_event_counter.inc(len(chunk)) - - if not backfilled: - # backfilled events have negative stream orderings, so we don't - # want to set the event_persisted_position to that. - synapse.metrics.event_persisted_position.set( - chunk[-1][0].internal_metadata.stream_ordering - ) - - for event, context in chunk: - if context.app_service: - origin_type = "local" - origin_entity = context.app_service.id - elif self.hs.is_mine_id(event.sender): - origin_type = "local" - origin_entity = "*client*" - else: - origin_type = "remote" - origin_entity = get_domain_from_id(event.sender) - - event_counter.labels(event.type, origin_type, origin_entity).inc() - - for room_id, new_state in iteritems(current_state_for_room): - self.get_current_state_ids.prefill((room_id,), new_state) - - for room_id, latest_event_ids in iteritems(new_forward_extremeties): - self.get_latest_event_ids_in_room.prefill( - (room_id,), list(latest_event_ids) - ) - - @defer.inlineCallbacks - def _calculate_new_extremities(self, room_id, event_contexts, latest_event_ids): - """Calculates the new forward extremities for a room given events to - persist. - - Assumes that we are only persisting events for one room at a time. - """ - - # we're only interested in new events which aren't outliers and which aren't - # being rejected. - new_events = [ - event - for event, ctx in event_contexts - if not event.internal_metadata.is_outlier() - and not ctx.rejected - and not event.internal_metadata.is_soft_failed() - ] - - latest_event_ids = set(latest_event_ids) - - # start with the existing forward extremities - result = set(latest_event_ids) - - # add all the new events to the list - result.update(event.event_id for event in new_events) - - # Now remove all events which are prev_events of any of the new events - result.difference_update( - e_id for event in new_events for e_id in event.prev_event_ids() - ) - - # Remove any events which are prev_events of any existing events. - existing_prevs = yield self._get_events_which_are_prevs(result) - result.difference_update(existing_prevs) - - # Finally handle the case where the new events have soft-failed prev - # events. If they do we need to remove them and their prev events, - # otherwise we end up with dangling extremities. - existing_prevs = yield self._get_prevs_before_rejected( - e_id for event in new_events for e_id in event.prev_event_ids() - ) - result.difference_update(existing_prevs) - - # We only update metrics for events that change forward extremities - # (e.g. we ignore backfill/outliers/etc) - if result != latest_event_ids: - forward_extremities_counter.observe(len(result)) - stale = latest_event_ids & result - stale_forward_extremities_counter.observe(len(stale)) - - return result - - @defer.inlineCallbacks - def _get_events_which_are_prevs(self, event_ids): - """Filter the supplied list of event_ids to get those which are prev_events of - existing (non-outlier/rejected) events. - - Args: - event_ids (Iterable[str]): event ids to filter - - Returns: - Deferred[List[str]]: filtered event ids - """ - results = [] - - def _get_events_which_are_prevs_txn(txn, batch): - sql = """ - SELECT prev_event_id, internal_metadata - FROM event_edges - INNER JOIN events USING (event_id) - LEFT JOIN rejections USING (event_id) - LEFT JOIN event_json USING (event_id) - WHERE - NOT events.outlier - AND rejections.event_id IS NULL - AND - """ - - clause, args = make_in_list_sql_clause( - self.database_engine, "prev_event_id", batch - ) - - txn.execute(sql + clause, args) - results.extend(r[0] for r in txn if not json.loads(r[1]).get("soft_failed")) - - for chunk in batch_iter(event_ids, 100): - yield self.runInteraction( - "_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk - ) - - return results - - @defer.inlineCallbacks - def _get_prevs_before_rejected(self, event_ids): - """Get soft-failed ancestors to remove from the extremities. - - Given a set of events, find all those that have been soft-failed or - rejected. Returns those soft failed/rejected events and their prev - events (whether soft-failed/rejected or not), and recurses up the - prev-event graph until it finds no more soft-failed/rejected events. - - This is used to find extremities that are ancestors of new events, but - are separated by soft failed events. - - Args: - event_ids (Iterable[str]): Events to find prev events for. Note - that these must have already been persisted. - - Returns: - Deferred[set[str]] - """ - - # The set of event_ids to return. This includes all soft-failed events - # and their prev events. - existing_prevs = set() - - def _get_prevs_before_rejected_txn(txn, batch): - to_recursively_check = batch - - while to_recursively_check: - sql = """ - SELECT - event_id, prev_event_id, internal_metadata, - rejections.event_id IS NOT NULL - FROM event_edges - INNER JOIN events USING (event_id) - LEFT JOIN rejections USING (event_id) - LEFT JOIN event_json USING (event_id) - WHERE - NOT events.outlier - AND - """ - - clause, args = make_in_list_sql_clause( - self.database_engine, "event_id", to_recursively_check - ) - - txn.execute(sql + clause, args) - to_recursively_check = [] - - for event_id, prev_event_id, metadata, rejected in txn: - if prev_event_id in existing_prevs: - continue - - soft_failed = json.loads(metadata).get("soft_failed") - if soft_failed or rejected: - to_recursively_check.append(prev_event_id) - existing_prevs.add(prev_event_id) - - for chunk in batch_iter(event_ids, 100): - yield self.runInteraction( - "_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk - ) - - return existing_prevs - - @defer.inlineCallbacks - def _get_new_state_after_events( - self, room_id, events_context, old_latest_event_ids, new_latest_event_ids - ): - """Calculate the current state dict after adding some new events to - a room - - Args: - room_id (str): - room to which the events are being added. Used for logging etc - - events_context (list[(EventBase, EventContext)]): - events and contexts which are being added to the room - - old_latest_event_ids (iterable[str]): - the old forward extremities for the room. - - new_latest_event_ids (iterable[str]): - the new forward extremities for the room. - - Returns: - Deferred[tuple[dict[(str,str), str]|None, dict[(str,str), str]|None]]: - Returns a tuple of two state maps, the first being the full new current - state and the second being the delta to the existing current state. - If both are None then there has been no change. - - If there has been a change then we only return the delta if its - already been calculated. Conversely if we do know the delta then - the new current state is only returned if we've already calculated - it. - """ - # map from state_group to ((type, key) -> event_id) state map - state_groups_map = {} - - # Map from (prev state group, new state group) -> delta state dict - state_group_deltas = {} - - for ev, ctx in events_context: - if ctx.state_group is None: - # This should only happen for outlier events. - if not ev.internal_metadata.is_outlier(): - raise Exception( - "Context for new event %s has no state " - "group" % (ev.event_id,) - ) - continue - - if ctx.state_group in state_groups_map: - continue - - # We're only interested in pulling out state that has already - # been cached in the context. We'll pull stuff out of the DB later - # if necessary. - current_state_ids = ctx.get_cached_current_state_ids() - if current_state_ids is not None: - state_groups_map[ctx.state_group] = current_state_ids - - if ctx.prev_group: - state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids - - # We need to map the event_ids to their state groups. First, let's - # check if the event is one we're persisting, in which case we can - # pull the state group from its context. - # Otherwise we need to pull the state group from the database. - - # Set of events we need to fetch groups for. (We know none of the old - # extremities are going to be in events_context). - missing_event_ids = set(old_latest_event_ids) - - event_id_to_state_group = {} - for event_id in new_latest_event_ids: - # First search in the list of new events we're adding. - for ev, ctx in events_context: - if event_id == ev.event_id and ctx.state_group is not None: - event_id_to_state_group[event_id] = ctx.state_group - break - else: - # If we couldn't find it, then we'll need to pull - # the state from the database - missing_event_ids.add(event_id) - - if missing_event_ids: - # Now pull out the state groups for any missing events from DB - event_to_groups = yield self._get_state_group_for_events(missing_event_ids) - event_id_to_state_group.update(event_to_groups) - - # State groups of old_latest_event_ids - old_state_groups = set( - event_id_to_state_group[evid] for evid in old_latest_event_ids - ) - - # State groups of new_latest_event_ids - new_state_groups = set( - event_id_to_state_group[evid] for evid in new_latest_event_ids - ) - - # If they old and new groups are the same then we don't need to do - # anything. - if old_state_groups == new_state_groups: - return None, None - - if len(new_state_groups) == 1 and len(old_state_groups) == 1: - # If we're going from one state group to another, lets check if - # we have a delta for that transition. If we do then we can just - # return that. - - new_state_group = next(iter(new_state_groups)) - old_state_group = next(iter(old_state_groups)) - - delta_ids = state_group_deltas.get((old_state_group, new_state_group), None) - if delta_ids is not None: - # We have a delta from the existing to new current state, - # so lets just return that. If we happen to already have - # the current state in memory then lets also return that, - # but it doesn't matter if we don't. - new_state = state_groups_map.get(new_state_group) - return new_state, delta_ids - - # Now that we have calculated new_state_groups we need to get - # their state IDs so we can resolve to a single state set. - missing_state = new_state_groups - set(state_groups_map) - if missing_state: - group_to_state = yield self._get_state_for_groups(missing_state) - state_groups_map.update(group_to_state) - - if len(new_state_groups) == 1: - # If there is only one state group, then we know what the current - # state is. - return state_groups_map[new_state_groups.pop()], None - - # Ok, we need to defer to the state handler to resolve our state sets. - - state_groups = {sg: state_groups_map[sg] for sg in new_state_groups} - - events_map = {ev.event_id: ev for ev, _ in events_context} - - # We need to get the room version, which is in the create event. - # Normally that'd be in the database, but its also possible that we're - # currently trying to persist it. - room_version = None - for ev, _ in events_context: - if ev.type == EventTypes.Create and ev.state_key == "": - room_version = ev.content.get("room_version", "1") - break - - if not room_version: - room_version = yield self.get_room_version(room_id) - - logger.debug("calling resolve_state_groups from preserve_events") - res = yield self._state_resolution_handler.resolve_state_groups( - room_id, - room_version, - state_groups, - events_map, - state_res_store=StateResolutionStore(self), - ) - - return res.state, None - - @defer.inlineCallbacks - def _calculate_state_delta(self, room_id, current_state): - """Calculate the new state deltas for a room. - - Assumes that we are only persisting events for one room at a time. - - Returns: - tuple[list, dict] (to_delete, to_insert): where to_delete are the - type/state_keys to remove from current_state_events and `to_insert` - are the updates to current_state_events. - """ - existing_state = yield self.get_current_state_ids(room_id) - - to_delete = [key for key in existing_state if key not in current_state] - - to_insert = { - key: ev_id - for key, ev_id in iteritems(current_state) - if ev_id != existing_state.get(key) - } - - return to_delete, to_insert - - @log_function - def _persist_events_txn( - self, - txn, - events_and_contexts, - backfilled, - delete_existing=False, - state_delta_for_room={}, - new_forward_extremeties={}, - ): - """Insert some number of room events into the necessary database tables. - - Rejected events are only inserted into the events table, the events_json table, - and the rejections table. Things reading from those table will need to check - whether the event was rejected. - - Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): - events to persist - backfilled (bool): True if the events were backfilled - delete_existing (bool): True to purge existing table rows for the - events from the database. This is useful when retrying due to - IntegrityError. - state_delta_for_room (dict[str, (list, dict)]): - The current-state delta for each room. For each room, a tuple - (to_delete, to_insert), being a list of type/state keys to be - removed from the current state, and a state set to be added to - the current state. - new_forward_extremeties (dict[str, list[str]]): - The new forward extremities for each room. For each room, a - list of the event ids which are the forward extremities. - - """ - all_events_and_contexts = events_and_contexts - - min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering - max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering - - self._update_forward_extremities_txn( - txn, - new_forward_extremities=new_forward_extremeties, - max_stream_order=max_stream_order, - ) - - # Ensure that we don't have the same event twice. - events_and_contexts = self._filter_events_and_contexts_for_duplicates( - events_and_contexts - ) - - self._update_room_depths_txn( - txn, events_and_contexts=events_and_contexts, backfilled=backfilled - ) - - # _update_outliers_txn filters out any events which have already been - # persisted, and returns the filtered list. - events_and_contexts = self._update_outliers_txn( - txn, events_and_contexts=events_and_contexts - ) - - # From this point onwards the events are only events that we haven't - # seen before. - - if delete_existing: - # For paranoia reasons, we go and delete all the existing entries - # for these events so we can reinsert them. - # This gets around any problems with some tables already having - # entries. - self._delete_existing_rows_txn(txn, events_and_contexts=events_and_contexts) - - self._store_event_txn(txn, events_and_contexts=events_and_contexts) - - # Insert into event_to_state_groups. - self._store_event_state_mappings_txn(txn, events_and_contexts) - - # We want to store event_auth mappings for rejected events, as they're - # used in state res v2. - # This is only necessary if the rejected event appears in an accepted - # event's auth chain, but its easier for now just to store them (and - # it doesn't take much storage compared to storing the entire event - # anyway). - self._simple_insert_many_txn( - txn, - table="event_auth", - values=[ - { - "event_id": event.event_id, - "room_id": event.room_id, - "auth_id": auth_id, - } - for event, _ in events_and_contexts - for auth_id in event.auth_event_ids() - if event.is_state() - ], - ) - - # _store_rejected_events_txn filters out any events which were - # rejected, and returns the filtered list. - events_and_contexts = self._store_rejected_events_txn( - txn, events_and_contexts=events_and_contexts - ) - - # From this point onwards the events are only ones that weren't - # rejected. - - self._update_metadata_tables_txn( - txn, - events_and_contexts=events_and_contexts, - all_events_and_contexts=all_events_and_contexts, - backfilled=backfilled, - ) - - # We call this last as it assumes we've inserted the events into - # room_memberships, where applicable. - self._update_current_state_txn(txn, state_delta_for_room, min_stream_order) - - def _update_current_state_txn(self, txn, state_delta_by_room, stream_id): - for room_id, current_state_tuple in iteritems(state_delta_by_room): - to_delete, to_insert = current_state_tuple - - # First we add entries to the current_state_delta_stream. We - # do this before updating the current_state_events table so - # that we can use it to calculate the `prev_event_id`. (This - # allows us to not have to pull out the existing state - # unnecessarily). - # - # The stream_id for the update is chosen to be the minimum of the stream_ids - # for the batch of the events that we are persisting; that means we do not - # end up in a situation where workers see events before the - # current_state_delta updates. - # - sql = """ - INSERT INTO current_state_delta_stream - (stream_id, room_id, type, state_key, event_id, prev_event_id) - SELECT ?, ?, ?, ?, ?, ( - SELECT event_id FROM current_state_events - WHERE room_id = ? AND type = ? AND state_key = ? - ) - """ - txn.executemany( - sql, - ( - ( - stream_id, - room_id, - etype, - state_key, - None, - room_id, - etype, - state_key, - ) - for etype, state_key in to_delete - # We sanity check that we're deleting rather than updating - if (etype, state_key) not in to_insert - ), - ) - txn.executemany( - sql, - ( - ( - stream_id, - room_id, - etype, - state_key, - ev_id, - room_id, - etype, - state_key, - ) - for (etype, state_key), ev_id in iteritems(to_insert) - ), - ) - - # Now we actually update the current_state_events table - - txn.executemany( - "DELETE FROM current_state_events" - " WHERE room_id = ? AND type = ? AND state_key = ?", - ( - (room_id, etype, state_key) - for etype, state_key in itertools.chain(to_delete, to_insert) - ), - ) - - # We include the membership in the current state table, hence we do - # a lookup when we insert. This assumes that all events have already - # been inserted into room_memberships. - txn.executemany( - """INSERT INTO current_state_events - (room_id, type, state_key, event_id, membership) - VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?)) - """, - [ - (room_id, key[0], key[1], ev_id, ev_id) - for key, ev_id in iteritems(to_insert) - ], - ) - - txn.call_after( - self._curr_state_delta_stream_cache.entity_has_changed, - room_id, - stream_id, - ) - - # Invalidate the various caches - - # Figure out the changes of membership to invalidate the - # `get_rooms_for_user` cache. - # We find out which membership events we may have deleted - # and which we have added, then we invlidate the caches for all - # those users. - members_changed = set( - state_key - for ev_type, state_key in itertools.chain(to_delete, to_insert) - if ev_type == EventTypes.Member - ) - - for member in members_changed: - txn.call_after( - self.get_rooms_for_user_with_stream_ordering.invalidate, (member,) - ) - - self._invalidate_state_caches_and_stream(txn, room_id, members_changed) - - def _update_forward_extremities_txn( - self, txn, new_forward_extremities, max_stream_order - ): - for room_id, new_extrem in iteritems(new_forward_extremities): - self._simple_delete_txn( - txn, table="event_forward_extremities", keyvalues={"room_id": room_id} - ) - txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) - - self._simple_insert_many_txn( - txn, - table="event_forward_extremities", - values=[ - {"event_id": ev_id, "room_id": room_id} - for room_id, new_extrem in iteritems(new_forward_extremities) - for ev_id in new_extrem - ], - ) - # We now insert into stream_ordering_to_exterm a mapping from room_id, - # new stream_ordering to new forward extremeties in the room. - # This allows us to later efficiently look up the forward extremeties - # for a room before a given stream_ordering - self._simple_insert_many_txn( - txn, - table="stream_ordering_to_exterm", - values=[ - { - "room_id": room_id, - "event_id": event_id, - "stream_ordering": max_stream_order, - } - for room_id, new_extrem in iteritems(new_forward_extremities) - for event_id in new_extrem - ], - ) - - @classmethod - def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts): - """Ensure that we don't have the same event twice. - - Pick the earliest non-outlier if there is one, else the earliest one. - - Args: - events_and_contexts (list[(EventBase, EventContext)]): - Returns: - list[(EventBase, EventContext)]: filtered list - """ - new_events_and_contexts = OrderedDict() - for event, context in events_and_contexts: - prev_event_context = new_events_and_contexts.get(event.event_id) - if prev_event_context: - if not event.internal_metadata.is_outlier(): - if prev_event_context[0].internal_metadata.is_outlier(): - # To ensure correct ordering we pop, as OrderedDict is - # ordered by first insertion. - new_events_and_contexts.pop(event.event_id, None) - new_events_and_contexts[event.event_id] = (event, context) - else: - new_events_and_contexts[event.event_id] = (event, context) - return list(new_events_and_contexts.values()) - - def _update_room_depths_txn(self, txn, events_and_contexts, backfilled): - """Update min_depth for each room - - Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting - backfilled (bool): True if the events were backfilled - """ - depth_updates = {} - for event, context in events_and_contexts: - # Remove the any existing cache entries for the event_ids - txn.call_after(self._invalidate_get_event_cache, event.event_id) - if not backfilled: - txn.call_after( - self._events_stream_cache.entity_has_changed, - event.room_id, - event.internal_metadata.stream_ordering, - ) - - if not event.internal_metadata.is_outlier() and not context.rejected: - depth_updates[event.room_id] = max( - event.depth, depth_updates.get(event.room_id, event.depth) - ) - - for room_id, depth in iteritems(depth_updates): - self._update_min_depth_for_room_txn(txn, room_id, depth) - - def _update_outliers_txn(self, txn, events_and_contexts): - """Update any outliers with new event info. - - This turns outliers into ex-outliers (unless the new event was - rejected). - - Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting - - Returns: - list[(EventBase, EventContext)] new list, without events which - are already in the events table. - """ - txn.execute( - "SELECT event_id, outlier FROM events WHERE event_id in (%s)" - % (",".join(["?"] * len(events_and_contexts)),), - [event.event_id for event, _ in events_and_contexts], - ) - - have_persisted = {event_id: outlier for event_id, outlier in txn} - - to_remove = set() - for event, context in events_and_contexts: - if event.event_id not in have_persisted: - continue - - to_remove.add(event) - - if context.rejected: - # If the event is rejected then we don't care if the event - # was an outlier or not. - continue - - outlier_persisted = have_persisted[event.event_id] - if not event.internal_metadata.is_outlier() and outlier_persisted: - # We received a copy of an event that we had already stored as - # an outlier in the database. We now have some state at that - # so we need to update the state_groups table with that state. - - # insert into event_to_state_groups. - try: - self._store_event_state_mappings_txn(txn, ((event, context),)) - except Exception: - logger.exception("") - raise - - metadata_json = encode_json(event.internal_metadata.get_dict()) - - sql = ( - "UPDATE event_json SET internal_metadata = ?" " WHERE event_id = ?" - ) - txn.execute(sql, (metadata_json, event.event_id)) - - # Add an entry to the ex_outlier_stream table to replicate the - # change in outlier status to our workers. - stream_order = event.internal_metadata.stream_ordering - state_group_id = context.state_group - self._simple_insert_txn( - txn, - table="ex_outlier_stream", - values={ - "event_stream_ordering": stream_order, - "event_id": event.event_id, - "state_group": state_group_id, - }, - ) - - sql = "UPDATE events SET outlier = ?" " WHERE event_id = ?" - txn.execute(sql, (False, event.event_id)) - - # Update the event_backward_extremities table now that this - # event isn't an outlier any more. - self._update_backward_extremeties(txn, [event]) - - return [ec for ec in events_and_contexts if ec[0] not in to_remove] - - @classmethod - def _delete_existing_rows_txn(cls, txn, events_and_contexts): - if not events_and_contexts: - # nothing to do here - return - - logger.info("Deleting existing") - - for table in ( - "events", - "event_auth", - "event_json", - "event_edges", - "event_forward_extremities", - "event_reference_hashes", - "event_search", - "event_to_state_groups", - "local_invites", - "state_events", - "rejections", - "redactions", - "room_memberships", - ): - txn.executemany( - "DELETE FROM %s WHERE event_id = ?" % (table,), - [(ev.event_id,) for ev, _ in events_and_contexts], - ) - - for table in ("event_push_actions",): - txn.executemany( - "DELETE FROM %s WHERE room_id = ? AND event_id = ?" % (table,), - [(ev.room_id, ev.event_id) for ev, _ in events_and_contexts], - ) - - def _store_event_txn(self, txn, events_and_contexts): - """Insert new events into the event and event_json tables - - Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting - """ - - if not events_and_contexts: - # nothing to do here - return - - def event_dict(event): - d = event.get_dict() - d.pop("redacted", None) - d.pop("redacted_because", None) - return d - - self._simple_insert_many_txn( - txn, - table="event_json", - values=[ - { - "event_id": event.event_id, - "room_id": event.room_id, - "internal_metadata": encode_json( - event.internal_metadata.get_dict() - ), - "json": encode_json(event_dict(event)), - "format_version": event.format_version, - } - for event, _ in events_and_contexts - ], - ) - - self._simple_insert_many_txn( - txn, - table="events", - values=[ - { - "stream_ordering": event.internal_metadata.stream_ordering, - "topological_ordering": event.depth, - "depth": event.depth, - "event_id": event.event_id, - "room_id": event.room_id, - "type": event.type, - "processed": True, - "outlier": event.internal_metadata.is_outlier(), - "origin_server_ts": int(event.origin_server_ts), - "received_ts": self._clock.time_msec(), - "sender": event.sender, - "contains_url": ( - "url" in event.content - and isinstance(event.content["url"], text_type) - ), - } - for event, _ in events_and_contexts - ], - ) - - for event, _ in events_and_contexts: - if not event.internal_metadata.is_redacted(): - # If we're persisting an unredacted event we go and ensure - # that we mark any redactions that reference this event as - # requiring censoring. - self._simple_update_txn( - txn, - table="redactions", - keyvalues={"redacts": event.event_id}, - updatevalues={"have_censored": False}, - ) - - def _store_rejected_events_txn(self, txn, events_and_contexts): - """Add rows to the 'rejections' table for received events which were - rejected - - Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting - - Returns: - list[(EventBase, EventContext)] new list, without the rejected - events. - """ - # Remove the rejected events from the list now that we've added them - # to the events table and the events_json table. - to_remove = set() - for event, context in events_and_contexts: - if context.rejected: - # Insert the event_id into the rejections table - self._store_rejections_txn(txn, event.event_id, context.rejected) - to_remove.add(event) - - return [ec for ec in events_and_contexts if ec[0] not in to_remove] - - def _update_metadata_tables_txn( - self, txn, events_and_contexts, all_events_and_contexts, backfilled - ): - """Update all the miscellaneous tables for new events - - Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting - all_events_and_contexts (list[(EventBase, EventContext)]): all - events that we were going to persist. This includes events - we've already persisted, etc, that wouldn't appear in - events_and_context. - backfilled (bool): True if the events were backfilled - """ - - # Insert all the push actions into the event_push_actions table. - self._set_push_actions_for_event_and_users_txn( - txn, - events_and_contexts=events_and_contexts, - all_events_and_contexts=all_events_and_contexts, - ) - - if not events_and_contexts: - # nothing to do here - return - - for event, context in events_and_contexts: - if event.type == EventTypes.Redaction and event.redacts is not None: - # Remove the entries in the event_push_actions table for the - # redacted event. - self._remove_push_actions_for_event_id_txn( - txn, event.room_id, event.redacts - ) - - # Remove from relations table. - self._handle_redaction(txn, event.redacts) - - # Update the event_forward_extremities, event_backward_extremities and - # event_edges tables. - self._handle_mult_prev_events( - txn, events=[event for event, _ in events_and_contexts] - ) - - for event, _ in events_and_contexts: - if event.type == EventTypes.Name: - # Insert into the event_search table. - self._store_room_name_txn(txn, event) - elif event.type == EventTypes.Topic: - # Insert into the event_search table. - self._store_room_topic_txn(txn, event) - elif event.type == EventTypes.Message: - # Insert into the event_search table. - self._store_room_message_txn(txn, event) - elif event.type == EventTypes.Redaction: - # Insert into the redactions table. - self._store_redaction(txn, event) - - self._handle_event_relations(txn, event) - - # Insert into the room_memberships table. - self._store_room_members_txn( - txn, - [ - event - for event, _ in events_and_contexts - if event.type == EventTypes.Member - ], - backfilled=backfilled, - ) - - # Insert event_reference_hashes table. - self._store_event_reference_hashes_txn( - txn, [event for event, _ in events_and_contexts] - ) - - state_events_and_contexts = [ - ec for ec in events_and_contexts if ec[0].is_state() - ] - - state_values = [] - for event, context in state_events_and_contexts: - vals = { - "event_id": event.event_id, - "room_id": event.room_id, - "type": event.type, - "state_key": event.state_key, - } - - # TODO: How does this work with backfilling? - if hasattr(event, "replaces_state"): - vals["prev_state"] = event.replaces_state - - state_values.append(vals) - - self._simple_insert_many_txn(txn, table="state_events", values=state_values) - - # Prefill the event cache - self._add_to_cache(txn, events_and_contexts) - - def _add_to_cache(self, txn, events_and_contexts): - to_prefill = [] - - rows = [] - N = 200 - for i in range(0, len(events_and_contexts), N): - ev_map = {e[0].event_id: e[0] for e in events_and_contexts[i : i + N]} - if not ev_map: - break - - sql = ( - "SELECT " - " e.event_id as event_id, " - " r.redacts as redacts," - " rej.event_id as rejects " - " FROM events as e" - " LEFT JOIN rejections as rej USING (event_id)" - " LEFT JOIN redactions as r ON e.event_id = r.redacts" - " WHERE " - ) - - clause, args = make_in_list_sql_clause( - self.database_engine, "e.event_id", list(ev_map) - ) - - txn.execute(sql + clause, args) - rows = self.cursor_to_dict(txn) - for row in rows: - event = ev_map[row["event_id"]] - if not row["rejects"] and not row["redacts"]: - to_prefill.append( - _EventCacheEntry(event=event, redacted_event=None) - ) - - def prefill(): - for cache_entry in to_prefill: - self._get_event_cache.prefill((cache_entry[0].event_id,), cache_entry) - - txn.call_after(prefill) - - def _store_redaction(self, txn, event): - # invalidate the cache for the redacted event - txn.call_after(self._invalidate_get_event_cache, event.redacts) - - self._simple_insert_txn( - txn, - table="redactions", - values={ - "event_id": event.event_id, - "redacts": event.redacts, - "received_ts": self._clock.time_msec(), - }, - ) - - @defer.inlineCallbacks - def _censor_redactions(self): - """Censors all redactions older than the configured period that haven't - been censored yet. - - By censor we mean update the event_json table with the redacted event. - - Returns: - Deferred - """ - - if self.hs.config.redaction_retention_period is None: - return - - before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period - - # We fetch all redactions that: - # 1. point to an event we have, - # 2. has a received_ts from before the cut off, and - # 3. we haven't yet censored. - # - # This is limited to 100 events to ensure that we don't try and do too - # much at once. We'll get called again so this should eventually catch - # up. - sql = """ - SELECT redactions.event_id, redacts FROM redactions - LEFT JOIN events AS original_event ON ( - redacts = original_event.event_id - ) - WHERE NOT have_censored - AND redactions.received_ts <= ? - ORDER BY redactions.received_ts ASC - LIMIT ? - """ - - rows = yield self._execute( - "_censor_redactions_fetch", None, sql, before_ts, 100 - ) - - updates = [] - - for redaction_id, event_id in rows: - redaction_event = yield self.get_event(redaction_id, allow_none=True) - original_event = yield self.get_event( - event_id, allow_rejected=True, allow_none=True - ) - - # The SQL above ensures that we have both the redaction and - # original event, so if the `get_event` calls return None it - # means that the redaction wasn't allowed. Either way we know that - # the result won't change so we mark the fact that we've checked. - if ( - redaction_event - and original_event - and original_event.internal_metadata.is_redacted() - ): - # Redaction was allowed - pruned_json = encode_json(prune_event_dict(original_event.get_dict())) - else: - # Redaction wasn't allowed - pruned_json = None - - updates.append((redaction_id, event_id, pruned_json)) - - def _update_censor_txn(txn): - for redaction_id, event_id, pruned_json in updates: - if pruned_json: - self._simple_update_one_txn( - txn, - table="event_json", - keyvalues={"event_id": event_id}, - updatevalues={"json": pruned_json}, - ) - - self._simple_update_one_txn( - txn, - table="redactions", - keyvalues={"event_id": redaction_id}, - updatevalues={"have_censored": True}, - ) - - yield self.runInteraction("_update_censor_txn", _update_censor_txn) - - @defer.inlineCallbacks - def count_daily_messages(self): - """ - Returns an estimate of the number of messages sent in the last day. - - If it has been significantly less or more than one day since the last - call to this function, it will return None. - """ - - def _count_messages(txn): - sql = """ - SELECT COALESCE(COUNT(*), 0) FROM events - WHERE type = 'm.room.message' - AND stream_ordering > ? - """ - txn.execute(sql, (self.stream_ordering_day_ago,)) - count, = txn.fetchone() - return count - - ret = yield self.runInteraction("count_messages", _count_messages) - return ret - - @defer.inlineCallbacks - def count_daily_sent_messages(self): - def _count_messages(txn): - # This is good enough as if you have silly characters in your own - # hostname then thats your own fault. - like_clause = "%:" + self.hs.hostname - - sql = """ - SELECT COALESCE(COUNT(*), 0) FROM events - WHERE type = 'm.room.message' - AND sender LIKE ? - AND stream_ordering > ? - """ - - txn.execute(sql, (like_clause, self.stream_ordering_day_ago)) - count, = txn.fetchone() - return count - - ret = yield self.runInteraction("count_daily_sent_messages", _count_messages) - return ret - - @defer.inlineCallbacks - def count_daily_active_rooms(self): - def _count(txn): - sql = """ - SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events - WHERE type = 'm.room.message' - AND stream_ordering > ? - """ - txn.execute(sql, (self.stream_ordering_day_ago,)) - count, = txn.fetchone() - return count - - ret = yield self.runInteraction("count_daily_active_rooms", _count) - return ret - - def get_current_backfill_token(self): - """The current minimum token that backfilled events have reached""" - return -self._backfill_id_gen.get_current_token() - - def get_current_events_token(self): - """The current maximum token that events have reached""" - return self._stream_id_gen.get_current_token() - - def get_all_new_forward_event_rows(self, last_id, current_id, limit): - if last_id == current_id: - return defer.succeed([]) - - def get_all_new_forward_event_rows(txn): - sql = ( - "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? < stream_ordering AND stream_ordering <= ?" - " ORDER BY stream_ordering ASC" - " LIMIT ?" - ) - txn.execute(sql, (last_id, current_id, limit)) - new_event_updates = txn.fetchall() - - if len(new_event_updates) == limit: - upper_bound = new_event_updates[-1][0] - else: - upper_bound = current_id - - sql = ( - "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " INNER JOIN ex_outlier_stream USING (event_id)" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? < event_stream_ordering" - " AND event_stream_ordering <= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (last_id, upper_bound)) - new_event_updates.extend(txn) - - return new_event_updates - - return self.runInteraction( - "get_all_new_forward_event_rows", get_all_new_forward_event_rows - ) - - def get_all_new_backfill_event_rows(self, last_id, current_id, limit): - if last_id == current_id: - return defer.succeed([]) - - def get_all_new_backfill_event_rows(txn): - sql = ( - "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? > stream_ordering AND stream_ordering >= ?" - " ORDER BY stream_ordering ASC" - " LIMIT ?" - ) - txn.execute(sql, (-last_id, -current_id, limit)) - new_event_updates = txn.fetchall() - - if len(new_event_updates) == limit: - upper_bound = new_event_updates[-1][0] - else: - upper_bound = current_id - - sql = ( - "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " INNER JOIN ex_outlier_stream USING (event_id)" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? > event_stream_ordering" - " AND event_stream_ordering >= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (-last_id, -upper_bound)) - new_event_updates.extend(txn.fetchall()) - - return new_event_updates - - return self.runInteraction( - "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows - ) - - @cached(num_args=5, max_entries=10) - def get_all_new_events( - self, - last_backfill_id, - last_forward_id, - current_backfill_id, - current_forward_id, - limit, - ): - """Get all the new events that have arrived at the server either as - new events or as backfilled events""" - have_backfill_events = last_backfill_id != current_backfill_id - have_forward_events = last_forward_id != current_forward_id - - if not have_backfill_events and not have_forward_events: - return defer.succeed(AllNewEventsResult([], [], [], [], [])) - - def get_all_new_events_txn(txn): - sql = ( - "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " WHERE ? < stream_ordering AND stream_ordering <= ?" - " ORDER BY stream_ordering ASC" - " LIMIT ?" - ) - if have_forward_events: - txn.execute(sql, (last_forward_id, current_forward_id, limit)) - new_forward_events = txn.fetchall() - - if len(new_forward_events) == limit: - upper_bound = new_forward_events[-1][0] - else: - upper_bound = current_forward_id - - sql = ( - "SELECT event_stream_ordering, event_id, state_group" - " FROM ex_outlier_stream" - " WHERE ? > event_stream_ordering" - " AND event_stream_ordering >= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (last_forward_id, upper_bound)) - forward_ex_outliers = txn.fetchall() - else: - new_forward_events = [] - forward_ex_outliers = [] - - sql = ( - "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " WHERE ? > stream_ordering AND stream_ordering >= ?" - " ORDER BY stream_ordering DESC" - " LIMIT ?" - ) - if have_backfill_events: - txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit)) - new_backfill_events = txn.fetchall() - - if len(new_backfill_events) == limit: - upper_bound = new_backfill_events[-1][0] - else: - upper_bound = current_backfill_id - - sql = ( - "SELECT -event_stream_ordering, event_id, state_group" - " FROM ex_outlier_stream" - " WHERE ? > event_stream_ordering" - " AND event_stream_ordering >= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (-last_backfill_id, -upper_bound)) - backward_ex_outliers = txn.fetchall() - else: - new_backfill_events = [] - backward_ex_outliers = [] - - return AllNewEventsResult( - new_forward_events, - new_backfill_events, - forward_ex_outliers, - backward_ex_outliers, - ) - - return self.runInteraction("get_all_new_events", get_all_new_events_txn) - - def purge_history(self, room_id, token, delete_local_events): - """Deletes room history before a certain point - - Args: - room_id (str): - - token (str): A topological token to delete events before - - delete_local_events (bool): - if True, we will delete local events as well as remote ones - (instead of just marking them as outliers and deleting their - state groups). - """ - - return self.runInteraction( - "purge_history", - self._purge_history_txn, - room_id, - token, - delete_local_events, - ) - - def _purge_history_txn(self, txn, room_id, token_str, delete_local_events): - token = RoomStreamToken.parse(token_str) - - # Tables that should be pruned: - # event_auth - # event_backward_extremities - # event_edges - # event_forward_extremities - # event_json - # event_push_actions - # event_reference_hashes - # event_search - # event_to_state_groups - # events - # rejections - # room_depth - # state_groups - # state_groups_state - - # we will build a temporary table listing the events so that we don't - # have to keep shovelling the list back and forth across the - # connection. Annoyingly the python sqlite driver commits the - # transaction on CREATE, so let's do this first. - # - # furthermore, we might already have the table from a previous (failed) - # purge attempt, so let's drop the table first. - - txn.execute("DROP TABLE IF EXISTS events_to_purge") - - txn.execute( - "CREATE TEMPORARY TABLE events_to_purge (" - " event_id TEXT NOT NULL," - " should_delete BOOLEAN NOT NULL" - ")" - ) - - # First ensure that we're not about to delete all the forward extremeties - txn.execute( - "SELECT e.event_id, e.depth FROM events as e " - "INNER JOIN event_forward_extremities as f " - "ON e.event_id = f.event_id " - "AND e.room_id = f.room_id " - "WHERE f.room_id = ?", - (room_id,), - ) - rows = txn.fetchall() - max_depth = max(row[1] for row in rows) - - if max_depth < token.topological: - # We need to ensure we don't delete all the events from the database - # otherwise we wouldn't be able to send any events (due to not - # having any backwards extremeties) - raise SynapseError( - 400, "topological_ordering is greater than forward extremeties" - ) - - logger.info("[purge] looking for events to delete") - - should_delete_expr = "state_key IS NULL" - should_delete_params = () - if not delete_local_events: - should_delete_expr += " AND event_id NOT LIKE ?" - - # We include the parameter twice since we use the expression twice - should_delete_params += ("%:" + self.hs.hostname, "%:" + self.hs.hostname) - - should_delete_params += (room_id, token.topological) - - # Note that we insert events that are outliers and aren't going to be - # deleted, as nothing will happen to them. - txn.execute( - "INSERT INTO events_to_purge" - " SELECT event_id, %s" - " FROM events AS e LEFT JOIN state_events USING (event_id)" - " WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?" - % (should_delete_expr, should_delete_expr), - should_delete_params, - ) - - # We create the indices *after* insertion as that's a lot faster. - - # create an index on should_delete because later we'll be looking for - # the should_delete / shouldn't_delete subsets - txn.execute( - "CREATE INDEX events_to_purge_should_delete" - " ON events_to_purge(should_delete)" - ) - - # We do joins against events_to_purge for e.g. calculating state - # groups to purge, etc., so lets make an index. - txn.execute("CREATE INDEX events_to_purge_id" " ON events_to_purge(event_id)") - - txn.execute("SELECT event_id, should_delete FROM events_to_purge") - event_rows = txn.fetchall() - logger.info( - "[purge] found %i events before cutoff, of which %i can be deleted", - len(event_rows), - sum(1 for e in event_rows if e[1]), - ) - - logger.info("[purge] Finding new backward extremities") - - # We calculate the new entries for the backward extremeties by finding - # events to be purged that are pointed to by events we're not going to - # purge. - txn.execute( - "SELECT DISTINCT e.event_id FROM events_to_purge AS e" - " INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id" - " LEFT JOIN events_to_purge AS ep2 ON ed.event_id = ep2.event_id" - " WHERE ep2.event_id IS NULL" - ) - new_backwards_extrems = txn.fetchall() - - logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems) - - txn.execute( - "DELETE FROM event_backward_extremities WHERE room_id = ?", (room_id,) - ) - - # Update backward extremeties - txn.executemany( - "INSERT INTO event_backward_extremities (room_id, event_id)" - " VALUES (?, ?)", - [(room_id, event_id) for event_id, in new_backwards_extrems], - ) - - logger.info("[purge] finding redundant state groups") - - # Get all state groups that are referenced by events that are to be - # deleted. We then go and check if they are referenced by other events - # or state groups, and if not we delete them. - txn.execute( - """ - SELECT DISTINCT state_group FROM events_to_purge - INNER JOIN event_to_state_groups USING (event_id) - """ - ) - - referenced_state_groups = set(sg for sg, in txn) - logger.info( - "[purge] found %i referenced state groups", len(referenced_state_groups) - ) - - logger.info("[purge] finding state groups that can be deleted") - - _ = self._find_unreferenced_groups_during_purge(txn, referenced_state_groups) - state_groups_to_delete, remaining_state_groups = _ - - logger.info( - "[purge] found %i state groups to delete", len(state_groups_to_delete) - ) - - logger.info( - "[purge] de-delta-ing %i remaining state groups", - len(remaining_state_groups), - ) - - # Now we turn the state groups that reference to-be-deleted state - # groups to non delta versions. - for sg in remaining_state_groups: - logger.info("[purge] de-delta-ing remaining state group %s", sg) - curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) - curr_state = curr_state[sg] - - self._simple_delete_txn( - txn, table="state_groups_state", keyvalues={"state_group": sg} - ) - - self._simple_delete_txn( - txn, table="state_group_edges", keyvalues={"state_group": sg} - ) - - self._simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": sg, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in iteritems(curr_state) - ], - ) - - logger.info("[purge] removing redundant state groups") - txn.executemany( - "DELETE FROM state_groups_state WHERE state_group = ?", - ((sg,) for sg in state_groups_to_delete), - ) - txn.executemany( - "DELETE FROM state_groups WHERE id = ?", - ((sg,) for sg in state_groups_to_delete), - ) - - logger.info("[purge] removing events from event_to_state_groups") - txn.execute( - "DELETE FROM event_to_state_groups " - "WHERE event_id IN (SELECT event_id from events_to_purge)" - ) - for event_id, _ in event_rows: - txn.call_after(self._get_state_group_for_event.invalidate, (event_id,)) - - # Delete all remote non-state events - for table in ( - "events", - "event_json", - "event_auth", - "event_edges", - "event_forward_extremities", - "event_reference_hashes", - "event_search", - "rejections", - ): - logger.info("[purge] removing events from %s", table) - - txn.execute( - "DELETE FROM %s WHERE event_id IN (" - " SELECT event_id FROM events_to_purge WHERE should_delete" - ")" % (table,) - ) - - # event_push_actions lacks an index on event_id, and has one on - # (room_id, event_id) instead. - for table in ("event_push_actions",): - logger.info("[purge] removing events from %s", table) - - txn.execute( - "DELETE FROM %s WHERE room_id = ? AND event_id IN (" - " SELECT event_id FROM events_to_purge WHERE should_delete" - ")" % (table,), - (room_id,), - ) - - # Mark all state and own events as outliers - logger.info("[purge] marking remaining events as outliers") - txn.execute( - "UPDATE events SET outlier = ?" - " WHERE event_id IN (" - " SELECT event_id FROM events_to_purge " - " WHERE NOT should_delete" - ")", - (True,), - ) - - # synapse tries to take out an exclusive lock on room_depth whenever it - # persists events (because upsert), and once we run this update, we - # will block that for the rest of our transaction. - # - # So, let's stick it at the end so that we don't block event - # persistence. - # - # We do this by calculating the minimum depth of the backwards - # extremities. However, the events in event_backward_extremities - # are ones we don't have yet so we need to look at the events that - # point to it via event_edges table. - txn.execute( - """ - SELECT COALESCE(MIN(depth), 0) - FROM event_backward_extremities AS eb - INNER JOIN event_edges AS eg ON eg.prev_event_id = eb.event_id - INNER JOIN events AS e ON e.event_id = eg.event_id - WHERE eb.room_id = ? - """, - (room_id,), - ) - min_depth, = txn.fetchone() - - logger.info("[purge] updating room_depth to %d", min_depth) - - txn.execute( - "UPDATE room_depth SET min_depth = ? WHERE room_id = ?", - (min_depth, room_id), - ) - - # finally, drop the temp table. this will commit the txn in sqlite, - # so make sure to keep this actually last. - txn.execute("DROP TABLE events_to_purge") - - logger.info("[purge] done") - - def _find_unreferenced_groups_during_purge(self, txn, state_groups): - """Used when purging history to figure out which state groups can be - deleted and which need to be de-delta'ed (due to one of its prev groups - being scheduled for deletion). - - Args: - txn - state_groups (set[int]): Set of state groups referenced by events - that are going to be deleted. - - Returns: - tuple[set[int], set[int]]: The set of state groups that can be - deleted and the set of state groups that need to be de-delta'ed - """ - # Graph of state group -> previous group - graph = {} - - # Set of events that we have found to be referenced by events - referenced_groups = set() - - # Set of state groups we've already seen - state_groups_seen = set(state_groups) - - # Set of state groups to handle next. - next_to_search = set(state_groups) - while next_to_search: - # We bound size of groups we're looking up at once, to stop the - # SQL query getting too big - if len(next_to_search) < 100: - current_search = next_to_search - next_to_search = set() - else: - current_search = set(itertools.islice(next_to_search, 100)) - next_to_search -= current_search - - # Check if state groups are referenced - sql = """ - SELECT DISTINCT state_group FROM event_to_state_groups - LEFT JOIN events_to_purge AS ep USING (event_id) - WHERE ep.event_id IS NULL AND - """ - clause, args = make_in_list_sql_clause( - txn.database_engine, "state_group", current_search - ) - txn.execute(sql + clause, list(args)) - - referenced = set(sg for sg, in txn) - referenced_groups |= referenced - - # We don't continue iterating up the state group graphs for state - # groups that are referenced. - current_search -= referenced - - rows = self._simple_select_many_txn( - txn, - table="state_group_edges", - column="prev_state_group", - iterable=current_search, - keyvalues={}, - retcols=("prev_state_group", "state_group"), - ) - - prevs = set(row["state_group"] for row in rows) - # We don't bother re-handling groups we've already seen - prevs -= state_groups_seen - next_to_search |= prevs - state_groups_seen |= prevs - - for row in rows: - # Note: Each state group can have at most one prev group - graph[row["state_group"]] = row["prev_state_group"] - - to_delete = state_groups_seen - referenced_groups - - to_dedelta = set() - for sg in referenced_groups: - prev_sg = graph.get(sg) - if prev_sg and prev_sg in to_delete: - to_dedelta.add(sg) - - return to_delete, to_dedelta - - def purge_room(self, room_id): - """Deletes all record of a room - - Args: - room_id (str): - """ - - return self.runInteraction("purge_room", self._purge_room_txn, room_id) - - def _purge_room_txn(self, txn, room_id): - # first we have to delete the state groups states - logger.info("[purge] removing %s from state_groups_state", room_id) - - txn.execute( - """ - DELETE FROM state_groups_state WHERE state_group IN ( - SELECT state_group FROM events JOIN event_to_state_groups USING(event_id) - WHERE events.room_id=? - ) - """, - (room_id,), - ) - - # ... and the state group edges - logger.info("[purge] removing %s from state_group_edges", room_id) - - txn.execute( - """ - DELETE FROM state_group_edges WHERE state_group IN ( - SELECT state_group FROM events JOIN event_to_state_groups USING(event_id) - WHERE events.room_id=? - ) - """, - (room_id,), - ) - - # ... and the state groups - logger.info("[purge] removing %s from state_groups", room_id) - - txn.execute( - """ - DELETE FROM state_groups WHERE id IN ( - SELECT state_group FROM events JOIN event_to_state_groups USING(event_id) - WHERE events.room_id=? - ) - """, - (room_id,), - ) - - # and then tables which lack an index on room_id but have one on event_id - for table in ( - "event_auth", - "event_edges", - "event_push_actions_staging", - "event_reference_hashes", - "event_relations", - "event_to_state_groups", - "redactions", - "rejections", - "state_events", - ): - logger.info("[purge] removing %s from %s", room_id, table) - - txn.execute( - """ - DELETE FROM %s WHERE event_id IN ( - SELECT event_id FROM events WHERE room_id=? - ) - """ - % (table,), - (room_id,), - ) - - # and finally, the tables with an index on room_id (or no useful index) - for table in ( - "current_state_events", - "event_backward_extremities", - "event_forward_extremities", - "event_json", - "event_push_actions", - "event_search", - "events", - "group_rooms", - "public_room_list_stream", - "receipts_graph", - "receipts_linearized", - "room_aliases", - "room_depth", - "room_memberships", - "room_stats_state", - "room_stats_current", - "room_stats_historical", - "room_stats_earliest_token", - "rooms", - "stream_ordering_to_exterm", - "topics", - "users_in_public_rooms", - "users_who_share_private_rooms", - # no useful index, but let's clear them anyway - "appservice_room_list", - "e2e_room_keys", - "event_push_summary", - "pusher_throttle", - "group_summary_rooms", - "local_invites", - "room_account_data", - "room_tags", - ): - logger.info("[purge] removing %s from %s", room_id, table) - txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,)) - - # Other tables we do NOT need to clear out: - # - # - blocked_rooms - # This is important, to make sure that we don't accidentally rejoin a blocked - # room after it was purged - # - # - user_directory - # This has a room_id column, but it is unused - # - - # Other tables that we might want to consider clearing out include: - # - # - event_reports - # Given that these are intended for abuse management my initial - # inclination is to leave them in place. - # - # - current_state_delta_stream - # - ex_outlier_stream - # - room_tags_revisions - # The problem with these is that they are largeish and there is no room_id - # index on them. In any case we should be clearing out 'stream' tables - # periodically anyway (#5888) - - # TODO: we could probably usefully do a bunch of cache invalidation here - - logger.info("[purge] done") - - @defer.inlineCallbacks - def is_event_after(self, event_id1, event_id2): - """Returns True if event_id1 is after event_id2 in the stream - """ - to_1, so_1 = yield self._get_event_ordering(event_id1) - to_2, so_2 = yield self._get_event_ordering(event_id2) - return (to_1, so_1) > (to_2, so_2) - - @cachedInlineCallbacks(max_entries=5000) - def _get_event_ordering(self, event_id): - res = yield self._simple_select_one( - table="events", - retcols=["topological_ordering", "stream_ordering"], - keyvalues={"event_id": event_id}, - allow_none=True, - ) - - if not res: - raise SynapseError(404, "Could not find event %s" % (event_id,)) - - return (int(res["topological_ordering"]), int(res["stream_ordering"])) - - def get_all_updated_current_state_deltas(self, from_token, to_token, limit): - def get_all_updated_current_state_deltas_txn(txn): - sql = """ - SELECT stream_id, room_id, type, state_key, event_id - FROM current_state_delta_stream - WHERE ? < stream_id AND stream_id <= ? - ORDER BY stream_id ASC LIMIT ? - """ - txn.execute(sql, (from_token, to_token, limit)) - return txn.fetchall() - - return self.runInteraction( - "get_all_updated_current_state_deltas", - get_all_updated_current_state_deltas_txn, - ) - - -AllNewEventsResult = namedtuple( - "AllNewEventsResult", - [ - "new_forward_events", - "new_backfill_events", - "forward_ex_outliers", - "backward_ex_outliers", - ], -) diff --git a/synapse/storage/events_bg_updates.py b/synapse/storage/events_bg_updates.py deleted file mode 100644 index 31ea6f917f..0000000000 --- a/synapse/storage/events_bg_updates.py +++ /dev/null @@ -1,505 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from six import text_type - -from canonicaljson import json - -from twisted.internet import defer - -from synapse.storage._base import make_in_list_sql_clause -from synapse.storage.background_updates import BackgroundUpdateStore - -logger = logging.getLogger(__name__) - - -class EventsBackgroundUpdatesStore(BackgroundUpdateStore): - - EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" - EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" - DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" - - def __init__(self, db_conn, hs): - super(EventsBackgroundUpdatesStore, self).__init__(db_conn, hs) - - self.register_background_update_handler( - self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts - ) - self.register_background_update_handler( - self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, - self._background_reindex_fields_sender, - ) - - self.register_background_index_update( - "event_contains_url_index", - index_name="event_contains_url_index", - table="events", - columns=["room_id", "topological_ordering", "stream_ordering"], - where_clause="contains_url = true AND outlier = false", - ) - - # an event_id index on event_search is useful for the purge_history - # api. Plus it means we get to enforce some integrity with a UNIQUE - # clause - self.register_background_index_update( - "event_search_event_id_idx", - index_name="event_search_event_id_idx", - table="event_search", - columns=["event_id"], - unique=True, - psql_only=True, - ) - - self.register_background_update_handler( - self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update - ) - - self.register_background_update_handler( - "redactions_received_ts", self._redactions_received_ts - ) - - # This index gets deleted in `event_fix_redactions_bytes` update - self.register_background_index_update( - "event_fix_redactions_bytes_create_index", - index_name="redactions_censored_redacts", - table="redactions", - columns=["redacts"], - where_clause="have_censored", - ) - - self.register_background_update_handler( - "event_fix_redactions_bytes", self._event_fix_redactions_bytes - ) - - @defer.inlineCallbacks - def _background_reindex_fields_sender(self, progress, batch_size): - target_min_stream_id = progress["target_min_stream_id_inclusive"] - max_stream_id = progress["max_stream_id_exclusive"] - rows_inserted = progress.get("rows_inserted", 0) - - INSERT_CLUMP_SIZE = 1000 - - def reindex_txn(txn): - sql = ( - "SELECT stream_ordering, event_id, json FROM events" - " INNER JOIN event_json USING (event_id)" - " WHERE ? <= stream_ordering AND stream_ordering < ?" - " ORDER BY stream_ordering DESC" - " LIMIT ?" - ) - - txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) - - rows = txn.fetchall() - if not rows: - return 0 - - min_stream_id = rows[-1][0] - - update_rows = [] - for row in rows: - try: - event_id = row[1] - event_json = json.loads(row[2]) - sender = event_json["sender"] - content = event_json["content"] - - contains_url = "url" in content - if contains_url: - contains_url &= isinstance(content["url"], text_type) - except (KeyError, AttributeError): - # If the event is missing a necessary field then - # skip over it. - continue - - update_rows.append((sender, contains_url, event_id)) - - sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?" - - for index in range(0, len(update_rows), INSERT_CLUMP_SIZE): - clump = update_rows[index : index + INSERT_CLUMP_SIZE] - txn.executemany(sql, clump) - - progress = { - "target_min_stream_id_inclusive": target_min_stream_id, - "max_stream_id_exclusive": min_stream_id, - "rows_inserted": rows_inserted + len(rows), - } - - self._background_update_progress_txn( - txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress - ) - - return len(rows) - - result = yield self.runInteraction( - self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn - ) - - if not result: - yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME) - - return result - - @defer.inlineCallbacks - def _background_reindex_origin_server_ts(self, progress, batch_size): - target_min_stream_id = progress["target_min_stream_id_inclusive"] - max_stream_id = progress["max_stream_id_exclusive"] - rows_inserted = progress.get("rows_inserted", 0) - - INSERT_CLUMP_SIZE = 1000 - - def reindex_search_txn(txn): - sql = ( - "SELECT stream_ordering, event_id FROM events" - " WHERE ? <= stream_ordering AND stream_ordering < ?" - " ORDER BY stream_ordering DESC" - " LIMIT ?" - ) - - txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) - - rows = txn.fetchall() - if not rows: - return 0 - - min_stream_id = rows[-1][0] - event_ids = [row[1] for row in rows] - - rows_to_update = [] - - chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)] - for chunk in chunks: - ev_rows = self._simple_select_many_txn( - txn, - table="event_json", - column="event_id", - iterable=chunk, - retcols=["event_id", "json"], - keyvalues={}, - ) - - for row in ev_rows: - event_id = row["event_id"] - event_json = json.loads(row["json"]) - try: - origin_server_ts = event_json["origin_server_ts"] - except (KeyError, AttributeError): - # If the event is missing a necessary field then - # skip over it. - continue - - rows_to_update.append((origin_server_ts, event_id)) - - sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?" - - for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE): - clump = rows_to_update[index : index + INSERT_CLUMP_SIZE] - txn.executemany(sql, clump) - - progress = { - "target_min_stream_id_inclusive": target_min_stream_id, - "max_stream_id_exclusive": min_stream_id, - "rows_inserted": rows_inserted + len(rows_to_update), - } - - self._background_update_progress_txn( - txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress - ) - - return len(rows_to_update) - - result = yield self.runInteraction( - self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn - ) - - if not result: - yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME) - - return result - - @defer.inlineCallbacks - def _cleanup_extremities_bg_update(self, progress, batch_size): - """Background update to clean out extremities that should have been - deleted previously. - - Mainly used to deal with the aftermath of #5269. - """ - - # This works by first copying all existing forward extremities into the - # `_extremities_to_check` table at start up, and then checking each - # event in that table whether we have any descendants that are not - # soft-failed/rejected. If that is the case then we delete that event - # from the forward extremities table. - # - # For efficiency, we do this in batches by recursively pulling out all - # descendants of a batch until we find the non soft-failed/rejected - # events, i.e. the set of descendants whose chain of prev events back - # to the batch of extremities are all soft-failed or rejected. - # Typically, we won't find any such events as extremities will rarely - # have any descendants, but if they do then we should delete those - # extremities. - - def _cleanup_extremities_bg_update_txn(txn): - # The set of extremity event IDs that we're checking this round - original_set = set() - - # A dict[str, set[str]] of event ID to their prev events. - graph = {} - - # The set of descendants of the original set that are not rejected - # nor soft-failed. Ancestors of these events should be removed - # from the forward extremities table. - non_rejected_leaves = set() - - # Set of event IDs that have been soft failed, and for which we - # should check if they have descendants which haven't been soft - # failed. - soft_failed_events_to_lookup = set() - - # First, we get `batch_size` events from the table, pulling out - # their successor events, if any, and the successor events' - # rejection status. - txn.execute( - """SELECT prev_event_id, event_id, internal_metadata, - rejections.event_id IS NOT NULL, events.outlier - FROM ( - SELECT event_id AS prev_event_id - FROM _extremities_to_check - LIMIT ? - ) AS f - LEFT JOIN event_edges USING (prev_event_id) - LEFT JOIN events USING (event_id) - LEFT JOIN event_json USING (event_id) - LEFT JOIN rejections USING (event_id) - """, - (batch_size,), - ) - - for prev_event_id, event_id, metadata, rejected, outlier in txn: - original_set.add(prev_event_id) - - if not event_id or outlier: - # Common case where the forward extremity doesn't have any - # descendants. - continue - - graph.setdefault(event_id, set()).add(prev_event_id) - - soft_failed = False - if metadata: - soft_failed = json.loads(metadata).get("soft_failed") - - if soft_failed or rejected: - soft_failed_events_to_lookup.add(event_id) - else: - non_rejected_leaves.add(event_id) - - # Now we recursively check all the soft-failed descendants we - # found above in the same way, until we have nothing left to - # check. - while soft_failed_events_to_lookup: - # We only want to do 100 at a time, so we split given list - # into two. - batch = list(soft_failed_events_to_lookup) - to_check, to_defer = batch[:100], batch[100:] - soft_failed_events_to_lookup = set(to_defer) - - sql = """SELECT prev_event_id, event_id, internal_metadata, - rejections.event_id IS NOT NULL - FROM event_edges - INNER JOIN events USING (event_id) - INNER JOIN event_json USING (event_id) - LEFT JOIN rejections USING (event_id) - WHERE - NOT events.outlier - AND - """ - clause, args = make_in_list_sql_clause( - self.database_engine, "prev_event_id", to_check - ) - txn.execute(sql + clause, list(args)) - - for prev_event_id, event_id, metadata, rejected in txn: - if event_id in graph: - # Already handled this event previously, but we still - # want to record the edge. - graph[event_id].add(prev_event_id) - continue - - graph[event_id] = {prev_event_id} - - soft_failed = json.loads(metadata).get("soft_failed") - if soft_failed or rejected: - soft_failed_events_to_lookup.add(event_id) - else: - non_rejected_leaves.add(event_id) - - # We have a set of non-soft-failed descendants, so we recurse up - # the graph to find all ancestors and add them to the set of event - # IDs that we can delete from forward extremities table. - to_delete = set() - while non_rejected_leaves: - event_id = non_rejected_leaves.pop() - prev_event_ids = graph.get(event_id, set()) - non_rejected_leaves.update(prev_event_ids) - to_delete.update(prev_event_ids) - - to_delete.intersection_update(original_set) - - deleted = self._simple_delete_many_txn( - txn=txn, - table="event_forward_extremities", - column="event_id", - iterable=to_delete, - keyvalues={}, - ) - - logger.info( - "Deleted %d forward extremities of %d checked, to clean up #5269", - deleted, - len(original_set), - ) - - if deleted: - # We now need to invalidate the caches of these rooms - rows = self._simple_select_many_txn( - txn, - table="events", - column="event_id", - iterable=to_delete, - keyvalues={}, - retcols=("room_id",), - ) - room_ids = set(row["room_id"] for row in rows) - for room_id in room_ids: - txn.call_after( - self.get_latest_event_ids_in_room.invalidate, (room_id,) - ) - - self._simple_delete_many_txn( - txn=txn, - table="_extremities_to_check", - column="event_id", - iterable=original_set, - keyvalues={}, - ) - - return len(original_set) - - num_handled = yield self.runInteraction( - "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn - ) - - if not num_handled: - yield self._end_background_update(self.DELETE_SOFT_FAILED_EXTREMITIES) - - def _drop_table_txn(txn): - txn.execute("DROP TABLE _extremities_to_check") - - yield self.runInteraction( - "_cleanup_extremities_bg_update_drop_table", _drop_table_txn - ) - - return num_handled - - @defer.inlineCallbacks - def _redactions_received_ts(self, progress, batch_size): - """Handles filling out the `received_ts` column in redactions. - """ - last_event_id = progress.get("last_event_id", "") - - def _redactions_received_ts_txn(txn): - # Fetch the set of event IDs that we want to update - sql = """ - SELECT event_id FROM redactions - WHERE event_id > ? - ORDER BY event_id ASC - LIMIT ? - """ - - txn.execute(sql, (last_event_id, batch_size)) - - rows = txn.fetchall() - if not rows: - return 0 - - upper_event_id, = rows[-1] - - # Update the redactions with the received_ts. - # - # Note: Not all events have an associated received_ts, so we - # fallback to using origin_server_ts. If we for some reason don't - # have an origin_server_ts, lets just use the current timestamp. - # - # We don't want to leave it null, as then we'll never try and - # censor those redactions. - sql = """ - UPDATE redactions - SET received_ts = ( - SELECT COALESCE(received_ts, origin_server_ts, ?) FROM events - WHERE events.event_id = redactions.event_id - ) - WHERE ? <= event_id AND event_id <= ? - """ - - txn.execute(sql, (self._clock.time_msec(), last_event_id, upper_event_id)) - - self._background_update_progress_txn( - txn, "redactions_received_ts", {"last_event_id": upper_event_id} - ) - - return len(rows) - - count = yield self.runInteraction( - "_redactions_received_ts", _redactions_received_ts_txn - ) - - if not count: - yield self._end_background_update("redactions_received_ts") - - return count - - @defer.inlineCallbacks - def _event_fix_redactions_bytes(self, progress, batch_size): - """Undoes hex encoded censored redacted event JSON. - """ - - def _event_fix_redactions_bytes_txn(txn): - # This update is quite fast due to new index. - txn.execute( - """ - UPDATE event_json - SET - json = convert_from(json::bytea, 'utf8') - FROM redactions - WHERE - redactions.have_censored - AND event_json.event_id = redactions.redacts - AND json NOT LIKE '{%'; - """ - ) - - txn.execute("DROP INDEX redactions_censored_redacts") - - yield self.runInteraction( - "_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn - ) - - yield self._end_background_update("event_fix_redactions_bytes") - - return 1 diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py deleted file mode 100644 index 4c4b76bd93..0000000000 --- a/synapse/storage/events_worker.py +++ /dev/null @@ -1,882 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import division - -import itertools -import logging -from collections import namedtuple - -from canonicaljson import json - -from twisted.internet import defer - -from synapse.api.constants import EventTypes -from synapse.api.errors import NotFoundError -from synapse.api.room_versions import EventFormatVersions -from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401 -from synapse.events.snapshot import EventContext # noqa: F401 -from synapse.events.utils import prune_event -from synapse.logging.context import LoggingContext, PreserveLoggingContext -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.types import get_domain_from_id -from synapse.util import batch_iter -from synapse.util.metrics import Measure - -logger = logging.getLogger(__name__) - - -# These values are used in the `enqueus_event` and `_do_fetch` methods to -# control how we batch/bulk fetch events from the database. -# The values are plucked out of thing air to make initial sync run faster -# on jki.re -# TODO: Make these configurable. -EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events -EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events -EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events - - -_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) - - -class EventsWorkerStore(SQLBaseStore): - def get_received_ts(self, event_id): - """Get received_ts (when it was persisted) for the event. - - Raises an exception for unknown events. - - Args: - event_id (str) - - Returns: - Deferred[int|None]: Timestamp in milliseconds, or None for events - that were persisted before received_ts was implemented. - """ - return self._simple_select_one_onecol( - table="events", - keyvalues={"event_id": event_id}, - retcol="received_ts", - desc="get_received_ts", - ) - - def get_received_ts_by_stream_pos(self, stream_ordering): - """Given a stream ordering get an approximate timestamp of when it - happened. - - This is done by simply taking the received ts of the first event that - has a stream ordering greater than or equal to the given stream pos. - If none exists returns the current time, on the assumption that it must - have happened recently. - - Args: - stream_ordering (int) - - Returns: - Deferred[int] - """ - - def _get_approximate_received_ts_txn(txn): - sql = """ - SELECT received_ts FROM events - WHERE stream_ordering >= ? - LIMIT 1 - """ - - txn.execute(sql, (stream_ordering,)) - row = txn.fetchone() - if row and row[0]: - ts = row[0] - else: - ts = self.clock.time_msec() - - return ts - - return self.runInteraction( - "get_approximate_received_ts", _get_approximate_received_ts_txn - ) - - @defer.inlineCallbacks - def get_event( - self, - event_id, - check_redacted=True, - get_prev_content=False, - allow_rejected=False, - allow_none=False, - check_room_id=None, - ): - """Get an event from the database by event_id. - - Args: - event_id (str): The event_id of the event to fetch - check_redacted (bool): If True, check if event has been redacted - and redact it. - get_prev_content (bool): If True and event is a state event, - include the previous states content in the unsigned field. - allow_rejected (bool): If True return rejected events. - allow_none (bool): If True, return None if no event found, if - False throw a NotFoundError - check_room_id (str|None): if not None, check the room of the found event. - If there is a mismatch, behave as per allow_none. - - Returns: - Deferred[EventBase|None] - """ - if not isinstance(event_id, str): - raise TypeError("Invalid event event_id %r" % (event_id,)) - - events = yield self.get_events_as_list( - [event_id], - check_redacted=check_redacted, - get_prev_content=get_prev_content, - allow_rejected=allow_rejected, - ) - - event = events[0] if events else None - - if event is not None and check_room_id is not None: - if event.room_id != check_room_id: - event = None - - if event is None and not allow_none: - raise NotFoundError("Could not find event %s" % (event_id,)) - - return event - - @defer.inlineCallbacks - def get_events( - self, - event_ids, - check_redacted=True, - get_prev_content=False, - allow_rejected=False, - ): - """Get events from the database - - Args: - event_ids (list): The event_ids of the events to fetch - check_redacted (bool): If True, check if event has been redacted - and redact it. - get_prev_content (bool): If True and event is a state event, - include the previous states content in the unsigned field. - allow_rejected (bool): If True return rejected events. - - Returns: - Deferred : Dict from event_id to event. - """ - events = yield self.get_events_as_list( - event_ids, - check_redacted=check_redacted, - get_prev_content=get_prev_content, - allow_rejected=allow_rejected, - ) - - return {e.event_id: e for e in events} - - @defer.inlineCallbacks - def get_events_as_list( - self, - event_ids, - check_redacted=True, - get_prev_content=False, - allow_rejected=False, - ): - """Get events from the database and return in a list in the same order - as given by `event_ids` arg. - - Args: - event_ids (list): The event_ids of the events to fetch - check_redacted (bool): If True, check if event has been redacted - and redact it. - get_prev_content (bool): If True and event is a state event, - include the previous states content in the unsigned field. - allow_rejected (bool): If True return rejected events. - - Returns: - Deferred[list[EventBase]]: List of events fetched from the database. The - events are in the same order as `event_ids` arg. - - Note that the returned list may be smaller than the list of event - IDs if not all events could be fetched. - """ - - if not event_ids: - return [] - - # there may be duplicates so we cast the list to a set - event_entry_map = yield self._get_events_from_cache_or_db( - set(event_ids), allow_rejected=allow_rejected - ) - - events = [] - for event_id in event_ids: - entry = event_entry_map.get(event_id, None) - if not entry: - continue - - if not allow_rejected: - assert not entry.event.rejected_reason, ( - "rejected event returned from _get_events_from_cache_or_db despite " - "allow_rejected=False" - ) - - # We may not have had the original event when we received a redaction, so - # we have to recheck auth now. - - if not allow_rejected and entry.event.type == EventTypes.Redaction: - if not hasattr(entry.event, "redacts"): - # A redacted redaction doesn't have a `redacts` key, in - # which case lets just withhold the event. - # - # Note: Most of the time if the redactions has been - # redacted we still have the un-redacted event in the DB - # and so we'll still see the `redacts` key. However, this - # isn't always true e.g. if we have censored the event. - logger.debug( - "Withholding redaction event %s as we don't have redacts key", - event_id, - ) - continue - - redacted_event_id = entry.event.redacts - event_map = yield self._get_events_from_cache_or_db([redacted_event_id]) - original_event_entry = event_map.get(redacted_event_id) - if not original_event_entry: - # we don't have the redacted event (or it was rejected). - # - # We assume that the redaction isn't authorized for now; if the - # redacted event later turns up, the redaction will be re-checked, - # and if it is found valid, the original will get redacted before it - # is served to the client. - logger.debug( - "Withholding redaction event %s since we don't (yet) have the " - "original %s", - event_id, - redacted_event_id, - ) - continue - - original_event = original_event_entry.event - if original_event.type == EventTypes.Create: - # we never serve redactions of Creates to clients. - logger.info( - "Withholding redaction %s of create event %s", - event_id, - redacted_event_id, - ) - continue - - if original_event.room_id != entry.event.room_id: - logger.info( - "Withholding redaction %s of event %s from a different room", - event_id, - redacted_event_id, - ) - continue - - if entry.event.internal_metadata.need_to_check_redaction(): - original_domain = get_domain_from_id(original_event.sender) - redaction_domain = get_domain_from_id(entry.event.sender) - if original_domain != redaction_domain: - # the senders don't match, so this is forbidden - logger.info( - "Withholding redaction %s whose sender domain %s doesn't " - "match that of redacted event %s %s", - event_id, - redaction_domain, - redacted_event_id, - original_domain, - ) - continue - - # Update the cache to save doing the checks again. - entry.event.internal_metadata.recheck_redaction = False - - if check_redacted and entry.redacted_event: - event = entry.redacted_event - else: - event = entry.event - - events.append(event) - - if get_prev_content: - if "replaces_state" in event.unsigned: - prev = yield self.get_event( - event.unsigned["replaces_state"], - get_prev_content=False, - allow_none=True, - ) - if prev: - event.unsigned = dict(event.unsigned) - event.unsigned["prev_content"] = prev.content - event.unsigned["prev_sender"] = prev.sender - - return events - - @defer.inlineCallbacks - def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): - """Fetch a bunch of events from the cache or the database. - - If events are pulled from the database, they will be cached for future lookups. - - Args: - event_ids (Iterable[str]): The event_ids of the events to fetch - allow_rejected (bool): Whether to include rejected events - - Returns: - Deferred[Dict[str, _EventCacheEntry]]: - map from event id to result - """ - event_entry_map = self._get_events_from_cache( - event_ids, allow_rejected=allow_rejected - ) - - missing_events_ids = [e for e in event_ids if e not in event_entry_map] - - if missing_events_ids: - log_ctx = LoggingContext.current_context() - log_ctx.record_event_fetch(len(missing_events_ids)) - - # Note that _get_events_from_db is also responsible for turning db rows - # into FrozenEvents (via _get_event_from_row), which involves seeing if - # the events have been redacted, and if so pulling the redaction event out - # of the database to check it. - # - missing_events = yield self._get_events_from_db( - missing_events_ids, allow_rejected=allow_rejected - ) - - event_entry_map.update(missing_events) - - return event_entry_map - - def _invalidate_get_event_cache(self, event_id): - self._get_event_cache.invalidate((event_id,)) - - def _get_events_from_cache(self, events, allow_rejected, update_metrics=True): - """Fetch events from the caches - - Args: - events (Iterable[str]): list of event_ids to fetch - allow_rejected (bool): Whether to return events that were rejected - update_metrics (bool): Whether to update the cache hit ratio metrics - - Returns: - dict of event_id -> _EventCacheEntry for each event_id in cache. If - allow_rejected is `False` then there will still be an entry but it - will be `None` - """ - event_map = {} - - for event_id in events: - ret = self._get_event_cache.get( - (event_id,), None, update_metrics=update_metrics - ) - if not ret: - continue - - if allow_rejected or not ret.event.rejected_reason: - event_map[event_id] = ret - else: - event_map[event_id] = None - - return event_map - - def _do_fetch(self, conn): - """Takes a database connection and waits for requests for events from - the _event_fetch_list queue. - """ - i = 0 - while True: - with self._event_fetch_lock: - event_list = self._event_fetch_list - self._event_fetch_list = [] - - if not event_list: - single_threaded = self.database_engine.single_threaded - if single_threaded or i > EVENT_QUEUE_ITERATIONS: - self._event_fetch_ongoing -= 1 - return - else: - self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S) - i += 1 - continue - i = 0 - - self._fetch_event_list(conn, event_list) - - def _fetch_event_list(self, conn, event_list): - """Handle a load of requests from the _event_fetch_list queue - - Args: - conn (twisted.enterprise.adbapi.Connection): database connection - - event_list (list[Tuple[list[str], Deferred]]): - The fetch requests. Each entry consists of a list of event - ids to be fetched, and a deferred to be completed once the - events have been fetched. - - The deferreds are callbacked with a dictionary mapping from event id - to event row. Note that it may well contain additional events that - were not part of this request. - """ - with Measure(self._clock, "_fetch_event_list"): - try: - events_to_fetch = set( - event_id for events, _ in event_list for event_id in events - ) - - row_dict = self._new_transaction( - conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch - ) - - # We only want to resolve deferreds from the main thread - def fire(): - for _, d in event_list: - d.callback(row_dict) - - with PreserveLoggingContext(): - self.hs.get_reactor().callFromThread(fire) - except Exception as e: - logger.exception("do_fetch") - - # We only want to resolve deferreds from the main thread - def fire(evs, exc): - for _, d in evs: - if not d.called: - with PreserveLoggingContext(): - d.errback(exc) - - with PreserveLoggingContext(): - self.hs.get_reactor().callFromThread(fire, event_list, e) - - @defer.inlineCallbacks - def _get_events_from_db(self, event_ids, allow_rejected=False): - """Fetch a bunch of events from the database. - - Returned events will be added to the cache for future lookups. - - Args: - event_ids (Iterable[str]): The event_ids of the events to fetch - allow_rejected (bool): Whether to include rejected events - - Returns: - Deferred[Dict[str, _EventCacheEntry]]: - map from event id to result. May return extra events which - weren't asked for. - """ - fetched_events = {} - events_to_fetch = event_ids - - while events_to_fetch: - row_map = yield self._enqueue_events(events_to_fetch) - - # we need to recursively fetch any redactions of those events - redaction_ids = set() - for event_id in events_to_fetch: - row = row_map.get(event_id) - fetched_events[event_id] = row - if row: - redaction_ids.update(row["redactions"]) - - events_to_fetch = redaction_ids.difference(fetched_events.keys()) - if events_to_fetch: - logger.debug("Also fetching redaction events %s", events_to_fetch) - - # build a map from event_id to EventBase - event_map = {} - for event_id, row in fetched_events.items(): - if not row: - continue - assert row["event_id"] == event_id - - rejected_reason = row["rejected_reason"] - - if not allow_rejected and rejected_reason: - continue - - d = json.loads(row["json"]) - internal_metadata = json.loads(row["internal_metadata"]) - - format_version = row["format_version"] - if format_version is None: - # This means that we stored the event before we had the concept - # of a event format version, so it must be a V1 event. - format_version = EventFormatVersions.V1 - - original_ev = event_type_from_format_version(format_version)( - event_dict=d, - internal_metadata_dict=internal_metadata, - rejected_reason=rejected_reason, - ) - - event_map[event_id] = original_ev - - # finally, we can decide whether each one nededs redacting, and build - # the cache entries. - result_map = {} - for event_id, original_ev in event_map.items(): - redactions = fetched_events[event_id]["redactions"] - redacted_event = self._maybe_redact_event_row( - original_ev, redactions, event_map - ) - - cache_entry = _EventCacheEntry( - event=original_ev, redacted_event=redacted_event - ) - - self._get_event_cache.prefill((event_id,), cache_entry) - result_map[event_id] = cache_entry - - return result_map - - @defer.inlineCallbacks - def _enqueue_events(self, events): - """Fetches events from the database using the _event_fetch_list. This - allows batch and bulk fetching of events - it allows us to fetch events - without having to create a new transaction for each request for events. - - Args: - events (Iterable[str]): events to be fetched. - - Returns: - Deferred[Dict[str, Dict]]: map from event id to row data from the database. - May contain events that weren't requested. - """ - - events_d = defer.Deferred() - with self._event_fetch_lock: - self._event_fetch_list.append((events, events_d)) - - self._event_fetch_lock.notify() - - if self._event_fetch_ongoing < EVENT_QUEUE_THREADS: - self._event_fetch_ongoing += 1 - should_start = True - else: - should_start = False - - if should_start: - run_as_background_process( - "fetch_events", self.runWithConnection, self._do_fetch - ) - - logger.debug("Loading %d events: %s", len(events), events) - with PreserveLoggingContext(): - row_map = yield events_d - logger.debug("Loaded %d events (%d rows)", len(events), len(row_map)) - - return row_map - - def _fetch_event_rows(self, txn, event_ids): - """Fetch event rows from the database - - Events which are not found are omitted from the result. - - The returned per-event dicts contain the following keys: - - * event_id (str) - - * json (str): json-encoded event structure - - * internal_metadata (str): json-encoded internal metadata dict - - * format_version (int|None): The format of the event. Hopefully one - of EventFormatVersions. 'None' means the event predates - EventFormatVersions (so the event is format V1). - - * rejected_reason (str|None): if the event was rejected, the reason - why. - - * redactions (List[str]): a list of event-ids which (claim to) redact - this event. - - Args: - txn (twisted.enterprise.adbapi.Connection): - event_ids (Iterable[str]): event IDs to fetch - - Returns: - Dict[str, Dict]: a map from event id to event info. - """ - event_dict = {} - for evs in batch_iter(event_ids, 200): - sql = ( - "SELECT " - " e.event_id, " - " e.internal_metadata," - " e.json," - " e.format_version, " - " rej.reason " - " FROM event_json as e" - " LEFT JOIN rejections as rej USING (event_id)" - " WHERE " - ) - - clause, args = make_in_list_sql_clause( - txn.database_engine, "e.event_id", evs - ) - - txn.execute(sql + clause, args) - - for row in txn: - event_id = row[0] - event_dict[event_id] = { - "event_id": event_id, - "internal_metadata": row[1], - "json": row[2], - "format_version": row[3], - "rejected_reason": row[4], - "redactions": [], - } - - # check for redactions - redactions_sql = "SELECT event_id, redacts FROM redactions WHERE " - - clause, args = make_in_list_sql_clause(txn.database_engine, "redacts", evs) - - txn.execute(redactions_sql + clause, args) - - for (redacter, redacted) in txn: - d = event_dict.get(redacted) - if d: - d["redactions"].append(redacter) - - return event_dict - - def _maybe_redact_event_row(self, original_ev, redactions, event_map): - """Given an event object and a list of possible redacting event ids, - determine whether to honour any of those redactions and if so return a redacted - event. - - Args: - original_ev (EventBase): - redactions (iterable[str]): list of event ids of potential redaction events - event_map (dict[str, EventBase]): other events which have been fetched, in - which we can look up the redaaction events. Map from event id to event. - - Returns: - Deferred[EventBase|None]: if the event should be redacted, a pruned - event object. Otherwise, None. - """ - if original_ev.type == "m.room.create": - # we choose to ignore redactions of m.room.create events. - return None - - for redaction_id in redactions: - redaction_event = event_map.get(redaction_id) - if not redaction_event or redaction_event.rejected_reason: - # we don't have the redaction event, or the redaction event was not - # authorized. - logger.debug( - "%s was redacted by %s but redaction not found/authed", - original_ev.event_id, - redaction_id, - ) - continue - - if redaction_event.room_id != original_ev.room_id: - logger.debug( - "%s was redacted by %s but redaction was in a different room!", - original_ev.event_id, - redaction_id, - ) - continue - - # Starting in room version v3, some redactions need to be - # rechecked if we didn't have the redacted event at the - # time, so we recheck on read instead. - if redaction_event.internal_metadata.need_to_check_redaction(): - expected_domain = get_domain_from_id(original_ev.sender) - if get_domain_from_id(redaction_event.sender) == expected_domain: - # This redaction event is allowed. Mark as not needing a recheck. - redaction_event.internal_metadata.recheck_redaction = False - else: - # Senders don't match, so the event isn't actually redacted - logger.debug( - "%s was redacted by %s but the senders don't match", - original_ev.event_id, - redaction_id, - ) - continue - - logger.debug("Redacting %s due to %s", original_ev.event_id, redaction_id) - - # we found a good redaction event. Redact! - redacted_event = prune_event(original_ev) - redacted_event.unsigned["redacted_by"] = redaction_id - - # It's fine to add the event directly, since get_pdu_json - # will serialise this field correctly - redacted_event.unsigned["redacted_because"] = redaction_event - - return redacted_event - - # no valid redaction found for this event - return None - - @defer.inlineCallbacks - def have_events_in_timeline(self, event_ids): - """Given a list of event ids, check if we have already processed and - stored them as non outliers. - """ - rows = yield self._simple_select_many_batch( - table="events", - retcols=("event_id",), - column="event_id", - iterable=list(event_ids), - keyvalues={"outlier": False}, - desc="have_events_in_timeline", - ) - - return set(r["event_id"] for r in rows) - - @defer.inlineCallbacks - def have_seen_events(self, event_ids): - """Given a list of event ids, check if we have already processed them. - - Args: - event_ids (iterable[str]): - - Returns: - Deferred[set[str]]: The events we have already seen. - """ - results = set() - - def have_seen_events_txn(txn, chunk): - sql = "SELECT event_id FROM events as e WHERE " - clause, args = make_in_list_sql_clause( - txn.database_engine, "e.event_id", chunk - ) - txn.execute(sql + clause, args) - for (event_id,) in txn: - results.add(event_id) - - # break the input up into chunks of 100 - input_iterator = iter(event_ids) - for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []): - yield self.runInteraction("have_seen_events", have_seen_events_txn, chunk) - return results - - def get_seen_events_with_rejections(self, event_ids): - """Given a list of event ids, check if we rejected them. - - Args: - event_ids (list[str]) - - Returns: - Deferred[dict[str, str|None): - Has an entry for each event id we already have seen. Maps to - the rejected reason string if we rejected the event, else maps - to None. - """ - if not event_ids: - return defer.succeed({}) - - def f(txn): - sql = ( - "SELECT e.event_id, reason FROM events as e " - "LEFT JOIN rejections as r ON e.event_id = r.event_id " - "WHERE e.event_id = ?" - ) - - res = {} - for event_id in event_ids: - txn.execute(sql, (event_id,)) - row = txn.fetchone() - if row: - _, rejected = row - res[event_id] = rejected - - return res - - return self.runInteraction("get_seen_events_with_rejections", f) - - def _get_total_state_event_counts_txn(self, txn, room_id): - """ - See get_total_state_event_counts. - """ - # We join against the events table as that has an index on room_id - sql = """ - SELECT COUNT(*) FROM state_events - INNER JOIN events USING (room_id, event_id) - WHERE room_id=? - """ - txn.execute(sql, (room_id,)) - row = txn.fetchone() - return row[0] if row else 0 - - def get_total_state_event_counts(self, room_id): - """ - Gets the total number of state events in a room. - - Args: - room_id (str) - - Returns: - Deferred[int] - """ - return self.runInteraction( - "get_total_state_event_counts", - self._get_total_state_event_counts_txn, - room_id, - ) - - def _get_current_state_event_counts_txn(self, txn, room_id): - """ - See get_current_state_event_counts. - """ - sql = "SELECT COUNT(*) FROM current_state_events WHERE room_id=?" - txn.execute(sql, (room_id,)) - row = txn.fetchone() - return row[0] if row else 0 - - def get_current_state_event_counts(self, room_id): - """ - Gets the current number of state events in a room. - - Args: - room_id (str) - - Returns: - Deferred[int] - """ - return self.runInteraction( - "get_current_state_event_counts", - self._get_current_state_event_counts_txn, - room_id, - ) - - @defer.inlineCallbacks - def get_room_complexity(self, room_id): - """ - Get a rough approximation of the complexity of the room. This is used by - remote servers to decide whether they wish to join the room or not. - Higher complexity value indicates that being in the room will consume - more resources. - - Args: - room_id (str) - - Returns: - Deferred[dict[str:int]] of complexity version to complexity. - """ - state_events = yield self.get_current_state_event_counts(room_id) - - # Call this one "v1", so we can introduce new ones as we want to develop - # it. - complexity_v1 = round(state_events / 500, 2) - - return {"v1": complexity_v1} diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py deleted file mode 100644 index 7c2a7da836..0000000000 --- a/synapse/storage/filtering.py +++ /dev/null @@ -1,75 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from canonicaljson import encode_canonical_json - -from synapse.api.errors import Codes, SynapseError -from synapse.util.caches.descriptors import cachedInlineCallbacks - -from ._base import SQLBaseStore, db_to_json - - -class FilteringStore(SQLBaseStore): - @cachedInlineCallbacks(num_args=2) - def get_user_filter(self, user_localpart, filter_id): - # filter_id is BIGINT UNSIGNED, so if it isn't a number, fail - # with a coherent error message rather than 500 M_UNKNOWN. - try: - int(filter_id) - except ValueError: - raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM) - - def_json = yield self._simple_select_one_onecol( - table="user_filters", - keyvalues={"user_id": user_localpart, "filter_id": filter_id}, - retcol="filter_json", - allow_none=False, - desc="get_user_filter", - ) - - return db_to_json(def_json) - - def add_user_filter(self, user_localpart, user_filter): - def_json = encode_canonical_json(user_filter) - - # Need an atomic transaction to SELECT the maximal ID so far then - # INSERT a new one - def _do_txn(txn): - sql = ( - "SELECT filter_id FROM user_filters " - "WHERE user_id = ? AND filter_json = ?" - ) - txn.execute(sql, (user_localpart, bytearray(def_json))) - filter_id_response = txn.fetchone() - if filter_id_response is not None: - return filter_id_response[0] - - sql = "SELECT MAX(filter_id) FROM user_filters " "WHERE user_id = ?" - txn.execute(sql, (user_localpart,)) - max_id = txn.fetchone()[0] - if max_id is None: - filter_id = 0 - else: - filter_id = max_id + 1 - - sql = ( - "INSERT INTO user_filters (user_id, filter_id, filter_json)" - "VALUES(?, ?, ?)" - ) - txn.execute(sql, (user_localpart, filter_id, bytearray(def_json))) - - return filter_id - - return self.runInteraction("add_user_filter", _do_txn) diff --git a/synapse/storage/group_server.py b/synapse/storage/group_server.py deleted file mode 100644 index 15b01c6958..0000000000 --- a/synapse/storage/group_server.py +++ /dev/null @@ -1,1181 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2017 Vector Creations Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from canonicaljson import json - -from twisted.internet import defer - -from synapse.api.errors import SynapseError - -from ._base import SQLBaseStore - -# The category ID for the "default" category. We don't store as null in the -# database to avoid the fun of null != null -_DEFAULT_CATEGORY_ID = "" -_DEFAULT_ROLE_ID = "" - - -class GroupServerStore(SQLBaseStore): - def set_group_join_policy(self, group_id, join_policy): - """Set the join policy of a group. - - join_policy can be one of: - * "invite" - * "open" - """ - return self._simple_update_one( - table="groups", - keyvalues={"group_id": group_id}, - updatevalues={"join_policy": join_policy}, - desc="set_group_join_policy", - ) - - def get_group(self, group_id): - return self._simple_select_one( - table="groups", - keyvalues={"group_id": group_id}, - retcols=( - "name", - "short_description", - "long_description", - "avatar_url", - "is_public", - "join_policy", - ), - allow_none=True, - desc="get_group", - ) - - def get_users_in_group(self, group_id, include_private=False): - # TODO: Pagination - - keyvalues = {"group_id": group_id} - if not include_private: - keyvalues["is_public"] = True - - return self._simple_select_list( - table="group_users", - keyvalues=keyvalues, - retcols=("user_id", "is_public", "is_admin"), - desc="get_users_in_group", - ) - - def get_invited_users_in_group(self, group_id): - # TODO: Pagination - - return self._simple_select_onecol( - table="group_invites", - keyvalues={"group_id": group_id}, - retcol="user_id", - desc="get_invited_users_in_group", - ) - - def get_rooms_in_group(self, group_id, include_private=False): - # TODO: Pagination - - keyvalues = {"group_id": group_id} - if not include_private: - keyvalues["is_public"] = True - - return self._simple_select_list( - table="group_rooms", - keyvalues=keyvalues, - retcols=("room_id", "is_public"), - desc="get_rooms_in_group", - ) - - def get_rooms_for_summary_by_category(self, group_id, include_private=False): - """Get the rooms and categories that should be included in a summary request - - Returns ([rooms], [categories]) - """ - - def _get_rooms_for_summary_txn(txn): - keyvalues = {"group_id": group_id} - if not include_private: - keyvalues["is_public"] = True - - sql = """ - SELECT room_id, is_public, category_id, room_order - FROM group_summary_rooms - WHERE group_id = ? - """ - - if not include_private: - sql += " AND is_public = ?" - txn.execute(sql, (group_id, True)) - else: - txn.execute(sql, (group_id,)) - - rooms = [ - { - "room_id": row[0], - "is_public": row[1], - "category_id": row[2] if row[2] != _DEFAULT_CATEGORY_ID else None, - "order": row[3], - } - for row in txn - ] - - sql = """ - SELECT category_id, is_public, profile, cat_order - FROM group_summary_room_categories - INNER JOIN group_room_categories USING (group_id, category_id) - WHERE group_id = ? - """ - - if not include_private: - sql += " AND is_public = ?" - txn.execute(sql, (group_id, True)) - else: - txn.execute(sql, (group_id,)) - - categories = { - row[0]: { - "is_public": row[1], - "profile": json.loads(row[2]), - "order": row[3], - } - for row in txn - } - - return rooms, categories - - return self.runInteraction("get_rooms_for_summary", _get_rooms_for_summary_txn) - - def add_room_to_summary(self, group_id, room_id, category_id, order, is_public): - return self.runInteraction( - "add_room_to_summary", - self._add_room_to_summary_txn, - group_id, - room_id, - category_id, - order, - is_public, - ) - - def _add_room_to_summary_txn( - self, txn, group_id, room_id, category_id, order, is_public - ): - """Add (or update) room's entry in summary. - - Args: - group_id (str) - room_id (str) - category_id (str): If not None then adds the category to the end of - the summary if its not already there. [Optional] - order (int): If not None inserts the room at that position, e.g. - an order of 1 will put the room first. Otherwise, the room gets - added to the end. - """ - room_in_group = self._simple_select_one_onecol_txn( - txn, - table="group_rooms", - keyvalues={"group_id": group_id, "room_id": room_id}, - retcol="room_id", - allow_none=True, - ) - if not room_in_group: - raise SynapseError(400, "room not in group") - - if category_id is None: - category_id = _DEFAULT_CATEGORY_ID - else: - cat_exists = self._simple_select_one_onecol_txn( - txn, - table="group_room_categories", - keyvalues={"group_id": group_id, "category_id": category_id}, - retcol="group_id", - allow_none=True, - ) - if not cat_exists: - raise SynapseError(400, "Category doesn't exist") - - # TODO: Check category is part of summary already - cat_exists = self._simple_select_one_onecol_txn( - txn, - table="group_summary_room_categories", - keyvalues={"group_id": group_id, "category_id": category_id}, - retcol="group_id", - allow_none=True, - ) - if not cat_exists: - # If not, add it with an order larger than all others - txn.execute( - """ - INSERT INTO group_summary_room_categories - (group_id, category_id, cat_order) - SELECT ?, ?, COALESCE(MAX(cat_order), 0) + 1 - FROM group_summary_room_categories - WHERE group_id = ? AND category_id = ? - """, - (group_id, category_id, group_id, category_id), - ) - - existing = self._simple_select_one_txn( - txn, - table="group_summary_rooms", - keyvalues={ - "group_id": group_id, - "room_id": room_id, - "category_id": category_id, - }, - retcols=("room_order", "is_public"), - allow_none=True, - ) - - if order is not None: - # Shuffle other room orders that come after the given order - sql = """ - UPDATE group_summary_rooms SET room_order = room_order + 1 - WHERE group_id = ? AND category_id = ? AND room_order >= ? - """ - txn.execute(sql, (group_id, category_id, order)) - elif not existing: - sql = """ - SELECT COALESCE(MAX(room_order), 0) + 1 FROM group_summary_rooms - WHERE group_id = ? AND category_id = ? - """ - txn.execute(sql, (group_id, category_id)) - order, = txn.fetchone() - - if existing: - to_update = {} - if order is not None: - to_update["room_order"] = order - if is_public is not None: - to_update["is_public"] = is_public - self._simple_update_txn( - txn, - table="group_summary_rooms", - keyvalues={ - "group_id": group_id, - "category_id": category_id, - "room_id": room_id, - }, - values=to_update, - ) - else: - if is_public is None: - is_public = True - - self._simple_insert_txn( - txn, - table="group_summary_rooms", - values={ - "group_id": group_id, - "category_id": category_id, - "room_id": room_id, - "room_order": order, - "is_public": is_public, - }, - ) - - def remove_room_from_summary(self, group_id, room_id, category_id): - if category_id is None: - category_id = _DEFAULT_CATEGORY_ID - - return self._simple_delete( - table="group_summary_rooms", - keyvalues={ - "group_id": group_id, - "category_id": category_id, - "room_id": room_id, - }, - desc="remove_room_from_summary", - ) - - @defer.inlineCallbacks - def get_group_categories(self, group_id): - rows = yield self._simple_select_list( - table="group_room_categories", - keyvalues={"group_id": group_id}, - retcols=("category_id", "is_public", "profile"), - desc="get_group_categories", - ) - - return { - row["category_id"]: { - "is_public": row["is_public"], - "profile": json.loads(row["profile"]), - } - for row in rows - } - - @defer.inlineCallbacks - def get_group_category(self, group_id, category_id): - category = yield self._simple_select_one( - table="group_room_categories", - keyvalues={"group_id": group_id, "category_id": category_id}, - retcols=("is_public", "profile"), - desc="get_group_category", - ) - - category["profile"] = json.loads(category["profile"]) - - return category - - def upsert_group_category(self, group_id, category_id, profile, is_public): - """Add/update room category for group - """ - insertion_values = {} - update_values = {"category_id": category_id} # This cannot be empty - - if profile is None: - insertion_values["profile"] = "{}" - else: - update_values["profile"] = json.dumps(profile) - - if is_public is None: - insertion_values["is_public"] = True - else: - update_values["is_public"] = is_public - - return self._simple_upsert( - table="group_room_categories", - keyvalues={"group_id": group_id, "category_id": category_id}, - values=update_values, - insertion_values=insertion_values, - desc="upsert_group_category", - ) - - def remove_group_category(self, group_id, category_id): - return self._simple_delete( - table="group_room_categories", - keyvalues={"group_id": group_id, "category_id": category_id}, - desc="remove_group_category", - ) - - @defer.inlineCallbacks - def get_group_roles(self, group_id): - rows = yield self._simple_select_list( - table="group_roles", - keyvalues={"group_id": group_id}, - retcols=("role_id", "is_public", "profile"), - desc="get_group_roles", - ) - - return { - row["role_id"]: { - "is_public": row["is_public"], - "profile": json.loads(row["profile"]), - } - for row in rows - } - - @defer.inlineCallbacks - def get_group_role(self, group_id, role_id): - role = yield self._simple_select_one( - table="group_roles", - keyvalues={"group_id": group_id, "role_id": role_id}, - retcols=("is_public", "profile"), - desc="get_group_role", - ) - - role["profile"] = json.loads(role["profile"]) - - return role - - def upsert_group_role(self, group_id, role_id, profile, is_public): - """Add/remove user role - """ - insertion_values = {} - update_values = {"role_id": role_id} # This cannot be empty - - if profile is None: - insertion_values["profile"] = "{}" - else: - update_values["profile"] = json.dumps(profile) - - if is_public is None: - insertion_values["is_public"] = True - else: - update_values["is_public"] = is_public - - return self._simple_upsert( - table="group_roles", - keyvalues={"group_id": group_id, "role_id": role_id}, - values=update_values, - insertion_values=insertion_values, - desc="upsert_group_role", - ) - - def remove_group_role(self, group_id, role_id): - return self._simple_delete( - table="group_roles", - keyvalues={"group_id": group_id, "role_id": role_id}, - desc="remove_group_role", - ) - - def add_user_to_summary(self, group_id, user_id, role_id, order, is_public): - return self.runInteraction( - "add_user_to_summary", - self._add_user_to_summary_txn, - group_id, - user_id, - role_id, - order, - is_public, - ) - - def _add_user_to_summary_txn( - self, txn, group_id, user_id, role_id, order, is_public - ): - """Add (or update) user's entry in summary. - - Args: - group_id (str) - user_id (str) - role_id (str): If not None then adds the role to the end of - the summary if its not already there. [Optional] - order (int): If not None inserts the user at that position, e.g. - an order of 1 will put the user first. Otherwise, the user gets - added to the end. - """ - user_in_group = self._simple_select_one_onecol_txn( - txn, - table="group_users", - keyvalues={"group_id": group_id, "user_id": user_id}, - retcol="user_id", - allow_none=True, - ) - if not user_in_group: - raise SynapseError(400, "user not in group") - - if role_id is None: - role_id = _DEFAULT_ROLE_ID - else: - role_exists = self._simple_select_one_onecol_txn( - txn, - table="group_roles", - keyvalues={"group_id": group_id, "role_id": role_id}, - retcol="group_id", - allow_none=True, - ) - if not role_exists: - raise SynapseError(400, "Role doesn't exist") - - # TODO: Check role is part of the summary already - role_exists = self._simple_select_one_onecol_txn( - txn, - table="group_summary_roles", - keyvalues={"group_id": group_id, "role_id": role_id}, - retcol="group_id", - allow_none=True, - ) - if not role_exists: - # If not, add it with an order larger than all others - txn.execute( - """ - INSERT INTO group_summary_roles - (group_id, role_id, role_order) - SELECT ?, ?, COALESCE(MAX(role_order), 0) + 1 - FROM group_summary_roles - WHERE group_id = ? AND role_id = ? - """, - (group_id, role_id, group_id, role_id), - ) - - existing = self._simple_select_one_txn( - txn, - table="group_summary_users", - keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id}, - retcols=("user_order", "is_public"), - allow_none=True, - ) - - if order is not None: - # Shuffle other users orders that come after the given order - sql = """ - UPDATE group_summary_users SET user_order = user_order + 1 - WHERE group_id = ? AND role_id = ? AND user_order >= ? - """ - txn.execute(sql, (group_id, role_id, order)) - elif not existing: - sql = """ - SELECT COALESCE(MAX(user_order), 0) + 1 FROM group_summary_users - WHERE group_id = ? AND role_id = ? - """ - txn.execute(sql, (group_id, role_id)) - order, = txn.fetchone() - - if existing: - to_update = {} - if order is not None: - to_update["user_order"] = order - if is_public is not None: - to_update["is_public"] = is_public - self._simple_update_txn( - txn, - table="group_summary_users", - keyvalues={ - "group_id": group_id, - "role_id": role_id, - "user_id": user_id, - }, - values=to_update, - ) - else: - if is_public is None: - is_public = True - - self._simple_insert_txn( - txn, - table="group_summary_users", - values={ - "group_id": group_id, - "role_id": role_id, - "user_id": user_id, - "user_order": order, - "is_public": is_public, - }, - ) - - def remove_user_from_summary(self, group_id, user_id, role_id): - if role_id is None: - role_id = _DEFAULT_ROLE_ID - - return self._simple_delete( - table="group_summary_users", - keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id}, - desc="remove_user_from_summary", - ) - - def get_users_for_summary_by_role(self, group_id, include_private=False): - """Get the users and roles that should be included in a summary request - - Returns ([users], [roles]) - """ - - def _get_users_for_summary_txn(txn): - keyvalues = {"group_id": group_id} - if not include_private: - keyvalues["is_public"] = True - - sql = """ - SELECT user_id, is_public, role_id, user_order - FROM group_summary_users - WHERE group_id = ? - """ - - if not include_private: - sql += " AND is_public = ?" - txn.execute(sql, (group_id, True)) - else: - txn.execute(sql, (group_id,)) - - users = [ - { - "user_id": row[0], - "is_public": row[1], - "role_id": row[2] if row[2] != _DEFAULT_ROLE_ID else None, - "order": row[3], - } - for row in txn - ] - - sql = """ - SELECT role_id, is_public, profile, role_order - FROM group_summary_roles - INNER JOIN group_roles USING (group_id, role_id) - WHERE group_id = ? - """ - - if not include_private: - sql += " AND is_public = ?" - txn.execute(sql, (group_id, True)) - else: - txn.execute(sql, (group_id,)) - - roles = { - row[0]: { - "is_public": row[1], - "profile": json.loads(row[2]), - "order": row[3], - } - for row in txn - } - - return users, roles - - return self.runInteraction( - "get_users_for_summary_by_role", _get_users_for_summary_txn - ) - - def is_user_in_group(self, user_id, group_id): - return self._simple_select_one_onecol( - table="group_users", - keyvalues={"group_id": group_id, "user_id": user_id}, - retcol="user_id", - allow_none=True, - desc="is_user_in_group", - ).addCallback(lambda r: bool(r)) - - def is_user_admin_in_group(self, group_id, user_id): - return self._simple_select_one_onecol( - table="group_users", - keyvalues={"group_id": group_id, "user_id": user_id}, - retcol="is_admin", - allow_none=True, - desc="is_user_admin_in_group", - ) - - def add_group_invite(self, group_id, user_id): - """Record that the group server has invited a user - """ - return self._simple_insert( - table="group_invites", - values={"group_id": group_id, "user_id": user_id}, - desc="add_group_invite", - ) - - def is_user_invited_to_local_group(self, group_id, user_id): - """Has the group server invited a user? - """ - return self._simple_select_one_onecol( - table="group_invites", - keyvalues={"group_id": group_id, "user_id": user_id}, - retcol="user_id", - desc="is_user_invited_to_local_group", - allow_none=True, - ) - - def get_users_membership_info_in_group(self, group_id, user_id): - """Get a dict describing the membership of a user in a group. - - Example if joined: - - { - "membership": "join", - "is_public": True, - "is_privileged": False, - } - - Returns an empty dict if the user is not join/invite/etc - """ - - def _get_users_membership_in_group_txn(txn): - row = self._simple_select_one_txn( - txn, - table="group_users", - keyvalues={"group_id": group_id, "user_id": user_id}, - retcols=("is_admin", "is_public"), - allow_none=True, - ) - - if row: - return { - "membership": "join", - "is_public": row["is_public"], - "is_privileged": row["is_admin"], - } - - row = self._simple_select_one_onecol_txn( - txn, - table="group_invites", - keyvalues={"group_id": group_id, "user_id": user_id}, - retcol="user_id", - allow_none=True, - ) - - if row: - return {"membership": "invite"} - - return {} - - return self.runInteraction( - "get_users_membership_info_in_group", _get_users_membership_in_group_txn - ) - - def add_user_to_group( - self, - group_id, - user_id, - is_admin=False, - is_public=True, - local_attestation=None, - remote_attestation=None, - ): - """Add a user to the group server. - - Args: - group_id (str) - user_id (str) - is_admin (bool) - is_public (bool) - local_attestation (dict): The attestation the GS created to give - to the remote server. Optional if the user and group are on the - same server - remote_attestation (dict): The attestation given to GS by remote - server. Optional if the user and group are on the same server - """ - - def _add_user_to_group_txn(txn): - self._simple_insert_txn( - txn, - table="group_users", - values={ - "group_id": group_id, - "user_id": user_id, - "is_admin": is_admin, - "is_public": is_public, - }, - ) - - self._simple_delete_txn( - txn, - table="group_invites", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - - if local_attestation: - self._simple_insert_txn( - txn, - table="group_attestations_renewals", - values={ - "group_id": group_id, - "user_id": user_id, - "valid_until_ms": local_attestation["valid_until_ms"], - }, - ) - if remote_attestation: - self._simple_insert_txn( - txn, - table="group_attestations_remote", - values={ - "group_id": group_id, - "user_id": user_id, - "valid_until_ms": remote_attestation["valid_until_ms"], - "attestation_json": json.dumps(remote_attestation), - }, - ) - - return self.runInteraction("add_user_to_group", _add_user_to_group_txn) - - def remove_user_from_group(self, group_id, user_id): - def _remove_user_from_group_txn(txn): - self._simple_delete_txn( - txn, - table="group_users", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - self._simple_delete_txn( - txn, - table="group_invites", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - self._simple_delete_txn( - txn, - table="group_attestations_renewals", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - self._simple_delete_txn( - txn, - table="group_attestations_remote", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - self._simple_delete_txn( - txn, - table="group_summary_users", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - - return self.runInteraction( - "remove_user_from_group", _remove_user_from_group_txn - ) - - def add_room_to_group(self, group_id, room_id, is_public): - return self._simple_insert( - table="group_rooms", - values={"group_id": group_id, "room_id": room_id, "is_public": is_public}, - desc="add_room_to_group", - ) - - def update_room_in_group_visibility(self, group_id, room_id, is_public): - return self._simple_update( - table="group_rooms", - keyvalues={"group_id": group_id, "room_id": room_id}, - updatevalues={"is_public": is_public}, - desc="update_room_in_group_visibility", - ) - - def remove_room_from_group(self, group_id, room_id): - def _remove_room_from_group_txn(txn): - self._simple_delete_txn( - txn, - table="group_rooms", - keyvalues={"group_id": group_id, "room_id": room_id}, - ) - - self._simple_delete_txn( - txn, - table="group_summary_rooms", - keyvalues={"group_id": group_id, "room_id": room_id}, - ) - - return self.runInteraction( - "remove_room_from_group", _remove_room_from_group_txn - ) - - def get_publicised_groups_for_user(self, user_id): - """Get all groups a user is publicising - """ - return self._simple_select_onecol( - table="local_group_membership", - keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True}, - retcol="group_id", - desc="get_publicised_groups_for_user", - ) - - def update_group_publicity(self, group_id, user_id, publicise): - """Update whether the user is publicising their membership of the group - """ - return self._simple_update_one( - table="local_group_membership", - keyvalues={"group_id": group_id, "user_id": user_id}, - updatevalues={"is_publicised": publicise}, - desc="update_group_publicity", - ) - - @defer.inlineCallbacks - def register_user_group_membership( - self, - group_id, - user_id, - membership, - is_admin=False, - content={}, - local_attestation=None, - remote_attestation=None, - is_publicised=False, - ): - """Registers that a local user is a member of a (local or remote) group. - - Args: - group_id (str) - user_id (str) - membership (str) - is_admin (bool) - content (dict): Content of the membership, e.g. includes the inviter - if the user has been invited. - local_attestation (dict): If remote group then store the fact that we - have given out an attestation, else None. - remote_attestation (dict): If remote group then store the remote - attestation from the group, else None. - """ - - def _register_user_group_membership_txn(txn, next_id): - # TODO: Upsert? - self._simple_delete_txn( - txn, - table="local_group_membership", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - self._simple_insert_txn( - txn, - table="local_group_membership", - values={ - "group_id": group_id, - "user_id": user_id, - "is_admin": is_admin, - "membership": membership, - "is_publicised": is_publicised, - "content": json.dumps(content), - }, - ) - - self._simple_insert_txn( - txn, - table="local_group_updates", - values={ - "stream_id": next_id, - "group_id": group_id, - "user_id": user_id, - "type": "membership", - "content": json.dumps( - {"membership": membership, "content": content} - ), - }, - ) - self._group_updates_stream_cache.entity_has_changed(user_id, next_id) - - # TODO: Insert profile to ensure it comes down stream if its a join. - - if membership == "join": - if local_attestation: - self._simple_insert_txn( - txn, - table="group_attestations_renewals", - values={ - "group_id": group_id, - "user_id": user_id, - "valid_until_ms": local_attestation["valid_until_ms"], - }, - ) - if remote_attestation: - self._simple_insert_txn( - txn, - table="group_attestations_remote", - values={ - "group_id": group_id, - "user_id": user_id, - "valid_until_ms": remote_attestation["valid_until_ms"], - "attestation_json": json.dumps(remote_attestation), - }, - ) - else: - self._simple_delete_txn( - txn, - table="group_attestations_renewals", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - self._simple_delete_txn( - txn, - table="group_attestations_remote", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - - return next_id - - with self._group_updates_id_gen.get_next() as next_id: - res = yield self.runInteraction( - "register_user_group_membership", - _register_user_group_membership_txn, - next_id, - ) - return res - - @defer.inlineCallbacks - def create_group( - self, group_id, user_id, name, avatar_url, short_description, long_description - ): - yield self._simple_insert( - table="groups", - values={ - "group_id": group_id, - "name": name, - "avatar_url": avatar_url, - "short_description": short_description, - "long_description": long_description, - "is_public": True, - }, - desc="create_group", - ) - - @defer.inlineCallbacks - def update_group_profile(self, group_id, profile): - yield self._simple_update_one( - table="groups", - keyvalues={"group_id": group_id}, - updatevalues=profile, - desc="update_group_profile", - ) - - def get_attestations_need_renewals(self, valid_until_ms): - """Get all attestations that need to be renewed until givent time - """ - - def _get_attestations_need_renewals_txn(txn): - sql = """ - SELECT group_id, user_id FROM group_attestations_renewals - WHERE valid_until_ms <= ? - """ - txn.execute(sql, (valid_until_ms,)) - return self.cursor_to_dict(txn) - - return self.runInteraction( - "get_attestations_need_renewals", _get_attestations_need_renewals_txn - ) - - def update_attestation_renewal(self, group_id, user_id, attestation): - """Update an attestation that we have renewed - """ - return self._simple_update_one( - table="group_attestations_renewals", - keyvalues={"group_id": group_id, "user_id": user_id}, - updatevalues={"valid_until_ms": attestation["valid_until_ms"]}, - desc="update_attestation_renewal", - ) - - def update_remote_attestion(self, group_id, user_id, attestation): - """Update an attestation that a remote has renewed - """ - return self._simple_update_one( - table="group_attestations_remote", - keyvalues={"group_id": group_id, "user_id": user_id}, - updatevalues={ - "valid_until_ms": attestation["valid_until_ms"], - "attestation_json": json.dumps(attestation), - }, - desc="update_remote_attestion", - ) - - def remove_attestation_renewal(self, group_id, user_id): - """Remove an attestation that we thought we should renew, but actually - shouldn't. Ideally this would never get called as we would never - incorrectly try and do attestations for local users on local groups. - - Args: - group_id (str) - user_id (str) - """ - return self._simple_delete( - table="group_attestations_renewals", - keyvalues={"group_id": group_id, "user_id": user_id}, - desc="remove_attestation_renewal", - ) - - @defer.inlineCallbacks - def get_remote_attestation(self, group_id, user_id): - """Get the attestation that proves the remote agrees that the user is - in the group. - """ - row = yield self._simple_select_one( - table="group_attestations_remote", - keyvalues={"group_id": group_id, "user_id": user_id}, - retcols=("valid_until_ms", "attestation_json"), - desc="get_remote_attestation", - allow_none=True, - ) - - now = int(self._clock.time_msec()) - if row and now < row["valid_until_ms"]: - return json.loads(row["attestation_json"]) - - return None - - def get_joined_groups(self, user_id): - return self._simple_select_onecol( - table="local_group_membership", - keyvalues={"user_id": user_id, "membership": "join"}, - retcol="group_id", - desc="get_joined_groups", - ) - - def get_all_groups_for_user(self, user_id, now_token): - def _get_all_groups_for_user_txn(txn): - sql = """ - SELECT group_id, type, membership, u.content - FROM local_group_updates AS u - INNER JOIN local_group_membership USING (group_id, user_id) - WHERE user_id = ? AND membership != 'leave' - AND stream_id <= ? - """ - txn.execute(sql, (user_id, now_token)) - return [ - { - "group_id": row[0], - "type": row[1], - "membership": row[2], - "content": json.loads(row[3]), - } - for row in txn - ] - - return self.runInteraction( - "get_all_groups_for_user", _get_all_groups_for_user_txn - ) - - def get_groups_changes_for_user(self, user_id, from_token, to_token): - from_token = int(from_token) - has_changed = self._group_updates_stream_cache.has_entity_changed( - user_id, from_token - ) - if not has_changed: - return [] - - def _get_groups_changes_for_user_txn(txn): - sql = """ - SELECT group_id, membership, type, u.content - FROM local_group_updates AS u - INNER JOIN local_group_membership USING (group_id, user_id) - WHERE user_id = ? AND ? < stream_id AND stream_id <= ? - """ - txn.execute(sql, (user_id, from_token, to_token)) - return [ - { - "group_id": group_id, - "membership": membership, - "type": gtype, - "content": json.loads(content_json), - } - for group_id, membership, gtype, content_json in txn - ] - - return self.runInteraction( - "get_groups_changes_for_user", _get_groups_changes_for_user_txn - ) - - def get_all_groups_changes(self, from_token, to_token, limit): - from_token = int(from_token) - has_changed = self._group_updates_stream_cache.has_any_entity_changed( - from_token - ) - if not has_changed: - return [] - - def _get_all_groups_changes_txn(txn): - sql = """ - SELECT stream_id, group_id, user_id, type, content - FROM local_group_updates - WHERE ? < stream_id AND stream_id <= ? - LIMIT ? - """ - txn.execute(sql, (from_token, to_token, limit)) - return [ - (stream_id, group_id, user_id, gtype, json.loads(content_json)) - for stream_id, group_id, user_id, gtype, content_json in txn - ] - - return self.runInteraction( - "get_all_groups_changes", _get_all_groups_changes_txn - ) - - def get_group_stream_token(self): - return self._group_updates_id_gen.get_current_token() - - def delete_group(self, group_id): - """Deletes a group fully from the database. - - Args: - group_id (str) - - Returns: - Deferred - """ - - def _delete_group_txn(txn): - tables = [ - "groups", - "group_users", - "group_invites", - "group_rooms", - "group_summary_rooms", - "group_summary_room_categories", - "group_room_categories", - "group_summary_users", - "group_summary_roles", - "group_roles", - "group_attestations_renewals", - "group_attestations_remote", - ] - - for table in tables: - self._simple_delete_txn( - txn, table=table, keyvalues={"group_id": group_id} - ) - - return self.runInteraction("delete_group", _delete_group_txn) diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index e72f89e446..4769b21529 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -14,208 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import itertools import logging -import six - import attr -from signedjson.key import decode_verify_key_bytes - -from synapse.util import batch_iter -from synapse.util.caches.descriptors import cached, cachedList - -from ._base import SQLBaseStore logger = logging.getLogger(__name__) -# py2 sqlite has buffer hardcoded as only binary type, so we must use it, -# despite being deprecated and removed in favor of memoryview -if six.PY2: - db_binary_type = six.moves.builtins.buffer -else: - db_binary_type = memoryview - @attr.s(slots=True, frozen=True) class FetchKeyResult(object): verify_key = attr.ib() # VerifyKey: the key itself valid_until_ts = attr.ib() # int: how long we can use this key for - - -class KeyStore(SQLBaseStore): - """Persistence for signature verification keys - """ - - @cached() - def _get_server_verify_key(self, server_name_and_key_id): - raise NotImplementedError() - - @cachedList( - cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids" - ) - def get_server_verify_keys(self, server_name_and_key_ids): - """ - Args: - server_name_and_key_ids (iterable[Tuple[str, str]]): - iterable of (server_name, key-id) tuples to fetch keys for - - Returns: - Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]: - map from (server_name, key_id) -> FetchKeyResult, or None if the key is - unknown - """ - keys = {} - - def _get_keys(txn, batch): - """Processes a batch of keys to fetch, and adds the result to `keys`.""" - - # batch_iter always returns tuples so it's safe to do len(batch) - sql = ( - "SELECT server_name, key_id, verify_key, ts_valid_until_ms " - "FROM server_signature_keys WHERE 1=0" - ) + " OR (server_name=? AND key_id=?)" * len(batch) - - txn.execute(sql, tuple(itertools.chain.from_iterable(batch))) - - for row in txn: - server_name, key_id, key_bytes, ts_valid_until_ms = row - - if ts_valid_until_ms is None: - # Old keys may be stored with a ts_valid_until_ms of null, - # in which case we treat this as if it was set to `0`, i.e. - # it won't match key requests that define a minimum - # `ts_valid_until_ms`. - ts_valid_until_ms = 0 - - res = FetchKeyResult( - verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)), - valid_until_ts=ts_valid_until_ms, - ) - keys[(server_name, key_id)] = res - - def _txn(txn): - for batch in batch_iter(server_name_and_key_ids, 50): - _get_keys(txn, batch) - return keys - - return self.runInteraction("get_server_verify_keys", _txn) - - def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys): - """Stores NACL verification keys for remote servers. - Args: - from_server (str): Where the verification keys were looked up - ts_added_ms (int): The time to record that the key was added - verify_keys (iterable[tuple[str, str, FetchKeyResult]]): - keys to be stored. Each entry is a triplet of - (server_name, key_id, key). - """ - key_values = [] - value_values = [] - invalidations = [] - for server_name, key_id, fetch_result in verify_keys: - key_values.append((server_name, key_id)) - value_values.append( - ( - from_server, - ts_added_ms, - fetch_result.valid_until_ts, - db_binary_type(fetch_result.verify_key.encode()), - ) - ) - # invalidate takes a tuple corresponding to the params of - # _get_server_verify_key. _get_server_verify_key only takes one - # param, which is itself the 2-tuple (server_name, key_id). - invalidations.append((server_name, key_id)) - - def _invalidate(res): - f = self._get_server_verify_key.invalidate - for i in invalidations: - f((i,)) - return res - - return self.runInteraction( - "store_server_verify_keys", - self._simple_upsert_many_txn, - table="server_signature_keys", - key_names=("server_name", "key_id"), - key_values=key_values, - value_names=( - "from_server", - "ts_added_ms", - "ts_valid_until_ms", - "verify_key", - ), - value_values=value_values, - ).addCallback(_invalidate) - - def store_server_keys_json( - self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes - ): - """Stores the JSON bytes for a set of keys from a server - The JSON should be signed by the originating server, the intermediate - server, and by this server. Updates the value for the - (server_name, key_id, from_server) triplet if one already existed. - Args: - server_name (str): The name of the server. - key_id (str): The identifer of the key this JSON is for. - from_server (str): The server this JSON was fetched from. - ts_now_ms (int): The time now in milliseconds. - ts_valid_until_ms (int): The time when this json stops being valid. - key_json (bytes): The encoded JSON. - """ - return self._simple_upsert( - table="server_keys_json", - keyvalues={ - "server_name": server_name, - "key_id": key_id, - "from_server": from_server, - }, - values={ - "server_name": server_name, - "key_id": key_id, - "from_server": from_server, - "ts_added_ms": ts_now_ms, - "ts_valid_until_ms": ts_expires_ms, - "key_json": db_binary_type(key_json_bytes), - }, - desc="store_server_keys_json", - ) - - def get_server_keys_json(self, server_keys): - """Retrive the key json for a list of server_keys and key ids. - If no keys are found for a given server, key_id and source then - that server, key_id, and source triplet entry will be an empty list. - The JSON is returned as a byte array so that it can be efficiently - used in an HTTP response. - Args: - server_keys (list): List of (server_name, key_id, source) triplets. - Returns: - Deferred[dict[Tuple[str, str, str|None], list[dict]]]: - Dict mapping (server_name, key_id, source) triplets to lists of dicts - """ - - def _get_server_keys_json_txn(txn): - results = {} - for server_name, key_id, from_server in server_keys: - keyvalues = {"server_name": server_name} - if key_id is not None: - keyvalues["key_id"] = key_id - if from_server is not None: - keyvalues["from_server"] = from_server - rows = self._simple_select_list_txn( - txn, - "server_keys_json", - keyvalues=keyvalues, - retcols=( - "key_id", - "from_server", - "ts_added_ms", - "ts_valid_until_ms", - "key_json", - ), - ) - results[(server_name, key_id, from_server)] = rows - return results - - return self.runInteraction("get_server_keys_json", _get_server_keys_json_txn) diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py deleted file mode 100644 index 84b5f3ad5e..0000000000 --- a/synapse/storage/media_repository.py +++ /dev/null @@ -1,378 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from synapse.storage.background_updates import BackgroundUpdateStore - - -class MediaRepositoryBackgroundUpdateStore(BackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(MediaRepositoryBackgroundUpdateStore, self).__init__(db_conn, hs) - - self.register_background_index_update( - update_name="local_media_repository_url_idx", - index_name="local_media_repository_url_idx", - table="local_media_repository", - columns=["created_ts"], - where_clause="url_cache IS NOT NULL", - ) - - -class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): - """Persistence for attachments and avatars""" - - def __init__(self, db_conn, hs): - super(MediaRepositoryStore, self).__init__(db_conn, hs) - - def get_local_media(self, media_id): - """Get the metadata for a local piece of media - Returns: - None if the media_id doesn't exist. - """ - return self._simple_select_one( - "local_media_repository", - {"media_id": media_id}, - ( - "media_type", - "media_length", - "upload_name", - "created_ts", - "quarantined_by", - "url_cache", - ), - allow_none=True, - desc="get_local_media", - ) - - def store_local_media( - self, - media_id, - media_type, - time_now_ms, - upload_name, - media_length, - user_id, - url_cache=None, - ): - return self._simple_insert( - "local_media_repository", - { - "media_id": media_id, - "media_type": media_type, - "created_ts": time_now_ms, - "upload_name": upload_name, - "media_length": media_length, - "user_id": user_id.to_string(), - "url_cache": url_cache, - }, - desc="store_local_media", - ) - - def get_url_cache(self, url, ts): - """Get the media_id and ts for a cached URL as of the given timestamp - Returns: - None if the URL isn't cached. - """ - - def get_url_cache_txn(txn): - # get the most recently cached result (relative to the given ts) - sql = ( - "SELECT response_code, etag, expires_ts, og, media_id, download_ts" - " FROM local_media_repository_url_cache" - " WHERE url = ? AND download_ts <= ?" - " ORDER BY download_ts DESC LIMIT 1" - ) - txn.execute(sql, (url, ts)) - row = txn.fetchone() - - if not row: - # ...or if we've requested a timestamp older than the oldest - # copy in the cache, return the oldest copy (if any) - sql = ( - "SELECT response_code, etag, expires_ts, og, media_id, download_ts" - " FROM local_media_repository_url_cache" - " WHERE url = ? AND download_ts > ?" - " ORDER BY download_ts ASC LIMIT 1" - ) - txn.execute(sql, (url, ts)) - row = txn.fetchone() - - if not row: - return None - - return dict( - zip( - ( - "response_code", - "etag", - "expires_ts", - "og", - "media_id", - "download_ts", - ), - row, - ) - ) - - return self.runInteraction("get_url_cache", get_url_cache_txn) - - def store_url_cache( - self, url, response_code, etag, expires_ts, og, media_id, download_ts - ): - return self._simple_insert( - "local_media_repository_url_cache", - { - "url": url, - "response_code": response_code, - "etag": etag, - "expires_ts": expires_ts, - "og": og, - "media_id": media_id, - "download_ts": download_ts, - }, - desc="store_url_cache", - ) - - def get_local_media_thumbnails(self, media_id): - return self._simple_select_list( - "local_media_repository_thumbnails", - {"media_id": media_id}, - ( - "thumbnail_width", - "thumbnail_height", - "thumbnail_method", - "thumbnail_type", - "thumbnail_length", - ), - desc="get_local_media_thumbnails", - ) - - def store_local_thumbnail( - self, - media_id, - thumbnail_width, - thumbnail_height, - thumbnail_type, - thumbnail_method, - thumbnail_length, - ): - return self._simple_insert( - "local_media_repository_thumbnails", - { - "media_id": media_id, - "thumbnail_width": thumbnail_width, - "thumbnail_height": thumbnail_height, - "thumbnail_method": thumbnail_method, - "thumbnail_type": thumbnail_type, - "thumbnail_length": thumbnail_length, - }, - desc="store_local_thumbnail", - ) - - def get_cached_remote_media(self, origin, media_id): - return self._simple_select_one( - "remote_media_cache", - {"media_origin": origin, "media_id": media_id}, - ( - "media_type", - "media_length", - "upload_name", - "created_ts", - "filesystem_id", - "quarantined_by", - ), - allow_none=True, - desc="get_cached_remote_media", - ) - - def store_cached_remote_media( - self, - origin, - media_id, - media_type, - media_length, - time_now_ms, - upload_name, - filesystem_id, - ): - return self._simple_insert( - "remote_media_cache", - { - "media_origin": origin, - "media_id": media_id, - "media_type": media_type, - "media_length": media_length, - "created_ts": time_now_ms, - "upload_name": upload_name, - "filesystem_id": filesystem_id, - "last_access_ts": time_now_ms, - }, - desc="store_cached_remote_media", - ) - - def update_cached_last_access_time(self, local_media, remote_media, time_ms): - """Updates the last access time of the given media - - Args: - local_media (iterable[str]): Set of media_ids - remote_media (iterable[(str, str)]): Set of (server_name, media_id) - time_ms: Current time in milliseconds - """ - - def update_cache_txn(txn): - sql = ( - "UPDATE remote_media_cache SET last_access_ts = ?" - " WHERE media_origin = ? AND media_id = ?" - ) - - txn.executemany( - sql, - ( - (time_ms, media_origin, media_id) - for media_origin, media_id in remote_media - ), - ) - - sql = ( - "UPDATE local_media_repository SET last_access_ts = ?" - " WHERE media_id = ?" - ) - - txn.executemany(sql, ((time_ms, media_id) for media_id in local_media)) - - return self.runInteraction("update_cached_last_access_time", update_cache_txn) - - def get_remote_media_thumbnails(self, origin, media_id): - return self._simple_select_list( - "remote_media_cache_thumbnails", - {"media_origin": origin, "media_id": media_id}, - ( - "thumbnail_width", - "thumbnail_height", - "thumbnail_method", - "thumbnail_type", - "thumbnail_length", - "filesystem_id", - ), - desc="get_remote_media_thumbnails", - ) - - def store_remote_media_thumbnail( - self, - origin, - media_id, - filesystem_id, - thumbnail_width, - thumbnail_height, - thumbnail_type, - thumbnail_method, - thumbnail_length, - ): - return self._simple_insert( - "remote_media_cache_thumbnails", - { - "media_origin": origin, - "media_id": media_id, - "thumbnail_width": thumbnail_width, - "thumbnail_height": thumbnail_height, - "thumbnail_method": thumbnail_method, - "thumbnail_type": thumbnail_type, - "thumbnail_length": thumbnail_length, - "filesystem_id": filesystem_id, - }, - desc="store_remote_media_thumbnail", - ) - - def get_remote_media_before(self, before_ts): - sql = ( - "SELECT media_origin, media_id, filesystem_id" - " FROM remote_media_cache" - " WHERE last_access_ts < ?" - ) - - return self._execute( - "get_remote_media_before", self.cursor_to_dict, sql, before_ts - ) - - def delete_remote_media(self, media_origin, media_id): - def delete_remote_media_txn(txn): - self._simple_delete_txn( - txn, - "remote_media_cache", - keyvalues={"media_origin": media_origin, "media_id": media_id}, - ) - self._simple_delete_txn( - txn, - "remote_media_cache_thumbnails", - keyvalues={"media_origin": media_origin, "media_id": media_id}, - ) - - return self.runInteraction("delete_remote_media", delete_remote_media_txn) - - def get_expired_url_cache(self, now_ts): - sql = ( - "SELECT media_id FROM local_media_repository_url_cache" - " WHERE expires_ts < ?" - " ORDER BY expires_ts ASC" - " LIMIT 500" - ) - - def _get_expired_url_cache_txn(txn): - txn.execute(sql, (now_ts,)) - return [row[0] for row in txn] - - return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn) - - def delete_url_cache(self, media_ids): - if len(media_ids) == 0: - return - - sql = "DELETE FROM local_media_repository_url_cache" " WHERE media_id = ?" - - def _delete_url_cache_txn(txn): - txn.executemany(sql, [(media_id,) for media_id in media_ids]) - - return self.runInteraction("delete_url_cache", _delete_url_cache_txn) - - def get_url_cache_media_before(self, before_ts): - sql = ( - "SELECT media_id FROM local_media_repository" - " WHERE created_ts < ? AND url_cache IS NOT NULL" - " ORDER BY created_ts ASC" - " LIMIT 500" - ) - - def _get_url_cache_media_before_txn(txn): - txn.execute(sql, (before_ts,)) - return [row[0] for row in txn] - - return self.runInteraction( - "get_url_cache_media_before", _get_url_cache_media_before_txn - ) - - def delete_url_cache_media(self, media_ids): - if len(media_ids) == 0: - return - - def _delete_url_cache_media_txn(txn): - sql = "DELETE FROM local_media_repository" " WHERE media_id = ?" - - txn.executemany(sql, [(media_id,) for media_id in media_ids]) - - sql = "DELETE FROM local_media_repository_thumbnails" " WHERE media_id = ?" - - txn.executemany(sql, [(media_id,) for media_id in media_ids]) - - return self.runInteraction( - "delete_url_cache_media", _delete_url_cache_media_txn - ) diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py deleted file mode 100644 index 3803604be7..0000000000 --- a/synapse/storage/monthly_active_users.py +++ /dev/null @@ -1,329 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2018 New Vector -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging - -from twisted.internet import defer - -from synapse.util.caches.descriptors import cached - -from ._base import SQLBaseStore - -logger = logging.getLogger(__name__) - -# Number of msec of granularity to store the monthly_active_user timestamp -# This means it is not necessary to update the table on every request -LAST_SEEN_GRANULARITY = 60 * 60 * 1000 - - -class MonthlyActiveUsersStore(SQLBaseStore): - def __init__(self, dbconn, hs): - super(MonthlyActiveUsersStore, self).__init__(None, hs) - self._clock = hs.get_clock() - self.hs = hs - # Do not add more reserved users than the total allowable number - self._new_transaction( - dbconn, - "initialise_mau_threepids", - [], - [], - self._initialise_reserved_users, - hs.config.mau_limits_reserved_threepids[: self.hs.config.max_mau_value], - ) - - def _initialise_reserved_users(self, txn, threepids): - """Ensures that reserved threepids are accounted for in the MAU table, should - be called on start up. - - Args: - txn (cursor): - threepids (list[dict]): List of threepid dicts to reserve - """ - - for tp in threepids: - user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"]) - - if user_id: - is_support = self.is_support_user_txn(txn, user_id) - if not is_support: - self.upsert_monthly_active_user_txn(txn, user_id) - else: - logger.warning("mau limit reserved threepid %s not found in db" % tp) - - @defer.inlineCallbacks - def reap_monthly_active_users(self): - """Cleans out monthly active user table to ensure that no stale - entries exist. - - Returns: - Deferred[] - """ - - def _reap_users(txn, reserved_users): - """ - Args: - reserved_users (tuple): reserved users to preserve - """ - - thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) - query_args = [thirty_days_ago] - base_sql = "DELETE FROM monthly_active_users WHERE timestamp < ?" - - # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres - # when len(reserved_users) == 0. Works fine on sqlite. - if len(reserved_users) > 0: - # questionmarks is a hack to overcome sqlite not supporting - # tuples in 'WHERE IN %s' - question_marks = ",".join("?" * len(reserved_users)) - - query_args.extend(reserved_users) - sql = base_sql + " AND user_id NOT IN ({})".format(question_marks) - else: - sql = base_sql - - txn.execute(sql, query_args) - - max_mau_value = self.hs.config.max_mau_value - if self.hs.config.limit_usage_by_mau: - # If MAU user count still exceeds the MAU threshold, then delete on - # a least recently active basis. - # Note it is not possible to write this query using OFFSET due to - # incompatibilities in how sqlite and postgres support the feature. - # sqlite requires 'LIMIT -1 OFFSET ?', the LIMIT must be present - # While Postgres does not require 'LIMIT', but also does not support - # negative LIMIT values. So there is no way to write it that both can - # support - if len(reserved_users) == 0: - sql = """ - DELETE FROM monthly_active_users - WHERE user_id NOT IN ( - SELECT user_id FROM monthly_active_users - ORDER BY timestamp DESC - LIMIT ? - ) - """ - txn.execute(sql, (max_mau_value,)) - # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres - # when len(reserved_users) == 0. Works fine on sqlite. - else: - # Must be >= 0 for postgres - num_of_non_reserved_users_to_remove = max( - max_mau_value - len(reserved_users), 0 - ) - - # It is important to filter reserved users twice to guard - # against the case where the reserved user is present in the - # SELECT, meaning that a legitmate mau is deleted. - sql = """ - DELETE FROM monthly_active_users - WHERE user_id NOT IN ( - SELECT user_id FROM monthly_active_users - WHERE user_id NOT IN ({}) - ORDER BY timestamp DESC - LIMIT ? - ) - AND user_id NOT IN ({}) - """.format( - question_marks, question_marks - ) - - query_args = [ - *reserved_users, - num_of_non_reserved_users_to_remove, - *reserved_users, - ] - - txn.execute(sql, query_args) - - reserved_users = yield self.get_registered_reserved_users() - yield self.runInteraction( - "reap_monthly_active_users", _reap_users, reserved_users - ) - # It seems poor to invalidate the whole cache, Postgres supports - # 'Returning' which would allow me to invalidate only the - # specific users, but sqlite has no way to do this and instead - # I would need to SELECT and the DELETE which without locking - # is racy. - # Have resolved to invalidate the whole cache for now and do - # something about it if and when the perf becomes significant - self.user_last_seen_monthly_active.invalidate_all() - self.get_monthly_active_count.invalidate_all() - - @cached(num_args=0) - def get_monthly_active_count(self): - """Generates current count of monthly active users - - Returns: - Defered[int]: Number of current monthly active users - """ - - def _count_users(txn): - sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users" - - txn.execute(sql) - count, = txn.fetchone() - return count - - return self.runInteraction("count_users", _count_users) - - @defer.inlineCallbacks - def get_registered_reserved_users(self): - """Of the reserved threepids defined in config, which are associated - with registered users? - - Returns: - Defered[list]: Real reserved users - """ - users = [] - - for tp in self.hs.config.mau_limits_reserved_threepids[ - : self.hs.config.max_mau_value - ]: - user_id = yield self.hs.get_datastore().get_user_id_by_threepid( - tp["medium"], tp["address"] - ) - if user_id: - users.append(user_id) - - return users - - @defer.inlineCallbacks - def upsert_monthly_active_user(self, user_id): - """Updates or inserts the user into the monthly active user table, which - is used to track the current MAU usage of the server - - Args: - user_id (str): user to add/update - """ - # Support user never to be included in MAU stats. Note I can't easily call this - # from upsert_monthly_active_user_txn because then I need a _txn form of - # is_support_user which is complicated because I want to cache the result. - # Therefore I call it here and ignore the case where - # upsert_monthly_active_user_txn is called directly from - # _initialise_reserved_users reasoning that it would be very strange to - # include a support user in this context. - - is_support = yield self.is_support_user(user_id) - if is_support: - return - - yield self.runInteraction( - "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id - ) - - user_in_mau = self.user_last_seen_monthly_active.cache.get( - (user_id,), None, update_metrics=False - ) - if user_in_mau is None: - self.get_monthly_active_count.invalidate(()) - - self.user_last_seen_monthly_active.invalidate((user_id,)) - - def upsert_monthly_active_user_txn(self, txn, user_id): - """Updates or inserts monthly active user member - - Note that, after calling this method, it will generally be necessary - to invalidate the caches on user_last_seen_monthly_active and - get_monthly_active_count. We can't do that here, because we are running - in a database thread rather than the main thread, and we can't call - txn.call_after because txn may not be a LoggingTransaction. - - We consciously do not call is_support_txn from this method because it - is not possible to cache the response. is_support_txn will be false in - almost all cases, so it seems reasonable to call it only for - upsert_monthly_active_user and to call is_support_txn manually - for cases where upsert_monthly_active_user_txn is called directly, - like _initialise_reserved_users - - In short, don't call this method with support users. (Support users - should not appear in the MAU stats). - - Args: - txn (cursor): - user_id (str): user to add/update - - Returns: - bool: True if a new entry was created, False if an - existing one was updated. - """ - - # Am consciously deciding to lock the table on the basis that is ought - # never be a big table and alternative approaches (batching multiple - # upserts into a single txn) introduced a lot of extra complexity. - # See https://github.com/matrix-org/synapse/issues/3854 for more - is_insert = self._simple_upsert_txn( - txn, - table="monthly_active_users", - keyvalues={"user_id": user_id}, - values={"timestamp": int(self._clock.time_msec())}, - ) - - return is_insert - - @cached(num_args=1) - def user_last_seen_monthly_active(self, user_id): - """ - Checks if a given user is part of the monthly active user group - Arguments: - user_id (str): user to add/update - Return: - Deferred[int] : timestamp since last seen, None if never seen - - """ - - return self._simple_select_one_onecol( - table="monthly_active_users", - keyvalues={"user_id": user_id}, - retcol="timestamp", - allow_none=True, - desc="user_last_seen_monthly_active", - ) - - @defer.inlineCallbacks - def populate_monthly_active_users(self, user_id): - """Checks on the state of monthly active user limits and optionally - add the user to the monthly active tables - - Args: - user_id(str): the user_id to query - """ - if self.hs.config.limit_usage_by_mau or self.hs.config.mau_stats_only: - # Trial users and guests should not be included as part of MAU group - is_guest = yield self.is_guest(user_id) - if is_guest: - return - is_trial = yield self.is_trial_user(user_id) - if is_trial: - return - - last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id) - now = self.hs.get_clock().time_msec() - - # We want to reduce to the total number of db writes, and are happy - # to trade accuracy of timestamp in order to lighten load. This means - # We always insert new users (where MAU threshold has not been reached), - # but only update if we have not previously seen the user for - # LAST_SEEN_GRANULARITY ms - if last_seen_timestamp is None: - # In the case where mau_stats_only is True and limit_usage_by_mau is - # False, there is no point in checking get_monthly_active_count - it - # adds no value and will break the logic if max_mau_value is exceeded. - if not self.hs.config.limit_usage_by_mau: - yield self.upsert_monthly_active_user(user_id) - else: - count = yield self.get_monthly_active_count() - if count < self.hs.config.max_mau_value: - yield self.upsert_monthly_active_user(user_id) - elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY: - yield self.upsert_monthly_active_user(user_id) diff --git a/synapse/storage/openid.py b/synapse/storage/openid.py deleted file mode 100644 index b3318045ee..0000000000 --- a/synapse/storage/openid.py +++ /dev/null @@ -1,31 +0,0 @@ -from ._base import SQLBaseStore - - -class OpenIdStore(SQLBaseStore): - def insert_open_id_token(self, token, ts_valid_until_ms, user_id): - return self._simple_insert( - table="open_id_tokens", - values={ - "token": token, - "ts_valid_until_ms": ts_valid_until_ms, - "user_id": user_id, - }, - desc="insert_open_id_token", - ) - - def get_user_id_for_open_id_token(self, token, ts_now_ms): - def get_user_id_for_token_txn(txn): - sql = ( - "SELECT user_id FROM open_id_tokens" - " WHERE token = ? AND ? <= ts_valid_until_ms" - ) - - txn.execute(sql, (token, ts_now_ms)) - - rows = txn.fetchall() - if not rows: - return None - else: - return rows[0][0] - - return self.runInteraction("get_user_id_for_token", get_user_id_for_token_txn) diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 3a641f538b..18a462f0ee 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -15,12 +15,7 @@ from collections import namedtuple -from twisted.internet import defer - from synapse.api.constants import PresenceState -from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.util import batch_iter -from synapse.util.caches.descriptors import cached, cachedList class UserPresenceState( @@ -72,132 +67,3 @@ class UserPresenceState( status_msg=None, currently_active=False, ) - - -class PresenceStore(SQLBaseStore): - @defer.inlineCallbacks - def update_presence(self, presence_states): - stream_ordering_manager = self._presence_id_gen.get_next_mult( - len(presence_states) - ) - - with stream_ordering_manager as stream_orderings: - yield self.runInteraction( - "update_presence", - self._update_presence_txn, - stream_orderings, - presence_states, - ) - - return stream_orderings[-1], self._presence_id_gen.get_current_token() - - def _update_presence_txn(self, txn, stream_orderings, presence_states): - for stream_id, state in zip(stream_orderings, presence_states): - txn.call_after( - self.presence_stream_cache.entity_has_changed, state.user_id, stream_id - ) - txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,)) - - # Actually insert new rows - self._simple_insert_many_txn( - txn, - table="presence_stream", - values=[ - { - "stream_id": stream_id, - "user_id": state.user_id, - "state": state.state, - "last_active_ts": state.last_active_ts, - "last_federation_update_ts": state.last_federation_update_ts, - "last_user_sync_ts": state.last_user_sync_ts, - "status_msg": state.status_msg, - "currently_active": state.currently_active, - } - for state in presence_states - ], - ) - - # Delete old rows to stop database from getting really big - sql = "DELETE FROM presence_stream WHERE stream_id < ? AND " - - for states in batch_iter(presence_states, 50): - clause, args = make_in_list_sql_clause( - self.database_engine, "user_id", [s.user_id for s in states] - ) - txn.execute(sql + clause, [stream_id] + list(args)) - - def get_all_presence_updates(self, last_id, current_id): - if last_id == current_id: - return defer.succeed([]) - - def get_all_presence_updates_txn(txn): - sql = ( - "SELECT stream_id, user_id, state, last_active_ts," - " last_federation_update_ts, last_user_sync_ts, status_msg," - " currently_active" - " FROM presence_stream" - " WHERE ? < stream_id AND stream_id <= ?" - ) - txn.execute(sql, (last_id, current_id)) - return txn.fetchall() - - return self.runInteraction( - "get_all_presence_updates", get_all_presence_updates_txn - ) - - @cached() - def _get_presence_for_user(self, user_id): - raise NotImplementedError() - - @cachedList( - cached_method_name="_get_presence_for_user", - list_name="user_ids", - num_args=1, - inlineCallbacks=True, - ) - def get_presence_for_users(self, user_ids): - rows = yield self._simple_select_many_batch( - table="presence_stream", - column="user_id", - iterable=user_ids, - keyvalues={}, - retcols=( - "user_id", - "state", - "last_active_ts", - "last_federation_update_ts", - "last_user_sync_ts", - "status_msg", - "currently_active", - ), - desc="get_presence_for_users", - ) - - for row in rows: - row["currently_active"] = bool(row["currently_active"]) - - return {row["user_id"]: UserPresenceState(**row) for row in rows} - - def get_current_presence_token(self): - return self._presence_id_gen.get_current_token() - - def allow_presence_visible(self, observed_localpart, observer_userid): - return self._simple_insert( - table="presence_allow_inbound", - values={ - "observed_user_id": observed_localpart, - "observer_user_id": observer_userid, - }, - desc="allow_presence_visible", - or_ignore=True, - ) - - def disallow_presence_visible(self, observed_localpart, observer_userid): - return self._simple_delete_one( - table="presence_allow_inbound", - keyvalues={ - "observed_user_id": observed_localpart, - "observer_user_id": observer_userid, - }, - desc="disallow_presence_visible", - ) diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py deleted file mode 100644 index 912c1df6be..0000000000 --- a/synapse/storage/profile.py +++ /dev/null @@ -1,179 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from twisted.internet import defer - -from synapse.api.errors import StoreError -from synapse.storage.roommember import ProfileInfo - -from ._base import SQLBaseStore - - -class ProfileWorkerStore(SQLBaseStore): - @defer.inlineCallbacks - def get_profileinfo(self, user_localpart): - try: - profile = yield self._simple_select_one( - table="profiles", - keyvalues={"user_id": user_localpart}, - retcols=("displayname", "avatar_url"), - desc="get_profileinfo", - ) - except StoreError as e: - if e.code == 404: - # no match - return ProfileInfo(None, None) - else: - raise - - return ProfileInfo( - avatar_url=profile["avatar_url"], display_name=profile["displayname"] - ) - - def get_profile_displayname(self, user_localpart): - return self._simple_select_one_onecol( - table="profiles", - keyvalues={"user_id": user_localpart}, - retcol="displayname", - desc="get_profile_displayname", - ) - - def get_profile_avatar_url(self, user_localpart): - return self._simple_select_one_onecol( - table="profiles", - keyvalues={"user_id": user_localpart}, - retcol="avatar_url", - desc="get_profile_avatar_url", - ) - - def get_from_remote_profile_cache(self, user_id): - return self._simple_select_one( - table="remote_profile_cache", - keyvalues={"user_id": user_id}, - retcols=("displayname", "avatar_url"), - allow_none=True, - desc="get_from_remote_profile_cache", - ) - - def create_profile(self, user_localpart): - return self._simple_insert( - table="profiles", values={"user_id": user_localpart}, desc="create_profile" - ) - - def set_profile_displayname(self, user_localpart, new_displayname): - return self._simple_update_one( - table="profiles", - keyvalues={"user_id": user_localpart}, - updatevalues={"displayname": new_displayname}, - desc="set_profile_displayname", - ) - - def set_profile_avatar_url(self, user_localpart, new_avatar_url): - return self._simple_update_one( - table="profiles", - keyvalues={"user_id": user_localpart}, - updatevalues={"avatar_url": new_avatar_url}, - desc="set_profile_avatar_url", - ) - - -class ProfileStore(ProfileWorkerStore): - def add_remote_profile_cache(self, user_id, displayname, avatar_url): - """Ensure we are caching the remote user's profiles. - - This should only be called when `is_subscribed_remote_profile_for_user` - would return true for the user. - """ - return self._simple_upsert( - table="remote_profile_cache", - keyvalues={"user_id": user_id}, - values={ - "displayname": displayname, - "avatar_url": avatar_url, - "last_check": self._clock.time_msec(), - }, - desc="add_remote_profile_cache", - ) - - def update_remote_profile_cache(self, user_id, displayname, avatar_url): - return self._simple_update( - table="remote_profile_cache", - keyvalues={"user_id": user_id}, - values={ - "displayname": displayname, - "avatar_url": avatar_url, - "last_check": self._clock.time_msec(), - }, - desc="update_remote_profile_cache", - ) - - @defer.inlineCallbacks - def maybe_delete_remote_profile_cache(self, user_id): - """Check if we still care about the remote user's profile, and if we - don't then remove their profile from the cache - """ - subscribed = yield self.is_subscribed_remote_profile_for_user(user_id) - if not subscribed: - yield self._simple_delete( - table="remote_profile_cache", - keyvalues={"user_id": user_id}, - desc="delete_remote_profile_cache", - ) - - def get_remote_profile_cache_entries_that_expire(self, last_checked): - """Get all users who haven't been checked since `last_checked` - """ - - def _get_remote_profile_cache_entries_that_expire_txn(txn): - sql = """ - SELECT user_id, displayname, avatar_url - FROM remote_profile_cache - WHERE last_check < ? - """ - - txn.execute(sql, (last_checked,)) - - return self.cursor_to_dict(txn) - - return self.runInteraction( - "get_remote_profile_cache_entries_that_expire", - _get_remote_profile_cache_entries_that_expire_txn, - ) - - @defer.inlineCallbacks - def is_subscribed_remote_profile_for_user(self, user_id): - """Check whether we are interested in a remote user's profile. - """ - res = yield self._simple_select_one_onecol( - table="group_users", - keyvalues={"user_id": user_id}, - retcol="user_id", - allow_none=True, - desc="should_update_remote_profile_cache_for_user", - ) - - if res: - return True - - res = yield self._simple_select_one_onecol( - table="group_invites", - keyvalues={"user_id": user_id}, - retcol="user_id", - allow_none=True, - desc="should_update_remote_profile_cache_for_user", - ) - - if res: - return True diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index c4e24edff2..f47cec0d86 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -14,704 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import abc -import logging - -from canonicaljson import json - -from twisted.internet import defer - -from synapse.push.baserules import list_with_base_rules -from synapse.storage.appservice import ApplicationServiceWorkerStore -from synapse.storage.pusher import PusherWorkerStore -from synapse.storage.receipts import ReceiptsWorkerStore -from synapse.storage.roommember import RoomMemberWorkerStore -from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList -from synapse.util.caches.stream_change_cache import StreamChangeCache - -from ._base import SQLBaseStore - -logger = logging.getLogger(__name__) - - -def _load_rules(rawrules, enabled_map): - ruleslist = [] - for rawrule in rawrules: - rule = dict(rawrule) - rule["conditions"] = json.loads(rawrule["conditions"]) - rule["actions"] = json.loads(rawrule["actions"]) - ruleslist.append(rule) - - # We're going to be mutating this a lot, so do a deep copy - rules = list(list_with_base_rules(ruleslist)) - - for i, rule in enumerate(rules): - rule_id = rule["rule_id"] - if rule_id in enabled_map: - if rule.get("enabled", True) != bool(enabled_map[rule_id]): - # Rules are cached across users. - rule = dict(rule) - rule["enabled"] = bool(enabled_map[rule_id]) - rules[i] = rule - - return rules - - -class PushRulesWorkerStore( - ApplicationServiceWorkerStore, - ReceiptsWorkerStore, - PusherWorkerStore, - RoomMemberWorkerStore, - SQLBaseStore, -): - """This is an abstract base class where subclasses must implement - `get_max_push_rules_stream_id` which can be called in the initializer. - """ - - # This ABCMeta metaclass ensures that we cannot be instantiated without - # the abstract methods being implemented. - __metaclass__ = abc.ABCMeta - - def __init__(self, db_conn, hs): - super(PushRulesWorkerStore, self).__init__(db_conn, hs) - - push_rules_prefill, push_rules_id = self._get_cache_dict( - db_conn, - "push_rules_stream", - entity_column="user_id", - stream_column="stream_id", - max_value=self.get_max_push_rules_stream_id(), - ) - - self.push_rules_stream_cache = StreamChangeCache( - "PushRulesStreamChangeCache", - push_rules_id, - prefilled_cache=push_rules_prefill, - ) - - @abc.abstractmethod - def get_max_push_rules_stream_id(self): - """Get the position of the push rules stream. - - Returns: - int - """ - raise NotImplementedError() - - @cachedInlineCallbacks(max_entries=5000) - def get_push_rules_for_user(self, user_id): - rows = yield self._simple_select_list( - table="push_rules", - keyvalues={"user_name": user_id}, - retcols=( - "user_name", - "rule_id", - "priority_class", - "priority", - "conditions", - "actions", - ), - desc="get_push_rules_enabled_for_user", - ) - - rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) - - enabled_map = yield self.get_push_rules_enabled_for_user(user_id) - - rules = _load_rules(rows, enabled_map) - - return rules - - @cachedInlineCallbacks(max_entries=5000) - def get_push_rules_enabled_for_user(self, user_id): - results = yield self._simple_select_list( - table="push_rules_enable", - keyvalues={"user_name": user_id}, - retcols=("user_name", "rule_id", "enabled"), - desc="get_push_rules_enabled_for_user", - ) - return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results} - - def have_push_rules_changed_for_user(self, user_id, last_id): - if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): - return defer.succeed(False) - else: - - def have_push_rules_changed_txn(txn): - sql = ( - "SELECT COUNT(stream_id) FROM push_rules_stream" - " WHERE user_id = ? AND ? < stream_id" - ) - txn.execute(sql, (user_id, last_id)) - count, = txn.fetchone() - return bool(count) - - return self.runInteraction( - "have_push_rules_changed", have_push_rules_changed_txn - ) - - @cachedList( - cached_method_name="get_push_rules_for_user", - list_name="user_ids", - num_args=1, - inlineCallbacks=True, - ) - def bulk_get_push_rules(self, user_ids): - if not user_ids: - return {} - - results = {user_id: [] for user_id in user_ids} - - rows = yield self._simple_select_many_batch( - table="push_rules", - column="user_name", - iterable=user_ids, - retcols=("*",), - desc="bulk_get_push_rules", - ) - - rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) - - for row in rows: - results.setdefault(row["user_name"], []).append(row) - - enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids) - - for user_id, rules in results.items(): - results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {})) - - return results - - @defer.inlineCallbacks - def copy_push_rule_from_room_to_room(self, new_room_id, user_id, rule): - """Copy a single push rule from one room to another for a specific user. - - Args: - new_room_id (str): ID of the new room. - user_id (str): ID of user the push rule belongs to. - rule (Dict): A push rule. - """ - # Create new rule id - rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1]) - new_rule_id = rule_id_scope + "/" + new_room_id - - # Change room id in each condition - for condition in rule.get("conditions", []): - if condition.get("key") == "room_id": - condition["pattern"] = new_room_id - - # Add the rule for the new room - yield self.add_push_rule( - user_id=user_id, - rule_id=new_rule_id, - priority_class=rule["priority_class"], - conditions=rule["conditions"], - actions=rule["actions"], - ) - - @defer.inlineCallbacks - def copy_push_rules_from_room_to_room_for_user( - self, old_room_id, new_room_id, user_id - ): - """Copy all of the push rules from one room to another for a specific - user. - - Args: - old_room_id (str): ID of the old room. - new_room_id (str): ID of the new room. - user_id (str): ID of user to copy push rules for. - """ - # Retrieve push rules for this user - user_push_rules = yield self.get_push_rules_for_user(user_id) - - # Get rules relating to the old room and copy them to the new room - for rule in user_push_rules: - conditions = rule.get("conditions", []) - if any( - (c.get("key") == "room_id" and c.get("pattern") == old_room_id) - for c in conditions - ): - yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule) - - @defer.inlineCallbacks - def bulk_get_push_rules_for_room(self, event, context): - state_group = context.state_group - if not state_group: - # If state_group is None it means it has yet to be assigned a - # state group, i.e. we need to make sure that calls with a state_group - # of None don't hit previous cached calls with a None state_group. - # To do this we set the state_group to a new object as object() != object() - state_group = object() - - current_state_ids = yield context.get_current_state_ids(self) - result = yield self._bulk_get_push_rules_for_room( - event.room_id, state_group, current_state_ids, event=event - ) - return result - - @cachedInlineCallbacks(num_args=2, cache_context=True) - def _bulk_get_push_rules_for_room( - self, room_id, state_group, current_state_ids, cache_context, event=None - ): - # We don't use `state_group`, its there so that we can cache based - # on it. However, its important that its never None, since two current_state's - # with a state_group of None are likely to be different. - # See bulk_get_push_rules_for_room for how we work around this. - assert state_group is not None - - # We also will want to generate notifs for other people in the room so - # their unread countss are correct in the event stream, but to avoid - # generating them for bot / AS users etc, we only do so for people who've - # sent a read receipt into the room. - - users_in_room = yield self._get_joined_users_from_context( - room_id, - state_group, - current_state_ids, - on_invalidate=cache_context.invalidate, - event=event, - ) - - # We ignore app service users for now. This is so that we don't fill - # up the `get_if_users_have_pushers` cache with AS entries that we - # know don't have pushers, nor even read receipts. - local_users_in_room = set( - u - for u in users_in_room - if self.hs.is_mine_id(u) - and not self.get_if_app_services_interested_in_user(u) - ) - - # users in the room who have pushers need to get push rules run because - # that's how their pushers work - if_users_with_pushers = yield self.get_if_users_have_pushers( - local_users_in_room, on_invalidate=cache_context.invalidate - ) - user_ids = set( - uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher - ) - - users_with_receipts = yield self.get_users_with_read_receipts_in_room( - room_id, on_invalidate=cache_context.invalidate - ) - - # any users with pushers must be ours: they have pushers - for uid in users_with_receipts: - if uid in local_users_in_room: - user_ids.add(uid) - - rules_by_user = yield self.bulk_get_push_rules( - user_ids, on_invalidate=cache_context.invalidate - ) - - rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None} - - return rules_by_user - - @cachedList( - cached_method_name="get_push_rules_enabled_for_user", - list_name="user_ids", - num_args=1, - inlineCallbacks=True, - ) - def bulk_get_push_rules_enabled(self, user_ids): - if not user_ids: - return {} - - results = {user_id: {} for user_id in user_ids} - - rows = yield self._simple_select_many_batch( - table="push_rules_enable", - column="user_name", - iterable=user_ids, - retcols=("user_name", "rule_id", "enabled"), - desc="bulk_get_push_rules_enabled", - ) - for row in rows: - enabled = bool(row["enabled"]) - results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled - return results - - -class PushRuleStore(PushRulesWorkerStore): - @defer.inlineCallbacks - def add_push_rule( - self, - user_id, - rule_id, - priority_class, - conditions, - actions, - before=None, - after=None, - ): - conditions_json = json.dumps(conditions) - actions_json = json.dumps(actions) - with self._push_rules_stream_id_gen.get_next() as ids: - stream_id, event_stream_ordering = ids - if before or after: - yield self.runInteraction( - "_add_push_rule_relative_txn", - self._add_push_rule_relative_txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - conditions_json, - actions_json, - before, - after, - ) - else: - yield self.runInteraction( - "_add_push_rule_highest_priority_txn", - self._add_push_rule_highest_priority_txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - conditions_json, - actions_json, - ) - - def _add_push_rule_relative_txn( - self, - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - conditions_json, - actions_json, - before, - after, - ): - # Lock the table since otherwise we'll have annoying races between the - # SELECT here and the UPSERT below. - self.database_engine.lock_table(txn, "push_rules") - - relative_to_rule = before or after - - res = self._simple_select_one_txn( - txn, - table="push_rules", - keyvalues={"user_name": user_id, "rule_id": relative_to_rule}, - retcols=["priority_class", "priority"], - allow_none=True, - ) - - if not res: - raise RuleNotFoundException( - "before/after rule not found: %s" % (relative_to_rule,) - ) - - base_priority_class = res["priority_class"] - base_rule_priority = res["priority"] - - if base_priority_class != priority_class: - raise InconsistentRuleException( - "Given priority class does not match class of relative rule" - ) - - if before: - # Higher priority rules are executed first, So adding a rule before - # a rule means giving it a higher priority than that rule. - new_rule_priority = base_rule_priority + 1 - else: - # We increment the priority of the existing rules to make space for - # the new rule. Therefore if we want this rule to appear after - # an existing rule we give it the priority of the existing rule, - # and then increment the priority of the existing rule. - new_rule_priority = base_rule_priority - - sql = ( - "UPDATE push_rules SET priority = priority + 1" - " WHERE user_name = ? AND priority_class = ? AND priority >= ?" - ) - - txn.execute(sql, (user_id, priority_class, new_rule_priority)) - - self._upsert_push_rule_txn( - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - new_rule_priority, - conditions_json, - actions_json, - ) - - def _add_push_rule_highest_priority_txn( - self, - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - conditions_json, - actions_json, - ): - # Lock the table since otherwise we'll have annoying races between the - # SELECT here and the UPSERT below. - self.database_engine.lock_table(txn, "push_rules") - - # find the highest priority rule in that class - sql = ( - "SELECT COUNT(*), MAX(priority) FROM push_rules" - " WHERE user_name = ? and priority_class = ?" - ) - txn.execute(sql, (user_id, priority_class)) - res = txn.fetchall() - (how_many, highest_prio) = res[0] - - new_prio = 0 - if how_many > 0: - new_prio = highest_prio + 1 - - self._upsert_push_rule_txn( - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - new_prio, - conditions_json, - actions_json, - ) - - def _upsert_push_rule_txn( - self, - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - priority, - conditions_json, - actions_json, - update_stream=True, - ): - """Specialised version of _simple_upsert_txn that picks a push_rule_id - using the _push_rule_id_gen if it needs to insert the rule. It assumes - that the "push_rules" table is locked""" - - sql = ( - "UPDATE push_rules" - " SET priority_class = ?, priority = ?, conditions = ?, actions = ?" - " WHERE user_name = ? AND rule_id = ?" - ) - - txn.execute( - sql, - (priority_class, priority, conditions_json, actions_json, user_id, rule_id), - ) - - if txn.rowcount == 0: - # We didn't update a row with the given rule_id so insert one - push_rule_id = self._push_rule_id_gen.get_next() - - self._simple_insert_txn( - txn, - table="push_rules", - values={ - "id": push_rule_id, - "user_name": user_id, - "rule_id": rule_id, - "priority_class": priority_class, - "priority": priority, - "conditions": conditions_json, - "actions": actions_json, - }, - ) - - if update_stream: - self._insert_push_rules_update_txn( - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - op="ADD", - data={ - "priority_class": priority_class, - "priority": priority, - "conditions": conditions_json, - "actions": actions_json, - }, - ) - - @defer.inlineCallbacks - def delete_push_rule(self, user_id, rule_id): - """ - Delete a push rule. Args specify the row to be deleted and can be - any of the columns in the push_rule table, but below are the - standard ones - - Args: - user_id (str): The matrix ID of the push rule owner - rule_id (str): The rule_id of the rule to be deleted - """ - - def delete_push_rule_txn(txn, stream_id, event_stream_ordering): - self._simple_delete_one_txn( - txn, "push_rules", {"user_name": user_id, "rule_id": rule_id} - ) - - self._insert_push_rules_update_txn( - txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE" - ) - - with self._push_rules_stream_id_gen.get_next() as ids: - stream_id, event_stream_ordering = ids - yield self.runInteraction( - "delete_push_rule", - delete_push_rule_txn, - stream_id, - event_stream_ordering, - ) - - @defer.inlineCallbacks - def set_push_rule_enabled(self, user_id, rule_id, enabled): - with self._push_rules_stream_id_gen.get_next() as ids: - stream_id, event_stream_ordering = ids - yield self.runInteraction( - "_set_push_rule_enabled_txn", - self._set_push_rule_enabled_txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - enabled, - ) - - def _set_push_rule_enabled_txn( - self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled - ): - new_id = self._push_rules_enable_id_gen.get_next() - self._simple_upsert_txn( - txn, - "push_rules_enable", - {"user_name": user_id, "rule_id": rule_id}, - {"enabled": 1 if enabled else 0}, - {"id": new_id}, - ) - - self._insert_push_rules_update_txn( - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - op="ENABLE" if enabled else "DISABLE", - ) - - @defer.inlineCallbacks - def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule): - actions_json = json.dumps(actions) - - def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering): - if is_default_rule: - # Add a dummy rule to the rules table with the user specified - # actions. - priority_class = -1 - priority = 1 - self._upsert_push_rule_txn( - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - priority, - "[]", - actions_json, - update_stream=False, - ) - else: - self._simple_update_one_txn( - txn, - "push_rules", - {"user_name": user_id, "rule_id": rule_id}, - {"actions": actions_json}, - ) - - self._insert_push_rules_update_txn( - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - op="ACTIONS", - data={"actions": actions_json}, - ) - - with self._push_rules_stream_id_gen.get_next() as ids: - stream_id, event_stream_ordering = ids - yield self.runInteraction( - "set_push_rule_actions", - set_push_rule_actions_txn, - stream_id, - event_stream_ordering, - ) - - def _insert_push_rules_update_txn( - self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None - ): - values = { - "stream_id": stream_id, - "event_stream_ordering": event_stream_ordering, - "user_id": user_id, - "rule_id": rule_id, - "op": op, - } - if data is not None: - values.update(data) - - self._simple_insert_txn(txn, "push_rules_stream", values=values) - - txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,)) - txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,)) - txn.call_after( - self.push_rules_stream_cache.entity_has_changed, user_id, stream_id - ) - - def get_all_push_rule_updates(self, last_id, current_id, limit): - """Get all the push rules changes that have happend on the server""" - if last_id == current_id: - return defer.succeed([]) - - def get_all_push_rule_updates_txn(txn): - sql = ( - "SELECT stream_id, event_stream_ordering, user_id, rule_id," - " op, priority_class, priority, conditions, actions" - " FROM push_rules_stream" - " WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC LIMIT ?" - ) - txn.execute(sql, (last_id, current_id, limit)) - return txn.fetchall() - - return self.runInteraction( - "get_all_push_rule_updates", get_all_push_rule_updates_txn - ) - - def get_push_rules_stream_token(self): - """Get the position of the push rules stream. - Returns a pair of a stream id for the push_rules stream and the - room stream ordering it corresponds to.""" - return self._push_rules_stream_id_gen.get_current_token() - - def get_max_push_rules_stream_id(self): - return self.get_push_rules_stream_token()[0] - class RuleNotFoundException(Exception): pass diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py deleted file mode 100644 index b12e80440a..0000000000 --- a/synapse/storage/pusher.py +++ /dev/null @@ -1,372 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -import six - -from canonicaljson import encode_canonical_json, json - -from twisted.internet import defer - -from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList - -from ._base import SQLBaseStore - -logger = logging.getLogger(__name__) - -if six.PY2: - db_binary_type = six.moves.builtins.buffer -else: - db_binary_type = memoryview - - -class PusherWorkerStore(SQLBaseStore): - def _decode_pushers_rows(self, rows): - for r in rows: - dataJson = r["data"] - r["data"] = None - try: - if isinstance(dataJson, db_binary_type): - dataJson = str(dataJson).decode("UTF8") - - r["data"] = json.loads(dataJson) - except Exception as e: - logger.warn( - "Invalid JSON in data for pusher %d: %s, %s", - r["id"], - dataJson, - e.args[0], - ) - pass - - if isinstance(r["pushkey"], db_binary_type): - r["pushkey"] = str(r["pushkey"]).decode("UTF8") - - return rows - - @defer.inlineCallbacks - def user_has_pusher(self, user_id): - ret = yield self._simple_select_one_onecol( - "pushers", {"user_name": user_id}, "id", allow_none=True - ) - return ret is not None - - def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey): - return self.get_pushers_by({"app_id": app_id, "pushkey": pushkey}) - - def get_pushers_by_user_id(self, user_id): - return self.get_pushers_by({"user_name": user_id}) - - @defer.inlineCallbacks - def get_pushers_by(self, keyvalues): - ret = yield self._simple_select_list( - "pushers", - keyvalues, - [ - "id", - "user_name", - "access_token", - "profile_tag", - "kind", - "app_id", - "app_display_name", - "device_display_name", - "pushkey", - "ts", - "lang", - "data", - "last_stream_ordering", - "last_success", - "failing_since", - ], - desc="get_pushers_by", - ) - return self._decode_pushers_rows(ret) - - @defer.inlineCallbacks - def get_all_pushers(self): - def get_pushers(txn): - txn.execute("SELECT * FROM pushers") - rows = self.cursor_to_dict(txn) - - return self._decode_pushers_rows(rows) - - rows = yield self.runInteraction("get_all_pushers", get_pushers) - return rows - - def get_all_updated_pushers(self, last_id, current_id, limit): - if last_id == current_id: - return defer.succeed(([], [])) - - def get_all_updated_pushers_txn(txn): - sql = ( - "SELECT id, user_name, access_token, profile_tag, kind," - " app_id, app_display_name, device_display_name, pushkey, ts," - " lang, data" - " FROM pushers" - " WHERE ? < id AND id <= ?" - " ORDER BY id ASC LIMIT ?" - ) - txn.execute(sql, (last_id, current_id, limit)) - updated = txn.fetchall() - - sql = ( - "SELECT stream_id, user_id, app_id, pushkey" - " FROM deleted_pushers" - " WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC LIMIT ?" - ) - txn.execute(sql, (last_id, current_id, limit)) - deleted = txn.fetchall() - - return updated, deleted - - return self.runInteraction( - "get_all_updated_pushers", get_all_updated_pushers_txn - ) - - def get_all_updated_pushers_rows(self, last_id, current_id, limit): - """Get all the pushers that have changed between the given tokens. - - Returns: - Deferred(list(tuple)): each tuple consists of: - stream_id (str) - user_id (str) - app_id (str) - pushkey (str) - was_deleted (bool): whether the pusher was added/updated (False) - or deleted (True) - """ - - if last_id == current_id: - return defer.succeed([]) - - def get_all_updated_pushers_rows_txn(txn): - sql = ( - "SELECT id, user_name, app_id, pushkey" - " FROM pushers" - " WHERE ? < id AND id <= ?" - " ORDER BY id ASC LIMIT ?" - ) - txn.execute(sql, (last_id, current_id, limit)) - results = [list(row) + [False] for row in txn] - - sql = ( - "SELECT stream_id, user_id, app_id, pushkey" - " FROM deleted_pushers" - " WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC LIMIT ?" - ) - txn.execute(sql, (last_id, current_id, limit)) - - results.extend(list(row) + [True] for row in txn) - results.sort() # Sort so that they're ordered by stream id - - return results - - return self.runInteraction( - "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn - ) - - @cachedInlineCallbacks(num_args=1, max_entries=15000) - def get_if_user_has_pusher(self, user_id): - # This only exists for the cachedList decorator - raise NotImplementedError() - - @cachedList( - cached_method_name="get_if_user_has_pusher", - list_name="user_ids", - num_args=1, - inlineCallbacks=True, - ) - def get_if_users_have_pushers(self, user_ids): - rows = yield self._simple_select_many_batch( - table="pushers", - column="user_name", - iterable=user_ids, - retcols=["user_name"], - desc="get_if_users_have_pushers", - ) - - result = {user_id: False for user_id in user_ids} - result.update({r["user_name"]: True for r in rows}) - - return result - - -class PusherStore(PusherWorkerStore): - def get_pushers_stream_token(self): - return self._pushers_id_gen.get_current_token() - - @defer.inlineCallbacks - def add_pusher( - self, - user_id, - access_token, - kind, - app_id, - app_display_name, - device_display_name, - pushkey, - pushkey_ts, - lang, - data, - last_stream_ordering, - profile_tag="", - ): - with self._pushers_id_gen.get_next() as stream_id: - # no need to lock because `pushers` has a unique key on - # (app_id, pushkey, user_name) so _simple_upsert will retry - yield self._simple_upsert( - table="pushers", - keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, - values={ - "access_token": access_token, - "kind": kind, - "app_display_name": app_display_name, - "device_display_name": device_display_name, - "ts": pushkey_ts, - "lang": lang, - "data": bytearray(encode_canonical_json(data)), - "last_stream_ordering": last_stream_ordering, - "profile_tag": profile_tag, - "id": stream_id, - }, - desc="add_pusher", - lock=False, - ) - - user_has_pusher = self.get_if_user_has_pusher.cache.get( - (user_id,), None, update_metrics=False - ) - - if user_has_pusher is not True: - # invalidate, since we the user might not have had a pusher before - yield self.runInteraction( - "add_pusher", - self._invalidate_cache_and_stream, - self.get_if_user_has_pusher, - (user_id,), - ) - - @defer.inlineCallbacks - def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id): - def delete_pusher_txn(txn, stream_id): - self._invalidate_cache_and_stream( - txn, self.get_if_user_has_pusher, (user_id,) - ) - - self._simple_delete_one_txn( - txn, - "pushers", - {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, - ) - - # it's possible for us to end up with duplicate rows for - # (app_id, pushkey, user_id) at different stream_ids, but that - # doesn't really matter. - self._simple_insert_txn( - txn, - table="deleted_pushers", - values={ - "stream_id": stream_id, - "app_id": app_id, - "pushkey": pushkey, - "user_id": user_id, - }, - ) - - with self._pushers_id_gen.get_next() as stream_id: - yield self.runInteraction("delete_pusher", delete_pusher_txn, stream_id) - - @defer.inlineCallbacks - def update_pusher_last_stream_ordering( - self, app_id, pushkey, user_id, last_stream_ordering - ): - yield self._simple_update_one( - "pushers", - {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, - {"last_stream_ordering": last_stream_ordering}, - desc="update_pusher_last_stream_ordering", - ) - - @defer.inlineCallbacks - def update_pusher_last_stream_ordering_and_success( - self, app_id, pushkey, user_id, last_stream_ordering, last_success - ): - """Update the last stream ordering position we've processed up to for - the given pusher. - - Args: - app_id (str) - pushkey (str) - last_stream_ordering (int) - last_success (int) - - Returns: - Deferred[bool]: True if the pusher still exists; False if it has been deleted. - """ - updated = yield self._simple_update( - table="pushers", - keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, - updatevalues={ - "last_stream_ordering": last_stream_ordering, - "last_success": last_success, - }, - desc="update_pusher_last_stream_ordering_and_success", - ) - - return bool(updated) - - @defer.inlineCallbacks - def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since): - yield self._simple_update( - table="pushers", - keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, - updatevalues={"failing_since": failing_since}, - desc="update_pusher_failing_since", - ) - - @defer.inlineCallbacks - def get_throttle_params_by_room(self, pusher_id): - res = yield self._simple_select_list( - "pusher_throttle", - {"pusher": pusher_id}, - ["room_id", "last_sent_ts", "throttle_ms"], - desc="get_throttle_params_by_room", - ) - - params_by_room = {} - for row in res: - params_by_room[row["room_id"]] = { - "last_sent_ts": row["last_sent_ts"], - "throttle_ms": row["throttle_ms"], - } - - return params_by_room - - @defer.inlineCallbacks - def set_throttle_params(self, pusher_id, room_id, params): - # no need to lock because `pusher_throttle` has a primary key on - # (pusher, room_id) so _simple_upsert will retry - yield self._simple_upsert( - "pusher_throttle", - {"pusher": pusher_id, "room_id": room_id}, - params, - desc="set_throttle_params", - lock=False, - ) diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py deleted file mode 100644 index 0c24430f28..0000000000 --- a/synapse/storage/receipts.py +++ /dev/null @@ -1,536 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import abc -import logging - -from canonicaljson import json - -from twisted.internet import defer - -from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.storage.util.id_generators import StreamIdGenerator -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList -from synapse.util.caches.stream_change_cache import StreamChangeCache - -logger = logging.getLogger(__name__) - - -class ReceiptsWorkerStore(SQLBaseStore): - """This is an abstract base class where subclasses must implement - `get_max_receipt_stream_id` which can be called in the initializer. - """ - - # This ABCMeta metaclass ensures that we cannot be instantiated without - # the abstract methods being implemented. - __metaclass__ = abc.ABCMeta - - def __init__(self, db_conn, hs): - super(ReceiptsWorkerStore, self).__init__(db_conn, hs) - - self._receipts_stream_cache = StreamChangeCache( - "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() - ) - - @abc.abstractmethod - def get_max_receipt_stream_id(self): - """Get the current max stream ID for receipts stream - - Returns: - int - """ - raise NotImplementedError() - - @cachedInlineCallbacks() - def get_users_with_read_receipts_in_room(self, room_id): - receipts = yield self.get_receipts_for_room(room_id, "m.read") - return set(r["user_id"] for r in receipts) - - @cached(num_args=2) - def get_receipts_for_room(self, room_id, receipt_type): - return self._simple_select_list( - table="receipts_linearized", - keyvalues={"room_id": room_id, "receipt_type": receipt_type}, - retcols=("user_id", "event_id"), - desc="get_receipts_for_room", - ) - - @cached(num_args=3) - def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type): - return self._simple_select_one_onecol( - table="receipts_linearized", - keyvalues={ - "room_id": room_id, - "receipt_type": receipt_type, - "user_id": user_id, - }, - retcol="event_id", - desc="get_own_receipt_for_user", - allow_none=True, - ) - - @cachedInlineCallbacks(num_args=2) - def get_receipts_for_user(self, user_id, receipt_type): - rows = yield self._simple_select_list( - table="receipts_linearized", - keyvalues={"user_id": user_id, "receipt_type": receipt_type}, - retcols=("room_id", "event_id"), - desc="get_receipts_for_user", - ) - - return {row["room_id"]: row["event_id"] for row in rows} - - @defer.inlineCallbacks - def get_receipts_for_user_with_orderings(self, user_id, receipt_type): - def f(txn): - sql = ( - "SELECT rl.room_id, rl.event_id," - " e.topological_ordering, e.stream_ordering" - " FROM receipts_linearized AS rl" - " INNER JOIN events AS e USING (room_id, event_id)" - " WHERE rl.room_id = e.room_id" - " AND rl.event_id = e.event_id" - " AND user_id = ?" - ) - txn.execute(sql, (user_id,)) - return txn.fetchall() - - rows = yield self.runInteraction("get_receipts_for_user_with_orderings", f) - return { - row[0]: { - "event_id": row[1], - "topological_ordering": row[2], - "stream_ordering": row[3], - } - for row in rows - } - - @defer.inlineCallbacks - def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): - """Get receipts for multiple rooms for sending to clients. - - Args: - room_ids (list): List of room_ids. - to_key (int): Max stream id to fetch receipts upto. - from_key (int): Min stream id to fetch receipts from. None fetches - from the start. - - Returns: - list: A list of receipts. - """ - room_ids = set(room_ids) - - if from_key is not None: - # Only ask the database about rooms where there have been new - # receipts added since `from_key` - room_ids = yield self._receipts_stream_cache.get_entities_changed( - room_ids, from_key - ) - - results = yield self._get_linearized_receipts_for_rooms( - room_ids, to_key, from_key=from_key - ) - - return [ev for res in results.values() for ev in res] - - def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): - """Get receipts for a single room for sending to clients. - - Args: - room_ids (str): The room id. - to_key (int): Max stream id to fetch receipts upto. - from_key (int): Min stream id to fetch receipts from. None fetches - from the start. - - Returns: - Deferred[list]: A list of receipts. - """ - if from_key is not None: - # Check the cache first to see if any new receipts have been added - # since`from_key`. If not we can no-op. - if not self._receipts_stream_cache.has_entity_changed(room_id, from_key): - defer.succeed([]) - - return self._get_linearized_receipts_for_room(room_id, to_key, from_key) - - @cachedInlineCallbacks(num_args=3, tree=True) - def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): - """See get_linearized_receipts_for_room - """ - - def f(txn): - if from_key: - sql = ( - "SELECT * FROM receipts_linearized WHERE" - " room_id = ? AND stream_id > ? AND stream_id <= ?" - ) - - txn.execute(sql, (room_id, from_key, to_key)) - else: - sql = ( - "SELECT * FROM receipts_linearized WHERE" - " room_id = ? AND stream_id <= ?" - ) - - txn.execute(sql, (room_id, to_key)) - - rows = self.cursor_to_dict(txn) - - return rows - - rows = yield self.runInteraction("get_linearized_receipts_for_room", f) - - if not rows: - return [] - - content = {} - for row in rows: - content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[ - row["user_id"] - ] = json.loads(row["data"]) - - return [{"type": "m.receipt", "room_id": room_id, "content": content}] - - @cachedList( - cached_method_name="_get_linearized_receipts_for_room", - list_name="room_ids", - num_args=3, - inlineCallbacks=True, - ) - def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): - if not room_ids: - return {} - - def f(txn): - if from_key: - sql = """ - SELECT * FROM receipts_linearized WHERE - stream_id > ? AND stream_id <= ? AND - """ - clause, args = make_in_list_sql_clause( - self.database_engine, "room_id", room_ids - ) - - txn.execute(sql + clause, [from_key, to_key] + list(args)) - else: - sql = """ - SELECT * FROM receipts_linearized WHERE - stream_id <= ? AND - """ - - clause, args = make_in_list_sql_clause( - self.database_engine, "room_id", room_ids - ) - - txn.execute(sql + clause, [to_key] + list(args)) - - return self.cursor_to_dict(txn) - - txn_results = yield self.runInteraction("_get_linearized_receipts_for_rooms", f) - - results = {} - for row in txn_results: - # We want a single event per room, since we want to batch the - # receipts by room, event and type. - room_event = results.setdefault( - row["room_id"], - {"type": "m.receipt", "room_id": row["room_id"], "content": {}}, - ) - - # The content is of the form: - # {"$foo:bar": { "read": { "@user:host": }, .. }, .. } - event_entry = room_event["content"].setdefault(row["event_id"], {}) - receipt_type = event_entry.setdefault(row["receipt_type"], {}) - - receipt_type[row["user_id"]] = json.loads(row["data"]) - - results = { - room_id: [results[room_id]] if room_id in results else [] - for room_id in room_ids - } - return results - - def get_all_updated_receipts(self, last_id, current_id, limit=None): - if last_id == current_id: - return defer.succeed([]) - - def get_all_updated_receipts_txn(txn): - sql = ( - "SELECT stream_id, room_id, receipt_type, user_id, event_id, data" - " FROM receipts_linearized" - " WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC" - ) - args = [last_id, current_id] - if limit is not None: - sql += " LIMIT ?" - args.append(limit) - txn.execute(sql, args) - - return (r[0:5] + (json.loads(r[5]),) for r in txn) - - return self.runInteraction( - "get_all_updated_receipts", get_all_updated_receipts_txn - ) - - def _invalidate_get_users_with_receipts_in_room( - self, room_id, receipt_type, user_id - ): - if receipt_type != "m.read": - return - - # Returns either an ObservableDeferred or the raw result - res = self.get_users_with_read_receipts_in_room.cache.get( - room_id, None, update_metrics=False - ) - - # first handle the Deferred case - if isinstance(res, defer.Deferred): - if res.called: - res = res.result - else: - res = None - - if res and user_id in res: - # We'd only be adding to the set, so no point invalidating if the - # user is already there - return - - self.get_users_with_read_receipts_in_room.invalidate((room_id,)) - - -class ReceiptsStore(ReceiptsWorkerStore): - def __init__(self, db_conn, hs): - # We instantiate this first as the ReceiptsWorkerStore constructor - # needs to be able to call get_max_receipt_stream_id - self._receipts_id_gen = StreamIdGenerator( - db_conn, "receipts_linearized", "stream_id" - ) - - super(ReceiptsStore, self).__init__(db_conn, hs) - - def get_max_receipt_stream_id(self): - return self._receipts_id_gen.get_current_token() - - def insert_linearized_receipt_txn( - self, txn, room_id, receipt_type, user_id, event_id, data, stream_id - ): - """Inserts a read-receipt into the database if it's newer than the current RR - - Returns: int|None - None if the RR is older than the current RR - otherwise, the rx timestamp of the event that the RR corresponds to - (or 0 if the event is unknown) - """ - res = self._simple_select_one_txn( - txn, - table="events", - retcols=["stream_ordering", "received_ts"], - keyvalues={"event_id": event_id}, - allow_none=True, - ) - - stream_ordering = int(res["stream_ordering"]) if res else None - rx_ts = res["received_ts"] if res else 0 - - # We don't want to clobber receipts for more recent events, so we - # have to compare orderings of existing receipts - if stream_ordering is not None: - sql = ( - "SELECT stream_ordering, event_id FROM events" - " INNER JOIN receipts_linearized as r USING (event_id, room_id)" - " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?" - ) - txn.execute(sql, (room_id, receipt_type, user_id)) - - for so, eid in txn: - if int(so) >= stream_ordering: - logger.debug( - "Ignoring new receipt for %s in favour of existing " - "one for later event %s", - event_id, - eid, - ) - return None - - txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type)) - txn.call_after( - self._invalidate_get_users_with_receipts_in_room, - room_id, - receipt_type, - user_id, - ) - txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type)) - # FIXME: This shouldn't invalidate the whole cache - txn.call_after( - self._get_linearized_receipts_for_room.invalidate_many, (room_id,) - ) - - txn.call_after( - self._receipts_stream_cache.entity_has_changed, room_id, stream_id - ) - - txn.call_after( - self.get_last_receipt_event_id_for_user.invalidate, - (user_id, room_id, receipt_type), - ) - - self._simple_delete_txn( - txn, - table="receipts_linearized", - keyvalues={ - "room_id": room_id, - "receipt_type": receipt_type, - "user_id": user_id, - }, - ) - - self._simple_insert_txn( - txn, - table="receipts_linearized", - values={ - "stream_id": stream_id, - "room_id": room_id, - "receipt_type": receipt_type, - "user_id": user_id, - "event_id": event_id, - "data": json.dumps(data), - }, - ) - - if receipt_type == "m.read" and stream_ordering is not None: - self._remove_old_push_actions_before_txn( - txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering - ) - - return rx_ts - - @defer.inlineCallbacks - def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data): - """Insert a receipt, either from local client or remote server. - - Automatically does conversion between linearized and graph - representations. - """ - if not event_ids: - return - - if len(event_ids) == 1: - linearized_event_id = event_ids[0] - else: - # we need to points in graph -> linearized form. - # TODO: Make this better. - def graph_to_linear(txn): - clause, args = make_in_list_sql_clause( - self.database_engine, "event_id", event_ids - ) - - sql = """ - SELECT event_id WHERE room_id = ? AND stream_ordering IN ( - SELECT max(stream_ordering) WHERE %s - ) - """ % ( - clause, - ) - - txn.execute(sql, [room_id] + list(args)) - rows = txn.fetchall() - if rows: - return rows[0][0] - else: - raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,)) - - linearized_event_id = yield self.runInteraction( - "insert_receipt_conv", graph_to_linear - ) - - stream_id_manager = self._receipts_id_gen.get_next() - with stream_id_manager as stream_id: - event_ts = yield self.runInteraction( - "insert_linearized_receipt", - self.insert_linearized_receipt_txn, - room_id, - receipt_type, - user_id, - linearized_event_id, - data, - stream_id=stream_id, - ) - - if event_ts is None: - return None - - now = self._clock.time_msec() - logger.debug( - "RR for event %s in %s (%i ms old)", - linearized_event_id, - room_id, - now - event_ts, - ) - - yield self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data) - - max_persisted_id = self._receipts_id_gen.get_current_token() - - return stream_id, max_persisted_id - - def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data): - return self.runInteraction( - "insert_graph_receipt", - self.insert_graph_receipt_txn, - room_id, - receipt_type, - user_id, - event_ids, - data, - ) - - def insert_graph_receipt_txn( - self, txn, room_id, receipt_type, user_id, event_ids, data - ): - txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type)) - txn.call_after( - self._invalidate_get_users_with_receipts_in_room, - room_id, - receipt_type, - user_id, - ) - txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type)) - # FIXME: This shouldn't invalidate the whole cache - txn.call_after( - self._get_linearized_receipts_for_room.invalidate_many, (room_id,) - ) - - self._simple_delete_txn( - txn, - table="receipts_graph", - keyvalues={ - "room_id": room_id, - "receipt_type": receipt_type, - "user_id": user_id, - }, - ) - self._simple_insert_txn( - txn, - table="receipts_graph", - values={ - "room_id": room_id, - "receipt_type": receipt_type, - "user_id": user_id, - "event_ids": json.dumps(event_ids), - "data": json.dumps(data), - }, - ) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py deleted file mode 100644 index 6c5b29288a..0000000000 --- a/synapse/storage/registration.py +++ /dev/null @@ -1,1499 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2017-2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import re - -from six import iterkeys -from six.moves import range - -from twisted.internet import defer -from twisted.internet.defer import Deferred - -from synapse.api.constants import UserTypes -from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage import background_updates -from synapse.storage._base import SQLBaseStore -from synapse.types import UserID -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks - -THIRTY_MINUTES_IN_MS = 30 * 60 * 1000 - -logger = logging.getLogger(__name__) - - -class RegistrationWorkerStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(RegistrationWorkerStore, self).__init__(db_conn, hs) - - self.config = hs.config - self.clock = hs.get_clock() - - @cached() - def get_user_by_id(self, user_id): - return self._simple_select_one( - table="users", - keyvalues={"name": user_id}, - retcols=[ - "name", - "password_hash", - "is_guest", - "consent_version", - "consent_server_notice_sent", - "appservice_id", - "creation_ts", - "user_type", - ], - allow_none=True, - desc="get_user_by_id", - ) - - @defer.inlineCallbacks - def is_trial_user(self, user_id): - """Checks if user is in the "trial" period, i.e. within the first - N days of registration defined by `mau_trial_days` config - - Args: - user_id (str) - - Returns: - Deferred[bool] - """ - - info = yield self.get_user_by_id(user_id) - if not info: - return False - - now = self.clock.time_msec() - trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000 - is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms - return is_trial - - @cached() - def get_user_by_access_token(self, token): - """Get a user from the given access token. - - Args: - token (str): The access token of a user. - Returns: - defer.Deferred: None, if the token did not match, otherwise dict - including the keys `name`, `is_guest`, `device_id`, `token_id`, - `valid_until_ms`. - """ - return self.runInteraction( - "get_user_by_access_token", self._query_for_auth, token - ) - - @cachedInlineCallbacks() - def get_expiration_ts_for_user(self, user_id): - """Get the expiration timestamp for the account bearing a given user ID. - - Args: - user_id (str): The ID of the user. - Returns: - defer.Deferred: None, if the account has no expiration timestamp, - otherwise int representation of the timestamp (as a number of - milliseconds since epoch). - """ - res = yield self._simple_select_one_onecol( - table="account_validity", - keyvalues={"user_id": user_id}, - retcol="expiration_ts_ms", - allow_none=True, - desc="get_expiration_ts_for_user", - ) - return res - - @defer.inlineCallbacks - def set_account_validity_for_user( - self, user_id, expiration_ts, email_sent, renewal_token=None - ): - """Updates the account validity properties of the given account, with the - given values. - - Args: - user_id (str): ID of the account to update properties for. - expiration_ts (int): New expiration date, as a timestamp in milliseconds - since epoch. - email_sent (bool): True means a renewal email has been sent for this - account and there's no need to send another one for the current validity - period. - renewal_token (str): Renewal token the user can use to extend the validity - of their account. Defaults to no token. - """ - - def set_account_validity_for_user_txn(txn): - self._simple_update_txn( - txn=txn, - table="account_validity", - keyvalues={"user_id": user_id}, - updatevalues={ - "expiration_ts_ms": expiration_ts, - "email_sent": email_sent, - "renewal_token": renewal_token, - }, - ) - self._invalidate_cache_and_stream( - txn, self.get_expiration_ts_for_user, (user_id,) - ) - - yield self.runInteraction( - "set_account_validity_for_user", set_account_validity_for_user_txn - ) - - @defer.inlineCallbacks - def set_renewal_token_for_user(self, user_id, renewal_token): - """Defines a renewal token for a given user. - - Args: - user_id (str): ID of the user to set the renewal token for. - renewal_token (str): Random unique string that will be used to renew the - user's account. - - Raises: - StoreError: The provided token is already set for another user. - """ - yield self._simple_update_one( - table="account_validity", - keyvalues={"user_id": user_id}, - updatevalues={"renewal_token": renewal_token}, - desc="set_renewal_token_for_user", - ) - - @defer.inlineCallbacks - def get_user_from_renewal_token(self, renewal_token): - """Get a user ID from a renewal token. - - Args: - renewal_token (str): The renewal token to perform the lookup with. - - Returns: - defer.Deferred[str]: The ID of the user to which the token belongs. - """ - res = yield self._simple_select_one_onecol( - table="account_validity", - keyvalues={"renewal_token": renewal_token}, - retcol="user_id", - desc="get_user_from_renewal_token", - ) - - return res - - @defer.inlineCallbacks - def get_renewal_token_for_user(self, user_id): - """Get the renewal token associated with a given user ID. - - Args: - user_id (str): The user ID to lookup a token for. - - Returns: - defer.Deferred[str]: The renewal token associated with this user ID. - """ - res = yield self._simple_select_one_onecol( - table="account_validity", - keyvalues={"user_id": user_id}, - retcol="renewal_token", - desc="get_renewal_token_for_user", - ) - - return res - - @defer.inlineCallbacks - def get_users_expiring_soon(self): - """Selects users whose account will expire in the [now, now + renew_at] time - window (see configuration for account_validity for information on what renew_at - refers to). - - Returns: - Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]] - """ - - def select_users_txn(txn, now_ms, renew_at): - sql = ( - "SELECT user_id, expiration_ts_ms FROM account_validity" - " WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?" - ) - values = [False, now_ms, renew_at] - txn.execute(sql, values) - return self.cursor_to_dict(txn) - - res = yield self.runInteraction( - "get_users_expiring_soon", - select_users_txn, - self.clock.time_msec(), - self.config.account_validity.renew_at, - ) - - return res - - @defer.inlineCallbacks - def set_renewal_mail_status(self, user_id, email_sent): - """Sets or unsets the flag that indicates whether a renewal email has been sent - to the user (and the user hasn't renewed their account yet). - - Args: - user_id (str): ID of the user to set/unset the flag for. - email_sent (bool): Flag which indicates whether a renewal email has been sent - to this user. - """ - yield self._simple_update_one( - table="account_validity", - keyvalues={"user_id": user_id}, - updatevalues={"email_sent": email_sent}, - desc="set_renewal_mail_status", - ) - - @defer.inlineCallbacks - def delete_account_validity_for_user(self, user_id): - """Deletes the entry for the given user in the account validity table, removing - their expiration date and renewal token. - - Args: - user_id (str): ID of the user to remove from the account validity table. - """ - yield self._simple_delete_one( - table="account_validity", - keyvalues={"user_id": user_id}, - desc="delete_account_validity_for_user", - ) - - @defer.inlineCallbacks - def is_server_admin(self, user): - """Determines if a user is an admin of this homeserver. - - Args: - user (UserID): user ID of the user to test - - Returns (bool): - true iff the user is a server admin, false otherwise. - """ - res = yield self._simple_select_one_onecol( - table="users", - keyvalues={"name": user.to_string()}, - retcol="admin", - allow_none=True, - desc="is_server_admin", - ) - - return res if res else False - - def set_server_admin(self, user, admin): - """Sets whether a user is an admin of this homeserver. - - Args: - user (UserID): user ID of the user to test - admin (bool): true iff the user is to be a server admin, - false otherwise. - """ - return self._simple_update_one( - table="users", - keyvalues={"name": user.to_string()}, - updatevalues={"admin": 1 if admin else 0}, - desc="set_server_admin", - ) - - def _query_for_auth(self, txn, token): - sql = ( - "SELECT users.name, users.is_guest, access_tokens.id as token_id," - " access_tokens.device_id, access_tokens.valid_until_ms" - " FROM users" - " INNER JOIN access_tokens on users.name = access_tokens.user_id" - " WHERE token = ?" - ) - - txn.execute(sql, (token,)) - rows = self.cursor_to_dict(txn) - if rows: - return rows[0] - - return None - - @cachedInlineCallbacks() - def is_real_user(self, user_id): - """Determines if the user is a real user, ie does not have a 'user_type'. - - Args: - user_id (str): user id to test - - Returns: - Deferred[bool]: True if user 'user_type' is null or empty string - """ - res = yield self.runInteraction("is_real_user", self.is_real_user_txn, user_id) - return res - - @cachedInlineCallbacks() - def is_support_user(self, user_id): - """Determines if the user is of type UserTypes.SUPPORT - - Args: - user_id (str): user id to test - - Returns: - Deferred[bool]: True if user is of type UserTypes.SUPPORT - """ - res = yield self.runInteraction( - "is_support_user", self.is_support_user_txn, user_id - ) - return res - - def is_real_user_txn(self, txn, user_id): - res = self._simple_select_one_onecol_txn( - txn=txn, - table="users", - keyvalues={"name": user_id}, - retcol="user_type", - allow_none=True, - ) - return res is None - - def is_support_user_txn(self, txn, user_id): - res = self._simple_select_one_onecol_txn( - txn=txn, - table="users", - keyvalues={"name": user_id}, - retcol="user_type", - allow_none=True, - ) - return True if res == UserTypes.SUPPORT else False - - def get_users_by_id_case_insensitive(self, user_id): - """Gets users that match user_id case insensitively. - Returns a mapping of user_id -> password_hash. - """ - - def f(txn): - sql = ( - "SELECT name, password_hash FROM users" " WHERE lower(name) = lower(?)" - ) - txn.execute(sql, (user_id,)) - return dict(txn) - - return self.runInteraction("get_users_by_id_case_insensitive", f) - - async def get_user_by_external_id( - self, auth_provider: str, external_id: str - ) -> str: - """Look up a user by their external auth id - - Args: - auth_provider: identifier for the remote auth provider - external_id: id on that system - - Returns: - str|None: the mxid of the user, or None if they are not known - """ - return await self._simple_select_one_onecol( - table="user_external_ids", - keyvalues={"auth_provider": auth_provider, "external_id": external_id}, - retcol="user_id", - allow_none=True, - desc="get_user_by_external_id", - ) - - @defer.inlineCallbacks - def count_all_users(self): - """Counts all users registered on the homeserver.""" - - def _count_users(txn): - txn.execute("SELECT COUNT(*) AS users FROM users") - rows = self.cursor_to_dict(txn) - if rows: - return rows[0]["users"] - return 0 - - ret = yield self.runInteraction("count_users", _count_users) - return ret - - def count_daily_user_type(self): - """ - Counts 1) native non guest users - 2) native guests users - 3) bridged users - who registered on the homeserver in the past 24 hours - """ - - def _count_daily_user_type(txn): - yesterday = int(self._clock.time()) - (60 * 60 * 24) - - sql = """ - SELECT user_type, COALESCE(count(*), 0) AS count FROM ( - SELECT - CASE - WHEN is_guest=0 AND appservice_id IS NULL THEN 'native' - WHEN is_guest=1 AND appservice_id IS NULL THEN 'guest' - WHEN is_guest=0 AND appservice_id IS NOT NULL THEN 'bridged' - END AS user_type - FROM users - WHERE creation_ts > ? - ) AS t GROUP BY user_type - """ - results = {"native": 0, "guest": 0, "bridged": 0} - txn.execute(sql, (yesterday,)) - for row in txn: - results[row[0]] = row[1] - return results - - return self.runInteraction("count_daily_user_type", _count_daily_user_type) - - @defer.inlineCallbacks - def count_nonbridged_users(self): - def _count_users(txn): - txn.execute( - """ - SELECT COALESCE(COUNT(*), 0) FROM users - WHERE appservice_id IS NULL - """ - ) - count, = txn.fetchone() - return count - - ret = yield self.runInteraction("count_users", _count_users) - return ret - - @defer.inlineCallbacks - def count_real_users(self): - """Counts all users without a special user_type registered on the homeserver.""" - - def _count_users(txn): - txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null") - rows = self.cursor_to_dict(txn) - if rows: - return rows[0]["users"] - return 0 - - ret = yield self.runInteraction("count_real_users", _count_users) - return ret - - @defer.inlineCallbacks - def find_next_generated_user_id_localpart(self): - """ - Gets the localpart of the next generated user ID. - - Generated user IDs are integers, and we aim for them to be as small as - we can. Unfortunately, it's possible some of them are already taken by - existing users, and there may be gaps in the already taken range. This - function returns the start of the first allocatable gap. This is to - avoid the case of ID 10000000 being pre-allocated, so us wasting the - first (and shortest) many generated user IDs. - """ - - def _find_next_generated_user_id(txn): - # We bound between '@1' and '@a' to avoid pulling the entire table - # out. - txn.execute("SELECT name FROM users WHERE '@1' <= name AND name < '@a'") - - regex = re.compile(r"^@(\d+):") - - found = set() - - for (user_id,) in txn: - match = regex.search(user_id) - if match: - found.add(int(match.group(1))) - for i in range(len(found) + 1): - if i not in found: - return i - - return ( - ( - yield self.runInteraction( - "find_next_generated_user_id", _find_next_generated_user_id - ) - ) - ) - - @defer.inlineCallbacks - def get_user_id_by_threepid(self, medium, address): - """Returns user id from threepid - - Args: - medium (str): threepid medium e.g. email - address (str): threepid address e.g. me@example.com - - Returns: - Deferred[str|None]: user id or None if no user id/threepid mapping exists - """ - user_id = yield self.runInteraction( - "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address - ) - return user_id - - def get_user_id_by_threepid_txn(self, txn, medium, address): - """Returns user id from threepid - - Args: - txn (cursor): - medium (str): threepid medium e.g. email - address (str): threepid address e.g. me@example.com - - Returns: - str|None: user id or None if no user id/threepid mapping exists - """ - ret = self._simple_select_one_txn( - txn, - "user_threepids", - {"medium": medium, "address": address}, - ["user_id"], - True, - ) - if ret: - return ret["user_id"] - return None - - @defer.inlineCallbacks - def user_add_threepid(self, user_id, medium, address, validated_at, added_at): - yield self._simple_upsert( - "user_threepids", - {"medium": medium, "address": address}, - {"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, - ) - - @defer.inlineCallbacks - def user_get_threepids(self, user_id): - ret = yield self._simple_select_list( - "user_threepids", - {"user_id": user_id}, - ["medium", "address", "validated_at", "added_at"], - "user_get_threepids", - ) - return ret - - def user_delete_threepid(self, user_id, medium, address): - return self._simple_delete( - "user_threepids", - keyvalues={"user_id": user_id, "medium": medium, "address": address}, - desc="user_delete_threepids", - ) - - def add_user_bound_threepid(self, user_id, medium, address, id_server): - """The server proxied a bind request to the given identity server on - behalf of the given user. We need to remember this in case the user - asks us to unbind the threepid. - - Args: - user_id (str) - medium (str) - address (str) - id_server (str) - - Returns: - Deferred - """ - # We need to use an upsert, in case they user had already bound the - # threepid - return self._simple_upsert( - table="user_threepid_id_server", - keyvalues={ - "user_id": user_id, - "medium": medium, - "address": address, - "id_server": id_server, - }, - values={}, - insertion_values={}, - desc="add_user_bound_threepid", - ) - - def user_get_bound_threepids(self, user_id): - """Get the threepids that a user has bound to an identity server through the homeserver - The homeserver remembers where binds to an identity server occurred. Using this - method can retrieve those threepids. - - Args: - user_id (str): The ID of the user to retrieve threepids for - - Returns: - Deferred[list[dict]]: List of dictionaries containing the following: - medium (str): The medium of the threepid (e.g "email") - address (str): The address of the threepid (e.g "bob@example.com") - """ - return self._simple_select_list( - table="user_threepid_id_server", - keyvalues={"user_id": user_id}, - retcols=["medium", "address"], - desc="user_get_bound_threepids", - ) - - def remove_user_bound_threepid(self, user_id, medium, address, id_server): - """The server proxied an unbind request to the given identity server on - behalf of the given user, so we remove the mapping of threepid to - identity server. - - Args: - user_id (str) - medium (str) - address (str) - id_server (str) - - Returns: - Deferred - """ - return self._simple_delete( - table="user_threepid_id_server", - keyvalues={ - "user_id": user_id, - "medium": medium, - "address": address, - "id_server": id_server, - }, - desc="remove_user_bound_threepid", - ) - - def get_id_servers_user_bound(self, user_id, medium, address): - """Get the list of identity servers that the server proxied bind - requests to for given user and threepid - - Args: - user_id (str) - medium (str) - address (str) - - Returns: - Deferred[list[str]]: Resolves to a list of identity servers - """ - return self._simple_select_onecol( - table="user_threepid_id_server", - keyvalues={"user_id": user_id, "medium": medium, "address": address}, - retcol="id_server", - desc="get_id_servers_user_bound", - ) - - @cachedInlineCallbacks() - def get_user_deactivated_status(self, user_id): - """Retrieve the value for the `deactivated` property for the provided user. - - Args: - user_id (str): The ID of the user to retrieve the status for. - - Returns: - defer.Deferred(bool): The requested value. - """ - - res = yield self._simple_select_one_onecol( - table="users", - keyvalues={"name": user_id}, - retcol="deactivated", - desc="get_user_deactivated_status", - ) - - # Convert the integer into a boolean. - return res == 1 - - def get_threepid_validation_session( - self, medium, client_secret, address=None, sid=None, validated=True - ): - """Gets a session_id and last_send_attempt (if available) for a - combination of validation metadata - - Args: - medium (str|None): The medium of the 3PID - address (str|None): The address of the 3PID - sid (str|None): The ID of the validation session - client_secret (str): A unique string provided by the client to help identify this - validation attempt - validated (bool|None): Whether sessions should be filtered by - whether they have been validated already or not. None to - perform no filtering - - Returns: - Deferred[dict|None]: A dict containing the following: - * address - address of the 3pid - * medium - medium of the 3pid - * client_secret - a secret provided by the client for this validation session - * session_id - ID of the validation session - * send_attempt - a number serving to dedupe send attempts for this session - * validated_at - timestamp of when this session was validated if so - - Otherwise None if a validation session is not found - """ - if not client_secret: - raise SynapseError( - 400, "Missing parameter: client_secret", errcode=Codes.MISSING_PARAM - ) - - keyvalues = {"client_secret": client_secret} - if medium: - keyvalues["medium"] = medium - if address: - keyvalues["address"] = address - if sid: - keyvalues["session_id"] = sid - - assert address or sid - - def get_threepid_validation_session_txn(txn): - sql = """ - SELECT address, session_id, medium, client_secret, - last_send_attempt, validated_at - FROM threepid_validation_session WHERE %s - """ % ( - " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)), - ) - - if validated is not None: - sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL") - - sql += " LIMIT 1" - - txn.execute(sql, list(keyvalues.values())) - rows = self.cursor_to_dict(txn) - if not rows: - return None - - return rows[0] - - return self.runInteraction( - "get_threepid_validation_session", get_threepid_validation_session_txn - ) - - def delete_threepid_session(self, session_id): - """Removes a threepid validation session from the database. This can - be done after validation has been performed and whatever action was - waiting on it has been carried out - - Args: - session_id (str): The ID of the session to delete - """ - - def delete_threepid_session_txn(txn): - self._simple_delete_txn( - txn, - table="threepid_validation_token", - keyvalues={"session_id": session_id}, - ) - self._simple_delete_txn( - txn, - table="threepid_validation_session", - keyvalues={"session_id": session_id}, - ) - - return self.runInteraction( - "delete_threepid_session", delete_threepid_session_txn - ) - - -class RegistrationBackgroundUpdateStore( - RegistrationWorkerStore, background_updates.BackgroundUpdateStore -): - def __init__(self, db_conn, hs): - super(RegistrationBackgroundUpdateStore, self).__init__(db_conn, hs) - - self.clock = hs.get_clock() - self.config = hs.config - - self.register_background_index_update( - "access_tokens_device_index", - index_name="access_tokens_device_id", - table="access_tokens", - columns=["user_id", "device_id"], - ) - - self.register_background_index_update( - "users_creation_ts", - index_name="users_creation_ts", - table="users", - columns=["creation_ts"], - ) - - # we no longer use refresh tokens, but it's possible that some people - # might have a background update queued to build this index. Just - # clear the background update. - self.register_noop_background_update("refresh_tokens_device_index") - - self.register_background_update_handler( - "user_threepids_grandfather", self._bg_user_threepids_grandfather - ) - - self.register_background_update_handler( - "users_set_deactivated_flag", self._background_update_set_deactivated_flag - ) - - @defer.inlineCallbacks - def _background_update_set_deactivated_flag(self, progress, batch_size): - """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1 - for each of them. - """ - - last_user = progress.get("user_id", "") - - def _background_update_set_deactivated_flag_txn(txn): - txn.execute( - """ - SELECT - users.name, - COUNT(access_tokens.token) AS count_tokens, - COUNT(user_threepids.address) AS count_threepids - FROM users - LEFT JOIN access_tokens ON (access_tokens.user_id = users.name) - LEFT JOIN user_threepids ON (user_threepids.user_id = users.name) - WHERE (users.password_hash IS NULL OR users.password_hash = '') - AND (users.appservice_id IS NULL OR users.appservice_id = '') - AND users.is_guest = 0 - AND users.name > ? - GROUP BY users.name - ORDER BY users.name ASC - LIMIT ?; - """, - (last_user, batch_size), - ) - - rows = self.cursor_to_dict(txn) - - if not rows: - return True, 0 - - rows_processed_nb = 0 - - for user in rows: - if not user["count_tokens"] and not user["count_threepids"]: - self.set_user_deactivated_status_txn(txn, user["name"], True) - rows_processed_nb += 1 - - logger.info("Marked %d rows as deactivated", rows_processed_nb) - - self._background_update_progress_txn( - txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]} - ) - - if batch_size > len(rows): - return True, len(rows) - else: - return False, len(rows) - - end, nb_processed = yield self.runInteraction( - "users_set_deactivated_flag", _background_update_set_deactivated_flag_txn - ) - - if end: - yield self._end_background_update("users_set_deactivated_flag") - - return nb_processed - - @defer.inlineCallbacks - def _bg_user_threepids_grandfather(self, progress, batch_size): - """We now track which identity servers a user binds their 3PID to, so - we need to handle the case of existing bindings where we didn't track - this. - - We do this by grandfathering in existing user threepids assuming that - they used one of the server configured trusted identity servers. - """ - id_servers = set(self.config.trusted_third_party_id_servers) - - def _bg_user_threepids_grandfather_txn(txn): - sql = """ - INSERT INTO user_threepid_id_server - (user_id, medium, address, id_server) - SELECT user_id, medium, address, ? - FROM user_threepids - """ - - txn.executemany(sql, [(id_server,) for id_server in id_servers]) - - if id_servers: - yield self.runInteraction( - "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn - ) - - yield self._end_background_update("user_threepids_grandfather") - - return 1 - - -class RegistrationStore(RegistrationBackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(RegistrationStore, self).__init__(db_conn, hs) - - self._account_validity = hs.config.account_validity - - # Create a background job for culling expired 3PID validity tokens - def start_cull(): - # run as a background process to make sure that the database transactions - # have a logcontext to report to - return run_as_background_process( - "cull_expired_threepid_validation_tokens", - self.cull_expired_threepid_validation_tokens, - ) - - hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS) - - @defer.inlineCallbacks - def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms): - """Adds an access token for the given user. - - Args: - user_id (str): The user ID. - token (str): The new access token to add. - device_id (str): ID of the device to associate with the access - token - valid_until_ms (int|None): when the token is valid until. None for - no expiry. - Raises: - StoreError if there was a problem adding this. - """ - next_id = self._access_tokens_id_gen.get_next() - - yield self._simple_insert( - "access_tokens", - { - "id": next_id, - "user_id": user_id, - "token": token, - "device_id": device_id, - "valid_until_ms": valid_until_ms, - }, - desc="add_access_token_to_user", - ) - - def register_user( - self, - user_id, - password_hash=None, - was_guest=False, - make_guest=False, - appservice_id=None, - create_profile_with_displayname=None, - admin=False, - user_type=None, - ): - """Attempts to register an account. - - Args: - user_id (str): The desired user ID to register. - password_hash (str): Optional. The password hash for this user. - was_guest (bool): Optional. Whether this is a guest account being - upgraded to a non-guest account. - make_guest (boolean): True if the the new user should be guest, - false to add a regular user account. - appservice_id (str): The ID of the appservice registering the user. - create_profile_with_displayname (unicode): Optionally create a profile for - the user, setting their displayname to the given value - admin (boolean): is an admin user? - user_type (str|None): type of user. One of the values from - api.constants.UserTypes, or None for a normal user. - - Raises: - StoreError if the user_id could not be registered. - """ - return self.runInteraction( - "register_user", - self._register_user, - user_id, - password_hash, - was_guest, - make_guest, - appservice_id, - create_profile_with_displayname, - admin, - user_type, - ) - - def _register_user( - self, - txn, - user_id, - password_hash, - was_guest, - make_guest, - appservice_id, - create_profile_with_displayname, - admin, - user_type, - ): - user_id_obj = UserID.from_string(user_id) - - now = int(self.clock.time()) - - try: - if was_guest: - # Ensure that the guest user actually exists - # ``allow_none=False`` makes this raise an exception - # if the row isn't in the database. - self._simple_select_one_txn( - txn, - "users", - keyvalues={"name": user_id, "is_guest": 1}, - retcols=("name",), - allow_none=False, - ) - - self._simple_update_one_txn( - txn, - "users", - keyvalues={"name": user_id, "is_guest": 1}, - updatevalues={ - "password_hash": password_hash, - "upgrade_ts": now, - "is_guest": 1 if make_guest else 0, - "appservice_id": appservice_id, - "admin": 1 if admin else 0, - "user_type": user_type, - }, - ) - else: - self._simple_insert_txn( - txn, - "users", - values={ - "name": user_id, - "password_hash": password_hash, - "creation_ts": now, - "is_guest": 1 if make_guest else 0, - "appservice_id": appservice_id, - "admin": 1 if admin else 0, - "user_type": user_type, - }, - ) - - except self.database_engine.module.IntegrityError: - raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE) - - if self._account_validity.enabled: - self.set_expiration_date_for_user_txn(txn, user_id) - - if create_profile_with_displayname: - # set a default displayname serverside to avoid ugly race - # between auto-joins and clients trying to set displaynames - # - # *obviously* the 'profiles' table uses localpart for user_id - # while everything else uses the full mxid. - txn.execute( - "INSERT INTO profiles(user_id, displayname) VALUES (?,?)", - (user_id_obj.localpart, create_profile_with_displayname), - ) - - if self.hs.config.stats_enabled: - # we create a new completed user statistics row - - # we don't strictly need current_token since this user really can't - # have any state deltas before now (as it is a new user), but still, - # we include it for completeness. - current_token = self._get_max_stream_id_in_current_state_deltas_txn(txn) - self._update_stats_delta_txn( - txn, now, "user", user_id, {}, complete_with_stream_id=current_token - ) - - self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - txn.call_after(self.is_guest.invalidate, (user_id,)) - - def record_user_external_id( - self, auth_provider: str, external_id: str, user_id: str - ) -> Deferred: - """Record a mapping from an external user id to a mxid - - Args: - auth_provider: identifier for the remote auth provider - external_id: id on that system - user_id: complete mxid that it is mapped to - """ - return self._simple_insert( - table="user_external_ids", - values={ - "auth_provider": auth_provider, - "external_id": external_id, - "user_id": user_id, - }, - desc="record_user_external_id", - ) - - def user_set_password_hash(self, user_id, password_hash): - """ - NB. This does *not* evict any cache because the one use for this - removes most of the entries subsequently anyway so it would be - pointless. Use flush_user separately. - """ - - def user_set_password_hash_txn(txn): - self._simple_update_one_txn( - txn, "users", {"name": user_id}, {"password_hash": password_hash} - ) - self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - - return self.runInteraction("user_set_password_hash", user_set_password_hash_txn) - - def user_set_consent_version(self, user_id, consent_version): - """Updates the user table to record privacy policy consent - - Args: - user_id (str): full mxid of the user to update - consent_version (str): version of the policy the user has consented - to - - Raises: - StoreError(404) if user not found - """ - - def f(txn): - self._simple_update_one_txn( - txn, - table="users", - keyvalues={"name": user_id}, - updatevalues={"consent_version": consent_version}, - ) - self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - - return self.runInteraction("user_set_consent_version", f) - - def user_set_consent_server_notice_sent(self, user_id, consent_version): - """Updates the user table to record that we have sent the user a server - notice about privacy policy consent - - Args: - user_id (str): full mxid of the user to update - consent_version (str): version of the policy we have notified the - user about - - Raises: - StoreError(404) if user not found - """ - - def f(txn): - self._simple_update_one_txn( - txn, - table="users", - keyvalues={"name": user_id}, - updatevalues={"consent_server_notice_sent": consent_version}, - ) - self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - - return self.runInteraction("user_set_consent_server_notice_sent", f) - - def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None): - """ - Invalidate access tokens belonging to a user - - Args: - user_id (str): ID of user the tokens belong to - except_token_id (str): list of access_tokens IDs which should - *not* be deleted - device_id (str|None): ID of device the tokens are associated with. - If None, tokens associated with any device (or no device) will - be deleted - Returns: - defer.Deferred[list[str, int, str|None, int]]: a list of - (token, token id, device id) for each of the deleted tokens - """ - - def f(txn): - keyvalues = {"user_id": user_id} - if device_id is not None: - keyvalues["device_id"] = device_id - - items = keyvalues.items() - where_clause = " AND ".join(k + " = ?" for k, _ in items) - values = [v for _, v in items] - if except_token_id: - where_clause += " AND id != ?" - values.append(except_token_id) - - txn.execute( - "SELECT token, id, device_id FROM access_tokens WHERE %s" - % where_clause, - values, - ) - tokens_and_devices = [(r[0], r[1], r[2]) for r in txn] - - for token, _, _ in tokens_and_devices: - self._invalidate_cache_and_stream( - txn, self.get_user_by_access_token, (token,) - ) - - txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values) - - return tokens_and_devices - - return self.runInteraction("user_delete_access_tokens", f) - - def delete_access_token(self, access_token): - def f(txn): - self._simple_delete_one_txn( - txn, table="access_tokens", keyvalues={"token": access_token} - ) - - self._invalidate_cache_and_stream( - txn, self.get_user_by_access_token, (access_token,) - ) - - return self.runInteraction("delete_access_token", f) - - @cachedInlineCallbacks() - def is_guest(self, user_id): - res = yield self._simple_select_one_onecol( - table="users", - keyvalues={"name": user_id}, - retcol="is_guest", - allow_none=True, - desc="is_guest", - ) - - return res if res else False - - def add_user_pending_deactivation(self, user_id): - """ - Adds a user to the table of users who need to be parted from all the rooms they're - in - """ - return self._simple_insert( - "users_pending_deactivation", - values={"user_id": user_id}, - desc="add_user_pending_deactivation", - ) - - def del_user_pending_deactivation(self, user_id): - """ - Removes the given user to the table of users who need to be parted from all the - rooms they're in, effectively marking that user as fully deactivated. - """ - # XXX: This should be simple_delete_one but we failed to put a unique index on - # the table, so somehow duplicate entries have ended up in it. - return self._simple_delete( - "users_pending_deactivation", - keyvalues={"user_id": user_id}, - desc="del_user_pending_deactivation", - ) - - def get_user_pending_deactivation(self): - """ - Gets one user from the table of users waiting to be parted from all the rooms - they're in. - """ - return self._simple_select_one_onecol( - "users_pending_deactivation", - keyvalues={}, - retcol="user_id", - allow_none=True, - desc="get_users_pending_deactivation", - ) - - def validate_threepid_session(self, session_id, client_secret, token, current_ts): - """Attempt to validate a threepid session using a token - - Args: - session_id (str): The id of a validation session - client_secret (str): A unique string provided by the client to - help identify this validation attempt - token (str): A validation token - current_ts (int): The current unix time in milliseconds. Used for - checking token expiry status - - Raises: - ThreepidValidationError: if a matching validation token was not found or has - expired - - Returns: - deferred str|None: A str representing a link to redirect the user - to if there is one. - """ - - # Insert everything into a transaction in order to run atomically - def validate_threepid_session_txn(txn): - row = self._simple_select_one_txn( - txn, - table="threepid_validation_session", - keyvalues={"session_id": session_id}, - retcols=["client_secret", "validated_at"], - allow_none=True, - ) - - if not row: - raise ThreepidValidationError(400, "Unknown session_id") - retrieved_client_secret = row["client_secret"] - validated_at = row["validated_at"] - - if retrieved_client_secret != client_secret: - raise ThreepidValidationError( - 400, "This client_secret does not match the provided session_id" - ) - - row = self._simple_select_one_txn( - txn, - table="threepid_validation_token", - keyvalues={"session_id": session_id, "token": token}, - retcols=["expires", "next_link"], - allow_none=True, - ) - - if not row: - raise ThreepidValidationError( - 400, "Validation token not found or has expired" - ) - expires = row["expires"] - next_link = row["next_link"] - - # If the session is already validated, no need to revalidate - if validated_at: - return next_link - - if expires <= current_ts: - raise ThreepidValidationError( - 400, "This token has expired. Please request a new one" - ) - - # Looks good. Validate the session - self._simple_update_txn( - txn, - table="threepid_validation_session", - keyvalues={"session_id": session_id}, - updatevalues={"validated_at": self.clock.time_msec()}, - ) - - return next_link - - # Return next_link if it exists - return self.runInteraction( - "validate_threepid_session_txn", validate_threepid_session_txn - ) - - def upsert_threepid_validation_session( - self, - medium, - address, - client_secret, - send_attempt, - session_id, - validated_at=None, - ): - """Upsert a threepid validation session - Args: - medium (str): The medium of the 3PID - address (str): The address of the 3PID - client_secret (str): A unique string provided by the client to - help identify this validation attempt - send_attempt (int): The latest send_attempt on this session - session_id (str): The id of this validation session - validated_at (int|None): The unix timestamp in milliseconds of - when the session was marked as valid - """ - insertion_values = { - "medium": medium, - "address": address, - "client_secret": client_secret, - } - - if validated_at: - insertion_values["validated_at"] = validated_at - - return self._simple_upsert( - table="threepid_validation_session", - keyvalues={"session_id": session_id}, - values={"last_send_attempt": send_attempt}, - insertion_values=insertion_values, - desc="upsert_threepid_validation_session", - ) - - def start_or_continue_validation_session( - self, - medium, - address, - session_id, - client_secret, - send_attempt, - next_link, - token, - token_expires, - ): - """Creates a new threepid validation session if it does not already - exist and associates a new validation token with it - - Args: - medium (str): The medium of the 3PID - address (str): The address of the 3PID - session_id (str): The id of this validation session - client_secret (str): A unique string provided by the client to - help identify this validation attempt - send_attempt (int): The latest send_attempt on this session - next_link (str|None): The link to redirect the user to upon - successful validation - token (str): The validation token - token_expires (int): The timestamp for which after the token - will no longer be valid - """ - - def start_or_continue_validation_session_txn(txn): - # Create or update a validation session - self._simple_upsert_txn( - txn, - table="threepid_validation_session", - keyvalues={"session_id": session_id}, - values={"last_send_attempt": send_attempt}, - insertion_values={ - "medium": medium, - "address": address, - "client_secret": client_secret, - }, - ) - - # Create a new validation token with this session ID - self._simple_insert_txn( - txn, - table="threepid_validation_token", - values={ - "session_id": session_id, - "token": token, - "next_link": next_link, - "expires": token_expires, - }, - ) - - return self.runInteraction( - "start_or_continue_validation_session", - start_or_continue_validation_session_txn, - ) - - def cull_expired_threepid_validation_tokens(self): - """Remove threepid validation tokens with expiry dates that have passed""" - - def cull_expired_threepid_validation_tokens_txn(txn, ts): - sql = """ - DELETE FROM threepid_validation_token WHERE - expires < ? - """ - return txn.execute(sql, (ts,)) - - return self.runInteraction( - "cull_expired_threepid_validation_tokens", - cull_expired_threepid_validation_tokens_txn, - self.clock.time_msec(), - ) - - @defer.inlineCallbacks - def set_user_deactivated_status(self, user_id, deactivated): - """Set the `deactivated` property for the provided user to the provided value. - - Args: - user_id (str): The ID of the user to set the status for. - deactivated (bool): The value to set for `deactivated`. - """ - - yield self.runInteraction( - "set_user_deactivated_status", - self.set_user_deactivated_status_txn, - user_id, - deactivated, - ) - - def set_user_deactivated_status_txn(self, txn, user_id, deactivated): - self._simple_update_one_txn( - txn=txn, - table="users", - keyvalues={"name": user_id}, - updatevalues={"deactivated": 1 if deactivated else 0}, - ) - self._invalidate_cache_and_stream( - txn, self.get_user_deactivated_status, (user_id,) - ) diff --git a/synapse/storage/rejections.py b/synapse/storage/rejections.py deleted file mode 100644 index f4c1c2a457..0000000000 --- a/synapse/storage/rejections.py +++ /dev/null @@ -1,42 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from ._base import SQLBaseStore - -logger = logging.getLogger(__name__) - - -class RejectionsStore(SQLBaseStore): - def _store_rejections_txn(self, txn, event_id, reason): - self._simple_insert_txn( - txn, - table="rejections", - values={ - "event_id": event_id, - "reason": reason, - "last_check": self._clock.time_msec(), - }, - ) - - def get_rejection_reason(self, event_id): - return self._simple_select_one_onecol( - table="rejections", - retcol="reason", - keyvalues={"event_id": event_id}, - allow_none=True, - desc="get_rejection_reason", - ) diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index fcb5f2f23a..d471ec9860 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -17,11 +17,7 @@ import logging import attr -from synapse.api.constants import RelationTypes from synapse.api.errors import SynapseError -from synapse.storage._base import SQLBaseStore -from synapse.storage.stream import generate_pagination_where_clause -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks logger = logging.getLogger(__name__) @@ -113,358 +109,3 @@ class AggregationPaginationToken(object): def as_tuple(self): return attr.astuple(self) - - -class RelationsWorkerStore(SQLBaseStore): - @cached(tree=True) - def get_relations_for_event( - self, - event_id, - relation_type=None, - event_type=None, - aggregation_key=None, - limit=5, - direction="b", - from_token=None, - to_token=None, - ): - """Get a list of relations for an event, ordered by topological ordering. - - Args: - event_id (str): Fetch events that relate to this event ID. - relation_type (str|None): Only fetch events with this relation - type, if given. - event_type (str|None): Only fetch events with this event type, if - given. - aggregation_key (str|None): Only fetch events with this aggregation - key, if given. - limit (int): Only fetch the most recent `limit` events. - direction (str): Whether to fetch the most recent first (`"b"`) or - the oldest first (`"f"`). - from_token (RelationPaginationToken|None): Fetch rows from the given - token, or from the start if None. - to_token (RelationPaginationToken|None): Fetch rows up to the given - token, or up to the end if None. - - Returns: - Deferred[PaginationChunk]: List of event IDs that match relations - requested. The rows are of the form `{"event_id": "..."}`. - """ - - where_clause = ["relates_to_id = ?"] - where_args = [event_id] - - if relation_type is not None: - where_clause.append("relation_type = ?") - where_args.append(relation_type) - - if event_type is not None: - where_clause.append("type = ?") - where_args.append(event_type) - - if aggregation_key: - where_clause.append("aggregation_key = ?") - where_args.append(aggregation_key) - - pagination_clause = generate_pagination_where_clause( - direction=direction, - column_names=("topological_ordering", "stream_ordering"), - from_token=attr.astuple(from_token) if from_token else None, - to_token=attr.astuple(to_token) if to_token else None, - engine=self.database_engine, - ) - - if pagination_clause: - where_clause.append(pagination_clause) - - if direction == "b": - order = "DESC" - else: - order = "ASC" - - sql = """ - SELECT event_id, topological_ordering, stream_ordering - FROM event_relations - INNER JOIN events USING (event_id) - WHERE %s - ORDER BY topological_ordering %s, stream_ordering %s - LIMIT ? - """ % ( - " AND ".join(where_clause), - order, - order, - ) - - def _get_recent_references_for_event_txn(txn): - txn.execute(sql, where_args + [limit + 1]) - - last_topo_id = None - last_stream_id = None - events = [] - for row in txn: - events.append({"event_id": row[0]}) - last_topo_id = row[1] - last_stream_id = row[2] - - next_batch = None - if len(events) > limit and last_topo_id and last_stream_id: - next_batch = RelationPaginationToken(last_topo_id, last_stream_id) - - return PaginationChunk( - chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token - ) - - return self.runInteraction( - "get_recent_references_for_event", _get_recent_references_for_event_txn - ) - - @cached(tree=True) - def get_aggregation_groups_for_event( - self, - event_id, - event_type=None, - limit=5, - direction="b", - from_token=None, - to_token=None, - ): - """Get a list of annotations on the event, grouped by event type and - aggregation key, sorted by count. - - This is used e.g. to get the what and how many reactions have happend - on an event. - - Args: - event_id (str): Fetch events that relate to this event ID. - event_type (str|None): Only fetch events with this event type, if - given. - limit (int): Only fetch the `limit` groups. - direction (str): Whether to fetch the highest count first (`"b"`) or - the lowest count first (`"f"`). - from_token (AggregationPaginationToken|None): Fetch rows from the - given token, or from the start if None. - to_token (AggregationPaginationToken|None): Fetch rows up to the - given token, or up to the end if None. - - - Returns: - Deferred[PaginationChunk]: List of groups of annotations that - match. Each row is a dict with `type`, `key` and `count` fields. - """ - - where_clause = ["relates_to_id = ?", "relation_type = ?"] - where_args = [event_id, RelationTypes.ANNOTATION] - - if event_type: - where_clause.append("type = ?") - where_args.append(event_type) - - having_clause = generate_pagination_where_clause( - direction=direction, - column_names=("COUNT(*)", "MAX(stream_ordering)"), - from_token=attr.astuple(from_token) if from_token else None, - to_token=attr.astuple(to_token) if to_token else None, - engine=self.database_engine, - ) - - if direction == "b": - order = "DESC" - else: - order = "ASC" - - if having_clause: - having_clause = "HAVING " + having_clause - else: - having_clause = "" - - sql = """ - SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering) - FROM event_relations - INNER JOIN events USING (event_id) - WHERE {where_clause} - GROUP BY relation_type, type, aggregation_key - {having_clause} - ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order} - LIMIT ? - """.format( - where_clause=" AND ".join(where_clause), - order=order, - having_clause=having_clause, - ) - - def _get_aggregation_groups_for_event_txn(txn): - txn.execute(sql, where_args + [limit + 1]) - - next_batch = None - events = [] - for row in txn: - events.append({"type": row[0], "key": row[1], "count": row[2]}) - next_batch = AggregationPaginationToken(row[2], row[3]) - - if len(events) <= limit: - next_batch = None - - return PaginationChunk( - chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token - ) - - return self.runInteraction( - "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn - ) - - @cachedInlineCallbacks() - def get_applicable_edit(self, event_id): - """Get the most recent edit (if any) that has happened for the given - event. - - Correctly handles checking whether edits were allowed to happen. - - Args: - event_id (str): The original event ID - - Returns: - Deferred[EventBase|None]: Returns the most recent edit, if any. - """ - - # We only allow edits for `m.room.message` events that have the same sender - # and event type. We can't assert these things during regular event auth so - # we have to do the checks post hoc. - - # Fetches latest edit that has the same type and sender as the - # original, and is an `m.room.message`. - sql = """ - SELECT edit.event_id FROM events AS edit - INNER JOIN event_relations USING (event_id) - INNER JOIN events AS original ON - original.event_id = relates_to_id - AND edit.type = original.type - AND edit.sender = original.sender - WHERE - relates_to_id = ? - AND relation_type = ? - AND edit.type = 'm.room.message' - ORDER by edit.origin_server_ts DESC, edit.event_id DESC - LIMIT 1 - """ - - def _get_applicable_edit_txn(txn): - txn.execute(sql, (event_id, RelationTypes.REPLACE)) - row = txn.fetchone() - if row: - return row[0] - - edit_id = yield self.runInteraction( - "get_applicable_edit", _get_applicable_edit_txn - ) - - if not edit_id: - return - - edit_event = yield self.get_event(edit_id, allow_none=True) - return edit_event - - def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender): - """Check if a user has already annotated an event with the same key - (e.g. already liked an event). - - Args: - parent_id (str): The event being annotated - event_type (str): The event type of the annotation - aggregation_key (str): The aggregation key of the annotation - sender (str): The sender of the annotation - - Returns: - Deferred[bool] - """ - - sql = """ - SELECT 1 FROM event_relations - INNER JOIN events USING (event_id) - WHERE - relates_to_id = ? - AND relation_type = ? - AND type = ? - AND sender = ? - AND aggregation_key = ? - LIMIT 1; - """ - - def _get_if_user_has_annotated_event(txn): - txn.execute( - sql, - ( - parent_id, - RelationTypes.ANNOTATION, - event_type, - sender, - aggregation_key, - ), - ) - - return bool(txn.fetchone()) - - return self.runInteraction( - "get_if_user_has_annotated_event", _get_if_user_has_annotated_event - ) - - -class RelationsStore(RelationsWorkerStore): - def _handle_event_relations(self, txn, event): - """Handles inserting relation data during peristence of events - - Args: - txn - event (EventBase) - """ - relation = event.content.get("m.relates_to") - if not relation: - # No relations - return - - rel_type = relation.get("rel_type") - if rel_type not in ( - RelationTypes.ANNOTATION, - RelationTypes.REFERENCE, - RelationTypes.REPLACE, - ): - # Unknown relation type - return - - parent_id = relation.get("event_id") - if not parent_id: - # Invalid relation - return - - aggregation_key = relation.get("key") - - self._simple_insert_txn( - txn, - table="event_relations", - values={ - "event_id": event.event_id, - "relates_to_id": parent_id, - "relation_type": rel_type, - "aggregation_key": aggregation_key, - }, - ) - - txn.call_after(self.get_relations_for_event.invalidate_many, (parent_id,)) - txn.call_after( - self.get_aggregation_groups_for_event.invalidate_many, (parent_id,) - ) - - if rel_type == RelationTypes.REPLACE: - txn.call_after(self.get_applicable_edit.invalidate, (parent_id,)) - - def _handle_redaction(self, txn, redacted_event_id): - """Handles receiving a redaction and checking whether we need to remove - any redacted relations from the database. - - Args: - txn - redacted_event_id (str): The event that was redacted. - """ - - self._simple_delete_txn( - txn, table="event_relations", keyvalues={"event_id": redacted_event_id} - ) diff --git a/synapse/storage/room.py b/synapse/storage/room.py deleted file mode 100644 index 43cc56fa6f..0000000000 --- a/synapse/storage/room.py +++ /dev/null @@ -1,681 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections -import logging -import re -from typing import Optional, Tuple - -from canonicaljson import json - -from twisted.internet import defer - -from synapse.api.errors import StoreError -from synapse.storage._base import SQLBaseStore -from synapse.storage.search import SearchStore -from synapse.types import ThirdPartyInstanceID -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks - -logger = logging.getLogger(__name__) - - -OpsLevel = collections.namedtuple( - "OpsLevel", ("ban_level", "kick_level", "redact_level") -) - -RatelimitOverride = collections.namedtuple( - "RatelimitOverride", ("messages_per_second", "burst_count") -) - - -class RoomWorkerStore(SQLBaseStore): - def get_room(self, room_id): - """Retrieve a room. - - Args: - room_id (str): The ID of the room to retrieve. - Returns: - A dict containing the room information, or None if the room is unknown. - """ - return self._simple_select_one( - table="rooms", - keyvalues={"room_id": room_id}, - retcols=("room_id", "is_public", "creator"), - desc="get_room", - allow_none=True, - ) - - def get_public_room_ids(self): - return self._simple_select_onecol( - table="rooms", - keyvalues={"is_public": True}, - retcol="room_id", - desc="get_public_room_ids", - ) - - def count_public_rooms(self, network_tuple, ignore_non_federatable): - """Counts the number of public rooms as tracked in the room_stats_current - and room_stats_state table. - - Args: - network_tuple (ThirdPartyInstanceID|None) - ignore_non_federatable (bool): If true filters out non-federatable rooms - """ - - def _count_public_rooms_txn(txn): - query_args = [] - - if network_tuple: - if network_tuple.appservice_id: - published_sql = """ - SELECT room_id from appservice_room_list - WHERE appservice_id = ? AND network_id = ? - """ - query_args.append(network_tuple.appservice_id) - query_args.append(network_tuple.network_id) - else: - published_sql = """ - SELECT room_id FROM rooms WHERE is_public - """ - else: - published_sql = """ - SELECT room_id FROM rooms WHERE is_public - UNION SELECT room_id from appservice_room_list - """ - - sql = """ - SELECT - COALESCE(COUNT(*), 0) - FROM ( - %(published_sql)s - ) published - INNER JOIN room_stats_state USING (room_id) - INNER JOIN room_stats_current USING (room_id) - WHERE - ( - join_rules = 'public' OR history_visibility = 'world_readable' - ) - AND joined_members > 0 - """ % { - "published_sql": published_sql - } - - txn.execute(sql, query_args) - return txn.fetchone()[0] - - return self.runInteraction("count_public_rooms", _count_public_rooms_txn) - - @defer.inlineCallbacks - def get_largest_public_rooms( - self, - network_tuple: Optional[ThirdPartyInstanceID], - search_filter: Optional[dict], - limit: Optional[int], - bounds: Optional[Tuple[int, str]], - forwards: bool, - ignore_non_federatable: bool = False, - ): - """Gets the largest public rooms (where largest is in terms of joined - members, as tracked in the statistics table). - - Args: - network_tuple - search_filter - limit: Maxmimum number of rows to return, unlimited otherwise. - bounds: An uppoer or lower bound to apply to result set if given, - consists of a joined member count and room_id (these are - excluded from result set). - forwards: true iff going forwards, going backwards otherwise - ignore_non_federatable: If true filters out non-federatable rooms. - - Returns: - Rooms in order: biggest number of joined users first. - We then arbitrarily use the room_id as a tie breaker. - - """ - - where_clauses = [] - query_args = [] - - if network_tuple: - if network_tuple.appservice_id: - published_sql = """ - SELECT room_id from appservice_room_list - WHERE appservice_id = ? AND network_id = ? - """ - query_args.append(network_tuple.appservice_id) - query_args.append(network_tuple.network_id) - else: - published_sql = """ - SELECT room_id FROM rooms WHERE is_public - """ - else: - published_sql = """ - SELECT room_id FROM rooms WHERE is_public - UNION SELECT room_id from appservice_room_list - """ - - # Work out the bounds if we're given them, these bounds look slightly - # odd, but are designed to help query planner use indices by pulling - # out a common bound. - if bounds: - last_joined_members, last_room_id = bounds - if forwards: - where_clauses.append( - """ - joined_members <= ? AND ( - joined_members < ? OR room_id < ? - ) - """ - ) - else: - where_clauses.append( - """ - joined_members >= ? AND ( - joined_members > ? OR room_id > ? - ) - """ - ) - - query_args += [last_joined_members, last_joined_members, last_room_id] - - if ignore_non_federatable: - where_clauses.append("is_federatable") - - if search_filter and search_filter.get("generic_search_term", None): - search_term = "%" + search_filter["generic_search_term"] + "%" - - where_clauses.append( - """ - ( - name LIKE ? - OR topic LIKE ? - OR canonical_alias LIKE ? - ) - """ - ) - query_args += [search_term, search_term, search_term] - - where_clause = "" - if where_clauses: - where_clause = " AND " + " AND ".join(where_clauses) - - sql = """ - SELECT - room_id, name, topic, canonical_alias, joined_members, - avatar, history_visibility, joined_members, guest_access - FROM ( - %(published_sql)s - ) published - INNER JOIN room_stats_state USING (room_id) - INNER JOIN room_stats_current USING (room_id) - WHERE - ( - join_rules = 'public' OR history_visibility = 'world_readable' - ) - AND joined_members > 0 - %(where_clause)s - ORDER BY joined_members %(dir)s, room_id %(dir)s - """ % { - "published_sql": published_sql, - "where_clause": where_clause, - "dir": "DESC" if forwards else "ASC", - } - - if limit is not None: - query_args.append(limit) - - sql += """ - LIMIT ? - """ - - def _get_largest_public_rooms_txn(txn): - txn.execute(sql, query_args) - - results = self.cursor_to_dict(txn) - - if not forwards: - results.reverse() - - return results - - ret_val = yield self.runInteraction( - "get_largest_public_rooms", _get_largest_public_rooms_txn - ) - defer.returnValue(ret_val) - - @cached(max_entries=10000) - def is_room_blocked(self, room_id): - return self._simple_select_one_onecol( - table="blocked_rooms", - keyvalues={"room_id": room_id}, - retcol="1", - allow_none=True, - desc="is_room_blocked", - ) - - @cachedInlineCallbacks(max_entries=10000) - def get_ratelimit_for_user(self, user_id): - """Check if there are any overrides for ratelimiting for the given - user - - Args: - user_id (str) - - Returns: - RatelimitOverride if there is an override, else None. If the contents - of RatelimitOverride are None or 0 then ratelimitng has been - disabled for that user entirely. - """ - row = yield self._simple_select_one( - table="ratelimit_override", - keyvalues={"user_id": user_id}, - retcols=("messages_per_second", "burst_count"), - allow_none=True, - desc="get_ratelimit_for_user", - ) - - if row: - return RatelimitOverride( - messages_per_second=row["messages_per_second"], - burst_count=row["burst_count"], - ) - else: - return None - - -class RoomStore(RoomWorkerStore, SearchStore): - @defer.inlineCallbacks - def store_room(self, room_id, room_creator_user_id, is_public): - """Stores a room. - - Args: - room_id (str): The desired room ID, can be None. - room_creator_user_id (str): The user ID of the room creator. - is_public (bool): True to indicate that this room should appear in - public room lists. - Raises: - StoreError if the room could not be stored. - """ - try: - - def store_room_txn(txn, next_id): - self._simple_insert_txn( - txn, - "rooms", - { - "room_id": room_id, - "creator": room_creator_user_id, - "is_public": is_public, - }, - ) - if is_public: - self._simple_insert_txn( - txn, - table="public_room_list_stream", - values={ - "stream_id": next_id, - "room_id": room_id, - "visibility": is_public, - }, - ) - - with self._public_room_id_gen.get_next() as next_id: - yield self.runInteraction("store_room_txn", store_room_txn, next_id) - except Exception as e: - logger.error("store_room with room_id=%s failed: %s", room_id, e) - raise StoreError(500, "Problem creating room.") - - @defer.inlineCallbacks - def set_room_is_public(self, room_id, is_public): - def set_room_is_public_txn(txn, next_id): - self._simple_update_one_txn( - txn, - table="rooms", - keyvalues={"room_id": room_id}, - updatevalues={"is_public": is_public}, - ) - - entries = self._simple_select_list_txn( - txn, - table="public_room_list_stream", - keyvalues={ - "room_id": room_id, - "appservice_id": None, - "network_id": None, - }, - retcols=("stream_id", "visibility"), - ) - - entries.sort(key=lambda r: r["stream_id"]) - - add_to_stream = True - if entries: - add_to_stream = bool(entries[-1]["visibility"]) != is_public - - if add_to_stream: - self._simple_insert_txn( - txn, - table="public_room_list_stream", - values={ - "stream_id": next_id, - "room_id": room_id, - "visibility": is_public, - "appservice_id": None, - "network_id": None, - }, - ) - - with self._public_room_id_gen.get_next() as next_id: - yield self.runInteraction( - "set_room_is_public", set_room_is_public_txn, next_id - ) - self.hs.get_notifier().on_new_replication_data() - - @defer.inlineCallbacks - def set_room_is_public_appservice( - self, room_id, appservice_id, network_id, is_public - ): - """Edit the appservice/network specific public room list. - - Each appservice can have a number of published room lists associated - with them, keyed off of an appservice defined `network_id`, which - basically represents a single instance of a bridge to a third party - network. - - Args: - room_id (str) - appservice_id (str) - network_id (str) - is_public (bool): Whether to publish or unpublish the room from the - list. - """ - - def set_room_is_public_appservice_txn(txn, next_id): - if is_public: - try: - self._simple_insert_txn( - txn, - table="appservice_room_list", - values={ - "appservice_id": appservice_id, - "network_id": network_id, - "room_id": room_id, - }, - ) - except self.database_engine.module.IntegrityError: - # We've already inserted, nothing to do. - return - else: - self._simple_delete_txn( - txn, - table="appservice_room_list", - keyvalues={ - "appservice_id": appservice_id, - "network_id": network_id, - "room_id": room_id, - }, - ) - - entries = self._simple_select_list_txn( - txn, - table="public_room_list_stream", - keyvalues={ - "room_id": room_id, - "appservice_id": appservice_id, - "network_id": network_id, - }, - retcols=("stream_id", "visibility"), - ) - - entries.sort(key=lambda r: r["stream_id"]) - - add_to_stream = True - if entries: - add_to_stream = bool(entries[-1]["visibility"]) != is_public - - if add_to_stream: - self._simple_insert_txn( - txn, - table="public_room_list_stream", - values={ - "stream_id": next_id, - "room_id": room_id, - "visibility": is_public, - "appservice_id": appservice_id, - "network_id": network_id, - }, - ) - - with self._public_room_id_gen.get_next() as next_id: - yield self.runInteraction( - "set_room_is_public_appservice", - set_room_is_public_appservice_txn, - next_id, - ) - self.hs.get_notifier().on_new_replication_data() - - def get_room_count(self): - """Retrieve a list of all rooms - """ - - def f(txn): - sql = "SELECT count(*) FROM rooms" - txn.execute(sql) - row = txn.fetchone() - return row[0] or 0 - - return self.runInteraction("get_rooms", f) - - def _store_room_topic_txn(self, txn, event): - if hasattr(event, "content") and "topic" in event.content: - self.store_event_search_txn( - txn, event, "content.topic", event.content["topic"] - ) - - def _store_room_name_txn(self, txn, event): - if hasattr(event, "content") and "name" in event.content: - self.store_event_search_txn( - txn, event, "content.name", event.content["name"] - ) - - def _store_room_message_txn(self, txn, event): - if hasattr(event, "content") and "body" in event.content: - self.store_event_search_txn( - txn, event, "content.body", event.content["body"] - ) - - def add_event_report( - self, room_id, event_id, user_id, reason, content, received_ts - ): - next_id = self._event_reports_id_gen.get_next() - return self._simple_insert( - table="event_reports", - values={ - "id": next_id, - "received_ts": received_ts, - "room_id": room_id, - "event_id": event_id, - "user_id": user_id, - "reason": reason, - "content": json.dumps(content), - }, - desc="add_event_report", - ) - - def get_current_public_room_stream_id(self): - return self._public_room_id_gen.get_current_token() - - def get_all_new_public_rooms(self, prev_id, current_id, limit): - def get_all_new_public_rooms(txn): - sql = """ - SELECT stream_id, room_id, visibility, appservice_id, network_id - FROM public_room_list_stream - WHERE stream_id > ? AND stream_id <= ? - ORDER BY stream_id ASC - LIMIT ? - """ - - txn.execute(sql, (prev_id, current_id, limit)) - return txn.fetchall() - - if prev_id == current_id: - return defer.succeed([]) - - return self.runInteraction("get_all_new_public_rooms", get_all_new_public_rooms) - - @defer.inlineCallbacks - def block_room(self, room_id, user_id): - """Marks the room as blocked. Can be called multiple times. - - Args: - room_id (str): Room to block - user_id (str): Who blocked it - - Returns: - Deferred - """ - yield self._simple_upsert( - table="blocked_rooms", - keyvalues={"room_id": room_id}, - values={}, - insertion_values={"user_id": user_id}, - desc="block_room", - ) - yield self.runInteraction( - "block_room_invalidation", - self._invalidate_cache_and_stream, - self.is_room_blocked, - (room_id,), - ) - - def get_media_mxcs_in_room(self, room_id): - """Retrieves all the local and remote media MXC URIs in a given room - - Args: - room_id (str) - - Returns: - The local and remote media as a lists of tuples where the key is - the hostname and the value is the media ID. - """ - - def _get_media_mxcs_in_room_txn(txn): - local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) - local_media_mxcs = [] - remote_media_mxcs = [] - - # Convert the IDs to MXC URIs - for media_id in local_mxcs: - local_media_mxcs.append("mxc://%s/%s" % (self.hs.hostname, media_id)) - for hostname, media_id in remote_mxcs: - remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id)) - - return local_media_mxcs, remote_media_mxcs - - return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn) - - def quarantine_media_ids_in_room(self, room_id, quarantined_by): - """For a room loops through all events with media and quarantines - the associated media - """ - - def _quarantine_media_in_room_txn(txn): - local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) - total_media_quarantined = 0 - - # Now update all the tables to set the quarantined_by flag - - txn.executemany( - """ - UPDATE local_media_repository - SET quarantined_by = ? - WHERE media_id = ? - """, - ((quarantined_by, media_id) for media_id in local_mxcs), - ) - - txn.executemany( - """ - UPDATE remote_media_cache - SET quarantined_by = ? - WHERE media_origin = ? AND media_id = ? - """, - ( - (quarantined_by, origin, media_id) - for origin, media_id in remote_mxcs - ), - ) - - total_media_quarantined += len(local_mxcs) - total_media_quarantined += len(remote_mxcs) - - return total_media_quarantined - - return self.runInteraction( - "quarantine_media_in_room", _quarantine_media_in_room_txn - ) - - def _get_media_mxcs_in_room_txn(self, txn, room_id): - """Retrieves all the local and remote media MXC URIs in a given room - - Args: - txn (cursor) - room_id (str) - - Returns: - The local and remote media as a lists of tuples where the key is - the hostname and the value is the media ID. - """ - mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)") - - next_token = self.get_current_events_token() + 1 - local_media_mxcs = [] - remote_media_mxcs = [] - - while next_token: - sql = """ - SELECT stream_ordering, json FROM events - JOIN event_json USING (room_id, event_id) - WHERE room_id = ? - AND stream_ordering < ? - AND contains_url = ? AND outlier = ? - ORDER BY stream_ordering DESC - LIMIT ? - """ - txn.execute(sql, (room_id, next_token, True, False, 100)) - - next_token = None - for stream_ordering, content_json in txn: - next_token = stream_ordering - event_json = json.loads(content_json) - content = event_json["content"] - content_url = content.get("url") - thumbnail_url = content.get("info", {}).get("thumbnail_url") - - for url in (content_url, thumbnail_url): - if not url: - continue - matches = mxc_re.match(url) - if matches: - hostname = matches.group(1) - media_id = matches.group(2) - if hostname == self.hs.hostname: - local_media_mxcs.append(media_id) - else: - remote_media_mxcs.append((hostname, media_id)) - - return local_media_mxcs, remote_media_mxcs diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index ff63487823..8c4a83a840 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -17,26 +17,6 @@ import logging from collections import namedtuple -from six import iteritems, itervalues - -from canonicaljson import json - -from twisted.internet import defer - -from synapse.api.constants import EventTypes, Membership -from synapse.metrics import LaterGauge -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import LoggingTransaction, make_in_list_sql_clause -from synapse.storage.background_updates import BackgroundUpdateStore -from synapse.storage.engines import Sqlite3Engine -from synapse.storage.events_worker import EventsWorkerStore -from synapse.types import get_domain_from_id -from synapse.util.async_helpers import Linearizer -from synapse.util.caches import intern_string -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList -from synapse.util.metrics import Measure -from synapse.util.stringutils import to_ascii - logger = logging.getLogger(__name__) @@ -57,1102 +37,3 @@ ProfileInfo = namedtuple("ProfileInfo", ("avatar_url", "display_name")) # a given membership type, suitable for use in calculating heroes for a room. # "count" points to the total numberr of users of a given membership type. MemberSummary = namedtuple("MemberSummary", ("members", "count")) - -_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" -_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" - - -class RoomMemberWorkerStore(EventsWorkerStore): - def __init__(self, db_conn, hs): - super(RoomMemberWorkerStore, self).__init__(db_conn, hs) - - # Is the current_state_events.membership up to date? Or is the - # background update still running? - self._current_state_events_membership_up_to_date = False - - txn = LoggingTransaction( - db_conn.cursor(), - name="_check_safe_current_state_events_membership_updated", - database_engine=self.database_engine, - ) - self._check_safe_current_state_events_membership_updated_txn(txn) - txn.close() - - if self.hs.config.metrics_flags.known_servers: - self._known_servers_count = 1 - self.hs.get_clock().looping_call( - run_as_background_process, - 60 * 1000, - "_count_known_servers", - self._count_known_servers, - ) - self.hs.get_clock().call_later( - 1000, - run_as_background_process, - "_count_known_servers", - self._count_known_servers, - ) - LaterGauge( - "synapse_federation_known_servers", - "", - [], - lambda: self._known_servers_count, - ) - - @defer.inlineCallbacks - def _count_known_servers(self): - """ - Count the servers that this server knows about. - - The statistic is stored on the class for the - `synapse_federation_known_servers` LaterGauge to collect. - """ - - def _transact(txn): - if isinstance(self.database_engine, Sqlite3Engine): - query = """ - SELECT COUNT(DISTINCT substr(out.user_id, pos+1)) - FROM ( - SELECT rm.user_id as user_id, instr(rm.user_id, ':') - AS pos FROM room_memberships as rm - INNER JOIN current_state_events as c ON rm.event_id = c.event_id - WHERE c.type = 'm.room.member' - ) as out - """ - else: - query = """ - SELECT COUNT(DISTINCT split_part(state_key, ':', 2)) - FROM current_state_events - WHERE type = 'm.room.member' AND membership = 'join'; - """ - txn.execute(query) - return list(txn)[0][0] - - count = yield self.runInteraction("get_known_servers", _transact) - - # We always know about ourselves, even if we have nothing in - # room_memberships (for example, the server is new). - self._known_servers_count = max([count, 1]) - return self._known_servers_count - - def _check_safe_current_state_events_membership_updated_txn(self, txn): - """Checks if it is safe to assume the new current_state_events - membership column is up to date - """ - - pending_update = self._simple_select_one_txn( - txn, - table="background_updates", - keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME}, - retcols=["update_name"], - allow_none=True, - ) - - self._current_state_events_membership_up_to_date = not pending_update - - # If the update is still running, reschedule to run. - if pending_update: - self._clock.call_later( - 15.0, - run_as_background_process, - "_check_safe_current_state_events_membership_updated", - self.runInteraction, - "_check_safe_current_state_events_membership_updated", - self._check_safe_current_state_events_membership_updated_txn, - ) - - @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True) - def get_hosts_in_room(self, room_id, cache_context): - """Returns the set of all hosts currently in the room - """ - user_ids = yield self.get_users_in_room( - room_id, on_invalidate=cache_context.invalidate - ) - hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids) - return hosts - - @cached(max_entries=100000, iterable=True) - def get_users_in_room(self, room_id): - return self.runInteraction( - "get_users_in_room", self.get_users_in_room_txn, room_id - ) - - def get_users_in_room_txn(self, txn, room_id): - # If we can assume current_state_events.membership is up to date - # then we can avoid a join, which is a Very Good Thing given how - # frequently this function gets called. - if self._current_state_events_membership_up_to_date: - sql = """ - SELECT state_key FROM current_state_events - WHERE type = 'm.room.member' AND room_id = ? AND membership = ? - """ - else: - sql = """ - SELECT state_key FROM room_memberships as m - INNER JOIN current_state_events as c - ON m.event_id = c.event_id - AND m.room_id = c.room_id - AND m.user_id = c.state_key - WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ? - """ - - txn.execute(sql, (room_id, Membership.JOIN)) - return [to_ascii(r[0]) for r in txn] - - @cached(max_entries=100000) - def get_room_summary(self, room_id): - """ Get the details of a room roughly suitable for use by the room - summary extension to /sync. Useful when lazy loading room members. - Args: - room_id (str): The room ID to query - Returns: - Deferred[dict[str, MemberSummary]: - dict of membership states, pointing to a MemberSummary named tuple. - """ - - def _get_room_summary_txn(txn): - # first get counts. - # We do this all in one transaction to keep the cache small. - # FIXME: get rid of this when we have room_stats - - # If we can assume current_state_events.membership is up to date - # then we can avoid a join, which is a Very Good Thing given how - # frequently this function gets called. - if self._current_state_events_membership_up_to_date: - # Note, rejected events will have a null membership field, so - # we we manually filter them out. - sql = """ - SELECT count(*), membership FROM current_state_events - WHERE type = 'm.room.member' AND room_id = ? - AND membership IS NOT NULL - GROUP BY membership - """ - else: - sql = """ - SELECT count(*), m.membership FROM room_memberships as m - INNER JOIN current_state_events as c - ON m.event_id = c.event_id - AND m.room_id = c.room_id - AND m.user_id = c.state_key - WHERE c.type = 'm.room.member' AND c.room_id = ? - GROUP BY m.membership - """ - - txn.execute(sql, (room_id,)) - res = {} - for count, membership in txn: - summary = res.setdefault(to_ascii(membership), MemberSummary([], count)) - - # we order by membership and then fairly arbitrarily by event_id so - # heroes are consistent - if self._current_state_events_membership_up_to_date: - # Note, rejected events will have a null membership field, so - # we we manually filter them out. - sql = """ - SELECT state_key, membership, event_id - FROM current_state_events - WHERE type = 'm.room.member' AND room_id = ? - AND membership IS NOT NULL - ORDER BY - CASE membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC, - event_id ASC - LIMIT ? - """ - else: - sql = """ - SELECT c.state_key, m.membership, c.event_id - FROM room_memberships as m - INNER JOIN current_state_events as c USING (room_id, event_id) - WHERE c.type = 'm.room.member' AND c.room_id = ? - ORDER BY - CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC, - c.event_id ASC - LIMIT ? - """ - - # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user. - txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6)) - for user_id, membership, event_id in txn: - summary = res[to_ascii(membership)] - # we will always have a summary for this membership type at this - # point given the summary currently contains the counts. - members = summary.members - members.append((to_ascii(user_id), to_ascii(event_id))) - - return res - - return self.runInteraction("get_room_summary", _get_room_summary_txn) - - def _get_user_counts_in_room_txn(self, txn, room_id): - """ - Get the user count in a room by membership. - - Args: - room_id (str) - membership (Membership) - - Returns: - Deferred[int] - """ - sql = """ - SELECT m.membership, count(*) FROM room_memberships as m - INNER JOIN current_state_events as c USING(event_id) - WHERE c.type = 'm.room.member' AND c.room_id = ? - GROUP BY m.membership - """ - - txn.execute(sql, (room_id,)) - return {row[0]: row[1] for row in txn} - - @cached() - def get_invited_rooms_for_user(self, user_id): - """ Get all the rooms the user is invited to - Args: - user_id (str): The user ID. - Returns: - A deferred list of RoomsForUser. - """ - - return self.get_rooms_for_user_where_membership_is(user_id, [Membership.INVITE]) - - @defer.inlineCallbacks - def get_invite_for_user_in_room(self, user_id, room_id): - """Gets the invite for the given user and room - - Args: - user_id (str) - room_id (str) - - Returns: - Deferred: Resolves to either a RoomsForUser or None if no invite was - found. - """ - invites = yield self.get_invited_rooms_for_user(user_id) - for invite in invites: - if invite.room_id == room_id: - return invite - return None - - @defer.inlineCallbacks - def get_rooms_for_user_where_membership_is(self, user_id, membership_list): - """ Get all the rooms for this user where the membership for this user - matches one in the membership list. - - Filters out forgotten rooms. - - Args: - user_id (str): The user ID. - membership_list (list): A list of synapse.api.constants.Membership - values which the user must be in. - - Returns: - Deferred[list[RoomsForUser]] - """ - if not membership_list: - return defer.succeed(None) - - rooms = yield self.runInteraction( - "get_rooms_for_user_where_membership_is", - self._get_rooms_for_user_where_membership_is_txn, - user_id, - membership_list, - ) - - # Now we filter out forgotten rooms - forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id) - return [room for room in rooms if room.room_id not in forgotten_rooms] - - def _get_rooms_for_user_where_membership_is_txn( - self, txn, user_id, membership_list - ): - - do_invite = Membership.INVITE in membership_list - membership_list = [m for m in membership_list if m != Membership.INVITE] - - results = [] - if membership_list: - if self._current_state_events_membership_up_to_date: - clause, args = make_in_list_sql_clause( - self.database_engine, "c.membership", membership_list - ) - sql = """ - SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering - FROM current_state_events AS c - INNER JOIN events AS e USING (room_id, event_id) - WHERE - c.type = 'm.room.member' - AND state_key = ? - AND %s - """ % ( - clause, - ) - else: - clause, args = make_in_list_sql_clause( - self.database_engine, "m.membership", membership_list - ) - sql = """ - SELECT room_id, e.sender, m.membership, event_id, e.stream_ordering - FROM current_state_events AS c - INNER JOIN room_memberships AS m USING (room_id, event_id) - INNER JOIN events AS e USING (room_id, event_id) - WHERE - c.type = 'm.room.member' - AND state_key = ? - AND %s - """ % ( - clause, - ) - - txn.execute(sql, (user_id, *args)) - results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)] - - if do_invite: - sql = ( - "SELECT i.room_id, inviter, i.event_id, e.stream_ordering" - " FROM local_invites as i" - " INNER JOIN events as e USING (event_id)" - " WHERE invitee = ? AND locally_rejected is NULL" - " AND replaced_by is NULL" - ) - - txn.execute(sql, (user_id,)) - results.extend( - RoomsForUser( - room_id=r["room_id"], - sender=r["inviter"], - event_id=r["event_id"], - stream_ordering=r["stream_ordering"], - membership=Membership.INVITE, - ) - for r in self.cursor_to_dict(txn) - ) - - return results - - @cachedInlineCallbacks(max_entries=500000, iterable=True) - def get_rooms_for_user_with_stream_ordering(self, user_id): - """Returns a set of room_ids the user is currently joined to - - Args: - user_id (str) - - Returns: - Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns - the rooms the user is in currently, along with the stream ordering - of the most recent join for that user and room. - """ - rooms = yield self.get_rooms_for_user_where_membership_is( - user_id, membership_list=[Membership.JOIN] - ) - return frozenset( - GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering) - for r in rooms - ) - - @defer.inlineCallbacks - def get_rooms_for_user(self, user_id, on_invalidate=None): - """Returns a set of room_ids the user is currently joined to - """ - rooms = yield self.get_rooms_for_user_with_stream_ordering( - user_id, on_invalidate=on_invalidate - ) - return frozenset(r.room_id for r in rooms) - - @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True) - def get_users_who_share_room_with_user(self, user_id, cache_context): - """Returns the set of users who share a room with `user_id` - """ - room_ids = yield self.get_rooms_for_user( - user_id, on_invalidate=cache_context.invalidate - ) - - user_who_share_room = set() - for room_id in room_ids: - user_ids = yield self.get_users_in_room( - room_id, on_invalidate=cache_context.invalidate - ) - user_who_share_room.update(user_ids) - - return user_who_share_room - - @defer.inlineCallbacks - def get_joined_users_from_context(self, event, context): - state_group = context.state_group - if not state_group: - # If state_group is None it means it has yet to be assigned a - # state group, i.e. we need to make sure that calls with a state_group - # of None don't hit previous cached calls with a None state_group. - # To do this we set the state_group to a new object as object() != object() - state_group = object() - - current_state_ids = yield context.get_current_state_ids(self) - result = yield self._get_joined_users_from_context( - event.room_id, state_group, current_state_ids, event=event, context=context - ) - return result - - @defer.inlineCallbacks - def get_joined_users_from_state(self, room_id, state_entry): - state_group = state_entry.state_group - if not state_group: - # If state_group is None it means it has yet to be assigned a - # state group, i.e. we need to make sure that calls with a state_group - # of None don't hit previous cached calls with a None state_group. - # To do this we set the state_group to a new object as object() != object() - state_group = object() - - with Measure(self._clock, "get_joined_users_from_state"): - return ( - yield self._get_joined_users_from_context( - room_id, state_group, state_entry.state, context=state_entry - ) - ) - - @cachedInlineCallbacks( - num_args=2, cache_context=True, iterable=True, max_entries=100000 - ) - def _get_joined_users_from_context( - self, - room_id, - state_group, - current_state_ids, - cache_context, - event=None, - context=None, - ): - # We don't use `state_group`, it's there so that we can cache based - # on it. However, it's important that it's never None, since two current_states - # with a state_group of None are likely to be different. - # See bulk_get_push_rules_for_room for how we work around this. - assert state_group is not None - - users_in_room = {} - member_event_ids = [ - e_id - for key, e_id in iteritems(current_state_ids) - if key[0] == EventTypes.Member - ] - - if context is not None: - # If we have a context with a delta from a previous state group, - # check if we also have the result from the previous group in cache. - # If we do then we can reuse that result and simply update it with - # any membership changes in `delta_ids` - if context.prev_group and context.delta_ids: - prev_res = self._get_joined_users_from_context.cache.get( - (room_id, context.prev_group), None - ) - if prev_res and isinstance(prev_res, dict): - users_in_room = dict(prev_res) - member_event_ids = [ - e_id - for key, e_id in iteritems(context.delta_ids) - if key[0] == EventTypes.Member - ] - for etype, state_key in context.delta_ids: - users_in_room.pop(state_key, None) - - # We check if we have any of the member event ids in the event cache - # before we ask the DB - - # We don't update the event cache hit ratio as it completely throws off - # the hit ratio counts. After all, we don't populate the cache if we - # miss it here - event_map = self._get_events_from_cache( - member_event_ids, allow_rejected=False, update_metrics=False - ) - - missing_member_event_ids = [] - for event_id in member_event_ids: - ev_entry = event_map.get(event_id) - if ev_entry: - if ev_entry.event.membership == Membership.JOIN: - users_in_room[to_ascii(ev_entry.event.state_key)] = ProfileInfo( - display_name=to_ascii( - ev_entry.event.content.get("displayname", None) - ), - avatar_url=to_ascii( - ev_entry.event.content.get("avatar_url", None) - ), - ) - else: - missing_member_event_ids.append(event_id) - - if missing_member_event_ids: - event_to_memberships = yield self._get_joined_profiles_from_event_ids( - missing_member_event_ids - ) - users_in_room.update((row for row in event_to_memberships.values() if row)) - - if event is not None and event.type == EventTypes.Member: - if event.membership == Membership.JOIN: - if event.event_id in member_event_ids: - users_in_room[to_ascii(event.state_key)] = ProfileInfo( - display_name=to_ascii(event.content.get("displayname", None)), - avatar_url=to_ascii(event.content.get("avatar_url", None)), - ) - - return users_in_room - - @cached(max_entries=10000) - def _get_joined_profile_from_event_id(self, event_id): - raise NotImplementedError() - - @cachedList( - cached_method_name="_get_joined_profile_from_event_id", - list_name="event_ids", - inlineCallbacks=True, - ) - def _get_joined_profiles_from_event_ids(self, event_ids): - """For given set of member event_ids check if they point to a join - event and if so return the associated user and profile info. - - Args: - event_ids (Iterable[str]): The member event IDs to lookup - - Returns: - Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID - to `user_id` and ProfileInfo (or None if not join event). - """ - - rows = yield self._simple_select_many_batch( - table="room_memberships", - column="event_id", - iterable=event_ids, - retcols=("user_id", "display_name", "avatar_url", "event_id"), - keyvalues={"membership": Membership.JOIN}, - batch_size=500, - desc="_get_membership_from_event_ids", - ) - - return { - row["event_id"]: ( - row["user_id"], - ProfileInfo( - avatar_url=row["avatar_url"], display_name=row["display_name"] - ), - ) - for row in rows - } - - @cachedInlineCallbacks(max_entries=10000) - def is_host_joined(self, room_id, host): - if "%" in host or "_" in host: - raise Exception("Invalid host name") - - sql = """ - SELECT state_key FROM current_state_events AS c - INNER JOIN room_memberships AS m USING (event_id) - WHERE m.membership = 'join' - AND type = 'm.room.member' - AND c.room_id = ? - AND state_key LIKE ? - LIMIT 1 - """ - - # We do need to be careful to ensure that host doesn't have any wild cards - # in it, but we checked above for known ones and we'll check below that - # the returned user actually has the correct domain. - like_clause = "%:" + host - - rows = yield self._execute("is_host_joined", None, sql, room_id, like_clause) - - if not rows: - return False - - user_id = rows[0][0] - if get_domain_from_id(user_id) != host: - # This can only happen if the host name has something funky in it - raise Exception("Invalid host name") - - return True - - @cachedInlineCallbacks() - def was_host_joined(self, room_id, host): - """Check whether the server is or ever was in the room. - - Args: - room_id (str) - host (str) - - Returns: - Deferred: Resolves to True if the host is/was in the room, otherwise - False. - """ - if "%" in host or "_" in host: - raise Exception("Invalid host name") - - sql = """ - SELECT user_id FROM room_memberships - WHERE room_id = ? - AND user_id LIKE ? - AND membership = 'join' - LIMIT 1 - """ - - # We do need to be careful to ensure that host doesn't have any wild cards - # in it, but we checked above for known ones and we'll check below that - # the returned user actually has the correct domain. - like_clause = "%:" + host - - rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause) - - if not rows: - return False - - user_id = rows[0][0] - if get_domain_from_id(user_id) != host: - # This can only happen if the host name has something funky in it - raise Exception("Invalid host name") - - return True - - @defer.inlineCallbacks - def get_joined_hosts(self, room_id, state_entry): - state_group = state_entry.state_group - if not state_group: - # If state_group is None it means it has yet to be assigned a - # state group, i.e. we need to make sure that calls with a state_group - # of None don't hit previous cached calls with a None state_group. - # To do this we set the state_group to a new object as object() != object() - state_group = object() - - with Measure(self._clock, "get_joined_hosts"): - return ( - yield self._get_joined_hosts( - room_id, state_group, state_entry.state, state_entry=state_entry - ) - ) - - @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True) - # @defer.inlineCallbacks - def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry): - # We don't use `state_group`, its there so that we can cache based - # on it. However, its important that its never None, since two current_state's - # with a state_group of None are likely to be different. - # See bulk_get_push_rules_for_room for how we work around this. - assert state_group is not None - - cache = self._get_joined_hosts_cache(room_id) - joined_hosts = yield cache.get_destinations(state_entry) - - return joined_hosts - - @cached(max_entries=10000) - def _get_joined_hosts_cache(self, room_id): - return _JoinedHostsCache(self, room_id) - - @cachedInlineCallbacks(num_args=2) - def did_forget(self, user_id, room_id): - """Returns whether user_id has elected to discard history for room_id. - - Returns False if they have since re-joined.""" - - def f(txn): - sql = ( - "SELECT" - " COUNT(*)" - " FROM" - " room_memberships" - " WHERE" - " user_id = ?" - " AND" - " room_id = ?" - " AND" - " forgotten = 0" - ) - txn.execute(sql, (user_id, room_id)) - rows = txn.fetchall() - return rows[0][0] - - count = yield self.runInteraction("did_forget_membership", f) - return count == 0 - - @cached() - def get_forgotten_rooms_for_user(self, user_id): - """Gets all rooms the user has forgotten. - - Args: - user_id (str) - - Returns: - Deferred[set[str]] - """ - - def _get_forgotten_rooms_for_user_txn(txn): - # This is a slightly convoluted query that first looks up all rooms - # that the user has forgotten in the past, then rechecks that list - # to see if any have subsequently been updated. This is done so that - # we can use a partial index on `forgotten = 1` on the assumption - # that few users will actually forget many rooms. - # - # Note that a room is considered "forgotten" if *all* membership - # events for that user and room have the forgotten field set (as - # when a user forgets a room we update all rows for that user and - # room, not just the current one). - sql = """ - SELECT room_id, ( - SELECT count(*) FROM room_memberships - WHERE room_id = m.room_id AND user_id = m.user_id AND forgotten = 0 - ) AS count - FROM room_memberships AS m - WHERE user_id = ? AND forgotten = 1 - GROUP BY room_id, user_id; - """ - txn.execute(sql, (user_id,)) - return set(row[0] for row in txn if row[1] == 0) - - return self.runInteraction( - "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn - ) - - @defer.inlineCallbacks - def get_rooms_user_has_been_in(self, user_id): - """Get all rooms that the user has ever been in. - - Args: - user_id (str) - - Returns: - Deferred[set[str]]: Set of room IDs. - """ - - room_ids = yield self._simple_select_onecol( - table="room_memberships", - keyvalues={"membership": Membership.JOIN, "user_id": user_id}, - retcol="room_id", - desc="get_rooms_user_has_been_in", - ) - - return set(room_ids) - - -class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(RoomMemberBackgroundUpdateStore, self).__init__(db_conn, hs) - self.register_background_update_handler( - _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile - ) - self.register_background_update_handler( - _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, - self._background_current_state_membership, - ) - self.register_background_index_update( - "room_membership_forgotten_idx", - index_name="room_memberships_user_room_forgotten", - table="room_memberships", - columns=["user_id", "room_id"], - where_clause="forgotten = 1", - ) - - @defer.inlineCallbacks - def _background_add_membership_profile(self, progress, batch_size): - target_min_stream_id = progress.get( - "target_min_stream_id_inclusive", self._min_stream_order_on_start - ) - max_stream_id = progress.get( - "max_stream_id_exclusive", self._stream_order_on_start + 1 - ) - - INSERT_CLUMP_SIZE = 1000 - - def add_membership_profile_txn(txn): - sql = """ - SELECT stream_ordering, event_id, events.room_id, event_json.json - FROM events - INNER JOIN event_json USING (event_id) - INNER JOIN room_memberships USING (event_id) - WHERE ? <= stream_ordering AND stream_ordering < ? - AND type = 'm.room.member' - ORDER BY stream_ordering DESC - LIMIT ? - """ - - txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) - - rows = self.cursor_to_dict(txn) - if not rows: - return 0 - - min_stream_id = rows[-1]["stream_ordering"] - - to_update = [] - for row in rows: - event_id = row["event_id"] - room_id = row["room_id"] - try: - event_json = json.loads(row["json"]) - content = event_json["content"] - except Exception: - continue - - display_name = content.get("displayname", None) - avatar_url = content.get("avatar_url", None) - - if display_name or avatar_url: - to_update.append((display_name, avatar_url, event_id, room_id)) - - to_update_sql = """ - UPDATE room_memberships SET display_name = ?, avatar_url = ? - WHERE event_id = ? AND room_id = ? - """ - for index in range(0, len(to_update), INSERT_CLUMP_SIZE): - clump = to_update[index : index + INSERT_CLUMP_SIZE] - txn.executemany(to_update_sql, clump) - - progress = { - "target_min_stream_id_inclusive": target_min_stream_id, - "max_stream_id_exclusive": min_stream_id, - } - - self._background_update_progress_txn( - txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress - ) - - return len(rows) - - result = yield self.runInteraction( - _MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn - ) - - if not result: - yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME) - - return result - - @defer.inlineCallbacks - def _background_current_state_membership(self, progress, batch_size): - """Update the new membership column on current_state_events. - - This works by iterating over all rooms in alphebetical order. - """ - - def _background_current_state_membership_txn(txn, last_processed_room): - processed = 0 - while processed < batch_size: - txn.execute( - """ - SELECT MIN(room_id) FROM current_state_events WHERE room_id > ? - """, - (last_processed_room,), - ) - row = txn.fetchone() - if not row or not row[0]: - return processed, True - - next_room, = row - - sql = """ - UPDATE current_state_events - SET membership = ( - SELECT membership FROM room_memberships - WHERE event_id = current_state_events.event_id - ) - WHERE room_id = ? - """ - txn.execute(sql, (next_room,)) - processed += txn.rowcount - - last_processed_room = next_room - - self._background_update_progress_txn( - txn, - _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, - {"last_processed_room": last_processed_room}, - ) - - return processed, False - - # If we haven't got a last processed room then just use the empty - # string, which will compare before all room IDs correctly. - last_processed_room = progress.get("last_processed_room", "") - - row_count, finished = yield self.runInteraction( - "_background_current_state_membership_update", - _background_current_state_membership_txn, - last_processed_room, - ) - - if finished: - yield self._end_background_update(_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME) - - return row_count - - -class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(RoomMemberStore, self).__init__(db_conn, hs) - - def _store_room_members_txn(self, txn, events, backfilled): - """Store a room member in the database. - """ - self._simple_insert_many_txn( - txn, - table="room_memberships", - values=[ - { - "event_id": event.event_id, - "user_id": event.state_key, - "sender": event.user_id, - "room_id": event.room_id, - "membership": event.membership, - "display_name": event.content.get("displayname", None), - "avatar_url": event.content.get("avatar_url", None), - } - for event in events - ], - ) - - for event in events: - txn.call_after( - self._membership_stream_cache.entity_has_changed, - event.state_key, - event.internal_metadata.stream_ordering, - ) - txn.call_after( - self.get_invited_rooms_for_user.invalidate, (event.state_key,) - ) - - # We update the local_invites table only if the event is "current", - # i.e., its something that has just happened. If the event is an - # outlier it is only current if its an "out of band membership", - # like a remote invite or a rejection of a remote invite. - is_new_state = not backfilled and ( - not event.internal_metadata.is_outlier() - or event.internal_metadata.is_out_of_band_membership() - ) - is_mine = self.hs.is_mine_id(event.state_key) - if is_new_state and is_mine: - if event.membership == Membership.INVITE: - self._simple_insert_txn( - txn, - table="local_invites", - values={ - "event_id": event.event_id, - "invitee": event.state_key, - "inviter": event.sender, - "room_id": event.room_id, - "stream_id": event.internal_metadata.stream_ordering, - }, - ) - else: - sql = ( - "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE" - " room_id = ? AND invitee = ? AND locally_rejected is NULL" - " AND replaced_by is NULL" - ) - - txn.execute( - sql, - ( - event.internal_metadata.stream_ordering, - event.event_id, - event.room_id, - event.state_key, - ), - ) - - @defer.inlineCallbacks - def locally_reject_invite(self, user_id, room_id): - sql = ( - "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE" - " room_id = ? AND invitee = ? AND locally_rejected is NULL" - " AND replaced_by is NULL" - ) - - def f(txn, stream_ordering): - txn.execute(sql, (stream_ordering, True, room_id, user_id)) - - with self._stream_id_gen.get_next() as stream_ordering: - yield self.runInteraction("locally_reject_invite", f, stream_ordering) - - def forget(self, user_id, room_id): - """Indicate that user_id wishes to discard history for room_id.""" - - def f(txn): - sql = ( - "UPDATE" - " room_memberships" - " SET" - " forgotten = 1" - " WHERE" - " user_id = ?" - " AND" - " room_id = ?" - ) - txn.execute(sql, (user_id, room_id)) - - self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id)) - self._invalidate_cache_and_stream( - txn, self.get_forgotten_rooms_for_user, (user_id,) - ) - - return self.runInteraction("forget_membership", f) - - -class _JoinedHostsCache(object): - """Cache for joined hosts in a room that is optimised to handle updates - via state deltas. - """ - - def __init__(self, store, room_id): - self.store = store - self.room_id = room_id - - self.hosts_to_joined_users = {} - - self.state_group = object() - - self.linearizer = Linearizer("_JoinedHostsCache") - - self._len = 0 - - @defer.inlineCallbacks - def get_destinations(self, state_entry): - """Get set of destinations for a state entry - - Args: - state_entry(synapse.state._StateCacheEntry) - """ - if state_entry.state_group == self.state_group: - return frozenset(self.hosts_to_joined_users) - - with (yield self.linearizer.queue(())): - if state_entry.state_group == self.state_group: - pass - elif state_entry.prev_group == self.state_group: - for (typ, state_key), event_id in iteritems(state_entry.delta_ids): - if typ != EventTypes.Member: - continue - - host = intern_string(get_domain_from_id(state_key)) - user_id = state_key - known_joins = self.hosts_to_joined_users.setdefault(host, set()) - - event = yield self.store.get_event(event_id) - if event.membership == Membership.JOIN: - known_joins.add(user_id) - else: - known_joins.discard(user_id) - - if not known_joins: - self.hosts_to_joined_users.pop(host, None) - else: - joined_users = yield self.store.get_joined_users_from_state( - self.room_id, state_entry - ) - - self.hosts_to_joined_users = {} - for user_id in joined_users: - host = intern_string(get_domain_from_id(user_id)) - self.hosts_to_joined_users.setdefault(host, set()).add(user_id) - - if state_entry.state_group: - self.state_group = state_entry.state_group - else: - self.state_group = object() - self._len = sum(len(v) for v in itervalues(self.hosts_to_joined_users)) - return frozenset(self.hosts_to_joined_users) - - def __len__(self): - return self._len diff --git a/synapse/storage/schema/delta/12/v12.sql b/synapse/storage/schema/delta/12/v12.sql deleted file mode 100644 index 5964c5aaac..0000000000 --- a/synapse/storage/schema/delta/12/v12.sql +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS rejections( - event_id TEXT NOT NULL, - reason TEXT NOT NULL, - last_check TEXT NOT NULL, - UNIQUE (event_id) -); - --- Push notification endpoints that users have configured -CREATE TABLE IF NOT EXISTS pushers ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - user_name TEXT NOT NULL, - profile_tag VARCHAR(32) NOT NULL, - kind VARCHAR(8) NOT NULL, - app_id VARCHAR(64) NOT NULL, - app_display_name VARCHAR(64) NOT NULL, - device_display_name VARCHAR(128) NOT NULL, - pushkey VARBINARY(512) NOT NULL, - ts BIGINT UNSIGNED NOT NULL, - lang VARCHAR(8), - data LONGBLOB, - last_token TEXT, - last_success BIGINT UNSIGNED, - failing_since BIGINT UNSIGNED, - UNIQUE (app_id, pushkey) -); - -CREATE TABLE IF NOT EXISTS push_rules ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - user_name TEXT NOT NULL, - rule_id TEXT NOT NULL, - priority_class TINYINT NOT NULL, - priority INTEGER NOT NULL DEFAULT 0, - conditions TEXT NOT NULL, - actions TEXT NOT NULL, - UNIQUE(user_name, rule_id) -); - -CREATE INDEX IF NOT EXISTS push_rules_user_name on push_rules (user_name); - -CREATE TABLE IF NOT EXISTS user_filters( - user_id TEXT, - filter_id BIGINT UNSIGNED, - filter_json LONGBLOB -); - -CREATE INDEX IF NOT EXISTS user_filters_by_user_id_filter_id ON user_filters( - user_id, filter_id -); diff --git a/synapse/storage/schema/delta/13/v13.sql b/synapse/storage/schema/delta/13/v13.sql deleted file mode 100644 index f8649e5d99..0000000000 --- a/synapse/storage/schema/delta/13/v13.sql +++ /dev/null @@ -1,19 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* We used to create a tables called application_services and - * application_services_regex, but these are no longer used and are removed in - * delta 54. - */ diff --git a/synapse/storage/schema/delta/14/v14.sql b/synapse/storage/schema/delta/14/v14.sql deleted file mode 100644 index a831920da6..0000000000 --- a/synapse/storage/schema/delta/14/v14.sql +++ /dev/null @@ -1,23 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -CREATE TABLE IF NOT EXISTS push_rules_enable ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - user_name TEXT NOT NULL, - rule_id TEXT NOT NULL, - enabled TINYINT, - UNIQUE(user_name, rule_id) -); - -CREATE INDEX IF NOT EXISTS push_rules_enable_user_name on push_rules_enable (user_name); diff --git a/synapse/storage/schema/delta/15/appservice_txns.sql b/synapse/storage/schema/delta/15/appservice_txns.sql deleted file mode 100644 index e4f5e76aec..0000000000 --- a/synapse/storage/schema/delta/15/appservice_txns.sql +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS application_services_state( - as_id TEXT PRIMARY KEY, - state VARCHAR(5), - last_txn INTEGER -); - -CREATE TABLE IF NOT EXISTS application_services_txns( - as_id TEXT NOT NULL, - txn_id INTEGER NOT NULL, - event_ids TEXT NOT NULL, - UNIQUE(as_id, txn_id) -); - -CREATE INDEX IF NOT EXISTS application_services_txns_id ON application_services_txns ( - as_id -); diff --git a/synapse/storage/schema/delta/15/presence_indices.sql b/synapse/storage/schema/delta/15/presence_indices.sql deleted file mode 100644 index 6b8d0f1ca7..0000000000 --- a/synapse/storage/schema/delta/15/presence_indices.sql +++ /dev/null @@ -1,2 +0,0 @@ - -CREATE INDEX IF NOT EXISTS presence_list_user_id ON presence_list (user_id); diff --git a/synapse/storage/schema/delta/15/v15.sql b/synapse/storage/schema/delta/15/v15.sql deleted file mode 100644 index 9523d2bcc3..0000000000 --- a/synapse/storage/schema/delta/15/v15.sql +++ /dev/null @@ -1,24 +0,0 @@ --- Drop, copy & recreate pushers table to change unique key --- Also add access_token column at the same time -CREATE TABLE IF NOT EXISTS pushers2 ( - id BIGINT PRIMARY KEY, - user_name TEXT NOT NULL, - access_token BIGINT DEFAULT NULL, - profile_tag VARCHAR(32) NOT NULL, - kind VARCHAR(8) NOT NULL, - app_id VARCHAR(64) NOT NULL, - app_display_name VARCHAR(64) NOT NULL, - device_display_name VARCHAR(128) NOT NULL, - pushkey bytea NOT NULL, - ts BIGINT NOT NULL, - lang VARCHAR(8), - data bytea, - last_token TEXT, - last_success BIGINT, - failing_since BIGINT, - UNIQUE (app_id, pushkey) -); -INSERT INTO pushers2 (id, user_name, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_token, last_success, failing_since) - SELECT id, user_name, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_token, last_success, failing_since FROM pushers; -DROP TABLE pushers; -ALTER TABLE pushers2 RENAME TO pushers; diff --git a/synapse/storage/schema/delta/16/events_order_index.sql b/synapse/storage/schema/delta/16/events_order_index.sql deleted file mode 100644 index a48f215170..0000000000 --- a/synapse/storage/schema/delta/16/events_order_index.sql +++ /dev/null @@ -1,4 +0,0 @@ -CREATE INDEX events_order ON events (topological_ordering, stream_ordering); -CREATE INDEX events_order_room ON events ( - room_id, topological_ordering, stream_ordering -); diff --git a/synapse/storage/schema/delta/16/remote_media_cache_index.sql b/synapse/storage/schema/delta/16/remote_media_cache_index.sql deleted file mode 100644 index 7a15265cb1..0000000000 --- a/synapse/storage/schema/delta/16/remote_media_cache_index.sql +++ /dev/null @@ -1,2 +0,0 @@ -CREATE INDEX IF NOT EXISTS remote_media_cache_thumbnails_media_id - ON remote_media_cache_thumbnails (media_id); \ No newline at end of file diff --git a/synapse/storage/schema/delta/16/remove_duplicates.sql b/synapse/storage/schema/delta/16/remove_duplicates.sql deleted file mode 100644 index 65c97b5e2f..0000000000 --- a/synapse/storage/schema/delta/16/remove_duplicates.sql +++ /dev/null @@ -1,9 +0,0 @@ - - -DELETE FROM event_to_state_groups WHERE state_group not in ( - SELECT MAX(state_group) FROM event_to_state_groups GROUP BY event_id -); - -DELETE FROM event_to_state_groups WHERE rowid not in ( - SELECT MIN(rowid) FROM event_to_state_groups GROUP BY event_id -); diff --git a/synapse/storage/schema/delta/16/room_alias_index.sql b/synapse/storage/schema/delta/16/room_alias_index.sql deleted file mode 100644 index f82486132b..0000000000 --- a/synapse/storage/schema/delta/16/room_alias_index.sql +++ /dev/null @@ -1,3 +0,0 @@ - -CREATE INDEX IF NOT EXISTS room_aliases_id ON room_aliases(room_id); -CREATE INDEX IF NOT EXISTS room_alias_servers_alias ON room_alias_servers(room_alias); diff --git a/synapse/storage/schema/delta/16/unique_constraints.sql b/synapse/storage/schema/delta/16/unique_constraints.sql deleted file mode 100644 index 5b8de52c33..0000000000 --- a/synapse/storage/schema/delta/16/unique_constraints.sql +++ /dev/null @@ -1,72 +0,0 @@ - --- We can use SQLite features here, since other db support was only added in v16 - --- -DELETE FROM current_state_events WHERE rowid not in ( - SELECT MIN(rowid) FROM current_state_events GROUP BY event_id -); - -DROP INDEX IF EXISTS current_state_events_event_id; -CREATE UNIQUE INDEX current_state_events_event_id ON current_state_events(event_id); - --- -DELETE FROM room_memberships WHERE rowid not in ( - SELECT MIN(rowid) FROM room_memberships GROUP BY event_id -); - -DROP INDEX IF EXISTS room_memberships_event_id; -CREATE UNIQUE INDEX room_memberships_event_id ON room_memberships(event_id); - --- -DELETE FROM topics WHERE rowid not in ( - SELECT MIN(rowid) FROM topics GROUP BY event_id -); - -DROP INDEX IF EXISTS topics_event_id; -CREATE UNIQUE INDEX topics_event_id ON topics(event_id); - --- -DELETE FROM room_names WHERE rowid not in ( - SELECT MIN(rowid) FROM room_names GROUP BY event_id -); - -DROP INDEX IF EXISTS room_names_id; -CREATE UNIQUE INDEX room_names_id ON room_names(event_id); - --- -DELETE FROM presence WHERE rowid not in ( - SELECT MIN(rowid) FROM presence GROUP BY user_id -); - -DROP INDEX IF EXISTS presence_id; -CREATE UNIQUE INDEX presence_id ON presence(user_id); - --- -DELETE FROM presence_allow_inbound WHERE rowid not in ( - SELECT MIN(rowid) FROM presence_allow_inbound - GROUP BY observed_user_id, observer_user_id -); - -DROP INDEX IF EXISTS presence_allow_inbound_observers; -CREATE UNIQUE INDEX presence_allow_inbound_observers ON presence_allow_inbound( - observed_user_id, observer_user_id -); - --- -DELETE FROM presence_list WHERE rowid not in ( - SELECT MIN(rowid) FROM presence_list - GROUP BY user_id, observed_user_id -); - -DROP INDEX IF EXISTS presence_list_observers; -CREATE UNIQUE INDEX presence_list_observers ON presence_list( - user_id, observed_user_id -); - --- -DELETE FROM room_aliases WHERE rowid not in ( - SELECT MIN(rowid) FROM room_aliases GROUP BY room_alias -); - -DROP INDEX IF EXISTS room_aliases_id; -CREATE INDEX room_aliases_id ON room_aliases(room_id); diff --git a/synapse/storage/schema/delta/16/users.sql b/synapse/storage/schema/delta/16/users.sql deleted file mode 100644 index cd0709250d..0000000000 --- a/synapse/storage/schema/delta/16/users.sql +++ /dev/null @@ -1,56 +0,0 @@ --- Convert `access_tokens`.user from rowids to user strings. --- MUST BE DONE BEFORE REMOVING ID COLUMN FROM USERS TABLE BELOW -CREATE TABLE IF NOT EXISTS new_access_tokens( - id BIGINT UNSIGNED PRIMARY KEY, - user_id TEXT NOT NULL, - device_id TEXT, - token TEXT NOT NULL, - last_used BIGINT UNSIGNED, - UNIQUE(token) -); - -INSERT INTO new_access_tokens - SELECT a.id, u.name, a.device_id, a.token, a.last_used - FROM access_tokens as a - INNER JOIN users as u ON u.id = a.user_id; - -DROP TABLE access_tokens; - -ALTER TABLE new_access_tokens RENAME TO access_tokens; - --- Remove ID column from `users` table -CREATE TABLE IF NOT EXISTS new_users( - name TEXT, - password_hash TEXT, - creation_ts BIGINT UNSIGNED, - admin BOOL DEFAULT 0 NOT NULL, - UNIQUE(name) -); - -INSERT INTO new_users SELECT name, password_hash, creation_ts, admin FROM users; - -DROP TABLE users; - -ALTER TABLE new_users RENAME TO users; - - --- Remove UNIQUE constraint from `user_ips` table -CREATE TABLE IF NOT EXISTS new_user_ips ( - user_id TEXT NOT NULL, - access_token TEXT NOT NULL, - device_id TEXT, - ip TEXT NOT NULL, - user_agent TEXT NOT NULL, - last_seen BIGINT UNSIGNED NOT NULL -); - -INSERT INTO new_user_ips - SELECT user, access_token, device_id, ip, user_agent, last_seen FROM user_ips; - -DROP TABLE user_ips; - -ALTER TABLE new_user_ips RENAME TO user_ips; - -CREATE INDEX IF NOT EXISTS user_ips_user ON user_ips(user_id); -CREATE INDEX IF NOT EXISTS user_ips_user_ip ON user_ips(user_id, access_token, ip); - diff --git a/synapse/storage/schema/delta/17/drop_indexes.sql b/synapse/storage/schema/delta/17/drop_indexes.sql deleted file mode 100644 index 7c9a90e27f..0000000000 --- a/synapse/storage/schema/delta/17/drop_indexes.sql +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -DROP INDEX IF EXISTS sent_transaction_dest; -DROP INDEX IF EXISTS sent_transaction_sent; -DROP INDEX IF EXISTS user_ips_user; diff --git a/synapse/storage/schema/delta/17/server_keys.sql b/synapse/storage/schema/delta/17/server_keys.sql deleted file mode 100644 index 70b247a06b..0000000000 --- a/synapse/storage/schema/delta/17/server_keys.sql +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS server_keys_json ( - server_name TEXT, -- Server name. - key_id TEXT, -- Requested key id. - from_server TEXT, -- Which server the keys were fetched from. - ts_added_ms INTEGER, -- When the keys were fetched - ts_valid_until_ms INTEGER, -- When this version of the keys exipires. - key_json bytea, -- JSON certificate for the remote server. - CONSTRAINT uniqueness UNIQUE (server_name, key_id, from_server) -); diff --git a/synapse/storage/schema/delta/17/user_threepids.sql b/synapse/storage/schema/delta/17/user_threepids.sql deleted file mode 100644 index c17715ac80..0000000000 --- a/synapse/storage/schema/delta/17/user_threepids.sql +++ /dev/null @@ -1,9 +0,0 @@ -CREATE TABLE user_threepids ( - user_id TEXT NOT NULL, - medium TEXT NOT NULL, - address TEXT NOT NULL, - validated_at BIGINT NOT NULL, - added_at BIGINT NOT NULL, - CONSTRAINT user_medium_address UNIQUE (user_id, medium, address) -); -CREATE INDEX user_threepids_user_id ON user_threepids(user_id); diff --git a/synapse/storage/schema/delta/18/server_keys_bigger_ints.sql b/synapse/storage/schema/delta/18/server_keys_bigger_ints.sql deleted file mode 100644 index 6e0871c92b..0000000000 --- a/synapse/storage/schema/delta/18/server_keys_bigger_ints.sql +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -CREATE TABLE IF NOT EXISTS new_server_keys_json ( - server_name TEXT NOT NULL, -- Server name. - key_id TEXT NOT NULL, -- Requested key id. - from_server TEXT NOT NULL, -- Which server the keys were fetched from. - ts_added_ms BIGINT NOT NULL, -- When the keys were fetched - ts_valid_until_ms BIGINT NOT NULL, -- When this version of the keys exipires. - key_json bytea NOT NULL, -- JSON certificate for the remote server. - CONSTRAINT server_keys_json_uniqueness UNIQUE (server_name, key_id, from_server) -); - -INSERT INTO new_server_keys_json - SELECT server_name, key_id, from_server,ts_added_ms, ts_valid_until_ms, key_json FROM server_keys_json ; - -DROP TABLE server_keys_json; - -ALTER TABLE new_server_keys_json RENAME TO server_keys_json; diff --git a/synapse/storage/schema/delta/19/event_index.sql b/synapse/storage/schema/delta/19/event_index.sql deleted file mode 100644 index 18b97b4332..0000000000 --- a/synapse/storage/schema/delta/19/event_index.sql +++ /dev/null @@ -1,19 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -CREATE INDEX events_order_topo_stream_room ON events( - topological_ordering, stream_ordering, room_id -); diff --git a/synapse/storage/schema/delta/20/dummy.sql b/synapse/storage/schema/delta/20/dummy.sql deleted file mode 100644 index e0ac49d1ec..0000000000 --- a/synapse/storage/schema/delta/20/dummy.sql +++ /dev/null @@ -1 +0,0 @@ -SELECT 1; diff --git a/synapse/storage/schema/delta/20/pushers.py b/synapse/storage/schema/delta/20/pushers.py deleted file mode 100644 index 3edfcfd783..0000000000 --- a/synapse/storage/schema/delta/20/pushers.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -Main purpose of this upgrade is to change the unique key on the -pushers table again (it was missed when the v16 full schema was -made) but this also changes the pushkey and data columns to text. -When selecting a bytea column into a text column, postgres inserts -the hex encoded data, and there's no portable way of getting the -UTF-8 bytes, so we have to do it in Python. -""" - -import logging - -logger = logging.getLogger(__name__) - - -def run_create(cur, database_engine, *args, **kwargs): - logger.info("Porting pushers table...") - cur.execute( - """ - CREATE TABLE IF NOT EXISTS pushers2 ( - id BIGINT PRIMARY KEY, - user_name TEXT NOT NULL, - access_token BIGINT DEFAULT NULL, - profile_tag VARCHAR(32) NOT NULL, - kind VARCHAR(8) NOT NULL, - app_id VARCHAR(64) NOT NULL, - app_display_name VARCHAR(64) NOT NULL, - device_display_name VARCHAR(128) NOT NULL, - pushkey TEXT NOT NULL, - ts BIGINT NOT NULL, - lang VARCHAR(8), - data TEXT, - last_token TEXT, - last_success BIGINT, - failing_since BIGINT, - UNIQUE (app_id, pushkey, user_name) - ) - """ - ) - cur.execute( - """SELECT - id, user_name, access_token, profile_tag, kind, - app_id, app_display_name, device_display_name, - pushkey, ts, lang, data, last_token, last_success, - failing_since - FROM pushers - """ - ) - count = 0 - for row in cur.fetchall(): - row = list(row) - row[8] = bytes(row[8]).decode("utf-8") - row[11] = bytes(row[11]).decode("utf-8") - cur.execute( - database_engine.convert_param_style( - """ - INSERT into pushers2 ( - id, user_name, access_token, profile_tag, kind, - app_id, app_display_name, device_display_name, - pushkey, ts, lang, data, last_token, last_success, - failing_since - ) values (%s)""" - % (",".join(["?" for _ in range(len(row))])) - ), - row, - ) - count += 1 - cur.execute("DROP TABLE pushers") - cur.execute("ALTER TABLE pushers2 RENAME TO pushers") - logger.info("Moved %d pushers to new table", count) - - -def run_upgrade(*args, **kwargs): - pass diff --git a/synapse/storage/schema/delta/21/end_to_end_keys.sql b/synapse/storage/schema/delta/21/end_to_end_keys.sql deleted file mode 100644 index 4c2fb20b77..0000000000 --- a/synapse/storage/schema/delta/21/end_to_end_keys.sql +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -CREATE TABLE IF NOT EXISTS e2e_device_keys_json ( - user_id TEXT NOT NULL, -- The user these keys are for. - device_id TEXT NOT NULL, -- Which of the user's devices these keys are for. - ts_added_ms BIGINT NOT NULL, -- When the keys were uploaded. - key_json TEXT NOT NULL, -- The keys for the device as a JSON blob. - CONSTRAINT e2e_device_keys_json_uniqueness UNIQUE (user_id, device_id) -); - - -CREATE TABLE IF NOT EXISTS e2e_one_time_keys_json ( - user_id TEXT NOT NULL, -- The user this one-time key is for. - device_id TEXT NOT NULL, -- The device this one-time key is for. - algorithm TEXT NOT NULL, -- Which algorithm this one-time key is for. - key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads. - ts_added_ms BIGINT NOT NULL, -- When this key was uploaded. - key_json TEXT NOT NULL, -- The key as a JSON blob. - CONSTRAINT e2e_one_time_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm, key_id) -); diff --git a/synapse/storage/schema/delta/21/receipts.sql b/synapse/storage/schema/delta/21/receipts.sql deleted file mode 100644 index d070845477..0000000000 --- a/synapse/storage/schema/delta/21/receipts.sql +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -CREATE TABLE IF NOT EXISTS receipts_graph( - room_id TEXT NOT NULL, - receipt_type TEXT NOT NULL, - user_id TEXT NOT NULL, - event_ids TEXT NOT NULL, - data TEXT NOT NULL, - CONSTRAINT receipts_graph_uniqueness UNIQUE (room_id, receipt_type, user_id) -); - -CREATE TABLE IF NOT EXISTS receipts_linearized ( - stream_id BIGINT NOT NULL, - room_id TEXT NOT NULL, - receipt_type TEXT NOT NULL, - user_id TEXT NOT NULL, - event_id TEXT NOT NULL, - data TEXT NOT NULL, - CONSTRAINT receipts_linearized_uniqueness UNIQUE (room_id, receipt_type, user_id) -); - -CREATE INDEX receipts_linearized_id ON receipts_linearized( - stream_id -); diff --git a/synapse/storage/schema/delta/22/receipts_index.sql b/synapse/storage/schema/delta/22/receipts_index.sql deleted file mode 100644 index bfc0b3bcaa..0000000000 --- a/synapse/storage/schema/delta/22/receipts_index.sql +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** Using CREATE INDEX directly is deprecated in favour of using background - * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql - * and synapse/storage/registration.py for an example using - * "access_tokens_device_index" **/ -CREATE INDEX receipts_linearized_room_stream ON receipts_linearized( - room_id, stream_id -); diff --git a/synapse/storage/schema/delta/22/user_threepids_unique.sql b/synapse/storage/schema/delta/22/user_threepids_unique.sql deleted file mode 100644 index 87edfa454c..0000000000 --- a/synapse/storage/schema/delta/22/user_threepids_unique.sql +++ /dev/null @@ -1,19 +0,0 @@ -CREATE TABLE IF NOT EXISTS user_threepids2 ( - user_id TEXT NOT NULL, - medium TEXT NOT NULL, - address TEXT NOT NULL, - validated_at BIGINT NOT NULL, - added_at BIGINT NOT NULL, - CONSTRAINT medium_address UNIQUE (medium, address) -); - -INSERT INTO user_threepids2 - SELECT * FROM user_threepids WHERE added_at IN ( - SELECT max(added_at) FROM user_threepids GROUP BY medium, address - ) -; - -DROP TABLE user_threepids; -ALTER TABLE user_threepids2 RENAME TO user_threepids; - -CREATE INDEX user_threepids_user_id ON user_threepids(user_id); diff --git a/synapse/storage/schema/delta/23/drop_state_index.sql b/synapse/storage/schema/delta/23/drop_state_index.sql deleted file mode 100644 index ae09fa0065..0000000000 --- a/synapse/storage/schema/delta/23/drop_state_index.sql +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -DROP INDEX IF EXISTS state_groups_state_tuple; diff --git a/synapse/storage/schema/delta/24/stats_reporting.sql b/synapse/storage/schema/delta/24/stats_reporting.sql deleted file mode 100644 index acea7483bd..0000000000 --- a/synapse/storage/schema/delta/24/stats_reporting.sql +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright 2019 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - /* We used to create a table called stats_reporting, but this is no longer - * used and is removed in delta 54. - */ \ No newline at end of file diff --git a/synapse/storage/schema/delta/25/fts.py b/synapse/storage/schema/delta/25/fts.py deleted file mode 100644 index 4b2ffd35fd..0000000000 --- a/synapse/storage/schema/delta/25/fts.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -import simplejson - -from synapse.storage.engines import PostgresEngine, Sqlite3Engine -from synapse.storage.prepare_database import get_statements - -logger = logging.getLogger(__name__) - - -POSTGRES_TABLE = """ -CREATE TABLE IF NOT EXISTS event_search ( - event_id TEXT, - room_id TEXT, - sender TEXT, - key TEXT, - vector tsvector -); - -CREATE INDEX event_search_fts_idx ON event_search USING gin(vector); -CREATE INDEX event_search_ev_idx ON event_search(event_id); -CREATE INDEX event_search_ev_ridx ON event_search(room_id); -""" - - -SQLITE_TABLE = ( - "CREATE VIRTUAL TABLE event_search" - " USING fts4 ( event_id, room_id, sender, key, value )" -) - - -def run_create(cur, database_engine, *args, **kwargs): - if isinstance(database_engine, PostgresEngine): - for statement in get_statements(POSTGRES_TABLE.splitlines()): - cur.execute(statement) - elif isinstance(database_engine, Sqlite3Engine): - cur.execute(SQLITE_TABLE) - else: - raise Exception("Unrecognized database engine") - - cur.execute("SELECT MIN(stream_ordering) FROM events") - rows = cur.fetchall() - min_stream_id = rows[0][0] - - cur.execute("SELECT MAX(stream_ordering) FROM events") - rows = cur.fetchall() - max_stream_id = rows[0][0] - - if min_stream_id is not None and max_stream_id is not None: - progress = { - "target_min_stream_id_inclusive": min_stream_id, - "max_stream_id_exclusive": max_stream_id + 1, - "rows_inserted": 0, - } - progress_json = simplejson.dumps(progress) - - sql = ( - "INSERT into background_updates (update_name, progress_json)" - " VALUES (?, ?)" - ) - - sql = database_engine.convert_param_style(sql) - - cur.execute(sql, ("event_search", progress_json)) - - -def run_upgrade(*args, **kwargs): - pass diff --git a/synapse/storage/schema/delta/25/guest_access.sql b/synapse/storage/schema/delta/25/guest_access.sql deleted file mode 100644 index 1ea389b471..0000000000 --- a/synapse/storage/schema/delta/25/guest_access.sql +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * This is a manual index of guest_access content of state events, - * so that we can join on them in SELECT statements. - */ -CREATE TABLE IF NOT EXISTS guest_access( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - guest_access TEXT NOT NULL, - UNIQUE (event_id) -); diff --git a/synapse/storage/schema/delta/25/history_visibility.sql b/synapse/storage/schema/delta/25/history_visibility.sql deleted file mode 100644 index f468fc1897..0000000000 --- a/synapse/storage/schema/delta/25/history_visibility.sql +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * This is a manual index of history_visibility content of state events, - * so that we can join on them in SELECT statements. - */ -CREATE TABLE IF NOT EXISTS history_visibility( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - history_visibility TEXT NOT NULL, - UNIQUE (event_id) -); diff --git a/synapse/storage/schema/delta/25/tags.sql b/synapse/storage/schema/delta/25/tags.sql deleted file mode 100644 index 7a32ce68e4..0000000000 --- a/synapse/storage/schema/delta/25/tags.sql +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -CREATE TABLE IF NOT EXISTS room_tags( - user_id TEXT NOT NULL, - room_id TEXT NOT NULL, - tag TEXT NOT NULL, -- The name of the tag. - content TEXT NOT NULL, -- The JSON content of the tag. - CONSTRAINT room_tag_uniqueness UNIQUE (user_id, room_id, tag) -); - -CREATE TABLE IF NOT EXISTS room_tags_revisions ( - user_id TEXT NOT NULL, - room_id TEXT NOT NULL, - stream_id BIGINT NOT NULL, -- The current version of the room tags. - CONSTRAINT room_tag_revisions_uniqueness UNIQUE (user_id, room_id) -); - -CREATE TABLE IF NOT EXISTS private_user_data_max_stream_id( - Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. - stream_id BIGINT NOT NULL, - CHECK (Lock='X') -); - -INSERT INTO private_user_data_max_stream_id (stream_id) VALUES (0); diff --git a/synapse/storage/schema/delta/26/account_data.sql b/synapse/storage/schema/delta/26/account_data.sql deleted file mode 100644 index e395de2b5e..0000000000 --- a/synapse/storage/schema/delta/26/account_data.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -ALTER TABLE private_user_data_max_stream_id RENAME TO account_data_max_stream_id; diff --git a/synapse/storage/schema/delta/27/account_data.sql b/synapse/storage/schema/delta/27/account_data.sql deleted file mode 100644 index bf0558b5b3..0000000000 --- a/synapse/storage/schema/delta/27/account_data.sql +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS account_data( - user_id TEXT NOT NULL, - account_data_type TEXT NOT NULL, -- The type of the account_data. - stream_id BIGINT NOT NULL, -- The version of the account_data. - content TEXT NOT NULL, -- The JSON content of the account_data - CONSTRAINT account_data_uniqueness UNIQUE (user_id, account_data_type) -); - - -CREATE TABLE IF NOT EXISTS room_account_data( - user_id TEXT NOT NULL, - room_id TEXT NOT NULL, - account_data_type TEXT NOT NULL, -- The type of the account_data. - stream_id BIGINT NOT NULL, -- The version of the account_data. - content TEXT NOT NULL, -- The JSON content of the account_data - CONSTRAINT room_account_data_uniqueness UNIQUE (user_id, room_id, account_data_type) -); - - -CREATE INDEX account_data_stream_id on account_data(user_id, stream_id); -CREATE INDEX room_account_data_stream_id on room_account_data(user_id, stream_id); diff --git a/synapse/storage/schema/delta/27/forgotten_memberships.sql b/synapse/storage/schema/delta/27/forgotten_memberships.sql deleted file mode 100644 index e2094f37fe..0000000000 --- a/synapse/storage/schema/delta/27/forgotten_memberships.sql +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * Keeps track of what rooms users have left and don't want to be able to - * access again. - * - * If all users on this server have left a room, we can delete the room - * entirely. - * - * This column should always contain either 0 or 1. - */ - - ALTER TABLE room_memberships ADD COLUMN forgotten INTEGER DEFAULT 0; diff --git a/synapse/storage/schema/delta/27/ts.py b/synapse/storage/schema/delta/27/ts.py deleted file mode 100644 index 414f9f5aa0..0000000000 --- a/synapse/storage/schema/delta/27/ts.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -import simplejson - -from synapse.storage.prepare_database import get_statements - -logger = logging.getLogger(__name__) - - -ALTER_TABLE = ( - "ALTER TABLE events ADD COLUMN origin_server_ts BIGINT;" - "CREATE INDEX events_ts ON events(origin_server_ts, stream_ordering);" -) - - -def run_create(cur, database_engine, *args, **kwargs): - for statement in get_statements(ALTER_TABLE.splitlines()): - cur.execute(statement) - - cur.execute("SELECT MIN(stream_ordering) FROM events") - rows = cur.fetchall() - min_stream_id = rows[0][0] - - cur.execute("SELECT MAX(stream_ordering) FROM events") - rows = cur.fetchall() - max_stream_id = rows[0][0] - - if min_stream_id is not None and max_stream_id is not None: - progress = { - "target_min_stream_id_inclusive": min_stream_id, - "max_stream_id_exclusive": max_stream_id + 1, - "rows_inserted": 0, - } - progress_json = simplejson.dumps(progress) - - sql = ( - "INSERT into background_updates (update_name, progress_json)" - " VALUES (?, ?)" - ) - - sql = database_engine.convert_param_style(sql) - - cur.execute(sql, ("event_origin_server_ts", progress_json)) - - -def run_upgrade(*args, **kwargs): - pass diff --git a/synapse/storage/schema/delta/28/event_push_actions.sql b/synapse/storage/schema/delta/28/event_push_actions.sql deleted file mode 100644 index 4d519849df..0000000000 --- a/synapse/storage/schema/delta/28/event_push_actions.sql +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright 2015 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS event_push_actions( - room_id TEXT NOT NULL, - event_id TEXT NOT NULL, - user_id TEXT NOT NULL, - profile_tag VARCHAR(32), - actions TEXT NOT NULL, - CONSTRAINT event_id_user_id_profile_tag_uniqueness UNIQUE (room_id, event_id, user_id, profile_tag) -); - - -CREATE INDEX event_push_actions_room_id_event_id_user_id_profile_tag on event_push_actions(room_id, event_id, user_id, profile_tag); -CREATE INDEX event_push_actions_room_id_user_id on event_push_actions(room_id, user_id); diff --git a/synapse/storage/schema/delta/28/events_room_stream.sql b/synapse/storage/schema/delta/28/events_room_stream.sql deleted file mode 100644 index 36609475f1..0000000000 --- a/synapse/storage/schema/delta/28/events_room_stream.sql +++ /dev/null @@ -1,20 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. -*/ - -/** Using CREATE INDEX directly is deprecated in favour of using background - * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql - * and synapse/storage/registration.py for an example using - * "access_tokens_device_index" **/ -CREATE INDEX events_room_stream on events(room_id, stream_ordering); diff --git a/synapse/storage/schema/delta/28/public_roms_index.sql b/synapse/storage/schema/delta/28/public_roms_index.sql deleted file mode 100644 index 6c1fd68c5b..0000000000 --- a/synapse/storage/schema/delta/28/public_roms_index.sql +++ /dev/null @@ -1,20 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. -*/ - -/** Using CREATE INDEX directly is deprecated in favour of using background - * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql - * and synapse/storage/registration.py for an example using - * "access_tokens_device_index" **/ -CREATE INDEX public_room_index on rooms(is_public); diff --git a/synapse/storage/schema/delta/28/receipts_user_id_index.sql b/synapse/storage/schema/delta/28/receipts_user_id_index.sql deleted file mode 100644 index cb84c69baa..0000000000 --- a/synapse/storage/schema/delta/28/receipts_user_id_index.sql +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** Using CREATE INDEX directly is deprecated in favour of using background - * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql - * and synapse/storage/registration.py for an example using - * "access_tokens_device_index" **/ -CREATE INDEX receipts_linearized_user ON receipts_linearized( - user_id -); diff --git a/synapse/storage/schema/delta/28/upgrade_times.sql b/synapse/storage/schema/delta/28/upgrade_times.sql deleted file mode 100644 index 3e4a9ab455..0000000000 --- a/synapse/storage/schema/delta/28/upgrade_times.sql +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * Stores the timestamp when a user upgraded from a guest to a full user, if - * that happened. - */ - -ALTER TABLE users ADD COLUMN upgrade_ts BIGINT; diff --git a/synapse/storage/schema/delta/28/users_is_guest.sql b/synapse/storage/schema/delta/28/users_is_guest.sql deleted file mode 100644 index 21d2b420bf..0000000000 --- a/synapse/storage/schema/delta/28/users_is_guest.sql +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -ALTER TABLE users ADD is_guest SMALLINT DEFAULT 0 NOT NULL; -/* - * NB: any guest users created between 27 and 28 will be incorrectly - * marked as not guests: we don't bother to fill these in correctly - * because guest access is not really complete in 27 anyway so it's - * very unlikley there will be any guest users created. - */ diff --git a/synapse/storage/schema/delta/29/push_actions.sql b/synapse/storage/schema/delta/29/push_actions.sql deleted file mode 100644 index 84b21cf813..0000000000 --- a/synapse/storage/schema/delta/29/push_actions.sql +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -ALTER TABLE event_push_actions ADD COLUMN topological_ordering BIGINT; -ALTER TABLE event_push_actions ADD COLUMN stream_ordering BIGINT; -ALTER TABLE event_push_actions ADD COLUMN notif SMALLINT; -ALTER TABLE event_push_actions ADD COLUMN highlight SMALLINT; - -UPDATE event_push_actions SET stream_ordering = ( - SELECT stream_ordering FROM events WHERE event_id = event_push_actions.event_id -), topological_ordering = ( - SELECT topological_ordering FROM events WHERE event_id = event_push_actions.event_id -); - -UPDATE event_push_actions SET notif = 1, highlight = 0; - -/** Using CREATE INDEX directly is deprecated in favour of using background - * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql - * and synapse/storage/registration.py for an example using - * "access_tokens_device_index" **/ -CREATE INDEX event_push_actions_rm_tokens on event_push_actions( - user_id, room_id, topological_ordering, stream_ordering -); diff --git a/synapse/storage/schema/delta/30/alias_creator.sql b/synapse/storage/schema/delta/30/alias_creator.sql deleted file mode 100644 index c9d0dde638..0000000000 --- a/synapse/storage/schema/delta/30/alias_creator.sql +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -ALTER TABLE room_aliases ADD COLUMN creator TEXT; diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/schema/delta/30/as_users.py deleted file mode 100644 index 9b95411fb6..0000000000 --- a/synapse/storage/schema/delta/30/as_users.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging - -from six.moves import range - -from synapse.config.appservice import load_appservices - -logger = logging.getLogger(__name__) - - -def run_create(cur, database_engine, *args, **kwargs): - # NULL indicates user was not registered by an appservice. - try: - cur.execute("ALTER TABLE users ADD COLUMN appservice_id TEXT") - except Exception: - # Maybe we already added the column? Hope so... - pass - - -def run_upgrade(cur, database_engine, config, *args, **kwargs): - cur.execute("SELECT name FROM users") - rows = cur.fetchall() - - config_files = [] - try: - config_files = config.app_service_config_files - except AttributeError: - logger.warning("Could not get app_service_config_files from config") - pass - - appservices = load_appservices(config.server_name, config_files) - - owned = {} - - for row in rows: - user_id = row[0] - for appservice in appservices: - if appservice.is_exclusive_user(user_id): - if user_id in owned.keys(): - logger.error( - "user_id %s was owned by more than one application" - " service (IDs %s and %s); assigning arbitrarily to %s" - % (user_id, owned[user_id], appservice.id, owned[user_id]) - ) - owned.setdefault(appservice.id, []).append(user_id) - - for as_id, user_ids in owned.items(): - n = 100 - user_chunks = (user_ids[i : i + 100] for i in range(0, len(user_ids), n)) - for chunk in user_chunks: - cur.execute( - database_engine.convert_param_style( - "UPDATE users SET appservice_id = ? WHERE name IN (%s)" - % (",".join("?" for _ in chunk),) - ), - [as_id] + chunk, - ) diff --git a/synapse/storage/schema/delta/30/deleted_pushers.sql b/synapse/storage/schema/delta/30/deleted_pushers.sql deleted file mode 100644 index 712c454aa1..0000000000 --- a/synapse/storage/schema/delta/30/deleted_pushers.sql +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS deleted_pushers( - stream_id BIGINT NOT NULL, - app_id TEXT NOT NULL, - pushkey TEXT NOT NULL, - user_id TEXT NOT NULL, - /* We only track the most recent delete for each app_id, pushkey and user_id. */ - UNIQUE (app_id, pushkey, user_id) -); - -CREATE INDEX deleted_pushers_stream_id ON deleted_pushers (stream_id); diff --git a/synapse/storage/schema/delta/30/presence_stream.sql b/synapse/storage/schema/delta/30/presence_stream.sql deleted file mode 100644 index 606bbb037d..0000000000 --- a/synapse/storage/schema/delta/30/presence_stream.sql +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - - CREATE TABLE presence_stream( - stream_id BIGINT, - user_id TEXT, - state TEXT, - last_active_ts BIGINT, - last_federation_update_ts BIGINT, - last_user_sync_ts BIGINT, - status_msg TEXT, - currently_active BOOLEAN - ); - - CREATE INDEX presence_stream_id ON presence_stream(stream_id, user_id); - CREATE INDEX presence_stream_user_id ON presence_stream(user_id); - CREATE INDEX presence_stream_state ON presence_stream(state); diff --git a/synapse/storage/schema/delta/30/public_rooms.sql b/synapse/storage/schema/delta/30/public_rooms.sql deleted file mode 100644 index f09db4faa6..0000000000 --- a/synapse/storage/schema/delta/30/public_rooms.sql +++ /dev/null @@ -1,23 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -/* This release removes the restriction that published rooms must have an alias, - * so we go back and ensure the only 'public' rooms are ones with an alias. - * We use (1 = 0) and (1 = 1) so that it works in both postgres and sqlite - */ -UPDATE rooms SET is_public = (1 = 0) WHERE is_public = (1 = 1) AND room_id not in ( - SELECT room_id FROM room_aliases -); diff --git a/synapse/storage/schema/delta/30/push_rule_stream.sql b/synapse/storage/schema/delta/30/push_rule_stream.sql deleted file mode 100644 index 735aa8d5f6..0000000000 --- a/synapse/storage/schema/delta/30/push_rule_stream.sql +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - - -CREATE TABLE push_rules_stream( - stream_id BIGINT NOT NULL, - event_stream_ordering BIGINT NOT NULL, - user_id TEXT NOT NULL, - rule_id TEXT NOT NULL, - op TEXT NOT NULL, -- One of "ENABLE", "DISABLE", "ACTIONS", "ADD", "DELETE" - priority_class SMALLINT, - priority INTEGER, - conditions TEXT, - actions TEXT -); - --- The extra data for each operation is: --- * ENABLE, DISABLE, DELETE: [] --- * ACTIONS: ["actions"] --- * ADD: ["priority_class", "priority", "actions", "conditions"] - --- Index for replication queries. -CREATE INDEX push_rules_stream_id ON push_rules_stream(stream_id); --- Index for /sync queries. -CREATE INDEX push_rules_stream_user_stream_id on push_rules_stream(user_id, stream_id); diff --git a/synapse/storage/schema/delta/30/state_stream.sql b/synapse/storage/schema/delta/30/state_stream.sql deleted file mode 100644 index e85699e82e..0000000000 --- a/synapse/storage/schema/delta/30/state_stream.sql +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -/* We used to create a table called current_state_resets, but this is no - * longer used and is removed in delta 54. - */ - -/* The outlier events that have aquired a state group typically through - * backfill. This is tracked separately to the events table, as assigning a - * state group change the position of the existing event in the stream - * ordering. - * However since a stream_ordering is assigned in persist_event for the - * (event, state) pair, we can use that stream_ordering to identify when - * the new state was assigned for the event. - */ -CREATE TABLE IF NOT EXISTS ex_outlier_stream( - event_stream_ordering BIGINT PRIMARY KEY NOT NULL, - event_id TEXT NOT NULL, - state_group BIGINT NOT NULL -); diff --git a/synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql b/synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql deleted file mode 100644 index 0dd2f1360c..0000000000 --- a/synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Stores guest account access tokens generated for unbound 3pids. -CREATE TABLE threepid_guest_access_tokens( - medium TEXT, -- The medium of the 3pid. Must be "email". - address TEXT, -- The 3pid address. - guest_access_token TEXT, -- The access token for a guest user for this 3pid. - first_inviter TEXT -- User ID of the first user to invite this 3pid to a room. -); - -CREATE UNIQUE INDEX threepid_guest_access_tokens_index ON threepid_guest_access_tokens(medium, address); diff --git a/synapse/storage/schema/delta/31/invites.sql b/synapse/storage/schema/delta/31/invites.sql deleted file mode 100644 index 2c57846d5a..0000000000 --- a/synapse/storage/schema/delta/31/invites.sql +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -CREATE TABLE local_invites( - stream_id BIGINT NOT NULL, - inviter TEXT NOT NULL, - invitee TEXT NOT NULL, - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - locally_rejected TEXT, - replaced_by TEXT -); - --- Insert all invites for local users into new `invites` table -INSERT INTO local_invites SELECT - stream_ordering as stream_id, - sender as inviter, - state_key as invitee, - event_id, - room_id, - NULL as locally_rejected, - NULL as replaced_by - FROM events - NATURAL JOIN current_state_events - NATURAL JOIN room_memberships - WHERE membership = 'invite' AND state_key IN (SELECT name FROM users); - -CREATE INDEX local_invites_id ON local_invites(stream_id); -CREATE INDEX local_invites_for_user_idx ON local_invites(invitee, locally_rejected, replaced_by, room_id); diff --git a/synapse/storage/schema/delta/31/local_media_repository_url_cache.sql b/synapse/storage/schema/delta/31/local_media_repository_url_cache.sql deleted file mode 100644 index 9efb4280eb..0000000000 --- a/synapse/storage/schema/delta/31/local_media_repository_url_cache.sql +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE local_media_repository_url_cache( - url TEXT, -- the URL being cached - response_code INTEGER, -- the HTTP response code of this download attempt - etag TEXT, -- the etag header of this response - expires INTEGER, -- the number of ms this response was valid for - og TEXT, -- cache of the OG metadata of this URL as JSON - media_id TEXT, -- the media_id, if any, of the URL's content in the repo - download_ts BIGINT -- the timestamp of this download attempt -); - -CREATE INDEX local_media_repository_url_cache_by_url_download_ts - ON local_media_repository_url_cache(url, download_ts); diff --git a/synapse/storage/schema/delta/31/pushers.py b/synapse/storage/schema/delta/31/pushers.py deleted file mode 100644 index 9bb504aad5..0000000000 --- a/synapse/storage/schema/delta/31/pushers.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# Change the last_token to last_stream_ordering now that pushers no longer -# listen on an event stream but instead select out of the event_push_actions -# table. - - -import logging - -logger = logging.getLogger(__name__) - - -def token_to_stream_ordering(token): - return int(token[1:].split("_")[0]) - - -def run_create(cur, database_engine, *args, **kwargs): - logger.info("Porting pushers table, delta 31...") - cur.execute( - """ - CREATE TABLE IF NOT EXISTS pushers2 ( - id BIGINT PRIMARY KEY, - user_name TEXT NOT NULL, - access_token BIGINT DEFAULT NULL, - profile_tag VARCHAR(32) NOT NULL, - kind VARCHAR(8) NOT NULL, - app_id VARCHAR(64) NOT NULL, - app_display_name VARCHAR(64) NOT NULL, - device_display_name VARCHAR(128) NOT NULL, - pushkey TEXT NOT NULL, - ts BIGINT NOT NULL, - lang VARCHAR(8), - data TEXT, - last_stream_ordering INTEGER, - last_success BIGINT, - failing_since BIGINT, - UNIQUE (app_id, pushkey, user_name) - ) - """ - ) - cur.execute( - """SELECT - id, user_name, access_token, profile_tag, kind, - app_id, app_display_name, device_display_name, - pushkey, ts, lang, data, last_token, last_success, - failing_since - FROM pushers - """ - ) - count = 0 - for row in cur.fetchall(): - row = list(row) - row[12] = token_to_stream_ordering(row[12]) - cur.execute( - database_engine.convert_param_style( - """ - INSERT into pushers2 ( - id, user_name, access_token, profile_tag, kind, - app_id, app_display_name, device_display_name, - pushkey, ts, lang, data, last_stream_ordering, last_success, - failing_since - ) values (%s)""" - % (",".join(["?" for _ in range(len(row))])) - ), - row, - ) - count += 1 - cur.execute("DROP TABLE pushers") - cur.execute("ALTER TABLE pushers2 RENAME TO pushers") - logger.info("Moved %d pushers to new table", count) - - -def run_upgrade(cur, database_engine, *args, **kwargs): - pass diff --git a/synapse/storage/schema/delta/31/pushers_index.sql b/synapse/storage/schema/delta/31/pushers_index.sql deleted file mode 100644 index a82add88fd..0000000000 --- a/synapse/storage/schema/delta/31/pushers_index.sql +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** Using CREATE INDEX directly is deprecated in favour of using background - * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql - * and synapse/storage/registration.py for an example using - * "access_tokens_device_index" **/ - CREATE INDEX event_push_actions_stream_ordering on event_push_actions( - stream_ordering, user_id - ); diff --git a/synapse/storage/schema/delta/31/search_update.py b/synapse/storage/schema/delta/31/search_update.py deleted file mode 100644 index 7d8ca5f93f..0000000000 --- a/synapse/storage/schema/delta/31/search_update.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -import simplejson - -from synapse.storage.engines import PostgresEngine -from synapse.storage.prepare_database import get_statements - -logger = logging.getLogger(__name__) - - -ALTER_TABLE = """ -ALTER TABLE event_search ADD COLUMN origin_server_ts BIGINT; -ALTER TABLE event_search ADD COLUMN stream_ordering BIGINT; -""" - - -def run_create(cur, database_engine, *args, **kwargs): - if not isinstance(database_engine, PostgresEngine): - return - - for statement in get_statements(ALTER_TABLE.splitlines()): - cur.execute(statement) - - cur.execute("SELECT MIN(stream_ordering) FROM events") - rows = cur.fetchall() - min_stream_id = rows[0][0] - - cur.execute("SELECT MAX(stream_ordering) FROM events") - rows = cur.fetchall() - max_stream_id = rows[0][0] - - if min_stream_id is not None and max_stream_id is not None: - progress = { - "target_min_stream_id_inclusive": min_stream_id, - "max_stream_id_exclusive": max_stream_id + 1, - "rows_inserted": 0, - "have_added_indexes": False, - } - progress_json = simplejson.dumps(progress) - - sql = ( - "INSERT into background_updates (update_name, progress_json)" - " VALUES (?, ?)" - ) - - sql = database_engine.convert_param_style(sql) - - cur.execute(sql, ("event_search_order", progress_json)) - - -def run_upgrade(cur, database_engine, *args, **kwargs): - pass diff --git a/synapse/storage/schema/delta/32/events.sql b/synapse/storage/schema/delta/32/events.sql deleted file mode 100644 index 1dd0f9e170..0000000000 --- a/synapse/storage/schema/delta/32/events.sql +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -ALTER TABLE events ADD COLUMN received_ts BIGINT; diff --git a/synapse/storage/schema/delta/32/openid.sql b/synapse/storage/schema/delta/32/openid.sql deleted file mode 100644 index 36f37b11c8..0000000000 --- a/synapse/storage/schema/delta/32/openid.sql +++ /dev/null @@ -1,9 +0,0 @@ - -CREATE TABLE open_id_tokens ( - token TEXT NOT NULL PRIMARY KEY, - ts_valid_until_ms bigint NOT NULL, - user_id TEXT NOT NULL, - UNIQUE (token) -); - -CREATE index open_id_tokens_ts_valid_until_ms ON open_id_tokens(ts_valid_until_ms); diff --git a/synapse/storage/schema/delta/32/pusher_throttle.sql b/synapse/storage/schema/delta/32/pusher_throttle.sql deleted file mode 100644 index d86d30c13c..0000000000 --- a/synapse/storage/schema/delta/32/pusher_throttle.sql +++ /dev/null @@ -1,23 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -CREATE TABLE pusher_throttle( - pusher BIGINT NOT NULL, - room_id TEXT NOT NULL, - last_sent_ts BIGINT, - throttle_ms BIGINT, - PRIMARY KEY (pusher, room_id) -); diff --git a/synapse/storage/schema/delta/32/remove_indices.sql b/synapse/storage/schema/delta/32/remove_indices.sql deleted file mode 100644 index 4219cdd06a..0000000000 --- a/synapse/storage/schema/delta/32/remove_indices.sql +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - --- The following indices are redundant, other indices are equivalent or --- supersets -DROP INDEX IF EXISTS events_room_id; -- Prefix of events_room_stream -DROP INDEX IF EXISTS events_order; -- Prefix of events_order_topo_stream_room -DROP INDEX IF EXISTS events_topological_ordering; -- Prefix of events_order_topo_stream_room -DROP INDEX IF EXISTS events_stream_ordering; -- Duplicate of PRIMARY KEY -DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY -DROP INDEX IF EXISTS event_to_state_groups_id; -- Duplicate of PRIMARY KEY -DROP INDEX IF EXISTS event_push_actions_room_id_event_id_user_id_profile_tag; -- Duplicate of UNIQUE CONSTRAINT - -DROP INDEX IF EXISTS st_extrem_id; -- Prefix of UNIQUE CONSTRAINT -DROP INDEX IF EXISTS event_signatures_id; -- Prefix of UNIQUE CONSTRAINT -DROP INDEX IF EXISTS redactions_event_id; -- Duplicate of UNIQUE CONSTRAINT - --- The following indices were unused -DROP INDEX IF EXISTS remote_media_cache_thumbnails_media_id; -DROP INDEX IF EXISTS evauth_edges_auth_id; -DROP INDEX IF EXISTS presence_stream_state; diff --git a/synapse/storage/schema/delta/32/reports.sql b/synapse/storage/schema/delta/32/reports.sql deleted file mode 100644 index d13609776f..0000000000 --- a/synapse/storage/schema/delta/32/reports.sql +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -CREATE TABLE event_reports( - id BIGINT NOT NULL PRIMARY KEY, - received_ts BIGINT NOT NULL, - room_id TEXT NOT NULL, - event_id TEXT NOT NULL, - user_id TEXT NOT NULL, - reason TEXT, - content TEXT -); diff --git a/synapse/storage/schema/delta/33/access_tokens_device_index.sql b/synapse/storage/schema/delta/33/access_tokens_device_index.sql deleted file mode 100644 index 61ad3fe3e8..0000000000 --- a/synapse/storage/schema/delta/33/access_tokens_device_index.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -INSERT INTO background_updates (update_name, progress_json) VALUES - ('access_tokens_device_index', '{}'); diff --git a/synapse/storage/schema/delta/33/devices.sql b/synapse/storage/schema/delta/33/devices.sql deleted file mode 100644 index eca7268d82..0000000000 --- a/synapse/storage/schema/delta/33/devices.sql +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE devices ( - user_id TEXT NOT NULL, - device_id TEXT NOT NULL, - display_name TEXT, - CONSTRAINT device_uniqueness UNIQUE (user_id, device_id) -); diff --git a/synapse/storage/schema/delta/33/devices_for_e2e_keys.sql b/synapse/storage/schema/delta/33/devices_for_e2e_keys.sql deleted file mode 100644 index aa4a3b9f2f..0000000000 --- a/synapse/storage/schema/delta/33/devices_for_e2e_keys.sql +++ /dev/null @@ -1,19 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- make sure that we have a device record for each set of E2E keys, so that the --- user can delete them if they like. -INSERT INTO devices - SELECT user_id, device_id, NULL FROM e2e_device_keys_json; diff --git a/synapse/storage/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql b/synapse/storage/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql deleted file mode 100644 index 6671573398..0000000000 --- a/synapse/storage/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql +++ /dev/null @@ -1,20 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- a previous version of the "devices_for_e2e_keys" delta set all the device --- names to "unknown device". This wasn't terribly helpful -UPDATE devices - SET display_name = NULL - WHERE display_name = 'unknown device'; diff --git a/synapse/storage/schema/delta/33/event_fields.py b/synapse/storage/schema/delta/33/event_fields.py deleted file mode 100644 index bff1256a7b..0000000000 --- a/synapse/storage/schema/delta/33/event_fields.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -import simplejson - -from synapse.storage.prepare_database import get_statements - -logger = logging.getLogger(__name__) - - -ALTER_TABLE = """ -ALTER TABLE events ADD COLUMN sender TEXT; -ALTER TABLE events ADD COLUMN contains_url BOOLEAN; -""" - - -def run_create(cur, database_engine, *args, **kwargs): - for statement in get_statements(ALTER_TABLE.splitlines()): - cur.execute(statement) - - cur.execute("SELECT MIN(stream_ordering) FROM events") - rows = cur.fetchall() - min_stream_id = rows[0][0] - - cur.execute("SELECT MAX(stream_ordering) FROM events") - rows = cur.fetchall() - max_stream_id = rows[0][0] - - if min_stream_id is not None and max_stream_id is not None: - progress = { - "target_min_stream_id_inclusive": min_stream_id, - "max_stream_id_exclusive": max_stream_id + 1, - "rows_inserted": 0, - } - progress_json = simplejson.dumps(progress) - - sql = ( - "INSERT into background_updates (update_name, progress_json)" - " VALUES (?, ?)" - ) - - sql = database_engine.convert_param_style(sql) - - cur.execute(sql, ("event_fields_sender_url", progress_json)) - - -def run_upgrade(cur, database_engine, *args, **kwargs): - pass diff --git a/synapse/storage/schema/delta/33/remote_media_ts.py b/synapse/storage/schema/delta/33/remote_media_ts.py deleted file mode 100644 index a26057dfb6..0000000000 --- a/synapse/storage/schema/delta/33/remote_media_ts.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time - -ALTER_TABLE = "ALTER TABLE remote_media_cache ADD COLUMN last_access_ts BIGINT" - - -def run_create(cur, database_engine, *args, **kwargs): - cur.execute(ALTER_TABLE) - - -def run_upgrade(cur, database_engine, *args, **kwargs): - cur.execute( - database_engine.convert_param_style( - "UPDATE remote_media_cache SET last_access_ts = ?" - ), - (int(time.time() * 1000),), - ) diff --git a/synapse/storage/schema/delta/33/user_ips_index.sql b/synapse/storage/schema/delta/33/user_ips_index.sql deleted file mode 100644 index 473f75a78e..0000000000 --- a/synapse/storage/schema/delta/33/user_ips_index.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -INSERT INTO background_updates (update_name, progress_json) VALUES - ('user_ips_device_index', '{}'); diff --git a/synapse/storage/schema/delta/34/appservice_stream.sql b/synapse/storage/schema/delta/34/appservice_stream.sql deleted file mode 100644 index 69e16eda0f..0000000000 --- a/synapse/storage/schema/delta/34/appservice_stream.sql +++ /dev/null @@ -1,23 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS appservice_stream_position( - Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. - stream_ordering BIGINT, - CHECK (Lock='X') -); - -INSERT INTO appservice_stream_position (stream_ordering) - SELECT COALESCE(MAX(stream_ordering), 0) FROM events; diff --git a/synapse/storage/schema/delta/34/cache_stream.py b/synapse/storage/schema/delta/34/cache_stream.py deleted file mode 100644 index cf09e43e2b..0000000000 --- a/synapse/storage/schema/delta/34/cache_stream.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.storage.engines import PostgresEngine -from synapse.storage.prepare_database import get_statements - -logger = logging.getLogger(__name__) - - -# This stream is used to notify replication slaves that some caches have -# been invalidated that they cannot infer from the other streams. -CREATE_TABLE = """ -CREATE TABLE cache_invalidation_stream ( - stream_id BIGINT, - cache_func TEXT, - keys TEXT[], - invalidation_ts BIGINT -); - -CREATE INDEX cache_invalidation_stream_id ON cache_invalidation_stream(stream_id); -""" - - -def run_create(cur, database_engine, *args, **kwargs): - if not isinstance(database_engine, PostgresEngine): - return - - for statement in get_statements(CREATE_TABLE.splitlines()): - cur.execute(statement) - - -def run_upgrade(cur, database_engine, *args, **kwargs): - pass diff --git a/synapse/storage/schema/delta/34/device_inbox.sql b/synapse/storage/schema/delta/34/device_inbox.sql deleted file mode 100644 index e68844c74a..0000000000 --- a/synapse/storage/schema/delta/34/device_inbox.sql +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE device_inbox ( - user_id TEXT NOT NULL, - device_id TEXT NOT NULL, - stream_id BIGINT NOT NULL, - message_json TEXT NOT NULL -- {"type":, "sender":, "content",} -); - -CREATE INDEX device_inbox_user_stream_id ON device_inbox(user_id, device_id, stream_id); -CREATE INDEX device_inbox_stream_id ON device_inbox(stream_id); diff --git a/synapse/storage/schema/delta/34/push_display_name_rename.sql b/synapse/storage/schema/delta/34/push_display_name_rename.sql deleted file mode 100644 index 0d9fe1a99a..0000000000 --- a/synapse/storage/schema/delta/34/push_display_name_rename.sql +++ /dev/null @@ -1,20 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -DELETE FROM push_rules WHERE rule_id = 'global/override/.m.rule.contains_display_name'; -UPDATE push_rules SET rule_id = 'global/override/.m.rule.contains_display_name' WHERE rule_id = 'global/underride/.m.rule.contains_display_name'; - -DELETE FROM push_rules_enable WHERE rule_id = 'global/override/.m.rule.contains_display_name'; -UPDATE push_rules_enable SET rule_id = 'global/override/.m.rule.contains_display_name' WHERE rule_id = 'global/underride/.m.rule.contains_display_name'; diff --git a/synapse/storage/schema/delta/34/received_txn_purge.py b/synapse/storage/schema/delta/34/received_txn_purge.py deleted file mode 100644 index 67d505e68b..0000000000 --- a/synapse/storage/schema/delta/34/received_txn_purge.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.storage.engines import PostgresEngine - -logger = logging.getLogger(__name__) - - -def run_create(cur, database_engine, *args, **kwargs): - if isinstance(database_engine, PostgresEngine): - cur.execute("TRUNCATE received_transactions") - else: - cur.execute("DELETE FROM received_transactions") - - cur.execute("CREATE INDEX received_transactions_ts ON received_transactions(ts)") - - -def run_upgrade(cur, database_engine, *args, **kwargs): - pass diff --git a/synapse/storage/schema/delta/35/00background_updates_add_col.sql b/synapse/storage/schema/delta/35/00background_updates_add_col.sql new file mode 100644 index 0000000000..c2d2a4f836 --- /dev/null +++ b/synapse/storage/schema/delta/35/00background_updates_add_col.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +ALTER TABLE background_updates ADD COLUMN depends_on TEXT; diff --git a/synapse/storage/schema/delta/35/add_state_index.sql b/synapse/storage/schema/delta/35/add_state_index.sql deleted file mode 100644 index 0fce26345b..0000000000 --- a/synapse/storage/schema/delta/35/add_state_index.sql +++ /dev/null @@ -1,20 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -ALTER TABLE background_updates ADD COLUMN depends_on TEXT; - -INSERT into background_updates (update_name, progress_json, depends_on) - VALUES ('state_group_state_type_index', '{}', 'state_group_state_deduplication'); diff --git a/synapse/storage/schema/delta/35/contains_url.sql b/synapse/storage/schema/delta/35/contains_url.sql deleted file mode 100644 index 6cd123027b..0000000000 --- a/synapse/storage/schema/delta/35/contains_url.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - INSERT into background_updates (update_name, progress_json) - VALUES ('event_contains_url_index', '{}'); diff --git a/synapse/storage/schema/delta/35/device_outbox.sql b/synapse/storage/schema/delta/35/device_outbox.sql deleted file mode 100644 index 17e6c43105..0000000000 --- a/synapse/storage/schema/delta/35/device_outbox.sql +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -DROP TABLE IF EXISTS device_federation_outbox; -CREATE TABLE device_federation_outbox ( - destination TEXT NOT NULL, - stream_id BIGINT NOT NULL, - queued_ts BIGINT NOT NULL, - messages_json TEXT NOT NULL -); - - -DROP INDEX IF EXISTS device_federation_outbox_destination_id; -CREATE INDEX device_federation_outbox_destination_id - ON device_federation_outbox(destination, stream_id); - - -DROP TABLE IF EXISTS device_federation_inbox; -CREATE TABLE device_federation_inbox ( - origin TEXT NOT NULL, - message_id TEXT NOT NULL, - received_ts BIGINT NOT NULL -); - -DROP INDEX IF EXISTS device_federation_inbox_sender_id; -CREATE INDEX device_federation_inbox_sender_id - ON device_federation_inbox(origin, message_id); diff --git a/synapse/storage/schema/delta/35/device_stream_id.sql b/synapse/storage/schema/delta/35/device_stream_id.sql deleted file mode 100644 index 7ab7d942e2..0000000000 --- a/synapse/storage/schema/delta/35/device_stream_id.sql +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE device_max_stream_id ( - stream_id BIGINT NOT NULL -); - -INSERT INTO device_max_stream_id (stream_id) - SELECT COALESCE(MAX(stream_id), 0) FROM device_inbox; diff --git a/synapse/storage/schema/delta/35/event_push_actions_index.sql b/synapse/storage/schema/delta/35/event_push_actions_index.sql deleted file mode 100644 index 2e836d8e9c..0000000000 --- a/synapse/storage/schema/delta/35/event_push_actions_index.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - INSERT into background_updates (update_name, progress_json) - VALUES ('epa_highlight_index', '{}'); diff --git a/synapse/storage/schema/delta/35/public_room_list_change_stream.sql b/synapse/storage/schema/delta/35/public_room_list_change_stream.sql deleted file mode 100644 index dd2bf2e28a..0000000000 --- a/synapse/storage/schema/delta/35/public_room_list_change_stream.sql +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -CREATE TABLE public_room_list_stream ( - stream_id BIGINT NOT NULL, - room_id TEXT NOT NULL, - visibility BOOLEAN NOT NULL -); - -INSERT INTO public_room_list_stream (stream_id, room_id, visibility) - SELECT 1, room_id, is_public FROM rooms - WHERE is_public = CAST(1 AS BOOLEAN); - -CREATE INDEX public_room_list_stream_idx on public_room_list_stream( - stream_id -); - -CREATE INDEX public_room_list_stream_rm_idx on public_room_list_stream( - room_id, stream_id -); diff --git a/synapse/storage/schema/delta/35/state.sql b/synapse/storage/schema/delta/35/state.sql deleted file mode 100644 index 0f1fa68a89..0000000000 --- a/synapse/storage/schema/delta/35/state.sql +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE state_group_edges( - state_group BIGINT NOT NULL, - prev_state_group BIGINT NOT NULL -); - -CREATE INDEX state_group_edges_idx ON state_group_edges(state_group); -CREATE INDEX state_group_edges_prev_idx ON state_group_edges(prev_state_group); diff --git a/synapse/storage/schema/delta/35/state_dedupe.sql b/synapse/storage/schema/delta/35/state_dedupe.sql deleted file mode 100644 index 97e5067ef4..0000000000 --- a/synapse/storage/schema/delta/35/state_dedupe.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -INSERT into background_updates (update_name, progress_json) - VALUES ('state_group_state_deduplication', '{}'); diff --git a/synapse/storage/schema/delta/35/stream_order_to_extrem.sql b/synapse/storage/schema/delta/35/stream_order_to_extrem.sql deleted file mode 100644 index 2b945d8a57..0000000000 --- a/synapse/storage/schema/delta/35/stream_order_to_extrem.sql +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -CREATE TABLE stream_ordering_to_exterm ( - stream_ordering BIGINT NOT NULL, - room_id TEXT NOT NULL, - event_id TEXT NOT NULL -); - -INSERT INTO stream_ordering_to_exterm (stream_ordering, room_id, event_id) - SELECT stream_ordering, room_id, event_id FROM event_forward_extremities - INNER JOIN ( - SELECT room_id, max(stream_ordering) as stream_ordering FROM events - INNER JOIN event_forward_extremities USING (room_id, event_id) - GROUP BY room_id - ) AS rms USING (room_id); - -CREATE INDEX stream_ordering_to_exterm_idx on stream_ordering_to_exterm( - stream_ordering -); - -CREATE INDEX stream_ordering_to_exterm_rm_idx on stream_ordering_to_exterm( - room_id, stream_ordering -); diff --git a/synapse/storage/schema/delta/36/readd_public_rooms.sql b/synapse/storage/schema/delta/36/readd_public_rooms.sql deleted file mode 100644 index 90d8fd18f9..0000000000 --- a/synapse/storage/schema/delta/36/readd_public_rooms.sql +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Re-add some entries to stream_ordering_to_exterm that were incorrectly deleted -INSERT INTO stream_ordering_to_exterm (stream_ordering, room_id, event_id) - SELECT - (SELECT stream_ordering FROM events where event_id = e.event_id) AS stream_ordering, - room_id, - event_id - FROM event_forward_extremities AS e - WHERE NOT EXISTS ( - SELECT room_id FROM stream_ordering_to_exterm AS s - WHERE s.room_id = e.room_id - ); diff --git a/synapse/storage/schema/delta/37/remove_auth_idx.py b/synapse/storage/schema/delta/37/remove_auth_idx.py deleted file mode 100644 index a377884169..0000000000 --- a/synapse/storage/schema/delta/37/remove_auth_idx.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.storage.engines import PostgresEngine -from synapse.storage.prepare_database import get_statements - -logger = logging.getLogger(__name__) - -DROP_INDICES = """ --- We only ever query based on event_id -DROP INDEX IF EXISTS state_events_room_id; -DROP INDEX IF EXISTS state_events_type; -DROP INDEX IF EXISTS state_events_state_key; - --- room_id is indexed elsewhere -DROP INDEX IF EXISTS current_state_events_room_id; -DROP INDEX IF EXISTS current_state_events_state_key; -DROP INDEX IF EXISTS current_state_events_type; - -DROP INDEX IF EXISTS transactions_have_ref; - --- (topological_ordering, stream_ordering, room_id) seems like a strange index, --- and is used incredibly rarely. -DROP INDEX IF EXISTS events_order_topo_stream_room; - --- an equivalent index to this actually gets re-created in delta 41, because it --- turned out that deleting it wasn't a great plan :/. In any case, let's --- delete it here, and delta 41 will create a new one with an added UNIQUE --- constraint -DROP INDEX IF EXISTS event_search_ev_idx; -""" - -POSTGRES_DROP_CONSTRAINT = """ -ALTER TABLE event_auth DROP CONSTRAINT IF EXISTS event_auth_event_id_auth_id_room_id_key; -""" - -SQLITE_DROP_CONSTRAINT = """ -DROP INDEX IF EXISTS evauth_edges_id; - -CREATE TABLE IF NOT EXISTS event_auth_new( - event_id TEXT NOT NULL, - auth_id TEXT NOT NULL, - room_id TEXT NOT NULL -); - -INSERT INTO event_auth_new - SELECT event_id, auth_id, room_id - FROM event_auth; - -DROP TABLE event_auth; - -ALTER TABLE event_auth_new RENAME TO event_auth; - -CREATE INDEX evauth_edges_id ON event_auth(event_id); -""" - - -def run_create(cur, database_engine, *args, **kwargs): - for statement in get_statements(DROP_INDICES.splitlines()): - cur.execute(statement) - - if isinstance(database_engine, PostgresEngine): - drop_constraint = POSTGRES_DROP_CONSTRAINT - else: - drop_constraint = SQLITE_DROP_CONSTRAINT - - for statement in get_statements(drop_constraint.splitlines()): - cur.execute(statement) - - -def run_upgrade(cur, database_engine, *args, **kwargs): - pass diff --git a/synapse/storage/schema/delta/37/user_threepids.sql b/synapse/storage/schema/delta/37/user_threepids.sql deleted file mode 100644 index cf7a90dd10..0000000000 --- a/synapse/storage/schema/delta/37/user_threepids.sql +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * Update any email addresses that were stored with mixed case into all - * lowercase - */ - - -- There may be "duplicate" emails (with different case) already in the table, - -- so we find them and move all but the most recently used account. - UPDATE user_threepids - SET medium = 'email_old' - WHERE medium = 'email' - AND address IN ( - -- We select all the addresses that are linked to the user_id that is NOT - -- the most recently created. - SELECT u.address - FROM - user_threepids AS u, - -- `duplicate_addresses` is a table of all the email addresses that - -- appear multiple times and when the binding was created - ( - SELECT lower(u1.address) AS address, max(u1.added_at) AS max_ts - FROM user_threepids AS u1 - INNER JOIN user_threepids AS u2 ON u1.medium = u2.medium AND lower(u1.address) = lower(u2.address) AND u1.address != u2.address - WHERE u1.medium = 'email' AND u2.medium = 'email' - GROUP BY lower(u1.address) - ) AS duplicate_addresses - WHERE - lower(u.address) = duplicate_addresses.address - AND u.added_at != max_ts -- NOT the most recently created - ); - - --- This update is now safe since we've removed the duplicate addresses. -UPDATE user_threepids SET address = LOWER(address) WHERE medium = 'email'; - - -/* Add an index for the select we do on passwored reset */ -CREATE INDEX user_threepids_medium_address on user_threepids (medium, address); diff --git a/synapse/storage/schema/delta/38/postgres_fts_gist.sql b/synapse/storage/schema/delta/38/postgres_fts_gist.sql deleted file mode 100644 index 515e6b8e84..0000000000 --- a/synapse/storage/schema/delta/38/postgres_fts_gist.sql +++ /dev/null @@ -1,19 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- We no longer do this given we back it out again in schema 47 - --- INSERT into background_updates (update_name, progress_json) --- VALUES ('event_search_postgres_gist', '{}'); diff --git a/synapse/storage/schema/delta/39/appservice_room_list.sql b/synapse/storage/schema/delta/39/appservice_room_list.sql deleted file mode 100644 index 74bdc49073..0000000000 --- a/synapse/storage/schema/delta/39/appservice_room_list.sql +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE appservice_room_list( - appservice_id TEXT NOT NULL, - network_id TEXT NOT NULL, - room_id TEXT NOT NULL -); - --- Each appservice can have multiple published room lists associated with them, --- keyed of a particular network_id -CREATE UNIQUE INDEX appservice_room_list_idx ON appservice_room_list( - appservice_id, network_id, room_id -); - -ALTER TABLE public_room_list_stream ADD COLUMN appservice_id TEXT; -ALTER TABLE public_room_list_stream ADD COLUMN network_id TEXT; diff --git a/synapse/storage/schema/delta/39/device_federation_stream_idx.sql b/synapse/storage/schema/delta/39/device_federation_stream_idx.sql deleted file mode 100644 index 00be801e90..0000000000 --- a/synapse/storage/schema/delta/39/device_federation_stream_idx.sql +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE INDEX device_federation_outbox_id ON device_federation_outbox(stream_id); diff --git a/synapse/storage/schema/delta/39/event_push_index.sql b/synapse/storage/schema/delta/39/event_push_index.sql deleted file mode 100644 index de2ad93e5c..0000000000 --- a/synapse/storage/schema/delta/39/event_push_index.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -INSERT INTO background_updates (update_name, progress_json) VALUES - ('event_push_actions_highlights_index', '{}'); diff --git a/synapse/storage/schema/delta/39/federation_out_position.sql b/synapse/storage/schema/delta/39/federation_out_position.sql deleted file mode 100644 index 5af814290b..0000000000 --- a/synapse/storage/schema/delta/39/federation_out_position.sql +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - CREATE TABLE federation_stream_position( - type TEXT NOT NULL, - stream_id INTEGER NOT NULL - ); - - INSERT INTO federation_stream_position (type, stream_id) VALUES ('federation', -1); - INSERT INTO federation_stream_position (type, stream_id) SELECT 'events', coalesce(max(stream_ordering), -1) FROM events; diff --git a/synapse/storage/schema/delta/39/membership_profile.sql b/synapse/storage/schema/delta/39/membership_profile.sql deleted file mode 100644 index 1bf911c8ab..0000000000 --- a/synapse/storage/schema/delta/39/membership_profile.sql +++ /dev/null @@ -1,20 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -ALTER TABLE room_memberships ADD COLUMN display_name TEXT; -ALTER TABLE room_memberships ADD COLUMN avatar_url TEXT; - -INSERT into background_updates (update_name, progress_json) - VALUES ('room_membership_profile_update', '{}'); diff --git a/synapse/storage/schema/delta/40/current_state_idx.sql b/synapse/storage/schema/delta/40/current_state_idx.sql deleted file mode 100644 index 7ffa189f39..0000000000 --- a/synapse/storage/schema/delta/40/current_state_idx.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2017 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -INSERT INTO background_updates (update_name, progress_json) VALUES - ('current_state_members_idx', '{}'); diff --git a/synapse/storage/schema/delta/40/device_inbox.sql b/synapse/storage/schema/delta/40/device_inbox.sql deleted file mode 100644 index b9fe1f0480..0000000000 --- a/synapse/storage/schema/delta/40/device_inbox.sql +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- turn the pre-fill startup query into a index-only scan on postgresql. -INSERT into background_updates (update_name, progress_json) - VALUES ('device_inbox_stream_index', '{}'); - -INSERT into background_updates (update_name, progress_json, depends_on) - VALUES ('device_inbox_stream_drop', '{}', 'device_inbox_stream_index'); diff --git a/synapse/storage/schema/delta/40/device_list_streams.sql b/synapse/storage/schema/delta/40/device_list_streams.sql deleted file mode 100644 index dd6dcb65f1..0000000000 --- a/synapse/storage/schema/delta/40/device_list_streams.sql +++ /dev/null @@ -1,60 +0,0 @@ -/* Copyright 2017 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Cache of remote devices. -CREATE TABLE device_lists_remote_cache ( - user_id TEXT NOT NULL, - device_id TEXT NOT NULL, - content TEXT NOT NULL -); - --- The last update we got for a user. Empty if we're not receiving updates for --- that user. -CREATE TABLE device_lists_remote_extremeties ( - user_id TEXT NOT NULL, - stream_id TEXT NOT NULL -); - --- we used to create non-unique indexes on these tables, but as of update 52 we create --- unique indexes concurrently: --- --- CREATE INDEX device_lists_remote_cache_id ON device_lists_remote_cache(user_id, device_id); --- CREATE INDEX device_lists_remote_extremeties_id ON device_lists_remote_extremeties(user_id, stream_id); - - --- Stream of device lists updates. Includes both local and remotes -CREATE TABLE device_lists_stream ( - stream_id BIGINT NOT NULL, - user_id TEXT NOT NULL, - device_id TEXT NOT NULL -); - -CREATE INDEX device_lists_stream_id ON device_lists_stream(stream_id, user_id); - - --- The stream of updates to send to other servers. We keep at least one row --- per user that was sent so that the prev_id for any new updates can be --- calculated -CREATE TABLE device_lists_outbound_pokes ( - destination TEXT NOT NULL, - stream_id BIGINT NOT NULL, - user_id TEXT NOT NULL, - device_id TEXT NOT NULL, - sent BOOLEAN NOT NULL, - ts BIGINT NOT NULL -- So that in future we can clear out pokes to dead servers -); - -CREATE INDEX device_lists_outbound_pokes_id ON device_lists_outbound_pokes(destination, stream_id); -CREATE INDEX device_lists_outbound_pokes_user ON device_lists_outbound_pokes(destination, user_id); diff --git a/synapse/storage/schema/delta/40/event_push_summary.sql b/synapse/storage/schema/delta/40/event_push_summary.sql deleted file mode 100644 index 3918f0b794..0000000000 --- a/synapse/storage/schema/delta/40/event_push_summary.sql +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright 2017 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Aggregate of old notification counts that have been deleted out of the --- main event_push_actions table. This count does not include those that were --- highlights, as they remain in the event_push_actions table. -CREATE TABLE event_push_summary ( - user_id TEXT NOT NULL, - room_id TEXT NOT NULL, - notif_count BIGINT NOT NULL, - stream_ordering BIGINT NOT NULL -); - -CREATE INDEX event_push_summary_user_rm ON event_push_summary(user_id, room_id); - - --- The stream ordering up to which we have aggregated the event_push_actions --- table into event_push_summary -CREATE TABLE event_push_summary_stream_ordering ( - Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. - stream_ordering BIGINT NOT NULL, - CHECK (Lock='X') -); - -INSERT INTO event_push_summary_stream_ordering (stream_ordering) VALUES (0); diff --git a/synapse/storage/schema/delta/40/pushers.sql b/synapse/storage/schema/delta/40/pushers.sql deleted file mode 100644 index 054a223f14..0000000000 --- a/synapse/storage/schema/delta/40/pushers.sql +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2017 Vector Creations Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS pushers2 ( - id BIGINT PRIMARY KEY, - user_name TEXT NOT NULL, - access_token BIGINT DEFAULT NULL, - profile_tag TEXT NOT NULL, - kind TEXT NOT NULL, - app_id TEXT NOT NULL, - app_display_name TEXT NOT NULL, - device_display_name TEXT NOT NULL, - pushkey TEXT NOT NULL, - ts BIGINT NOT NULL, - lang TEXT, - data TEXT, - last_stream_ordering INTEGER, - last_success BIGINT, - failing_since BIGINT, - UNIQUE (app_id, pushkey, user_name) -); - -INSERT INTO pushers2 SELECT * FROM PUSHERS; - -DROP TABLE PUSHERS; - -ALTER TABLE pushers2 RENAME TO pushers; diff --git a/synapse/storage/schema/delta/41/device_list_stream_idx.sql b/synapse/storage/schema/delta/41/device_list_stream_idx.sql deleted file mode 100644 index b7bee8b692..0000000000 --- a/synapse/storage/schema/delta/41/device_list_stream_idx.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2017 Vector Creations Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -INSERT into background_updates (update_name, progress_json) - VALUES ('device_lists_stream_idx', '{}'); diff --git a/synapse/storage/schema/delta/41/device_outbound_index.sql b/synapse/storage/schema/delta/41/device_outbound_index.sql deleted file mode 100644 index 62f0b9892b..0000000000 --- a/synapse/storage/schema/delta/41/device_outbound_index.sql +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2017 Vector Creations Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE INDEX device_lists_outbound_pokes_stream ON device_lists_outbound_pokes(stream_id); diff --git a/synapse/storage/schema/delta/41/event_search_event_id_idx.sql b/synapse/storage/schema/delta/41/event_search_event_id_idx.sql deleted file mode 100644 index 5d9cfecf36..0000000000 --- a/synapse/storage/schema/delta/41/event_search_event_id_idx.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2017 Vector Creations Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -INSERT into background_updates (update_name, progress_json) - VALUES ('event_search_event_id_idx', '{}'); diff --git a/synapse/storage/schema/delta/41/ratelimit.sql b/synapse/storage/schema/delta/41/ratelimit.sql deleted file mode 100644 index a194bf0238..0000000000 --- a/synapse/storage/schema/delta/41/ratelimit.sql +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2017 Vector Creations Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE ratelimit_override ( - user_id TEXT NOT NULL, - messages_per_second BIGINT, - burst_count BIGINT -); - -CREATE UNIQUE INDEX ratelimit_override_idx ON ratelimit_override(user_id); diff --git a/synapse/storage/schema/delta/42/current_state_delta.sql b/synapse/storage/schema/delta/42/current_state_delta.sql deleted file mode 100644 index d28851aff8..0000000000 --- a/synapse/storage/schema/delta/42/current_state_delta.sql +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright 2017 Vector Creations Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -CREATE TABLE current_state_delta_stream ( - stream_id BIGINT NOT NULL, - room_id TEXT NOT NULL, - type TEXT NOT NULL, - state_key TEXT NOT NULL, - event_id TEXT, -- Is null if the key was removed - prev_event_id TEXT -- Is null if the key was added -); - -CREATE INDEX current_state_delta_stream_idx ON current_state_delta_stream(stream_id); diff --git a/synapse/storage/schema/delta/42/device_list_last_id.sql b/synapse/storage/schema/delta/42/device_list_last_id.sql deleted file mode 100644 index 9ab8c14fa3..0000000000 --- a/synapse/storage/schema/delta/42/device_list_last_id.sql +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright 2017 Vector Creations Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - --- Table of last stream_id that we sent to destination for user_id. This is --- used to fill out the `prev_id` fields of outbound device list updates. -CREATE TABLE device_lists_outbound_last_success ( - destination TEXT NOT NULL, - user_id TEXT NOT NULL, - stream_id BIGINT NOT NULL -); - -INSERT INTO device_lists_outbound_last_success - SELECT destination, user_id, coalesce(max(stream_id), 0) as stream_id - FROM device_lists_outbound_pokes - WHERE sent = (1 = 1) -- sqlite doesn't have inbuilt boolean values - GROUP BY destination, user_id; - -CREATE INDEX device_lists_outbound_last_success_idx ON device_lists_outbound_last_success( - destination, user_id, stream_id -); diff --git a/synapse/storage/schema/delta/42/event_auth_state_only.sql b/synapse/storage/schema/delta/42/event_auth_state_only.sql deleted file mode 100644 index b8821ac759..0000000000 --- a/synapse/storage/schema/delta/42/event_auth_state_only.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2017 Vector Creations Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -INSERT INTO background_updates (update_name, progress_json) VALUES - ('event_auth_state_only', '{}'); diff --git a/synapse/storage/schema/delta/42/user_dir.py b/synapse/storage/schema/delta/42/user_dir.py deleted file mode 100644 index 506f326f4d..0000000000 --- a/synapse/storage/schema/delta/42/user_dir.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright 2017 Vector Creations Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.storage.engines import PostgresEngine, Sqlite3Engine -from synapse.storage.prepare_database import get_statements - -logger = logging.getLogger(__name__) - - -BOTH_TABLES = """ -CREATE TABLE user_directory_stream_pos ( - Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. - stream_id BIGINT, - CHECK (Lock='X') -); - -INSERT INTO user_directory_stream_pos (stream_id) VALUES (null); - -CREATE TABLE user_directory ( - user_id TEXT NOT NULL, - room_id TEXT NOT NULL, -- A room_id that we know the user is joined to - display_name TEXT, - avatar_url TEXT -); - -CREATE INDEX user_directory_room_idx ON user_directory(room_id); -CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id); - -CREATE TABLE users_in_pubic_room ( - user_id TEXT NOT NULL, - room_id TEXT NOT NULL -- A room_id that we know is public -); - -CREATE INDEX users_in_pubic_room_room_idx ON users_in_pubic_room(room_id); -CREATE UNIQUE INDEX users_in_pubic_room_user_idx ON users_in_pubic_room(user_id); -""" - - -POSTGRES_TABLE = """ -CREATE TABLE user_directory_search ( - user_id TEXT NOT NULL, - vector tsvector -); - -CREATE INDEX user_directory_search_fts_idx ON user_directory_search USING gin(vector); -CREATE UNIQUE INDEX user_directory_search_user_idx ON user_directory_search(user_id); -""" - - -SQLITE_TABLE = """ -CREATE VIRTUAL TABLE user_directory_search - USING fts4 ( user_id, value ); -""" - - -def run_create(cur, database_engine, *args, **kwargs): - for statement in get_statements(BOTH_TABLES.splitlines()): - cur.execute(statement) - - if isinstance(database_engine, PostgresEngine): - for statement in get_statements(POSTGRES_TABLE.splitlines()): - cur.execute(statement) - elif isinstance(database_engine, Sqlite3Engine): - for statement in get_statements(SQLITE_TABLE.splitlines()): - cur.execute(statement) - else: - raise Exception("Unrecognized database engine") - - -def run_upgrade(*args, **kwargs): - pass diff --git a/synapse/storage/schema/delta/43/blocked_rooms.sql b/synapse/storage/schema/delta/43/blocked_rooms.sql deleted file mode 100644 index 0e3cd143ff..0000000000 --- a/synapse/storage/schema/delta/43/blocked_rooms.sql +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2017 Vector Creations Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE blocked_rooms ( - room_id TEXT NOT NULL, - user_id TEXT NOT NULL -- Admin who blocked the room -); - -CREATE UNIQUE INDEX blocked_rooms_idx ON blocked_rooms(room_id); diff --git a/synapse/storage/schema/delta/43/quarantine_media.sql b/synapse/storage/schema/delta/43/quarantine_media.sql deleted file mode 100644 index 630907ec4f..0000000000 --- a/synapse/storage/schema/delta/43/quarantine_media.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2017 Vector Creations Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -ALTER TABLE local_media_repository ADD COLUMN quarantined_by TEXT; -ALTER TABLE remote_media_cache ADD COLUMN quarantined_by TEXT; diff --git a/synapse/storage/schema/delta/43/url_cache.sql b/synapse/storage/schema/delta/43/url_cache.sql deleted file mode 100644 index 45ebe020da..0000000000 --- a/synapse/storage/schema/delta/43/url_cache.sql +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2017 Vector Creations Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -ALTER TABLE local_media_repository ADD COLUMN url_cache TEXT; diff --git a/synapse/storage/schema/delta/43/user_share.sql b/synapse/storage/schema/delta/43/user_share.sql deleted file mode 100644 index ee7062abe4..0000000000 --- a/synapse/storage/schema/delta/43/user_share.sql +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright 2017 Vector Creations Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Table keeping track of who shares a room with who. We only keep track --- of this for local users, so `user_id` is local users only (but we do keep track --- of which remote users share a room) -CREATE TABLE users_who_share_rooms ( - user_id TEXT NOT NULL, - other_user_id TEXT NOT NULL, - room_id TEXT NOT NULL, - share_private BOOLEAN NOT NULL -- is the shared room private? i.e. they share a private room -); - - -CREATE UNIQUE INDEX users_who_share_rooms_u_idx ON users_who_share_rooms(user_id, other_user_id); -CREATE INDEX users_who_share_rooms_r_idx ON users_who_share_rooms(room_id); -CREATE INDEX users_who_share_rooms_o_idx ON users_who_share_rooms(other_user_id); - - --- Make sure that we populate the table initially -UPDATE user_directory_stream_pos SET stream_id = NULL; diff --git a/synapse/storage/schema/delta/44/expire_url_cache.sql b/synapse/storage/schema/delta/44/expire_url_cache.sql deleted file mode 100644 index b12f9b2ebf..0000000000 --- a/synapse/storage/schema/delta/44/expire_url_cache.sql +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright 2017 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- this didn't work on SQLite 3.7 (because of lack of partial indexes), so was --- removed and replaced with 46/local_media_repository_url_idx.sql. --- --- CREATE INDEX local_media_repository_url_idx ON local_media_repository(created_ts) WHERE url_cache IS NOT NULL; - --- we need to change `expires` to `expires_ts` so that we can index on it. SQLite doesn't support --- indices on expressions until 3.9. -CREATE TABLE local_media_repository_url_cache_new( - url TEXT, - response_code INTEGER, - etag TEXT, - expires_ts BIGINT, - og TEXT, - media_id TEXT, - download_ts BIGINT -); - -INSERT INTO local_media_repository_url_cache_new - SELECT url, response_code, etag, expires + download_ts, og, media_id, download_ts FROM local_media_repository_url_cache; - -DROP TABLE local_media_repository_url_cache; -ALTER TABLE local_media_repository_url_cache_new RENAME TO local_media_repository_url_cache; - -CREATE INDEX local_media_repository_url_cache_expires_idx ON local_media_repository_url_cache(expires_ts); -CREATE INDEX local_media_repository_url_cache_by_url_download_ts ON local_media_repository_url_cache(url, download_ts); -CREATE INDEX local_media_repository_url_cache_media_idx ON local_media_repository_url_cache(media_id); diff --git a/synapse/storage/schema/delta/45/group_server.sql b/synapse/storage/schema/delta/45/group_server.sql deleted file mode 100644 index b2333848a0..0000000000 --- a/synapse/storage/schema/delta/45/group_server.sql +++ /dev/null @@ -1,167 +0,0 @@ -/* Copyright 2017 Vector Creations Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE groups ( - group_id TEXT NOT NULL, - name TEXT, -- the display name of the room - avatar_url TEXT, - short_description TEXT, - long_description TEXT -); - -CREATE UNIQUE INDEX groups_idx ON groups(group_id); - - --- list of users the group server thinks are joined -CREATE TABLE group_users ( - group_id TEXT NOT NULL, - user_id TEXT NOT NULL, - is_admin BOOLEAN NOT NULL, - is_public BOOLEAN NOT NULL -- whether the users membership can be seen by everyone -); - - -CREATE INDEX groups_users_g_idx ON group_users(group_id, user_id); -CREATE INDEX groups_users_u_idx ON group_users(user_id); - --- list of users the group server thinks are invited -CREATE TABLE group_invites ( - group_id TEXT NOT NULL, - user_id TEXT NOT NULL -); - -CREATE INDEX groups_invites_g_idx ON group_invites(group_id, user_id); -CREATE INDEX groups_invites_u_idx ON group_invites(user_id); - - -CREATE TABLE group_rooms ( - group_id TEXT NOT NULL, - room_id TEXT NOT NULL, - is_public BOOLEAN NOT NULL -- whether the room can be seen by everyone -); - -CREATE UNIQUE INDEX groups_rooms_g_idx ON group_rooms(group_id, room_id); -CREATE INDEX groups_rooms_r_idx ON group_rooms(room_id); - - --- Rooms to include in the summary -CREATE TABLE group_summary_rooms ( - group_id TEXT NOT NULL, - room_id TEXT NOT NULL, - category_id TEXT NOT NULL, - room_order BIGINT NOT NULL, - is_public BOOLEAN NOT NULL, -- whether the room should be show to everyone - UNIQUE (group_id, category_id, room_id, room_order), - CHECK (room_order > 0) -); - -CREATE UNIQUE INDEX group_summary_rooms_g_idx ON group_summary_rooms(group_id, room_id, category_id); - - --- Categories to include in the summary -CREATE TABLE group_summary_room_categories ( - group_id TEXT NOT NULL, - category_id TEXT NOT NULL, - cat_order BIGINT NOT NULL, - UNIQUE (group_id, category_id, cat_order), - CHECK (cat_order > 0) -); - --- The categories in the group -CREATE TABLE group_room_categories ( - group_id TEXT NOT NULL, - category_id TEXT NOT NULL, - profile TEXT NOT NULL, - is_public BOOLEAN NOT NULL, -- whether the category should be show to everyone - UNIQUE (group_id, category_id) -); - --- The users to include in the group summary -CREATE TABLE group_summary_users ( - group_id TEXT NOT NULL, - user_id TEXT NOT NULL, - role_id TEXT NOT NULL, - user_order BIGINT NOT NULL, - is_public BOOLEAN NOT NULL -- whether the user should be show to everyone -); - -CREATE INDEX group_summary_users_g_idx ON group_summary_users(group_id); - --- The roles to include in the group summary -CREATE TABLE group_summary_roles ( - group_id TEXT NOT NULL, - role_id TEXT NOT NULL, - role_order BIGINT NOT NULL, - UNIQUE (group_id, role_id, role_order), - CHECK (role_order > 0) -); - - --- The roles in a groups -CREATE TABLE group_roles ( - group_id TEXT NOT NULL, - role_id TEXT NOT NULL, - profile TEXT NOT NULL, - is_public BOOLEAN NOT NULL, -- whether the role should be show to everyone - UNIQUE (group_id, role_id) -); - - --- List of attestations we've given out and need to renew -CREATE TABLE group_attestations_renewals ( - group_id TEXT NOT NULL, - user_id TEXT NOT NULL, - valid_until_ms BIGINT NOT NULL -); - -CREATE INDEX group_attestations_renewals_g_idx ON group_attestations_renewals(group_id, user_id); -CREATE INDEX group_attestations_renewals_u_idx ON group_attestations_renewals(user_id); -CREATE INDEX group_attestations_renewals_v_idx ON group_attestations_renewals(valid_until_ms); - - --- List of attestations we've received from remotes and are interested in. -CREATE TABLE group_attestations_remote ( - group_id TEXT NOT NULL, - user_id TEXT NOT NULL, - valid_until_ms BIGINT NOT NULL, - attestation_json TEXT NOT NULL -); - -CREATE INDEX group_attestations_remote_g_idx ON group_attestations_remote(group_id, user_id); -CREATE INDEX group_attestations_remote_u_idx ON group_attestations_remote(user_id); -CREATE INDEX group_attestations_remote_v_idx ON group_attestations_remote(valid_until_ms); - - --- The group membership for the HS's users -CREATE TABLE local_group_membership ( - group_id TEXT NOT NULL, - user_id TEXT NOT NULL, - is_admin BOOLEAN NOT NULL, - membership TEXT NOT NULL, - is_publicised BOOLEAN NOT NULL, -- if the user is publicising their membership - content TEXT NOT NULL -); - -CREATE INDEX local_group_membership_u_idx ON local_group_membership(user_id, group_id); -CREATE INDEX local_group_membership_g_idx ON local_group_membership(group_id); - - -CREATE TABLE local_group_updates ( - stream_id BIGINT NOT NULL, - group_id TEXT NOT NULL, - user_id TEXT NOT NULL, - type TEXT NOT NULL, - content TEXT NOT NULL -); diff --git a/synapse/storage/schema/delta/45/profile_cache.sql b/synapse/storage/schema/delta/45/profile_cache.sql deleted file mode 100644 index e5ddc84df0..0000000000 --- a/synapse/storage/schema/delta/45/profile_cache.sql +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright 2017 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - --- A subset of remote users whose profiles we have cached. --- Whether a user is in this table or not is defined by the storage function --- `is_subscribed_remote_profile_for_user` -CREATE TABLE remote_profile_cache ( - user_id TEXT NOT NULL, - displayname TEXT, - avatar_url TEXT, - last_check BIGINT NOT NULL -); - -CREATE UNIQUE INDEX remote_profile_cache_user_id ON remote_profile_cache(user_id); -CREATE INDEX remote_profile_cache_time ON remote_profile_cache(last_check); diff --git a/synapse/storage/schema/delta/46/drop_refresh_tokens.sql b/synapse/storage/schema/delta/46/drop_refresh_tokens.sql deleted file mode 100644 index 68c48a89a9..0000000000 --- a/synapse/storage/schema/delta/46/drop_refresh_tokens.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2017 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* we no longer use (or create) the refresh_tokens table */ -DROP TABLE IF EXISTS refresh_tokens; diff --git a/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql b/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql deleted file mode 100644 index bb307889c1..0000000000 --- a/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2017 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- drop the unique constraint on deleted_pushers so that we can just insert --- into it rather than upserting. - -CREATE TABLE deleted_pushers2 ( - stream_id BIGINT NOT NULL, - app_id TEXT NOT NULL, - pushkey TEXT NOT NULL, - user_id TEXT NOT NULL -); - -INSERT INTO deleted_pushers2 (stream_id, app_id, pushkey, user_id) - SELECT stream_id, app_id, pushkey, user_id from deleted_pushers; - -DROP TABLE deleted_pushers; -ALTER TABLE deleted_pushers2 RENAME TO deleted_pushers; - --- create the index after doing the inserts because that's more efficient. --- it also means we can give it the same name as the old one without renaming. -CREATE INDEX deleted_pushers_stream_id ON deleted_pushers (stream_id); - diff --git a/synapse/storage/schema/delta/46/group_server.sql b/synapse/storage/schema/delta/46/group_server.sql deleted file mode 100644 index 097679bc9a..0000000000 --- a/synapse/storage/schema/delta/46/group_server.sql +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright 2017 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE groups_new ( - group_id TEXT NOT NULL, - name TEXT, -- the display name of the room - avatar_url TEXT, - short_description TEXT, - long_description TEXT, - is_public BOOL NOT NULL -- whether non-members can access group APIs -); - --- NB: awful hack to get the default to be true on postgres and 1 on sqlite -INSERT INTO groups_new - SELECT group_id, name, avatar_url, short_description, long_description, (1=1) FROM groups; - -DROP TABLE groups; -ALTER TABLE groups_new RENAME TO groups; - -CREATE UNIQUE INDEX groups_idx ON groups(group_id); diff --git a/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql b/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql deleted file mode 100644 index bbfc7f5d1a..0000000000 --- a/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright 2017 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- register a background update which will recreate the --- local_media_repository_url_idx index. --- --- We do this as a bg update not because it is a particularly onerous --- operation, but because we'd like it to be a partial index if possible, and --- the background_index_update code will understand whether we are on --- postgres or sqlite and behave accordingly. -INSERT INTO background_updates (update_name, progress_json) VALUES - ('local_media_repository_url_idx', '{}'); diff --git a/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql b/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql deleted file mode 100644 index cb0d5a2576..0000000000 --- a/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2017 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- change the user_directory table to also cover global local user profiles --- rather than just profiles within specific rooms. - -CREATE TABLE user_directory2 ( - user_id TEXT NOT NULL, - room_id TEXT, - display_name TEXT, - avatar_url TEXT -); - -INSERT INTO user_directory2(user_id, room_id, display_name, avatar_url) - SELECT user_id, room_id, display_name, avatar_url from user_directory; - -DROP TABLE user_directory; -ALTER TABLE user_directory2 RENAME TO user_directory; - --- create indexes after doing the inserts because that's more efficient. --- it also means we can give it the same name as the old one without renaming. -CREATE INDEX user_directory_room_idx ON user_directory(room_id); -CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id); diff --git a/synapse/storage/schema/delta/46/user_dir_typos.sql b/synapse/storage/schema/delta/46/user_dir_typos.sql deleted file mode 100644 index d9505f8da1..0000000000 --- a/synapse/storage/schema/delta/46/user_dir_typos.sql +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright 2017 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- this is just embarassing :| -ALTER TABLE users_in_pubic_room RENAME TO users_in_public_rooms; - --- this is only 300K rows on matrix.org and takes ~3s to generate the index, --- so is hopefully not going to block anyone else for that long... -CREATE INDEX users_in_public_rooms_room_idx ON users_in_public_rooms(room_id); -CREATE UNIQUE INDEX users_in_public_rooms_user_idx ON users_in_public_rooms(user_id); -DROP INDEX users_in_pubic_room_room_idx; -DROP INDEX users_in_pubic_room_user_idx; diff --git a/synapse/storage/schema/delta/47/last_access_media.sql b/synapse/storage/schema/delta/47/last_access_media.sql deleted file mode 100644 index f505fb22b5..0000000000 --- a/synapse/storage/schema/delta/47/last_access_media.sql +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -ALTER TABLE local_media_repository ADD COLUMN last_access_ts BIGINT; diff --git a/synapse/storage/schema/delta/47/postgres_fts_gin.sql b/synapse/storage/schema/delta/47/postgres_fts_gin.sql deleted file mode 100644 index 31d7a817eb..0000000000 --- a/synapse/storage/schema/delta/47/postgres_fts_gin.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -INSERT into background_updates (update_name, progress_json) - VALUES ('event_search_postgres_gin', '{}'); diff --git a/synapse/storage/schema/delta/47/push_actions_staging.sql b/synapse/storage/schema/delta/47/push_actions_staging.sql deleted file mode 100644 index edccf4a96f..0000000000 --- a/synapse/storage/schema/delta/47/push_actions_staging.sql +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Temporary staging area for push actions that have been calculated for an --- event, but the event hasn't yet been persisted. --- When the event is persisted the rows are moved over to the --- event_push_actions table. -CREATE TABLE event_push_actions_staging ( - event_id TEXT NOT NULL, - user_id TEXT NOT NULL, - actions TEXT NOT NULL, - notif SMALLINT NOT NULL, - highlight SMALLINT NOT NULL -); - -CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging(event_id); diff --git a/synapse/storage/schema/delta/47/state_group_seq.py b/synapse/storage/schema/delta/47/state_group_seq.py deleted file mode 100644 index 9fd1ccf6f7..0000000000 --- a/synapse/storage/schema/delta/47/state_group_seq.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.storage.engines import PostgresEngine - - -def run_create(cur, database_engine, *args, **kwargs): - if isinstance(database_engine, PostgresEngine): - # if we already have some state groups, we want to start making new - # ones with a higher id. - cur.execute("SELECT max(id) FROM state_groups") - row = cur.fetchone() - - if row[0] is None: - start_val = 1 - else: - start_val = row[0] + 1 - - cur.execute("CREATE SEQUENCE state_group_id_seq START WITH %s", (start_val,)) - - -def run_upgrade(*args, **kwargs): - pass diff --git a/synapse/storage/schema/delta/48/add_user_consent.sql b/synapse/storage/schema/delta/48/add_user_consent.sql deleted file mode 100644 index 5237491506..0000000000 --- a/synapse/storage/schema/delta/48/add_user_consent.sql +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* record the version of the privacy policy the user has consented to - */ -ALTER TABLE users ADD COLUMN consent_version TEXT; diff --git a/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql b/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql deleted file mode 100644 index 9248b0b24a..0000000000 --- a/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -INSERT into background_updates (update_name, progress_json) - VALUES ('user_ips_last_seen_index', '{}'); diff --git a/synapse/storage/schema/delta/48/deactivated_users.sql b/synapse/storage/schema/delta/48/deactivated_users.sql deleted file mode 100644 index e9013a6969..0000000000 --- a/synapse/storage/schema/delta/48/deactivated_users.sql +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * Store any accounts that have been requested to be deactivated. - * We part the account from all the rooms its in when its - * deactivated. This can take some time and synapse may be restarted - * before it completes, so store the user IDs here until the process - * is complete. - */ -CREATE TABLE users_pending_deactivation ( - user_id TEXT NOT NULL -); diff --git a/synapse/storage/schema/delta/48/group_unique_indexes.py b/synapse/storage/schema/delta/48/group_unique_indexes.py deleted file mode 100644 index 49f5f2c003..0000000000 --- a/synapse/storage/schema/delta/48/group_unique_indexes.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.storage.engines import PostgresEngine -from synapse.storage.prepare_database import get_statements - -FIX_INDEXES = """ --- rebuild indexes as uniques -DROP INDEX groups_invites_g_idx; -CREATE UNIQUE INDEX group_invites_g_idx ON group_invites(group_id, user_id); -DROP INDEX groups_users_g_idx; -CREATE UNIQUE INDEX group_users_g_idx ON group_users(group_id, user_id); - --- rename other indexes to actually match their table names.. -DROP INDEX groups_users_u_idx; -CREATE INDEX group_users_u_idx ON group_users(user_id); -DROP INDEX groups_invites_u_idx; -CREATE INDEX group_invites_u_idx ON group_invites(user_id); -DROP INDEX groups_rooms_g_idx; -CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms(group_id, room_id); -DROP INDEX groups_rooms_r_idx; -CREATE INDEX group_rooms_r_idx ON group_rooms(room_id); -""" - - -def run_create(cur, database_engine, *args, **kwargs): - rowid = "ctid" if isinstance(database_engine, PostgresEngine) else "rowid" - - # remove duplicates from group_users & group_invites tables - cur.execute( - """ - DELETE FROM group_users WHERE %s NOT IN ( - SELECT min(%s) FROM group_users GROUP BY group_id, user_id - ); - """ - % (rowid, rowid) - ) - cur.execute( - """ - DELETE FROM group_invites WHERE %s NOT IN ( - SELECT min(%s) FROM group_invites GROUP BY group_id, user_id - ); - """ - % (rowid, rowid) - ) - - for statement in get_statements(FIX_INDEXES.splitlines()): - cur.execute(statement) - - -def run_upgrade(*args, **kwargs): - pass diff --git a/synapse/storage/schema/delta/48/groups_joinable.sql b/synapse/storage/schema/delta/48/groups_joinable.sql deleted file mode 100644 index ce26eaf0c9..0000000000 --- a/synapse/storage/schema/delta/48/groups_joinable.sql +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * This isn't a real ENUM because sqlite doesn't support it - * and we use a default of NULL for inserted rows and interpret - * NULL at the python store level as necessary so that existing - * rows are given the correct default policy. - */ -ALTER TABLE groups ADD COLUMN join_policy TEXT NOT NULL DEFAULT 'invite'; diff --git a/synapse/storage/schema/delta/49/add_user_consent_server_notice_sent.sql b/synapse/storage/schema/delta/49/add_user_consent_server_notice_sent.sql deleted file mode 100644 index 14dcf18d73..0000000000 --- a/synapse/storage/schema/delta/49/add_user_consent_server_notice_sent.sql +++ /dev/null @@ -1,20 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* record whether we have sent a server notice about consenting to the - * privacy policy. Specifically records the version of the policy we sent - * a message about. - */ -ALTER TABLE users ADD COLUMN consent_server_notice_sent TEXT; diff --git a/synapse/storage/schema/delta/49/add_user_daily_visits.sql b/synapse/storage/schema/delta/49/add_user_daily_visits.sql deleted file mode 100644 index 3dd478196f..0000000000 --- a/synapse/storage/schema/delta/49/add_user_daily_visits.sql +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -CREATE TABLE user_daily_visits ( user_id TEXT NOT NULL, - device_id TEXT, - timestamp BIGINT NOT NULL ); -CREATE INDEX user_daily_visits_uts_idx ON user_daily_visits(user_id, timestamp); -CREATE INDEX user_daily_visits_ts_idx ON user_daily_visits(timestamp); diff --git a/synapse/storage/schema/delta/49/add_user_ips_last_seen_only_index.sql b/synapse/storage/schema/delta/49/add_user_ips_last_seen_only_index.sql deleted file mode 100644 index 3a4ed59b5b..0000000000 --- a/synapse/storage/schema/delta/49/add_user_ips_last_seen_only_index.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -INSERT into background_updates (update_name, progress_json) - VALUES ('user_ips_last_seen_only_index', '{}'); diff --git a/synapse/storage/schema/delta/50/add_creation_ts_users_index.sql b/synapse/storage/schema/delta/50/add_creation_ts_users_index.sql deleted file mode 100644 index c93ae47532..0000000000 --- a/synapse/storage/schema/delta/50/add_creation_ts_users_index.sql +++ /dev/null @@ -1,19 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - - -INSERT into background_updates (update_name, progress_json) - VALUES ('users_creation_ts', '{}'); diff --git a/synapse/storage/schema/delta/50/erasure_store.sql b/synapse/storage/schema/delta/50/erasure_store.sql deleted file mode 100644 index 5d8641a9ab..0000000000 --- a/synapse/storage/schema/delta/50/erasure_store.sql +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- a table of users who have requested that their details be erased -CREATE TABLE erased_users ( - user_id TEXT NOT NULL -); - -CREATE UNIQUE INDEX erased_users_user ON erased_users(user_id); diff --git a/synapse/storage/schema/delta/50/make_event_content_nullable.py b/synapse/storage/schema/delta/50/make_event_content_nullable.py deleted file mode 100644 index b1684a8441..0000000000 --- a/synapse/storage/schema/delta/50/make_event_content_nullable.py +++ /dev/null @@ -1,96 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -We want to stop populating 'event.content', so we need to make it nullable. - -If this has to be rolled back, then the following should populate the missing data: - -Postgres: - - UPDATE events SET content=(ej.json::json)->'content' FROM event_json ej - WHERE ej.event_id = events.event_id AND - stream_ordering < ( - SELECT stream_ordering FROM events WHERE content IS NOT NULL - ORDER BY stream_ordering LIMIT 1 - ); - - UPDATE events SET content=(ej.json::json)->'content' FROM event_json ej - WHERE ej.event_id = events.event_id AND - stream_ordering > ( - SELECT stream_ordering FROM events WHERE content IS NOT NULL - ORDER BY stream_ordering DESC LIMIT 1 - ); - -SQLite: - - UPDATE events SET content=( - SELECT json_extract(json,'$.content') FROM event_json ej - WHERE ej.event_id = events.event_id - ) - WHERE - stream_ordering < ( - SELECT stream_ordering FROM events WHERE content IS NOT NULL - ORDER BY stream_ordering LIMIT 1 - ) - OR stream_ordering > ( - SELECT stream_ordering FROM events WHERE content IS NOT NULL - ORDER BY stream_ordering DESC LIMIT 1 - ); - -""" - -import logging - -from synapse.storage.engines import PostgresEngine - -logger = logging.getLogger(__name__) - - -def run_create(cur, database_engine, *args, **kwargs): - pass - - -def run_upgrade(cur, database_engine, *args, **kwargs): - if isinstance(database_engine, PostgresEngine): - cur.execute( - """ - ALTER TABLE events ALTER COLUMN content DROP NOT NULL; - """ - ) - return - - # sqlite is an arse about this. ref: https://www.sqlite.org/lang_altertable.html - - cur.execute( - "SELECT sql FROM sqlite_master WHERE tbl_name='events' AND type='table'" - ) - (oldsql,) = cur.fetchone() - - sql = oldsql.replace("content TEXT NOT NULL", "content TEXT") - if sql == oldsql: - raise Exception("Couldn't find null constraint to drop in %s" % oldsql) - - logger.info("Replacing definition of 'events' with: %s", sql) - - cur.execute("PRAGMA schema_version") - (oldver,) = cur.fetchone() - cur.execute("PRAGMA writable_schema=ON") - cur.execute( - "UPDATE sqlite_master SET sql=? WHERE tbl_name='events' AND type='table'", - (sql,), - ) - cur.execute("PRAGMA schema_version=%i" % (oldver + 1,)) - cur.execute("PRAGMA writable_schema=OFF") diff --git a/synapse/storage/schema/delta/51/e2e_room_keys.sql b/synapse/storage/schema/delta/51/e2e_room_keys.sql deleted file mode 100644 index c0e66a697d..0000000000 --- a/synapse/storage/schema/delta/51/e2e_room_keys.sql +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2017 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- users' optionally backed up encrypted e2e sessions -CREATE TABLE e2e_room_keys ( - user_id TEXT NOT NULL, - room_id TEXT NOT NULL, - session_id TEXT NOT NULL, - version TEXT NOT NULL, - first_message_index INT, - forwarded_count INT, - is_verified BOOLEAN, - session_data TEXT NOT NULL -); - -CREATE UNIQUE INDEX e2e_room_keys_idx ON e2e_room_keys(user_id, room_id, session_id); - --- the metadata for each generation of encrypted e2e session backups -CREATE TABLE e2e_room_keys_versions ( - user_id TEXT NOT NULL, - version TEXT NOT NULL, - algorithm TEXT NOT NULL, - auth_data TEXT NOT NULL, - deleted SMALLINT DEFAULT 0 NOT NULL -); - -CREATE UNIQUE INDEX e2e_room_keys_versions_idx ON e2e_room_keys_versions(user_id, version); diff --git a/synapse/storage/schema/delta/51/monthly_active_users.sql b/synapse/storage/schema/delta/51/monthly_active_users.sql deleted file mode 100644 index c9d537d5a3..0000000000 --- a/synapse/storage/schema/delta/51/monthly_active_users.sql +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- a table of monthly active users, for use where blocking based on mau limits -CREATE TABLE monthly_active_users ( - user_id TEXT NOT NULL, - -- Last time we saw the user. Not guaranteed to be accurate due to rate limiting - -- on updates, Granularity of updates governed by - -- synapse.storage.monthly_active_users.LAST_SEEN_GRANULARITY - -- Measured in ms since epoch. - timestamp BIGINT NOT NULL -); - -CREATE UNIQUE INDEX monthly_active_users_users ON monthly_active_users(user_id); -CREATE INDEX monthly_active_users_time_stamp ON monthly_active_users(timestamp); diff --git a/synapse/storage/schema/delta/52/add_event_to_state_group_index.sql b/synapse/storage/schema/delta/52/add_event_to_state_group_index.sql deleted file mode 100644 index 91e03d13e1..0000000000 --- a/synapse/storage/schema/delta/52/add_event_to_state_group_index.sql +++ /dev/null @@ -1,19 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- This is needed to efficiently check for unreferenced state groups during --- purge. Added events_to_state_group(state_group) index -INSERT into background_updates (update_name, progress_json) - VALUES ('event_to_state_groups_sg_index', '{}'); diff --git a/synapse/storage/schema/delta/52/device_list_streams_unique_idx.sql b/synapse/storage/schema/delta/52/device_list_streams_unique_idx.sql deleted file mode 100644 index bfa49e6f92..0000000000 --- a/synapse/storage/schema/delta/52/device_list_streams_unique_idx.sql +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- register a background update which will create a unique index on --- device_lists_remote_cache -INSERT into background_updates (update_name, progress_json) - VALUES ('device_lists_remote_cache_unique_idx', '{}'); - --- and one on device_lists_remote_extremeties -INSERT into background_updates (update_name, progress_json, depends_on) - VALUES ( - 'device_lists_remote_extremeties_unique_idx', '{}', - - -- doesn't really depend on this, but we need to make sure both happen - -- before we drop the old indexes. - 'device_lists_remote_cache_unique_idx' - ); - --- once they complete, we can drop the old indexes. -INSERT into background_updates (update_name, progress_json, depends_on) - VALUES ( - 'drop_device_list_streams_non_unique_indexes', '{}', - 'device_lists_remote_extremeties_unique_idx' - ); diff --git a/synapse/storage/schema/delta/52/e2e_room_keys.sql b/synapse/storage/schema/delta/52/e2e_room_keys.sql deleted file mode 100644 index db687cccae..0000000000 --- a/synapse/storage/schema/delta/52/e2e_room_keys.sql +++ /dev/null @@ -1,53 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* Change version column to an integer so we can do MAX() sensibly - */ -CREATE TABLE e2e_room_keys_versions_new ( - user_id TEXT NOT NULL, - version BIGINT NOT NULL, - algorithm TEXT NOT NULL, - auth_data TEXT NOT NULL, - deleted SMALLINT DEFAULT 0 NOT NULL -); - -INSERT INTO e2e_room_keys_versions_new - SELECT user_id, CAST(version as BIGINT), algorithm, auth_data, deleted FROM e2e_room_keys_versions; - -DROP TABLE e2e_room_keys_versions; -ALTER TABLE e2e_room_keys_versions_new RENAME TO e2e_room_keys_versions; - -CREATE UNIQUE INDEX e2e_room_keys_versions_idx ON e2e_room_keys_versions(user_id, version); - -/* Change e2e_rooms_keys to match - */ -CREATE TABLE e2e_room_keys_new ( - user_id TEXT NOT NULL, - room_id TEXT NOT NULL, - session_id TEXT NOT NULL, - version BIGINT NOT NULL, - first_message_index INT, - forwarded_count INT, - is_verified BOOLEAN, - session_data TEXT NOT NULL -); - -INSERT INTO e2e_room_keys_new - SELECT user_id, room_id, session_id, CAST(version as BIGINT), first_message_index, forwarded_count, is_verified, session_data FROM e2e_room_keys; - -DROP TABLE e2e_room_keys; -ALTER TABLE e2e_room_keys_new RENAME TO e2e_room_keys; - -CREATE UNIQUE INDEX e2e_room_keys_idx ON e2e_room_keys(user_id, room_id, session_id); diff --git a/synapse/storage/schema/delta/53/add_user_type_to_users.sql b/synapse/storage/schema/delta/53/add_user_type_to_users.sql deleted file mode 100644 index 88ec2f83e5..0000000000 --- a/synapse/storage/schema/delta/53/add_user_type_to_users.sql +++ /dev/null @@ -1,19 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* The type of the user: NULL for a regular user, or one of the constants in - * synapse.api.constants.UserTypes - */ -ALTER TABLE users ADD COLUMN user_type TEXT DEFAULT NULL; diff --git a/synapse/storage/schema/delta/53/drop_sent_transactions.sql b/synapse/storage/schema/delta/53/drop_sent_transactions.sql deleted file mode 100644 index e372f5a44a..0000000000 --- a/synapse/storage/schema/delta/53/drop_sent_transactions.sql +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -DROP TABLE IF EXISTS sent_transactions; diff --git a/synapse/storage/schema/delta/53/event_format_version.sql b/synapse/storage/schema/delta/53/event_format_version.sql deleted file mode 100644 index 1d977c2834..0000000000 --- a/synapse/storage/schema/delta/53/event_format_version.sql +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2019 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -ALTER TABLE event_json ADD COLUMN format_version INTEGER; diff --git a/synapse/storage/schema/delta/53/user_dir_populate.sql b/synapse/storage/schema/delta/53/user_dir_populate.sql deleted file mode 100644 index ffcc896b58..0000000000 --- a/synapse/storage/schema/delta/53/user_dir_populate.sql +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2019 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Set up staging tables -INSERT INTO background_updates (update_name, progress_json) VALUES - ('populate_user_directory_createtables', '{}'); - --- Run through each room and update the user directory according to who is in it -INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES - ('populate_user_directory_process_rooms', '{}', 'populate_user_directory_createtables'); - --- Insert all users, if search_all_users is on -INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES - ('populate_user_directory_process_users', '{}', 'populate_user_directory_process_rooms'); - --- Clean up staging tables -INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES - ('populate_user_directory_cleanup', '{}', 'populate_user_directory_process_users'); diff --git a/synapse/storage/schema/delta/53/user_ips_index.sql b/synapse/storage/schema/delta/53/user_ips_index.sql deleted file mode 100644 index b812c5794f..0000000000 --- a/synapse/storage/schema/delta/53/user_ips_index.sql +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -- analyze user_ips, to help ensure the correct indices are used -INSERT INTO background_updates (update_name, progress_json) VALUES - ('user_ips_analyze', '{}'); - --- delete duplicates -INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES - ('user_ips_remove_dupes', '{}', 'user_ips_analyze'); - --- add a new unique index to user_ips table -INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES - ('user_ips_device_unique_index', '{}', 'user_ips_remove_dupes'); - --- drop the old original index -INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES - ('user_ips_drop_nonunique_index', '{}', 'user_ips_device_unique_index'); diff --git a/synapse/storage/schema/delta/53/user_share.sql b/synapse/storage/schema/delta/53/user_share.sql deleted file mode 100644 index 5831b1a6f8..0000000000 --- a/synapse/storage/schema/delta/53/user_share.sql +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright 2017 Vector Creations Ltd, 2019 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Old disused version of the tables below. -DROP TABLE IF EXISTS users_who_share_rooms; - --- Tables keeping track of what users share rooms. This is a map of local users --- to local or remote users, per room. Remote users cannot be in the user_id --- column, only the other_user_id column. There are two tables, one for public --- rooms and those for private rooms. -CREATE TABLE IF NOT EXISTS users_who_share_public_rooms ( - user_id TEXT NOT NULL, - other_user_id TEXT NOT NULL, - room_id TEXT NOT NULL -); - -CREATE TABLE IF NOT EXISTS users_who_share_private_rooms ( - user_id TEXT NOT NULL, - other_user_id TEXT NOT NULL, - room_id TEXT NOT NULL -); - -CREATE UNIQUE INDEX users_who_share_public_rooms_u_idx ON users_who_share_public_rooms(user_id, other_user_id, room_id); -CREATE INDEX users_who_share_public_rooms_r_idx ON users_who_share_public_rooms(room_id); -CREATE INDEX users_who_share_public_rooms_o_idx ON users_who_share_public_rooms(other_user_id); - -CREATE UNIQUE INDEX users_who_share_private_rooms_u_idx ON users_who_share_private_rooms(user_id, other_user_id, room_id); -CREATE INDEX users_who_share_private_rooms_r_idx ON users_who_share_private_rooms(room_id); -CREATE INDEX users_who_share_private_rooms_o_idx ON users_who_share_private_rooms(other_user_id); - --- Make sure that we populate the tables initially by resetting the stream ID -UPDATE user_directory_stream_pos SET stream_id = NULL; diff --git a/synapse/storage/schema/delta/53/user_threepid_id.sql b/synapse/storage/schema/delta/53/user_threepid_id.sql deleted file mode 100644 index 80c2c573b6..0000000000 --- a/synapse/storage/schema/delta/53/user_threepid_id.sql +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright 2019 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Tracks which identity server a user bound their threepid via. -CREATE TABLE user_threepid_id_server ( - user_id TEXT NOT NULL, - medium TEXT NOT NULL, - address TEXT NOT NULL, - id_server TEXT NOT NULL -); - -CREATE UNIQUE INDEX user_threepid_id_server_idx ON user_threepid_id_server( - user_id, medium, address, id_server -); - -INSERT INTO background_updates (update_name, progress_json) VALUES - ('user_threepids_grandfather', '{}'); diff --git a/synapse/storage/schema/delta/53/users_in_public_rooms.sql b/synapse/storage/schema/delta/53/users_in_public_rooms.sql deleted file mode 100644 index f7827ca6d2..0000000000 --- a/synapse/storage/schema/delta/53/users_in_public_rooms.sql +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright 2019 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- We don't need the old version of this table. -DROP TABLE IF EXISTS users_in_public_rooms; - --- Old version of users_in_public_rooms -DROP TABLE IF EXISTS users_who_share_public_rooms; - --- Track what users are in public rooms. -CREATE TABLE IF NOT EXISTS users_in_public_rooms ( - user_id TEXT NOT NULL, - room_id TEXT NOT NULL -); - -CREATE UNIQUE INDEX users_in_public_rooms_u_idx ON users_in_public_rooms(user_id, room_id); diff --git a/synapse/storage/schema/delta/54/account_validity_with_renewal.sql b/synapse/storage/schema/delta/54/account_validity_with_renewal.sql deleted file mode 100644 index 0adb2ad55e..0000000000 --- a/synapse/storage/schema/delta/54/account_validity_with_renewal.sql +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2019 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- We previously changed the schema for this table without renaming the file, which means --- that some databases might still be using the old schema. This ensures Synapse uses the --- right schema for the table. -DROP TABLE IF EXISTS account_validity; - --- Track what users are in public rooms. -CREATE TABLE IF NOT EXISTS account_validity ( - user_id TEXT PRIMARY KEY, - expiration_ts_ms BIGINT NOT NULL, - email_sent BOOLEAN NOT NULL, - renewal_token TEXT -); - -CREATE INDEX account_validity_email_sent_idx ON account_validity(email_sent, expiration_ts_ms) -CREATE UNIQUE INDEX account_validity_renewal_string_idx ON account_validity(renewal_token) diff --git a/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql b/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql deleted file mode 100644 index c01aa9d2d9..0000000000 --- a/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql +++ /dev/null @@ -1,23 +0,0 @@ -/* Copyright 2019 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* When we can use this key until, before we have to refresh it. */ -ALTER TABLE server_signature_keys ADD COLUMN ts_valid_until_ms BIGINT; - -UPDATE server_signature_keys SET ts_valid_until_ms = ( - SELECT MAX(ts_valid_until_ms) FROM server_keys_json skj WHERE - skj.server_name = server_signature_keys.server_name AND - skj.key_id = server_signature_keys.key_id -); diff --git a/synapse/storage/schema/delta/54/delete_forward_extremities.sql b/synapse/storage/schema/delta/54/delete_forward_extremities.sql deleted file mode 100644 index b062ec840c..0000000000 --- a/synapse/storage/schema/delta/54/delete_forward_extremities.sql +++ /dev/null @@ -1,23 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Start a background job to cleanup extremities that were incorrectly added --- by bug #5269. -INSERT INTO background_updates (update_name, progress_json) VALUES - ('delete_soft_failed_extremities', '{}'); - -DROP TABLE IF EXISTS _extremities_to_check; -- To make this delta schema file idempotent. -CREATE TABLE _extremities_to_check AS SELECT event_id FROM event_forward_extremities; -CREATE INDEX _extremities_to_check_id ON _extremities_to_check(event_id); diff --git a/synapse/storage/schema/delta/54/drop_legacy_tables.sql b/synapse/storage/schema/delta/54/drop_legacy_tables.sql deleted file mode 100644 index dbbe682697..0000000000 --- a/synapse/storage/schema/delta/54/drop_legacy_tables.sql +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2019 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- we need to do this first due to foreign constraints -DROP TABLE IF EXISTS application_services_regex; - -DROP TABLE IF EXISTS application_services; -DROP TABLE IF EXISTS transaction_id_to_pdu; -DROP TABLE IF EXISTS stats_reporting; -DROP TABLE IF EXISTS current_state_resets; -DROP TABLE IF EXISTS event_content_hashes; -DROP TABLE IF EXISTS event_destinations; -DROP TABLE IF EXISTS event_edge_hashes; -DROP TABLE IF EXISTS event_signatures; -DROP TABLE IF EXISTS feedback; -DROP TABLE IF EXISTS room_hosts; -DROP TABLE IF EXISTS server_tls_certificates; -DROP TABLE IF EXISTS state_forward_extremities; diff --git a/synapse/storage/schema/delta/54/drop_presence_list.sql b/synapse/storage/schema/delta/54/drop_presence_list.sql deleted file mode 100644 index e6ee70c623..0000000000 --- a/synapse/storage/schema/delta/54/drop_presence_list.sql +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2019 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -DROP TABLE IF EXISTS presence_list; diff --git a/synapse/storage/schema/delta/54/relations.sql b/synapse/storage/schema/delta/54/relations.sql deleted file mode 100644 index 134862b870..0000000000 --- a/synapse/storage/schema/delta/54/relations.sql +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright 2019 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Tracks related events, like reactions, replies, edits, etc. Note that things --- in this table are not necessarily "valid", e.g. it may contain edits from --- people who don't have power to edit other peoples events. -CREATE TABLE IF NOT EXISTS event_relations ( - event_id TEXT NOT NULL, - relates_to_id TEXT NOT NULL, - relation_type TEXT NOT NULL, - aggregation_key TEXT -); - -CREATE UNIQUE INDEX event_relations_id ON event_relations(event_id); -CREATE INDEX event_relations_relates ON event_relations(relates_to_id, relation_type, aggregation_key); diff --git a/synapse/storage/schema/delta/54/stats.sql b/synapse/storage/schema/delta/54/stats.sql deleted file mode 100644 index 652e58308e..0000000000 --- a/synapse/storage/schema/delta/54/stats.sql +++ /dev/null @@ -1,80 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE stats_stream_pos ( - Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. - stream_id BIGINT, - CHECK (Lock='X') -); - -INSERT INTO stats_stream_pos (stream_id) VALUES (null); - -CREATE TABLE user_stats ( - user_id TEXT NOT NULL, - ts BIGINT NOT NULL, - bucket_size INT NOT NULL, - public_rooms INT NOT NULL, - private_rooms INT NOT NULL -); - -CREATE UNIQUE INDEX user_stats_user_ts ON user_stats(user_id, ts); - -CREATE TABLE room_stats ( - room_id TEXT NOT NULL, - ts BIGINT NOT NULL, - bucket_size INT NOT NULL, - current_state_events INT NOT NULL, - joined_members INT NOT NULL, - invited_members INT NOT NULL, - left_members INT NOT NULL, - banned_members INT NOT NULL, - state_events INT NOT NULL -); - -CREATE UNIQUE INDEX room_stats_room_ts ON room_stats(room_id, ts); - --- cache of current room state; useful for the publicRooms list -CREATE TABLE room_state ( - room_id TEXT NOT NULL, - join_rules TEXT, - history_visibility TEXT, - encryption TEXT, - name TEXT, - topic TEXT, - avatar TEXT, - canonical_alias TEXT - -- get aliases straight from the right table -); - -CREATE UNIQUE INDEX room_state_room ON room_state(room_id); - -CREATE TABLE room_stats_earliest_token ( - room_id TEXT NOT NULL, - token BIGINT NOT NULL -); - -CREATE UNIQUE INDEX room_stats_earliest_token_idx ON room_stats_earliest_token(room_id); - --- Set up staging tables -INSERT INTO background_updates (update_name, progress_json) VALUES - ('populate_stats_createtables', '{}'); - --- Run through each room and update stats -INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES - ('populate_stats_process_rooms', '{}', 'populate_stats_createtables'); - --- Clean up staging tables -INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES - ('populate_stats_cleanup', '{}', 'populate_stats_process_rooms'); diff --git a/synapse/storage/schema/delta/54/stats2.sql b/synapse/storage/schema/delta/54/stats2.sql deleted file mode 100644 index 3b2d48447f..0000000000 --- a/synapse/storage/schema/delta/54/stats2.sql +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- This delta file gets run after `54/stats.sql` delta. - --- We want to add some indices to the temporary stats table, so we re-insert --- 'populate_stats_createtables' if we are still processing the rooms update. -INSERT INTO background_updates (update_name, progress_json) - SELECT 'populate_stats_createtables', '{}' - WHERE - 'populate_stats_process_rooms' IN ( - SELECT update_name FROM background_updates - ) - AND 'populate_stats_createtables' NOT IN ( -- don't insert if already exists - SELECT update_name FROM background_updates - ); diff --git a/synapse/storage/schema/delta/55/access_token_expiry.sql b/synapse/storage/schema/delta/55/access_token_expiry.sql deleted file mode 100644 index 4590604bfd..0000000000 --- a/synapse/storage/schema/delta/55/access_token_expiry.sql +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- when this access token can be used until, in ms since the epoch. NULL means the token --- never expires. -ALTER TABLE access_tokens ADD COLUMN valid_until_ms BIGINT; diff --git a/synapse/storage/schema/delta/55/track_threepid_validations.sql b/synapse/storage/schema/delta/55/track_threepid_validations.sql deleted file mode 100644 index a8eced2e0a..0000000000 --- a/synapse/storage/schema/delta/55/track_threepid_validations.sql +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -CREATE TABLE IF NOT EXISTS threepid_validation_session ( - session_id TEXT PRIMARY KEY, - medium TEXT NOT NULL, - address TEXT NOT NULL, - client_secret TEXT NOT NULL, - last_send_attempt BIGINT NOT NULL, - validated_at BIGINT -); - -CREATE TABLE IF NOT EXISTS threepid_validation_token ( - token TEXT PRIMARY KEY, - session_id TEXT NOT NULL, - next_link TEXT, - expires BIGINT NOT NULL -); - -CREATE INDEX threepid_validation_token_session_id ON threepid_validation_token(session_id); diff --git a/synapse/storage/schema/delta/55/users_alter_deactivated.sql b/synapse/storage/schema/delta/55/users_alter_deactivated.sql deleted file mode 100644 index dabdde489b..0000000000 --- a/synapse/storage/schema/delta/55/users_alter_deactivated.sql +++ /dev/null @@ -1,19 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -ALTER TABLE users ADD deactivated SMALLINT DEFAULT 0 NOT NULL; - -INSERT INTO background_updates (update_name, progress_json) VALUES - ('users_set_deactivated_flag', '{}'); diff --git a/synapse/storage/schema/delta/56/add_spans_to_device_lists.sql b/synapse/storage/schema/delta/56/add_spans_to_device_lists.sql deleted file mode 100644 index 41807eb1e7..0000000000 --- a/synapse/storage/schema/delta/56/add_spans_to_device_lists.sql +++ /dev/null @@ -1,20 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * Opentracing context data for inclusion in the device_list_update EDUs, as a - * json-encoded dictionary. NULL if opentracing is disabled (or not enabled for this destination). - */ -ALTER TABLE device_lists_outbound_pokes ADD opentracing_context TEXT; diff --git a/synapse/storage/schema/delta/56/current_state_events_membership.sql b/synapse/storage/schema/delta/56/current_state_events_membership.sql deleted file mode 100644 index 473018676f..0000000000 --- a/synapse/storage/schema/delta/56/current_state_events_membership.sql +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- We add membership to current state so that we don't need to join against --- room_memberships, which can be surprisingly costly (we do such queries --- very frequently). --- This will be null for non-membership events and the content.membership key --- for membership events. (Will also be null for membership events until the --- background update job has finished). -ALTER TABLE current_state_events ADD membership TEXT; diff --git a/synapse/storage/schema/delta/56/current_state_events_membership_mk2.sql b/synapse/storage/schema/delta/56/current_state_events_membership_mk2.sql deleted file mode 100644 index 3133d42d4a..0000000000 --- a/synapse/storage/schema/delta/56/current_state_events_membership_mk2.sql +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- We add membership to current state so that we don't need to join against --- room_memberships, which can be surprisingly costly (we do such queries --- very frequently). --- This will be null for non-membership events and the content.membership key --- for membership events. (Will also be null for membership events until the --- background update job has finished). - -INSERT INTO background_updates (update_name, progress_json) VALUES - ('current_state_events_membership', '{}'); diff --git a/synapse/storage/schema/delta/56/destinations_failure_ts.sql b/synapse/storage/schema/delta/56/destinations_failure_ts.sql deleted file mode 100644 index f00889290b..0000000000 --- a/synapse/storage/schema/delta/56/destinations_failure_ts.sql +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * Record the timestamp when a given server started failing - */ -ALTER TABLE destinations ADD failure_ts BIGINT; - -/* as a rough approximation, we assume that the server started failing at - * retry_interval before the last retry - */ -UPDATE destinations SET failure_ts = retry_last_ts - retry_interval - WHERE retry_last_ts > 0; diff --git a/synapse/storage/schema/delta/56/destinations_retry_interval_type.sql.postgres b/synapse/storage/schema/delta/56/destinations_retry_interval_type.sql.postgres deleted file mode 100644 index b9bbb18a91..0000000000 --- a/synapse/storage/schema/delta/56/destinations_retry_interval_type.sql.postgres +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- We want to store large retry intervals so we upgrade the column from INT --- to BIGINT. We don't need to do this on SQLite. -ALTER TABLE destinations ALTER retry_interval SET DATA TYPE BIGINT; diff --git a/synapse/storage/schema/delta/56/devices_last_seen.sql b/synapse/storage/schema/delta/56/devices_last_seen.sql deleted file mode 100644 index dfa902d0ba..0000000000 --- a/synapse/storage/schema/delta/56/devices_last_seen.sql +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright 2019 Matrix.org Foundation CIC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Track last seen information for a device in the devices table, rather --- than relying on it being in the user_ips table (which we want to be able --- to purge old entries from) -ALTER TABLE devices ADD COLUMN last_seen BIGINT; -ALTER TABLE devices ADD COLUMN ip TEXT; -ALTER TABLE devices ADD COLUMN user_agent TEXT; - -INSERT INTO background_updates (update_name, progress_json) VALUES - ('devices_last_seen', '{}'); diff --git a/synapse/storage/schema/delta/56/drop_unused_event_tables.sql b/synapse/storage/schema/delta/56/drop_unused_event_tables.sql deleted file mode 100644 index 9f09922c67..0000000000 --- a/synapse/storage/schema/delta/56/drop_unused_event_tables.sql +++ /dev/null @@ -1,20 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- these tables are never used. -DROP TABLE IF EXISTS room_names; -DROP TABLE IF EXISTS topics; -DROP TABLE IF EXISTS history_visibility; -DROP TABLE IF EXISTS guest_access; diff --git a/synapse/storage/schema/delta/56/fix_room_keys_index.sql b/synapse/storage/schema/delta/56/fix_room_keys_index.sql deleted file mode 100644 index 014cb3b538..0000000000 --- a/synapse/storage/schema/delta/56/fix_room_keys_index.sql +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright 2019 Matrix.org Foundation CIC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- version is supposed to be part of the room keys index -CREATE UNIQUE INDEX e2e_room_keys_with_version_idx ON e2e_room_keys(user_id, version, room_id, session_id); -DROP INDEX IF EXISTS e2e_room_keys_idx; diff --git a/synapse/storage/schema/delta/56/public_room_list_idx.sql b/synapse/storage/schema/delta/56/public_room_list_idx.sql deleted file mode 100644 index 7be31ffebb..0000000000 --- a/synapse/storage/schema/delta/56/public_room_list_idx.sql +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE INDEX public_room_list_stream_network ON public_room_list_stream (appservice_id, network_id, room_id); diff --git a/synapse/storage/schema/delta/56/redaction_censor.sql b/synapse/storage/schema/delta/56/redaction_censor.sql deleted file mode 100644 index fe51b02309..0000000000 --- a/synapse/storage/schema/delta/56/redaction_censor.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -ALTER TABLE redactions ADD COLUMN have_censored BOOL NOT NULL DEFAULT false; -CREATE INDEX redactions_have_censored ON redactions(event_id) WHERE not have_censored; diff --git a/synapse/storage/schema/delta/56/redaction_censor2.sql b/synapse/storage/schema/delta/56/redaction_censor2.sql deleted file mode 100644 index 77a5eca499..0000000000 --- a/synapse/storage/schema/delta/56/redaction_censor2.sql +++ /dev/null @@ -1,20 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -ALTER TABLE redactions ADD COLUMN received_ts BIGINT; -CREATE INDEX redactions_have_censored_ts ON redactions(received_ts) WHERE not have_censored; - -INSERT INTO background_updates (update_name, progress_json) VALUES - ('redactions_received_ts', '{}'); diff --git a/synapse/storage/schema/delta/56/redaction_censor3_fix_update.sql.postgres b/synapse/storage/schema/delta/56/redaction_censor3_fix_update.sql.postgres deleted file mode 100644 index 67471f3ef5..0000000000 --- a/synapse/storage/schema/delta/56/redaction_censor3_fix_update.sql.postgres +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - --- There was a bug where we may have updated censored redactions as bytes, --- which can (somehow) cause json to be inserted hex encoded. These updates go --- and undoes any such hex encoded JSON. - -INSERT into background_updates (update_name, progress_json) - VALUES ('event_fix_redactions_bytes_create_index', '{}'); - -INSERT into background_updates (update_name, progress_json, depends_on) - VALUES ('event_fix_redactions_bytes', '{}', 'event_fix_redactions_bytes_create_index'); diff --git a/synapse/storage/schema/delta/56/room_membership_idx.sql b/synapse/storage/schema/delta/56/room_membership_idx.sql deleted file mode 100644 index 92ab1f5e65..0000000000 --- a/synapse/storage/schema/delta/56/room_membership_idx.sql +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Adds an index on room_memberships for fetching all forgotten rooms for a user -INSERT INTO background_updates (update_name, progress_json) VALUES - ('room_membership_forgotten_idx', '{}'); diff --git a/synapse/storage/schema/delta/56/stats_separated.sql b/synapse/storage/schema/delta/56/stats_separated.sql deleted file mode 100644 index 163529c071..0000000000 --- a/synapse/storage/schema/delta/56/stats_separated.sql +++ /dev/null @@ -1,152 +0,0 @@ -/* Copyright 2018 New Vector Ltd - * Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - ------ First clean up from previous versions of room stats. - --- First remove old stats stuff -DROP TABLE IF EXISTS room_stats; -DROP TABLE IF EXISTS room_state; -DROP TABLE IF EXISTS room_stats_state; -DROP TABLE IF EXISTS user_stats; -DROP TABLE IF EXISTS room_stats_earliest_tokens; -DROP TABLE IF EXISTS _temp_populate_stats_position; -DROP TABLE IF EXISTS _temp_populate_stats_rooms; -DROP TABLE IF EXISTS stats_stream_pos; - --- Unschedule old background updates if they're still scheduled -DELETE FROM background_updates WHERE update_name IN ( - 'populate_stats_createtables', - 'populate_stats_process_rooms', - 'populate_stats_process_users', - 'populate_stats_cleanup' -); - -INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES - ('populate_stats_process_rooms', '{}', ''); - -INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES - ('populate_stats_process_users', '{}', 'populate_stats_process_rooms'); - ------ Create tables for our version of room stats. - --- single-row table to track position of incremental updates -DROP TABLE IF EXISTS stats_incremental_position; -CREATE TABLE stats_incremental_position ( - Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. - stream_id BIGINT NOT NULL, - CHECK (Lock='X') -); - --- insert a null row and make sure it is the only one. -INSERT INTO stats_incremental_position ( - stream_id -) SELECT COALESCE(MAX(stream_ordering), 0) from events; - --- represents PRESENT room statistics for a room --- only holds absolute fields -DROP TABLE IF EXISTS room_stats_current; -CREATE TABLE room_stats_current ( - room_id TEXT NOT NULL PRIMARY KEY, - - -- These are absolute counts - current_state_events INT NOT NULL, - joined_members INT NOT NULL, - invited_members INT NOT NULL, - left_members INT NOT NULL, - banned_members INT NOT NULL, - - local_users_in_room INT NOT NULL, - - -- The maximum delta stream position that this row takes into account. - completed_delta_stream_id BIGINT NOT NULL -); - - --- represents HISTORICAL room statistics for a room -DROP TABLE IF EXISTS room_stats_historical; -CREATE TABLE room_stats_historical ( - room_id TEXT NOT NULL, - -- These stats cover the time from (end_ts - bucket_size)...end_ts (in ms). - -- Note that end_ts is quantised. - end_ts BIGINT NOT NULL, - bucket_size BIGINT NOT NULL, - - -- These stats are absolute counts - current_state_events BIGINT NOT NULL, - joined_members BIGINT NOT NULL, - invited_members BIGINT NOT NULL, - left_members BIGINT NOT NULL, - banned_members BIGINT NOT NULL, - local_users_in_room BIGINT NOT NULL, - - -- These stats are per time slice - total_events BIGINT NOT NULL, - total_event_bytes BIGINT NOT NULL, - - PRIMARY KEY (room_id, end_ts) -); - --- We use this index to speed up deletion of ancient room stats. -CREATE INDEX room_stats_historical_end_ts ON room_stats_historical (end_ts); - --- represents PRESENT statistics for a user --- only holds absolute fields -DROP TABLE IF EXISTS user_stats_current; -CREATE TABLE user_stats_current ( - user_id TEXT NOT NULL PRIMARY KEY, - - joined_rooms BIGINT NOT NULL, - - -- The maximum delta stream position that this row takes into account. - completed_delta_stream_id BIGINT NOT NULL -); - --- represents HISTORICAL statistics for a user -DROP TABLE IF EXISTS user_stats_historical; -CREATE TABLE user_stats_historical ( - user_id TEXT NOT NULL, - end_ts BIGINT NOT NULL, - bucket_size BIGINT NOT NULL, - - joined_rooms BIGINT NOT NULL, - - invites_sent BIGINT NOT NULL, - rooms_created BIGINT NOT NULL, - total_events BIGINT NOT NULL, - total_event_bytes BIGINT NOT NULL, - - PRIMARY KEY (user_id, end_ts) -); - --- We use this index to speed up deletion of ancient user stats. -CREATE INDEX user_stats_historical_end_ts ON user_stats_historical (end_ts); - - -CREATE TABLE room_stats_state ( - room_id TEXT NOT NULL, - name TEXT, - canonical_alias TEXT, - join_rules TEXT, - history_visibility TEXT, - encryption TEXT, - avatar TEXT, - guest_access TEXT, - is_federatable BOOLEAN, - topic TEXT -); - -CREATE UNIQUE INDEX room_stats_state_room ON room_stats_state(room_id); diff --git a/synapse/storage/schema/delta/56/unique_user_filter_index.py b/synapse/storage/schema/delta/56/unique_user_filter_index.py deleted file mode 100644 index 1de8b54961..0000000000 --- a/synapse/storage/schema/delta/56/unique_user_filter_index.py +++ /dev/null @@ -1,52 +0,0 @@ -import logging - -from synapse.storage.engines import PostgresEngine - -logger = logging.getLogger(__name__) - - -""" -This migration updates the user_filters table as follows: - - - drops any (user_id, filter_id) duplicates - - makes the columns NON-NULLable - - turns the index into a UNIQUE index -""" - - -def run_upgrade(cur, database_engine, *args, **kwargs): - pass - - -def run_create(cur, database_engine, *args, **kwargs): - if isinstance(database_engine, PostgresEngine): - select_clause = """ - SELECT DISTINCT ON (user_id, filter_id) user_id, filter_id, filter_json - FROM user_filters - """ - else: - select_clause = """ - SELECT * FROM user_filters GROUP BY user_id, filter_id - """ - sql = """ - DROP TABLE IF EXISTS user_filters_migration; - DROP INDEX IF EXISTS user_filters_unique; - CREATE TABLE user_filters_migration ( - user_id TEXT NOT NULL, - filter_id BIGINT NOT NULL, - filter_json BYTEA NOT NULL - ); - INSERT INTO user_filters_migration (user_id, filter_id, filter_json) - %s; - CREATE UNIQUE INDEX user_filters_unique ON user_filters_migration - (user_id, filter_id); - DROP TABLE user_filters; - ALTER TABLE user_filters_migration RENAME TO user_filters; - """ % ( - select_clause, - ) - - if isinstance(database_engine, PostgresEngine): - cur.execute(sql) - else: - cur.executescript(sql) diff --git a/synapse/storage/schema/delta/56/user_external_ids.sql b/synapse/storage/schema/delta/56/user_external_ids.sql deleted file mode 100644 index 91390c4527..0000000000 --- a/synapse/storage/schema/delta/56/user_external_ids.sql +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright 2019 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * a table which records mappings from external auth providers to mxids - */ -CREATE TABLE IF NOT EXISTS user_external_ids ( - auth_provider TEXT NOT NULL, - external_id TEXT NOT NULL, - user_id TEXT NOT NULL, - UNIQUE (auth_provider, external_id) -); diff --git a/synapse/storage/schema/delta/56/users_in_public_rooms_idx.sql b/synapse/storage/schema/delta/56/users_in_public_rooms_idx.sql deleted file mode 100644 index 149f8be8b6..0000000000 --- a/synapse/storage/schema/delta/56/users_in_public_rooms_idx.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2019 Matrix.org Foundation CIC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- this was apparently forgotten when the table was created back in delta 53. -CREATE INDEX users_in_public_rooms_r_idx ON users_in_public_rooms(room_id); diff --git a/synapse/storage/schema/full_schemas/16/application_services.sql b/synapse/storage/schema/full_schemas/16/application_services.sql deleted file mode 100644 index 883fcd10b2..0000000000 --- a/synapse/storage/schema/full_schemas/16/application_services.sql +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* We used to create tables called application_services and - * application_services_regex, but these are no longer used and are removed in - * delta 54. - */ - - -CREATE TABLE IF NOT EXISTS application_services_state( - as_id TEXT PRIMARY KEY, - state VARCHAR(5), - last_txn INTEGER -); - -CREATE TABLE IF NOT EXISTS application_services_txns( - as_id TEXT NOT NULL, - txn_id INTEGER NOT NULL, - event_ids TEXT NOT NULL, - UNIQUE(as_id, txn_id) -); - -CREATE INDEX application_services_txns_id ON application_services_txns ( - as_id -); diff --git a/synapse/storage/schema/full_schemas/16/event_edges.sql b/synapse/storage/schema/full_schemas/16/event_edges.sql deleted file mode 100644 index 10ce2aa7a0..0000000000 --- a/synapse/storage/schema/full_schemas/16/event_edges.sql +++ /dev/null @@ -1,70 +0,0 @@ -/* Copyright 2014-2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* We used to create tables called event_destinations and - * state_forward_extremities, but these are no longer used and are removed in - * delta 54. - */ - -CREATE TABLE IF NOT EXISTS event_forward_extremities( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - UNIQUE (event_id, room_id) -); - -CREATE INDEX ev_extrem_room ON event_forward_extremities(room_id); -CREATE INDEX ev_extrem_id ON event_forward_extremities(event_id); - - -CREATE TABLE IF NOT EXISTS event_backward_extremities( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - UNIQUE (event_id, room_id) -); - -CREATE INDEX ev_b_extrem_room ON event_backward_extremities(room_id); -CREATE INDEX ev_b_extrem_id ON event_backward_extremities(event_id); - - -CREATE TABLE IF NOT EXISTS event_edges( - event_id TEXT NOT NULL, - prev_event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - is_state BOOL NOT NULL, -- true if this is a prev_state edge rather than a regular - -- event dag edge. - UNIQUE (event_id, prev_event_id, room_id, is_state) -); - -CREATE INDEX ev_edges_id ON event_edges(event_id); -CREATE INDEX ev_edges_prev_id ON event_edges(prev_event_id); - - -CREATE TABLE IF NOT EXISTS room_depth( - room_id TEXT NOT NULL, - min_depth INTEGER NOT NULL, - UNIQUE (room_id) -); - -CREATE INDEX room_depth_room ON room_depth(room_id); - -CREATE TABLE IF NOT EXISTS event_auth( - event_id TEXT NOT NULL, - auth_id TEXT NOT NULL, - room_id TEXT NOT NULL, - UNIQUE (event_id, auth_id, room_id) -); - -CREATE INDEX evauth_edges_id ON event_auth(event_id); -CREATE INDEX evauth_edges_auth_id ON event_auth(auth_id); diff --git a/synapse/storage/schema/full_schemas/16/event_signatures.sql b/synapse/storage/schema/full_schemas/16/event_signatures.sql deleted file mode 100644 index 95826da431..0000000000 --- a/synapse/storage/schema/full_schemas/16/event_signatures.sql +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright 2014-2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - /* We used to create tables called event_content_hashes and event_edge_hashes, - * but these are no longer used and are removed in delta 54. - */ - -CREATE TABLE IF NOT EXISTS event_reference_hashes ( - event_id TEXT, - algorithm TEXT, - hash bytea, - UNIQUE (event_id, algorithm) -); - -CREATE INDEX event_reference_hashes_id ON event_reference_hashes(event_id); - - -CREATE TABLE IF NOT EXISTS event_signatures ( - event_id TEXT, - signature_name TEXT, - key_id TEXT, - signature bytea, - UNIQUE (event_id, signature_name, key_id) -); - -CREATE INDEX event_signatures_id ON event_signatures(event_id); diff --git a/synapse/storage/schema/full_schemas/16/im.sql b/synapse/storage/schema/full_schemas/16/im.sql deleted file mode 100644 index a1a2aa8e5b..0000000000 --- a/synapse/storage/schema/full_schemas/16/im.sql +++ /dev/null @@ -1,120 +0,0 @@ -/* Copyright 2014-2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* We used to create tables called room_hosts and feedback, - * but these are no longer used and are removed in delta 54. - */ - -CREATE TABLE IF NOT EXISTS events( - stream_ordering INTEGER PRIMARY KEY, - topological_ordering BIGINT NOT NULL, - event_id TEXT NOT NULL, - type TEXT NOT NULL, - room_id TEXT NOT NULL, - - -- 'content' used to be created NULLable, but as of delta 50 we drop that constraint. - -- the hack we use to drop the constraint doesn't work for an in-memory sqlite - -- database, which breaks the sytests. Hence, we no longer make it nullable. - content TEXT, - - unrecognized_keys TEXT, - processed BOOL NOT NULL, - outlier BOOL NOT NULL, - depth BIGINT DEFAULT 0 NOT NULL, - UNIQUE (event_id) -); - -CREATE INDEX events_stream_ordering ON events (stream_ordering); -CREATE INDEX events_topological_ordering ON events (topological_ordering); -CREATE INDEX events_order ON events (topological_ordering, stream_ordering); -CREATE INDEX events_room_id ON events (room_id); -CREATE INDEX events_order_room ON events ( - room_id, topological_ordering, stream_ordering -); - - -CREATE TABLE IF NOT EXISTS event_json( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - internal_metadata TEXT NOT NULL, - json TEXT NOT NULL, - UNIQUE (event_id) -); - -CREATE INDEX event_json_room_id ON event_json(room_id); - - -CREATE TABLE IF NOT EXISTS state_events( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - type TEXT NOT NULL, - state_key TEXT NOT NULL, - prev_state TEXT, - UNIQUE (event_id) -); - -CREATE INDEX state_events_room_id ON state_events (room_id); -CREATE INDEX state_events_type ON state_events (type); -CREATE INDEX state_events_state_key ON state_events (state_key); - - -CREATE TABLE IF NOT EXISTS current_state_events( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - type TEXT NOT NULL, - state_key TEXT NOT NULL, - UNIQUE (event_id), - UNIQUE (room_id, type, state_key) -); - -CREATE INDEX current_state_events_room_id ON current_state_events (room_id); -CREATE INDEX current_state_events_type ON current_state_events (type); -CREATE INDEX current_state_events_state_key ON current_state_events (state_key); - -CREATE TABLE IF NOT EXISTS room_memberships( - event_id TEXT NOT NULL, - user_id TEXT NOT NULL, - sender TEXT NOT NULL, - room_id TEXT NOT NULL, - membership TEXT NOT NULL, - UNIQUE (event_id) -); - -CREATE INDEX room_memberships_room_id ON room_memberships (room_id); -CREATE INDEX room_memberships_user_id ON room_memberships (user_id); - -CREATE TABLE IF NOT EXISTS topics( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - topic TEXT NOT NULL, - UNIQUE (event_id) -); - -CREATE INDEX topics_room_id ON topics(room_id); - -CREATE TABLE IF NOT EXISTS room_names( - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - name TEXT NOT NULL, - UNIQUE (event_id) -); - -CREATE INDEX room_names_room_id ON room_names(room_id); - -CREATE TABLE IF NOT EXISTS rooms( - room_id TEXT PRIMARY KEY NOT NULL, - is_public BOOL, - creator TEXT -); diff --git a/synapse/storage/schema/full_schemas/16/keys.sql b/synapse/storage/schema/full_schemas/16/keys.sql deleted file mode 100644 index 11cdffdbb3..0000000000 --- a/synapse/storage/schema/full_schemas/16/keys.sql +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright 2014-2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- we used to create a table called server_tls_certificates, but this is no --- longer used, and is removed in delta 54. - -CREATE TABLE IF NOT EXISTS server_signature_keys( - server_name TEXT, -- Server name. - key_id TEXT, -- Key version. - from_server TEXT, -- Which key server the key was fetched form. - ts_added_ms BIGINT, -- When the key was added. - verify_key bytea, -- NACL verification key. - UNIQUE (server_name, key_id) -); diff --git a/synapse/storage/schema/full_schemas/16/media_repository.sql b/synapse/storage/schema/full_schemas/16/media_repository.sql deleted file mode 100644 index 8f3759bb2a..0000000000 --- a/synapse/storage/schema/full_schemas/16/media_repository.sql +++ /dev/null @@ -1,68 +0,0 @@ -/* Copyright 2014-2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS local_media_repository ( - media_id TEXT, -- The id used to refer to the media. - media_type TEXT, -- The MIME-type of the media. - media_length INTEGER, -- Length of the media in bytes. - created_ts BIGINT, -- When the content was uploaded in ms. - upload_name TEXT, -- The name the media was uploaded with. - user_id TEXT, -- The user who uploaded the file. - UNIQUE (media_id) -); - -CREATE TABLE IF NOT EXISTS local_media_repository_thumbnails ( - media_id TEXT, -- The id used to refer to the media. - thumbnail_width INTEGER, -- The width of the thumbnail in pixels. - thumbnail_height INTEGER, -- The height of the thumbnail in pixels. - thumbnail_type TEXT, -- The MIME-type of the thumbnail. - thumbnail_method TEXT, -- The method used to make the thumbnail. - thumbnail_length INTEGER, -- The length of the thumbnail in bytes. - UNIQUE ( - media_id, thumbnail_width, thumbnail_height, thumbnail_type - ) -); - -CREATE INDEX local_media_repository_thumbnails_media_id - ON local_media_repository_thumbnails (media_id); - -CREATE TABLE IF NOT EXISTS remote_media_cache ( - media_origin TEXT, -- The remote HS the media came from. - media_id TEXT, -- The id used to refer to the media on that server. - media_type TEXT, -- The MIME-type of the media. - created_ts BIGINT, -- When the content was uploaded in ms. - upload_name TEXT, -- The name the media was uploaded with. - media_length INTEGER, -- Length of the media in bytes. - filesystem_id TEXT, -- The name used to store the media on disk. - UNIQUE (media_origin, media_id) -); - -CREATE TABLE IF NOT EXISTS remote_media_cache_thumbnails ( - media_origin TEXT, -- The remote HS the media came from. - media_id TEXT, -- The id used to refer to the media. - thumbnail_width INTEGER, -- The width of the thumbnail in pixels. - thumbnail_height INTEGER, -- The height of the thumbnail in pixels. - thumbnail_method TEXT, -- The method used to make the thumbnail - thumbnail_type TEXT, -- The MIME-type of the thumbnail. - thumbnail_length INTEGER, -- The length of the thumbnail in bytes. - filesystem_id TEXT, -- The name used to store the media on disk. - UNIQUE ( - media_origin, media_id, thumbnail_width, thumbnail_height, - thumbnail_type - ) -); - -CREATE INDEX remote_media_cache_thumbnails_media_id - ON remote_media_cache_thumbnails (media_id); diff --git a/synapse/storage/schema/full_schemas/16/presence.sql b/synapse/storage/schema/full_schemas/16/presence.sql deleted file mode 100644 index 01d2d8f833..0000000000 --- a/synapse/storage/schema/full_schemas/16/presence.sql +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright 2014-2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -CREATE TABLE IF NOT EXISTS presence( - user_id TEXT NOT NULL, - state VARCHAR(20), - status_msg TEXT, - mtime BIGINT, -- miliseconds since last state change - UNIQUE (user_id) -); - --- For each of /my/ users which possibly-remote users are allowed to see their --- presence state -CREATE TABLE IF NOT EXISTS presence_allow_inbound( - observed_user_id TEXT NOT NULL, - observer_user_id TEXT NOT NULL, -- a UserID, - UNIQUE (observed_user_id, observer_user_id) -); - --- We used to create a table called presence_list, but this is no longer used --- and is removed in delta 54. \ No newline at end of file diff --git a/synapse/storage/schema/full_schemas/16/profiles.sql b/synapse/storage/schema/full_schemas/16/profiles.sql deleted file mode 100644 index c04f4747d9..0000000000 --- a/synapse/storage/schema/full_schemas/16/profiles.sql +++ /dev/null @@ -1,20 +0,0 @@ -/* Copyright 2014-2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -CREATE TABLE IF NOT EXISTS profiles( - user_id TEXT NOT NULL, - displayname TEXT, - avatar_url TEXT, - UNIQUE(user_id) -); diff --git a/synapse/storage/schema/full_schemas/16/push.sql b/synapse/storage/schema/full_schemas/16/push.sql deleted file mode 100644 index e44465cf45..0000000000 --- a/synapse/storage/schema/full_schemas/16/push.sql +++ /dev/null @@ -1,74 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS rejections( - event_id TEXT NOT NULL, - reason TEXT NOT NULL, - last_check TEXT NOT NULL, - UNIQUE (event_id) -); - --- Push notification endpoints that users have configured -CREATE TABLE IF NOT EXISTS pushers ( - id BIGINT PRIMARY KEY, - user_name TEXT NOT NULL, - access_token BIGINT DEFAULT NULL, - profile_tag VARCHAR(32) NOT NULL, - kind VARCHAR(8) NOT NULL, - app_id VARCHAR(64) NOT NULL, - app_display_name VARCHAR(64) NOT NULL, - device_display_name VARCHAR(128) NOT NULL, - pushkey bytea NOT NULL, - ts BIGINT NOT NULL, - lang VARCHAR(8), - data bytea, - last_token TEXT, - last_success BIGINT, - failing_since BIGINT, - UNIQUE (app_id, pushkey) -); - -CREATE TABLE IF NOT EXISTS push_rules ( - id BIGINT PRIMARY KEY, - user_name TEXT NOT NULL, - rule_id TEXT NOT NULL, - priority_class SMALLINT NOT NULL, - priority INTEGER NOT NULL DEFAULT 0, - conditions TEXT NOT NULL, - actions TEXT NOT NULL, - UNIQUE(user_name, rule_id) -); - -CREATE INDEX push_rules_user_name on push_rules (user_name); - -CREATE TABLE IF NOT EXISTS user_filters( - user_id TEXT, - filter_id BIGINT, - filter_json bytea -); - -CREATE INDEX user_filters_by_user_id_filter_id ON user_filters( - user_id, filter_id -); - -CREATE TABLE IF NOT EXISTS push_rules_enable ( - id BIGINT PRIMARY KEY, - user_name TEXT NOT NULL, - rule_id TEXT NOT NULL, - enabled SMALLINT, - UNIQUE(user_name, rule_id) -); - -CREATE INDEX push_rules_enable_user_name on push_rules_enable (user_name); diff --git a/synapse/storage/schema/full_schemas/16/redactions.sql b/synapse/storage/schema/full_schemas/16/redactions.sql deleted file mode 100644 index 318f0d9aa5..0000000000 --- a/synapse/storage/schema/full_schemas/16/redactions.sql +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2014-2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -CREATE TABLE IF NOT EXISTS redactions ( - event_id TEXT NOT NULL, - redacts TEXT NOT NULL, - UNIQUE (event_id) -); - -CREATE INDEX redactions_event_id ON redactions (event_id); -CREATE INDEX redactions_redacts ON redactions (redacts); diff --git a/synapse/storage/schema/full_schemas/16/room_aliases.sql b/synapse/storage/schema/full_schemas/16/room_aliases.sql deleted file mode 100644 index d47da3b12f..0000000000 --- a/synapse/storage/schema/full_schemas/16/room_aliases.sql +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright 2014-2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS room_aliases( - room_alias TEXT NOT NULL, - room_id TEXT NOT NULL, - UNIQUE (room_alias) -); - -CREATE INDEX room_aliases_id ON room_aliases(room_id); - -CREATE TABLE IF NOT EXISTS room_alias_servers( - room_alias TEXT NOT NULL, - server TEXT NOT NULL -); - -CREATE INDEX room_alias_servers_alias ON room_alias_servers(room_alias); diff --git a/synapse/storage/schema/full_schemas/16/state.sql b/synapse/storage/schema/full_schemas/16/state.sql deleted file mode 100644 index 96391a8f0e..0000000000 --- a/synapse/storage/schema/full_schemas/16/state.sql +++ /dev/null @@ -1,40 +0,0 @@ -/* Copyright 2014-2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS state_groups( - id BIGINT PRIMARY KEY, - room_id TEXT NOT NULL, - event_id TEXT NOT NULL -); - -CREATE TABLE IF NOT EXISTS state_groups_state( - state_group BIGINT NOT NULL, - room_id TEXT NOT NULL, - type TEXT NOT NULL, - state_key TEXT NOT NULL, - event_id TEXT NOT NULL -); - -CREATE TABLE IF NOT EXISTS event_to_state_groups( - event_id TEXT NOT NULL, - state_group BIGINT NOT NULL, - UNIQUE (event_id) -); - -CREATE INDEX state_groups_id ON state_groups(id); - -CREATE INDEX state_groups_state_id ON state_groups_state(state_group); -CREATE INDEX state_groups_state_tuple ON state_groups_state(room_id, type, state_key); -CREATE INDEX event_to_state_groups_id ON event_to_state_groups(event_id); diff --git a/synapse/storage/schema/full_schemas/16/transactions.sql b/synapse/storage/schema/full_schemas/16/transactions.sql deleted file mode 100644 index 17e67bedac..0000000000 --- a/synapse/storage/schema/full_schemas/16/transactions.sql +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright 2014-2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ --- Stores what transaction ids we have received and what our response was -CREATE TABLE IF NOT EXISTS received_transactions( - transaction_id TEXT, - origin TEXT, - ts BIGINT, - response_code INTEGER, - response_json bytea, - has_been_referenced smallint default 0, -- Whether thishas been referenced by a prev_tx - UNIQUE (transaction_id, origin) -); - -CREATE INDEX transactions_have_ref ON received_transactions(origin, has_been_referenced);-- WHERE has_been_referenced = 0; - --- For sent transactions only. -CREATE TABLE IF NOT EXISTS transaction_id_to_pdu( - transaction_id INTEGER, - destination TEXT, - pdu_id TEXT, - pdu_origin TEXT, - UNIQUE (transaction_id, destination) -); - -CREATE INDEX transaction_id_to_pdu_dest ON transaction_id_to_pdu(destination); - --- To track destination health -CREATE TABLE IF NOT EXISTS destinations( - destination TEXT PRIMARY KEY, - retry_last_ts BIGINT, - retry_interval INTEGER -); diff --git a/synapse/storage/schema/full_schemas/16/users.sql b/synapse/storage/schema/full_schemas/16/users.sql deleted file mode 100644 index f013aa8b18..0000000000 --- a/synapse/storage/schema/full_schemas/16/users.sql +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright 2014-2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -CREATE TABLE IF NOT EXISTS users( - name TEXT, - password_hash TEXT, - creation_ts BIGINT, - admin SMALLINT DEFAULT 0 NOT NULL, - UNIQUE(name) -); - -CREATE TABLE IF NOT EXISTS access_tokens( - id BIGINT PRIMARY KEY, - user_id TEXT NOT NULL, - device_id TEXT, - token TEXT NOT NULL, - last_used BIGINT, - UNIQUE(token) -); - -CREATE TABLE IF NOT EXISTS user_ips ( - user_id TEXT NOT NULL, - access_token TEXT NOT NULL, - device_id TEXT, - ip TEXT NOT NULL, - user_agent TEXT NOT NULL, - last_seen BIGINT NOT NULL -); - -CREATE INDEX user_ips_user ON user_ips(user_id); -CREATE INDEX user_ips_user_ip ON user_ips(user_id, access_token, ip); diff --git a/synapse/storage/schema/full_schemas/54/full.sql b/synapse/storage/schema/full_schemas/54/full.sql new file mode 100644 index 0000000000..1005880466 --- /dev/null +++ b/synapse/storage/schema/full_schemas/54/full.sql @@ -0,0 +1,8 @@ + + +CREATE TABLE background_updates ( + update_name text NOT NULL, + progress_json text NOT NULL, + depends_on text, + CONSTRAINT background_updates_uniqueness UNIQUE (update_name) +); diff --git a/synapse/storage/schema/full_schemas/54/full.sql.postgres b/synapse/storage/schema/full_schemas/54/full.sql.postgres deleted file mode 100644 index 098434356f..0000000000 --- a/synapse/storage/schema/full_schemas/54/full.sql.postgres +++ /dev/null @@ -1,2052 +0,0 @@ - - - - - -CREATE TABLE access_tokens ( - id bigint NOT NULL, - user_id text NOT NULL, - device_id text, - token text NOT NULL, - last_used bigint -); - - - -CREATE TABLE account_data ( - user_id text NOT NULL, - account_data_type text NOT NULL, - stream_id bigint NOT NULL, - content text NOT NULL -); - - - -CREATE TABLE account_data_max_stream_id ( - lock character(1) DEFAULT 'X'::bpchar NOT NULL, - stream_id bigint NOT NULL, - CONSTRAINT private_user_data_max_stream_id_lock_check CHECK ((lock = 'X'::bpchar)) -); - - - -CREATE TABLE account_validity ( - user_id text NOT NULL, - expiration_ts_ms bigint NOT NULL, - email_sent boolean NOT NULL, - renewal_token text -); - - - -CREATE TABLE application_services_state ( - as_id text NOT NULL, - state character varying(5), - last_txn integer -); - - - -CREATE TABLE application_services_txns ( - as_id text NOT NULL, - txn_id integer NOT NULL, - event_ids text NOT NULL -); - - - -CREATE TABLE appservice_room_list ( - appservice_id text NOT NULL, - network_id text NOT NULL, - room_id text NOT NULL -); - - - -CREATE TABLE appservice_stream_position ( - lock character(1) DEFAULT 'X'::bpchar NOT NULL, - stream_ordering bigint, - CONSTRAINT appservice_stream_position_lock_check CHECK ((lock = 'X'::bpchar)) -); - - - -CREATE TABLE background_updates ( - update_name text NOT NULL, - progress_json text NOT NULL, - depends_on text -); - - - -CREATE TABLE blocked_rooms ( - room_id text NOT NULL, - user_id text NOT NULL -); - - - -CREATE TABLE cache_invalidation_stream ( - stream_id bigint, - cache_func text, - keys text[], - invalidation_ts bigint -); - - - -CREATE TABLE current_state_delta_stream ( - stream_id bigint NOT NULL, - room_id text NOT NULL, - type text NOT NULL, - state_key text NOT NULL, - event_id text, - prev_event_id text -); - - - -CREATE TABLE current_state_events ( - event_id text NOT NULL, - room_id text NOT NULL, - type text NOT NULL, - state_key text NOT NULL -); - - - -CREATE TABLE deleted_pushers ( - stream_id bigint NOT NULL, - app_id text NOT NULL, - pushkey text NOT NULL, - user_id text NOT NULL -); - - - -CREATE TABLE destinations ( - destination text NOT NULL, - retry_last_ts bigint, - retry_interval integer -); - - - -CREATE TABLE device_federation_inbox ( - origin text NOT NULL, - message_id text NOT NULL, - received_ts bigint NOT NULL -); - - - -CREATE TABLE device_federation_outbox ( - destination text NOT NULL, - stream_id bigint NOT NULL, - queued_ts bigint NOT NULL, - messages_json text NOT NULL -); - - - -CREATE TABLE device_inbox ( - user_id text NOT NULL, - device_id text NOT NULL, - stream_id bigint NOT NULL, - message_json text NOT NULL -); - - - -CREATE TABLE device_lists_outbound_last_success ( - destination text NOT NULL, - user_id text NOT NULL, - stream_id bigint NOT NULL -); - - - -CREATE TABLE device_lists_outbound_pokes ( - destination text NOT NULL, - stream_id bigint NOT NULL, - user_id text NOT NULL, - device_id text NOT NULL, - sent boolean NOT NULL, - ts bigint NOT NULL -); - - - -CREATE TABLE device_lists_remote_cache ( - user_id text NOT NULL, - device_id text NOT NULL, - content text NOT NULL -); - - - -CREATE TABLE device_lists_remote_extremeties ( - user_id text NOT NULL, - stream_id text NOT NULL -); - - - -CREATE TABLE device_lists_stream ( - stream_id bigint NOT NULL, - user_id text NOT NULL, - device_id text NOT NULL -); - - - -CREATE TABLE device_max_stream_id ( - stream_id bigint NOT NULL -); - - - -CREATE TABLE devices ( - user_id text NOT NULL, - device_id text NOT NULL, - display_name text -); - - - -CREATE TABLE e2e_device_keys_json ( - user_id text NOT NULL, - device_id text NOT NULL, - ts_added_ms bigint NOT NULL, - key_json text NOT NULL -); - - - -CREATE TABLE e2e_one_time_keys_json ( - user_id text NOT NULL, - device_id text NOT NULL, - algorithm text NOT NULL, - key_id text NOT NULL, - ts_added_ms bigint NOT NULL, - key_json text NOT NULL -); - - - -CREATE TABLE e2e_room_keys ( - user_id text NOT NULL, - room_id text NOT NULL, - session_id text NOT NULL, - version bigint NOT NULL, - first_message_index integer, - forwarded_count integer, - is_verified boolean, - session_data text NOT NULL -); - - - -CREATE TABLE e2e_room_keys_versions ( - user_id text NOT NULL, - version bigint NOT NULL, - algorithm text NOT NULL, - auth_data text NOT NULL, - deleted smallint DEFAULT 0 NOT NULL -); - - - -CREATE TABLE erased_users ( - user_id text NOT NULL -); - - - -CREATE TABLE event_auth ( - event_id text NOT NULL, - auth_id text NOT NULL, - room_id text NOT NULL -); - - - -CREATE TABLE event_backward_extremities ( - event_id text NOT NULL, - room_id text NOT NULL -); - - - -CREATE TABLE event_edges ( - event_id text NOT NULL, - prev_event_id text NOT NULL, - room_id text NOT NULL, - is_state boolean NOT NULL -); - - - -CREATE TABLE event_forward_extremities ( - event_id text NOT NULL, - room_id text NOT NULL -); - - - -CREATE TABLE event_json ( - event_id text NOT NULL, - room_id text NOT NULL, - internal_metadata text NOT NULL, - json text NOT NULL, - format_version integer -); - - - -CREATE TABLE event_push_actions ( - room_id text NOT NULL, - event_id text NOT NULL, - user_id text NOT NULL, - profile_tag character varying(32), - actions text NOT NULL, - topological_ordering bigint, - stream_ordering bigint, - notif smallint, - highlight smallint -); - - - -CREATE TABLE event_push_actions_staging ( - event_id text NOT NULL, - user_id text NOT NULL, - actions text NOT NULL, - notif smallint NOT NULL, - highlight smallint NOT NULL -); - - - -CREATE TABLE event_push_summary ( - user_id text NOT NULL, - room_id text NOT NULL, - notif_count bigint NOT NULL, - stream_ordering bigint NOT NULL -); - - - -CREATE TABLE event_push_summary_stream_ordering ( - lock character(1) DEFAULT 'X'::bpchar NOT NULL, - stream_ordering bigint NOT NULL, - CONSTRAINT event_push_summary_stream_ordering_lock_check CHECK ((lock = 'X'::bpchar)) -); - - - -CREATE TABLE event_reference_hashes ( - event_id text, - algorithm text, - hash bytea -); - - - -CREATE TABLE event_relations ( - event_id text NOT NULL, - relates_to_id text NOT NULL, - relation_type text NOT NULL, - aggregation_key text -); - - - -CREATE TABLE event_reports ( - id bigint NOT NULL, - received_ts bigint NOT NULL, - room_id text NOT NULL, - event_id text NOT NULL, - user_id text NOT NULL, - reason text, - content text -); - - - -CREATE TABLE event_search ( - event_id text, - room_id text, - sender text, - key text, - vector tsvector, - origin_server_ts bigint, - stream_ordering bigint -); - - - -CREATE TABLE event_to_state_groups ( - event_id text NOT NULL, - state_group bigint NOT NULL -); - - - -CREATE TABLE events ( - stream_ordering integer NOT NULL, - topological_ordering bigint NOT NULL, - event_id text NOT NULL, - type text NOT NULL, - room_id text NOT NULL, - content text, - unrecognized_keys text, - processed boolean NOT NULL, - outlier boolean NOT NULL, - depth bigint DEFAULT 0 NOT NULL, - origin_server_ts bigint, - received_ts bigint, - sender text, - contains_url boolean -); - - - -CREATE TABLE ex_outlier_stream ( - event_stream_ordering bigint NOT NULL, - event_id text NOT NULL, - state_group bigint NOT NULL -); - - - -CREATE TABLE federation_stream_position ( - type text NOT NULL, - stream_id integer NOT NULL -); - - - -CREATE TABLE group_attestations_remote ( - group_id text NOT NULL, - user_id text NOT NULL, - valid_until_ms bigint NOT NULL, - attestation_json text NOT NULL -); - - - -CREATE TABLE group_attestations_renewals ( - group_id text NOT NULL, - user_id text NOT NULL, - valid_until_ms bigint NOT NULL -); - - - -CREATE TABLE group_invites ( - group_id text NOT NULL, - user_id text NOT NULL -); - - - -CREATE TABLE group_roles ( - group_id text NOT NULL, - role_id text NOT NULL, - profile text NOT NULL, - is_public boolean NOT NULL -); - - - -CREATE TABLE group_room_categories ( - group_id text NOT NULL, - category_id text NOT NULL, - profile text NOT NULL, - is_public boolean NOT NULL -); - - - -CREATE TABLE group_rooms ( - group_id text NOT NULL, - room_id text NOT NULL, - is_public boolean NOT NULL -); - - - -CREATE TABLE group_summary_roles ( - group_id text NOT NULL, - role_id text NOT NULL, - role_order bigint NOT NULL, - CONSTRAINT group_summary_roles_role_order_check CHECK ((role_order > 0)) -); - - - -CREATE TABLE group_summary_room_categories ( - group_id text NOT NULL, - category_id text NOT NULL, - cat_order bigint NOT NULL, - CONSTRAINT group_summary_room_categories_cat_order_check CHECK ((cat_order > 0)) -); - - - -CREATE TABLE group_summary_rooms ( - group_id text NOT NULL, - room_id text NOT NULL, - category_id text NOT NULL, - room_order bigint NOT NULL, - is_public boolean NOT NULL, - CONSTRAINT group_summary_rooms_room_order_check CHECK ((room_order > 0)) -); - - - -CREATE TABLE group_summary_users ( - group_id text NOT NULL, - user_id text NOT NULL, - role_id text NOT NULL, - user_order bigint NOT NULL, - is_public boolean NOT NULL -); - - - -CREATE TABLE group_users ( - group_id text NOT NULL, - user_id text NOT NULL, - is_admin boolean NOT NULL, - is_public boolean NOT NULL -); - - - -CREATE TABLE groups ( - group_id text NOT NULL, - name text, - avatar_url text, - short_description text, - long_description text, - is_public boolean NOT NULL, - join_policy text DEFAULT 'invite'::text NOT NULL -); - - - -CREATE TABLE guest_access ( - event_id text NOT NULL, - room_id text NOT NULL, - guest_access text NOT NULL -); - - - -CREATE TABLE history_visibility ( - event_id text NOT NULL, - room_id text NOT NULL, - history_visibility text NOT NULL -); - - - -CREATE TABLE local_group_membership ( - group_id text NOT NULL, - user_id text NOT NULL, - is_admin boolean NOT NULL, - membership text NOT NULL, - is_publicised boolean NOT NULL, - content text NOT NULL -); - - - -CREATE TABLE local_group_updates ( - stream_id bigint NOT NULL, - group_id text NOT NULL, - user_id text NOT NULL, - type text NOT NULL, - content text NOT NULL -); - - - -CREATE TABLE local_invites ( - stream_id bigint NOT NULL, - inviter text NOT NULL, - invitee text NOT NULL, - event_id text NOT NULL, - room_id text NOT NULL, - locally_rejected text, - replaced_by text -); - - - -CREATE TABLE local_media_repository ( - media_id text, - media_type text, - media_length integer, - created_ts bigint, - upload_name text, - user_id text, - quarantined_by text, - url_cache text, - last_access_ts bigint -); - - - -CREATE TABLE local_media_repository_thumbnails ( - media_id text, - thumbnail_width integer, - thumbnail_height integer, - thumbnail_type text, - thumbnail_method text, - thumbnail_length integer -); - - - -CREATE TABLE local_media_repository_url_cache ( - url text, - response_code integer, - etag text, - expires_ts bigint, - og text, - media_id text, - download_ts bigint -); - - - -CREATE TABLE monthly_active_users ( - user_id text NOT NULL, - "timestamp" bigint NOT NULL -); - - - -CREATE TABLE open_id_tokens ( - token text NOT NULL, - ts_valid_until_ms bigint NOT NULL, - user_id text NOT NULL -); - - - -CREATE TABLE presence ( - user_id text NOT NULL, - state character varying(20), - status_msg text, - mtime bigint -); - - - -CREATE TABLE presence_allow_inbound ( - observed_user_id text NOT NULL, - observer_user_id text NOT NULL -); - - - -CREATE TABLE presence_stream ( - stream_id bigint, - user_id text, - state text, - last_active_ts bigint, - last_federation_update_ts bigint, - last_user_sync_ts bigint, - status_msg text, - currently_active boolean -); - - - -CREATE TABLE profiles ( - user_id text NOT NULL, - displayname text, - avatar_url text -); - - - -CREATE TABLE public_room_list_stream ( - stream_id bigint NOT NULL, - room_id text NOT NULL, - visibility boolean NOT NULL, - appservice_id text, - network_id text -); - - - -CREATE TABLE push_rules ( - id bigint NOT NULL, - user_name text NOT NULL, - rule_id text NOT NULL, - priority_class smallint NOT NULL, - priority integer DEFAULT 0 NOT NULL, - conditions text NOT NULL, - actions text NOT NULL -); - - - -CREATE TABLE push_rules_enable ( - id bigint NOT NULL, - user_name text NOT NULL, - rule_id text NOT NULL, - enabled smallint -); - - - -CREATE TABLE push_rules_stream ( - stream_id bigint NOT NULL, - event_stream_ordering bigint NOT NULL, - user_id text NOT NULL, - rule_id text NOT NULL, - op text NOT NULL, - priority_class smallint, - priority integer, - conditions text, - actions text -); - - - -CREATE TABLE pusher_throttle ( - pusher bigint NOT NULL, - room_id text NOT NULL, - last_sent_ts bigint, - throttle_ms bigint -); - - - -CREATE TABLE pushers ( - id bigint NOT NULL, - user_name text NOT NULL, - access_token bigint, - profile_tag text NOT NULL, - kind text NOT NULL, - app_id text NOT NULL, - app_display_name text NOT NULL, - device_display_name text NOT NULL, - pushkey text NOT NULL, - ts bigint NOT NULL, - lang text, - data text, - last_stream_ordering integer, - last_success bigint, - failing_since bigint -); - - - -CREATE TABLE ratelimit_override ( - user_id text NOT NULL, - messages_per_second bigint, - burst_count bigint -); - - - -CREATE TABLE receipts_graph ( - room_id text NOT NULL, - receipt_type text NOT NULL, - user_id text NOT NULL, - event_ids text NOT NULL, - data text NOT NULL -); - - - -CREATE TABLE receipts_linearized ( - stream_id bigint NOT NULL, - room_id text NOT NULL, - receipt_type text NOT NULL, - user_id text NOT NULL, - event_id text NOT NULL, - data text NOT NULL -); - - - -CREATE TABLE received_transactions ( - transaction_id text, - origin text, - ts bigint, - response_code integer, - response_json bytea, - has_been_referenced smallint DEFAULT 0 -); - - - -CREATE TABLE redactions ( - event_id text NOT NULL, - redacts text NOT NULL -); - - - -CREATE TABLE rejections ( - event_id text NOT NULL, - reason text NOT NULL, - last_check text NOT NULL -); - - - -CREATE TABLE remote_media_cache ( - media_origin text, - media_id text, - media_type text, - created_ts bigint, - upload_name text, - media_length integer, - filesystem_id text, - last_access_ts bigint, - quarantined_by text -); - - - -CREATE TABLE remote_media_cache_thumbnails ( - media_origin text, - media_id text, - thumbnail_width integer, - thumbnail_height integer, - thumbnail_method text, - thumbnail_type text, - thumbnail_length integer, - filesystem_id text -); - - - -CREATE TABLE remote_profile_cache ( - user_id text NOT NULL, - displayname text, - avatar_url text, - last_check bigint NOT NULL -); - - - -CREATE TABLE room_account_data ( - user_id text NOT NULL, - room_id text NOT NULL, - account_data_type text NOT NULL, - stream_id bigint NOT NULL, - content text NOT NULL -); - - - -CREATE TABLE room_alias_servers ( - room_alias text NOT NULL, - server text NOT NULL -); - - - -CREATE TABLE room_aliases ( - room_alias text NOT NULL, - room_id text NOT NULL, - creator text -); - - - -CREATE TABLE room_depth ( - room_id text NOT NULL, - min_depth integer NOT NULL -); - - - -CREATE TABLE room_memberships ( - event_id text NOT NULL, - user_id text NOT NULL, - sender text NOT NULL, - room_id text NOT NULL, - membership text NOT NULL, - forgotten integer DEFAULT 0, - display_name text, - avatar_url text -); - - - -CREATE TABLE room_names ( - event_id text NOT NULL, - room_id text NOT NULL, - name text NOT NULL -); - - - -CREATE TABLE room_state ( - room_id text NOT NULL, - join_rules text, - history_visibility text, - encryption text, - name text, - topic text, - avatar text, - canonical_alias text -); - - - -CREATE TABLE room_stats ( - room_id text NOT NULL, - ts bigint NOT NULL, - bucket_size integer NOT NULL, - current_state_events integer NOT NULL, - joined_members integer NOT NULL, - invited_members integer NOT NULL, - left_members integer NOT NULL, - banned_members integer NOT NULL, - state_events integer NOT NULL -); - - - -CREATE TABLE room_stats_earliest_token ( - room_id text NOT NULL, - token bigint NOT NULL -); - - - -CREATE TABLE room_tags ( - user_id text NOT NULL, - room_id text NOT NULL, - tag text NOT NULL, - content text NOT NULL -); - - - -CREATE TABLE room_tags_revisions ( - user_id text NOT NULL, - room_id text NOT NULL, - stream_id bigint NOT NULL -); - - - -CREATE TABLE rooms ( - room_id text NOT NULL, - is_public boolean, - creator text -); - - - -CREATE TABLE server_keys_json ( - server_name text NOT NULL, - key_id text NOT NULL, - from_server text NOT NULL, - ts_added_ms bigint NOT NULL, - ts_valid_until_ms bigint NOT NULL, - key_json bytea NOT NULL -); - - - -CREATE TABLE server_signature_keys ( - server_name text, - key_id text, - from_server text, - ts_added_ms bigint, - verify_key bytea, - ts_valid_until_ms bigint -); - - - -CREATE TABLE state_events ( - event_id text NOT NULL, - room_id text NOT NULL, - type text NOT NULL, - state_key text NOT NULL, - prev_state text -); - - - -CREATE TABLE state_group_edges ( - state_group bigint NOT NULL, - prev_state_group bigint NOT NULL -); - - - -CREATE SEQUENCE state_group_id_seq - START WITH 1 - INCREMENT BY 1 - NO MINVALUE - NO MAXVALUE - CACHE 1; - - - -CREATE TABLE state_groups ( - id bigint NOT NULL, - room_id text NOT NULL, - event_id text NOT NULL -); - - - -CREATE TABLE state_groups_state ( - state_group bigint NOT NULL, - room_id text NOT NULL, - type text NOT NULL, - state_key text NOT NULL, - event_id text NOT NULL -); - - - -CREATE TABLE stats_stream_pos ( - lock character(1) DEFAULT 'X'::bpchar NOT NULL, - stream_id bigint, - CONSTRAINT stats_stream_pos_lock_check CHECK ((lock = 'X'::bpchar)) -); - - - -CREATE TABLE stream_ordering_to_exterm ( - stream_ordering bigint NOT NULL, - room_id text NOT NULL, - event_id text NOT NULL -); - - - -CREATE TABLE threepid_guest_access_tokens ( - medium text, - address text, - guest_access_token text, - first_inviter text -); - - - -CREATE TABLE topics ( - event_id text NOT NULL, - room_id text NOT NULL, - topic text NOT NULL -); - - - -CREATE TABLE user_daily_visits ( - user_id text NOT NULL, - device_id text, - "timestamp" bigint NOT NULL -); - - - -CREATE TABLE user_directory ( - user_id text NOT NULL, - room_id text, - display_name text, - avatar_url text -); - - - -CREATE TABLE user_directory_search ( - user_id text NOT NULL, - vector tsvector -); - - - -CREATE TABLE user_directory_stream_pos ( - lock character(1) DEFAULT 'X'::bpchar NOT NULL, - stream_id bigint, - CONSTRAINT user_directory_stream_pos_lock_check CHECK ((lock = 'X'::bpchar)) -); - - - -CREATE TABLE user_filters ( - user_id text, - filter_id bigint, - filter_json bytea -); - - - -CREATE TABLE user_ips ( - user_id text NOT NULL, - access_token text NOT NULL, - device_id text, - ip text NOT NULL, - user_agent text NOT NULL, - last_seen bigint NOT NULL -); - - - -CREATE TABLE user_stats ( - user_id text NOT NULL, - ts bigint NOT NULL, - bucket_size integer NOT NULL, - public_rooms integer NOT NULL, - private_rooms integer NOT NULL -); - - - -CREATE TABLE user_threepid_id_server ( - user_id text NOT NULL, - medium text NOT NULL, - address text NOT NULL, - id_server text NOT NULL -); - - - -CREATE TABLE user_threepids ( - user_id text NOT NULL, - medium text NOT NULL, - address text NOT NULL, - validated_at bigint NOT NULL, - added_at bigint NOT NULL -); - - - -CREATE TABLE users ( - name text, - password_hash text, - creation_ts bigint, - admin smallint DEFAULT 0 NOT NULL, - upgrade_ts bigint, - is_guest smallint DEFAULT 0 NOT NULL, - appservice_id text, - consent_version text, - consent_server_notice_sent text, - user_type text -); - - - -CREATE TABLE users_in_public_rooms ( - user_id text NOT NULL, - room_id text NOT NULL -); - - - -CREATE TABLE users_pending_deactivation ( - user_id text NOT NULL -); - - - -CREATE TABLE users_who_share_private_rooms ( - user_id text NOT NULL, - other_user_id text NOT NULL, - room_id text NOT NULL -); - - - -ALTER TABLE ONLY access_tokens - ADD CONSTRAINT access_tokens_pkey PRIMARY KEY (id); - - - -ALTER TABLE ONLY access_tokens - ADD CONSTRAINT access_tokens_token_key UNIQUE (token); - - - -ALTER TABLE ONLY account_data - ADD CONSTRAINT account_data_uniqueness UNIQUE (user_id, account_data_type); - - - -ALTER TABLE ONLY account_validity - ADD CONSTRAINT account_validity_pkey PRIMARY KEY (user_id); - - - -ALTER TABLE ONLY application_services_state - ADD CONSTRAINT application_services_state_pkey PRIMARY KEY (as_id); - - - -ALTER TABLE ONLY application_services_txns - ADD CONSTRAINT application_services_txns_as_id_txn_id_key UNIQUE (as_id, txn_id); - - - -ALTER TABLE ONLY appservice_stream_position - ADD CONSTRAINT appservice_stream_position_lock_key UNIQUE (lock); - - - -ALTER TABLE ONLY background_updates - ADD CONSTRAINT background_updates_uniqueness UNIQUE (update_name); - - - -ALTER TABLE ONLY current_state_events - ADD CONSTRAINT current_state_events_event_id_key UNIQUE (event_id); - - - -ALTER TABLE ONLY current_state_events - ADD CONSTRAINT current_state_events_room_id_type_state_key_key UNIQUE (room_id, type, state_key); - - - -ALTER TABLE ONLY destinations - ADD CONSTRAINT destinations_pkey PRIMARY KEY (destination); - - - -ALTER TABLE ONLY devices - ADD CONSTRAINT device_uniqueness UNIQUE (user_id, device_id); - - - -ALTER TABLE ONLY e2e_device_keys_json - ADD CONSTRAINT e2e_device_keys_json_uniqueness UNIQUE (user_id, device_id); - - - -ALTER TABLE ONLY e2e_one_time_keys_json - ADD CONSTRAINT e2e_one_time_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm, key_id); - - - -ALTER TABLE ONLY event_backward_extremities - ADD CONSTRAINT event_backward_extremities_event_id_room_id_key UNIQUE (event_id, room_id); - - - -ALTER TABLE ONLY event_edges - ADD CONSTRAINT event_edges_event_id_prev_event_id_room_id_is_state_key UNIQUE (event_id, prev_event_id, room_id, is_state); - - - -ALTER TABLE ONLY event_forward_extremities - ADD CONSTRAINT event_forward_extremities_event_id_room_id_key UNIQUE (event_id, room_id); - - - -ALTER TABLE ONLY event_push_actions - ADD CONSTRAINT event_id_user_id_profile_tag_uniqueness UNIQUE (room_id, event_id, user_id, profile_tag); - - - -ALTER TABLE ONLY event_json - ADD CONSTRAINT event_json_event_id_key UNIQUE (event_id); - - - -ALTER TABLE ONLY event_push_summary_stream_ordering - ADD CONSTRAINT event_push_summary_stream_ordering_lock_key UNIQUE (lock); - - - -ALTER TABLE ONLY event_reference_hashes - ADD CONSTRAINT event_reference_hashes_event_id_algorithm_key UNIQUE (event_id, algorithm); - - - -ALTER TABLE ONLY event_reports - ADD CONSTRAINT event_reports_pkey PRIMARY KEY (id); - - - -ALTER TABLE ONLY event_to_state_groups - ADD CONSTRAINT event_to_state_groups_event_id_key UNIQUE (event_id); - - - -ALTER TABLE ONLY events - ADD CONSTRAINT events_event_id_key UNIQUE (event_id); - - - -ALTER TABLE ONLY events - ADD CONSTRAINT events_pkey PRIMARY KEY (stream_ordering); - - - -ALTER TABLE ONLY ex_outlier_stream - ADD CONSTRAINT ex_outlier_stream_pkey PRIMARY KEY (event_stream_ordering); - - - -ALTER TABLE ONLY group_roles - ADD CONSTRAINT group_roles_group_id_role_id_key UNIQUE (group_id, role_id); - - - -ALTER TABLE ONLY group_room_categories - ADD CONSTRAINT group_room_categories_group_id_category_id_key UNIQUE (group_id, category_id); - - - -ALTER TABLE ONLY group_summary_roles - ADD CONSTRAINT group_summary_roles_group_id_role_id_role_order_key UNIQUE (group_id, role_id, role_order); - - - -ALTER TABLE ONLY group_summary_room_categories - ADD CONSTRAINT group_summary_room_categories_group_id_category_id_cat_orde_key UNIQUE (group_id, category_id, cat_order); - - - -ALTER TABLE ONLY group_summary_rooms - ADD CONSTRAINT group_summary_rooms_group_id_category_id_room_id_room_order_key UNIQUE (group_id, category_id, room_id, room_order); - - - -ALTER TABLE ONLY guest_access - ADD CONSTRAINT guest_access_event_id_key UNIQUE (event_id); - - - -ALTER TABLE ONLY history_visibility - ADD CONSTRAINT history_visibility_event_id_key UNIQUE (event_id); - - - -ALTER TABLE ONLY local_media_repository - ADD CONSTRAINT local_media_repository_media_id_key UNIQUE (media_id); - - - -ALTER TABLE ONLY local_media_repository_thumbnails - ADD CONSTRAINT local_media_repository_thumbn_media_id_thumbnail_width_thum_key UNIQUE (media_id, thumbnail_width, thumbnail_height, thumbnail_type); - - - -ALTER TABLE ONLY user_threepids - ADD CONSTRAINT medium_address UNIQUE (medium, address); - - - -ALTER TABLE ONLY open_id_tokens - ADD CONSTRAINT open_id_tokens_pkey PRIMARY KEY (token); - - - -ALTER TABLE ONLY presence_allow_inbound - ADD CONSTRAINT presence_allow_inbound_observed_user_id_observer_user_id_key UNIQUE (observed_user_id, observer_user_id); - - - -ALTER TABLE ONLY presence - ADD CONSTRAINT presence_user_id_key UNIQUE (user_id); - - - -ALTER TABLE ONLY account_data_max_stream_id - ADD CONSTRAINT private_user_data_max_stream_id_lock_key UNIQUE (lock); - - - -ALTER TABLE ONLY profiles - ADD CONSTRAINT profiles_user_id_key UNIQUE (user_id); - - - -ALTER TABLE ONLY push_rules_enable - ADD CONSTRAINT push_rules_enable_pkey PRIMARY KEY (id); - - - -ALTER TABLE ONLY push_rules_enable - ADD CONSTRAINT push_rules_enable_user_name_rule_id_key UNIQUE (user_name, rule_id); - - - -ALTER TABLE ONLY push_rules - ADD CONSTRAINT push_rules_pkey PRIMARY KEY (id); - - - -ALTER TABLE ONLY push_rules - ADD CONSTRAINT push_rules_user_name_rule_id_key UNIQUE (user_name, rule_id); - - - -ALTER TABLE ONLY pusher_throttle - ADD CONSTRAINT pusher_throttle_pkey PRIMARY KEY (pusher, room_id); - - - -ALTER TABLE ONLY pushers - ADD CONSTRAINT pushers2_app_id_pushkey_user_name_key UNIQUE (app_id, pushkey, user_name); - - - -ALTER TABLE ONLY pushers - ADD CONSTRAINT pushers2_pkey PRIMARY KEY (id); - - - -ALTER TABLE ONLY receipts_graph - ADD CONSTRAINT receipts_graph_uniqueness UNIQUE (room_id, receipt_type, user_id); - - - -ALTER TABLE ONLY receipts_linearized - ADD CONSTRAINT receipts_linearized_uniqueness UNIQUE (room_id, receipt_type, user_id); - - - -ALTER TABLE ONLY received_transactions - ADD CONSTRAINT received_transactions_transaction_id_origin_key UNIQUE (transaction_id, origin); - - - -ALTER TABLE ONLY redactions - ADD CONSTRAINT redactions_event_id_key UNIQUE (event_id); - - - -ALTER TABLE ONLY rejections - ADD CONSTRAINT rejections_event_id_key UNIQUE (event_id); - - - -ALTER TABLE ONLY remote_media_cache - ADD CONSTRAINT remote_media_cache_media_origin_media_id_key UNIQUE (media_origin, media_id); - - - -ALTER TABLE ONLY remote_media_cache_thumbnails - ADD CONSTRAINT remote_media_cache_thumbnails_media_origin_media_id_thumbna_key UNIQUE (media_origin, media_id, thumbnail_width, thumbnail_height, thumbnail_type); - - - -ALTER TABLE ONLY room_account_data - ADD CONSTRAINT room_account_data_uniqueness UNIQUE (user_id, room_id, account_data_type); - - - -ALTER TABLE ONLY room_aliases - ADD CONSTRAINT room_aliases_room_alias_key UNIQUE (room_alias); - - - -ALTER TABLE ONLY room_depth - ADD CONSTRAINT room_depth_room_id_key UNIQUE (room_id); - - - -ALTER TABLE ONLY room_memberships - ADD CONSTRAINT room_memberships_event_id_key UNIQUE (event_id); - - - -ALTER TABLE ONLY room_names - ADD CONSTRAINT room_names_event_id_key UNIQUE (event_id); - - - -ALTER TABLE ONLY room_tags_revisions - ADD CONSTRAINT room_tag_revisions_uniqueness UNIQUE (user_id, room_id); - - - -ALTER TABLE ONLY room_tags - ADD CONSTRAINT room_tag_uniqueness UNIQUE (user_id, room_id, tag); - - - -ALTER TABLE ONLY rooms - ADD CONSTRAINT rooms_pkey PRIMARY KEY (room_id); - - - -ALTER TABLE ONLY server_keys_json - ADD CONSTRAINT server_keys_json_uniqueness UNIQUE (server_name, key_id, from_server); - - - -ALTER TABLE ONLY server_signature_keys - ADD CONSTRAINT server_signature_keys_server_name_key_id_key UNIQUE (server_name, key_id); - - - -ALTER TABLE ONLY state_events - ADD CONSTRAINT state_events_event_id_key UNIQUE (event_id); - - - -ALTER TABLE ONLY state_groups - ADD CONSTRAINT state_groups_pkey PRIMARY KEY (id); - - - -ALTER TABLE ONLY stats_stream_pos - ADD CONSTRAINT stats_stream_pos_lock_key UNIQUE (lock); - - - -ALTER TABLE ONLY topics - ADD CONSTRAINT topics_event_id_key UNIQUE (event_id); - - - -ALTER TABLE ONLY user_directory_stream_pos - ADD CONSTRAINT user_directory_stream_pos_lock_key UNIQUE (lock); - - - -ALTER TABLE ONLY users - ADD CONSTRAINT users_name_key UNIQUE (name); - - - -CREATE INDEX access_tokens_device_id ON access_tokens USING btree (user_id, device_id); - - - -CREATE INDEX account_data_stream_id ON account_data USING btree (user_id, stream_id); - - - -CREATE INDEX application_services_txns_id ON application_services_txns USING btree (as_id); - - - -CREATE UNIQUE INDEX appservice_room_list_idx ON appservice_room_list USING btree (appservice_id, network_id, room_id); - - - -CREATE UNIQUE INDEX blocked_rooms_idx ON blocked_rooms USING btree (room_id); - - - -CREATE INDEX cache_invalidation_stream_id ON cache_invalidation_stream USING btree (stream_id); - - - -CREATE INDEX current_state_delta_stream_idx ON current_state_delta_stream USING btree (stream_id); - - - -CREATE INDEX current_state_events_member_index ON current_state_events USING btree (state_key) WHERE (type = 'm.room.member'::text); - - - -CREATE INDEX deleted_pushers_stream_id ON deleted_pushers USING btree (stream_id); - - - -CREATE INDEX device_federation_inbox_sender_id ON device_federation_inbox USING btree (origin, message_id); - - - -CREATE INDEX device_federation_outbox_destination_id ON device_federation_outbox USING btree (destination, stream_id); - - - -CREATE INDEX device_federation_outbox_id ON device_federation_outbox USING btree (stream_id); - - - -CREATE INDEX device_inbox_stream_id_user_id ON device_inbox USING btree (stream_id, user_id); - - - -CREATE INDEX device_inbox_user_stream_id ON device_inbox USING btree (user_id, device_id, stream_id); - - - -CREATE INDEX device_lists_outbound_last_success_idx ON device_lists_outbound_last_success USING btree (destination, user_id, stream_id); - - - -CREATE INDEX device_lists_outbound_pokes_id ON device_lists_outbound_pokes USING btree (destination, stream_id); - - - -CREATE INDEX device_lists_outbound_pokes_stream ON device_lists_outbound_pokes USING btree (stream_id); - - - -CREATE INDEX device_lists_outbound_pokes_user ON device_lists_outbound_pokes USING btree (destination, user_id); - - - -CREATE UNIQUE INDEX device_lists_remote_cache_unique_id ON device_lists_remote_cache USING btree (user_id, device_id); - - - -CREATE UNIQUE INDEX device_lists_remote_extremeties_unique_idx ON device_lists_remote_extremeties USING btree (user_id); - - - -CREATE INDEX device_lists_stream_id ON device_lists_stream USING btree (stream_id, user_id); - - - -CREATE INDEX device_lists_stream_user_id ON device_lists_stream USING btree (user_id, device_id); - - - -CREATE UNIQUE INDEX e2e_room_keys_idx ON e2e_room_keys USING btree (user_id, room_id, session_id); - - - -CREATE UNIQUE INDEX e2e_room_keys_versions_idx ON e2e_room_keys_versions USING btree (user_id, version); - - - -CREATE UNIQUE INDEX erased_users_user ON erased_users USING btree (user_id); - - - -CREATE INDEX ev_b_extrem_id ON event_backward_extremities USING btree (event_id); - - - -CREATE INDEX ev_b_extrem_room ON event_backward_extremities USING btree (room_id); - - - -CREATE INDEX ev_edges_id ON event_edges USING btree (event_id); - - - -CREATE INDEX ev_edges_prev_id ON event_edges USING btree (prev_event_id); - - - -CREATE INDEX ev_extrem_id ON event_forward_extremities USING btree (event_id); - - - -CREATE INDEX ev_extrem_room ON event_forward_extremities USING btree (room_id); - - - -CREATE INDEX evauth_edges_id ON event_auth USING btree (event_id); - - - -CREATE INDEX event_contains_url_index ON events USING btree (room_id, topological_ordering, stream_ordering) WHERE ((contains_url = true) AND (outlier = false)); - - - -CREATE INDEX event_json_room_id ON event_json USING btree (room_id); - - - -CREATE INDEX event_push_actions_highlights_index ON event_push_actions USING btree (user_id, room_id, topological_ordering, stream_ordering) WHERE (highlight = 1); - - - -CREATE INDEX event_push_actions_rm_tokens ON event_push_actions USING btree (user_id, room_id, topological_ordering, stream_ordering); - - - -CREATE INDEX event_push_actions_room_id_user_id ON event_push_actions USING btree (room_id, user_id); - - - -CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging USING btree (event_id); - - - -CREATE INDEX event_push_actions_stream_ordering ON event_push_actions USING btree (stream_ordering, user_id); - - - -CREATE INDEX event_push_actions_u_highlight ON event_push_actions USING btree (user_id, stream_ordering); - - - -CREATE INDEX event_push_summary_user_rm ON event_push_summary USING btree (user_id, room_id); - - - -CREATE INDEX event_reference_hashes_id ON event_reference_hashes USING btree (event_id); - - - -CREATE UNIQUE INDEX event_relations_id ON event_relations USING btree (event_id); - - - -CREATE INDEX event_relations_relates ON event_relations USING btree (relates_to_id, relation_type, aggregation_key); - - - -CREATE INDEX event_search_ev_ridx ON event_search USING btree (room_id); - - - -CREATE UNIQUE INDEX event_search_event_id_idx ON event_search USING btree (event_id); - - - -CREATE INDEX event_search_fts_idx ON event_search USING gin (vector); - - - -CREATE INDEX event_to_state_groups_sg_index ON event_to_state_groups USING btree (state_group); - - - -CREATE INDEX events_order_room ON events USING btree (room_id, topological_ordering, stream_ordering); - - - -CREATE INDEX events_room_stream ON events USING btree (room_id, stream_ordering); - - - -CREATE INDEX events_ts ON events USING btree (origin_server_ts, stream_ordering); - - - -CREATE INDEX group_attestations_remote_g_idx ON group_attestations_remote USING btree (group_id, user_id); - - - -CREATE INDEX group_attestations_remote_u_idx ON group_attestations_remote USING btree (user_id); - - - -CREATE INDEX group_attestations_remote_v_idx ON group_attestations_remote USING btree (valid_until_ms); - - - -CREATE INDEX group_attestations_renewals_g_idx ON group_attestations_renewals USING btree (group_id, user_id); - - - -CREATE INDEX group_attestations_renewals_u_idx ON group_attestations_renewals USING btree (user_id); - - - -CREATE INDEX group_attestations_renewals_v_idx ON group_attestations_renewals USING btree (valid_until_ms); - - - -CREATE UNIQUE INDEX group_invites_g_idx ON group_invites USING btree (group_id, user_id); - - - -CREATE INDEX group_invites_u_idx ON group_invites USING btree (user_id); - - - -CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms USING btree (group_id, room_id); - - - -CREATE INDEX group_rooms_r_idx ON group_rooms USING btree (room_id); - - - -CREATE UNIQUE INDEX group_summary_rooms_g_idx ON group_summary_rooms USING btree (group_id, room_id, category_id); - - - -CREATE INDEX group_summary_users_g_idx ON group_summary_users USING btree (group_id); - - - -CREATE UNIQUE INDEX group_users_g_idx ON group_users USING btree (group_id, user_id); - - - -CREATE INDEX group_users_u_idx ON group_users USING btree (user_id); - - - -CREATE UNIQUE INDEX groups_idx ON groups USING btree (group_id); - - - -CREATE INDEX local_group_membership_g_idx ON local_group_membership USING btree (group_id); - - - -CREATE INDEX local_group_membership_u_idx ON local_group_membership USING btree (user_id, group_id); - - - -CREATE INDEX local_invites_for_user_idx ON local_invites USING btree (invitee, locally_rejected, replaced_by, room_id); - - - -CREATE INDEX local_invites_id ON local_invites USING btree (stream_id); - - - -CREATE INDEX local_media_repository_thumbnails_media_id ON local_media_repository_thumbnails USING btree (media_id); - - - -CREATE INDEX local_media_repository_url_cache_by_url_download_ts ON local_media_repository_url_cache USING btree (url, download_ts); - - - -CREATE INDEX local_media_repository_url_cache_expires_idx ON local_media_repository_url_cache USING btree (expires_ts); - - - -CREATE INDEX local_media_repository_url_cache_media_idx ON local_media_repository_url_cache USING btree (media_id); - - - -CREATE INDEX local_media_repository_url_idx ON local_media_repository USING btree (created_ts) WHERE (url_cache IS NOT NULL); - - - -CREATE INDEX monthly_active_users_time_stamp ON monthly_active_users USING btree ("timestamp"); - - - -CREATE UNIQUE INDEX monthly_active_users_users ON monthly_active_users USING btree (user_id); - - - -CREATE INDEX open_id_tokens_ts_valid_until_ms ON open_id_tokens USING btree (ts_valid_until_ms); - - - -CREATE INDEX presence_stream_id ON presence_stream USING btree (stream_id, user_id); - - - -CREATE INDEX presence_stream_user_id ON presence_stream USING btree (user_id); - - - -CREATE INDEX public_room_index ON rooms USING btree (is_public); - - - -CREATE INDEX public_room_list_stream_idx ON public_room_list_stream USING btree (stream_id); - - - -CREATE INDEX public_room_list_stream_rm_idx ON public_room_list_stream USING btree (room_id, stream_id); - - - -CREATE INDEX push_rules_enable_user_name ON push_rules_enable USING btree (user_name); - - - -CREATE INDEX push_rules_stream_id ON push_rules_stream USING btree (stream_id); - - - -CREATE INDEX push_rules_stream_user_stream_id ON push_rules_stream USING btree (user_id, stream_id); - - - -CREATE INDEX push_rules_user_name ON push_rules USING btree (user_name); - - - -CREATE UNIQUE INDEX ratelimit_override_idx ON ratelimit_override USING btree (user_id); - - - -CREATE INDEX receipts_linearized_id ON receipts_linearized USING btree (stream_id); - - - -CREATE INDEX receipts_linearized_room_stream ON receipts_linearized USING btree (room_id, stream_id); - - - -CREATE INDEX receipts_linearized_user ON receipts_linearized USING btree (user_id); - - - -CREATE INDEX received_transactions_ts ON received_transactions USING btree (ts); - - - -CREATE INDEX redactions_redacts ON redactions USING btree (redacts); - - - -CREATE INDEX remote_profile_cache_time ON remote_profile_cache USING btree (last_check); - - - -CREATE UNIQUE INDEX remote_profile_cache_user_id ON remote_profile_cache USING btree (user_id); - - - -CREATE INDEX room_account_data_stream_id ON room_account_data USING btree (user_id, stream_id); - - - -CREATE INDEX room_alias_servers_alias ON room_alias_servers USING btree (room_alias); - - - -CREATE INDEX room_aliases_id ON room_aliases USING btree (room_id); - - - -CREATE INDEX room_depth_room ON room_depth USING btree (room_id); - - - -CREATE INDEX room_memberships_room_id ON room_memberships USING btree (room_id); - - - -CREATE INDEX room_memberships_user_id ON room_memberships USING btree (user_id); - - - -CREATE INDEX room_names_room_id ON room_names USING btree (room_id); - - - -CREATE UNIQUE INDEX room_state_room ON room_state USING btree (room_id); - - - -CREATE UNIQUE INDEX room_stats_earliest_token_idx ON room_stats_earliest_token USING btree (room_id); - - - -CREATE UNIQUE INDEX room_stats_room_ts ON room_stats USING btree (room_id, ts); - - - -CREATE INDEX state_group_edges_idx ON state_group_edges USING btree (state_group); - - - -CREATE INDEX state_group_edges_prev_idx ON state_group_edges USING btree (prev_state_group); - - - -CREATE INDEX state_groups_state_type_idx ON state_groups_state USING btree (state_group, type, state_key); - - - -CREATE INDEX stream_ordering_to_exterm_idx ON stream_ordering_to_exterm USING btree (stream_ordering); - - - -CREATE INDEX stream_ordering_to_exterm_rm_idx ON stream_ordering_to_exterm USING btree (room_id, stream_ordering); - - - -CREATE UNIQUE INDEX threepid_guest_access_tokens_index ON threepid_guest_access_tokens USING btree (medium, address); - - - -CREATE INDEX topics_room_id ON topics USING btree (room_id); - - - -CREATE INDEX user_daily_visits_ts_idx ON user_daily_visits USING btree ("timestamp"); - - - -CREATE INDEX user_daily_visits_uts_idx ON user_daily_visits USING btree (user_id, "timestamp"); - - - -CREATE INDEX user_directory_room_idx ON user_directory USING btree (room_id); - - - -CREATE INDEX user_directory_search_fts_idx ON user_directory_search USING gin (vector); - - - -CREATE UNIQUE INDEX user_directory_search_user_idx ON user_directory_search USING btree (user_id); - - - -CREATE UNIQUE INDEX user_directory_user_idx ON user_directory USING btree (user_id); - - - -CREATE INDEX user_filters_by_user_id_filter_id ON user_filters USING btree (user_id, filter_id); - - - -CREATE INDEX user_ips_device_id ON user_ips USING btree (user_id, device_id, last_seen); - - - -CREATE INDEX user_ips_last_seen ON user_ips USING btree (user_id, last_seen); - - - -CREATE INDEX user_ips_last_seen_only ON user_ips USING btree (last_seen); - - - -CREATE UNIQUE INDEX user_ips_user_token_ip_unique_index ON user_ips USING btree (user_id, access_token, ip); - - - -CREATE UNIQUE INDEX user_stats_user_ts ON user_stats USING btree (user_id, ts); - - - -CREATE UNIQUE INDEX user_threepid_id_server_idx ON user_threepid_id_server USING btree (user_id, medium, address, id_server); - - - -CREATE INDEX user_threepids_medium_address ON user_threepids USING btree (medium, address); - - - -CREATE INDEX user_threepids_user_id ON user_threepids USING btree (user_id); - - - -CREATE INDEX users_creation_ts ON users USING btree (creation_ts); - - - -CREATE UNIQUE INDEX users_in_public_rooms_u_idx ON users_in_public_rooms USING btree (user_id, room_id); - - - -CREATE INDEX users_who_share_private_rooms_o_idx ON users_who_share_private_rooms USING btree (other_user_id); - - - -CREATE INDEX users_who_share_private_rooms_r_idx ON users_who_share_private_rooms USING btree (room_id); - - - -CREATE UNIQUE INDEX users_who_share_private_rooms_u_idx ON users_who_share_private_rooms USING btree (user_id, other_user_id, room_id); - - - diff --git a/synapse/storage/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/schema/full_schemas/54/full.sql.sqlite deleted file mode 100644 index be9295e4c9..0000000000 --- a/synapse/storage/schema/full_schemas/54/full.sql.sqlite +++ /dev/null @@ -1,260 +0,0 @@ -CREATE TABLE application_services_state( as_id TEXT PRIMARY KEY, state VARCHAR(5), last_txn INTEGER ); -CREATE TABLE application_services_txns( as_id TEXT NOT NULL, txn_id INTEGER NOT NULL, event_ids TEXT NOT NULL, UNIQUE(as_id, txn_id) ); -CREATE INDEX application_services_txns_id ON application_services_txns ( as_id ); -CREATE TABLE presence( user_id TEXT NOT NULL, state VARCHAR(20), status_msg TEXT, mtime BIGINT, UNIQUE (user_id) ); -CREATE TABLE presence_allow_inbound( observed_user_id TEXT NOT NULL, observer_user_id TEXT NOT NULL, UNIQUE (observed_user_id, observer_user_id) ); -CREATE TABLE users( name TEXT, password_hash TEXT, creation_ts BIGINT, admin SMALLINT DEFAULT 0 NOT NULL, upgrade_ts BIGINT, is_guest SMALLINT DEFAULT 0 NOT NULL, appservice_id TEXT, consent_version TEXT, consent_server_notice_sent TEXT, user_type TEXT DEFAULT NULL, UNIQUE(name) ); -CREATE TABLE access_tokens( id BIGINT PRIMARY KEY, user_id TEXT NOT NULL, device_id TEXT, token TEXT NOT NULL, last_used BIGINT, UNIQUE(token) ); -CREATE TABLE user_ips ( user_id TEXT NOT NULL, access_token TEXT NOT NULL, device_id TEXT, ip TEXT NOT NULL, user_agent TEXT NOT NULL, last_seen BIGINT NOT NULL ); -CREATE TABLE profiles( user_id TEXT NOT NULL, displayname TEXT, avatar_url TEXT, UNIQUE(user_id) ); -CREATE TABLE received_transactions( transaction_id TEXT, origin TEXT, ts BIGINT, response_code INTEGER, response_json bytea, has_been_referenced smallint default 0, UNIQUE (transaction_id, origin) ); -CREATE TABLE destinations( destination TEXT PRIMARY KEY, retry_last_ts BIGINT, retry_interval INTEGER ); -CREATE TABLE events( stream_ordering INTEGER PRIMARY KEY, topological_ordering BIGINT NOT NULL, event_id TEXT NOT NULL, type TEXT NOT NULL, room_id TEXT NOT NULL, content TEXT, unrecognized_keys TEXT, processed BOOL NOT NULL, outlier BOOL NOT NULL, depth BIGINT DEFAULT 0 NOT NULL, origin_server_ts BIGINT, received_ts BIGINT, sender TEXT, contains_url BOOLEAN, UNIQUE (event_id) ); -CREATE INDEX events_order_room ON events ( room_id, topological_ordering, stream_ordering ); -CREATE TABLE event_json( event_id TEXT NOT NULL, room_id TEXT NOT NULL, internal_metadata TEXT NOT NULL, json TEXT NOT NULL, format_version INTEGER, UNIQUE (event_id) ); -CREATE INDEX event_json_room_id ON event_json(room_id); -CREATE TABLE state_events( event_id TEXT NOT NULL, room_id TEXT NOT NULL, type TEXT NOT NULL, state_key TEXT NOT NULL, prev_state TEXT, UNIQUE (event_id) ); -CREATE TABLE current_state_events( event_id TEXT NOT NULL, room_id TEXT NOT NULL, type TEXT NOT NULL, state_key TEXT NOT NULL, UNIQUE (event_id), UNIQUE (room_id, type, state_key) ); -CREATE TABLE room_memberships( event_id TEXT NOT NULL, user_id TEXT NOT NULL, sender TEXT NOT NULL, room_id TEXT NOT NULL, membership TEXT NOT NULL, forgotten INTEGER DEFAULT 0, display_name TEXT, avatar_url TEXT, UNIQUE (event_id) ); -CREATE INDEX room_memberships_room_id ON room_memberships (room_id); -CREATE INDEX room_memberships_user_id ON room_memberships (user_id); -CREATE TABLE topics( event_id TEXT NOT NULL, room_id TEXT NOT NULL, topic TEXT NOT NULL, UNIQUE (event_id) ); -CREATE INDEX topics_room_id ON topics(room_id); -CREATE TABLE room_names( event_id TEXT NOT NULL, room_id TEXT NOT NULL, name TEXT NOT NULL, UNIQUE (event_id) ); -CREATE INDEX room_names_room_id ON room_names(room_id); -CREATE TABLE rooms( room_id TEXT PRIMARY KEY NOT NULL, is_public BOOL, creator TEXT ); -CREATE TABLE server_signature_keys( server_name TEXT, key_id TEXT, from_server TEXT, ts_added_ms BIGINT, verify_key bytea, ts_valid_until_ms BIGINT, UNIQUE (server_name, key_id) ); -CREATE TABLE rejections( event_id TEXT NOT NULL, reason TEXT NOT NULL, last_check TEXT NOT NULL, UNIQUE (event_id) ); -CREATE TABLE push_rules ( id BIGINT PRIMARY KEY, user_name TEXT NOT NULL, rule_id TEXT NOT NULL, priority_class SMALLINT NOT NULL, priority INTEGER NOT NULL DEFAULT 0, conditions TEXT NOT NULL, actions TEXT NOT NULL, UNIQUE(user_name, rule_id) ); -CREATE INDEX push_rules_user_name on push_rules (user_name); -CREATE TABLE user_filters( user_id TEXT, filter_id BIGINT, filter_json bytea ); -CREATE INDEX user_filters_by_user_id_filter_id ON user_filters( user_id, filter_id ); -CREATE TABLE push_rules_enable ( id BIGINT PRIMARY KEY, user_name TEXT NOT NULL, rule_id TEXT NOT NULL, enabled SMALLINT, UNIQUE(user_name, rule_id) ); -CREATE INDEX push_rules_enable_user_name on push_rules_enable (user_name); -CREATE TABLE event_forward_extremities( event_id TEXT NOT NULL, room_id TEXT NOT NULL, UNIQUE (event_id, room_id) ); -CREATE INDEX ev_extrem_room ON event_forward_extremities(room_id); -CREATE INDEX ev_extrem_id ON event_forward_extremities(event_id); -CREATE TABLE event_backward_extremities( event_id TEXT NOT NULL, room_id TEXT NOT NULL, UNIQUE (event_id, room_id) ); -CREATE INDEX ev_b_extrem_room ON event_backward_extremities(room_id); -CREATE INDEX ev_b_extrem_id ON event_backward_extremities(event_id); -CREATE TABLE event_edges( event_id TEXT NOT NULL, prev_event_id TEXT NOT NULL, room_id TEXT NOT NULL, is_state BOOL NOT NULL, UNIQUE (event_id, prev_event_id, room_id, is_state) ); -CREATE INDEX ev_edges_id ON event_edges(event_id); -CREATE INDEX ev_edges_prev_id ON event_edges(prev_event_id); -CREATE TABLE room_depth( room_id TEXT NOT NULL, min_depth INTEGER NOT NULL, UNIQUE (room_id) ); -CREATE INDEX room_depth_room ON room_depth(room_id); -CREATE TABLE state_groups( id BIGINT PRIMARY KEY, room_id TEXT NOT NULL, event_id TEXT NOT NULL ); -CREATE TABLE state_groups_state( state_group BIGINT NOT NULL, room_id TEXT NOT NULL, type TEXT NOT NULL, state_key TEXT NOT NULL, event_id TEXT NOT NULL ); -CREATE TABLE event_to_state_groups( event_id TEXT NOT NULL, state_group BIGINT NOT NULL, UNIQUE (event_id) ); -CREATE TABLE local_media_repository ( media_id TEXT, media_type TEXT, media_length INTEGER, created_ts BIGINT, upload_name TEXT, user_id TEXT, quarantined_by TEXT, url_cache TEXT, last_access_ts BIGINT, UNIQUE (media_id) ); -CREATE TABLE local_media_repository_thumbnails ( media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_type TEXT, thumbnail_method TEXT, thumbnail_length INTEGER, UNIQUE ( media_id, thumbnail_width, thumbnail_height, thumbnail_type ) ); -CREATE INDEX local_media_repository_thumbnails_media_id ON local_media_repository_thumbnails (media_id); -CREATE TABLE remote_media_cache ( media_origin TEXT, media_id TEXT, media_type TEXT, created_ts BIGINT, upload_name TEXT, media_length INTEGER, filesystem_id TEXT, last_access_ts BIGINT, quarantined_by TEXT, UNIQUE (media_origin, media_id) ); -CREATE TABLE remote_media_cache_thumbnails ( media_origin TEXT, media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_method TEXT, thumbnail_type TEXT, thumbnail_length INTEGER, filesystem_id TEXT, UNIQUE ( media_origin, media_id, thumbnail_width, thumbnail_height, thumbnail_type ) ); -CREATE TABLE redactions ( event_id TEXT NOT NULL, redacts TEXT NOT NULL, UNIQUE (event_id) ); -CREATE INDEX redactions_redacts ON redactions (redacts); -CREATE TABLE room_aliases( room_alias TEXT NOT NULL, room_id TEXT NOT NULL, creator TEXT, UNIQUE (room_alias) ); -CREATE INDEX room_aliases_id ON room_aliases(room_id); -CREATE TABLE room_alias_servers( room_alias TEXT NOT NULL, server TEXT NOT NULL ); -CREATE INDEX room_alias_servers_alias ON room_alias_servers(room_alias); -CREATE TABLE event_reference_hashes ( event_id TEXT, algorithm TEXT, hash bytea, UNIQUE (event_id, algorithm) ); -CREATE INDEX event_reference_hashes_id ON event_reference_hashes(event_id); -CREATE TABLE IF NOT EXISTS "server_keys_json" ( server_name TEXT NOT NULL, key_id TEXT NOT NULL, from_server TEXT NOT NULL, ts_added_ms BIGINT NOT NULL, ts_valid_until_ms BIGINT NOT NULL, key_json bytea NOT NULL, CONSTRAINT server_keys_json_uniqueness UNIQUE (server_name, key_id, from_server) ); -CREATE TABLE e2e_device_keys_json ( user_id TEXT NOT NULL, device_id TEXT NOT NULL, ts_added_ms BIGINT NOT NULL, key_json TEXT NOT NULL, CONSTRAINT e2e_device_keys_json_uniqueness UNIQUE (user_id, device_id) ); -CREATE TABLE e2e_one_time_keys_json ( user_id TEXT NOT NULL, device_id TEXT NOT NULL, algorithm TEXT NOT NULL, key_id TEXT NOT NULL, ts_added_ms BIGINT NOT NULL, key_json TEXT NOT NULL, CONSTRAINT e2e_one_time_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm, key_id) ); -CREATE TABLE receipts_graph( room_id TEXT NOT NULL, receipt_type TEXT NOT NULL, user_id TEXT NOT NULL, event_ids TEXT NOT NULL, data TEXT NOT NULL, CONSTRAINT receipts_graph_uniqueness UNIQUE (room_id, receipt_type, user_id) ); -CREATE TABLE receipts_linearized ( stream_id BIGINT NOT NULL, room_id TEXT NOT NULL, receipt_type TEXT NOT NULL, user_id TEXT NOT NULL, event_id TEXT NOT NULL, data TEXT NOT NULL, CONSTRAINT receipts_linearized_uniqueness UNIQUE (room_id, receipt_type, user_id) ); -CREATE INDEX receipts_linearized_id ON receipts_linearized( stream_id ); -CREATE INDEX receipts_linearized_room_stream ON receipts_linearized( room_id, stream_id ); -CREATE TABLE IF NOT EXISTS "user_threepids" ( user_id TEXT NOT NULL, medium TEXT NOT NULL, address TEXT NOT NULL, validated_at BIGINT NOT NULL, added_at BIGINT NOT NULL, CONSTRAINT medium_address UNIQUE (medium, address) ); -CREATE INDEX user_threepids_user_id ON user_threepids(user_id); -CREATE TABLE background_updates( update_name TEXT NOT NULL, progress_json TEXT NOT NULL, depends_on TEXT, CONSTRAINT background_updates_uniqueness UNIQUE (update_name) ); -CREATE VIRTUAL TABLE event_search USING fts4 ( event_id, room_id, sender, key, value ) -/* event_search(event_id,room_id,sender,"key",value) */; -CREATE TABLE IF NOT EXISTS 'event_search_content'(docid INTEGER PRIMARY KEY, 'c0event_id', 'c1room_id', 'c2sender', 'c3key', 'c4value'); -CREATE TABLE IF NOT EXISTS 'event_search_segments'(blockid INTEGER PRIMARY KEY, block BLOB); -CREATE TABLE IF NOT EXISTS 'event_search_segdir'(level INTEGER,idx INTEGER,start_block INTEGER,leaves_end_block INTEGER,end_block INTEGER,root BLOB,PRIMARY KEY(level, idx)); -CREATE TABLE IF NOT EXISTS 'event_search_docsize'(docid INTEGER PRIMARY KEY, size BLOB); -CREATE TABLE IF NOT EXISTS 'event_search_stat'(id INTEGER PRIMARY KEY, value BLOB); -CREATE TABLE guest_access( event_id TEXT NOT NULL, room_id TEXT NOT NULL, guest_access TEXT NOT NULL, UNIQUE (event_id) ); -CREATE TABLE history_visibility( event_id TEXT NOT NULL, room_id TEXT NOT NULL, history_visibility TEXT NOT NULL, UNIQUE (event_id) ); -CREATE TABLE room_tags( user_id TEXT NOT NULL, room_id TEXT NOT NULL, tag TEXT NOT NULL, content TEXT NOT NULL, CONSTRAINT room_tag_uniqueness UNIQUE (user_id, room_id, tag) ); -CREATE TABLE room_tags_revisions ( user_id TEXT NOT NULL, room_id TEXT NOT NULL, stream_id BIGINT NOT NULL, CONSTRAINT room_tag_revisions_uniqueness UNIQUE (user_id, room_id) ); -CREATE TABLE IF NOT EXISTS "account_data_max_stream_id"( Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, stream_id BIGINT NOT NULL, CHECK (Lock='X') ); -CREATE TABLE account_data( user_id TEXT NOT NULL, account_data_type TEXT NOT NULL, stream_id BIGINT NOT NULL, content TEXT NOT NULL, CONSTRAINT account_data_uniqueness UNIQUE (user_id, account_data_type) ); -CREATE TABLE room_account_data( user_id TEXT NOT NULL, room_id TEXT NOT NULL, account_data_type TEXT NOT NULL, stream_id BIGINT NOT NULL, content TEXT NOT NULL, CONSTRAINT room_account_data_uniqueness UNIQUE (user_id, room_id, account_data_type) ); -CREATE INDEX account_data_stream_id on account_data(user_id, stream_id); -CREATE INDEX room_account_data_stream_id on room_account_data(user_id, stream_id); -CREATE INDEX events_ts ON events(origin_server_ts, stream_ordering); -CREATE TABLE event_push_actions( room_id TEXT NOT NULL, event_id TEXT NOT NULL, user_id TEXT NOT NULL, profile_tag VARCHAR(32), actions TEXT NOT NULL, topological_ordering BIGINT, stream_ordering BIGINT, notif SMALLINT, highlight SMALLINT, CONSTRAINT event_id_user_id_profile_tag_uniqueness UNIQUE (room_id, event_id, user_id, profile_tag) ); -CREATE INDEX event_push_actions_room_id_user_id on event_push_actions(room_id, user_id); -CREATE INDEX events_room_stream on events(room_id, stream_ordering); -CREATE INDEX public_room_index on rooms(is_public); -CREATE INDEX receipts_linearized_user ON receipts_linearized( user_id ); -CREATE INDEX event_push_actions_rm_tokens on event_push_actions( user_id, room_id, topological_ordering, stream_ordering ); -CREATE TABLE presence_stream( stream_id BIGINT, user_id TEXT, state TEXT, last_active_ts BIGINT, last_federation_update_ts BIGINT, last_user_sync_ts BIGINT, status_msg TEXT, currently_active BOOLEAN ); -CREATE INDEX presence_stream_id ON presence_stream(stream_id, user_id); -CREATE INDEX presence_stream_user_id ON presence_stream(user_id); -CREATE TABLE push_rules_stream( stream_id BIGINT NOT NULL, event_stream_ordering BIGINT NOT NULL, user_id TEXT NOT NULL, rule_id TEXT NOT NULL, op TEXT NOT NULL, priority_class SMALLINT, priority INTEGER, conditions TEXT, actions TEXT ); -CREATE INDEX push_rules_stream_id ON push_rules_stream(stream_id); -CREATE INDEX push_rules_stream_user_stream_id on push_rules_stream(user_id, stream_id); -CREATE TABLE ex_outlier_stream( event_stream_ordering BIGINT PRIMARY KEY NOT NULL, event_id TEXT NOT NULL, state_group BIGINT NOT NULL ); -CREATE TABLE threepid_guest_access_tokens( medium TEXT, address TEXT, guest_access_token TEXT, first_inviter TEXT ); -CREATE UNIQUE INDEX threepid_guest_access_tokens_index ON threepid_guest_access_tokens(medium, address); -CREATE TABLE local_invites( stream_id BIGINT NOT NULL, inviter TEXT NOT NULL, invitee TEXT NOT NULL, event_id TEXT NOT NULL, room_id TEXT NOT NULL, locally_rejected TEXT, replaced_by TEXT ); -CREATE INDEX local_invites_id ON local_invites(stream_id); -CREATE INDEX local_invites_for_user_idx ON local_invites(invitee, locally_rejected, replaced_by, room_id); -CREATE INDEX event_push_actions_stream_ordering on event_push_actions( stream_ordering, user_id ); -CREATE TABLE open_id_tokens ( token TEXT NOT NULL PRIMARY KEY, ts_valid_until_ms bigint NOT NULL, user_id TEXT NOT NULL, UNIQUE (token) ); -CREATE INDEX open_id_tokens_ts_valid_until_ms ON open_id_tokens(ts_valid_until_ms); -CREATE TABLE pusher_throttle( pusher BIGINT NOT NULL, room_id TEXT NOT NULL, last_sent_ts BIGINT, throttle_ms BIGINT, PRIMARY KEY (pusher, room_id) ); -CREATE TABLE event_reports( id BIGINT NOT NULL PRIMARY KEY, received_ts BIGINT NOT NULL, room_id TEXT NOT NULL, event_id TEXT NOT NULL, user_id TEXT NOT NULL, reason TEXT, content TEXT ); -CREATE TABLE devices ( user_id TEXT NOT NULL, device_id TEXT NOT NULL, display_name TEXT, CONSTRAINT device_uniqueness UNIQUE (user_id, device_id) ); -CREATE TABLE appservice_stream_position( Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, stream_ordering BIGINT, CHECK (Lock='X') ); -CREATE TABLE device_inbox ( user_id TEXT NOT NULL, device_id TEXT NOT NULL, stream_id BIGINT NOT NULL, message_json TEXT NOT NULL ); -CREATE INDEX device_inbox_user_stream_id ON device_inbox(user_id, device_id, stream_id); -CREATE INDEX received_transactions_ts ON received_transactions(ts); -CREATE TABLE device_federation_outbox ( destination TEXT NOT NULL, stream_id BIGINT NOT NULL, queued_ts BIGINT NOT NULL, messages_json TEXT NOT NULL ); -CREATE INDEX device_federation_outbox_destination_id ON device_federation_outbox(destination, stream_id); -CREATE TABLE device_federation_inbox ( origin TEXT NOT NULL, message_id TEXT NOT NULL, received_ts BIGINT NOT NULL ); -CREATE INDEX device_federation_inbox_sender_id ON device_federation_inbox(origin, message_id); -CREATE TABLE device_max_stream_id ( stream_id BIGINT NOT NULL ); -CREATE TABLE public_room_list_stream ( stream_id BIGINT NOT NULL, room_id TEXT NOT NULL, visibility BOOLEAN NOT NULL , appservice_id TEXT, network_id TEXT); -CREATE INDEX public_room_list_stream_idx on public_room_list_stream( stream_id ); -CREATE INDEX public_room_list_stream_rm_idx on public_room_list_stream( room_id, stream_id ); -CREATE TABLE state_group_edges( state_group BIGINT NOT NULL, prev_state_group BIGINT NOT NULL ); -CREATE INDEX state_group_edges_idx ON state_group_edges(state_group); -CREATE INDEX state_group_edges_prev_idx ON state_group_edges(prev_state_group); -CREATE TABLE stream_ordering_to_exterm ( stream_ordering BIGINT NOT NULL, room_id TEXT NOT NULL, event_id TEXT NOT NULL ); -CREATE INDEX stream_ordering_to_exterm_idx on stream_ordering_to_exterm( stream_ordering ); -CREATE INDEX stream_ordering_to_exterm_rm_idx on stream_ordering_to_exterm( room_id, stream_ordering ); -CREATE TABLE IF NOT EXISTS "event_auth"( event_id TEXT NOT NULL, auth_id TEXT NOT NULL, room_id TEXT NOT NULL ); -CREATE INDEX evauth_edges_id ON event_auth(event_id); -CREATE INDEX user_threepids_medium_address on user_threepids (medium, address); -CREATE TABLE appservice_room_list( appservice_id TEXT NOT NULL, network_id TEXT NOT NULL, room_id TEXT NOT NULL ); -CREATE UNIQUE INDEX appservice_room_list_idx ON appservice_room_list( appservice_id, network_id, room_id ); -CREATE INDEX device_federation_outbox_id ON device_federation_outbox(stream_id); -CREATE TABLE federation_stream_position( type TEXT NOT NULL, stream_id INTEGER NOT NULL ); -CREATE TABLE device_lists_remote_cache ( user_id TEXT NOT NULL, device_id TEXT NOT NULL, content TEXT NOT NULL ); -CREATE TABLE device_lists_remote_extremeties ( user_id TEXT NOT NULL, stream_id TEXT NOT NULL ); -CREATE TABLE device_lists_stream ( stream_id BIGINT NOT NULL, user_id TEXT NOT NULL, device_id TEXT NOT NULL ); -CREATE INDEX device_lists_stream_id ON device_lists_stream(stream_id, user_id); -CREATE TABLE device_lists_outbound_pokes ( destination TEXT NOT NULL, stream_id BIGINT NOT NULL, user_id TEXT NOT NULL, device_id TEXT NOT NULL, sent BOOLEAN NOT NULL, ts BIGINT NOT NULL ); -CREATE INDEX device_lists_outbound_pokes_id ON device_lists_outbound_pokes(destination, stream_id); -CREATE INDEX device_lists_outbound_pokes_user ON device_lists_outbound_pokes(destination, user_id); -CREATE TABLE event_push_summary ( user_id TEXT NOT NULL, room_id TEXT NOT NULL, notif_count BIGINT NOT NULL, stream_ordering BIGINT NOT NULL ); -CREATE INDEX event_push_summary_user_rm ON event_push_summary(user_id, room_id); -CREATE TABLE event_push_summary_stream_ordering ( Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, stream_ordering BIGINT NOT NULL, CHECK (Lock='X') ); -CREATE TABLE IF NOT EXISTS "pushers" ( id BIGINT PRIMARY KEY, user_name TEXT NOT NULL, access_token BIGINT DEFAULT NULL, profile_tag TEXT NOT NULL, kind TEXT NOT NULL, app_id TEXT NOT NULL, app_display_name TEXT NOT NULL, device_display_name TEXT NOT NULL, pushkey TEXT NOT NULL, ts BIGINT NOT NULL, lang TEXT, data TEXT, last_stream_ordering INTEGER, last_success BIGINT, failing_since BIGINT, UNIQUE (app_id, pushkey, user_name) ); -CREATE INDEX device_lists_outbound_pokes_stream ON device_lists_outbound_pokes(stream_id); -CREATE TABLE ratelimit_override ( user_id TEXT NOT NULL, messages_per_second BIGINT, burst_count BIGINT ); -CREATE UNIQUE INDEX ratelimit_override_idx ON ratelimit_override(user_id); -CREATE TABLE current_state_delta_stream ( stream_id BIGINT NOT NULL, room_id TEXT NOT NULL, type TEXT NOT NULL, state_key TEXT NOT NULL, event_id TEXT, prev_event_id TEXT ); -CREATE INDEX current_state_delta_stream_idx ON current_state_delta_stream(stream_id); -CREATE TABLE device_lists_outbound_last_success ( destination TEXT NOT NULL, user_id TEXT NOT NULL, stream_id BIGINT NOT NULL ); -CREATE INDEX device_lists_outbound_last_success_idx ON device_lists_outbound_last_success( destination, user_id, stream_id ); -CREATE TABLE user_directory_stream_pos ( Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, stream_id BIGINT, CHECK (Lock='X') ); -CREATE VIRTUAL TABLE user_directory_search USING fts4 ( user_id, value ) -/* user_directory_search(user_id,value) */; -CREATE TABLE IF NOT EXISTS 'user_directory_search_content'(docid INTEGER PRIMARY KEY, 'c0user_id', 'c1value'); -CREATE TABLE IF NOT EXISTS 'user_directory_search_segments'(blockid INTEGER PRIMARY KEY, block BLOB); -CREATE TABLE IF NOT EXISTS 'user_directory_search_segdir'(level INTEGER,idx INTEGER,start_block INTEGER,leaves_end_block INTEGER,end_block INTEGER,root BLOB,PRIMARY KEY(level, idx)); -CREATE TABLE IF NOT EXISTS 'user_directory_search_docsize'(docid INTEGER PRIMARY KEY, size BLOB); -CREATE TABLE IF NOT EXISTS 'user_directory_search_stat'(id INTEGER PRIMARY KEY, value BLOB); -CREATE TABLE blocked_rooms ( room_id TEXT NOT NULL, user_id TEXT NOT NULL ); -CREATE UNIQUE INDEX blocked_rooms_idx ON blocked_rooms(room_id); -CREATE TABLE IF NOT EXISTS "local_media_repository_url_cache"( url TEXT, response_code INTEGER, etag TEXT, expires_ts BIGINT, og TEXT, media_id TEXT, download_ts BIGINT ); -CREATE INDEX local_media_repository_url_cache_expires_idx ON local_media_repository_url_cache(expires_ts); -CREATE INDEX local_media_repository_url_cache_by_url_download_ts ON local_media_repository_url_cache(url, download_ts); -CREATE INDEX local_media_repository_url_cache_media_idx ON local_media_repository_url_cache(media_id); -CREATE TABLE group_users ( group_id TEXT NOT NULL, user_id TEXT NOT NULL, is_admin BOOLEAN NOT NULL, is_public BOOLEAN NOT NULL ); -CREATE TABLE group_invites ( group_id TEXT NOT NULL, user_id TEXT NOT NULL ); -CREATE TABLE group_rooms ( group_id TEXT NOT NULL, room_id TEXT NOT NULL, is_public BOOLEAN NOT NULL ); -CREATE TABLE group_summary_rooms ( group_id TEXT NOT NULL, room_id TEXT NOT NULL, category_id TEXT NOT NULL, room_order BIGINT NOT NULL, is_public BOOLEAN NOT NULL, UNIQUE (group_id, category_id, room_id, room_order), CHECK (room_order > 0) ); -CREATE UNIQUE INDEX group_summary_rooms_g_idx ON group_summary_rooms(group_id, room_id, category_id); -CREATE TABLE group_summary_room_categories ( group_id TEXT NOT NULL, category_id TEXT NOT NULL, cat_order BIGINT NOT NULL, UNIQUE (group_id, category_id, cat_order), CHECK (cat_order > 0) ); -CREATE TABLE group_room_categories ( group_id TEXT NOT NULL, category_id TEXT NOT NULL, profile TEXT NOT NULL, is_public BOOLEAN NOT NULL, UNIQUE (group_id, category_id) ); -CREATE TABLE group_summary_users ( group_id TEXT NOT NULL, user_id TEXT NOT NULL, role_id TEXT NOT NULL, user_order BIGINT NOT NULL, is_public BOOLEAN NOT NULL ); -CREATE INDEX group_summary_users_g_idx ON group_summary_users(group_id); -CREATE TABLE group_summary_roles ( group_id TEXT NOT NULL, role_id TEXT NOT NULL, role_order BIGINT NOT NULL, UNIQUE (group_id, role_id, role_order), CHECK (role_order > 0) ); -CREATE TABLE group_roles ( group_id TEXT NOT NULL, role_id TEXT NOT NULL, profile TEXT NOT NULL, is_public BOOLEAN NOT NULL, UNIQUE (group_id, role_id) ); -CREATE TABLE group_attestations_renewals ( group_id TEXT NOT NULL, user_id TEXT NOT NULL, valid_until_ms BIGINT NOT NULL ); -CREATE INDEX group_attestations_renewals_g_idx ON group_attestations_renewals(group_id, user_id); -CREATE INDEX group_attestations_renewals_u_idx ON group_attestations_renewals(user_id); -CREATE INDEX group_attestations_renewals_v_idx ON group_attestations_renewals(valid_until_ms); -CREATE TABLE group_attestations_remote ( group_id TEXT NOT NULL, user_id TEXT NOT NULL, valid_until_ms BIGINT NOT NULL, attestation_json TEXT NOT NULL ); -CREATE INDEX group_attestations_remote_g_idx ON group_attestations_remote(group_id, user_id); -CREATE INDEX group_attestations_remote_u_idx ON group_attestations_remote(user_id); -CREATE INDEX group_attestations_remote_v_idx ON group_attestations_remote(valid_until_ms); -CREATE TABLE local_group_membership ( group_id TEXT NOT NULL, user_id TEXT NOT NULL, is_admin BOOLEAN NOT NULL, membership TEXT NOT NULL, is_publicised BOOLEAN NOT NULL, content TEXT NOT NULL ); -CREATE INDEX local_group_membership_u_idx ON local_group_membership(user_id, group_id); -CREATE INDEX local_group_membership_g_idx ON local_group_membership(group_id); -CREATE TABLE local_group_updates ( stream_id BIGINT NOT NULL, group_id TEXT NOT NULL, user_id TEXT NOT NULL, type TEXT NOT NULL, content TEXT NOT NULL ); -CREATE TABLE remote_profile_cache ( user_id TEXT NOT NULL, displayname TEXT, avatar_url TEXT, last_check BIGINT NOT NULL ); -CREATE UNIQUE INDEX remote_profile_cache_user_id ON remote_profile_cache(user_id); -CREATE INDEX remote_profile_cache_time ON remote_profile_cache(last_check); -CREATE TABLE IF NOT EXISTS "deleted_pushers" ( stream_id BIGINT NOT NULL, app_id TEXT NOT NULL, pushkey TEXT NOT NULL, user_id TEXT NOT NULL ); -CREATE INDEX deleted_pushers_stream_id ON deleted_pushers (stream_id); -CREATE TABLE IF NOT EXISTS "groups" ( group_id TEXT NOT NULL, name TEXT, avatar_url TEXT, short_description TEXT, long_description TEXT, is_public BOOL NOT NULL , join_policy TEXT NOT NULL DEFAULT 'invite'); -CREATE UNIQUE INDEX groups_idx ON groups(group_id); -CREATE TABLE IF NOT EXISTS "user_directory" ( user_id TEXT NOT NULL, room_id TEXT, display_name TEXT, avatar_url TEXT ); -CREATE INDEX user_directory_room_idx ON user_directory(room_id); -CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id); -CREATE TABLE event_push_actions_staging ( event_id TEXT NOT NULL, user_id TEXT NOT NULL, actions TEXT NOT NULL, notif SMALLINT NOT NULL, highlight SMALLINT NOT NULL ); -CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging(event_id); -CREATE TABLE users_pending_deactivation ( user_id TEXT NOT NULL ); -CREATE UNIQUE INDEX group_invites_g_idx ON group_invites(group_id, user_id); -CREATE UNIQUE INDEX group_users_g_idx ON group_users(group_id, user_id); -CREATE INDEX group_users_u_idx ON group_users(user_id); -CREATE INDEX group_invites_u_idx ON group_invites(user_id); -CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms(group_id, room_id); -CREATE INDEX group_rooms_r_idx ON group_rooms(room_id); -CREATE TABLE user_daily_visits ( user_id TEXT NOT NULL, device_id TEXT, timestamp BIGINT NOT NULL ); -CREATE INDEX user_daily_visits_uts_idx ON user_daily_visits(user_id, timestamp); -CREATE INDEX user_daily_visits_ts_idx ON user_daily_visits(timestamp); -CREATE TABLE erased_users ( user_id TEXT NOT NULL ); -CREATE UNIQUE INDEX erased_users_user ON erased_users(user_id); -CREATE TABLE monthly_active_users ( user_id TEXT NOT NULL, timestamp BIGINT NOT NULL ); -CREATE UNIQUE INDEX monthly_active_users_users ON monthly_active_users(user_id); -CREATE INDEX monthly_active_users_time_stamp ON monthly_active_users(timestamp); -CREATE TABLE IF NOT EXISTS "e2e_room_keys_versions" ( user_id TEXT NOT NULL, version BIGINT NOT NULL, algorithm TEXT NOT NULL, auth_data TEXT NOT NULL, deleted SMALLINT DEFAULT 0 NOT NULL ); -CREATE UNIQUE INDEX e2e_room_keys_versions_idx ON e2e_room_keys_versions(user_id, version); -CREATE TABLE IF NOT EXISTS "e2e_room_keys" ( user_id TEXT NOT NULL, room_id TEXT NOT NULL, session_id TEXT NOT NULL, version BIGINT NOT NULL, first_message_index INT, forwarded_count INT, is_verified BOOLEAN, session_data TEXT NOT NULL ); -CREATE UNIQUE INDEX e2e_room_keys_idx ON e2e_room_keys(user_id, room_id, session_id); -CREATE TABLE users_who_share_private_rooms ( user_id TEXT NOT NULL, other_user_id TEXT NOT NULL, room_id TEXT NOT NULL ); -CREATE UNIQUE INDEX users_who_share_private_rooms_u_idx ON users_who_share_private_rooms(user_id, other_user_id, room_id); -CREATE INDEX users_who_share_private_rooms_r_idx ON users_who_share_private_rooms(room_id); -CREATE INDEX users_who_share_private_rooms_o_idx ON users_who_share_private_rooms(other_user_id); -CREATE TABLE user_threepid_id_server ( user_id TEXT NOT NULL, medium TEXT NOT NULL, address TEXT NOT NULL, id_server TEXT NOT NULL ); -CREATE UNIQUE INDEX user_threepid_id_server_idx ON user_threepid_id_server( user_id, medium, address, id_server ); -CREATE TABLE users_in_public_rooms ( user_id TEXT NOT NULL, room_id TEXT NOT NULL ); -CREATE UNIQUE INDEX users_in_public_rooms_u_idx ON users_in_public_rooms(user_id, room_id); -CREATE TABLE account_validity ( user_id TEXT PRIMARY KEY, expiration_ts_ms BIGINT NOT NULL, email_sent BOOLEAN NOT NULL, renewal_token TEXT ); -CREATE TABLE event_relations ( event_id TEXT NOT NULL, relates_to_id TEXT NOT NULL, relation_type TEXT NOT NULL, aggregation_key TEXT ); -CREATE UNIQUE INDEX event_relations_id ON event_relations(event_id); -CREATE INDEX event_relations_relates ON event_relations(relates_to_id, relation_type, aggregation_key); -CREATE TABLE stats_stream_pos ( Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, stream_id BIGINT, CHECK (Lock='X') ); -CREATE TABLE user_stats ( user_id TEXT NOT NULL, ts BIGINT NOT NULL, bucket_size INT NOT NULL, public_rooms INT NOT NULL, private_rooms INT NOT NULL ); -CREATE UNIQUE INDEX user_stats_user_ts ON user_stats(user_id, ts); -CREATE TABLE room_stats ( room_id TEXT NOT NULL, ts BIGINT NOT NULL, bucket_size INT NOT NULL, current_state_events INT NOT NULL, joined_members INT NOT NULL, invited_members INT NOT NULL, left_members INT NOT NULL, banned_members INT NOT NULL, state_events INT NOT NULL ); -CREATE UNIQUE INDEX room_stats_room_ts ON room_stats(room_id, ts); -CREATE TABLE room_state ( room_id TEXT NOT NULL, join_rules TEXT, history_visibility TEXT, encryption TEXT, name TEXT, topic TEXT, avatar TEXT, canonical_alias TEXT ); -CREATE UNIQUE INDEX room_state_room ON room_state(room_id); -CREATE TABLE room_stats_earliest_token ( room_id TEXT NOT NULL, token BIGINT NOT NULL ); -CREATE UNIQUE INDEX room_stats_earliest_token_idx ON room_stats_earliest_token(room_id); -CREATE INDEX access_tokens_device_id ON access_tokens (user_id, device_id); -CREATE INDEX user_ips_device_id ON user_ips (user_id, device_id, last_seen); -CREATE INDEX event_contains_url_index ON events (room_id, topological_ordering, stream_ordering); -CREATE INDEX event_push_actions_u_highlight ON event_push_actions (user_id, stream_ordering); -CREATE INDEX event_push_actions_highlights_index ON event_push_actions (user_id, room_id, topological_ordering, stream_ordering); -CREATE INDEX current_state_events_member_index ON current_state_events (state_key); -CREATE INDEX device_inbox_stream_id_user_id ON device_inbox (stream_id, user_id); -CREATE INDEX device_lists_stream_user_id ON device_lists_stream (user_id, device_id); -CREATE INDEX local_media_repository_url_idx ON local_media_repository (created_ts); -CREATE INDEX user_ips_last_seen ON user_ips (user_id, last_seen); -CREATE INDEX user_ips_last_seen_only ON user_ips (last_seen); -CREATE INDEX users_creation_ts ON users (creation_ts); -CREATE INDEX event_to_state_groups_sg_index ON event_to_state_groups (state_group); -CREATE UNIQUE INDEX device_lists_remote_cache_unique_id ON device_lists_remote_cache (user_id, device_id); -CREATE INDEX state_groups_state_type_idx ON state_groups_state(state_group, type, state_key); -CREATE UNIQUE INDEX device_lists_remote_extremeties_unique_idx ON device_lists_remote_extremeties (user_id); -CREATE UNIQUE INDEX user_ips_user_token_ip_unique_index ON user_ips (user_id, access_token, ip); diff --git a/synapse/storage/schema/full_schemas/54/stream_positions.sql b/synapse/storage/schema/full_schemas/54/stream_positions.sql deleted file mode 100644 index c265fd20e2..0000000000 --- a/synapse/storage/schema/full_schemas/54/stream_positions.sql +++ /dev/null @@ -1,7 +0,0 @@ - -INSERT INTO appservice_stream_position (stream_ordering) SELECT COALESCE(MAX(stream_ordering), 0) FROM events; -INSERT INTO federation_stream_position (type, stream_id) VALUES ('federation', -1); -INSERT INTO federation_stream_position (type, stream_id) SELECT 'events', coalesce(max(stream_ordering), -1) FROM events; -INSERT INTO user_directory_stream_pos (stream_id) VALUES (0); -INSERT INTO stats_stream_pos (stream_id) VALUES (0); -INSERT INTO event_push_summary_stream_ordering (stream_ordering) VALUES (0); diff --git a/synapse/storage/schema/full_schemas/README.txt b/synapse/storage/schema/full_schemas/README.txt deleted file mode 100644 index d3f6401344..0000000000 --- a/synapse/storage/schema/full_schemas/README.txt +++ /dev/null @@ -1,19 +0,0 @@ -Building full schema dumps -========================== - -These schemas need to be made from a database that has had all background updates run. - -Postgres --------- - -$ pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner $DATABASE_NAME| sed -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' > full.sql.postgres - -SQLite ------- - -$ sqlite3 $DATABASE_FILE ".schema" > full.sql.sqlite - -After ------ - -Delete the CREATE statements for "sqlite_stat1", "schema_version", "applied_schema_deltas", and "applied_module_schemas". \ No newline at end of file diff --git a/synapse/storage/search.py b/synapse/storage/search.py deleted file mode 100644 index 7695bf09fc..0000000000 --- a/synapse/storage/search.py +++ /dev/null @@ -1,713 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import re -from collections import namedtuple - -from six import string_types - -from canonicaljson import json - -from twisted.internet import defer - -from synapse.api.errors import SynapseError -from synapse.storage._base import make_in_list_sql_clause -from synapse.storage.engines import PostgresEngine, Sqlite3Engine - -from .background_updates import BackgroundUpdateStore - -logger = logging.getLogger(__name__) - -SearchEntry = namedtuple( - "SearchEntry", - ["key", "value", "event_id", "room_id", "stream_ordering", "origin_server_ts"], -) - - -class SearchBackgroundUpdateStore(BackgroundUpdateStore): - - EVENT_SEARCH_UPDATE_NAME = "event_search" - EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" - EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist" - EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" - - def __init__(self, db_conn, hs): - super(SearchBackgroundUpdateStore, self).__init__(db_conn, hs) - - if not hs.config.enable_search: - return - - self.register_background_update_handler( - self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search - ) - self.register_background_update_handler( - self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_reindex_search_order - ) - - # we used to have a background update to turn the GIN index into a - # GIST one; we no longer do that (obviously) because we actually want - # a GIN index. However, it's possible that some people might still have - # the background update queued, so we register a handler to clear the - # background update. - self.register_noop_background_update(self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME) - - self.register_background_update_handler( - self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search - ) - - @defer.inlineCallbacks - def _background_reindex_search(self, progress, batch_size): - # we work through the events table from highest stream id to lowest - target_min_stream_id = progress["target_min_stream_id_inclusive"] - max_stream_id = progress["max_stream_id_exclusive"] - rows_inserted = progress.get("rows_inserted", 0) - - TYPES = ["m.room.name", "m.room.message", "m.room.topic"] - - def reindex_search_txn(txn): - sql = ( - "SELECT stream_ordering, event_id, room_id, type, json, " - " origin_server_ts FROM events" - " JOIN event_json USING (room_id, event_id)" - " WHERE ? <= stream_ordering AND stream_ordering < ?" - " AND (%s)" - " ORDER BY stream_ordering DESC" - " LIMIT ?" - ) % (" OR ".join("type = '%s'" % (t,) for t in TYPES),) - - txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) - - # we could stream straight from the results into - # store_search_entries_txn with a generator function, but that - # would mean having two cursors open on the database at once. - # Instead we just build a list of results. - rows = self.cursor_to_dict(txn) - if not rows: - return 0 - - min_stream_id = rows[-1]["stream_ordering"] - - event_search_rows = [] - for row in rows: - try: - event_id = row["event_id"] - room_id = row["room_id"] - etype = row["type"] - stream_ordering = row["stream_ordering"] - origin_server_ts = row["origin_server_ts"] - try: - event_json = json.loads(row["json"]) - content = event_json["content"] - except Exception: - continue - - if etype == "m.room.message": - key = "content.body" - value = content["body"] - elif etype == "m.room.topic": - key = "content.topic" - value = content["topic"] - elif etype == "m.room.name": - key = "content.name" - value = content["name"] - else: - raise Exception("unexpected event type %s" % etype) - except (KeyError, AttributeError): - # If the event is missing a necessary field then - # skip over it. - continue - - if not isinstance(value, string_types): - # If the event body, name or topic isn't a string - # then skip over it - continue - - event_search_rows.append( - SearchEntry( - key=key, - value=value, - event_id=event_id, - room_id=room_id, - stream_ordering=stream_ordering, - origin_server_ts=origin_server_ts, - ) - ) - - self.store_search_entries_txn(txn, event_search_rows) - - progress = { - "target_min_stream_id_inclusive": target_min_stream_id, - "max_stream_id_exclusive": min_stream_id, - "rows_inserted": rows_inserted + len(event_search_rows), - } - - self._background_update_progress_txn( - txn, self.EVENT_SEARCH_UPDATE_NAME, progress - ) - - return len(event_search_rows) - - result = yield self.runInteraction( - self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn - ) - - if not result: - yield self._end_background_update(self.EVENT_SEARCH_UPDATE_NAME) - - return result - - @defer.inlineCallbacks - def _background_reindex_gin_search(self, progress, batch_size): - """This handles old synapses which used GIST indexes, if any; - converting them back to be GIN as per the actual schema. - """ - - def create_index(conn): - conn.rollback() - - # we have to set autocommit, because postgres refuses to - # CREATE INDEX CONCURRENTLY without it. - conn.set_session(autocommit=True) - - try: - c = conn.cursor() - - # if we skipped the conversion to GIST, we may already/still - # have an event_search_fts_idx; unfortunately postgres 9.4 - # doesn't support CREATE INDEX IF EXISTS so we just catch the - # exception and ignore it. - import psycopg2 - - try: - c.execute( - "CREATE INDEX CONCURRENTLY event_search_fts_idx" - " ON event_search USING GIN (vector)" - ) - except psycopg2.ProgrammingError as e: - logger.warn( - "Ignoring error %r when trying to switch from GIST to GIN", e - ) - - # we should now be able to delete the GIST index. - c.execute("DROP INDEX IF EXISTS event_search_fts_idx_gist") - finally: - conn.set_session(autocommit=False) - - if isinstance(self.database_engine, PostgresEngine): - yield self.runWithConnection(create_index) - - yield self._end_background_update(self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME) - return 1 - - @defer.inlineCallbacks - def _background_reindex_search_order(self, progress, batch_size): - target_min_stream_id = progress["target_min_stream_id_inclusive"] - max_stream_id = progress["max_stream_id_exclusive"] - rows_inserted = progress.get("rows_inserted", 0) - have_added_index = progress["have_added_indexes"] - - if not have_added_index: - - def create_index(conn): - conn.rollback() - conn.set_session(autocommit=True) - c = conn.cursor() - - # We create with NULLS FIRST so that when we search *backwards* - # we get the ones with non null origin_server_ts *first* - c.execute( - "CREATE INDEX CONCURRENTLY event_search_room_order ON event_search(" - "room_id, origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)" - ) - c.execute( - "CREATE INDEX CONCURRENTLY event_search_order ON event_search(" - "origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)" - ) - conn.set_session(autocommit=False) - - yield self.runWithConnection(create_index) - - pg = dict(progress) - pg["have_added_indexes"] = True - - yield self.runInteraction( - self.EVENT_SEARCH_ORDER_UPDATE_NAME, - self._background_update_progress_txn, - self.EVENT_SEARCH_ORDER_UPDATE_NAME, - pg, - ) - - def reindex_search_txn(txn): - sql = ( - "UPDATE event_search AS es SET stream_ordering = e.stream_ordering," - " origin_server_ts = e.origin_server_ts" - " FROM events AS e" - " WHERE e.event_id = es.event_id" - " AND ? <= e.stream_ordering AND e.stream_ordering < ?" - " RETURNING es.stream_ordering" - ) - - min_stream_id = max_stream_id - batch_size - txn.execute(sql, (min_stream_id, max_stream_id)) - rows = txn.fetchall() - - if min_stream_id < target_min_stream_id: - # We've recached the end. - return len(rows), False - - progress = { - "target_min_stream_id_inclusive": target_min_stream_id, - "max_stream_id_exclusive": min_stream_id, - "rows_inserted": rows_inserted + len(rows), - "have_added_indexes": True, - } - - self._background_update_progress_txn( - txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, progress - ) - - return len(rows), True - - num_rows, finished = yield self.runInteraction( - self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn - ) - - if not finished: - yield self._end_background_update(self.EVENT_SEARCH_ORDER_UPDATE_NAME) - - return num_rows - - def store_search_entries_txn(self, txn, entries): - """Add entries to the search table - - Args: - txn (cursor): - entries (iterable[SearchEntry]): - entries to be added to the table - """ - if not self.hs.config.enable_search: - return - if isinstance(self.database_engine, PostgresEngine): - sql = ( - "INSERT INTO event_search" - " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)" - " VALUES (?,?,?,to_tsvector('english', ?),?,?)" - ) - - args = ( - ( - entry.event_id, - entry.room_id, - entry.key, - entry.value, - entry.stream_ordering, - entry.origin_server_ts, - ) - for entry in entries - ) - - txn.executemany(sql, args) - - elif isinstance(self.database_engine, Sqlite3Engine): - sql = ( - "INSERT INTO event_search (event_id, room_id, key, value)" - " VALUES (?,?,?,?)" - ) - args = ( - (entry.event_id, entry.room_id, entry.key, entry.value) - for entry in entries - ) - - txn.executemany(sql, args) - else: - # This should be unreachable. - raise Exception("Unrecognized database engine") - - -class SearchStore(SearchBackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(SearchStore, self).__init__(db_conn, hs) - - def store_event_search_txn(self, txn, event, key, value): - """Add event to the search table - - Args: - txn (cursor): - event (EventBase): - key (str): - value (str): - """ - self.store_search_entries_txn( - txn, - ( - SearchEntry( - key=key, - value=value, - event_id=event.event_id, - room_id=event.room_id, - stream_ordering=event.internal_metadata.stream_ordering, - origin_server_ts=event.origin_server_ts, - ), - ), - ) - - @defer.inlineCallbacks - def search_msgs(self, room_ids, search_term, keys): - """Performs a full text search over events with given keys. - - Args: - room_ids (list): List of room ids to search in - search_term (str): Search term to search for - keys (list): List of keys to search in, currently supports - "content.body", "content.name", "content.topic" - - Returns: - list of dicts - """ - clauses = [] - - search_query = search_query = _parse_query(self.database_engine, search_term) - - args = [] - - # Make sure we don't explode because the person is in too many rooms. - # We filter the results below regardless. - if len(room_ids) < 500: - clause, args = make_in_list_sql_clause( - self.database_engine, "room_id", room_ids - ) - clauses = [clause] - - local_clauses = [] - for key in keys: - local_clauses.append("key = ?") - args.append(key) - - clauses.append("(%s)" % (" OR ".join(local_clauses),)) - - count_args = args - count_clauses = clauses - - if isinstance(self.database_engine, PostgresEngine): - sql = ( - "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) AS rank," - " room_id, event_id" - " FROM event_search" - " WHERE vector @@ to_tsquery('english', ?)" - ) - args = [search_query, search_query] + args - - count_sql = ( - "SELECT room_id, count(*) as count FROM event_search" - " WHERE vector @@ to_tsquery('english', ?)" - ) - count_args = [search_query] + count_args - elif isinstance(self.database_engine, Sqlite3Engine): - sql = ( - "SELECT rank(matchinfo(event_search)) as rank, room_id, event_id" - " FROM event_search" - " WHERE value MATCH ?" - ) - args = [search_query] + args - - count_sql = ( - "SELECT room_id, count(*) as count FROM event_search" - " WHERE value MATCH ?" - ) - count_args = [search_term] + count_args - else: - # This should be unreachable. - raise Exception("Unrecognized database engine") - - for clause in clauses: - sql += " AND " + clause - - for clause in count_clauses: - count_sql += " AND " + clause - - # We add an arbitrary limit here to ensure we don't try to pull the - # entire table from the database. - sql += " ORDER BY rank DESC LIMIT 500" - - results = yield self._execute("search_msgs", self.cursor_to_dict, sql, *args) - - results = list(filter(lambda row: row["room_id"] in room_ids, results)) - - events = yield self.get_events_as_list([r["event_id"] for r in results]) - - event_map = {ev.event_id: ev for ev in events} - - highlights = None - if isinstance(self.database_engine, PostgresEngine): - highlights = yield self._find_highlights_in_postgres(search_query, events) - - count_sql += " GROUP BY room_id" - - count_results = yield self._execute( - "search_rooms_count", self.cursor_to_dict, count_sql, *count_args - ) - - count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) - - return { - "results": [ - {"event": event_map[r["event_id"]], "rank": r["rank"]} - for r in results - if r["event_id"] in event_map - ], - "highlights": highlights, - "count": count, - } - - @defer.inlineCallbacks - def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None): - """Performs a full text search over events with given keys. - - Args: - room_id (list): The room_ids to search in - search_term (str): Search term to search for - keys (list): List of keys to search in, currently supports - "content.body", "content.name", "content.topic" - pagination_token (str): A pagination token previously returned - - Returns: - list of dicts - """ - clauses = [] - - search_query = search_query = _parse_query(self.database_engine, search_term) - - args = [] - - # Make sure we don't explode because the person is in too many rooms. - # We filter the results below regardless. - if len(room_ids) < 500: - clause, args = make_in_list_sql_clause( - self.database_engine, "room_id", room_ids - ) - clauses = [clause] - - local_clauses = [] - for key in keys: - local_clauses.append("key = ?") - args.append(key) - - clauses.append("(%s)" % (" OR ".join(local_clauses),)) - - # take copies of the current args and clauses lists, before adding - # pagination clauses to main query. - count_args = list(args) - count_clauses = list(clauses) - - if pagination_token: - try: - origin_server_ts, stream = pagination_token.split(",") - origin_server_ts = int(origin_server_ts) - stream = int(stream) - except Exception: - raise SynapseError(400, "Invalid pagination token") - - clauses.append( - "(origin_server_ts < ?" - " OR (origin_server_ts = ? AND stream_ordering < ?))" - ) - args.extend([origin_server_ts, origin_server_ts, stream]) - - if isinstance(self.database_engine, PostgresEngine): - sql = ( - "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) as rank," - " origin_server_ts, stream_ordering, room_id, event_id" - " FROM event_search" - " WHERE vector @@ to_tsquery('english', ?) AND " - ) - args = [search_query, search_query] + args - - count_sql = ( - "SELECT room_id, count(*) as count FROM event_search" - " WHERE vector @@ to_tsquery('english', ?) AND " - ) - count_args = [search_query] + count_args - elif isinstance(self.database_engine, Sqlite3Engine): - # We use CROSS JOIN here to ensure we use the right indexes. - # https://sqlite.org/optoverview.html#crossjoin - # - # We want to use the full text search index on event_search to - # extract all possible matches first, then lookup those matches - # in the events table to get the topological ordering. We need - # to use the indexes in this order because sqlite refuses to - # MATCH unless it uses the full text search index - sql = ( - "SELECT rank(matchinfo) as rank, room_id, event_id," - " origin_server_ts, stream_ordering" - " FROM (SELECT key, event_id, matchinfo(event_search) as matchinfo" - " FROM event_search" - " WHERE value MATCH ?" - " )" - " CROSS JOIN events USING (event_id)" - " WHERE " - ) - args = [search_query] + args - - count_sql = ( - "SELECT room_id, count(*) as count FROM event_search" - " WHERE value MATCH ? AND " - ) - count_args = [search_term] + count_args - else: - # This should be unreachable. - raise Exception("Unrecognized database engine") - - sql += " AND ".join(clauses) - count_sql += " AND ".join(count_clauses) - - # We add an arbitrary limit here to ensure we don't try to pull the - # entire table from the database. - if isinstance(self.database_engine, PostgresEngine): - sql += ( - " ORDER BY origin_server_ts DESC NULLS LAST," - " stream_ordering DESC NULLS LAST LIMIT ?" - ) - elif isinstance(self.database_engine, Sqlite3Engine): - sql += " ORDER BY origin_server_ts DESC, stream_ordering DESC LIMIT ?" - else: - raise Exception("Unrecognized database engine") - - args.append(limit) - - results = yield self._execute("search_rooms", self.cursor_to_dict, sql, *args) - - results = list(filter(lambda row: row["room_id"] in room_ids, results)) - - events = yield self.get_events_as_list([r["event_id"] for r in results]) - - event_map = {ev.event_id: ev for ev in events} - - highlights = None - if isinstance(self.database_engine, PostgresEngine): - highlights = yield self._find_highlights_in_postgres(search_query, events) - - count_sql += " GROUP BY room_id" - - count_results = yield self._execute( - "search_rooms_count", self.cursor_to_dict, count_sql, *count_args - ) - - count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) - - return { - "results": [ - { - "event": event_map[r["event_id"]], - "rank": r["rank"], - "pagination_token": "%s,%s" - % (r["origin_server_ts"], r["stream_ordering"]), - } - for r in results - if r["event_id"] in event_map - ], - "highlights": highlights, - "count": count, - } - - def _find_highlights_in_postgres(self, search_query, events): - """Given a list of events and a search term, return a list of words - that match from the content of the event. - - This is used to give a list of words that clients can match against to - highlight the matching parts. - - Args: - search_query (str) - events (list): A list of events - - Returns: - deferred : A set of strings. - """ - - def f(txn): - highlight_words = set() - for event in events: - # As a hack we simply join values of all possible keys. This is - # fine since we're only using them to find possible highlights. - values = [] - for key in ("body", "name", "topic"): - v = event.content.get(key, None) - if v: - values.append(v) - - if not values: - continue - - value = " ".join(values) - - # We need to find some values for StartSel and StopSel that - # aren't in the value so that we can pick results out. - start_sel = "<" - stop_sel = ">" - - while start_sel in value: - start_sel += "<" - while stop_sel in value: - stop_sel += ">" - - query = "SELECT ts_headline(?, to_tsquery('english', ?), %s)" % ( - _to_postgres_options( - { - "StartSel": start_sel, - "StopSel": stop_sel, - "MaxFragments": "50", - } - ) - ) - txn.execute(query, (value, search_query)) - headline, = txn.fetchall()[0] - - # Now we need to pick the possible highlights out of the haedline - # result. - matcher_regex = "%s(.*?)%s" % ( - re.escape(start_sel), - re.escape(stop_sel), - ) - - res = re.findall(matcher_regex, headline) - highlight_words.update([r.lower() for r in res]) - - return highlight_words - - return self.runInteraction("_find_highlights", f) - - -def _to_postgres_options(options_dict): - return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),) - - -def _parse_query(database_engine, search_term): - """Takes a plain unicode string from the user and converts it into a form - that can be passed to database. - We use this so that we can add prefix matching, which isn't something - that is supported by default. - """ - - # Pull out the individual words, discarding any non-word characters. - results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) - - if isinstance(database_engine, PostgresEngine): - return " & ".join(result + ":*" for result in results) - elif isinstance(database_engine, Sqlite3Engine): - return " & ".join(result + "*" for result in results) - else: - # This should be unreachable. - raise Exception("Unrecognized database engine") diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py deleted file mode 100644 index fb83218f90..0000000000 --- a/synapse/storage/signatures.py +++ /dev/null @@ -1,102 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import six - -from unpaddedbase64 import encode_base64 - -from twisted.internet import defer - -from synapse.crypto.event_signing import compute_event_reference_hash -from synapse.util.caches.descriptors import cached, cachedList - -from ._base import SQLBaseStore - -# py2 sqlite has buffer hardcoded as only binary type, so we must use it, -# despite being deprecated and removed in favor of memoryview -if six.PY2: - db_binary_type = six.moves.builtins.buffer -else: - db_binary_type = memoryview - - -class SignatureWorkerStore(SQLBaseStore): - @cached() - def get_event_reference_hash(self, event_id): - # This is a dummy function to allow get_event_reference_hashes - # to use its cache - raise NotImplementedError() - - @cachedList( - cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1 - ) - def get_event_reference_hashes(self, event_ids): - def f(txn): - return { - event_id: self._get_event_reference_hashes_txn(txn, event_id) - for event_id in event_ids - } - - return self.runInteraction("get_event_reference_hashes", f) - - @defer.inlineCallbacks - def add_event_hashes(self, event_ids): - hashes = yield self.get_event_reference_hashes(event_ids) - hashes = { - e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"} - for e_id, h in hashes.items() - } - - return list(hashes.items()) - - def _get_event_reference_hashes_txn(self, txn, event_id): - """Get all the hashes for a given PDU. - Args: - txn (cursor): - event_id (str): Id for the Event. - Returns: - A dict[unicode, bytes] of algorithm -> hash. - """ - query = ( - "SELECT algorithm, hash" - " FROM event_reference_hashes" - " WHERE event_id = ?" - ) - txn.execute(query, (event_id,)) - return {k: v for k, v in txn} - - -class SignatureStore(SignatureWorkerStore): - """Persistence for event signatures and hashes""" - - def _store_event_reference_hashes_txn(self, txn, events): - """Store a hash for a PDU - Args: - txn (cursor): - events (list): list of Events. - """ - - vals = [] - for event in events: - ref_alg, ref_hash_bytes = compute_event_reference_hash(event) - vals.append( - { - "event_id": event.event_id, - "algorithm": ref_alg, - "hash": db_binary_type(ref_hash_bytes), - } - ) - - self._simple_insert_many_txn(txn, table="event_reference_hashes", values=vals) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index a941a5ae3f..a2df8fa827 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -14,45 +14,16 @@ # limitations under the License. import logging -from collections import namedtuple from six import iteritems, itervalues -from six.moves import range import attr -from twisted.internet import defer - from synapse.api.constants import EventTypes -from synapse.api.errors import NotFoundError -from synapse.storage._base import SQLBaseStore -from synapse.storage.background_updates import BackgroundUpdateStore -from synapse.storage.engines import PostgresEngine -from synapse.storage.events_worker import EventsWorkerStore -from synapse.util.caches import get_cache_factor_for, intern_string -from synapse.util.caches.descriptors import cached, cachedList -from synapse.util.caches.dictionary_cache import DictionaryCache -from synapse.util.stringutils import to_ascii logger = logging.getLogger(__name__) -MAX_STATE_DELTA_HOPS = 100 - - -class _GetStateGroupDelta( - namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids")) -): - """Return type of get_state_group_delta that implements __len__, which lets - us use the itrable flag when caching - """ - - __slots__ = [] - - def __len__(self): - return len(self.delta_ids) if self.delta_ids else 0 - - @attr.s(slots=True) class StateFilter(object): """A filter used when querying for state. @@ -351,1195 +322,3 @@ class StateFilter(object): ) return member_filter, non_member_filter - - -class StateGroupBackgroundUpdateStore(SQLBaseStore): - """Defines functions related to state groups needed to run the state backgroud - updates. - """ - - def _count_state_group_hops_txn(self, txn, state_group): - """Given a state group, count how many hops there are in the tree. - - This is used to ensure the delta chains don't get too long. - """ - if isinstance(self.database_engine, PostgresEngine): - sql = """ - WITH RECURSIVE state(state_group) AS ( - VALUES(?::bigint) - UNION ALL - SELECT prev_state_group FROM state_group_edges e, state s - WHERE s.state_group = e.state_group - ) - SELECT count(*) FROM state; - """ - - txn.execute(sql, (state_group,)) - row = txn.fetchone() - if row and row[0]: - return row[0] - else: - return 0 - else: - # We don't use WITH RECURSIVE on sqlite3 as there are distributions - # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) - next_group = state_group - count = 0 - - while next_group: - next_group = self._simple_select_one_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": next_group}, - retcol="prev_state_group", - allow_none=True, - ) - if next_group: - count += 1 - - return count - - def _get_state_groups_from_groups_txn( - self, txn, groups, state_filter=StateFilter.all() - ): - results = {group: {} for group in groups} - - where_clause, where_args = state_filter.make_sql_filter_clause() - - # Unless the filter clause is empty, we're going to append it after an - # existing where clause - if where_clause: - where_clause = " AND (%s)" % (where_clause,) - - if isinstance(self.database_engine, PostgresEngine): - # Temporarily disable sequential scans in this transaction. This is - # a temporary hack until we can add the right indices in - txn.execute("SET LOCAL enable_seqscan=off") - - # The below query walks the state_group tree so that the "state" - # table includes all state_groups in the tree. It then joins - # against `state_groups_state` to fetch the latest state. - # It assumes that previous state groups are always numerically - # lesser. - # The PARTITION is used to get the event_id in the greatest state - # group for the given type, state_key. - # This may return multiple rows per (type, state_key), but last_value - # should be the same. - sql = """ - WITH RECURSIVE state(state_group) AS ( - VALUES(?::bigint) - UNION ALL - SELECT prev_state_group FROM state_group_edges e, state s - WHERE s.state_group = e.state_group - ) - SELECT DISTINCT type, state_key, last_value(event_id) OVER ( - PARTITION BY type, state_key ORDER BY state_group ASC - ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING - ) AS event_id FROM state_groups_state - WHERE state_group IN ( - SELECT state_group FROM state - ) - """ - - for group in groups: - args = [group] - args.extend(where_args) - - txn.execute(sql + where_clause, args) - for row in txn: - typ, state_key, event_id = row - key = (typ, state_key) - results[group][key] = event_id - else: - max_entries_returned = state_filter.max_entries_returned() - - # We don't use WITH RECURSIVE on sqlite3 as there are distributions - # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) - for group in groups: - next_group = group - - while next_group: - # We did this before by getting the list of group ids, and - # then passing that list to sqlite to get latest event for - # each (type, state_key). However, that was terribly slow - # without the right indices (which we can't add until - # after we finish deduping state, which requires this func) - args = [next_group] - args.extend(where_args) - - txn.execute( - "SELECT type, state_key, event_id FROM state_groups_state" - " WHERE state_group = ? " + where_clause, - args, - ) - results[group].update( - ((typ, state_key), event_id) - for typ, state_key, event_id in txn - if (typ, state_key) not in results[group] - ) - - # If the number of entries in the (type,state_key)->event_id dict - # matches the number of (type,state_keys) types we were searching - # for, then we must have found them all, so no need to go walk - # further down the tree... UNLESS our types filter contained - # wildcards (i.e. Nones) in which case we have to do an exhaustive - # search - if ( - max_entries_returned is not None - and len(results[group]) == max_entries_returned - ): - break - - next_group = self._simple_select_one_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": next_group}, - retcol="prev_state_group", - allow_none=True, - ) - - return results - - -# this inherits from EventsWorkerStore because it calls self.get_events -class StateGroupWorkerStore( - EventsWorkerStore, StateGroupBackgroundUpdateStore, SQLBaseStore -): - """The parts of StateGroupStore that can be called from workers. - """ - - STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" - STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" - CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" - - def __init__(self, db_conn, hs): - super(StateGroupWorkerStore, self).__init__(db_conn, hs) - - # Originally the state store used a single DictionaryCache to cache the - # event IDs for the state types in a given state group to avoid hammering - # on the state_group* tables. - # - # The point of using a DictionaryCache is that it can cache a subset - # of the state events for a given state group (i.e. a subset of the keys for a - # given dict which is an entry in the cache for a given state group ID). - # - # However, this poses problems when performing complicated queries - # on the store - for instance: "give me all the state for this group, but - # limit members to this subset of users", as DictionaryCache's API isn't - # rich enough to say "please cache any of these fields, apart from this subset". - # This is problematic when lazy loading members, which requires this behaviour, - # as without it the cache has no choice but to speculatively load all - # state events for the group, which negates the efficiency being sought. - # - # Rather than overcomplicating DictionaryCache's API, we instead split the - # state_group_cache into two halves - one for tracking non-member events, - # and the other for tracking member_events. This means that lazy loading - # queries can be made in a cache-friendly manner by querying both caches - # separately and then merging the result. So for the example above, you - # would query the members cache for a specific subset of state keys - # (which DictionaryCache will handle efficiently and fine) and the non-members - # cache for all state (which DictionaryCache will similarly handle fine) - # and then just merge the results together. - # - # We size the non-members cache to be smaller than the members cache as the - # vast majority of state in Matrix (today) is member events. - - self._state_group_cache = DictionaryCache( - "*stateGroupCache*", - # TODO: this hasn't been tuned yet - 50000 * get_cache_factor_for("stateGroupCache"), - ) - self._state_group_members_cache = DictionaryCache( - "*stateGroupMembersCache*", - 500000 * get_cache_factor_for("stateGroupMembersCache"), - ) - - @defer.inlineCallbacks - def get_room_version(self, room_id): - """Get the room_version of a given room - - Args: - room_id (str) - - Returns: - Deferred[str] - - Raises: - NotFoundError if the room is unknown - """ - # for now we do this by looking at the create event. We may want to cache this - # more intelligently in future. - - # Retrieve the room's create event - create_event = yield self.get_create_event_for_room(room_id) - return create_event.content.get("room_version", "1") - - @defer.inlineCallbacks - def get_room_predecessor(self, room_id): - """Get the predecessor room of an upgraded room if one exists. - Otherwise return None. - - Args: - room_id (str) - - Returns: - Deferred[unicode|None]: predecessor room id - - Raises: - NotFoundError if the room is unknown - """ - # Retrieve the room's create event - create_event = yield self.get_create_event_for_room(room_id) - - # Return predecessor if present - return create_event.content.get("predecessor", None) - - @defer.inlineCallbacks - def get_create_event_for_room(self, room_id): - """Get the create state event for a room. - - Args: - room_id (str) - - Returns: - Deferred[EventBase]: The room creation event. - - Raises: - NotFoundError if the room is unknown - """ - state_ids = yield self.get_current_state_ids(room_id) - create_id = state_ids.get((EventTypes.Create, "")) - - # If we can't find the create event, assume we've hit a dead end - if not create_id: - raise NotFoundError("Unknown room %s" % (room_id)) - - # Retrieve the room's create event and return - create_event = yield self.get_event(create_id) - return create_event - - @cached(max_entries=100000, iterable=True) - def get_current_state_ids(self, room_id): - """Get the current state event ids for a room based on the - current_state_events table. - - Args: - room_id (str) - - Returns: - deferred: dict of (type, state_key) -> event_id - """ - - def _get_current_state_ids_txn(txn): - txn.execute( - """SELECT type, state_key, event_id FROM current_state_events - WHERE room_id = ? - """, - (room_id,), - ) - - return { - (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn - } - - return self.runInteraction("get_current_state_ids", _get_current_state_ids_txn) - - # FIXME: how should this be cached? - def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()): - """Get the current state event of a given type for a room based on the - current_state_events table. This may not be as up-to-date as the result - of doing a fresh state resolution as per state_handler.get_current_state - - Args: - room_id (str) - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - Deferred[dict[tuple[str, str], str]]: Map from type/state_key to - event ID. - """ - - where_clause, where_args = state_filter.make_sql_filter_clause() - - if not where_clause: - # We delegate to the cached version - return self.get_current_state_ids(room_id) - - def _get_filtered_current_state_ids_txn(txn): - results = {} - sql = """ - SELECT type, state_key, event_id FROM current_state_events - WHERE room_id = ? - """ - - if where_clause: - sql += " AND (%s)" % (where_clause,) - - args = [room_id] - args.extend(where_args) - txn.execute(sql, args) - for row in txn: - typ, state_key, event_id = row - key = (intern_string(typ), intern_string(state_key)) - results[key] = event_id - - return results - - return self.runInteraction( - "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn - ) - - @defer.inlineCallbacks - def get_canonical_alias_for_room(self, room_id): - """Get canonical alias for room, if any - - Args: - room_id (str) - - Returns: - Deferred[str|None]: The canonical alias, if any - """ - - state = yield self.get_filtered_current_state_ids( - room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")]) - ) - - event_id = state.get((EventTypes.CanonicalAlias, "")) - if not event_id: - return - - event = yield self.get_event(event_id, allow_none=True) - if not event: - return - - return event.content.get("canonical_alias") - - @cached(max_entries=10000, iterable=True) - def get_state_group_delta(self, state_group): - """Given a state group try to return a previous group and a delta between - the old and the new. - - Returns: - (prev_group, delta_ids), where both may be None. - """ - - def _get_state_group_delta_txn(txn): - prev_group = self._simple_select_one_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": state_group}, - retcol="prev_state_group", - allow_none=True, - ) - - if not prev_group: - return _GetStateGroupDelta(None, None) - - delta_ids = self._simple_select_list_txn( - txn, - table="state_groups_state", - keyvalues={"state_group": state_group}, - retcols=("type", "state_key", "event_id"), - ) - - return _GetStateGroupDelta( - prev_group, - {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids}, - ) - - return self.runInteraction("get_state_group_delta", _get_state_group_delta_txn) - - @defer.inlineCallbacks - def get_state_groups_ids(self, _room_id, event_ids): - """Get the event IDs of all the state for the state groups for the given events - - Args: - _room_id (str): id of the room for these events - event_ids (iterable[str]): ids of the events - - Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) - """ - if not event_ids: - return {} - - event_to_groups = yield self._get_state_group_for_events(event_ids) - - groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups) - - return group_to_state - - @defer.inlineCallbacks - def get_state_ids_for_group(self, state_group): - """Get the event IDs of all the state in the given state group - - Args: - state_group (int) - - Returns: - Deferred[dict]: Resolves to a map of (type, state_key) -> event_id - """ - group_to_state = yield self._get_state_for_groups((state_group,)) - - return group_to_state[state_group] - - @defer.inlineCallbacks - def get_state_groups(self, room_id, event_ids): - """ Get the state groups for the given list of event_ids - - Returns: - Deferred[dict[int, list[EventBase]]]: - dict of state_group_id -> list of state events. - """ - if not event_ids: - return {} - - group_to_ids = yield self.get_state_groups_ids(room_id, event_ids) - - state_event_map = yield self.get_events( - [ - ev_id - for group_ids in itervalues(group_to_ids) - for ev_id in itervalues(group_ids) - ], - get_prev_content=False, - ) - - return { - group: [ - state_event_map[v] - for v in itervalues(event_id_map) - if v in state_event_map - ] - for group, event_id_map in iteritems(group_to_ids) - } - - @defer.inlineCallbacks - def _get_state_groups_from_groups(self, groups, state_filter): - """Returns the state groups for a given set of groups, filtering on - types of state events. - - Args: - groups(list[int]): list of state group IDs to query - state_filter (StateFilter): The state filter used to fetch state - from the database. - Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) - """ - results = {} - - chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] - for chunk in chunks: - res = yield self.runInteraction( - "_get_state_groups_from_groups", - self._get_state_groups_from_groups_txn, - chunk, - state_filter, - ) - results.update(res) - - return results - - @defer.inlineCallbacks - def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): - """Given a list of event_ids and type tuples, return a list of state - dicts for each event. - - Args: - event_ids (list[string]) - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - deferred: A dict of (event_id) -> (type, state_key) -> [state_events] - """ - event_to_groups = yield self._get_state_group_for_events(event_ids) - - groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups, state_filter) - - state_event_map = yield self.get_events( - [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)], - get_prev_content=False, - ) - - event_to_state = { - event_id: { - k: state_event_map[v] - for k, v in iteritems(group_to_state[group]) - if v in state_event_map - } - for event_id, group in iteritems(event_to_groups) - } - - return {event: event_to_state[event] for event in event_ids} - - @defer.inlineCallbacks - def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()): - """ - Get the state dicts corresponding to a list of events, containing the event_ids - of the state events (as opposed to the events themselves) - - Args: - event_ids(list(str)): events whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - A deferred dict from event_id -> (type, state_key) -> event_id - """ - event_to_groups = yield self._get_state_group_for_events(event_ids) - - groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups, state_filter) - - event_to_state = { - event_id: group_to_state[group] - for event_id, group in iteritems(event_to_groups) - } - - return {event: event_to_state[event] for event in event_ids} - - @defer.inlineCallbacks - def get_state_for_event(self, event_id, state_filter=StateFilter.all()): - """ - Get the state dict corresponding to a particular event - - Args: - event_id(str): event whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - A deferred dict from (type, state_key) -> state_event - """ - state_map = yield self.get_state_for_events([event_id], state_filter) - return state_map[event_id] - - @defer.inlineCallbacks - def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()): - """ - Get the state dict corresponding to a particular event - - Args: - event_id(str): event whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - A deferred dict from (type, state_key) -> state_event - """ - state_map = yield self.get_state_ids_for_events([event_id], state_filter) - return state_map[event_id] - - @cached(max_entries=50000) - def _get_state_group_for_event(self, event_id): - return self._simple_select_one_onecol( - table="event_to_state_groups", - keyvalues={"event_id": event_id}, - retcol="state_group", - allow_none=True, - desc="_get_state_group_for_event", - ) - - @cachedList( - cached_method_name="_get_state_group_for_event", - list_name="event_ids", - num_args=1, - inlineCallbacks=True, - ) - def _get_state_group_for_events(self, event_ids): - """Returns mapping event_id -> state_group - """ - rows = yield self._simple_select_many_batch( - table="event_to_state_groups", - column="event_id", - iterable=event_ids, - keyvalues={}, - retcols=("event_id", "state_group"), - desc="_get_state_group_for_events", - ) - - return {row["event_id"]: row["state_group"] for row in rows} - - def _get_state_for_group_using_cache(self, cache, group, state_filter): - """Checks if group is in cache. See `_get_state_for_groups` - - Args: - cache(DictionaryCache): the state group cache to use - group(int): The state group to lookup - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns 2-tuple (`state_dict`, `got_all`). - `got_all` is a bool indicating if we successfully retrieved all - requests state from the cache, if False we need to query the DB for the - missing state. - """ - is_all, known_absent, state_dict_ids = cache.get(group) - - if is_all or state_filter.is_full(): - # Either we have everything or want everything, either way - # `is_all` tells us whether we've gotten everything. - return state_filter.filter_state(state_dict_ids), is_all - - # tracks whether any of our requested types are missing from the cache - missing_types = False - - if state_filter.has_wildcards(): - # We don't know if we fetched all the state keys for the types in - # the filter that are wildcards, so we have to assume that we may - # have missed some. - missing_types = True - else: - # There aren't any wild cards, so `concrete_types()` returns the - # complete list of event types we're wanting. - for key in state_filter.concrete_types(): - if key not in state_dict_ids and key not in known_absent: - missing_types = True - break - - return state_filter.filter_state(state_dict_ids), not missing_types - - @defer.inlineCallbacks - def _get_state_for_groups(self, groups, state_filter=StateFilter.all()): - """Gets the state at each of a list of state groups, optionally - filtering by type/state_key - - Args: - groups (iterable[int]): list of state groups for which we want - to get the state. - state_filter (StateFilter): The state filter used to fetch state - from the database. - Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) - """ - - member_filter, non_member_filter = state_filter.get_member_split() - - # Now we look them up in the member and non-member caches - non_member_state, incomplete_groups_nm, = ( - yield self._get_state_for_groups_using_cache( - groups, self._state_group_cache, state_filter=non_member_filter - ) - ) - - member_state, incomplete_groups_m, = ( - yield self._get_state_for_groups_using_cache( - groups, self._state_group_members_cache, state_filter=member_filter - ) - ) - - state = dict(non_member_state) - for group in groups: - state[group].update(member_state[group]) - - # Now fetch any missing groups from the database - - incomplete_groups = incomplete_groups_m | incomplete_groups_nm - - if not incomplete_groups: - return state - - cache_sequence_nm = self._state_group_cache.sequence - cache_sequence_m = self._state_group_members_cache.sequence - - # Help the cache hit ratio by expanding the filter a bit - db_state_filter = state_filter.return_expanded() - - group_to_state_dict = yield self._get_state_groups_from_groups( - list(incomplete_groups), state_filter=db_state_filter - ) - - # Now lets update the caches - self._insert_into_cache( - group_to_state_dict, - db_state_filter, - cache_seq_num_members=cache_sequence_m, - cache_seq_num_non_members=cache_sequence_nm, - ) - - # And finally update the result dict, by filtering out any extra - # stuff we pulled out of the database. - for group, group_state_dict in iteritems(group_to_state_dict): - # We just replace any existing entries, as we will have loaded - # everything we need from the database anyway. - state[group] = state_filter.filter_state(group_state_dict) - - return state - - def _get_state_for_groups_using_cache(self, groups, cache, state_filter): - """Gets the state at each of a list of state groups, optionally - filtering by type/state_key, querying from a specific cache. - - Args: - groups (iterable[int]): list of state groups for which we want - to get the state. - cache (DictionaryCache): the cache of group ids to state dicts which - we will pass through - either the normal state cache or the specific - members state cache. - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of - dict of state_group_id -> (dict of (type, state_key) -> event id) - of entries in the cache, and the state group ids either missing - from the cache or incomplete. - """ - results = {} - incomplete_groups = set() - for group in set(groups): - state_dict_ids, got_all = self._get_state_for_group_using_cache( - cache, group, state_filter - ) - results[group] = state_dict_ids - - if not got_all: - incomplete_groups.add(group) - - return results, incomplete_groups - - def _insert_into_cache( - self, - group_to_state_dict, - state_filter, - cache_seq_num_members, - cache_seq_num_non_members, - ): - """Inserts results from querying the database into the relevant cache. - - Args: - group_to_state_dict (dict): The new entries pulled from database. - Map from state group to state dict - state_filter (StateFilter): The state filter used to fetch state - from the database. - cache_seq_num_members (int): Sequence number of member cache since - last lookup in cache - cache_seq_num_non_members (int): Sequence number of member cache since - last lookup in cache - """ - - # We need to work out which types we've fetched from the DB for the - # member vs non-member caches. This should be as accurate as possible, - # but can be an underestimate (e.g. when we have wild cards) - - member_filter, non_member_filter = state_filter.get_member_split() - if member_filter.is_full(): - # We fetched all member events - member_types = None - else: - # `concrete_types()` will only return a subset when there are wild - # cards in the filter, but that's fine. - member_types = member_filter.concrete_types() - - if non_member_filter.is_full(): - # We fetched all non member events - non_member_types = None - else: - non_member_types = non_member_filter.concrete_types() - - for group, group_state_dict in iteritems(group_to_state_dict): - state_dict_members = {} - state_dict_non_members = {} - - for k, v in iteritems(group_state_dict): - if k[0] == EventTypes.Member: - state_dict_members[k] = v - else: - state_dict_non_members[k] = v - - self._state_group_members_cache.update( - cache_seq_num_members, - key=group, - value=state_dict_members, - fetched_keys=member_types, - ) - - self._state_group_cache.update( - cache_seq_num_non_members, - key=group, - value=state_dict_non_members, - fetched_keys=non_member_types, - ) - - def store_state_group( - self, event_id, room_id, prev_group, delta_ids, current_state_ids - ): - """Store a new set of state, returning a newly assigned state group. - - Args: - event_id (str): The event ID for which the state was calculated - room_id (str) - prev_group (int|None): A previous state group for the room, optional. - delta_ids (dict|None): The delta between state at `prev_group` and - `current_state_ids`, if `prev_group` was given. Same format as - `current_state_ids`. - current_state_ids (dict): The state to store. Map of (type, state_key) - to event_id. - - Returns: - Deferred[int]: The state group ID - """ - - def _store_state_group_txn(txn): - if current_state_ids is None: - # AFAIK, this can never happen - raise Exception("current_state_ids cannot be None") - - state_group = self.database_engine.get_next_state_group_id(txn) - - self._simple_insert_txn( - txn, - table="state_groups", - values={"id": state_group, "room_id": room_id, "event_id": event_id}, - ) - - # We persist as a delta if we can, while also ensuring the chain - # of deltas isn't tooo long, as otherwise read performance degrades. - if prev_group: - is_in_db = self._simple_select_one_onecol_txn( - txn, - table="state_groups", - keyvalues={"id": prev_group}, - retcol="id", - allow_none=True, - ) - if not is_in_db: - raise Exception( - "Trying to persist state with unpersisted prev_group: %r" - % (prev_group,) - ) - - potential_hops = self._count_state_group_hops_txn(txn, prev_group) - if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: - self._simple_insert_txn( - txn, - table="state_group_edges", - values={"state_group": state_group, "prev_state_group": prev_group}, - ) - - self._simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": state_group, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in iteritems(delta_ids) - ], - ) - else: - self._simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": state_group, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in iteritems(current_state_ids) - ], - ) - - # Prefill the state group caches with this group. - # It's fine to use the sequence like this as the state group map - # is immutable. (If the map wasn't immutable then this prefill could - # race with another update) - - current_member_state_ids = { - s: ev - for (s, ev) in iteritems(current_state_ids) - if s[0] == EventTypes.Member - } - txn.call_after( - self._state_group_members_cache.update, - self._state_group_members_cache.sequence, - key=state_group, - value=dict(current_member_state_ids), - ) - - current_non_member_state_ids = { - s: ev - for (s, ev) in iteritems(current_state_ids) - if s[0] != EventTypes.Member - } - txn.call_after( - self._state_group_cache.update, - self._state_group_cache.sequence, - key=state_group, - value=dict(current_non_member_state_ids), - ) - - return state_group - - return self.runInteraction("store_state_group", _store_state_group_txn) - - -class StateBackgroundUpdateStore( - StateGroupBackgroundUpdateStore, BackgroundUpdateStore -): - - STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" - STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" - CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" - EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" - - def __init__(self, db_conn, hs): - super(StateBackgroundUpdateStore, self).__init__(db_conn, hs) - self.register_background_update_handler( - self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, - self._background_deduplicate_state, - ) - self.register_background_update_handler( - self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state - ) - self.register_background_index_update( - self.CURRENT_STATE_INDEX_UPDATE_NAME, - index_name="current_state_events_member_index", - table="current_state_events", - columns=["state_key"], - where_clause="type='m.room.member'", - ) - self.register_background_index_update( - self.EVENT_STATE_GROUP_INDEX_UPDATE_NAME, - index_name="event_to_state_groups_sg_index", - table="event_to_state_groups", - columns=["state_group"], - ) - - @defer.inlineCallbacks - def _background_deduplicate_state(self, progress, batch_size): - """This background update will slowly deduplicate state by reencoding - them as deltas. - """ - last_state_group = progress.get("last_state_group", 0) - rows_inserted = progress.get("rows_inserted", 0) - max_group = progress.get("max_group", None) - - BATCH_SIZE_SCALE_FACTOR = 100 - - batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR)) - - if max_group is None: - rows = yield self._execute( - "_background_deduplicate_state", - None, - "SELECT coalesce(max(id), 0) FROM state_groups", - ) - max_group = rows[0][0] - - def reindex_txn(txn): - new_last_state_group = last_state_group - for count in range(batch_size): - txn.execute( - "SELECT id, room_id FROM state_groups" - " WHERE ? < id AND id <= ?" - " ORDER BY id ASC" - " LIMIT 1", - (new_last_state_group, max_group), - ) - row = txn.fetchone() - if row: - state_group, room_id = row - - if not row or not state_group: - return True, count - - txn.execute( - "SELECT state_group FROM state_group_edges" - " WHERE state_group = ?", - (state_group,), - ) - - # If we reach a point where we've already started inserting - # edges we should stop. - if txn.fetchall(): - return True, count - - txn.execute( - "SELECT coalesce(max(id), 0) FROM state_groups" - " WHERE id < ? AND room_id = ?", - (state_group, room_id), - ) - prev_group, = txn.fetchone() - new_last_state_group = state_group - - if prev_group: - potential_hops = self._count_state_group_hops_txn(txn, prev_group) - if potential_hops >= MAX_STATE_DELTA_HOPS: - # We want to ensure chains are at most this long,# - # otherwise read performance degrades. - continue - - prev_state = self._get_state_groups_from_groups_txn( - txn, [prev_group] - ) - prev_state = prev_state[prev_group] - - curr_state = self._get_state_groups_from_groups_txn( - txn, [state_group] - ) - curr_state = curr_state[state_group] - - if not set(prev_state.keys()) - set(curr_state.keys()): - # We can only do a delta if the current has a strict super set - # of keys - - delta_state = { - key: value - for key, value in iteritems(curr_state) - if prev_state.get(key, None) != value - } - - self._simple_delete_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": state_group}, - ) - - self._simple_insert_txn( - txn, - table="state_group_edges", - values={ - "state_group": state_group, - "prev_state_group": prev_group, - }, - ) - - self._simple_delete_txn( - txn, - table="state_groups_state", - keyvalues={"state_group": state_group}, - ) - - self._simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": state_group, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in iteritems(delta_state) - ], - ) - - progress = { - "last_state_group": state_group, - "rows_inserted": rows_inserted + batch_size, - "max_group": max_group, - } - - self._background_update_progress_txn( - txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress - ) - - return False, batch_size - - finished, result = yield self.runInteraction( - self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn - ) - - if finished: - yield self._end_background_update( - self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME - ) - - return result * BATCH_SIZE_SCALE_FACTOR - - @defer.inlineCallbacks - def _background_index_state(self, progress, batch_size): - def reindex_txn(conn): - conn.rollback() - if isinstance(self.database_engine, PostgresEngine): - # postgres insists on autocommit for the index - conn.set_session(autocommit=True) - try: - txn = conn.cursor() - txn.execute( - "CREATE INDEX CONCURRENTLY state_groups_state_type_idx" - " ON state_groups_state(state_group, type, state_key)" - ) - txn.execute("DROP INDEX IF EXISTS state_groups_state_id") - finally: - conn.set_session(autocommit=False) - else: - txn = conn.cursor() - txn.execute( - "CREATE INDEX state_groups_state_type_idx" - " ON state_groups_state(state_group, type, state_key)" - ) - txn.execute("DROP INDEX IF EXISTS state_groups_state_id") - - yield self.runWithConnection(reindex_txn) - - yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) - - return 1 - - -class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore): - """ Keeps track of the state at a given event. - - This is done by the concept of `state groups`. Every event is a assigned - a state group (identified by an arbitrary string), which references a - collection of state events. The current state of an event is then the - collection of state events referenced by the event's state group. - - Hence, every change in the current state causes a new state group to be - generated. However, if no change happens (e.g., if we get a message event - with only one parent it inherits the state group from its parent.) - - There are three tables: - * `state_groups`: Stores group name, first event with in the group and - room id. - * `event_to_state_groups`: Maps events to state groups. - * `state_groups_state`: Maps state group to state events. - """ - - def __init__(self, db_conn, hs): - super(StateStore, self).__init__(db_conn, hs) - - def _store_event_state_mappings_txn(self, txn, events_and_contexts): - state_groups = {} - for event, context in events_and_contexts: - if event.internal_metadata.is_outlier(): - continue - - # if the event was rejected, just give it the same state as its - # predecessor. - if context.rejected: - state_groups[event.event_id] = context.prev_group - continue - - state_groups[event.event_id] = context.state_group - - self._simple_insert_many_txn( - txn, - table="event_to_state_groups", - values=[ - {"state_group": state_group_id, "event_id": event_id} - for event_id, state_group_id in iteritems(state_groups) - ], - ) - - for event_id, state_group_id in iteritems(state_groups): - txn.call_after( - self._get_state_group_for_event.prefill, (event_id,), state_group_id - ) diff --git a/synapse/storage/state_deltas.py b/synapse/storage/state_deltas.py deleted file mode 100644 index 28f33ec18f..0000000000 --- a/synapse/storage/state_deltas.py +++ /dev/null @@ -1,119 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2018 Vector Creations Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.storage._base import SQLBaseStore - -logger = logging.getLogger(__name__) - - -class StateDeltasStore(SQLBaseStore): - def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int): - """Fetch a list of room state changes since the given stream id - - Each entry in the result contains the following fields: - - stream_id (int) - - room_id (str) - - type (str): event type - - state_key (str): - - event_id (str|None): new event_id for this state key. None if the - state has been deleted. - - prev_event_id (str|None): previous event_id for this state key. None - if it's new state. - - Args: - prev_stream_id (int): point to get changes since (exclusive) - max_stream_id (int): the point that we know has been correctly persisted - - ie, an upper limit to return changes from. - - Returns: - Deferred[tuple[int, list[dict]]: A tuple consisting of: - - the stream id which these results go up to - - list of current_state_delta_stream rows. If it is empty, we are - up to date. - """ - prev_stream_id = int(prev_stream_id) - - # check we're not going backwards - assert prev_stream_id <= max_stream_id - - if not self._curr_state_delta_stream_cache.has_any_entity_changed( - prev_stream_id - ): - # if the CSDs haven't changed between prev_stream_id and now, we - # know for certain that they haven't changed between prev_stream_id and - # max_stream_id. - return max_stream_id, [] - - def get_current_state_deltas_txn(txn): - # First we calculate the max stream id that will give us less than - # N results. - # We arbitarily limit to 100 stream_id entries to ensure we don't - # select toooo many. - sql = """ - SELECT stream_id, count(*) - FROM current_state_delta_stream - WHERE stream_id > ? AND stream_id <= ? - GROUP BY stream_id - ORDER BY stream_id ASC - LIMIT 100 - """ - txn.execute(sql, (prev_stream_id, max_stream_id)) - - total = 0 - - for stream_id, count in txn: - total += count - if total > 100: - # We arbitarily limit to 100 entries to ensure we don't - # select toooo many. - logger.debug( - "Clipping current_state_delta_stream rows to stream_id %i", - stream_id, - ) - clipped_stream_id = stream_id - break - else: - # if there's no problem, we may as well go right up to the max_stream_id - clipped_stream_id = max_stream_id - - # Now actually get the deltas - sql = """ - SELECT stream_id, room_id, type, state_key, event_id, prev_event_id - FROM current_state_delta_stream - WHERE ? < stream_id AND stream_id <= ? - ORDER BY stream_id ASC - """ - txn.execute(sql, (prev_stream_id, clipped_stream_id)) - return clipped_stream_id, self.cursor_to_dict(txn) - - return self.runInteraction( - "get_current_state_deltas", get_current_state_deltas_txn - ) - - def _get_max_stream_id_in_current_state_deltas_txn(self, txn): - return self._simple_select_one_onecol_txn( - txn, - table="current_state_delta_stream", - keyvalues={}, - retcol="COALESCE(MAX(stream_id), -1)", - ) - - def get_max_stream_id_in_current_state_deltas(self): - return self.runInteraction( - "get_max_stream_id_in_current_state_deltas", - self._get_max_stream_id_in_current_state_deltas_txn, - ) diff --git a/synapse/storage/stats.py b/synapse/storage/stats.py deleted file mode 100644 index 7c224cd3d9..0000000000 --- a/synapse/storage/stats.py +++ /dev/null @@ -1,881 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2018, 2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from itertools import chain - -from twisted.internet import defer -from twisted.internet.defer import DeferredLock - -from synapse.api.constants import EventTypes, Membership -from synapse.storage import PostgresEngine -from synapse.storage.state_deltas import StateDeltasStore -from synapse.util.caches.descriptors import cached - -logger = logging.getLogger(__name__) - -# these fields track absolutes (e.g. total number of rooms on the server) -# You can think of these as Prometheus Gauges. -# You can draw these stats on a line graph. -# Example: number of users in a room -ABSOLUTE_STATS_FIELDS = { - "room": ( - "current_state_events", - "joined_members", - "invited_members", - "left_members", - "banned_members", - "local_users_in_room", - ), - "user": ("joined_rooms",), -} - -# these fields are per-timeslice and so should be reset to 0 upon a new slice -# You can draw these stats on a histogram. -# Example: number of events sent locally during a time slice -PER_SLICE_FIELDS = { - "room": ("total_events", "total_event_bytes"), - "user": ("invites_sent", "rooms_created", "total_events", "total_event_bytes"), -} - -TYPE_TO_TABLE = {"room": ("room_stats", "room_id"), "user": ("user_stats", "user_id")} - -# these are the tables (& ID columns) which contain our actual subjects -TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")} - - -class StatsStore(StateDeltasStore): - def __init__(self, db_conn, hs): - super(StatsStore, self).__init__(db_conn, hs) - - self.server_name = hs.hostname - self.clock = self.hs.get_clock() - self.stats_enabled = hs.config.stats_enabled - self.stats_bucket_size = hs.config.stats_bucket_size - - self.stats_delta_processing_lock = DeferredLock() - - self.register_background_update_handler( - "populate_stats_process_rooms", self._populate_stats_process_rooms - ) - self.register_background_update_handler( - "populate_stats_process_users", self._populate_stats_process_users - ) - # we no longer need to perform clean-up, but we will give ourselves - # the potential to reintroduce it in the future – so documentation - # will still encourage the use of this no-op handler. - self.register_noop_background_update("populate_stats_cleanup") - self.register_noop_background_update("populate_stats_prepare") - - def quantise_stats_time(self, ts): - """ - Quantises a timestamp to be a multiple of the bucket size. - - Args: - ts (int): the timestamp to quantise, in milliseconds since the Unix - Epoch - - Returns: - int: a timestamp which - - is divisible by the bucket size; - - is no later than `ts`; and - - is the largest such timestamp. - """ - return (ts // self.stats_bucket_size) * self.stats_bucket_size - - @defer.inlineCallbacks - def _populate_stats_process_users(self, progress, batch_size): - """ - This is a background update which regenerates statistics for users. - """ - if not self.stats_enabled: - yield self._end_background_update("populate_stats_process_users") - return 1 - - last_user_id = progress.get("last_user_id", "") - - def _get_next_batch(txn): - sql = """ - SELECT DISTINCT name FROM users - WHERE name > ? - ORDER BY name ASC - LIMIT ? - """ - txn.execute(sql, (last_user_id, batch_size)) - return [r for r, in txn] - - users_to_work_on = yield self.runInteraction( - "_populate_stats_process_users", _get_next_batch - ) - - # No more rooms -- complete the transaction. - if not users_to_work_on: - yield self._end_background_update("populate_stats_process_users") - return 1 - - for user_id in users_to_work_on: - yield self._calculate_and_set_initial_state_for_user(user_id) - progress["last_user_id"] = user_id - - yield self.runInteraction( - "populate_stats_process_users", - self._background_update_progress_txn, - "populate_stats_process_users", - progress, - ) - - return len(users_to_work_on) - - @defer.inlineCallbacks - def _populate_stats_process_rooms(self, progress, batch_size): - """ - This is a background update which regenerates statistics for rooms. - """ - if not self.stats_enabled: - yield self._end_background_update("populate_stats_process_rooms") - return 1 - - last_room_id = progress.get("last_room_id", "") - - def _get_next_batch(txn): - sql = """ - SELECT DISTINCT room_id FROM current_state_events - WHERE room_id > ? - ORDER BY room_id ASC - LIMIT ? - """ - txn.execute(sql, (last_room_id, batch_size)) - return [r for r, in txn] - - rooms_to_work_on = yield self.runInteraction( - "populate_stats_rooms_get_batch", _get_next_batch - ) - - # No more rooms -- complete the transaction. - if not rooms_to_work_on: - yield self._end_background_update("populate_stats_process_rooms") - return 1 - - for room_id in rooms_to_work_on: - yield self._calculate_and_set_initial_state_for_room(room_id) - progress["last_room_id"] = room_id - - yield self.runInteraction( - "_populate_stats_process_rooms", - self._background_update_progress_txn, - "populate_stats_process_rooms", - progress, - ) - - return len(rooms_to_work_on) - - def get_stats_positions(self): - """ - Returns the stats processor positions. - """ - return self._simple_select_one_onecol( - table="stats_incremental_position", - keyvalues={}, - retcol="stream_id", - desc="stats_incremental_position", - ) - - def update_room_state(self, room_id, fields): - """ - Args: - room_id (str) - fields (dict[str:Any]) - """ - - # For whatever reason some of the fields may contain null bytes, which - # postgres isn't a fan of, so we replace those fields with null. - for col in ( - "join_rules", - "history_visibility", - "encryption", - "name", - "topic", - "avatar", - "canonical_alias", - ): - field = fields.get(col) - if field and "\0" in field: - fields[col] = None - - return self._simple_upsert( - table="room_stats_state", - keyvalues={"room_id": room_id}, - values=fields, - desc="update_room_state", - ) - - def get_statistics_for_subject(self, stats_type, stats_id, start, size=100): - """ - Get statistics for a given subject. - - Args: - stats_type (str): The type of subject - stats_id (str): The ID of the subject (e.g. room_id or user_id) - start (int): Pagination start. Number of entries, not timestamp. - size (int): How many entries to return. - - Returns: - Deferred[list[dict]], where the dict has the keys of - ABSOLUTE_STATS_FIELDS[stats_type], and "bucket_size" and "end_ts". - """ - return self.runInteraction( - "get_statistics_for_subject", - self._get_statistics_for_subject_txn, - stats_type, - stats_id, - start, - size, - ) - - def _get_statistics_for_subject_txn( - self, txn, stats_type, stats_id, start, size=100 - ): - """ - Transaction-bound version of L{get_statistics_for_subject}. - """ - - table, id_col = TYPE_TO_TABLE[stats_type] - selected_columns = list( - ABSOLUTE_STATS_FIELDS[stats_type] + PER_SLICE_FIELDS[stats_type] - ) - - slice_list = self._simple_select_list_paginate_txn( - txn, - table + "_historical", - {id_col: stats_id}, - "end_ts", - start, - size, - retcols=selected_columns + ["bucket_size", "end_ts"], - order_direction="DESC", - ) - - return slice_list - - def get_room_stats_state(self, room_id): - """ - Returns the current room_stats_state for a room. - - Args: - room_id (str): The ID of the room to return state for. - - Returns (dict): - Dictionary containing these keys: - "name", "topic", "canonical_alias", "avatar", "join_rules", - "history_visibility" - """ - return self._simple_select_one( - "room_stats_state", - {"room_id": room_id}, - retcols=( - "name", - "topic", - "canonical_alias", - "avatar", - "join_rules", - "history_visibility", - ), - ) - - @cached() - def get_earliest_token_for_stats(self, stats_type, id): - """ - Fetch the "earliest token". This is used by the room stats delta - processor to ignore deltas that have been processed between the - start of the background task and any particular room's stats - being calculated. - - Returns: - Deferred[int] - """ - table, id_col = TYPE_TO_TABLE[stats_type] - - return self._simple_select_one_onecol( - "%s_current" % (table,), - keyvalues={id_col: id}, - retcol="completed_delta_stream_id", - allow_none=True, - ) - - def bulk_update_stats_delta(self, ts, updates, stream_id): - """Bulk update stats tables for a given stream_id and updates the stats - incremental position. - - Args: - ts (int): Current timestamp in ms - updates(dict[str, dict[str, dict[str, Counter]]]): The updates to - commit as a mapping stats_type -> stats_id -> field -> delta. - stream_id (int): Current position. - - Returns: - Deferred - """ - - def _bulk_update_stats_delta_txn(txn): - for stats_type, stats_updates in updates.items(): - for stats_id, fields in stats_updates.items(): - logger.info( - "Updating %s stats for %s: %s", stats_type, stats_id, fields - ) - self._update_stats_delta_txn( - txn, - ts=ts, - stats_type=stats_type, - stats_id=stats_id, - fields=fields, - complete_with_stream_id=stream_id, - ) - - self._simple_update_one_txn( - txn, - table="stats_incremental_position", - keyvalues={}, - updatevalues={"stream_id": stream_id}, - ) - - return self.runInteraction( - "bulk_update_stats_delta", _bulk_update_stats_delta_txn - ) - - def update_stats_delta( - self, - ts, - stats_type, - stats_id, - fields, - complete_with_stream_id, - absolute_field_overrides=None, - ): - """ - Updates the statistics for a subject, with a delta (difference/relative - change). - - Args: - ts (int): timestamp of the change - stats_type (str): "room" or "user" – the kind of subject - stats_id (str): the subject's ID (room ID or user ID) - fields (dict[str, int]): Deltas of stats values. - complete_with_stream_id (int, optional): - If supplied, converts an incomplete row into a complete row, - with the supplied stream_id marked as the stream_id where the - row was completed. - absolute_field_overrides (dict[str, int]): Current stats values - (i.e. not deltas) of absolute fields. - Does not work with per-slice fields. - """ - - return self.runInteraction( - "update_stats_delta", - self._update_stats_delta_txn, - ts, - stats_type, - stats_id, - fields, - complete_with_stream_id=complete_with_stream_id, - absolute_field_overrides=absolute_field_overrides, - ) - - def _update_stats_delta_txn( - self, - txn, - ts, - stats_type, - stats_id, - fields, - complete_with_stream_id, - absolute_field_overrides=None, - ): - if absolute_field_overrides is None: - absolute_field_overrides = {} - - table, id_col = TYPE_TO_TABLE[stats_type] - - quantised_ts = self.quantise_stats_time(int(ts)) - end_ts = quantised_ts + self.stats_bucket_size - - # Lets be paranoid and check that all the given field names are known - abs_field_names = ABSOLUTE_STATS_FIELDS[stats_type] - slice_field_names = PER_SLICE_FIELDS[stats_type] - for field in chain(fields.keys(), absolute_field_overrides.keys()): - if field not in abs_field_names and field not in slice_field_names: - # guard against potential SQL injection dodginess - raise ValueError( - "%s is not a recognised field" - " for stats type %s" % (field, stats_type) - ) - - # Per slice fields do not get added to the _current table - - # This calculates the deltas (`field = field + ?` values) - # for absolute fields, - # * defaulting to 0 if not specified - # (required for the INSERT part of upserting to work) - # * omitting overrides specified in `absolute_field_overrides` - deltas_of_absolute_fields = { - key: fields.get(key, 0) - for key in abs_field_names - if key not in absolute_field_overrides - } - - # Keep the delta stream ID field up to date - absolute_field_overrides = absolute_field_overrides.copy() - absolute_field_overrides["completed_delta_stream_id"] = complete_with_stream_id - - # first upsert the `_current` table - self._upsert_with_additive_relatives_txn( - txn=txn, - table=table + "_current", - keyvalues={id_col: stats_id}, - absolutes=absolute_field_overrides, - additive_relatives=deltas_of_absolute_fields, - ) - - per_slice_additive_relatives = { - key: fields.get(key, 0) for key in slice_field_names - } - self._upsert_copy_from_table_with_additive_relatives_txn( - txn=txn, - into_table=table + "_historical", - keyvalues={id_col: stats_id}, - extra_dst_insvalues={"bucket_size": self.stats_bucket_size}, - extra_dst_keyvalues={"end_ts": end_ts}, - additive_relatives=per_slice_additive_relatives, - src_table=table + "_current", - copy_columns=abs_field_names, - ) - - def _upsert_with_additive_relatives_txn( - self, txn, table, keyvalues, absolutes, additive_relatives - ): - """Used to update values in the stats tables. - - This is basically a slightly convoluted upsert that *adds* to any - existing rows. - - Args: - txn - table (str): Table name - keyvalues (dict[str, any]): Row-identifying key values - absolutes (dict[str, any]): Absolute (set) fields - additive_relatives (dict[str, int]): Fields that will be added onto - if existing row present. - """ - if self.database_engine.can_native_upsert: - absolute_updates = [ - "%(field)s = EXCLUDED.%(field)s" % {"field": field} - for field in absolutes.keys() - ] - - relative_updates = [ - "%(field)s = EXCLUDED.%(field)s + %(table)s.%(field)s" - % {"table": table, "field": field} - for field in additive_relatives.keys() - ] - - insert_cols = [] - qargs = [] - - for (key, val) in chain( - keyvalues.items(), absolutes.items(), additive_relatives.items() - ): - insert_cols.append(key) - qargs.append(val) - - sql = """ - INSERT INTO %(table)s (%(insert_cols_cs)s) - VALUES (%(insert_vals_qs)s) - ON CONFLICT (%(key_columns)s) DO UPDATE SET %(updates)s - """ % { - "table": table, - "insert_cols_cs": ", ".join(insert_cols), - "insert_vals_qs": ", ".join( - ["?"] * (len(keyvalues) + len(absolutes) + len(additive_relatives)) - ), - "key_columns": ", ".join(keyvalues), - "updates": ", ".join(chain(absolute_updates, relative_updates)), - } - - txn.execute(sql, qargs) - else: - self.database_engine.lock_table(txn, table) - retcols = list(chain(absolutes.keys(), additive_relatives.keys())) - current_row = self._simple_select_one_txn( - txn, table, keyvalues, retcols, allow_none=True - ) - if current_row is None: - merged_dict = {**keyvalues, **absolutes, **additive_relatives} - self._simple_insert_txn(txn, table, merged_dict) - else: - for (key, val) in additive_relatives.items(): - current_row[key] += val - current_row.update(absolutes) - self._simple_update_one_txn(txn, table, keyvalues, current_row) - - def _upsert_copy_from_table_with_additive_relatives_txn( - self, - txn, - into_table, - keyvalues, - extra_dst_keyvalues, - extra_dst_insvalues, - additive_relatives, - src_table, - copy_columns, - ): - """Updates the historic stats table with latest updates. - - This involves copying "absolute" fields from the `_current` table, and - adding relative fields to any existing values. - - Args: - txn: Transaction - into_table (str): The destination table to UPSERT the row into - keyvalues (dict[str, any]): Row-identifying key values - extra_dst_keyvalues (dict[str, any]): Additional keyvalues - for `into_table`. - extra_dst_insvalues (dict[str, any]): Additional values to insert - on new row creation for `into_table`. - additive_relatives (dict[str, any]): Fields that will be added onto - if existing row present. (Must be disjoint from copy_columns.) - src_table (str): The source table to copy from - copy_columns (iterable[str]): The list of columns to copy - """ - if self.database_engine.can_native_upsert: - ins_columns = chain( - keyvalues, - copy_columns, - additive_relatives, - extra_dst_keyvalues, - extra_dst_insvalues, - ) - sel_exprs = chain( - keyvalues, - copy_columns, - ( - "?" - for _ in chain( - additive_relatives, extra_dst_keyvalues, extra_dst_insvalues - ) - ), - ) - keyvalues_where = ("%s = ?" % f for f in keyvalues) - - sets_cc = ("%s = EXCLUDED.%s" % (f, f) for f in copy_columns) - sets_ar = ( - "%s = EXCLUDED.%s + %s.%s" % (f, f, into_table, f) - for f in additive_relatives - ) - - sql = """ - INSERT INTO %(into_table)s (%(ins_columns)s) - SELECT %(sel_exprs)s - FROM %(src_table)s - WHERE %(keyvalues_where)s - ON CONFLICT (%(keyvalues)s) - DO UPDATE SET %(sets)s - """ % { - "into_table": into_table, - "ins_columns": ", ".join(ins_columns), - "sel_exprs": ", ".join(sel_exprs), - "keyvalues_where": " AND ".join(keyvalues_where), - "src_table": src_table, - "keyvalues": ", ".join( - chain(keyvalues.keys(), extra_dst_keyvalues.keys()) - ), - "sets": ", ".join(chain(sets_cc, sets_ar)), - } - - qargs = list( - chain( - additive_relatives.values(), - extra_dst_keyvalues.values(), - extra_dst_insvalues.values(), - keyvalues.values(), - ) - ) - txn.execute(sql, qargs) - else: - self.database_engine.lock_table(txn, into_table) - src_row = self._simple_select_one_txn( - txn, src_table, keyvalues, copy_columns - ) - all_dest_keyvalues = {**keyvalues, **extra_dst_keyvalues} - dest_current_row = self._simple_select_one_txn( - txn, - into_table, - keyvalues=all_dest_keyvalues, - retcols=list(chain(additive_relatives.keys(), copy_columns)), - allow_none=True, - ) - - if dest_current_row is None: - merged_dict = { - **keyvalues, - **extra_dst_keyvalues, - **extra_dst_insvalues, - **src_row, - **additive_relatives, - } - self._simple_insert_txn(txn, into_table, merged_dict) - else: - for (key, val) in additive_relatives.items(): - src_row[key] = dest_current_row[key] + val - self._simple_update_txn(txn, into_table, all_dest_keyvalues, src_row) - - def get_changes_room_total_events_and_bytes(self, min_pos, max_pos): - """Fetches the counts of events in the given range of stream IDs. - - Args: - min_pos (int) - max_pos (int) - - Returns: - Deferred[dict[str, dict[str, int]]]: Mapping of room ID to field - changes. - """ - - return self.runInteraction( - "stats_incremental_total_events_and_bytes", - self.get_changes_room_total_events_and_bytes_txn, - min_pos, - max_pos, - ) - - def get_changes_room_total_events_and_bytes_txn(self, txn, low_pos, high_pos): - """Gets the total_events and total_event_bytes counts for rooms and - senders, in a range of stream_orderings (including backfilled events). - - Args: - txn - low_pos (int): Low stream ordering - high_pos (int): High stream ordering - - Returns: - tuple[dict[str, dict[str, int]], dict[str, dict[str, int]]]: The - room and user deltas for total_events/total_event_bytes in the - format of `stats_id` -> fields - """ - - if low_pos >= high_pos: - # nothing to do here. - return {}, {} - - if isinstance(self.database_engine, PostgresEngine): - new_bytes_expression = "OCTET_LENGTH(json)" - else: - new_bytes_expression = "LENGTH(CAST(json AS BLOB))" - - sql = """ - SELECT events.room_id, COUNT(*) AS new_events, SUM(%s) AS new_bytes - FROM events INNER JOIN event_json USING (event_id) - WHERE (? < stream_ordering AND stream_ordering <= ?) - OR (? <= stream_ordering AND stream_ordering <= ?) - GROUP BY events.room_id - """ % ( - new_bytes_expression, - ) - - txn.execute(sql, (low_pos, high_pos, -high_pos, -low_pos)) - - room_deltas = { - room_id: {"total_events": new_events, "total_event_bytes": new_bytes} - for room_id, new_events, new_bytes in txn - } - - sql = """ - SELECT events.sender, COUNT(*) AS new_events, SUM(%s) AS new_bytes - FROM events INNER JOIN event_json USING (event_id) - WHERE (? < stream_ordering AND stream_ordering <= ?) - OR (? <= stream_ordering AND stream_ordering <= ?) - GROUP BY events.sender - """ % ( - new_bytes_expression, - ) - - txn.execute(sql, (low_pos, high_pos, -high_pos, -low_pos)) - - user_deltas = { - user_id: {"total_events": new_events, "total_event_bytes": new_bytes} - for user_id, new_events, new_bytes in txn - if self.hs.is_mine_id(user_id) - } - - return room_deltas, user_deltas - - @defer.inlineCallbacks - def _calculate_and_set_initial_state_for_room(self, room_id): - """Calculate and insert an entry into room_stats_current. - - Args: - room_id (str) - - Returns: - Deferred[tuple[dict, dict, int]]: A tuple of room state, membership - counts and stream position. - """ - - def _fetch_current_state_stats(txn): - pos = self.get_room_max_stream_ordering() - - rows = self._simple_select_many_txn( - txn, - table="current_state_events", - column="type", - iterable=[ - EventTypes.Create, - EventTypes.JoinRules, - EventTypes.RoomHistoryVisibility, - EventTypes.Encryption, - EventTypes.Name, - EventTypes.Topic, - EventTypes.RoomAvatar, - EventTypes.CanonicalAlias, - ], - keyvalues={"room_id": room_id, "state_key": ""}, - retcols=["event_id"], - ) - - event_ids = [row["event_id"] for row in rows] - - txn.execute( - """ - SELECT membership, count(*) FROM current_state_events - WHERE room_id = ? AND type = 'm.room.member' - GROUP BY membership - """, - (room_id,), - ) - membership_counts = {membership: cnt for membership, cnt in txn} - - txn.execute( - """ - SELECT COALESCE(count(*), 0) FROM current_state_events - WHERE room_id = ? - """, - (room_id,), - ) - - current_state_events_count, = txn.fetchone() - - users_in_room = self.get_users_in_room_txn(txn, room_id) - - return ( - event_ids, - membership_counts, - current_state_events_count, - users_in_room, - pos, - ) - - ( - event_ids, - membership_counts, - current_state_events_count, - users_in_room, - pos, - ) = yield self.runInteraction( - "get_initial_state_for_room", _fetch_current_state_stats - ) - - state_event_map = yield self.get_events(event_ids, get_prev_content=False) - - room_state = { - "join_rules": None, - "history_visibility": None, - "encryption": None, - "name": None, - "topic": None, - "avatar": None, - "canonical_alias": None, - "is_federatable": True, - } - - for event in state_event_map.values(): - if event.type == EventTypes.JoinRules: - room_state["join_rules"] = event.content.get("join_rule") - elif event.type == EventTypes.RoomHistoryVisibility: - room_state["history_visibility"] = event.content.get( - "history_visibility" - ) - elif event.type == EventTypes.Encryption: - room_state["encryption"] = event.content.get("algorithm") - elif event.type == EventTypes.Name: - room_state["name"] = event.content.get("name") - elif event.type == EventTypes.Topic: - room_state["topic"] = event.content.get("topic") - elif event.type == EventTypes.RoomAvatar: - room_state["avatar"] = event.content.get("url") - elif event.type == EventTypes.CanonicalAlias: - room_state["canonical_alias"] = event.content.get("alias") - elif event.type == EventTypes.Create: - room_state["is_federatable"] = ( - event.content.get("m.federate", True) is True - ) - - yield self.update_room_state(room_id, room_state) - - local_users_in_room = [u for u in users_in_room if self.hs.is_mine_id(u)] - - yield self.update_stats_delta( - ts=self.clock.time_msec(), - stats_type="room", - stats_id=room_id, - fields={}, - complete_with_stream_id=pos, - absolute_field_overrides={ - "current_state_events": current_state_events_count, - "joined_members": membership_counts.get(Membership.JOIN, 0), - "invited_members": membership_counts.get(Membership.INVITE, 0), - "left_members": membership_counts.get(Membership.LEAVE, 0), - "banned_members": membership_counts.get(Membership.BAN, 0), - "local_users_in_room": len(local_users_in_room), - }, - ) - - @defer.inlineCallbacks - def _calculate_and_set_initial_state_for_user(self, user_id): - def _calculate_and_set_initial_state_for_user_txn(txn): - pos = self._get_max_stream_id_in_current_state_deltas_txn(txn) - - txn.execute( - """ - SELECT COUNT(distinct room_id) FROM current_state_events - WHERE type = 'm.room.member' AND state_key = ? - AND membership = 'join' - """, - (user_id,), - ) - count, = txn.fetchone() - return count, pos - - joined_rooms, pos = yield self.runInteraction( - "calculate_and_set_initial_state_for_user", - _calculate_and_set_initial_state_for_user_txn, - ) - - yield self.update_stats_delta( - ts=self.clock.time_msec(), - stats_type="user", - stats_id=user_id, - fields={}, - complete_with_stream_id=pos, - absolute_field_overrides={"joined_rooms": joined_rooms}, - ) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py deleted file mode 100644 index 490454f19a..0000000000 --- a/synapse/storage/stream.py +++ /dev/null @@ -1,948 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" This module is responsible for getting events from the DB for pagination -and event streaming. - -The order it returns events in depend on whether we are streaming forwards or -are paginating backwards. We do this because we want to handle out of order -messages nicely, while still returning them in the correct order when we -paginate bacwards. - -This is implemented by keeping two ordering columns: stream_ordering and -topological_ordering. Stream ordering is basically insertion/received order -(except for events from backfill requests). The topological_ordering is a -weak ordering of events based on the pdu graph. - -This means that we have to have two different types of tokens, depending on -what sort order was used: - - stream tokens are of the form: "s%d", which maps directly to the column - - topological tokems: "t%d-%d", where the integers map to the topological - and stream ordering columns respectively. -""" - -import abc -import logging -from collections import namedtuple - -from six.moves import range - -from twisted.internet import defer - -from synapse.logging.context import make_deferred_yieldable, run_in_background -from synapse.storage._base import SQLBaseStore -from synapse.storage.engines import PostgresEngine -from synapse.storage.events_worker import EventsWorkerStore -from synapse.types import RoomStreamToken -from synapse.util.caches.stream_change_cache import StreamChangeCache - -logger = logging.getLogger(__name__) - - -MAX_STREAM_SIZE = 1000 - - -_STREAM_TOKEN = "stream" -_TOPOLOGICAL_TOKEN = "topological" - - -# Used as return values for pagination APIs -_EventDictReturn = namedtuple( - "_EventDictReturn", ("event_id", "topological_ordering", "stream_ordering") -) - - -def generate_pagination_where_clause( - direction, column_names, from_token, to_token, engine -): - """Creates an SQL expression to bound the columns by the pagination - tokens. - - For example creates an SQL expression like: - - (6, 7) >= (topological_ordering, stream_ordering) - AND (5, 3) < (topological_ordering, stream_ordering) - - would be generated for dir=b, from_token=(6, 7) and to_token=(5, 3). - - Note that tokens are considered to be after the row they are in, e.g. if - a row A has a token T, then we consider A to be before T. This convention - is important when figuring out inequalities for the generated SQL, and - produces the following result: - - If paginating forwards then we exclude any rows matching the from - token, but include those that match the to token. - - If paginating backwards then we include any rows matching the from - token, but include those that match the to token. - - Args: - direction (str): Whether we're paginating backwards("b") or - forwards ("f"). - column_names (tuple[str, str]): The column names to bound. Must *not* - be user defined as these get inserted directly into the SQL - statement without escapes. - from_token (tuple[int, int]|None): The start point for the pagination. - This is an exclusive minimum bound if direction is "f", and an - inclusive maximum bound if direction is "b". - to_token (tuple[int, int]|None): The endpoint point for the pagination. - This is an inclusive maximum bound if direction is "f", and an - exclusive minimum bound if direction is "b". - engine: The database engine to generate the clauses for - - Returns: - str: The sql expression - """ - assert direction in ("b", "f") - - where_clause = [] - if from_token: - where_clause.append( - _make_generic_sql_bound( - bound=">=" if direction == "b" else "<", - column_names=column_names, - values=from_token, - engine=engine, - ) - ) - - if to_token: - where_clause.append( - _make_generic_sql_bound( - bound="<" if direction == "b" else ">=", - column_names=column_names, - values=to_token, - engine=engine, - ) - ) - - return " AND ".join(where_clause) - - -def _make_generic_sql_bound(bound, column_names, values, engine): - """Create an SQL expression that bounds the given column names by the - values, e.g. create the equivalent of `(1, 2) < (col1, col2)`. - - Only works with two columns. - - Older versions of SQLite don't support that syntax so we have to expand it - out manually. - - Args: - bound (str): The comparison operator to use. One of ">", "<", ">=", - "<=", where the values are on the left and columns on the right. - names (tuple[str, str]): The column names. Must *not* be user defined - as these get inserted directly into the SQL statement without - escapes. - values (tuple[int|None, int]): The values to bound the columns by. If - the first value is None then only creates a bound on the second - column. - engine: The database engine to generate the SQL for - - Returns: - str - """ - - assert bound in (">", "<", ">=", "<=") - - name1, name2 = column_names - val1, val2 = values - - if val1 is None: - val2 = int(val2) - return "(%d %s %s)" % (val2, bound, name2) - - val1 = int(val1) - val2 = int(val2) - - if isinstance(engine, PostgresEngine): - # Postgres doesn't optimise ``(x < a) OR (x=a AND y ? AND stream_ordering <= ?" - " ORDER BY stream_ordering %s LIMIT ?" - ) % (order,) - txn.execute(sql, (room_id, from_id, to_id, limit)) - - rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] - return rows - - rows = yield self.runInteraction("get_room_events_stream_for_room", f) - - ret = yield self.get_events_as_list( - [r.event_id for r in rows], get_prev_content=True - ) - - self._set_before_and_after(ret, rows, topo_order=from_id is None) - - if order.lower() == "desc": - ret.reverse() - - if rows: - key = "s%d" % min(r.stream_ordering for r in rows) - else: - # Assume we didn't get anything because there was nothing to - # get. - key = from_key - - return ret, key - - @defer.inlineCallbacks - def get_membership_changes_for_user(self, user_id, from_key, to_key): - from_id = RoomStreamToken.parse_stream_token(from_key).stream - to_id = RoomStreamToken.parse_stream_token(to_key).stream - - if from_key == to_key: - return [] - - if from_id: - has_changed = self._membership_stream_cache.has_entity_changed( - user_id, int(from_id) - ) - if not has_changed: - return [] - - def f(txn): - sql = ( - "SELECT m.event_id, stream_ordering FROM events AS e," - " room_memberships AS m" - " WHERE e.event_id = m.event_id" - " AND m.user_id = ?" - " AND e.stream_ordering > ? AND e.stream_ordering <= ?" - " ORDER BY e.stream_ordering ASC" - ) - txn.execute(sql, (user_id, from_id, to_id)) - - rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] - - return rows - - rows = yield self.runInteraction("get_membership_changes_for_user", f) - - ret = yield self.get_events_as_list( - [r.event_id for r in rows], get_prev_content=True - ) - - self._set_before_and_after(ret, rows, topo_order=False) - - return ret - - @defer.inlineCallbacks - def get_recent_events_for_room(self, room_id, limit, end_token): - """Get the most recent events in the room in topological ordering. - - Args: - room_id (str) - limit (int) - end_token (str): The stream token representing now. - - Returns: - Deferred[tuple[list[FrozenEvent], str]]: Returns a list of - events and a token pointing to the start of the returned - events. - The events returned are in ascending order. - """ - - rows, token = yield self.get_recent_event_ids_for_room( - room_id, limit, end_token - ) - - logger.debug("stream before") - events = yield self.get_events_as_list( - [r.event_id for r in rows], get_prev_content=True - ) - logger.debug("stream after") - - self._set_before_and_after(events, rows) - - return (events, token) - - @defer.inlineCallbacks - def get_recent_event_ids_for_room(self, room_id, limit, end_token): - """Get the most recent events in the room in topological ordering. - - Args: - room_id (str) - limit (int) - end_token (str): The stream token representing now. - - Returns: - Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of - _EventDictReturn and a token pointing to the start of the returned - events. - The events returned are in ascending order. - """ - # Allow a zero limit here, and no-op. - if limit == 0: - return [], end_token - - end_token = RoomStreamToken.parse(end_token) - - rows, token = yield self.runInteraction( - "get_recent_event_ids_for_room", - self._paginate_room_events_txn, - room_id, - from_token=end_token, - limit=limit, - ) - - # We want to return the results in ascending order. - rows.reverse() - - return rows, token - - def get_room_event_after_stream_ordering(self, room_id, stream_ordering): - """Gets details of the first event in a room at or after a stream ordering - - Args: - room_id (str): - stream_ordering (int): - - Returns: - Deferred[(int, int, str)]: - (stream ordering, topological ordering, event_id) - """ - - def _f(txn): - sql = ( - "SELECT stream_ordering, topological_ordering, event_id" - " FROM events" - " WHERE room_id = ? AND stream_ordering >= ?" - " AND NOT outlier" - " ORDER BY stream_ordering" - " LIMIT 1" - ) - txn.execute(sql, (room_id, stream_ordering)) - return txn.fetchone() - - return self.runInteraction("get_room_event_after_stream_ordering", _f) - - @defer.inlineCallbacks - def get_room_events_max_id(self, room_id=None): - """Returns the current token for rooms stream. - - By default, it returns the current global stream token. Specifying a - `room_id` causes it to return the current room specific topological - token. - """ - token = yield self.get_room_max_stream_ordering() - if room_id is None: - return "s%d" % (token,) - else: - topo = yield self.runInteraction( - "_get_max_topological_txn", self._get_max_topological_txn, room_id - ) - return "t%d-%d" % (topo, token) - - def get_stream_token_for_event(self, event_id): - """The stream token for an event - Args: - event_id(str): The id of the event to look up a stream token for. - Raises: - StoreError if the event wasn't in the database. - Returns: - A deferred "s%d" stream token. - """ - return self._simple_select_one_onecol( - table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering" - ).addCallback(lambda row: "s%d" % (row,)) - - def get_topological_token_for_event(self, event_id): - """The stream token for an event - Args: - event_id(str): The id of the event to look up a stream token for. - Raises: - StoreError if the event wasn't in the database. - Returns: - A deferred "t%d-%d" topological token. - """ - return self._simple_select_one( - table="events", - keyvalues={"event_id": event_id}, - retcols=("stream_ordering", "topological_ordering"), - desc="get_topological_token_for_event", - ).addCallback( - lambda row: "t%d-%d" % (row["topological_ordering"], row["stream_ordering"]) - ) - - def get_max_topological_token(self, room_id, stream_key): - """Get the max topological token in a room before the given stream - ordering. - - Args: - room_id (str) - stream_key (int) - - Returns: - Deferred[int] - """ - sql = ( - "SELECT coalesce(max(topological_ordering), 0) FROM events" - " WHERE room_id = ? AND stream_ordering < ?" - ) - return self._execute( - "get_max_topological_token", None, sql, room_id, stream_key - ).addCallback(lambda r: r[0][0] if r else 0) - - def _get_max_topological_txn(self, txn, room_id): - txn.execute( - "SELECT MAX(topological_ordering) FROM events" " WHERE room_id = ?", - (room_id,), - ) - - rows = txn.fetchall() - return rows[0][0] if rows else 0 - - @staticmethod - def _set_before_and_after(events, rows, topo_order=True): - """Inserts ordering information to events' internal metadata from - the DB rows. - - Args: - events (list[FrozenEvent]) - rows (list[_EventDictReturn]) - topo_order (bool): Whether the events were ordered topologically - or by stream ordering. If true then all rows should have a non - null topological_ordering. - """ - for event, row in zip(events, rows): - stream = row.stream_ordering - if topo_order and row.topological_ordering: - topo = row.topological_ordering - else: - topo = None - internal = event.internal_metadata - internal.before = str(RoomStreamToken(topo, stream - 1)) - internal.after = str(RoomStreamToken(topo, stream)) - internal.order = (int(topo) if topo else 0, int(stream)) - - @defer.inlineCallbacks - def get_events_around( - self, room_id, event_id, before_limit, after_limit, event_filter=None - ): - """Retrieve events and pagination tokens around a given event in a - room. - - Args: - room_id (str) - event_id (str) - before_limit (int) - after_limit (int) - event_filter (Filter|None) - - Returns: - dict - """ - - results = yield self.runInteraction( - "get_events_around", - self._get_events_around_txn, - room_id, - event_id, - before_limit, - after_limit, - event_filter, - ) - - events_before = yield self.get_events_as_list( - [e for e in results["before"]["event_ids"]], get_prev_content=True - ) - - events_after = yield self.get_events_as_list( - [e for e in results["after"]["event_ids"]], get_prev_content=True - ) - - return { - "events_before": events_before, - "events_after": events_after, - "start": results["before"]["token"], - "end": results["after"]["token"], - } - - def _get_events_around_txn( - self, txn, room_id, event_id, before_limit, after_limit, event_filter - ): - """Retrieves event_ids and pagination tokens around a given event in a - room. - - Args: - room_id (str) - event_id (str) - before_limit (int) - after_limit (int) - event_filter (Filter|None) - - Returns: - dict - """ - - results = self._simple_select_one_txn( - txn, - "events", - keyvalues={"event_id": event_id, "room_id": room_id}, - retcols=["stream_ordering", "topological_ordering"], - ) - - # Paginating backwards includes the event at the token, but paginating - # forward doesn't. - before_token = RoomStreamToken( - results["topological_ordering"] - 1, results["stream_ordering"] - ) - - after_token = RoomStreamToken( - results["topological_ordering"], results["stream_ordering"] - ) - - rows, start_token = self._paginate_room_events_txn( - txn, - room_id, - before_token, - direction="b", - limit=before_limit, - event_filter=event_filter, - ) - events_before = [r.event_id for r in rows] - - rows, end_token = self._paginate_room_events_txn( - txn, - room_id, - after_token, - direction="f", - limit=after_limit, - event_filter=event_filter, - ) - events_after = [r.event_id for r in rows] - - return { - "before": {"event_ids": events_before, "token": start_token}, - "after": {"event_ids": events_after, "token": end_token}, - } - - @defer.inlineCallbacks - def get_all_new_events_stream(self, from_id, current_id, limit): - """Get all new events - - Returns all events with from_id < stream_ordering <= current_id. - - Args: - from_id (int): the stream_ordering of the last event we processed - current_id (int): the stream_ordering of the most recently processed event - limit (int): the maximum number of events to return - - Returns: - Deferred[Tuple[int, list[FrozenEvent]]]: A tuple of (next_id, events), where - `next_id` is the next value to pass as `from_id` (it will either be the - stream_ordering of the last returned event, or, if fewer than `limit` events - were found, `current_id`. - """ - - def get_all_new_events_stream_txn(txn): - sql = ( - "SELECT e.stream_ordering, e.event_id" - " FROM events AS e" - " WHERE" - " ? < e.stream_ordering AND e.stream_ordering <= ?" - " ORDER BY e.stream_ordering ASC" - " LIMIT ?" - ) - - txn.execute(sql, (from_id, current_id, limit)) - rows = txn.fetchall() - - upper_bound = current_id - if len(rows) == limit: - upper_bound = rows[-1][0] - - return upper_bound, [row[1] for row in rows] - - upper_bound, event_ids = yield self.runInteraction( - "get_all_new_events_stream", get_all_new_events_stream_txn - ) - - events = yield self.get_events_as_list(event_ids) - - return upper_bound, events - - def get_federation_out_pos(self, typ): - return self._simple_select_one_onecol( - table="federation_stream_position", - retcol="stream_id", - keyvalues={"type": typ}, - desc="get_federation_out_pos", - ) - - def update_federation_out_pos(self, typ, stream_id): - return self._simple_update_one( - table="federation_stream_position", - keyvalues={"type": typ}, - updatevalues={"stream_id": stream_id}, - desc="update_federation_out_pos", - ) - - def has_room_changed_since(self, room_id, stream_id): - return self._events_stream_cache.has_entity_changed(room_id, stream_id) - - def _paginate_room_events_txn( - self, - txn, - room_id, - from_token, - to_token=None, - direction="b", - limit=-1, - event_filter=None, - ): - """Returns list of events before or after a given token. - - Args: - txn - room_id (str) - from_token (RoomStreamToken): The token used to stream from - to_token (RoomStreamToken|None): A token which if given limits the - results to only those before - direction(char): Either 'b' or 'f' to indicate whether we are - paginating forwards or backwards from `from_key`. - limit (int): The maximum number of events to return. - event_filter (Filter|None): If provided filters the events to - those that match the filter. - - Returns: - Deferred[tuple[list[_EventDictReturn], str]]: Returns the results - as a list of _EventDictReturn and a token that points to the end - of the result set. If no events are returned then the end of the - stream has been reached (i.e. there are no events between - `from_token` and `to_token`), or `limit` is zero. - """ - - assert int(limit) >= 0 - - # Tokens really represent positions between elements, but we use - # the convention of pointing to the event before the gap. Hence - # we have a bit of asymmetry when it comes to equalities. - args = [False, room_id] - if direction == "b": - order = "DESC" - else: - order = "ASC" - - bounds = generate_pagination_where_clause( - direction=direction, - column_names=("topological_ordering", "stream_ordering"), - from_token=from_token, - to_token=to_token, - engine=self.database_engine, - ) - - filter_clause, filter_args = filter_to_clause(event_filter) - - if filter_clause: - bounds += " AND " + filter_clause - args.extend(filter_args) - - args.append(int(limit)) - - sql = ( - "SELECT event_id, topological_ordering, stream_ordering" - " FROM events" - " WHERE outlier = ? AND room_id = ? AND %(bounds)s" - " ORDER BY topological_ordering %(order)s," - " stream_ordering %(order)s LIMIT ?" - ) % {"bounds": bounds, "order": order} - - txn.execute(sql, args) - - rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn] - - if rows: - topo = rows[-1].topological_ordering - toke = rows[-1].stream_ordering - if direction == "b": - # Tokens are positions between events. - # This token points *after* the last event in the chunk. - # We need it to point to the event before it in the chunk - # when we are going backwards so we subtract one from the - # stream part. - toke -= 1 - next_token = RoomStreamToken(topo, toke) - else: - # TODO (erikj): We should work out what to do here instead. - next_token = to_token if to_token else from_token - - return rows, str(next_token) - - @defer.inlineCallbacks - def paginate_room_events( - self, room_id, from_key, to_key=None, direction="b", limit=-1, event_filter=None - ): - """Returns list of events before or after a given token. - - Args: - room_id (str) - from_key (str): The token used to stream from - to_key (str|None): A token which if given limits the results to - only those before - direction(char): Either 'b' or 'f' to indicate whether we are - paginating forwards or backwards from `from_key`. - limit (int): The maximum number of events to return. - event_filter (Filter|None): If provided filters the events to - those that match the filter. - - Returns: - tuple[list[FrozenEvent], str]: Returns the results as a list of - events and a token that points to the end of the result set. If no - events are returned then the end of the stream has been reached - (i.e. there are no events between `from_key` and `to_key`). - """ - - from_key = RoomStreamToken.parse(from_key) - if to_key: - to_key = RoomStreamToken.parse(to_key) - - rows, token = yield self.runInteraction( - "paginate_room_events", - self._paginate_room_events_txn, - room_id, - from_key, - to_key, - direction, - limit, - event_filter, - ) - - events = yield self.get_events_as_list( - [r.event_id for r in rows], get_prev_content=True - ) - - self._set_before_and_after(events, rows) - - return (events, token) - - -class StreamStore(StreamWorkerStore): - def get_room_max_stream_ordering(self): - return self._stream_id_gen.get_current_token() - - def get_room_min_stream_ordering(self): - return self._backfill_id_gen.get_current_token() diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py deleted file mode 100644 index 20dd6bd53d..0000000000 --- a/synapse/storage/tags.py +++ /dev/null @@ -1,265 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from six.moves import range - -from canonicaljson import json - -from twisted.internet import defer - -from synapse.storage.account_data import AccountDataWorkerStore -from synapse.util.caches.descriptors import cached - -logger = logging.getLogger(__name__) - - -class TagsWorkerStore(AccountDataWorkerStore): - @cached() - def get_tags_for_user(self, user_id): - """Get all the tags for a user. - - - Args: - user_id(str): The user to get the tags for. - Returns: - A deferred dict mapping from room_id strings to dicts mapping from - tag strings to tag content. - """ - - deferred = self._simple_select_list( - "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"] - ) - - @deferred.addCallback - def tags_by_room(rows): - tags_by_room = {} - for row in rows: - room_tags = tags_by_room.setdefault(row["room_id"], {}) - room_tags[row["tag"]] = json.loads(row["content"]) - return tags_by_room - - return deferred - - @defer.inlineCallbacks - def get_all_updated_tags(self, last_id, current_id, limit): - """Get all the client tags that have changed on the server - Args: - last_id(int): The position to fetch from. - current_id(int): The position to fetch up to. - Returns: - A deferred list of tuples of stream_id int, user_id string, - room_id string, tag string and content string. - """ - if last_id == current_id: - return [] - - def get_all_updated_tags_txn(txn): - sql = ( - "SELECT stream_id, user_id, room_id" - " FROM room_tags_revisions as r" - " WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC LIMIT ?" - ) - txn.execute(sql, (last_id, current_id, limit)) - return txn.fetchall() - - tag_ids = yield self.runInteraction( - "get_all_updated_tags", get_all_updated_tags_txn - ) - - def get_tag_content(txn, tag_ids): - sql = ( - "SELECT tag, content" " FROM room_tags" " WHERE user_id=? AND room_id=?" - ) - results = [] - for stream_id, user_id, room_id in tag_ids: - txn.execute(sql, (user_id, room_id)) - tags = [] - for tag, content in txn: - tags.append(json.dumps(tag) + ":" + content) - tag_json = "{" + ",".join(tags) + "}" - results.append((stream_id, user_id, room_id, tag_json)) - - return results - - batch_size = 50 - results = [] - for i in range(0, len(tag_ids), batch_size): - tags = yield self.runInteraction( - "get_all_updated_tag_content", - get_tag_content, - tag_ids[i : i + batch_size], - ) - results.extend(tags) - - return results - - @defer.inlineCallbacks - def get_updated_tags(self, user_id, stream_id): - """Get all the tags for the rooms where the tags have changed since the - given version - - Args: - user_id(str): The user to get the tags for. - stream_id(int): The earliest update to get for the user. - Returns: - A deferred dict mapping from room_id strings to lists of tag - strings for all the rooms that changed since the stream_id token. - """ - - def get_updated_tags_txn(txn): - sql = ( - "SELECT room_id from room_tags_revisions" - " WHERE user_id = ? AND stream_id > ?" - ) - txn.execute(sql, (user_id, stream_id)) - room_ids = [row[0] for row in txn] - return room_ids - - changed = self._account_data_stream_cache.has_entity_changed( - user_id, int(stream_id) - ) - if not changed: - return {} - - room_ids = yield self.runInteraction("get_updated_tags", get_updated_tags_txn) - - results = {} - if room_ids: - tags_by_room = yield self.get_tags_for_user(user_id) - for room_id in room_ids: - results[room_id] = tags_by_room.get(room_id, {}) - - return results - - def get_tags_for_room(self, user_id, room_id): - """Get all the tags for the given room - Args: - user_id(str): The user to get tags for - room_id(str): The room to get tags for - Returns: - A deferred list of string tags. - """ - return self._simple_select_list( - table="room_tags", - keyvalues={"user_id": user_id, "room_id": room_id}, - retcols=("tag", "content"), - desc="get_tags_for_room", - ).addCallback( - lambda rows: {row["tag"]: json.loads(row["content"]) for row in rows} - ) - - -class TagsStore(TagsWorkerStore): - @defer.inlineCallbacks - def add_tag_to_room(self, user_id, room_id, tag, content): - """Add a tag to a room for a user. - Args: - user_id(str): The user to add a tag for. - room_id(str): The room to add a tag for. - tag(str): The tag name to add. - content(dict): A json object to associate with the tag. - Returns: - A deferred that completes once the tag has been added. - """ - content_json = json.dumps(content) - - def add_tag_txn(txn, next_id): - self._simple_upsert_txn( - txn, - table="room_tags", - keyvalues={"user_id": user_id, "room_id": room_id, "tag": tag}, - values={"content": content_json}, - ) - self._update_revision_txn(txn, user_id, room_id, next_id) - - with self._account_data_id_gen.get_next() as next_id: - yield self.runInteraction("add_tag", add_tag_txn, next_id) - - self.get_tags_for_user.invalidate((user_id,)) - - result = self._account_data_id_gen.get_current_token() - return result - - @defer.inlineCallbacks - def remove_tag_from_room(self, user_id, room_id, tag): - """Remove a tag from a room for a user. - Returns: - A deferred that completes once the tag has been removed - """ - - def remove_tag_txn(txn, next_id): - sql = ( - "DELETE FROM room_tags " - " WHERE user_id = ? AND room_id = ? AND tag = ?" - ) - txn.execute(sql, (user_id, room_id, tag)) - self._update_revision_txn(txn, user_id, room_id, next_id) - - with self._account_data_id_gen.get_next() as next_id: - yield self.runInteraction("remove_tag", remove_tag_txn, next_id) - - self.get_tags_for_user.invalidate((user_id,)) - - result = self._account_data_id_gen.get_current_token() - return result - - def _update_revision_txn(self, txn, user_id, room_id, next_id): - """Update the latest revision of the tags for the given user and room. - - Args: - txn: The database cursor - user_id(str): The ID of the user. - room_id(str): The ID of the room. - next_id(int): The the revision to advance to. - """ - - txn.call_after( - self._account_data_stream_cache.entity_has_changed, user_id, next_id - ) - - update_max_id_sql = ( - "UPDATE account_data_max_stream_id" - " SET stream_id = ?" - " WHERE stream_id < ?" - ) - txn.execute(update_max_id_sql, (next_id, next_id)) - - update_sql = ( - "UPDATE room_tags_revisions" - " SET stream_id = ?" - " WHERE user_id = ?" - " AND room_id = ?" - ) - txn.execute(update_sql, (next_id, user_id, room_id)) - - if txn.rowcount == 0: - insert_sql = ( - "INSERT INTO room_tags_revisions (user_id, room_id, stream_id)" - " VALUES (?, ?, ?)" - ) - try: - txn.execute(insert_sql, (user_id, room_id, next_id)) - except self.database_engine.module.IntegrityError: - # Ignore insertion errors. It doesn't matter if the row wasn't - # inserted because if two updates happend concurrently the one - # with the higher stream_id will not be reported to a client - # unless the previous update has completed. It doesn't matter - # which stream_id ends up in the table, as long as it is higher - # than the id that the client has. - pass diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py deleted file mode 100644 index 289c117396..0000000000 --- a/synapse/storage/transactions.py +++ /dev/null @@ -1,274 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from collections import namedtuple - -import six - -from canonicaljson import encode_canonical_json - -from twisted.internet import defer - -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util.caches.expiringcache import ExpiringCache - -from ._base import SQLBaseStore, db_to_json - -# py2 sqlite has buffer hardcoded as only binary type, so we must use it, -# despite being deprecated and removed in favor of memoryview -if six.PY2: - db_binary_type = six.moves.builtins.buffer -else: - db_binary_type = memoryview - -logger = logging.getLogger(__name__) - - -_TransactionRow = namedtuple( - "_TransactionRow", - ("id", "transaction_id", "destination", "ts", "response_code", "response_json"), -) - -_UpdateTransactionRow = namedtuple( - "_TransactionRow", ("response_code", "response_json") -) - -SENTINEL = object() - - -class TransactionStore(SQLBaseStore): - """A collection of queries for handling PDUs. - """ - - def __init__(self, db_conn, hs): - super(TransactionStore, self).__init__(db_conn, hs) - - self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000) - - self._destination_retry_cache = ExpiringCache( - cache_name="get_destination_retry_timings", - clock=self._clock, - expiry_ms=5 * 60 * 1000, - ) - - def get_received_txn_response(self, transaction_id, origin): - """For an incoming transaction from a given origin, check if we have - already responded to it. If so, return the response code and response - body (as a dict). - - Args: - transaction_id (str) - origin(str) - - Returns: - tuple: None if we have not previously responded to - this transaction or a 2-tuple of (int, dict) - """ - - return self.runInteraction( - "get_received_txn_response", - self._get_received_txn_response, - transaction_id, - origin, - ) - - def _get_received_txn_response(self, txn, transaction_id, origin): - result = self._simple_select_one_txn( - txn, - table="received_transactions", - keyvalues={"transaction_id": transaction_id, "origin": origin}, - retcols=( - "transaction_id", - "origin", - "ts", - "response_code", - "response_json", - "has_been_referenced", - ), - allow_none=True, - ) - - if result and result["response_code"]: - return result["response_code"], db_to_json(result["response_json"]) - - else: - return None - - def set_received_txn_response(self, transaction_id, origin, code, response_dict): - """Persist the response we returened for an incoming transaction, and - should return for subsequent transactions with the same transaction_id - and origin. - - Args: - txn - transaction_id (str) - origin (str) - code (int) - response_json (str) - """ - - return self._simple_insert( - table="received_transactions", - values={ - "transaction_id": transaction_id, - "origin": origin, - "response_code": code, - "response_json": db_binary_type(encode_canonical_json(response_dict)), - "ts": self._clock.time_msec(), - }, - or_ignore=True, - desc="set_received_txn_response", - ) - - @defer.inlineCallbacks - def get_destination_retry_timings(self, destination): - """Gets the current retry timings (if any) for a given destination. - - Args: - destination (str) - - Returns: - None if not retrying - Otherwise a dict for the retry scheme - """ - - result = self._destination_retry_cache.get(destination, SENTINEL) - if result is not SENTINEL: - return result - - result = yield self.runInteraction( - "get_destination_retry_timings", - self._get_destination_retry_timings, - destination, - ) - - # We don't hugely care about race conditions between getting and - # invalidating the cache, since we time out fairly quickly anyway. - self._destination_retry_cache[destination] = result - return result - - def _get_destination_retry_timings(self, txn, destination): - result = self._simple_select_one_txn( - txn, - table="destinations", - keyvalues={"destination": destination}, - retcols=("destination", "failure_ts", "retry_last_ts", "retry_interval"), - allow_none=True, - ) - - if result and result["retry_last_ts"] > 0: - return result - else: - return None - - def set_destination_retry_timings( - self, destination, failure_ts, retry_last_ts, retry_interval - ): - """Sets the current retry timings for a given destination. - Both timings should be zero if retrying is no longer occuring. - - Args: - destination (str) - failure_ts (int|None) - when the server started failing (ms since epoch) - retry_last_ts (int) - time of last retry attempt in unix epoch ms - retry_interval (int) - how long until next retry in ms - """ - - self._destination_retry_cache.pop(destination, None) - return self.runInteraction( - "set_destination_retry_timings", - self._set_destination_retry_timings, - destination, - failure_ts, - retry_last_ts, - retry_interval, - ) - - def _set_destination_retry_timings( - self, txn, destination, failure_ts, retry_last_ts, retry_interval - ): - - if self.database_engine.can_native_upsert: - # Upsert retry time interval if retry_interval is zero (i.e. we're - # resetting it) or greater than the existing retry interval. - - sql = """ - INSERT INTO destinations ( - destination, failure_ts, retry_last_ts, retry_interval - ) - VALUES (?, ?, ?, ?) - ON CONFLICT (destination) DO UPDATE SET - failure_ts = EXCLUDED.failure_ts, - retry_last_ts = EXCLUDED.retry_last_ts, - retry_interval = EXCLUDED.retry_interval - WHERE - EXCLUDED.retry_interval = 0 - OR destinations.retry_interval < EXCLUDED.retry_interval - """ - - txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval)) - - return - - self.database_engine.lock_table(txn, "destinations") - - # We need to be careful here as the data may have changed from under us - # due to a worker setting the timings. - - prev_row = self._simple_select_one_txn( - txn, - table="destinations", - keyvalues={"destination": destination}, - retcols=("failure_ts", "retry_last_ts", "retry_interval"), - allow_none=True, - ) - - if not prev_row: - self._simple_insert_txn( - txn, - table="destinations", - values={ - "destination": destination, - "failure_ts": failure_ts, - "retry_last_ts": retry_last_ts, - "retry_interval": retry_interval, - }, - ) - elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval: - self._simple_update_one_txn( - txn, - "destinations", - keyvalues={"destination": destination}, - updatevalues={ - "failure_ts": failure_ts, - "retry_last_ts": retry_last_ts, - "retry_interval": retry_interval, - }, - ) - - def _start_cleanup_transactions(self): - return run_as_background_process( - "cleanup_transactions", self._cleanup_transactions - ) - - def _cleanup_transactions(self): - now = self._clock.time_msec() - month_ago = now - 30 * 24 * 60 * 60 * 1000 - - def _cleanup_transactions_txn(txn): - txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) - - return self.runInteraction("_cleanup_transactions", _cleanup_transactions_txn) diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py deleted file mode 100644 index 1b1e4751b9..0000000000 --- a/synapse/storage/user_directory.py +++ /dev/null @@ -1,827 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2017 Vector Creations Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import re - -from twisted.internet import defer - -from synapse.api.constants import EventTypes, JoinRules -from synapse.storage.background_updates import BackgroundUpdateStore -from synapse.storage.engines import PostgresEngine, Sqlite3Engine -from synapse.storage.state import StateFilter -from synapse.storage.state_deltas import StateDeltasStore -from synapse.types import get_domain_from_id, get_localpart_from_id -from synapse.util.caches.descriptors import cached - -logger = logging.getLogger(__name__) - - -TEMP_TABLE = "_temp_populate_user_directory" - - -class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore): - - # How many records do we calculate before sending it to - # add_users_who_share_private_rooms? - SHARE_PRIVATE_WORKING_SET = 500 - - def __init__(self, db_conn, hs): - super(UserDirectoryBackgroundUpdateStore, self).__init__(db_conn, hs) - - self.server_name = hs.hostname - - self.register_background_update_handler( - "populate_user_directory_createtables", - self._populate_user_directory_createtables, - ) - self.register_background_update_handler( - "populate_user_directory_process_rooms", - self._populate_user_directory_process_rooms, - ) - self.register_background_update_handler( - "populate_user_directory_process_users", - self._populate_user_directory_process_users, - ) - self.register_background_update_handler( - "populate_user_directory_cleanup", self._populate_user_directory_cleanup - ) - - @defer.inlineCallbacks - def _populate_user_directory_createtables(self, progress, batch_size): - - # Get all the rooms that we want to process. - def _make_staging_area(txn): - sql = ( - "CREATE TABLE IF NOT EXISTS " - + TEMP_TABLE - + "_rooms(room_id TEXT NOT NULL, events BIGINT NOT NULL)" - ) - txn.execute(sql) - - sql = ( - "CREATE TABLE IF NOT EXISTS " - + TEMP_TABLE - + "_position(position TEXT NOT NULL)" - ) - txn.execute(sql) - - # Get rooms we want to process from the database - sql = """ - SELECT room_id, count(*) FROM current_state_events - GROUP BY room_id - """ - txn.execute(sql) - rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()] - self._simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms) - del rooms - - # If search all users is on, get all the users we want to add. - if self.hs.config.user_directory_search_all_users: - sql = ( - "CREATE TABLE IF NOT EXISTS " - + TEMP_TABLE - + "_users(user_id TEXT NOT NULL)" - ) - txn.execute(sql) - - txn.execute("SELECT name FROM users") - users = [{"user_id": x[0]} for x in txn.fetchall()] - - self._simple_insert_many_txn(txn, TEMP_TABLE + "_users", users) - - new_pos = yield self.get_max_stream_id_in_current_state_deltas() - yield self.runInteraction( - "populate_user_directory_temp_build", _make_staging_area - ) - yield self._simple_insert(TEMP_TABLE + "_position", {"position": new_pos}) - - yield self._end_background_update("populate_user_directory_createtables") - return 1 - - @defer.inlineCallbacks - def _populate_user_directory_cleanup(self, progress, batch_size): - """ - Update the user directory stream position, then clean up the old tables. - """ - position = yield self._simple_select_one_onecol( - TEMP_TABLE + "_position", None, "position" - ) - yield self.update_user_directory_stream_pos(position) - - def _delete_staging_area(txn): - txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms") - txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users") - txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position") - - yield self.runInteraction( - "populate_user_directory_cleanup", _delete_staging_area - ) - - yield self._end_background_update("populate_user_directory_cleanup") - return 1 - - @defer.inlineCallbacks - def _populate_user_directory_process_rooms(self, progress, batch_size): - """ - Args: - progress (dict) - batch_size (int): Maximum number of state events to process - per cycle. - """ - state = self.hs.get_state_handler() - - # If we don't have progress filed, delete everything. - if not progress: - yield self.delete_all_from_user_dir() - - def _get_next_batch(txn): - # Only fetch 250 rooms, so we don't fetch too many at once, even - # if those 250 rooms have less than batch_size state events. - sql = """ - SELECT room_id, events FROM %s - ORDER BY events DESC - LIMIT 250 - """ % ( - TEMP_TABLE + "_rooms", - ) - txn.execute(sql) - rooms_to_work_on = txn.fetchall() - - if not rooms_to_work_on: - return None - - # Get how many are left to process, so we can give status on how - # far we are in processing - txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms") - progress["remaining"] = txn.fetchone()[0] - - return rooms_to_work_on - - rooms_to_work_on = yield self.runInteraction( - "populate_user_directory_temp_read", _get_next_batch - ) - - # No more rooms -- complete the transaction. - if not rooms_to_work_on: - yield self._end_background_update("populate_user_directory_process_rooms") - return 1 - - logger.info( - "Processing the next %d rooms of %d remaining" - % (len(rooms_to_work_on), progress["remaining"]) - ) - - processed_event_count = 0 - - for room_id, event_count in rooms_to_work_on: - is_in_room = yield self.is_host_joined(room_id, self.server_name) - - if is_in_room: - is_public = yield self.is_room_world_readable_or_publicly_joinable( - room_id - ) - - users_with_profile = yield state.get_current_users_in_room(room_id) - user_ids = set(users_with_profile) - - # Update each user in the user directory. - for user_id, profile in users_with_profile.items(): - yield self.update_profile_in_user_dir( - user_id, profile.display_name, profile.avatar_url - ) - - to_insert = set() - - if is_public: - for user_id in user_ids: - if self.get_if_app_services_interested_in_user(user_id): - continue - - to_insert.add(user_id) - - if to_insert: - yield self.add_users_in_public_rooms(room_id, to_insert) - to_insert.clear() - else: - for user_id in user_ids: - if not self.hs.is_mine_id(user_id): - continue - - if self.get_if_app_services_interested_in_user(user_id): - continue - - for other_user_id in user_ids: - if user_id == other_user_id: - continue - - user_set = (user_id, other_user_id) - to_insert.add(user_set) - - # If it gets too big, stop and write to the database - # to prevent storing too much in RAM. - if len(to_insert) >= self.SHARE_PRIVATE_WORKING_SET: - yield self.add_users_who_share_private_room( - room_id, to_insert - ) - to_insert.clear() - - if to_insert: - yield self.add_users_who_share_private_room(room_id, to_insert) - to_insert.clear() - - # We've finished a room. Delete it from the table. - yield self._simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id}) - # Update the remaining counter. - progress["remaining"] -= 1 - yield self.runInteraction( - "populate_user_directory", - self._background_update_progress_txn, - "populate_user_directory_process_rooms", - progress, - ) - - processed_event_count += event_count - - if processed_event_count > batch_size: - # Don't process any more rooms, we've hit our batch size. - return processed_event_count - - return processed_event_count - - @defer.inlineCallbacks - def _populate_user_directory_process_users(self, progress, batch_size): - """ - If search_all_users is enabled, add all of the users to the user directory. - """ - if not self.hs.config.user_directory_search_all_users: - yield self._end_background_update("populate_user_directory_process_users") - return 1 - - def _get_next_batch(txn): - sql = "SELECT user_id FROM %s LIMIT %s" % ( - TEMP_TABLE + "_users", - str(batch_size), - ) - txn.execute(sql) - users_to_work_on = txn.fetchall() - - if not users_to_work_on: - return None - - users_to_work_on = [x[0] for x in users_to_work_on] - - # Get how many are left to process, so we can give status on how - # far we are in processing - sql = "SELECT COUNT(*) FROM " + TEMP_TABLE + "_users" - txn.execute(sql) - progress["remaining"] = txn.fetchone()[0] - - return users_to_work_on - - users_to_work_on = yield self.runInteraction( - "populate_user_directory_temp_read", _get_next_batch - ) - - # No more users -- complete the transaction. - if not users_to_work_on: - yield self._end_background_update("populate_user_directory_process_users") - return 1 - - logger.info( - "Processing the next %d users of %d remaining" - % (len(users_to_work_on), progress["remaining"]) - ) - - for user_id in users_to_work_on: - profile = yield self.get_profileinfo(get_localpart_from_id(user_id)) - yield self.update_profile_in_user_dir( - user_id, profile.display_name, profile.avatar_url - ) - - # We've finished processing a user. Delete it from the table. - yield self._simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id}) - # Update the remaining counter. - progress["remaining"] -= 1 - yield self.runInteraction( - "populate_user_directory", - self._background_update_progress_txn, - "populate_user_directory_process_users", - progress, - ) - - return len(users_to_work_on) - - @defer.inlineCallbacks - def is_room_world_readable_or_publicly_joinable(self, room_id): - """Check if the room is either world_readable or publically joinable - """ - - # Create a state filter that only queries join and history state event - types_to_filter = ( - (EventTypes.JoinRules, ""), - (EventTypes.RoomHistoryVisibility, ""), - ) - - current_state_ids = yield self.get_filtered_current_state_ids( - room_id, StateFilter.from_types(types_to_filter) - ) - - join_rules_id = current_state_ids.get((EventTypes.JoinRules, "")) - if join_rules_id: - join_rule_ev = yield self.get_event(join_rules_id, allow_none=True) - if join_rule_ev: - if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC: - return True - - hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, "")) - if hist_vis_id: - hist_vis_ev = yield self.get_event(hist_vis_id, allow_none=True) - if hist_vis_ev: - if hist_vis_ev.content.get("history_visibility") == "world_readable": - return True - - return False - - def update_profile_in_user_dir(self, user_id, display_name, avatar_url): - """ - Update or add a user's profile in the user directory. - """ - - def _update_profile_in_user_dir_txn(txn): - new_entry = self._simple_upsert_txn( - txn, - table="user_directory", - keyvalues={"user_id": user_id}, - values={"display_name": display_name, "avatar_url": avatar_url}, - lock=False, # We're only inserter - ) - - if isinstance(self.database_engine, PostgresEngine): - # We weight the localpart most highly, then display name and finally - # server name - if self.database_engine.can_native_upsert: - sql = """ - INSERT INTO user_directory_search(user_id, vector) - VALUES (?, - setweight(to_tsvector('english', ?), 'A') - || setweight(to_tsvector('english', ?), 'D') - || setweight(to_tsvector('english', COALESCE(?, '')), 'B') - ) ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector - """ - txn.execute( - sql, - ( - user_id, - get_localpart_from_id(user_id), - get_domain_from_id(user_id), - display_name, - ), - ) - else: - # TODO: Remove this code after we've bumped the minimum version - # of postgres to always support upserts, so we can get rid of - # `new_entry` usage - if new_entry is True: - sql = """ - INSERT INTO user_directory_search(user_id, vector) - VALUES (?, - setweight(to_tsvector('english', ?), 'A') - || setweight(to_tsvector('english', ?), 'D') - || setweight(to_tsvector('english', COALESCE(?, '')), 'B') - ) - """ - txn.execute( - sql, - ( - user_id, - get_localpart_from_id(user_id), - get_domain_from_id(user_id), - display_name, - ), - ) - elif new_entry is False: - sql = """ - UPDATE user_directory_search - SET vector = setweight(to_tsvector('english', ?), 'A') - || setweight(to_tsvector('english', ?), 'D') - || setweight(to_tsvector('english', COALESCE(?, '')), 'B') - WHERE user_id = ? - """ - txn.execute( - sql, - ( - get_localpart_from_id(user_id), - get_domain_from_id(user_id), - display_name, - user_id, - ), - ) - else: - raise RuntimeError( - "upsert returned None when 'can_native_upsert' is False" - ) - elif isinstance(self.database_engine, Sqlite3Engine): - value = "%s %s" % (user_id, display_name) if display_name else user_id - self._simple_upsert_txn( - txn, - table="user_directory_search", - keyvalues={"user_id": user_id}, - values={"value": value}, - lock=False, # We're only inserter - ) - else: - # This should be unreachable. - raise Exception("Unrecognized database engine") - - txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) - - return self.runInteraction( - "update_profile_in_user_dir", _update_profile_in_user_dir_txn - ) - - def add_users_who_share_private_room(self, room_id, user_id_tuples): - """Insert entries into the users_who_share_private_rooms table. The first - user should be a local user. - - Args: - room_id (str) - user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs. - """ - - def _add_users_who_share_room_txn(txn): - self._simple_upsert_many_txn( - txn, - table="users_who_share_private_rooms", - key_names=["user_id", "other_user_id", "room_id"], - key_values=[ - (user_id, other_user_id, room_id) - for user_id, other_user_id in user_id_tuples - ], - value_names=(), - value_values=None, - ) - - return self.runInteraction( - "add_users_who_share_room", _add_users_who_share_room_txn - ) - - def add_users_in_public_rooms(self, room_id, user_ids): - """Insert entries into the users_who_share_private_rooms table. The first - user should be a local user. - - Args: - room_id (str) - user_ids (list[str]) - """ - - def _add_users_in_public_rooms_txn(txn): - - self._simple_upsert_many_txn( - txn, - table="users_in_public_rooms", - key_names=["user_id", "room_id"], - key_values=[(user_id, room_id) for user_id in user_ids], - value_names=(), - value_values=None, - ) - - return self.runInteraction( - "add_users_in_public_rooms", _add_users_in_public_rooms_txn - ) - - def delete_all_from_user_dir(self): - """Delete the entire user directory - """ - - def _delete_all_from_user_dir_txn(txn): - txn.execute("DELETE FROM user_directory") - txn.execute("DELETE FROM user_directory_search") - txn.execute("DELETE FROM users_in_public_rooms") - txn.execute("DELETE FROM users_who_share_private_rooms") - txn.call_after(self.get_user_in_directory.invalidate_all) - - return self.runInteraction( - "delete_all_from_user_dir", _delete_all_from_user_dir_txn - ) - - @cached() - def get_user_in_directory(self, user_id): - return self._simple_select_one( - table="user_directory", - keyvalues={"user_id": user_id}, - retcols=("display_name", "avatar_url"), - allow_none=True, - desc="get_user_in_directory", - ) - - def update_user_directory_stream_pos(self, stream_id): - return self._simple_update_one( - table="user_directory_stream_pos", - keyvalues={}, - updatevalues={"stream_id": stream_id}, - desc="update_user_directory_stream_pos", - ) - - -class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): - - # How many records do we calculate before sending it to - # add_users_who_share_private_rooms? - SHARE_PRIVATE_WORKING_SET = 500 - - def __init__(self, db_conn, hs): - super(UserDirectoryStore, self).__init__(db_conn, hs) - - def remove_from_user_dir(self, user_id): - def _remove_from_user_dir_txn(txn): - self._simple_delete_txn( - txn, table="user_directory", keyvalues={"user_id": user_id} - ) - self._simple_delete_txn( - txn, table="user_directory_search", keyvalues={"user_id": user_id} - ) - self._simple_delete_txn( - txn, table="users_in_public_rooms", keyvalues={"user_id": user_id} - ) - self._simple_delete_txn( - txn, - table="users_who_share_private_rooms", - keyvalues={"user_id": user_id}, - ) - self._simple_delete_txn( - txn, - table="users_who_share_private_rooms", - keyvalues={"other_user_id": user_id}, - ) - txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) - - return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn) - - @defer.inlineCallbacks - def get_users_in_dir_due_to_room(self, room_id): - """Get all user_ids that are in the room directory because they're - in the given room_id - """ - user_ids_share_pub = yield self._simple_select_onecol( - table="users_in_public_rooms", - keyvalues={"room_id": room_id}, - retcol="user_id", - desc="get_users_in_dir_due_to_room", - ) - - user_ids_share_priv = yield self._simple_select_onecol( - table="users_who_share_private_rooms", - keyvalues={"room_id": room_id}, - retcol="other_user_id", - desc="get_users_in_dir_due_to_room", - ) - - user_ids = set(user_ids_share_pub) - user_ids.update(user_ids_share_priv) - - return user_ids - - def remove_user_who_share_room(self, user_id, room_id): - """ - Deletes entries in the users_who_share_*_rooms table. The first - user should be a local user. - - Args: - user_id (str) - room_id (str) - """ - - def _remove_user_who_share_room_txn(txn): - self._simple_delete_txn( - txn, - table="users_who_share_private_rooms", - keyvalues={"user_id": user_id, "room_id": room_id}, - ) - self._simple_delete_txn( - txn, - table="users_who_share_private_rooms", - keyvalues={"other_user_id": user_id, "room_id": room_id}, - ) - self._simple_delete_txn( - txn, - table="users_in_public_rooms", - keyvalues={"user_id": user_id, "room_id": room_id}, - ) - - return self.runInteraction( - "remove_user_who_share_room", _remove_user_who_share_room_txn - ) - - @defer.inlineCallbacks - def get_user_dir_rooms_user_is_in(self, user_id): - """ - Returns the rooms that a user is in. - - Args: - user_id(str): Must be a local user - - Returns: - list: user_id - """ - rows = yield self._simple_select_onecol( - table="users_who_share_private_rooms", - keyvalues={"user_id": user_id}, - retcol="room_id", - desc="get_rooms_user_is_in", - ) - - pub_rows = yield self._simple_select_onecol( - table="users_in_public_rooms", - keyvalues={"user_id": user_id}, - retcol="room_id", - desc="get_rooms_user_is_in", - ) - - users = set(pub_rows) - users.update(rows) - return list(users) - - @defer.inlineCallbacks - def get_rooms_in_common_for_users(self, user_id, other_user_id): - """Given two user_ids find out the list of rooms they share. - """ - sql = """ - SELECT room_id FROM ( - SELECT c.room_id FROM current_state_events AS c - INNER JOIN room_memberships AS m USING (event_id) - WHERE type = 'm.room.member' - AND m.membership = 'join' - AND state_key = ? - ) AS f1 INNER JOIN ( - SELECT c.room_id FROM current_state_events AS c - INNER JOIN room_memberships AS m USING (event_id) - WHERE type = 'm.room.member' - AND m.membership = 'join' - AND state_key = ? - ) f2 USING (room_id) - """ - - rows = yield self._execute( - "get_rooms_in_common_for_users", None, sql, user_id, other_user_id - ) - - return [room_id for room_id, in rows] - - def get_user_directory_stream_pos(self): - return self._simple_select_one_onecol( - table="user_directory_stream_pos", - keyvalues={}, - retcol="stream_id", - desc="get_user_directory_stream_pos", - ) - - @defer.inlineCallbacks - def search_user_dir(self, user_id, search_term, limit): - """Searches for users in directory - - Returns: - dict of the form:: - - { - "limited": , # whether there were more results or not - "results": [ # Ordered by best match first - { - "user_id": , - "display_name": , - "avatar_url": - } - ] - } - """ - - if self.hs.config.user_directory_search_all_users: - join_args = (user_id,) - where_clause = "user_id != ?" - else: - join_args = (user_id,) - where_clause = """ - ( - EXISTS (select 1 from users_in_public_rooms WHERE user_id = t.user_id) - OR EXISTS ( - SELECT 1 FROM users_who_share_private_rooms - WHERE user_id = ? AND other_user_id = t.user_id - ) - ) - """ - - if isinstance(self.database_engine, PostgresEngine): - full_query, exact_query, prefix_query = _parse_query_postgres(search_term) - - # We order by rank and then if they have profile info - # The ranking algorithm is hand tweaked for "best" results. Broadly - # the idea is we give a higher weight to exact matches. - # The array of numbers are the weights for the various part of the - # search: (domain, _, display name, localpart) - sql = """ - SELECT d.user_id AS user_id, display_name, avatar_url - FROM user_directory_search as t - INNER JOIN user_directory AS d USING (user_id) - WHERE - %s - AND vector @@ to_tsquery('english', ?) - ORDER BY - (CASE WHEN d.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END) - * (CASE WHEN display_name IS NOT NULL THEN 1.2 ELSE 1.0 END) - * (CASE WHEN avatar_url IS NOT NULL THEN 1.2 ELSE 1.0 END) - * ( - 3 * ts_rank_cd( - '{0.1, 0.1, 0.9, 1.0}', - vector, - to_tsquery('english', ?), - 8 - ) - + ts_rank_cd( - '{0.1, 0.1, 0.9, 1.0}', - vector, - to_tsquery('english', ?), - 8 - ) - ) - DESC, - display_name IS NULL, - avatar_url IS NULL - LIMIT ? - """ % ( - where_clause, - ) - args = join_args + (full_query, exact_query, prefix_query, limit + 1) - elif isinstance(self.database_engine, Sqlite3Engine): - search_query = _parse_query_sqlite(search_term) - - sql = """ - SELECT d.user_id AS user_id, display_name, avatar_url - FROM user_directory_search as t - INNER JOIN user_directory AS d USING (user_id) - WHERE - %s - AND value MATCH ? - ORDER BY - rank(matchinfo(user_directory_search)) DESC, - display_name IS NULL, - avatar_url IS NULL - LIMIT ? - """ % ( - where_clause, - ) - args = join_args + (search_query, limit + 1) - else: - # This should be unreachable. - raise Exception("Unrecognized database engine") - - results = yield self._execute( - "search_user_dir", self.cursor_to_dict, sql, *args - ) - - limited = len(results) > limit - - return {"limited": limited, "results": results} - - -def _parse_query_sqlite(search_term): - """Takes a plain unicode string from the user and converts it into a form - that can be passed to database. - We use this so that we can add prefix matching, which isn't something - that is supported by default. - - We specifically add both a prefix and non prefix matching term so that - exact matches get ranked higher. - """ - - # Pull out the individual words, discarding any non-word characters. - results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) - return " & ".join("(%s* OR %s)" % (result, result) for result in results) - - -def _parse_query_postgres(search_term): - """Takes a plain unicode string from the user and converts it into a form - that can be passed to database. - We use this so that we can add prefix matching, which isn't something - that is supported by default. - """ - - # Pull out the individual words, discarding any non-word characters. - results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) - - both = " & ".join("(%s:* | %s)" % (result, result) for result in results) - exact = " & ".join("%s" % (result,) for result in results) - prefix = " & ".join("%s:*" % (result,) for result in results) - - return both, exact, prefix diff --git a/synapse/storage/user_erasure_store.py b/synapse/storage/user_erasure_store.py deleted file mode 100644 index aa4f0da5f0..0000000000 --- a/synapse/storage/user_erasure_store.py +++ /dev/null @@ -1,91 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import operator - -from synapse.storage._base import SQLBaseStore -from synapse.util.caches.descriptors import cached, cachedList - - -class UserErasureWorkerStore(SQLBaseStore): - @cached() - def is_user_erased(self, user_id): - """ - Check if the given user id has requested erasure - - Args: - user_id (str): full user id to check - - Returns: - Deferred[bool]: True if the user has requested erasure - """ - return self._simple_select_onecol( - table="erased_users", - keyvalues={"user_id": user_id}, - retcol="1", - desc="is_user_erased", - ).addCallback(operator.truth) - - @cachedList( - cached_method_name="is_user_erased", list_name="user_ids", inlineCallbacks=True - ) - def are_users_erased(self, user_ids): - """ - Checks which users in a list have requested erasure - - Args: - user_ids (iterable[str]): full user id to check - - Returns: - Deferred[dict[str, bool]]: - for each user, whether the user has requested erasure. - """ - # this serves the dual purpose of (a) making sure we can do len and - # iterate it multiple times, and (b) avoiding duplicates. - user_ids = tuple(set(user_ids)) - - rows = yield self._simple_select_many_batch( - table="erased_users", - column="user_id", - iterable=user_ids, - retcols=("user_id",), - desc="are_users_erased", - ) - erased_users = set(row["user_id"] for row in rows) - - res = dict((u, u in erased_users) for u in user_ids) - return res - - -class UserErasureStore(UserErasureWorkerStore): - def mark_user_erased(self, user_id): - """Indicate that user_id wishes their message history to be erased. - - Args: - user_id (str): full user_id to be erased - """ - - def f(txn): - # first check if they are already in the list - txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?", (user_id,)) - if txn.fetchone(): - return - - # they are not already there: do the insert. - txn.execute("INSERT INTO erased_users (user_id) VALUES (?)", (user_id,)) - - self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,)) - - return self.runInteraction("mark_user_erased", f) diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 7569b6fab5..d5c8bd7612 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -13,9 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse import storage from synapse.rest import admin from synapse.rest.client.v1 import login, room +from synapse.storage.data_stores.main import stats from tests import unittest @@ -87,10 +87,10 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) def _get_current_stats(self, stats_type, stat_id): - table, id_col = storage.stats.TYPE_TO_TABLE[stats_type] + table, id_col = stats.TYPE_TO_TABLE[stats_type] - cols = list(storage.stats.ABSOLUTE_STATS_FIELDS[stats_type]) + list( - storage.stats.PER_SLICE_FIELDS[stats_type] + cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type]) + list( + stats.PER_SLICE_FIELDS[stats_type] ) end_ts = self.store.quantise_stats_time(self.reactor.seconds() * 1000) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 622b16a071..dfeea24599 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -24,7 +24,7 @@ from twisted.internet import defer from synapse.appservice import ApplicationService, ApplicationServiceState from synapse.config._base import ConfigError -from synapse.storage.appservice import ( +from synapse.storage.data_stores.main.appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore, ) diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index 34f9c72709..69dcaa63d5 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -50,6 +50,8 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): schema_path = os.path.join( prepare_database.dir_path, + "data_stores", + "main", "schema", "delta", "54", diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 45824bd3b2..24c7fe16c3 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -16,7 +16,7 @@ from twisted.internet import defer -from synapse.storage.profile import ProfileStore +from synapse.storage.data_stores.main.profile import ProfileStore from synapse.types import UserID from tests import unittest diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index d7d244ce97..7eea57c0e2 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -15,7 +15,7 @@ from twisted.internet import defer -from synapse.storage import UserDirectoryStore +from synapse.storage.data_stores.main.user_directory import UserDirectoryStore from tests import unittest from tests.utils import setup_test_homeserver -- cgit 1.5.1 From 336eeea3ffd14dbd879459cde50f1b7f32e9a325 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 22 Oct 2019 11:02:01 +0100 Subject: Fix postgres unit tests to use prepare_database --- tests/utils.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) (limited to 'tests') diff --git a/tests/utils.py b/tests/utils.py index 46ef2959f2..0a64f75d04 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -38,11 +38,7 @@ from synapse.logging.context import LoggingContext from synapse.server import HomeServer from synapse.storage import DataStore from synapse.storage.engines import PostgresEngine, create_engine -from synapse.storage.prepare_database import ( - _get_or_create_schema_state, - _setup_new_database, - prepare_database, -) +from synapse.storage.prepare_database import prepare_database from synapse.util.ratelimitutils import FederationRateLimiter # set this to True to run the tests against postgres instead of sqlite. @@ -88,11 +84,7 @@ def setupdb(): host=POSTGRES_HOST, password=POSTGRES_PASSWORD, ) - cur = db_conn.cursor() - _get_or_create_schema_state(cur, db_engine) - _setup_new_database(cur, db_engine) - db_conn.commit() - cur.close() + prepare_database(db_conn, db_engine, None) db_conn.close() def _cleanup(): -- cgit 1.5.1 From 3e3f9b684e2853217d86349f289780b397afa88a Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Tue, 22 Oct 2019 22:26:30 -0400 Subject: fix unit test --- tests/storage/test_devices.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 3cc18f9f1c..039cc79357 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -137,7 +137,9 @@ class DeviceStoreTestCase(tests.unittest.TestCase): """Check that an specific device ids exist in a list of device update EDUs""" self.assertEqual(len(device_updates), len(expected_device_ids)) - received_device_ids = {update["device_id"] for update in device_updates} + received_device_ids = { + update["device_id"] for edu_type, update in device_updates + } self.assertEqual(received_device_ids, set(expected_device_ids)) @defer.inlineCallbacks -- cgit 1.5.1 From 3ca4c7c516a349cbb9dace0ace33bb889955eda1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 23 Oct 2019 12:02:36 +0100 Subject: Use new EventPersistenceStore --- synapse/handlers/federation.py | 3 ++- synapse/handlers/message.py | 3 ++- synapse/server.py | 9 ++++++++- tests/replication/slave/storage/_base.py | 1 + tests/replication/slave/storage/test_events.py | 10 +++++++--- tests/storage/test_redaction.py | 11 ++++++----- tests/storage/test_room.py | 3 ++- tests/storage/test_roommember.py | 3 ++- tests/storage/test_state.py | 3 ++- tests/test_visibility.py | 7 ++++--- tests/utils.py | 10 ++++++++-- 11 files changed, 44 insertions(+), 19 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 4b4c6c15f9..7e9c9aee13 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -109,6 +109,7 @@ class FederationHandler(BaseHandler): self.hs = hs self.store = hs.get_datastore() + self.storage = hs.get_storage() self.federation_client = hs.get_federation_client() self.state_handler = hs.get_state_handler() self.server_name = hs.hostname @@ -2648,7 +2649,7 @@ class FederationHandler(BaseHandler): backfilled=backfilled, ) else: - max_stream_id = yield self.store.persist_events( + max_stream_id = yield self.storage.persistence.persist_events( event_and_contexts, backfilled=backfilled ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 0f8cce8ffe..7908a2d52c 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -234,6 +234,7 @@ class EventCreationHandler(object): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() + self.storage = hs.get_storage() self.state = hs.get_state_handler() self.clock = hs.get_clock() self.validator = EventValidator() @@ -868,7 +869,7 @@ class EventCreationHandler(object): if prev_state_ids: raise AuthError(403, "Changing the room create event is forbidden") - (event_stream_id, max_stream_id) = yield self.store.persist_event( + event_stream_id, max_stream_id = yield self.storage.persistence.persist_event( event, context=context ) diff --git a/synapse/server.py b/synapse/server.py index 1fcc7375d3..54a7f4aa5f 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -95,6 +95,7 @@ from synapse.server_notices.worker_server_notices_sender import ( WorkerServerNoticesSender, ) from synapse.state import StateHandler, StateResolutionHandler +from synapse.storage import DataStores, Storage from synapse.streams.events import EventSources from synapse.util import Clock from synapse.util.distributor import Distributor @@ -196,6 +197,7 @@ class HomeServer(object): "account_validity_handler", "saml_handler", "event_client_serializer", + "storage", ] REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"] @@ -225,6 +227,7 @@ class HomeServer(object): self.registration_ratelimiter = Ratelimiter() self.datastore = None + self.datastores = None # Other kwargs are explicit dependencies for depname in kwargs: @@ -234,6 +237,7 @@ class HomeServer(object): logger.info("Setting up.") with self.get_db_conn() as conn: self.datastore = self.DATASTORE_CLASS(conn, self) + self.datastores = DataStores(self.datastore, conn, self) conn.commit() logger.info("Finished setting up.") @@ -266,7 +270,7 @@ class HomeServer(object): return self.clock def get_datastore(self): - return self.datastore + return self.datastores.main def get_config(self): return self.config @@ -537,6 +541,9 @@ class HomeServer(object): def build_event_client_serializer(self): return EventClientSerializer(self) + def build_storage(self) -> Storage: + return Storage(self, self.datastores) + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 104349cdbd..4f924ce451 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -41,6 +41,7 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.master_store = self.hs.get_datastore() + self.storage = hs.get_storage() self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs) self.event_id = 0 diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index a368117b43..b68e9fe082 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -234,7 +234,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join" ) msg, msgctx = self.build_event() - self.get_success(self.master_store.persist_events([(j2, j2ctx), (msg, msgctx)])) + self.get_success( + self.storage.persistence.persist_events([(j2, j2ctx), (msg, msgctx)]) + ) self.replicate() event_source = RoomEventSource(self.hs) @@ -290,10 +292,12 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): if backfill: self.get_success( - self.master_store.persist_events([(event, context)], backfilled=True) + self.storage.persistence.persist_events( + [(event, context)], backfilled=True + ) ) else: - self.get_success(self.master_store.persist_event(event, context)) + self.get_success(self.storage.persistence.persist_event(event, context)) return event diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 427d3c49c5..4561c3e383 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -39,6 +39,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() + self.storage = hs.get_storage() self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() @@ -73,7 +74,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.store.persist_event(event, context)) + self.get_success(self.storage.persistence.persist_event(event, context)) return event @@ -95,7 +96,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.store.persist_event(event, context)) + self.get_success(self.storage.persistence.persist_event(event, context)) return event @@ -116,7 +117,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.store.persist_event(event, context)) + self.get_success(self.storage.persistence.persist_event(event, context)) return event @@ -263,7 +264,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) ) - self.get_success(self.store.persist_event(event_1, context_1)) + self.get_success(self.storage.persistence.persist_event(event_1, context_1)) event_2, context_2 = self.get_success( self.event_creation_handler.create_new_client_event( @@ -282,7 +283,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) ) ) - self.get_success(self.store.persist_event(event_2, context_2)) + self.get_success(self.storage.persistence.persist_event(event_2, context_2)) # fetch one of the redactions fetched = self.get_success(self.store.get_event(redaction_event_id1)) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 1bee45706f..3ddaa151fe 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -62,6 +62,7 @@ class RoomEventsStoreTestCase(unittest.TestCase): # Room events need the full datastore, for persist_event() and # get_room_state() self.store = hs.get_datastore() + self.storage = hs.get_storage() self.event_factory = hs.get_event_factory() self.room = RoomID.from_string("!abcde:test") @@ -72,7 +73,7 @@ class RoomEventsStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def inject_room_event(self, **kwargs): - yield self.store.persist_event( + yield self.storage.persistence.persist_event( self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) ) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 447a3c6ffb..9ddd17f73d 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -44,6 +44,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # We can't test the RoomMemberStore on its own without the other event # storage logic self.store = hs.get_datastore() + self.storage = hs.get_storage() self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() @@ -70,7 +71,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.store.persist_event(event, context)) + self.get_success(self.storage.persistence.persist_event(event, context)) return event diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 5c2cf3c2db..d573a3e07b 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -34,6 +34,7 @@ class StateStoreTestCase(tests.unittest.TestCase): hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() + self.storage = hs.get_storage() self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() @@ -63,7 +64,7 @@ class StateStoreTestCase(tests.unittest.TestCase): builder ) - yield self.store.persist_event(event, context) + yield self.storage.persistence.persist_event(event, context) return event diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 18f1a0035d..6ae1ea9b04 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -36,6 +36,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): self.event_creation_handler = self.hs.get_event_creation_handler() self.event_builder_factory = self.hs.get_event_builder_factory() self.store = self.hs.get_datastore() + self.storage = self.hs.get_storage() yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM") @@ -137,7 +138,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): event, context = yield self.event_creation_handler.create_new_client_event( builder ) - yield self.hs.get_datastore().persist_event(event, context) + yield self.storage.persistence.persist_event(event, context) return event @defer.inlineCallbacks @@ -159,7 +160,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): builder ) - yield self.hs.get_datastore().persist_event(event, context) + yield self.storage.persistence.persist_event(event, context) return event @defer.inlineCallbacks @@ -180,7 +181,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): builder ) - yield self.hs.get_datastore().persist_event(event, context) + yield self.storage.persistence.persist_event(event, context) return event @defer.inlineCallbacks diff --git a/tests/utils.py b/tests/utils.py index 0a64f75d04..2808eea933 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -326,10 +326,16 @@ def setup_test_homeserver( if homeserverToUse.__name__ == "TestHomeServer": hs.setup_master() else: + # If we have been given an explicit datastore we probably want to mock + # out the DataStores somehow too. This all feels a bit wrong, but then + # mocking the stores feels wrong too. + datastores = Mock(datastore=datastore) + hs = homeserverToUse( name, db_pool=None, datastore=datastore, + datastores=datastores, config=config, version_string="Synapse/tests", database_engine=db_engine, @@ -647,7 +653,7 @@ def create_room(hs, room_id, creator_id): creator_id (str) """ - store = hs.get_datastore() + persistence_store = hs.get_storage().persistence event_builder_factory = hs.get_event_builder_factory() event_creation_handler = hs.get_event_creation_handler() @@ -664,4 +670,4 @@ def create_room(hs, room_id, creator_id): event, context = yield event_creation_handler.create_new_client_event(builder) - yield store.persist_event(event, context) + yield persistence_store.persist_event(event, context) -- cgit 1.5.1 From 2794b79052f96b8103ce2b710959853313a82e90 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Thu, 24 Oct 2019 11:48:46 +0100 Subject: Option to suppress resource exceeded alerting (#6173) The expected use case is to suppress MAU limiting on small instances --- changelog.d/6173.feature | 1 + docs/sample_config.yaml | 8 +- synapse/api/auth.py | 12 ++- synapse/api/constants.py | 7 ++ synapse/config/server.py | 10 +- .../resource_limits_server_notices.py | 110 ++++++++++++++------- .../test_resource_limits_server_notices.py | 59 ++++++++++- tests/utils.py | 1 - 8 files changed, 161 insertions(+), 47 deletions(-) create mode 100644 changelog.d/6173.feature (limited to 'tests') diff --git a/changelog.d/6173.feature b/changelog.d/6173.feature new file mode 100644 index 0000000000..b1cabc322b --- /dev/null +++ b/changelog.d/6173.feature @@ -0,0 +1 @@ +Add config option to suppress client side resource limit alerting. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index b4dd146f06..6c81c0db75 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -241,7 +241,6 @@ listeners: # #hs_disabled: false #hs_disabled_message: 'Human readable reason for why the HS is blocked' -#hs_disabled_limit_type: 'error code(str), to help clients decode reason' # Monthly Active User Blocking # @@ -261,9 +260,16 @@ listeners: # sign up in a short space of time never to return after their initial # session. # +# 'mau_limit_alerting' is a means of limiting client side alerting +# should the mau limit be reached. This is useful for small instances +# where the admin has 5 mau seats (say) for 5 specific people and no +# interest increasing the mau limit further. Defaults to True, which +# means that alerting is enabled +# #limit_usage_by_mau: false #max_mau_value: 50 #mau_trial_days: 2 +#mau_limit_alerting: false # If enabled, the metrics for the number of monthly active users will # be populated, however no one will be limited. If limit_usage_by_mau diff --git a/synapse/api/auth.py b/synapse/api/auth.py index cd347fbe1b..53f3bb0fa8 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -25,7 +25,13 @@ from twisted.internet import defer import synapse.logging.opentracing as opentracing import synapse.types from synapse import event_auth -from synapse.api.constants import EventTypes, JoinRules, Membership, UserTypes +from synapse.api.constants import ( + EventTypes, + JoinRules, + LimitBlockingTypes, + Membership, + UserTypes, +) from synapse.api.errors import ( AuthError, Codes, @@ -726,7 +732,7 @@ class Auth(object): self.hs.config.hs_disabled_message, errcode=Codes.RESOURCE_LIMIT_EXCEEDED, admin_contact=self.hs.config.admin_contact, - limit_type=self.hs.config.hs_disabled_limit_type, + limit_type=LimitBlockingTypes.HS_DISABLED, ) if self.hs.config.limit_usage_by_mau is True: assert not (user_id and threepid) @@ -759,5 +765,5 @@ class Auth(object): "Monthly Active User Limit Exceeded", admin_contact=self.hs.config.admin_contact, errcode=Codes.RESOURCE_LIMIT_EXCEEDED, - limit_type="monthly_active_user", + limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER, ) diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 60e99e4663..312196675e 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -131,3 +131,10 @@ class RelationTypes(object): ANNOTATION = "m.annotation" REPLACE = "m.replace" REFERENCE = "m.reference" + + +class LimitBlockingTypes(object): + """Reasons that a server may be blocked""" + + MONTHLY_ACTIVE_USER = "monthly_active_user" + HS_DISABLED = "hs_disabled" diff --git a/synapse/config/server.py b/synapse/config/server.py index c942841578..d556df308d 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -171,6 +171,7 @@ class ServerConfig(Config): ) self.mau_trial_days = config.get("mau_trial_days", 0) + self.mau_limit_alerting = config.get("mau_limit_alerting", True) # How long to keep redacted events in the database in unredacted form # before redacting them. @@ -192,7 +193,6 @@ class ServerConfig(Config): # Options to disable HS self.hs_disabled = config.get("hs_disabled", False) self.hs_disabled_message = config.get("hs_disabled_message", "") - self.hs_disabled_limit_type = config.get("hs_disabled_limit_type", "") # Admin uri to direct users at should their instance become blocked # due to resource constraints @@ -675,7 +675,6 @@ class ServerConfig(Config): # #hs_disabled: false #hs_disabled_message: 'Human readable reason for why the HS is blocked' - #hs_disabled_limit_type: 'error code(str), to help clients decode reason' # Monthly Active User Blocking # @@ -695,9 +694,16 @@ class ServerConfig(Config): # sign up in a short space of time never to return after their initial # session. # + # 'mau_limit_alerting' is a means of limiting client side alerting + # should the mau limit be reached. This is useful for small instances + # where the admin has 5 mau seats (say) for 5 specific people and no + # interest increasing the mau limit further. Defaults to True, which + # means that alerting is enabled + # #limit_usage_by_mau: false #max_mau_value: 50 #mau_trial_days: 2 + #mau_limit_alerting: false # If enabled, the metrics for the number of monthly active users will # be populated, however no one will be limited. If limit_usage_by_mau diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index 81c4aff496..c0e7f475c9 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -20,6 +20,7 @@ from twisted.internet import defer from synapse.api.constants import ( EventTypes, + LimitBlockingTypes, ServerNoticeLimitReached, ServerNoticeMsgType, ) @@ -70,7 +71,7 @@ class ResourceLimitsServerNotices(object): return if not self._server_notices_manager.is_enabled(): - # Don't try and send server notices unles they've been enabled + # Don't try and send server notices unless they've been enabled return timestamp = yield self._store.user_last_seen_monthly_active(user_id) @@ -79,8 +80,6 @@ class ResourceLimitsServerNotices(object): # In practice, not sure we can ever get here return - # Determine current state of room - room_id = yield self._server_notices_manager.get_notice_room_for_user(user_id) if not room_id: @@ -88,50 +87,85 @@ class ResourceLimitsServerNotices(object): return yield self._check_and_set_tags(user_id, room_id) + + # Determine current state of room currently_blocked, ref_events = yield self._is_room_currently_blocked(room_id) + limit_msg = None + limit_type = None try: - # Normally should always pass in user_id if you have it, but in - # this case are checking what would happen to other users if they - # were to arrive. - try: - yield self._auth.check_auth_blocking() - is_auth_blocking = False - except ResourceLimitError as e: - is_auth_blocking = True - event_content = e.msg - event_limit_type = e.limit_type - - if currently_blocked and not is_auth_blocking: - # Room is notifying of a block, when it ought not to be. - # Remove block notification - content = {"pinned": ref_events} - yield self._server_notices_manager.send_notice( - user_id, content, EventTypes.Pinned, "" - ) + # Normally should always pass in user_id to check_auth_blocking + # if you have it, but in this case are checking what would happen + # to other users if they were to arrive. + yield self._auth.check_auth_blocking() + except ResourceLimitError as e: + limit_msg = e.msg + limit_type = e.limit_type - elif not currently_blocked and is_auth_blocking: + try: + if ( + limit_type == LimitBlockingTypes.MONTHLY_ACTIVE_USER + and not self._config.mau_limit_alerting + ): + # We have hit the MAU limit, but MAU alerting is disabled: + # reset room if necessary and return + if currently_blocked: + self._remove_limit_block_notification(user_id, ref_events) + return + + if currently_blocked and not limit_msg: + # Room is notifying of a block, when it ought not to be. + yield self._remove_limit_block_notification(user_id, ref_events) + elif not currently_blocked and limit_msg: # Room is not notifying of a block, when it ought to be. - # Add block notification - content = { - "body": event_content, - "msgtype": ServerNoticeMsgType, - "server_notice_type": ServerNoticeLimitReached, - "admin_contact": self._config.admin_contact, - "limit_type": event_limit_type, - } - event = yield self._server_notices_manager.send_notice( - user_id, content, EventTypes.Message + yield self._apply_limit_block_notification( + user_id, limit_msg, limit_type ) - - content = {"pinned": [event.event_id]} - yield self._server_notices_manager.send_notice( - user_id, content, EventTypes.Pinned, "" - ) - except SynapseError as e: logger.error("Error sending resource limits server notice: %s", e) + @defer.inlineCallbacks + def _remove_limit_block_notification(self, user_id, ref_events): + """Utility method to remove limit block notifications from the server + notices room. + + Args: + user_id (str): user to notify + ref_events (list[str]): The event_ids of pinned events that are unrelated to + limit blocking and need to be preserved. + """ + content = {"pinned": ref_events} + yield self._server_notices_manager.send_notice( + user_id, content, EventTypes.Pinned, "" + ) + + @defer.inlineCallbacks + def _apply_limit_block_notification(self, user_id, event_body, event_limit_type): + """Utility method to apply limit block notifications in the server + notices room. + + Args: + user_id (str): user to notify + event_body(str): The human readable text that describes the block. + event_limit_type(str): Specifies the type of block e.g. monthly active user + limit has been exceeded. + """ + content = { + "body": event_body, + "msgtype": ServerNoticeMsgType, + "server_notice_type": ServerNoticeLimitReached, + "admin_contact": self._config.admin_contact, + "limit_type": event_limit_type, + } + event = yield self._server_notices_manager.send_notice( + user_id, content, EventTypes.Message + ) + + content = {"pinned": [event.event_id]} + yield self._server_notices_manager.send_notice( + user_id, content, EventTypes.Pinned, "" + ) + @defer.inlineCallbacks def _check_and_set_tags(self, user_id, room_id): """ diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index cdf89e3383..eb540e34f6 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -17,7 +17,7 @@ from mock import Mock from twisted.internet import defer -from synapse.api.constants import EventTypes, ServerNoticeMsgType +from synapse.api.constants import EventTypes, LimitBlockingTypes, ServerNoticeMsgType from synapse.api.errors import ResourceLimitError from synapse.server_notices.resource_limits_server_notices import ( ResourceLimitsServerNotices, @@ -133,7 +133,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Would be better to check contents, but 2 calls == set blocking event - self.assertTrue(self._send_notice.call_count == 2) + self.assertEqual(self._send_notice.call_count, 2) def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self): """ @@ -158,6 +158,61 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self._send_notice.assert_not_called() + def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked(self): + """ + Test that when server is over MAU limit and alerting is suppressed, then + an alert message is not sent into the room + """ + self.hs.config.mau_limit_alerting = False + self._rlsn._auth.check_auth_blocking = Mock( + side_effect=ResourceLimitError( + 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER + ) + ) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) + + self.assertTrue(self._send_notice.call_count == 0) + + def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self): + """ + Test that when a server is disabled, that MAU limit alerting is ignored. + """ + self.hs.config.mau_limit_alerting = False + self._rlsn._auth.check_auth_blocking = Mock( + side_effect=ResourceLimitError( + 403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED + ) + ) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) + + # Would be better to check contents, but 2 calls == set blocking event + self.assertEqual(self._send_notice.call_count, 2) + + def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked(self): + """ + When the room is already in a blocked state, test that when alerting + is suppressed that the room is returned to an unblocked state. + """ + self.hs.config.mau_limit_alerting = False + self._rlsn._auth.check_auth_blocking = Mock( + side_effect=ResourceLimitError( + 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER + ) + ) + self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock( + return_value=defer.succeed((True, [])) + ) + + mock_event = Mock( + type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} + ) + self._rlsn._store.get_events = Mock( + return_value=defer.succeed({"123": mock_event}) + ) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) + + self._send_notice.assert_called_once() + class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): diff --git a/tests/utils.py b/tests/utils.py index 0a64f75d04..8cced4b7e8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -137,7 +137,6 @@ def default_config(name, parse=False): "limit_usage_by_mau": False, "hs_disabled": False, "hs_disabled_message": "", - "hs_disabled_limit_type": "", "max_mau_value": 50, "mau_trial_days": 0, "mau_stats_only": False, -- cgit 1.5.1 From 848cd388d96ec95b2598f1eaaf8967b8f064c08c Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Thu, 24 Oct 2019 21:13:01 -0400 Subject: delete keys when deleting backups --- synapse/storage/data_stores/main/e2e_room_keys.py | 8 +++ .../delta/56/delete_keys_from_deleted_backups.sql | 25 +++++++ tests/storage/test_e2e_room_keys.py | 76 ++++++++++++++++++++++ 3 files changed, 109 insertions(+) create mode 100644 synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql create mode 100644 tests/storage/test_e2e_room_keys.py (limited to 'tests') diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py index ef88e79293..1cbbae5b63 100644 --- a/synapse/storage/data_stores/main/e2e_room_keys.py +++ b/synapse/storage/data_stores/main/e2e_room_keys.py @@ -321,9 +321,17 @@ class EndToEndRoomKeyStore(SQLBaseStore): def _delete_e2e_room_keys_version_txn(txn): if version is None: this_version = self._get_current_version(txn, user_id) + if this_version is None: + raise StoreError(404, "No current backup version") else: this_version = version + self._simple_delete_txn( + txn, + table="e2e_room_keys", + keyvalues={"user_id": user_id, "version": this_version}, + ) + return self._simple_update_one_txn( txn, table="e2e_room_keys_versions", diff --git a/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql b/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql new file mode 100644 index 0000000000..1d2ddb1b1a --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql @@ -0,0 +1,25 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* delete room keys that belong to deleted room key version, or to room key + * versions that don't exist (anymore) + */ +DELETE FROM e2e_room_keys +WHERE version NOT IN ( + SELECT version + FROM e2e_room_keys_versions + WHERE e2e_room_keys.user_id = e2e_room_keys_versions.user_id + AND e2e_room_keys_versions.deleted = 0 +); diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py new file mode 100644 index 0000000000..ef4e7ce9d6 --- /dev/null +++ b/tests/storage/test_e2e_room_keys.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from tests import unittest, utils + +# sample room_key data for use in the tests +room_key = { + "first_message_index": 1, + "forwarded_count": 1, + "is_verified": False, + "session_data": "SSBBTSBBIEZJU0gK", +} + + +class E2eRoomKeysHandlerTestCase(unittest.TestCase): + def __init__(self, *args, **kwargs): + super(E2eRoomKeysHandlerTestCase, self).__init__(*args, **kwargs) + self.hs = None # type: synapse.server.HomeServer + self.store = None # type: synapse.storage.DataStore + + @defer.inlineCallbacks + def setUp(self): + hs = yield utils.setup_test_homeserver(self.addCleanup) + + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def test_room_keys_version_delete(self): + # test that deleting a room key backup deletes the keys + version1 = yield self.store.create_e2e_room_keys_version( + "user_id", {"algorithm": "rot13", "auth_data": {}} + ) + + yield self.store.set_e2e_room_key( + "user_id", version1, "room", "session", room_key + ) + + version2 = yield self.store.create_e2e_room_keys_version( + "user_id", {"algorithm": "rot13", "auth_data": {}} + ) + + yield self.store.set_e2e_room_key( + "user_id", version2, "room", "session", room_key + ) + + # make sure the keys were stored properly + keys = yield self.store.get_e2e_room_keys("user_id", version1) + self.assertEqual(len(keys["rooms"]), 1) + + keys = yield self.store.get_e2e_room_keys("user_id", version2) + self.assertEqual(len(keys["rooms"]), 1) + + # delete version1 + yield self.store.delete_e2e_room_keys_version("user_id", version1) + + # make sure the key from version1 is gone, and the key from version2 is + # still there + keys = yield self.store.get_e2e_room_keys("user_id", version1) + self.assertEqual(len(keys["rooms"]), 0) + + keys = yield self.store.get_e2e_room_keys("user_id", version2) + self.assertEqual(len(keys["rooms"]), 1) -- cgit 1.5.1 From 29a0bc5637e6811220f44ee727370a190b5be1ab Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Thu, 24 Oct 2019 21:43:02 -0400 Subject: remove some unnecessary lines --- tests/storage/test_e2e_room_keys.py | 5 ----- 1 file changed, 5 deletions(-) (limited to 'tests') diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py index ef4e7ce9d6..6658dbda94 100644 --- a/tests/storage/test_e2e_room_keys.py +++ b/tests/storage/test_e2e_room_keys.py @@ -27,11 +27,6 @@ room_key = { class E2eRoomKeysHandlerTestCase(unittest.TestCase): - def __init__(self, *args, **kwargs): - super(E2eRoomKeysHandlerTestCase, self).__init__(*args, **kwargs) - self.hs = None # type: synapse.server.HomeServer - self.store = None # type: synapse.storage.DataStore - @defer.inlineCallbacks def setUp(self): hs = yield utils.setup_test_homeserver(self.addCleanup) -- cgit 1.5.1 From 7e7a1461f64888fa71c8f06fb12f29cc3d233179 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 25 Oct 2019 10:57:37 +0100 Subject: Fix tests --- tests/handlers/test_stats.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'tests') diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index d5c8bd7612..e0075ccd32 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -607,6 +607,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): """ self.hs.config.stats_enabled = False + self.handler.stats_enabled = False u1 = self.register_user("u1", "pass") u1token = self.login("u1", "pass") @@ -618,6 +619,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.assertIsNone(self._get_current_stats("user", u1)) self.hs.config.stats_enabled = True + self.handler.stats_enabled = True self._perform_background_initial_update() -- cgit 1.5.1 From 4cf3a30a20c64c3939135b00b3eb5b06f273c9f9 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Fri, 25 Oct 2019 10:42:07 -0400 Subject: switch to using HomeserverTestCase --- tests/storage/test_e2e_room_keys.py | 44 +++++++++++++++++++++---------------- 1 file changed, 25 insertions(+), 19 deletions(-) (limited to 'tests') diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py index 6658dbda94..9935ac59ce 100644 --- a/tests/storage/test_e2e_room_keys.py +++ b/tests/storage/test_e2e_room_keys.py @@ -26,46 +26,52 @@ room_key = { } -class E2eRoomKeysHandlerTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield utils.setup_test_homeserver(self.addCleanup) - +class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver("server", http_client=None) self.store = hs.get_datastore() + return hs - @defer.inlineCallbacks def test_room_keys_version_delete(self): # test that deleting a room key backup deletes the keys - version1 = yield self.store.create_e2e_room_keys_version( - "user_id", {"algorithm": "rot13", "auth_data": {}} + version1 = self.get_success( + self.store.create_e2e_room_keys_version( + "user_id", {"algorithm": "rot13", "auth_data": {}} + ) ) - yield self.store.set_e2e_room_key( - "user_id", version1, "room", "session", room_key + self.get_success( + self.store.set_e2e_room_key( + "user_id", version1, "room", "session", room_key + ) ) - version2 = yield self.store.create_e2e_room_keys_version( - "user_id", {"algorithm": "rot13", "auth_data": {}} + version2 = self.get_success( + self.store.create_e2e_room_keys_version( + "user_id", {"algorithm": "rot13", "auth_data": {}} + ) ) - yield self.store.set_e2e_room_key( - "user_id", version2, "room", "session", room_key + self.get_success( + self.store.set_e2e_room_key( + "user_id", version2, "room", "session", room_key + ) ) # make sure the keys were stored properly - keys = yield self.store.get_e2e_room_keys("user_id", version1) + keys = self.get_success(self.store.get_e2e_room_keys("user_id", version1)) self.assertEqual(len(keys["rooms"]), 1) - keys = yield self.store.get_e2e_room_keys("user_id", version2) + keys = self.get_success(self.store.get_e2e_room_keys("user_id", version2)) self.assertEqual(len(keys["rooms"]), 1) # delete version1 - yield self.store.delete_e2e_room_keys_version("user_id", version1) + self.get_success(self.store.delete_e2e_room_keys_version("user_id", version1)) # make sure the key from version1 is gone, and the key from version2 is # still there - keys = yield self.store.get_e2e_room_keys("user_id", version1) + keys = self.get_success(self.store.get_e2e_room_keys("user_id", version1)) self.assertEqual(len(keys["rooms"]), 0) - keys = yield self.store.get_e2e_room_keys("user_id", version2) + keys = self.get_success(self.store.get_e2e_room_keys("user_id", version2)) self.assertEqual(len(keys["rooms"]), 1) -- cgit 1.5.1 From 4697c0de0b0b51b7b5791f3a842b174931261a47 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Fri, 25 Oct 2019 10:47:02 -0400 Subject: remove unneeded imports --- tests/storage/test_e2e_room_keys.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py index 9935ac59ce..d128fde441 100644 --- a/tests/storage/test_e2e_room_keys.py +++ b/tests/storage/test_e2e_room_keys.py @@ -13,9 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - -from tests import unittest, utils +from tests import unittest # sample room_key data for use in the tests room_key = { -- cgit 1.5.1 From d0d8a22c13427cce341dbb7ae1d92d2c0ae709c3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 28 Oct 2019 13:33:04 +0000 Subject: Quick fix to ensure cache descriptors always return deferreds --- synapse/push/bulk_push_rule_evaluator.py | 2 +- synapse/storage/data_stores/main/roommember.py | 2 +- synapse/util/caches/descriptors.py | 4 ++-- tests/util/caches/test_descriptors.py | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) (limited to 'tests') diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 22491f3700..2bbdd11941 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -79,7 +79,7 @@ class BulkPushRuleEvaluator(object): dict of user_id -> push_rules """ room_id = event.room_id - rules_for_room = self._get_rules_for_room(room_id) + rules_for_room = yield self._get_rules_for_room(room_id) rules_by_user = yield rules_for_room.get_rules(event, context) diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index e47ab604dd..bc04bfd7d4 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -720,7 +720,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # See bulk_get_push_rules_for_room for how we work around this. assert state_group is not None - cache = self._get_joined_hosts_cache(room_id) + cache = yield self._get_joined_hosts_cache(room_id) joined_hosts = yield cache.get_destinations(state_entry) return joined_hosts diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 5ac2530a6a..5a8da449b2 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -438,7 +438,7 @@ class CacheDescriptor(_CacheDescriptorBase): if isinstance(cached_result_d, ObservableDeferred): observer = cached_result_d.observe() else: - observer = cached_result_d + observer = defer.succeed(cached_result_d) except KeyError: ret = defer.maybeDeferred( @@ -618,7 +618,7 @@ class CacheListDescriptor(_CacheDescriptorBase): ) return make_deferred_yieldable(d) else: - return results + return defer.succeed(results) obj.__dict__[self.orig.__name__] = wrapped diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 5713870f48..f907903511 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -325,9 +325,9 @@ class DescriptorTestCase(unittest.TestCase): self.assertEqual(len(obj.fn.cache.cache), 3) r = obj.fn(1, 2) - self.assertEqual(r, ["spam", "eggs"]) + self.assertEqual(r.result, ["spam", "eggs"]) r = obj.fn(1, 3) - self.assertEqual(r, ["chips"]) + self.assertEqual(r.result, ["chips"]) obj.mock.assert_not_called() def test_cache_iterable_with_sync_exception(self): -- cgit 1.5.1 From 3f33879be4697666a304a472d090d4def0c671d0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 29 Oct 2019 14:05:32 +0000 Subject: Port federation_server to async/await --- synapse/federation/federation_server.py | 205 ++++++++++++++------------------ tests/handlers/test_typing.py | 3 + 2 files changed, 90 insertions(+), 118 deletions(-) (limited to 'tests') diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 5fc7c1d67b..15c1fa0a51 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -21,7 +21,6 @@ from six import iteritems from canonicaljson import json from prometheus_client import Counter -from twisted.internet import defer from twisted.internet.abstract import isIPAddress from twisted.python import failure @@ -86,14 +85,12 @@ class FederationServer(FederationBase): # come in waves. self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000) - @defer.inlineCallbacks - @log_function - def on_backfill_request(self, origin, room_id, versions, limit): - with (yield self._server_linearizer.queue((origin, room_id))): + async def on_backfill_request(self, origin, room_id, versions, limit): + with (await self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - pdus = yield self.handler.on_backfill_request( + pdus = await self.handler.on_backfill_request( origin, room_id, versions, limit ) @@ -101,9 +98,7 @@ class FederationServer(FederationBase): return 200, res - @defer.inlineCallbacks - @log_function - def on_incoming_transaction(self, origin, transaction_data): + async def on_incoming_transaction(self, origin, transaction_data): # keep this as early as possible to make the calculated origin ts as # accurate as possible. request_time = self._clock.time_msec() @@ -118,18 +113,17 @@ class FederationServer(FederationBase): # use a linearizer to ensure that we don't process the same transaction # multiple times in parallel. with ( - yield self._transaction_linearizer.queue( + await self._transaction_linearizer.queue( (origin, transaction.transaction_id) ) ): - result = yield self._handle_incoming_transaction( + result = await self._handle_incoming_transaction( origin, transaction, request_time ) return result - @defer.inlineCallbacks - def _handle_incoming_transaction(self, origin, transaction, request_time): + async def _handle_incoming_transaction(self, origin, transaction, request_time): """ Process an incoming transaction and return the HTTP response Args: @@ -140,7 +134,7 @@ class FederationServer(FederationBase): Returns: Deferred[(int, object)]: http response code and body """ - response = yield self.transaction_actions.have_responded(origin, transaction) + response = await self.transaction_actions.have_responded(origin, transaction) if response: logger.debug( @@ -159,7 +153,7 @@ class FederationServer(FederationBase): logger.info("Transaction PDU or EDU count too large. Returning 400") response = {} - yield self.transaction_actions.set_response( + await self.transaction_actions.set_response( origin, transaction, 400, response ) return 400, response @@ -195,7 +189,7 @@ class FederationServer(FederationBase): continue try: - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) except NotFoundError: logger.info("Ignoring PDU for unknown room_id: %s", room_id) continue @@ -221,11 +215,10 @@ class FederationServer(FederationBase): # require callouts to other servers to fetch missing events), but # impose a limit to avoid going too crazy with ram/cpu. - @defer.inlineCallbacks - def process_pdus_for_room(room_id): + async def process_pdus_for_room(room_id): logger.debug("Processing PDUs for %s", room_id) try: - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) except AuthError as e: logger.warn("Ignoring PDUs for room %s from banned server", room_id) for pdu in pdus_by_room[room_id]: @@ -237,7 +230,7 @@ class FederationServer(FederationBase): event_id = pdu.event_id with nested_logging_context(event_id): try: - yield self._handle_received_pdu(origin, pdu) + await self._handle_received_pdu(origin, pdu) pdu_results[event_id] = {} except FederationError as e: logger.warn("Error handling PDU %s: %s", event_id, e) @@ -251,36 +244,33 @@ class FederationServer(FederationBase): exc_info=(f.type, f.value, f.getTracebackObject()), ) - yield concurrently_execute( + await concurrently_execute( process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT ) if hasattr(transaction, "edus"): for edu in (Edu(**x) for x in transaction.edus): - yield self.received_edu(origin, edu.edu_type, edu.content) + await self.received_edu(origin, edu.edu_type, edu.content) response = {"pdus": pdu_results} logger.debug("Returning: %s", str(response)) - yield self.transaction_actions.set_response(origin, transaction, 200, response) + await self.transaction_actions.set_response(origin, transaction, 200, response) return 200, response - @defer.inlineCallbacks - def received_edu(self, origin, edu_type, content): + async def received_edu(self, origin, edu_type, content): received_edus_counter.inc() - yield self.registry.on_edu(edu_type, origin, content) + await self.registry.on_edu(edu_type, origin, content) - @defer.inlineCallbacks - @log_function - def on_context_state_request(self, origin, room_id, event_id): + async def on_context_state_request(self, origin, room_id, event_id): if not event_id: raise NotImplementedError("Specify an event") origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - in_room = yield self.auth.check_host_in_room(room_id, origin) + in_room = await self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") @@ -289,8 +279,8 @@ class FederationServer(FederationBase): # in the cache so we could return it without waiting for the linearizer # - but that's non-trivial to get right, and anyway somewhat defeats # the point of the linearizer. - with (yield self._server_linearizer.queue((origin, room_id))): - resp = yield self._state_resp_cache.wrap( + with (await self._server_linearizer.queue((origin, room_id))): + resp = await self._state_resp_cache.wrap( (room_id, event_id), self._on_context_state_request_compute, room_id, @@ -299,65 +289,58 @@ class FederationServer(FederationBase): return 200, resp - @defer.inlineCallbacks - def on_state_ids_request(self, origin, room_id, event_id): + async def on_state_ids_request(self, origin, room_id, event_id): if not event_id: raise NotImplementedError("Specify an event") origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - in_room = yield self.auth.check_host_in_room(room_id, origin) + in_room = await self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") - state_ids = yield self.handler.get_state_ids_for_pdu(room_id, event_id) - auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids) + state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id) + auth_chain_ids = await self.store.get_auth_chain_ids(state_ids) return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} - @defer.inlineCallbacks - def _on_context_state_request_compute(self, room_id, event_id): - pdus = yield self.handler.get_state_for_pdu(room_id, event_id) - auth_chain = yield self.store.get_auth_chain([pdu.event_id for pdu in pdus]) + async def _on_context_state_request_compute(self, room_id, event_id): + pdus = await self.handler.get_state_for_pdu(room_id, event_id) + auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus]) return { "pdus": [pdu.get_pdu_json() for pdu in pdus], "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], } - @defer.inlineCallbacks - @log_function - def on_pdu_request(self, origin, event_id): - pdu = yield self.handler.get_persisted_pdu(origin, event_id) + async def on_pdu_request(self, origin, event_id): + pdu = await self.handler.get_persisted_pdu(origin, event_id) if pdu: return 200, self._transaction_from_pdus([pdu]).get_dict() else: return 404, "" - @defer.inlineCallbacks - def on_query_request(self, query_type, args): + async def on_query_request(self, query_type, args): received_queries_counter.labels(query_type).inc() - resp = yield self.registry.on_query(query_type, args) + resp = await self.registry.on_query(query_type, args) return 200, resp - @defer.inlineCallbacks - def on_make_join_request(self, origin, room_id, user_id, supported_versions): + async def on_make_join_request(self, origin, room_id, user_id, supported_versions): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) if room_version not in supported_versions: logger.warn("Room version %s not in %s", room_version, supported_versions) raise IncompatibleRoomVersionError(room_version=room_version) - pdu = yield self.handler.on_make_join_request(origin, room_id, user_id) + pdu = await self.handler.on_make_join_request(origin, room_id, user_id) time_now = self._clock.time_msec() return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} - @defer.inlineCallbacks - def on_invite_request(self, origin, content, room_version): + async def on_invite_request(self, origin, content, room_version): if room_version not in KNOWN_ROOM_VERSIONS: raise SynapseError( 400, @@ -369,28 +352,27 @@ class FederationServer(FederationBase): pdu = event_from_pdu_json(content, format_ver) origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, pdu.room_id) - pdu = yield self._check_sigs_and_hash(room_version, pdu) - ret_pdu = yield self.handler.on_invite_request(origin, pdu) + await self.check_server_matches_acl(origin_host, pdu.room_id) + pdu = await self._check_sigs_and_hash(room_version, pdu) + ret_pdu = await self.handler.on_invite_request(origin, pdu) time_now = self._clock.time_msec() return {"event": ret_pdu.get_pdu_json(time_now)} - @defer.inlineCallbacks - def on_send_join_request(self, origin, content, room_id): + async def on_send_join_request(self, origin, content, room_id): logger.debug("on_send_join_request: content: %s", content) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) pdu = event_from_pdu_json(content, format_ver) origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, pdu.room_id) + await self.check_server_matches_acl(origin_host, pdu.room_id) logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) - pdu = yield self._check_sigs_and_hash(room_version, pdu) + pdu = await self._check_sigs_and_hash(room_version, pdu) - res_pdus = yield self.handler.on_send_join_request(origin, pdu) + res_pdus = await self.handler.on_send_join_request(origin, pdu) time_now = self._clock.time_msec() return ( 200, @@ -402,48 +384,44 @@ class FederationServer(FederationBase): }, ) - @defer.inlineCallbacks - def on_make_leave_request(self, origin, room_id, user_id): + async def on_make_leave_request(self, origin, room_id, user_id): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) - pdu = yield self.handler.on_make_leave_request(origin, room_id, user_id) + await self.check_server_matches_acl(origin_host, room_id) + pdu = await self.handler.on_make_leave_request(origin, room_id, user_id) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) time_now = self._clock.time_msec() return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} - @defer.inlineCallbacks - def on_send_leave_request(self, origin, content, room_id): + async def on_send_leave_request(self, origin, content, room_id): logger.debug("on_send_leave_request: content: %s", content) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) pdu = event_from_pdu_json(content, format_ver) origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, pdu.room_id) + await self.check_server_matches_acl(origin_host, pdu.room_id) logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures) - pdu = yield self._check_sigs_and_hash(room_version, pdu) + pdu = await self._check_sigs_and_hash(room_version, pdu) - yield self.handler.on_send_leave_request(origin, pdu) + await self.handler.on_send_leave_request(origin, pdu) return 200, {} - @defer.inlineCallbacks - def on_event_auth(self, origin, room_id, event_id): - with (yield self._server_linearizer.queue((origin, room_id))): + async def on_event_auth(self, origin, room_id, event_id): + with (await self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) time_now = self._clock.time_msec() - auth_pdus = yield self.handler.on_event_auth(event_id) + auth_pdus = await self.handler.on_event_auth(event_id) res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]} return 200, res - @defer.inlineCallbacks - def on_query_auth_request(self, origin, content, room_id, event_id): + async def on_query_auth_request(self, origin, content, room_id, event_id): """ Content is a dict with keys:: auth_chain (list): A list of events that give the auth chain. @@ -462,22 +440,22 @@ class FederationServer(FederationBase): Returns: Deferred: Results in `dict` with the same format as `content` """ - with (yield self._server_linearizer.queue((origin, room_id))): + with (await self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) auth_chain = [ event_from_pdu_json(e, format_ver) for e in content["auth_chain"] ] - signed_auth = yield self._check_sigs_and_hash_and_fetch( + signed_auth = await self._check_sigs_and_hash_and_fetch( origin, auth_chain, outlier=True, room_version=room_version ) - ret = yield self.handler.on_query_auth( + ret = await self.handler.on_query_auth( origin, event_id, room_id, @@ -503,16 +481,14 @@ class FederationServer(FederationBase): return self.on_query_request("user_devices", user_id) @trace - @defer.inlineCallbacks - @log_function - def on_claim_client_keys(self, origin, content): + async def on_claim_client_keys(self, origin, content): query = [] for user_id, device_keys in content.get("one_time_keys", {}).items(): for device_id, algorithm in device_keys.items(): query.append((user_id, device_id, algorithm)) log_kv({"message": "Claiming one time keys.", "user, device pairs": query}) - results = yield self.store.claim_e2e_one_time_keys(query) + results = await self.store.claim_e2e_one_time_keys(query) json_result = {} for user_id, device_keys in results.items(): @@ -536,14 +512,12 @@ class FederationServer(FederationBase): return {"one_time_keys": json_result} - @defer.inlineCallbacks - @log_function - def on_get_missing_events( + async def on_get_missing_events( self, origin, room_id, earliest_events, latest_events, limit ): - with (yield self._server_linearizer.queue((origin, room_id))): + with (await self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) logger.info( "on_get_missing_events: earliest_events: %r, latest_events: %r," @@ -553,7 +527,7 @@ class FederationServer(FederationBase): limit, ) - missing_events = yield self.handler.on_get_missing_events( + missing_events = await self.handler.on_get_missing_events( origin, room_id, earliest_events, latest_events, limit ) @@ -586,8 +560,7 @@ class FederationServer(FederationBase): destination=None, ) - @defer.inlineCallbacks - def _handle_received_pdu(self, origin, pdu): + async def _handle_received_pdu(self, origin, pdu): """ Process a PDU received in a federation /send/ transaction. If the event is invalid, then this method throws a FederationError. @@ -640,37 +613,34 @@ class FederationServer(FederationBase): logger.info("Accepting join PDU %s from %s", pdu.event_id, origin) # We've already checked that we know the room version by this point - room_version = yield self.store.get_room_version(pdu.room_id) + room_version = await self.store.get_room_version(pdu.room_id) # Check signature. try: - pdu = yield self._check_sigs_and_hash(room_version, pdu) + pdu = await self._check_sigs_and_hash(room_version, pdu) except SynapseError as e: raise FederationError("ERROR", e.code, e.msg, affected=pdu.event_id) - yield self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True) + await self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True) def __str__(self): return "" % self.server_name - @defer.inlineCallbacks - def exchange_third_party_invite( + async def exchange_third_party_invite( self, sender_user_id, target_user_id, room_id, signed ): - ret = yield self.handler.exchange_third_party_invite( + ret = await self.handler.exchange_third_party_invite( sender_user_id, target_user_id, room_id, signed ) return ret - @defer.inlineCallbacks - def on_exchange_third_party_invite_request(self, room_id, event_dict): - ret = yield self.handler.on_exchange_third_party_invite_request( + async def on_exchange_third_party_invite_request(self, room_id, event_dict): + ret = await self.handler.on_exchange_third_party_invite_request( room_id, event_dict ) return ret - @defer.inlineCallbacks - def check_server_matches_acl(self, server_name, room_id): + async def check_server_matches_acl(self, server_name, room_id): """Check if the given server is allowed by the server ACLs in the room Args: @@ -680,13 +650,13 @@ class FederationServer(FederationBase): Raises: AuthError if the server does not match the ACL """ - state_ids = yield self.store.get_current_state_ids(room_id) + state_ids = await self.store.get_current_state_ids(room_id) acl_event_id = state_ids.get((EventTypes.ServerACL, "")) if not acl_event_id: return - acl_event = yield self.store.get_event(acl_event_id) + acl_event = await self.store.get_event(acl_event_id) if server_matches_acl_event(server_name, acl_event): return @@ -799,15 +769,14 @@ class FederationHandlerRegistry(object): self.query_handlers[query_type] = handler - @defer.inlineCallbacks - def on_edu(self, edu_type, origin, content): + async def on_edu(self, edu_type, origin, content): handler = self.edu_handlers.get(edu_type) if not handler: logger.warn("No handler registered for EDU type %s", edu_type) with start_active_span_from_edu(content, "handle_edu"): try: - yield handler(origin, content) + await handler(origin, content) except SynapseError as e: logger.info("Failed to handle edu %r: %r", edu_type, e) except Exception: diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 67f1013051..f360c8e965 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -144,6 +144,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.datastore.get_to_device_stream_token = lambda: 0 self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: ([], 0) self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None + self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed( + None + ) def test_started_typing_local(self): self.room_members = [U_APPLE, U_BANANA] -- cgit 1.5.1 From 326b3dace77aeb36e516ea9b04ba1baa171bcb47 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 30 Oct 2019 11:35:46 +0000 Subject: Make ObservableDeferred.observe() always return deferred. This makes it easier to use in an async/await world. Also fixes a bug where cache descriptors would occaisonally return a raw value rather than a deferred. --- synapse/util/async_helpers.py | 7 ++----- tests/storage/test__base.py | 2 +- tests/util/caches/test_descriptors.py | 4 ++-- 3 files changed, 5 insertions(+), 8 deletions(-) (limited to 'tests') diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 7659eaeb42..fd75ba27ad 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -86,11 +86,8 @@ class ObservableDeferred(object): deferred.addCallbacks(callback, errback) - def observe(self): + def observe(self) -> defer.Deferred: """Observe the underlying deferred. - - Can return either a deferred if the underlying deferred is still pending - (or has failed), or the actual value. Callers may need to use maybeDeferred. """ if not self._result: d = defer.Deferred() @@ -105,7 +102,7 @@ class ObservableDeferred(object): return d else: success, res = self._result - return res if success else defer.fail(res) + return defer.succeed(res) if success else defer.fail(res) def observers(self): return self._observers diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index dd49a14524..9b81b536f5 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -197,7 +197,7 @@ class CacheDecoratorTestCase(unittest.TestCase): a.func.prefill(("foo",), ObservableDeferred(d)) - self.assertEquals(a.func("foo"), d.result) + self.assertEquals(a.func("foo").result, d.result) self.assertEquals(callcount[0], 0) @defer.inlineCallbacks diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index f907903511..39e360fe24 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -310,14 +310,14 @@ class DescriptorTestCase(unittest.TestCase): obj.mock.return_value = ["spam", "eggs"] r = obj.fn(1, 2) - self.assertEqual(r, ["spam", "eggs"]) + self.assertEqual(r.result, ["spam", "eggs"]) obj.mock.assert_called_once_with(1, 2) obj.mock.reset_mock() # a call with different params should call the mock again obj.mock.return_value = ["chips"] r = obj.fn(1, 3) - self.assertEqual(r, ["chips"]) + self.assertEqual(r.result, ["chips"]) obj.mock.assert_called_once_with(1, 3) obj.mock.reset_mock() -- cgit 1.5.1 From a8d16f6c00d5adb204af5fa73ffaea40eea4b632 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 30 Oct 2019 13:33:38 +0000 Subject: Review comments --- synapse/server.py | 5 ++--- synapse/storage/__init__.py | 4 ++-- synapse/storage/data_stores/main/events.py | 19 ++++++++++++++----- synapse/storage/persist_events.py | 13 ++++++++++--- tests/crypto/test_keyring.py | 4 ++-- tests/test_federation.py | 11 +++++++---- 6 files changed, 37 insertions(+), 19 deletions(-) (limited to 'tests') diff --git a/synapse/server.py b/synapse/server.py index 54a7f4aa5f..0b81af646c 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -226,7 +226,6 @@ class HomeServer(object): self.admin_redaction_ratelimiter = Ratelimiter() self.registration_ratelimiter = Ratelimiter() - self.datastore = None self.datastores = None # Other kwargs are explicit dependencies @@ -236,8 +235,8 @@ class HomeServer(object): def setup(self): logger.info("Setting up.") with self.get_db_conn() as conn: - self.datastore = self.DATASTORE_CLASS(conn, self) - self.datastores = DataStores(self.datastore, conn, self) + datastore = self.DATASTORE_CLASS(conn, self) + self.datastores = DataStores(datastore, conn, self) conn.commit() logger.info("Finished setting up.") diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index a58187a76f..a6429d17ed 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -29,7 +29,7 @@ stored in `synapse.storage.schema`. from synapse.storage.data_stores import DataStores from synapse.storage.data_stores.main import DataStore -from synapse.storage.persist_events import EventsPersistenceStore +from synapse.storage.persist_events import EventsPersistenceStorage __all__ = ["DataStores", "DataStore"] @@ -44,7 +44,7 @@ class Storage(object): # interfaces. self.main = stores.main - self.persistence = EventsPersistenceStore(hs, stores) + self.persistence = EventsPersistenceStorage(hs, stores) def are_all_users_on_domain(txn, database_engine, domain): diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 6304531cd5..813f34528c 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -146,7 +146,7 @@ class EventsStore( @_retry_on_integrity_error @defer.inlineCallbacks - def _persist_events( + def _persist_events_and_state_updates( self, events_and_contexts, current_state_for_room, @@ -155,18 +155,27 @@ class EventsStore( backfilled=False, delete_existing=False, ): - """Persist events to db + """Persist a set of events alongside updates to the current state and + forward extremities tables. Args: events_and_contexts (list[(EventBase, EventContext)]): - backfilled (bool): + current_state_for_room (dict[str, dict]): Map from room_id to the + current state of the room based on forward extremities + state_delta_for_room (dict[str, tuple]): Map from room_id to tuple + of `(to_delete, to_insert)` where to_delete is a list + of type/state keys to remove from current state, and to_insert + is a map (type,key)->event_id giving the state delta in each + room. + new_forward_extremities (dict[str, list[str]]): Map from room_id + to list of event IDs that are the new forward extremities of + the room. + backfilled (bool) delete_existing (bool): Returns: Deferred: resolves when the events have been persisted """ - if not events_and_contexts: - return # We want to calculate the stream orderings as late as possible, as # we only notify after all events with a lesser stream ordering have diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 9a63953d4d..cf66225574 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -171,7 +171,13 @@ class _EventPeristenceQueue(object): pass -class EventsPersistenceStore(object): +class EventsPersistenceStorage(object): + """High level interface for handling persisting newly received events. + + Takes care of batching up events by room, and calculating the necessary + current state and forward extremity changes. + """ + def __init__(self, hs, stores: DataStores): # We ultimately want to split out the state store from the main store, # so we use separate variables here even though they point to the same @@ -257,7 +263,8 @@ class EventsPersistenceStore(object): def _persist_events( self, events_and_contexts, backfilled=False, delete_existing=False ): - """Persist events to db + """Calculates the change to current state and forward extremities, and + persists the given events and with those updates. Args: events_and_contexts (list[(EventBase, EventContext)]): @@ -399,7 +406,7 @@ class EventsPersistenceStore(object): if current_state is not None: current_state_for_room[room_id] = current_state - yield self.main_store._persist_events( + yield self.main_store._persist_events_and_state_updates( chunk, current_state_for_room=current_state_for_room, state_delta_for_room=state_delta_for_room, diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index c4f0bbd3dd..8efd39c7f7 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -178,7 +178,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): kr = keyring.Keyring(self.hs) key1 = signedjson.key.generate_signing_key(1) - r = self.hs.datastore.store_server_verify_keys( + r = self.hs.get_datastore().store_server_verify_keys( "server9", time.time() * 1000, [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))], @@ -209,7 +209,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): ) key1 = signedjson.key.generate_signing_key(1) - r = self.hs.datastore.store_server_verify_keys( + r = self.hs.get_datastore().store_server_verify_keys( "server9", time.time() * 1000, [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), None))], diff --git a/tests/test_federation.py b/tests/test_federation.py index a73f18f88e..d1acb16f30 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -36,7 +36,8 @@ class MessageAcceptTests(unittest.TestCase): # Figure out what the most recent event is most_recent = self.successResultOf( maybeDeferred( - self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + self.homeserver.get_datastore().get_latest_event_ids_in_room, + self.room_id, ) )[0] @@ -75,7 +76,8 @@ class MessageAcceptTests(unittest.TestCase): self.assertEqual( self.successResultOf( maybeDeferred( - self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + self.homeserver.get_datastore().get_latest_event_ids_in_room, + self.room_id, ) )[0], "$join:test.serv", @@ -97,7 +99,8 @@ class MessageAcceptTests(unittest.TestCase): # Figure out what the most recent event is most_recent = self.successResultOf( maybeDeferred( - self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + self.homeserver.get_datastore().get_latest_event_ids_in_room, + self.room_id, ) )[0] @@ -137,6 +140,6 @@ class MessageAcceptTests(unittest.TestCase): # Make sure the invalid event isn't there extrem = maybeDeferred( - self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + self.homeserver.get_datastore().get_latest_event_ids_in_room, self.room_id ) self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv") -- cgit 1.5.1 From 69f0054ce675bd9d35104c39af9fae9a908b7f33 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 23 Oct 2019 17:25:54 +0100 Subject: Port to use state storage --- synapse/handlers/admin.py | 7 +- synapse/handlers/device.py | 3 +- synapse/handlers/events.py | 6 +- synapse/handlers/federation.py | 19 ++--- synapse/handlers/initial_sync.py | 14 ++-- synapse/handlers/message.py | 10 +-- synapse/handlers/pagination.py | 6 +- synapse/handlers/room.py | 6 +- synapse/handlers/search.py | 12 ++-- synapse/handlers/sync.py | 20 +++--- synapse/notifier.py | 6 +- synapse/push/httppusher.py | 3 +- synapse/push/mailer.py | 3 +- synapse/push/push_tools.py | 9 +-- synapse/state/__init__.py | 13 ++-- synapse/visibility.py | 30 ++++---- tests/storage/test_state.py | 150 ++++++++++++++++++++++++++------------- tests/test_state.py | 3 + tests/test_visibility.py | 11 ++- 19 files changed, 216 insertions(+), 115 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 1a87b58838..6407d56f8e 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -30,6 +30,9 @@ class AdminHandler(BaseHandler): def __init__(self, hs): super(AdminHandler, self).__init__(hs) + self.storage = hs.get_storage() + self.state_store = self.storage.state + @defer.inlineCallbacks def get_whois(self, user): connections = [] @@ -205,7 +208,7 @@ class AdminHandler(BaseHandler): from_key = events[-1].internal_metadata.after - events = yield filter_events_for_client(self.store, user_id, events) + events = yield filter_events_for_client(self.storage, user_id, events) writer.write_events(room_id, events) @@ -241,7 +244,7 @@ class AdminHandler(BaseHandler): for event_id in extremities: if not event_to_unseen_prevs[event_id]: continue - state = yield self.store.get_state_for_event(event_id) + state = yield self.state_store.get_state_for_event(event_id) writer.write_state(room_id, event_id, state) return writer.finished() diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 5f23ee4488..b3fd7e6249 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -46,6 +46,7 @@ class DeviceWorkerHandler(BaseHandler): self.hs = hs self.state = hs.get_state_handler() + self.state_store = hs.get_storage().state self._auth_handler = hs.get_auth_handler() @trace @@ -178,7 +179,7 @@ class DeviceWorkerHandler(BaseHandler): continue # mapping from event_id -> state_dict - prev_state_ids = yield self.store.get_state_ids_for_events(event_ids) + prev_state_ids = yield self.state_store.get_state_ids_for_events(event_ids) # Check if we've joined the room? If so we just blindly add all the users to # the "possibly changed" users. diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 5e748687e3..45fe13c62f 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -147,6 +147,10 @@ class EventStreamHandler(BaseHandler): class EventHandler(BaseHandler): + def __init__(self, hs): + super(EventHandler, self).__init__(hs) + self.storage = hs.get_storage() + @defer.inlineCallbacks def get_event(self, user, room_id, event_id): """Retrieve a single specified event. @@ -172,7 +176,7 @@ class EventHandler(BaseHandler): is_peeking = user.to_string() not in users filtered = yield filter_events_for_client( - self.store, user.to_string(), [event], is_peeking=is_peeking + self.storage, user.to_string(), [event], is_peeking=is_peeking ) if not filtered: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 08276fdebf..4d9e33346d 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -110,6 +110,7 @@ class FederationHandler(BaseHandler): self.store = hs.get_datastore() self.storage = hs.get_storage() + self.state_store = self.storage.state self.federation_client = hs.get_federation_client() self.state_handler = hs.get_state_handler() self.server_name = hs.hostname @@ -325,7 +326,7 @@ class FederationHandler(BaseHandler): event_map = {event_id: pdu} try: # Get the state of the events we know about - ours = yield self.store.get_state_groups_ids(room_id, seen) + ours = yield self.state_store.get_state_groups_ids(room_id, seen) # state_maps is a list of mappings from (type, state_key) to event_id state_maps = list( @@ -889,7 +890,7 @@ class FederationHandler(BaseHandler): # We set `check_history_visibility_only` as we might otherwise get false # positives from users having been erased. filtered_extremities = yield filter_events_for_server( - self.store, + self.storage, self.server_name, list(extremities_events.values()), redact=False, @@ -1550,7 +1551,7 @@ class FederationHandler(BaseHandler): event_id, allow_none=False, check_room_id=room_id ) - state_groups = yield self.store.get_state_groups(room_id, [event_id]) + state_groups = yield self.state_store.get_state_groups(room_id, [event_id]) if state_groups: _, state = list(iteritems(state_groups)).pop() @@ -1579,7 +1580,7 @@ class FederationHandler(BaseHandler): event_id, allow_none=False, check_room_id=room_id ) - state_groups = yield self.store.get_state_groups_ids(room_id, [event_id]) + state_groups = yield self.state_store.get_state_groups_ids(room_id, [event_id]) if state_groups: _, state = list(state_groups.items()).pop() @@ -1607,7 +1608,7 @@ class FederationHandler(BaseHandler): events = yield self.store.get_backfill_events(room_id, pdu_list, limit) - events = yield filter_events_for_server(self.store, origin, events) + events = yield filter_events_for_server(self.storage, origin, events) return events @@ -1637,7 +1638,7 @@ class FederationHandler(BaseHandler): if not in_room: raise AuthError(403, "Host not in room.") - events = yield filter_events_for_server(self.store, origin, [event]) + events = yield filter_events_for_server(self.storage, origin, [event]) event = events[0] return event else: @@ -1903,7 +1904,7 @@ class FederationHandler(BaseHandler): # given state at the event. This should correctly handle cases # like bans, especially with state res v2. - state_sets = yield self.store.get_state_groups( + state_sets = yield self.state_store.get_state_groups( event.room_id, extrem_ids ) state_sets = list(state_sets.values()) @@ -1994,7 +1995,7 @@ class FederationHandler(BaseHandler): ) missing_events = yield filter_events_for_server( - self.store, origin, missing_events + self.storage, origin, missing_events ) return missing_events @@ -2235,7 +2236,7 @@ class FederationHandler(BaseHandler): # create a new state group as a delta from the existing one. prev_group = context.state_group - state_group = yield self.store.store_state_group( + state_group = yield self.state_store.store_state_group( event.event_id, event.room_id, prev_group=prev_group, diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index f991efeee3..49c9e031f9 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -43,6 +43,8 @@ class InitialSyncHandler(BaseHandler): self.validator = EventValidator() self.snapshot_cache = SnapshotCache() self._event_serializer = hs.get_event_client_serializer() + self.storage = hs.get_storage() + self.state_store = self.storage.state def snapshot_all_rooms( self, @@ -169,7 +171,7 @@ class InitialSyncHandler(BaseHandler): elif event.membership == Membership.LEAVE: room_end_token = "s%d" % (event.stream_ordering,) deferred_room_state = run_in_background( - self.store.get_state_for_events, [event.event_id] + self.state_store.get_state_for_events, [event.event_id] ) deferred_room_state.addCallback( lambda states: states[event.event_id] @@ -189,7 +191,9 @@ class InitialSyncHandler(BaseHandler): ) ).addErrback(unwrapFirstError) - messages = yield filter_events_for_client(self.store, user_id, messages) + messages = yield filter_events_for_client( + self.storage, user_id, messages + ) start_token = now_token.copy_and_replace("room_key", token) end_token = now_token.copy_and_replace("room_key", room_end_token) @@ -307,7 +311,7 @@ class InitialSyncHandler(BaseHandler): def _room_initial_sync_parted( self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking ): - room_state = yield self.store.get_state_for_events([member_event_id]) + room_state = yield self.state_store.get_state_for_events([member_event_id]) room_state = room_state[member_event_id] @@ -322,7 +326,7 @@ class InitialSyncHandler(BaseHandler): ) messages = yield filter_events_for_client( - self.store, user_id, messages, is_peeking=is_peeking + self.storage, user_id, messages, is_peeking=is_peeking ) start_token = StreamToken.START.copy_and_replace("room_key", token) @@ -414,7 +418,7 @@ class InitialSyncHandler(BaseHandler): ) messages = yield filter_events_for_client( - self.store, user_id, messages, is_peeking=is_peeking + self.storage, user_id, messages, is_peeking=is_peeking ) start_token = now_token.copy_and_replace("room_key", token) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 7908a2d52c..6e2a360262 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -59,6 +59,8 @@ class MessageHandler(object): self.clock = hs.get_clock() self.state = hs.get_state_handler() self.store = hs.get_datastore() + self.storage = hs.get_storage() + self.state_store = self.storage.state self._event_serializer = hs.get_event_client_serializer() @defer.inlineCallbacks @@ -82,7 +84,7 @@ class MessageHandler(object): data = yield self.state.get_current_state(room_id, event_type, state_key) elif membership == Membership.LEAVE: key = (event_type, state_key) - room_state = yield self.store.get_state_for_events( + room_state = yield self.state_store.get_state_for_events( [membership_event_id], StateFilter.from_types([key]) ) data = room_state[membership_event_id].get(key) @@ -135,12 +137,12 @@ class MessageHandler(object): raise NotFoundError("Can't find event for token %s" % (at_token,)) visible_events = yield filter_events_for_client( - self.store, user_id, last_events + self.storage, user_id, last_events ) event = last_events[0] if visible_events: - room_state = yield self.store.get_state_for_events( + room_state = yield self.state_store.get_state_for_events( [event.event_id], state_filter=state_filter ) room_state = room_state[event.event_id] @@ -161,7 +163,7 @@ class MessageHandler(object): ) room_state = yield self.store.get_events(state_ids.values()) elif membership == Membership.LEAVE: - room_state = yield self.store.get_state_for_events( + room_state = yield self.state_store.get_state_for_events( [membership_event_id], state_filter=state_filter ) room_state = room_state[membership_event_id] diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 5744f4579d..b7185fe7a0 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -69,6 +69,8 @@ class PaginationHandler(object): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() + self.storage = hs.get_storage() + self.state_store = self.storage.state self.clock = hs.get_clock() self._server_name = hs.hostname @@ -255,7 +257,7 @@ class PaginationHandler(object): events = event_filter.filter(events) events = yield filter_events_for_client( - self.store, user_id, events, is_peeking=(member_event_id is None) + self.storage, user_id, events, is_peeking=(member_event_id is None) ) if not events: @@ -274,7 +276,7 @@ class PaginationHandler(object): (EventTypes.Member, event.sender) for event in events ) - state_ids = yield self.store.get_state_ids_for_event( + state_ids = yield self.state_store.get_state_ids_for_event( events[0].event_id, state_filter=state_filter ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 2816bd8f87..84bad39815 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -822,6 +822,8 @@ class RoomContextHandler(object): def __init__(self, hs): self.hs = hs self.store = hs.get_datastore() + self.storage = hs.get_storage() + self.state_store = self.storage.state @defer.inlineCallbacks def get_event_context(self, user, room_id, event_id, limit, event_filter): @@ -848,7 +850,7 @@ class RoomContextHandler(object): def filter_evts(events): return filter_events_for_client( - self.store, user.to_string(), events, is_peeking=is_peeking + self.storage, user.to_string(), events, is_peeking=is_peeking ) event = yield self.store.get_event( @@ -890,7 +892,7 @@ class RoomContextHandler(object): # first? Shouldn't we be consistent with /sync? # https://github.com/matrix-org/matrix-doc/issues/687 - state = yield self.store.get_state_for_events( + state = yield self.state_store.get_state_for_events( [last_event_id], state_filter=state_filter ) results["state"] = list(state[last_event_id].values()) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index cd5e90bacb..f4d8a60774 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -35,6 +35,8 @@ class SearchHandler(BaseHandler): def __init__(self, hs): super(SearchHandler, self).__init__(hs) self._event_serializer = hs.get_event_client_serializer() + self.storage = hs.get_storage() + self.state_store = self.storage.state @defer.inlineCallbacks def get_old_rooms_from_upgraded_room(self, room_id): @@ -221,7 +223,7 @@ class SearchHandler(BaseHandler): filtered_events = search_filter.filter([r["event"] for r in results]) events = yield filter_events_for_client( - self.store, user.to_string(), filtered_events + self.storage, user.to_string(), filtered_events ) events.sort(key=lambda e: -rank_map[e.event_id]) @@ -271,7 +273,7 @@ class SearchHandler(BaseHandler): filtered_events = search_filter.filter([r["event"] for r in results]) events = yield filter_events_for_client( - self.store, user.to_string(), filtered_events + self.storage, user.to_string(), filtered_events ) room_events.extend(events) @@ -340,11 +342,11 @@ class SearchHandler(BaseHandler): ) res["events_before"] = yield filter_events_for_client( - self.store, user.to_string(), res["events_before"] + self.storage, user.to_string(), res["events_before"] ) res["events_after"] = yield filter_events_for_client( - self.store, user.to_string(), res["events_after"] + self.storage, user.to_string(), res["events_after"] ) res["start"] = now_token.copy_and_replace( @@ -372,7 +374,7 @@ class SearchHandler(BaseHandler): [(EventTypes.Member, sender) for sender in senders] ) - state = yield self.store.get_state_for_event( + state = yield self.state_store.get_state_for_event( last_event_id, state_filter ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index d99160e9d7..43a082dcda 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -230,6 +230,8 @@ class SyncHandler(object): self.response_cache = ResponseCache(hs, "sync") self.state = hs.get_state_handler() self.auth = hs.get_auth() + self.storage = hs.get_storage() + self.state_store = self.storage.state # ExpiringCache((User, Device)) -> LruCache(state_key => event_id) self.lazy_loaded_members_cache = ExpiringCache( @@ -417,7 +419,7 @@ class SyncHandler(object): current_state_ids = frozenset(itervalues(current_state_ids)) recents = yield filter_events_for_client( - self.store, + self.storage, sync_config.user.to_string(), recents, always_include_ids=current_state_ids, @@ -470,7 +472,7 @@ class SyncHandler(object): current_state_ids = frozenset(itervalues(current_state_ids)) loaded_recents = yield filter_events_for_client( - self.store, + self.storage, sync_config.user.to_string(), loaded_recents, always_include_ids=current_state_ids, @@ -509,7 +511,7 @@ class SyncHandler(object): Returns: A Deferred map from ((type, state_key)->Event) """ - state_ids = yield self.store.get_state_ids_for_event( + state_ids = yield self.state_store.get_state_ids_for_event( event.event_id, state_filter=state_filter ) if event.is_state(): @@ -580,7 +582,7 @@ class SyncHandler(object): return None last_event = last_events[-1] - state_ids = yield self.store.get_state_ids_for_event( + state_ids = yield self.state_store.get_state_ids_for_event( last_event.event_id, state_filter=StateFilter.from_types( [(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")] @@ -757,11 +759,11 @@ class SyncHandler(object): if full_state: if batch: - current_state_ids = yield self.store.get_state_ids_for_event( + current_state_ids = yield self.state_store.get_state_ids_for_event( batch.events[-1].event_id, state_filter=state_filter ) - state_ids = yield self.store.get_state_ids_for_event( + state_ids = yield self.state_store.get_state_ids_for_event( batch.events[0].event_id, state_filter=state_filter ) @@ -781,7 +783,7 @@ class SyncHandler(object): ) elif batch.limited: if batch: - state_at_timeline_start = yield self.store.get_state_ids_for_event( + state_at_timeline_start = yield self.state_store.get_state_ids_for_event( batch.events[0].event_id, state_filter=state_filter ) else: @@ -810,7 +812,7 @@ class SyncHandler(object): ) if batch: - current_state_ids = yield self.store.get_state_ids_for_event( + current_state_ids = yield self.state_store.get_state_ids_for_event( batch.events[-1].event_id, state_filter=state_filter ) else: @@ -841,7 +843,7 @@ class SyncHandler(object): # So we fish out all the member events corresponding to the # timeline here, and then dedupe any redundant ones below. - state_ids = yield self.store.get_state_ids_for_event( + state_ids = yield self.state_store.get_state_ids_for_event( batch.events[0].event_id, # we only want members! state_filter=StateFilter.from_types( diff --git a/synapse/notifier.py b/synapse/notifier.py index 4e091314e6..af161a81d7 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -159,6 +159,7 @@ class Notifier(object): self.room_to_user_streams = {} self.hs = hs + self.storage = hs.get_storage() self.event_sources = hs.get_event_sources() self.store = hs.get_datastore() self.pending_new_room_events = [] @@ -425,7 +426,10 @@ class Notifier(object): if name == "room": new_events = yield filter_events_for_client( - self.store, user.to_string(), new_events, is_peeking=is_peeking + self.storage, + user.to_string(), + new_events, + is_peeking=is_peeking, ) elif name == "presence": now = self.clock.time_msec() diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 6299587808..36e26032a1 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -64,6 +64,7 @@ class HttpPusher(object): def __init__(self, hs, pusherdict): self.hs = hs self.store = self.hs.get_datastore() + self.storage = self.hs.get_storage() self.clock = self.hs.get_clock() self.state_handler = self.hs.get_state_handler() self.user_id = pusherdict["user_name"] @@ -329,7 +330,7 @@ class HttpPusher(object): return d ctx = yield push_tools.get_context_for_event( - self.store, self.state_handler, event, self.user_id + self.storage, self.state_handler, event, self.user_id ) d = { diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 5b16ab4ae8..1d15a06a58 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -119,6 +119,7 @@ class Mailer(object): self.store = self.hs.get_datastore() self.macaroon_gen = self.hs.get_macaroon_generator() self.state_handler = self.hs.get_state_handler() + self.storage = hs.get_storage() self.app_name = app_name logger.info("Created Mailer for app_name %s" % app_name) @@ -389,7 +390,7 @@ class Mailer(object): } the_events = yield filter_events_for_client( - self.store, user_id, results["events_before"] + self.storage, user_id, results["events_before"] ) the_events.append(notif_event) diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index a54051a726..de5c101a58 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -16,6 +16,7 @@ from twisted.internet import defer from synapse.push.presentable_names import calculate_room_name, name_from_member_event +from synapse.storage import Storage @defer.inlineCallbacks @@ -43,22 +44,22 @@ def get_badge_count(store, user_id): @defer.inlineCallbacks -def get_context_for_event(store, state_handler, ev, user_id): +def get_context_for_event(storage: Storage, state_handler, ev, user_id): ctx = {} - room_state_ids = yield store.get_state_ids_for_event(ev.event_id) + room_state_ids = yield storage.state.get_state_ids_for_event(ev.event_id) # we no longer bother setting room_alias, and make room_name the # human-readable name instead, be that m.room.name, an alias or # a list of people in the room name = yield calculate_room_name( - store, room_state_ids, user_id, fallback_to_single_member=False + storage.main, room_state_ids, user_id, fallback_to_single_member=False ) if name: ctx["name"] = name sender_state_event_id = room_state_ids[("m.room.member", ev.sender)] - sender_state_event = yield store.get_event(sender_state_event_id) + sender_state_event = yield storage.main.get_event(sender_state_event_id) ctx["sender_display_name"] = name_from_member_event(sender_state_event) return ctx diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index dc9f5a9008..4e91eb66fe 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -103,6 +103,7 @@ class StateHandler(object): def __init__(self, hs): self.clock = hs.get_clock() self.store = hs.get_datastore() + self.state_store = hs.get_storage().state self.hs = hs self._state_resolution_handler = hs.get_state_resolution_handler() @@ -271,7 +272,7 @@ class StateHandler(object): else: current_state_ids = prev_state_ids - state_group = yield self.store.store_state_group( + state_group = yield self.state_store.store_state_group( event.event_id, event.room_id, prev_group=None, @@ -321,7 +322,7 @@ class StateHandler(object): delta_ids = dict(entry.delta_ids) delta_ids[key] = event.event_id - state_group = yield self.store.store_state_group( + state_group = yield self.state_store.store_state_group( event.event_id, event.room_id, prev_group=prev_group, @@ -334,7 +335,7 @@ class StateHandler(object): delta_ids = entry.delta_ids if entry.state_group is None: - entry.state_group = yield self.store.store_state_group( + entry.state_group = yield self.state_store.store_state_group( event.event_id, event.room_id, prev_group=entry.prev_group, @@ -376,14 +377,16 @@ class StateHandler(object): # map from state group id to the state in that state group (where # 'state' is a map from state key to event id) # dict[int, dict[(str, str), str]] - state_groups_ids = yield self.store.get_state_groups_ids(room_id, event_ids) + state_groups_ids = yield self.state_store.get_state_groups_ids( + room_id, event_ids + ) if len(state_groups_ids) == 0: return _StateCacheEntry(state={}, state_group=None) elif len(state_groups_ids) == 1: name, state_list = list(state_groups_ids.items()).pop() - prev_group, delta_ids = yield self.store.get_state_group_delta(name) + prev_group, delta_ids = yield self.state_store.get_state_group_delta(name) return _StateCacheEntry( state=state_list, diff --git a/synapse/visibility.py b/synapse/visibility.py index bf0f1eebd8..8c843febd8 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -23,6 +23,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.events.utils import prune_event +from synapse.storage import Storage from synapse.storage.state import StateFilter from synapse.types import get_domain_from_id @@ -43,14 +44,13 @@ MEMBERSHIP_PRIORITY = ( @defer.inlineCallbacks def filter_events_for_client( - store, user_id, events, is_peeking=False, always_include_ids=frozenset() + storage: Storage, user_id, events, is_peeking=False, always_include_ids=frozenset() ): """ Check which events a user is allowed to see Args: - store (synapse.storage.DataStore): our datastore (can also be a worker - store) + storage user_id(str): user id to be checked events(list[synapse.events.EventBase]): sequence of events to be checked is_peeking(bool): should be True if: @@ -68,12 +68,12 @@ def filter_events_for_client( events = list(e for e in events if not e.internal_metadata.is_soft_failed()) types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id)) - event_id_to_state = yield store.get_state_for_events( + event_id_to_state = yield storage.state.get_state_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types(types), ) - ignore_dict_content = yield store.get_global_account_data_by_type_for_user( + ignore_dict_content = yield storage.main.get_global_account_data_by_type_for_user( "m.ignored_user_list", user_id ) @@ -84,7 +84,7 @@ def filter_events_for_client( else [] ) - erased_senders = yield store.are_users_erased((e.sender for e in events)) + erased_senders = yield storage.main.are_users_erased((e.sender for e in events)) def allowed(event): """ @@ -213,13 +213,17 @@ def filter_events_for_client( @defer.inlineCallbacks def filter_events_for_server( - store, server_name, events, redact=True, check_history_visibility_only=False + storage: Storage, + server_name, + events, + redact=True, + check_history_visibility_only=False, ): """Filter a list of events based on whether given server is allowed to see them. Args: - store (DataStore) + storage server_name (str) events (iterable[FrozenEvent]) redact (bool): Whether to return a redacted version of the event, or @@ -274,7 +278,7 @@ def filter_events_for_server( # Lets check to see if all the events have a history visibility # of "shared" or "world_readable". If thats the case then we don't # need to check membership (as we know the server is in the room). - event_to_state_ids = yield store.get_state_ids_for_events( + event_to_state_ids = yield storage.state.get_state_ids_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types( types=((EventTypes.RoomHistoryVisibility, ""),) @@ -292,14 +296,14 @@ def filter_events_for_server( if not visibility_ids: all_open = True else: - event_map = yield store.get_events(visibility_ids) + event_map = yield storage.main.get_events(visibility_ids) all_open = all( e.content.get("history_visibility") in (None, "shared", "world_readable") for e in itervalues(event_map) ) if not check_history_visibility_only: - erased_senders = yield store.are_users_erased((e.sender for e in events)) + erased_senders = yield storage.main.are_users_erased((e.sender for e in events)) else: # We don't want to check whether users are erased, which is equivalent # to no users having been erased. @@ -328,7 +332,7 @@ def filter_events_for_server( # first, for each event we're wanting to return, get the event_ids # of the history vis and membership state at those events. - event_to_state_ids = yield store.get_state_ids_for_events( + event_to_state_ids = yield storage.state.get_state_ids_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types( types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None)) @@ -358,7 +362,7 @@ def filter_events_for_server( return False return state_key[idx + 1 :] == server_name - event_map = yield store.get_events( + event_map = yield storage.main.get_events( [ e_id for e_id, key in iteritems(event_id_to_state_key) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index d573a3e07b..43200654f1 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -35,6 +35,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store = hs.get_datastore() self.storage = hs.get_storage() + self.state_datastore = self.store self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() @@ -83,7 +84,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - state_group_map = yield self.store.get_state_groups_ids( + state_group_map = yield self.storage.state.get_state_groups_ids( self.room, [e2.event_id] ) self.assertEqual(len(state_group_map), 1) @@ -102,7 +103,9 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - state_group_map = yield self.store.get_state_groups(self.room, [e2.event_id]) + state_group_map = yield self.storage.state.get_state_groups( + self.room, [e2.event_id] + ) self.assertEqual(len(state_group_map), 1) state_list = list(state_group_map.values())[0] @@ -142,7 +145,7 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we get the full state as of the final event - state = yield self.store.get_state_for_event(e5.event_id) + state = yield self.storage.state.get_state_for_event(e5.event_id) self.assertIsNotNone(e4) @@ -158,21 +161,21 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we can filter to the m.room.name event (with a '' state key) - state = yield self.store.get_state_for_event( + state = yield self.storage.state.get_state_for_event( e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can filter to the m.room.name event (with a wildcard None state key) - state = yield self.store.get_state_for_event( + state = yield self.storage.state.get_state_for_event( e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can grab the m.room.member events (with a wildcard None state key) - state = yield self.store.get_state_for_event( + state = yield self.storage.state.get_state_for_event( e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) ) @@ -182,7 +185,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # check we can grab a specific room member without filtering out the # other event types - state = yield self.store.get_state_for_event( + state = yield self.storage.state.get_state_for_event( e5.event_id, state_filter=StateFilter( types={EventTypes.Member: {self.u_alice.to_string()}}, @@ -200,7 +203,7 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check that we can grab everything except members - state = yield self.store.get_state_for_event( + state = yield self.storage.state.get_state_for_event( e5.event_id, state_filter=StateFilter( types={EventTypes.Member: set()}, include_others=True @@ -216,13 +219,18 @@ class StateStoreTestCase(tests.unittest.TestCase): ####################################################### room_id = self.room.to_string() - group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id]) + group_ids = yield self.storage.state.get_state_groups_ids( + room_id, [e5.event_id] + ) group = list(group_ids.keys())[0] # test _get_state_for_group_using_cache correctly filters out members # with types=[] - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: set()}, include_others=True @@ -238,8 +246,11 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: set()}, include_others=True @@ -251,8 +262,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with wildcard types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: None}, include_others=True @@ -268,8 +282,11 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: None}, include_others=True @@ -288,8 +305,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=True @@ -305,8 +325,11 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=True @@ -318,8 +341,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=False @@ -332,9 +358,11 @@ class StateStoreTestCase(tests.unittest.TestCase): ####################################################### # deliberately remove e2 (room name) from the _state_group_cache - (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get( - group - ) + ( + is_all, + known_absent, + state_dict_ids, + ) = self.state_datastore._state_group_cache.get(group) self.assertEqual(is_all, True) self.assertEqual(known_absent, set()) @@ -347,18 +375,20 @@ class StateStoreTestCase(tests.unittest.TestCase): ) state_dict_ids.pop((e2.type, e2.state_key)) - self.store._state_group_cache.invalidate(group) - self.store._state_group_cache.update( - sequence=self.store._state_group_cache.sequence, + self.state_datastore._state_group_cache.invalidate(group) + self.state_datastore._state_group_cache.update( + sequence=self.state_datastore._state_group_cache.sequence, key=group, value=state_dict_ids, # list fetched keys so it knows it's partial fetched_keys=((e1.type, e1.state_key),), ) - (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get( - group - ) + ( + is_all, + known_absent, + state_dict_ids, + ) = self.state_datastore._state_group_cache.get(group) self.assertEqual(is_all, False) self.assertEqual(known_absent, set([(e1.type, e1.state_key)])) @@ -370,8 +400,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters out members # with types=[] room_id = self.room.to_string() - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: set()}, include_others=True @@ -382,8 +415,11 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) room_id = self.room.to_string() - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: set()}, include_others=True @@ -395,8 +431,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # wildcard types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: None}, include_others=True @@ -406,8 +445,11 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(is_all, False) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: None}, include_others=True @@ -425,8 +467,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=True @@ -436,8 +481,11 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(is_all, False) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=True @@ -449,8 +497,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=False @@ -460,8 +511,11 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(is_all, False) self.assertDictEqual({}, state_dict) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=False diff --git a/tests/test_state.py b/tests/test_state.py index 610ec9fb46..38246555bd 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -158,10 +158,12 @@ class Graph(object): class StateTestCase(unittest.TestCase): def setUp(self): self.store = StateGroupStore() + storage = Mock(main=self.store, state=self.store) hs = Mock( spec_set=[ "config", "get_datastore", + "get_storage", "get_auth", "get_state_handler", "get_clock", @@ -174,6 +176,7 @@ class StateTestCase(unittest.TestCase): hs.get_clock.return_value = MockClock() hs.get_auth.return_value = Auth(hs) hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs) + hs.get_storage.return_value = storage self.state = StateHandler(hs) self.event_id = 0 diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 6ae1ea9b04..f7381b2885 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -14,6 +14,8 @@ # limitations under the License. import logging +from mock import Mock + from twisted.internet import defer from twisted.internet.defer import succeed @@ -63,7 +65,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): events_to_filter.append(evt) filtered = yield filter_events_for_server( - self.store, "test_server", events_to_filter + self.storage, "test_server", events_to_filter ) # the result should be 5 redacted events, and 5 unredacted events. @@ -101,7 +103,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): # ... and the filtering happens. filtered = yield filter_events_for_server( - self.store, "test_server", events_to_filter + self.storage, "test_server", events_to_filter ) for i in range(0, len(events_to_filter)): @@ -258,6 +260,11 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): logger.info("Starting filtering") start = time.time() + + storage = Mock() + storage.main = test_store + storage.state = test_store + filtered = yield filter_events_for_server( test_store, "test_server", events_to_filter ) -- cgit 1.5.1 From 7c8c97e635811609c5a7ae4c0bb94e6573c30753 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 30 Oct 2019 15:12:49 +0000 Subject: Split purge API into events vs state --- synapse/handlers/pagination.py | 7 +- synapse/storage/__init__.py | 2 + synapse/storage/data_stores/main/events.py | 328 ++++++++++++++--------------- synapse/storage/data_stores/main/state.py | 23 ++ synapse/storage/purge_events.py | 117 ++++++++++ tests/storage/test_purge.py | 15 +- 6 files changed, 308 insertions(+), 184 deletions(-) create mode 100644 synapse/storage/purge_events.py (limited to 'tests') diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 5744f4579d..9088ba14cd 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -69,6 +69,7 @@ class PaginationHandler(object): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() + self.storage = hs.get_storage() self.clock = hs.get_clock() self._server_name = hs.hostname @@ -125,7 +126,9 @@ class PaginationHandler(object): self._purges_in_progress_by_room.add(room_id) try: with (yield self.pagination_lock.write(room_id)): - yield self.store.purge_history(room_id, token, delete_local_events) + yield self.storage.purge_events.purge_history( + room_id, token, delete_local_events + ) logger.info("[purge] complete") self._purges_by_id[purge_id].status = PurgeStatus.STATUS_COMPLETE except Exception: @@ -168,7 +171,7 @@ class PaginationHandler(object): if joined: raise SynapseError(400, "Users are still joined to this room") - await self.store.purge_room(room_id) + await self.storage.purge_events.purge_room(room_id) @defer.inlineCallbacks def get_messages( diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index a6429d17ed..3646ebd007 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -30,6 +30,7 @@ stored in `synapse.storage.schema`. from synapse.storage.data_stores import DataStores from synapse.storage.data_stores.main import DataStore from synapse.storage.persist_events import EventsPersistenceStorage +from synapse.storage.purge_events import PurgeEventsStorage __all__ = ["DataStores", "DataStore"] @@ -45,6 +46,7 @@ class Storage(object): self.main = stores.main self.persistence = EventsPersistenceStorage(hs, stores) + self.purge_events = PurgeEventsStorage(hs, stores) def are_all_users_on_domain(txn, database_engine, domain): diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 7c3607f308..4eacba8058 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -1368,6 +1368,10 @@ class EventsStore( if True, we will delete local events as well as remote ones (instead of just marking them as outliers and deleting their state groups). + + Returns: + Deferred[set[int]]: The set of state groups that reference deleted + events. """ return self.runInteraction( @@ -1521,60 +1525,6 @@ class EventsStore( "[purge] found %i referenced state groups", len(referenced_state_groups) ) - logger.info("[purge] finding state groups that can be deleted") - - _ = self._find_unreferenced_groups_during_purge(txn, referenced_state_groups) - state_groups_to_delete, remaining_state_groups = _ - - logger.info( - "[purge] found %i state groups to delete", len(state_groups_to_delete) - ) - - logger.info( - "[purge] de-delta-ing %i remaining state groups", - len(remaining_state_groups), - ) - - # Now we turn the state groups that reference to-be-deleted state - # groups to non delta versions. - for sg in remaining_state_groups: - logger.info("[purge] de-delta-ing remaining state group %s", sg) - curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) - curr_state = curr_state[sg] - - self._simple_delete_txn( - txn, table="state_groups_state", keyvalues={"state_group": sg} - ) - - self._simple_delete_txn( - txn, table="state_group_edges", keyvalues={"state_group": sg} - ) - - self._simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": sg, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in iteritems(curr_state) - ], - ) - - logger.info("[purge] removing redundant state groups") - txn.executemany( - "DELETE FROM state_groups_state WHERE state_group = ?", - ((sg,) for sg in state_groups_to_delete), - ) - txn.executemany( - "DELETE FROM state_groups WHERE id = ?", - ((sg,) for sg in state_groups_to_delete), - ) - logger.info("[purge] removing events from event_to_state_groups") txn.execute( "DELETE FROM event_to_state_groups " @@ -1661,87 +1611,7 @@ class EventsStore( logger.info("[purge] done") - def _find_unreferenced_groups_during_purge(self, txn, state_groups): - """Used when purging history to figure out which state groups can be - deleted and which need to be de-delta'ed (due to one of its prev groups - being scheduled for deletion). - - Args: - txn - state_groups (set[int]): Set of state groups referenced by events - that are going to be deleted. - - Returns: - tuple[set[int], set[int]]: The set of state groups that can be - deleted and the set of state groups that need to be de-delta'ed - """ - # Graph of state group -> previous group - graph = {} - - # Set of events that we have found to be referenced by events - referenced_groups = set() - - # Set of state groups we've already seen - state_groups_seen = set(state_groups) - - # Set of state groups to handle next. - next_to_search = set(state_groups) - while next_to_search: - # We bound size of groups we're looking up at once, to stop the - # SQL query getting too big - if len(next_to_search) < 100: - current_search = next_to_search - next_to_search = set() - else: - current_search = set(itertools.islice(next_to_search, 100)) - next_to_search -= current_search - - # Check if state groups are referenced - sql = """ - SELECT DISTINCT state_group FROM event_to_state_groups - LEFT JOIN events_to_purge AS ep USING (event_id) - WHERE ep.event_id IS NULL AND - """ - clause, args = make_in_list_sql_clause( - txn.database_engine, "state_group", current_search - ) - txn.execute(sql + clause, list(args)) - - referenced = set(sg for sg, in txn) - referenced_groups |= referenced - - # We don't continue iterating up the state group graphs for state - # groups that are referenced. - current_search -= referenced - - rows = self._simple_select_many_txn( - txn, - table="state_group_edges", - column="prev_state_group", - iterable=current_search, - keyvalues={}, - retcols=("prev_state_group", "state_group"), - ) - - prevs = set(row["state_group"] for row in rows) - # We don't bother re-handling groups we've already seen - prevs -= state_groups_seen - next_to_search |= prevs - state_groups_seen |= prevs - - for row in rows: - # Note: Each state group can have at most one prev group - graph[row["state_group"]] = row["prev_state_group"] - - to_delete = state_groups_seen - referenced_groups - - to_dedelta = set() - for sg in referenced_groups: - prev_sg = graph.get(sg) - if prev_sg and prev_sg in to_delete: - to_dedelta.add(sg) - - return to_delete, to_dedelta + return referenced_state_groups def purge_room(self, room_id): """Deletes all record of a room @@ -1753,46 +1623,7 @@ class EventsStore( return self.runInteraction("purge_room", self._purge_room_txn, room_id) def _purge_room_txn(self, txn, room_id): - # first we have to delete the state groups states - logger.info("[purge] removing %s from state_groups_state", room_id) - - txn.execute( - """ - DELETE FROM state_groups_state WHERE state_group IN ( - SELECT state_group FROM events JOIN event_to_state_groups USING(event_id) - WHERE events.room_id=? - ) - """, - (room_id,), - ) - - # ... and the state group edges - logger.info("[purge] removing %s from state_group_edges", room_id) - - txn.execute( - """ - DELETE FROM state_group_edges WHERE state_group IN ( - SELECT state_group FROM events JOIN event_to_state_groups USING(event_id) - WHERE events.room_id=? - ) - """, - (room_id,), - ) - - # ... and the state groups - logger.info("[purge] removing %s from state_groups", room_id) - - txn.execute( - """ - DELETE FROM state_groups WHERE id IN ( - SELECT state_group FROM events JOIN event_to_state_groups USING(event_id) - WHERE events.room_id=? - ) - """, - (room_id,), - ) - - # and then tables which lack an index on room_id but have one on event_id + # First delete tables which lack an index on room_id but have one on event_id for table in ( "event_auth", "event_edges", @@ -1881,6 +1712,153 @@ class EventsStore( logger.info("[purge] done") + def purge_unreferenced_state_groups(self, room_id, state_groups_to_delete): + """Deletes no longer referenced state groups and de-deltas any state + groups that reference them. + """ + + return self.runInteraction( + "purge_unreferenced_state_groups", + self._purge_unreferenced_state_groups, + room_id, + state_groups_to_delete, + ) + + def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete): + logger.info( + "[purge] found %i state groups to delete", len(state_groups_to_delete) + ) + + rows = self._simple_select_many_txn( + txn, + table="state_group_edges", + column="prev_state_group", + iterable=state_groups_to_delete, + keyvalues={}, + retcols=("state_group",), + ) + + remaining_state_groups = set( + row["state_group"] + for row in rows + if row["state_group"] not in state_groups_to_delete + ) + + logger.info( + "[purge] de-delta-ing %i remaining state groups", + len(remaining_state_groups), + ) + + # Now we turn the state groups that reference to-be-deleted state + # groups to non delta versions. + for sg in remaining_state_groups: + logger.info("[purge] de-delta-ing remaining state group %s", sg) + curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) + curr_state = curr_state[sg] + + self._simple_delete_txn( + txn, table="state_groups_state", keyvalues={"state_group": sg} + ) + + self._simple_delete_txn( + txn, table="state_group_edges", keyvalues={"state_group": sg} + ) + + self._simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": sg, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(curr_state) + ], + ) + + logger.info("[purge] removing redundant state groups") + txn.executemany( + "DELETE FROM state_groups_state WHERE state_group = ?", + ((sg,) for sg in state_groups_to_delete), + ) + txn.executemany( + "DELETE FROM state_groups WHERE id = ?", + ((sg,) for sg in state_groups_to_delete), + ) + + @defer.inlineCallbacks + def get_previous_state_groups(self, state_groups): + """Fetch the previous groups of the given state groups. + + Args: + state_groups (Iterable[int]) + + Returns: + Deferred[dict[int, int]]: mapping from state group to previous + state group. + """ + + rows = yield self._simple_select_many_batch( + table="state_group_edges", + column="prev_state_group", + iterable=state_groups, + keyvalues={}, + retcols=("prev_state_group", "state_group"), + desc="get_previous_state_groups", + ) + + return {row["state_group"]: row["prev_state_group"] for row in rows} + + def purge_room_state(self, room_id): + """Deletes all record of a room from state tables + + Args: + room_id (str): + """ + + return self.runInteraction( + "purge_room_state", self._purge_room_state_txn, room_id + ) + + def _purge_room_state_txn(self, txn, room_id): + # first we have to delete the state groups states + logger.info("[purge] removing %s from state_groups_state", room_id) + + txn.execute( + """ + DELETE FROM state_groups_state + INNER JOIN state_groups USING (event_id) + WHEREE state_groups.room_id = ? + """, + (room_id,), + ) + + # ... and the state group edges + logger.info("[purge] removing %s from state_group_edges", room_id) + + txn.execute( + """ + DELETE FROM state_group_edges + INNER JOIN state_groups USING (event_id) + WHEREE state_groups.room_id = ? + ) + """, + (room_id,), + ) + + # ... and the state groups + logger.info("[purge] removing %s from state_groups", room_id) + + txn.execute( + """ + DELETE FROM state_groups WHEREE room_id = ? + """, + (room_id,), + ) + async def is_event_after(self, event_id1, event_id2): """Returns True if event_id1 is after event_id2 in the stream """ diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 9b2207075b..36be8f0a9d 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -989,6 +989,29 @@ class StateGroupWorkerStore( return self.runInteraction("store_state_group", _store_state_group_txn) + @defer.inlineCallbacks + def get_referenced_state_groups(self, state_groups): + """Check if the state groups are referenced by events. + + Args: + state_groups (Iterable[int]) + + Returns: + Deferred[set[int]]: The subset of state groups that are + referenced. + """ + + rows = yield self._simple_select_many_batch( + table="event_to_state_groups", + column="state_group", + iterable=state_groups, + keyvalues={}, + retcols=("DISTINCT state_group",), + desc="get_referenced_state_groups", + ) + + return set(row["state_group"] for row in rows) + class StateBackgroundUpdateStore( StateGroupBackgroundUpdateStore, BackgroundUpdateStore diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py new file mode 100644 index 0000000000..dd45df0c88 --- /dev/null +++ b/synapse/storage/purge_events.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import logging + +from twisted.internet import defer + +logger = logging.getLogger(__name__) + + +class PurgeEventsStorage(object): + """High level interface for purging rooms and event history. + """ + + def __init__(self, hs, stores): + self.stores = stores + + @defer.inlineCallbacks + def purge_room(self, room_id: str): + """Deletes all record of a room + """ + + yield self.stores.main.purge_room(room_id) + yield self.stores.main.purge_room_state(room_id) + + @defer.inlineCallbacks + def purge_history(self, room_id, token, delete_local_events): + """Deletes room history before a certain point + + Args: + room_id (str): + + token (str): A topological token to delete events before + + delete_local_events (bool): + if True, we will delete local events as well as remote ones + (instead of just marking them as outliers and deleting their + state groups). + """ + state_groups = yield self.stores.main.purge_history( + room_id, token, delete_local_events + ) + + logger.info("[purge] finding state groups that can be deleted") + + sg_to_delete = yield self._find_unreferenced_groups(state_groups) + + yield self.stores.main.purge_unreferenced_state_groups(room_id, sg_to_delete) + + @defer.inlineCallbacks + def _find_unreferenced_groups(self, state_groups): + """Used when purging history to figure out which state groups can be + deleted. + + Args: + state_groups (set[int]): Set of state groups referenced by events + that are going to be deleted. + + Returns: + Deferred[set[int]] The set of state groups that can be deleted. + """ + # Graph of state group -> previous group + graph = {} + + # Set of events that we have found to be referenced by events + referenced_groups = set() + + # Set of state groups we've already seen + state_groups_seen = set(state_groups) + + # Set of state groups to handle next. + next_to_search = set(state_groups) + while next_to_search: + # We bound size of groups we're looking up at once, to stop the + # SQL query getting too big + if len(next_to_search) < 100: + current_search = next_to_search + next_to_search = set() + else: + current_search = set(itertools.islice(next_to_search, 100)) + next_to_search -= current_search + + referenced = yield self.stores.main.get_referenced_state_groups( + current_search + ) + referenced_groups |= referenced + + # We don't continue iterating up the state group graphs for state + # groups that are referenced. + current_search -= referenced + + edges = yield self.stores.main.get_previous_state_groups(current_search) + + prevs = set(edges.values()) + # We don't bother re-handling groups we've already seen + prevs -= state_groups_seen + next_to_search |= prevs + state_groups_seen |= prevs + + graph.update(edges) + + to_delete = state_groups_seen - referenced_groups + + return to_delete diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index f671599cb8..b9fafaa1a6 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -40,23 +40,24 @@ class PurgeTests(HomeserverTestCase): third = self.helper.send(self.room_id, body="test3") last = self.helper.send(self.room_id, body="test4") - storage = self.hs.get_datastore() + store = self.hs.get_datastore() + storage = self.hs.get_storage() # Get the topological token - event = storage.get_topological_token_for_event(last["event_id"]) + event = store.get_topological_token_for_event(last["event_id"]) self.pump() event = self.successResultOf(event) # Purge everything before this topological token - purge = storage.purge_history(self.room_id, event, True) + purge = storage.purge_events.purge_history(self.room_id, event, True) self.pump() self.assertEqual(self.successResultOf(purge), None) # Try and get the events - get_first = storage.get_event(first["event_id"]) - get_second = storage.get_event(second["event_id"]) - get_third = storage.get_event(third["event_id"]) - get_last = storage.get_event(last["event_id"]) + get_first = store.get_event(first["event_id"]) + get_second = store.get_event(second["event_id"]) + get_third = store.get_event(third["event_id"]) + get_last = store.get_event(last["event_id"]) self.pump() # 1-3 should fail and last will succeed, meaning that 1-3 are deleted -- cgit 1.5.1 From e7943f660add8b602ea5225060bd0d74e6440017 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 30 Oct 2019 16:15:04 +0000 Subject: Add unit tests --- synapse/api/filtering.py | 2 +- tests/api/test_filtering.py | 51 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index a27029c678..bd91b9f018 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -307,7 +307,7 @@ class Filter(object): content = event.get("content", {}) # check if there is a string url field in the content for filtering purposes contains_url = isinstance(content.get("url"), text_type) - labels = content.get(LabelsField) + labels = content.get(LabelsField, []) return self.check_fields(room_id, sender, ev_type, labels, contains_url) diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 6ba623de13..66b3c828db 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -19,6 +19,7 @@ import jsonschema from twisted.internet import defer +from synapse.api.constants import LabelsField from synapse.api.errors import SynapseError from synapse.api.filtering import Filter from synapse.events import FrozenEvent @@ -95,6 +96,8 @@ class FilteringTestCase(unittest.TestCase): "types": ["m.room.message"], "not_rooms": ["!726s6s6q:example.com"], "not_senders": ["@spam:example.com"], + "org.matrix.labels": ["#fun"], + "org.matrix.not_labels": ["#work"], }, "ephemeral": { "types": ["m.receipt", "m.typing"], @@ -320,6 +323,54 @@ class FilteringTestCase(unittest.TestCase): ) self.assertFalse(Filter(definition).check(event)) + def test_filter_labels(self): + definition = {"org.matrix.labels": ["#fun"]} + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={ + LabelsField: ["#fun"] + }, + ) + + self.assertTrue(Filter(definition).check(event)) + + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={ + LabelsField: ["#notfun"] + }, + ) + + self.assertFalse(Filter(definition).check(event)) + + def test_filter_not_labels(self): + definition = {"org.matrix.not_labels": ["#fun"]} + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={ + LabelsField: ["#fun"] + }, + ) + + self.assertFalse(Filter(definition).check(event)) + + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={ + LabelsField: ["#notfun"] + }, + ) + + self.assertTrue(Filter(definition).check(event)) + @defer.inlineCallbacks def test_filter_presence_match(self): user_filter_json = {"presence": {"types": ["m.*"]}} -- cgit 1.5.1 From 395683add1d569c0fdfd83d279551a3ba926f4d5 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 30 Oct 2019 16:47:37 +0000 Subject: Add integration tests for sync --- tests/rest/client/v1/utils.py | 15 ++++- tests/rest/client/v2_alpha/test_sync.py | 112 +++++++++++++++++++++++++++++++- 2 files changed, 122 insertions(+), 5 deletions(-) (limited to 'tests') diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index cdded88b7f..8ea0cb05ea 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -106,13 +106,22 @@ class RestHelper(object): self.auth_user_id = temp_id def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200): - if txn_id is None: - txn_id = "m%s" % (str(time.time())) if body is None: body = "body_text_here" - path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id) content = {"msgtype": "m.text", "body": body} + + return self.send_event( + room_id, "m.room.message", content, txn_id, tok, expect_code + ) + + def send_event( + self, room_id, type, content={}, txn_id=None, tok=None, expect_code=200 + ): + if txn_id is None: + txn_id = "m%s" % (str(time.time())) + + path = "/_matrix/client/r0/rooms/%s/send/%s/%s" % (room_id, type, txn_id) if tok: path = path + "?access_token=%s" % tok diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 71895094bd..0263be010f 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -12,9 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import json from mock import Mock +from synapse.api.constants import EventTypes, LabelsField import synapse.rest.admin from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import sync @@ -26,7 +27,12 @@ from tests.server import TimedOutException class FilterTestCase(unittest.HomeserverTestCase): user_id = "@apple:test" - servlets = [sync.register_servlets] + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + sync.register_servlets, + ] def make_homeserver(self, reactor, clock): @@ -70,6 +76,108 @@ class FilterTestCase(unittest.HomeserverTestCase): ) +class SyncFilterTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + sync.register_servlets, + ] + + def test_sync_filter_labels(self): + sync_filter = json.dumps( + { + "room": { + "timeline": { + "types": [EventTypes.Message], + "org.matrix.labels": ["#fun"], + } + } + } + ) + + events = self._test_sync_filter_labels(sync_filter) + + self.assertEqual(len(events), 2, events) + self.assertEqual(events[0]["content"]["body"], "with label", events[0]) + self.assertEqual(events[1]["content"]["body"], "with label", events[1]) + + def test_sync_filter_not_labels(self): + sync_filter = json.dumps( + { + "room": { + "timeline": { + "types": [EventTypes.Message], + "org.matrix.not_labels": ["#fun"], + } + } + } + ) + + events = self._test_sync_filter_labels(sync_filter) + + self.assertEqual(len(events), 2, events) + self.assertEqual(events[0]["content"]["body"], "without label", events[0]) + self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1]) + + def _test_sync_filter_labels(self, sync_filter): + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + + room_id = self.helper.create_room_as(user_id, tok=tok) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with label", + LabelsField: ["#fun"], + }, + tok=tok, + ) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "without label", + }, + tok=tok, + ) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with wrong label", + LabelsField: ["#work"], + }, + tok=tok, + ) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with label", + LabelsField: ["#fun"], + }, + tok=tok, + ) + + request, channel = self.make_request( + "GET", "/sync?filter=%s" % sync_filter, access_token=tok + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + return channel.json_body["rooms"]["join"][room_id]["timeline"]["events"] + + class SyncTypingTests(unittest.HomeserverTestCase): servlets = [ -- cgit 1.5.1 From fe51d6cacf6e1a2da5fc3589d0bc4118342b33dd Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 30 Oct 2019 17:28:41 +0000 Subject: Add more integration testing --- synapse/storage/data_stores/main/stream.py | 2 +- tests/rest/client/v2_alpha/test_sync.py | 45 ++++++++++++++++++++++++++---- 2 files changed, 40 insertions(+), 7 deletions(-) (limited to 'tests') diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py index 907d7f20ba..cfa34ba1e7 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -872,7 +872,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): args.append(int(limit)) sql = ( - "SELECT event_id, topological_ordering, stream_ordering" + "SELECT DISTINCT event_id, topological_ordering, stream_ordering" " FROM events" " LEFT JOIN event_labels USING (event_id)" " WHERE outlier = ? AND room_id = ? AND %(bounds)s" diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 0263be010f..a1aa7d87bd 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -85,6 +85,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): ] def test_sync_filter_labels(self): + """Test that we can filter by a label.""" sync_filter = json.dumps( { "room": { @@ -98,11 +99,12 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): events = self._test_sync_filter_labels(sync_filter) - self.assertEqual(len(events), 2, events) - self.assertEqual(events[0]["content"]["body"], "with label", events[0]) - self.assertEqual(events[1]["content"]["body"], "with label", events[1]) + self.assertEqual(len(events), 2, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) + self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) def test_sync_filter_not_labels(self): + """Test that we can filter by the absence of a label.""" sync_filter = json.dumps( { "room": { @@ -116,9 +118,29 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): events = self._test_sync_filter_labels(sync_filter) - self.assertEqual(len(events), 2, events) + self.assertEqual(len(events), 3, [event["content"] for event in events]) self.assertEqual(events[0]["content"]["body"], "without label", events[0]) self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1]) + self.assertEqual(events[2]["content"]["body"], "with two wrong labels", events[2]) + + def test_sync_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label.""" + sync_filter = json.dumps( + { + "room": { + "timeline": { + "types": [EventTypes.Message], + "org.matrix.labels": ["#work"], + "org.matrix.not_labels": ["#notfun"], + } + } + } + ) + + events = self._test_sync_filter_labels(sync_filter) + + self.assertEqual(len(events), 1, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) def _test_sync_filter_labels(self, sync_filter): user_id = self.register_user("kermit", "test") @@ -131,7 +153,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): type=EventTypes.Message, content={ "msgtype": "m.text", - "body": "with label", + "body": "with right label", LabelsField: ["#fun"], }, tok=tok, @@ -163,7 +185,18 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): type=EventTypes.Message, content={ "msgtype": "m.text", - "body": "with label", + "body": "with two wrong labels", + LabelsField: ["#work", "#notfun"], + }, + tok=tok, + ) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", LabelsField: ["#fun"], }, tok=tok, -- cgit 1.5.1 From d8c9109aeee58950f0fd4d9865836b82aa7aafb6 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 30 Oct 2019 17:48:22 +0000 Subject: Add integration tests for /messages --- tests/rest/client/v1/test_rooms.py | 102 ++++++++++++++++++++++++++++++++++++- 1 file changed, 101 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 2f2ca74611..ba2008497e 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -24,7 +24,7 @@ from six.moves.urllib import parse as urlparse from twisted.internet import defer import synapse.rest.admin -from synapse.api.constants import Membership +from synapse.api.constants import EventTypes, LabelsField, Membership from synapse.rest.client.v1 import login, profile, room from tests import unittest @@ -811,6 +811,106 @@ class RoomMessageListTestCase(RoomBase): self.assertTrue("chunk" in channel.json_body) self.assertTrue("end" in channel.json_body) + def test_filter_labels(self): + """Test that we can filter by a label.""" + message_filter = json.dumps({ + "types": [EventTypes.Message], + "org.matrix.labels": ["#fun"], + }) + + events = self._test_filter_labels(message_filter) + + self.assertEqual(len(events), 2, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) + self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) + + def test_filter_not_labels(self): + """Test that we can filter by the absence of a label.""" + message_filter = json.dumps({ + "types": [EventTypes.Message], + "org.matrix.not_labels": ["#fun"], + }) + + events = self._test_filter_labels(message_filter) + + self.assertEqual(len(events), 3, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "without label", events[0]) + self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1]) + self.assertEqual(events[2]["content"]["body"], "with two wrong labels", events[2]) + + def test_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label.""" + sync_filter = json.dumps({ + "types": [EventTypes.Message], + "org.matrix.labels": ["#work"], + "org.matrix.not_labels": ["#notfun"], + }) + + events = self._test_filter_labels(sync_filter) + + self.assertEqual(len(events), 1, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) + + def _test_filter_labels(self, message_filter): + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + LabelsField: ["#fun"], + } + ) + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "without label", + } + ) + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with wrong label", + LabelsField: ["#work"], + } + ) + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with two wrong labels", + LabelsField: ["#work", "#notfun"], + } + ) + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + LabelsField: ["#fun"], + } + ) + + token = "s0_0_0_0_0_0_0_0_0" + request, channel = self.make_request( + "GET", "/rooms/%s/messages?access_token=x&from=%s&filter=%s" % ( + self.room_id, token, message_filter + ) + ) + self.render(request) + + return channel.json_body["chunk"] + class RoomSearchTestCase(unittest.HomeserverTestCase): servlets = [ -- cgit 1.5.1 From dcc069a2e2540862c233a20037e3e59591a42431 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 30 Oct 2019 18:01:56 +0000 Subject: Lint --- synapse/storage/data_stores/main/events.py | 8 +---- tests/api/test_filtering.py | 16 +++------- tests/rest/client/v1/test_rooms.py | 49 +++++++++++++++--------------- tests/rest/client/v2_alpha/test_sync.py | 12 ++++---- 4 files changed, 35 insertions(+), 50 deletions(-) (limited to 'tests') diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index f80b5f1a3f..2b900f1ce1 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -2486,13 +2486,7 @@ class EventsStore( return self._simple_insert_many_txn( txn=txn, table="event_labels", - values=[ - { - "event_id": event_id, - "label": label, - } - for label in labels - ], + values=[{"event_id": event_id, "label": label} for label in labels], ) diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 66b3c828db..e004ab1ee5 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -329,9 +329,7 @@ class FilteringTestCase(unittest.TestCase): sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown", - content={ - LabelsField: ["#fun"] - }, + content={LabelsField: ["#fun"]}, ) self.assertTrue(Filter(definition).check(event)) @@ -340,9 +338,7 @@ class FilteringTestCase(unittest.TestCase): sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown", - content={ - LabelsField: ["#notfun"] - }, + content={LabelsField: ["#notfun"]}, ) self.assertFalse(Filter(definition).check(event)) @@ -353,9 +349,7 @@ class FilteringTestCase(unittest.TestCase): sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown", - content={ - LabelsField: ["#fun"] - }, + content={LabelsField: ["#fun"]}, ) self.assertFalse(Filter(definition).check(event)) @@ -364,9 +358,7 @@ class FilteringTestCase(unittest.TestCase): sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown", - content={ - LabelsField: ["#notfun"] - }, + content={LabelsField: ["#notfun"]}, ) self.assertTrue(Filter(definition).check(event)) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index ba2008497e..188f47bd7d 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -813,10 +813,9 @@ class RoomMessageListTestCase(RoomBase): def test_filter_labels(self): """Test that we can filter by a label.""" - message_filter = json.dumps({ - "types": [EventTypes.Message], - "org.matrix.labels": ["#fun"], - }) + message_filter = json.dumps( + {"types": [EventTypes.Message], "org.matrix.labels": ["#fun"]} + ) events = self._test_filter_labels(message_filter) @@ -826,25 +825,28 @@ class RoomMessageListTestCase(RoomBase): def test_filter_not_labels(self): """Test that we can filter by the absence of a label.""" - message_filter = json.dumps({ - "types": [EventTypes.Message], - "org.matrix.not_labels": ["#fun"], - }) + message_filter = json.dumps( + {"types": [EventTypes.Message], "org.matrix.not_labels": ["#fun"]} + ) events = self._test_filter_labels(message_filter) self.assertEqual(len(events), 3, [event["content"] for event in events]) self.assertEqual(events[0]["content"]["body"], "without label", events[0]) self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1]) - self.assertEqual(events[2]["content"]["body"], "with two wrong labels", events[2]) + self.assertEqual( + events[2]["content"]["body"], "with two wrong labels", events[2] + ) def test_filter_labels_not_labels(self): """Test that we can filter by both a label and the absence of another label.""" - sync_filter = json.dumps({ - "types": [EventTypes.Message], - "org.matrix.labels": ["#work"], - "org.matrix.not_labels": ["#notfun"], - }) + sync_filter = json.dumps( + { + "types": [EventTypes.Message], + "org.matrix.labels": ["#work"], + "org.matrix.not_labels": ["#notfun"], + } + ) events = self._test_filter_labels(sync_filter) @@ -859,16 +861,13 @@ class RoomMessageListTestCase(RoomBase): "msgtype": "m.text", "body": "with right label", LabelsField: ["#fun"], - } + }, ) self.helper.send_event( room_id=self.room_id, type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "without label", - } + content={"msgtype": "m.text", "body": "without label"}, ) self.helper.send_event( @@ -878,7 +877,7 @@ class RoomMessageListTestCase(RoomBase): "msgtype": "m.text", "body": "with wrong label", LabelsField: ["#work"], - } + }, ) self.helper.send_event( @@ -888,7 +887,7 @@ class RoomMessageListTestCase(RoomBase): "msgtype": "m.text", "body": "with two wrong labels", LabelsField: ["#work", "#notfun"], - } + }, ) self.helper.send_event( @@ -898,14 +897,14 @@ class RoomMessageListTestCase(RoomBase): "msgtype": "m.text", "body": "with right label", LabelsField: ["#fun"], - } + }, ) token = "s0_0_0_0_0_0_0_0_0" request, channel = self.make_request( - "GET", "/rooms/%s/messages?access_token=x&from=%s&filter=%s" % ( - self.room_id, token, message_filter - ) + "GET", + "/rooms/%s/messages?access_token=x&from=%s&filter=%s" + % (self.room_id, token, message_filter), ) self.render(request) diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index a1aa7d87bd..c5c199d412 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -13,10 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import json + from mock import Mock -from synapse.api.constants import EventTypes, LabelsField import synapse.rest.admin +from synapse.api.constants import EventTypes, LabelsField from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import sync @@ -121,7 +122,9 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): self.assertEqual(len(events), 3, [event["content"] for event in events]) self.assertEqual(events[0]["content"]["body"], "without label", events[0]) self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1]) - self.assertEqual(events[2]["content"]["body"], "with two wrong labels", events[2]) + self.assertEqual( + events[2]["content"]["body"], "with two wrong labels", events[2] + ) def test_sync_filter_labels_not_labels(self): """Test that we can filter by both a label and the absence of another label.""" @@ -162,10 +165,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): self.helper.send_event( room_id=room_id, type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "without label", - }, + content={"msgtype": "m.text", "body": "without label"}, tok=tok, ) -- cgit 1.5.1 From bb6cec27a5ac6d5d6d5f67df21610a63745ac0a9 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Wed, 30 Oct 2019 14:57:34 -0400 Subject: rename get_devices_by_remote to get_device_updates_by_remote --- synapse/federation/sender/per_destination_queue.py | 4 ++-- synapse/storage/data_stores/main/devices.py | 8 ++++---- tests/handlers/test_typing.py | 4 ++-- tests/storage/test_devices.py | 12 ++++++------ 4 files changed, 14 insertions(+), 14 deletions(-) (limited to 'tests') diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index d5d4a60c88..6e3012cd41 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -359,7 +359,7 @@ class PerDestinationQueue(object): last_device_list = self._last_device_list_stream_id # Retrieve list of new device updates to send to the destination - now_stream_id, results = yield self._store.get_devices_by_remote( + now_stream_id, results = yield self._store.get_device_updates_by_remote( self._destination, last_device_list, limit=limit ) edus = [ @@ -372,7 +372,7 @@ class PerDestinationQueue(object): for (edu_type, content) in results ] - assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs" + assert len(edus) <= limit, "get_device_updates_by_remote returned too many EDUs" return (edus, now_stream_id) diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 0b12bc58c4..717eab4159 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -91,7 +91,7 @@ class DeviceWorkerStore(SQLBaseStore): @trace @defer.inlineCallbacks - def get_devices_by_remote(self, destination, from_stream_id, limit): + def get_device_updates_by_remote(self, destination, from_stream_id, limit): """Get a stream of device updates to send to the given remote server. Args: @@ -123,8 +123,8 @@ class DeviceWorkerStore(SQLBaseStore): # stream_id; the rationale being that such a large device list update # is likely an error. updates = yield self.runInteraction( - "get_devices_by_remote", - self._get_devices_by_remote_txn, + "get_device_updates_by_remote", + self._get_device_updates_by_remote_txn, destination, from_stream_id, now_stream_id, @@ -241,7 +241,7 @@ class DeviceWorkerStore(SQLBaseStore): return now_stream_id, results - def _get_devices_by_remote_txn( + def _get_device_updates_by_remote_txn( self, txn, destination, from_stream_id, now_stream_id, limit ): """Return device update information for a given remote destination diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index f360c8e965..5ec568f4e6 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -73,7 +73,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): "get_received_txn_response", "set_received_txn_response", "get_destination_retry_timings", - "get_devices_by_remote", + "get_device_updates_by_remote", # Bits that user_directory needs "get_user_directory_stream_pos", "get_current_state_deltas", @@ -109,7 +109,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): retry_timings_res ) - self.datastore.get_devices_by_remote.return_value = (0, []) + self.datastore.get_device_updates_by_remote.return_value = (0, []) def get_received_txn_response(*args): return defer.succeed(None) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 039cc79357..6f8d990959 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -72,7 +72,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase): ) @defer.inlineCallbacks - def test_get_devices_by_remote(self): + def test_get_device_updates_by_remote(self): device_ids = ["device_id1", "device_id2"] # Add two device updates with a single stream_id @@ -81,7 +81,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase): ) # Get all device updates ever meant for this remote - now_stream_id, device_updates = yield self.store.get_devices_by_remote( + now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( "somehost", -1, limit=100 ) @@ -89,7 +89,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase): self._check_devices_in_updates(device_ids, device_updates) @defer.inlineCallbacks - def test_get_devices_by_remote_limited(self): + def test_get_device_updates_by_remote_limited(self): # Test breaking the update limit in 1, 101, and 1 device_id segments # first add one device @@ -115,20 +115,20 @@ class DeviceStoreTestCase(tests.unittest.TestCase): # # first we should get a single update - now_stream_id, device_updates = yield self.store.get_devices_by_remote( + now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( "someotherhost", -1, limit=100 ) self._check_devices_in_updates(device_ids1, device_updates) # Then we should get an empty list back as the 101 devices broke the limit - now_stream_id, device_updates = yield self.store.get_devices_by_remote( + now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( "someotherhost", now_stream_id, limit=100 ) self.assertEqual(len(device_updates), 0) # The 101 devices should've been cleared, so we should now just get one device # update - now_stream_id, device_updates = yield self.store.get_devices_by_remote( + now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( "someotherhost", now_stream_id, limit=100 ) self._check_devices_in_updates(device_ids3, device_updates) -- cgit 1.5.1 From 97c60ccaa35059a55866304f6850b24e99912036 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 31 Oct 2019 11:30:25 +0000 Subject: Add unit test for /purge_room API --- tests/rest/admin/test_admin.py | 78 ++++++++++++++++++++++++++++++++++++++++++ tests/server.py | 6 +++- 2 files changed, 83 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index d3a4f717f7..8e1ca8b738 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -561,3 +561,81 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) return channel.json_body["groups"] + + +class PurgeRoomTestCase(unittest.HomeserverTestCase): + """Test /purge_room admin API. + """ + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + def test_purge_room(self): + room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + # All users have to have left the room. + self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok) + + url = "/_synapse/admin/v1/purge_room" + request, channel = self.make_request( + "POST", + url.encode("ascii"), + {"room_id": room_id}, + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Test that the following tables have been purged of all rows related to the room. + for table in ( + "current_state_events", + "event_backward_extremities", + "event_forward_extremities", + "event_json", + "event_push_actions", + "event_search", + "events", + "group_rooms", + "public_room_list_stream", + "receipts_graph", + "receipts_linearized", + "room_aliases", + "room_depth", + "room_memberships", + "room_stats_state", + "room_stats_current", + "room_stats_historical", + "room_stats_earliest_token", + "rooms", + "stream_ordering_to_exterm", + "users_in_public_rooms", + "users_who_share_private_rooms", + "appservice_room_list", + "e2e_room_keys", + "event_push_summary", + "pusher_throttle", + "group_summary_rooms", + "local_invites", + "room_account_data", + "room_tags", + ): + count = self.get_success( + self.store._simple_select_one_onecol( + table="events", + keyvalues={"room_id": room_id}, + retcol="COUNT(*)", + desc="test_purge_room", + ) + ) + + self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) diff --git a/tests/server.py b/tests/server.py index e397ebe8fa..469efb4edb 100644 --- a/tests/server.py +++ b/tests/server.py @@ -161,7 +161,11 @@ def make_request( path = path.encode("ascii") # Decorate it to be the full path, if we're using shorthand - if shorthand and not path.startswith(b"/_matrix"): + if ( + shorthand + and not path.startswith(b"/_matrix") + and not path.startswith(b"/_synapse") + ): path = b"/_matrix/client/r0/" + path path = path.replace(b"//", b"/") -- cgit 1.5.1 From c6dbca2422bf77ccbf0b52d9245d28c258dac4f3 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 1 Nov 2019 10:30:51 +0000 Subject: Incorporate review --- changelog.d/6301.feature | 2 +- synapse/api/constants.py | 5 ++++- synapse/api/filtering.py | 6 ++++-- synapse/storage/data_stores/main/events.py | 12 ++++++++++-- tests/api/test_filtering.py | 10 +++++----- tests/rest/client/v1/test_rooms.py | 10 +++++----- tests/rest/client/v2_alpha/test_sync.py | 10 +++++----- 7 files changed, 34 insertions(+), 21 deletions(-) (limited to 'tests') diff --git a/changelog.d/6301.feature b/changelog.d/6301.feature index b7ff3fad3b..78a187a1dc 100644 --- a/changelog.d/6301.feature +++ b/changelog.d/6301.feature @@ -1 +1 @@ -Implement label-based filtering. +Implement label-based filtering on `/sync` and `/messages` ([MSC2326](https://github.com/matrix-org/matrix-doc/pull/2326)). diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 999ec02fd9..cf4ce5f5a2 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -140,4 +140,7 @@ class LimitBlockingTypes(object): HS_DISABLED = "hs_disabled" -LabelsField = "org.matrix.labels" +class EventContentFields(object): + """Fields found in events' content, regardless of type.""" + # Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326 + Labels = "org.matrix.labels" diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index bd91b9f018..30a7ee0a7a 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -20,7 +20,7 @@ from jsonschema import FormatChecker from twisted.internet import defer -from synapse.api.constants import LabelsField +from synapse.api.constants import EventContentFields from synapse.api.errors import SynapseError from synapse.storage.presence import UserPresenceState from synapse.types import RoomID, UserID @@ -67,6 +67,8 @@ ROOM_EVENT_FILTER_SCHEMA = { "contains_url": {"type": "boolean"}, "lazy_load_members": {"type": "boolean"}, "include_redundant_members": {"type": "boolean"}, + # Include or exclude events with the provided labels. + # cf https://github.com/matrix-org/matrix-doc/pull/2326 "org.matrix.labels": {"type": "array", "items": {"type": "string"}}, "org.matrix.not_labels": {"type": "array", "items": {"type": "string"}}, }, @@ -307,7 +309,7 @@ class Filter(object): content = event.get("content", {}) # check if there is a string url field in the content for filtering purposes contains_url = isinstance(content.get("url"), text_type) - labels = content.get(LabelsField, []) + labels = content.get(EventContentFields.Labels, []) return self.check_fields(room_id, sender, ev_type, labels, contains_url) diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 2b900f1ce1..42ffa9066a 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -29,7 +29,7 @@ from prometheus_client import Counter, Histogram from twisted.internet import defer import synapse.metrics -from synapse.api.constants import EventTypes, LabelsField +from synapse.api.constants import EventTypes, EventContentFields from synapse.api.errors import SynapseError from synapse.events import EventBase # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401 @@ -1491,7 +1491,7 @@ class EventsStore( self._handle_event_relations(txn, event) # Store the labels for this event. - labels = event.content.get(LabelsField) + labels = event.content.get(EventContentFields.Labels) if labels: self.insert_labels_for_event_txn(txn, event.event_id, labels) @@ -2483,6 +2483,14 @@ class EventsStore( ) def insert_labels_for_event_txn(self, txn, event_id, labels): + """Store the mapping between an event's ID and its labels, with one row per + (event_id, label) tuple. + + Args: + txn (LoggingTransaction): The transaction to execute. + event_id (str): The event's ID. + labels (list[str]): A list of text labels. + """ return self._simple_insert_many_txn( txn=txn, table="event_labels", diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index e004ab1ee5..8ec48c4154 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -19,7 +19,7 @@ import jsonschema from twisted.internet import defer -from synapse.api.constants import LabelsField +from synapse.api.constants import EventContentFields from synapse.api.errors import SynapseError from synapse.api.filtering import Filter from synapse.events import FrozenEvent @@ -329,7 +329,7 @@ class FilteringTestCase(unittest.TestCase): sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown", - content={LabelsField: ["#fun"]}, + content={EventContentFields.Labels: ["#fun"]}, ) self.assertTrue(Filter(definition).check(event)) @@ -338,7 +338,7 @@ class FilteringTestCase(unittest.TestCase): sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown", - content={LabelsField: ["#notfun"]}, + content={EventContentFields.Labels: ["#notfun"]}, ) self.assertFalse(Filter(definition).check(event)) @@ -349,7 +349,7 @@ class FilteringTestCase(unittest.TestCase): sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown", - content={LabelsField: ["#fun"]}, + content={EventContentFields.Labels: ["#fun"]}, ) self.assertFalse(Filter(definition).check(event)) @@ -358,7 +358,7 @@ class FilteringTestCase(unittest.TestCase): sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown", - content={LabelsField: ["#notfun"]}, + content={EventContentFields.Labels: ["#notfun"]}, ) self.assertTrue(Filter(definition).check(event)) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 188f47bd7d..0dc0faa0e5 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -24,7 +24,7 @@ from six.moves.urllib import parse as urlparse from twisted.internet import defer import synapse.rest.admin -from synapse.api.constants import EventTypes, LabelsField, Membership +from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.rest.client.v1 import login, profile, room from tests import unittest @@ -860,7 +860,7 @@ class RoomMessageListTestCase(RoomBase): content={ "msgtype": "m.text", "body": "with right label", - LabelsField: ["#fun"], + EventContentFields.Labels: ["#fun"], }, ) @@ -876,7 +876,7 @@ class RoomMessageListTestCase(RoomBase): content={ "msgtype": "m.text", "body": "with wrong label", - LabelsField: ["#work"], + EventContentFields.Labels: ["#work"], }, ) @@ -886,7 +886,7 @@ class RoomMessageListTestCase(RoomBase): content={ "msgtype": "m.text", "body": "with two wrong labels", - LabelsField: ["#work", "#notfun"], + EventContentFields.Labels: ["#work", "#notfun"], }, ) @@ -896,7 +896,7 @@ class RoomMessageListTestCase(RoomBase): content={ "msgtype": "m.text", "body": "with right label", - LabelsField: ["#fun"], + EventContentFields.Labels: ["#fun"], }, ) diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index c5c199d412..c3c6f75ced 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -17,7 +17,7 @@ import json from mock import Mock import synapse.rest.admin -from synapse.api.constants import EventTypes, LabelsField +from synapse.api.constants import EventContentFields, EventTypes from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import sync @@ -157,7 +157,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): content={ "msgtype": "m.text", "body": "with right label", - LabelsField: ["#fun"], + EventContentFields.Labels: ["#fun"], }, tok=tok, ) @@ -175,7 +175,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): content={ "msgtype": "m.text", "body": "with wrong label", - LabelsField: ["#work"], + EventContentFields.Labels: ["#work"], }, tok=tok, ) @@ -186,7 +186,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): content={ "msgtype": "m.text", "body": "with two wrong labels", - LabelsField: ["#work", "#notfun"], + EventContentFields.Labels: ["#work", "#notfun"], }, tok=tok, ) @@ -197,7 +197,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): content={ "msgtype": "m.text", "body": "with right label", - LabelsField: ["#fun"], + EventContentFields.Labels: ["#fun"], }, tok=tok, ) -- cgit 1.5.1 From 1cb84c6486a5131dd284f341bb657434becda255 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 1 Nov 2019 14:07:44 +0000 Subject: Support for routing outbound HTTP requests via a proxy (#6239) The `http_proxy` and `HTTPS_PROXY` env vars can be set to a `host[:port]` value which should point to a proxy. The address of the proxy should be excluded from IP blacklists such as the `url_preview_ip_range_blacklist`. The proxy will then be used for * push * url previews * phone-home stats * recaptcha validation * CAS auth validation It will *not* be used for: * Application Services * Identity servers * Outbound federation * In worker configurations, connections from workers to masters Fixes #4198. --- changelog.d/6238.feature | 1 + synapse/app/homeserver.py | 2 +- synapse/handlers/ui_auth/checkers.py | 2 +- synapse/http/client.py | 17 +- synapse/http/connectproxyclient.py | 195 ++++++++++++ synapse/http/proxyagent.py | 195 ++++++++++++ synapse/push/httppusher.py | 2 +- synapse/rest/client/v1/login.py | 2 +- synapse/rest/media/v1/preview_url_resource.py | 2 + synapse/server.py | 9 + synapse/server.pyi | 9 + tests/http/__init__.py | 17 ++ .../federation/test_matrix_federation_agent.py | 11 +- tests/http/test_proxyagent.py | 334 +++++++++++++++++++++ tests/push/test_http.py | 2 +- tests/server.py | 24 +- 16 files changed, 812 insertions(+), 12 deletions(-) create mode 100644 changelog.d/6238.feature create mode 100644 synapse/http/connectproxyclient.py create mode 100644 synapse/http/proxyagent.py create mode 100644 tests/http/test_proxyagent.py (limited to 'tests') diff --git a/changelog.d/6238.feature b/changelog.d/6238.feature new file mode 100644 index 0000000000..d225ac33b6 --- /dev/null +++ b/changelog.d/6238.feature @@ -0,0 +1 @@ +Add support for outbound http proxying via http_proxy/HTTPS_PROXY env vars. diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 8997c1f9e7..8d28076d92 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -565,7 +565,7 @@ def run(hs): "Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats) ) try: - yield hs.get_simple_http_client().put_json( + yield hs.get_proxied_http_client().put_json( hs.config.report_stats_endpoint, stats ) except Exception as e: diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 29aa1e5aaf..8363d887a9 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -81,7 +81,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker): def __init__(self, hs): super().__init__(hs) self._enabled = bool(hs.config.recaptcha_private_key) - self._http_client = hs.get_simple_http_client() + self._http_client = hs.get_proxied_http_client() self._url = hs.config.recaptcha_siteverify_api self._secret = hs.config.recaptcha_private_key diff --git a/synapse/http/client.py b/synapse/http/client.py index 2df5b383b5..d4c285445e 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -45,6 +45,7 @@ from synapse.http import ( cancelled_to_request_timed_out_error, redact_uri, ) +from synapse.http.proxyagent import ProxyAgent from synapse.logging.context import make_deferred_yieldable from synapse.logging.opentracing import set_tag, start_active_span, tags from synapse.util.async_helpers import timeout_deferred @@ -183,7 +184,15 @@ class SimpleHttpClient(object): using HTTP in Matrix """ - def __init__(self, hs, treq_args={}, ip_whitelist=None, ip_blacklist=None): + def __init__( + self, + hs, + treq_args={}, + ip_whitelist=None, + ip_blacklist=None, + http_proxy=None, + https_proxy=None, + ): """ Args: hs (synapse.server.HomeServer) @@ -192,6 +201,8 @@ class SimpleHttpClient(object): we may not request. ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can request if it were otherwise caught in a blacklist. + http_proxy (bytes): proxy server to use for http connections. host[:port] + https_proxy (bytes): proxy server to use for https connections. host[:port] """ self.hs = hs @@ -236,11 +247,13 @@ class SimpleHttpClient(object): # The default context factory in Twisted 14.0.0 (which we require) is # BrowserLikePolicyForHTTPS which will do regular cert validation # 'like a browser' - self.agent = Agent( + self.agent = ProxyAgent( self.reactor, connectTimeout=15, contextFactory=self.hs.get_http_client_context_factory(), pool=pool, + http_proxy=http_proxy, + https_proxy=https_proxy, ) if self._ip_blacklist: diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py new file mode 100644 index 0000000000..be7b2ceb8e --- /dev/null +++ b/synapse/http/connectproxyclient.py @@ -0,0 +1,195 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from zope.interface import implementer + +from twisted.internet import defer, protocol +from twisted.internet.error import ConnectError +from twisted.internet.interfaces import IStreamClientEndpoint +from twisted.internet.protocol import connectionDone +from twisted.web import http + +logger = logging.getLogger(__name__) + + +class ProxyConnectError(ConnectError): + pass + + +@implementer(IStreamClientEndpoint) +class HTTPConnectProxyEndpoint(object): + """An Endpoint implementation which will send a CONNECT request to an http proxy + + Wraps an existing HostnameEndpoint for the proxy. + + When we get the connect() request from the connection pool (via the TLS wrapper), + we'll first connect to the proxy endpoint with a ProtocolFactory which will make the + CONNECT request. Once that completes, we invoke the protocolFactory which was passed + in. + + Args: + reactor: the Twisted reactor to use for the connection + proxy_endpoint (IStreamClientEndpoint): the endpoint to use to connect to the + proxy + host (bytes): hostname that we want to CONNECT to + port (int): port that we want to connect to + """ + + def __init__(self, reactor, proxy_endpoint, host, port): + self._reactor = reactor + self._proxy_endpoint = proxy_endpoint + self._host = host + self._port = port + + def __repr__(self): + return "" % (self._proxy_endpoint,) + + def connect(self, protocolFactory): + f = HTTPProxiedClientFactory(self._host, self._port, protocolFactory) + d = self._proxy_endpoint.connect(f) + # once the tcp socket connects successfully, we need to wait for the + # CONNECT to complete. + d.addCallback(lambda conn: f.on_connection) + return d + + +class HTTPProxiedClientFactory(protocol.ClientFactory): + """ClientFactory wrapper that triggers an HTTP proxy CONNECT on connect. + + Once the CONNECT completes, invokes the original ClientFactory to build the + HTTP Protocol object and run the rest of the connection. + + Args: + dst_host (bytes): hostname that we want to CONNECT to + dst_port (int): port that we want to connect to + wrapped_factory (protocol.ClientFactory): The original Factory + """ + + def __init__(self, dst_host, dst_port, wrapped_factory): + self.dst_host = dst_host + self.dst_port = dst_port + self.wrapped_factory = wrapped_factory + self.on_connection = defer.Deferred() + + def startedConnecting(self, connector): + return self.wrapped_factory.startedConnecting(connector) + + def buildProtocol(self, addr): + wrapped_protocol = self.wrapped_factory.buildProtocol(addr) + + return HTTPConnectProtocol( + self.dst_host, self.dst_port, wrapped_protocol, self.on_connection + ) + + def clientConnectionFailed(self, connector, reason): + logger.debug("Connection to proxy failed: %s", reason) + if not self.on_connection.called: + self.on_connection.errback(reason) + return self.wrapped_factory.clientConnectionFailed(connector, reason) + + def clientConnectionLost(self, connector, reason): + logger.debug("Connection to proxy lost: %s", reason) + if not self.on_connection.called: + self.on_connection.errback(reason) + return self.wrapped_factory.clientConnectionLost(connector, reason) + + +class HTTPConnectProtocol(protocol.Protocol): + """Protocol that wraps an existing Protocol to do a CONNECT handshake at connect + + Args: + host (bytes): The original HTTP(s) hostname or IPv4 or IPv6 address literal + to put in the CONNECT request + + port (int): The original HTTP(s) port to put in the CONNECT request + + wrapped_protocol (interfaces.IProtocol): the original protocol (probably + HTTPChannel or TLSMemoryBIOProtocol, but could be anything really) + + connected_deferred (Deferred): a Deferred which will be callbacked with + wrapped_protocol when the CONNECT completes + """ + + def __init__(self, host, port, wrapped_protocol, connected_deferred): + self.host = host + self.port = port + self.wrapped_protocol = wrapped_protocol + self.connected_deferred = connected_deferred + self.http_setup_client = HTTPConnectSetupClient(self.host, self.port) + self.http_setup_client.on_connected.addCallback(self.proxyConnected) + + def connectionMade(self): + self.http_setup_client.makeConnection(self.transport) + + def connectionLost(self, reason=connectionDone): + if self.wrapped_protocol.connected: + self.wrapped_protocol.connectionLost(reason) + + self.http_setup_client.connectionLost(reason) + + if not self.connected_deferred.called: + self.connected_deferred.errback(reason) + + def proxyConnected(self, _): + self.wrapped_protocol.makeConnection(self.transport) + + self.connected_deferred.callback(self.wrapped_protocol) + + # Get any pending data from the http buf and forward it to the original protocol + buf = self.http_setup_client.clearLineBuffer() + if buf: + self.wrapped_protocol.dataReceived(buf) + + def dataReceived(self, data): + # if we've set up the HTTP protocol, we can send the data there + if self.wrapped_protocol.connected: + return self.wrapped_protocol.dataReceived(data) + + # otherwise, we must still be setting up the connection: send the data to the + # setup client + return self.http_setup_client.dataReceived(data) + + +class HTTPConnectSetupClient(http.HTTPClient): + """HTTPClient protocol to send a CONNECT message for proxies and read the response. + + Args: + host (bytes): The hostname to send in the CONNECT message + port (int): The port to send in the CONNECT message + """ + + def __init__(self, host, port): + self.host = host + self.port = port + self.on_connected = defer.Deferred() + + def connectionMade(self): + logger.debug("Connected to proxy, sending CONNECT") + self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port)) + self.endHeaders() + + def handleStatus(self, version, status, message): + logger.debug("Got Status: %s %s %s", status, message, version) + if status != b"200": + raise ProxyConnectError("Unexpected status on CONNECT: %s" % status) + + def handleEndHeaders(self): + logger.debug("End Headers") + self.on_connected.callback(None) + + def handleResponse(self, body): + pass diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py new file mode 100644 index 0000000000..332da02a8d --- /dev/null +++ b/synapse/http/proxyagent.py @@ -0,0 +1,195 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import re + +from zope.interface import implementer + +from twisted.internet import defer +from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS +from twisted.python.failure import Failure +from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase +from twisted.web.error import SchemeNotSupported +from twisted.web.iweb import IAgent + +from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint + +logger = logging.getLogger(__name__) + +_VALID_URI = re.compile(br"\A[\x21-\x7e]+\Z") + + +@implementer(IAgent) +class ProxyAgent(_AgentBase): + """An Agent implementation which will use an HTTP proxy if one was requested + + Args: + reactor: twisted reactor to place outgoing + connections. + + contextFactory (IPolicyForHTTPS): A factory for TLS contexts, to control the + verification parameters of OpenSSL. The default is to use a + `BrowserLikePolicyForHTTPS`, so unless you have special + requirements you can leave this as-is. + + connectTimeout (float): The amount of time that this Agent will wait + for the peer to accept a connection. + + bindAddress (bytes): The local address for client sockets to bind to. + + pool (HTTPConnectionPool|None): connection pool to be used. If None, a + non-persistent pool instance will be created. + """ + + def __init__( + self, + reactor, + contextFactory=BrowserLikePolicyForHTTPS(), + connectTimeout=None, + bindAddress=None, + pool=None, + http_proxy=None, + https_proxy=None, + ): + _AgentBase.__init__(self, reactor, pool) + + self._endpoint_kwargs = {} + if connectTimeout is not None: + self._endpoint_kwargs["timeout"] = connectTimeout + if bindAddress is not None: + self._endpoint_kwargs["bindAddress"] = bindAddress + + self.http_proxy_endpoint = _http_proxy_endpoint( + http_proxy, reactor, **self._endpoint_kwargs + ) + + self.https_proxy_endpoint = _http_proxy_endpoint( + https_proxy, reactor, **self._endpoint_kwargs + ) + + self._policy_for_https = contextFactory + self._reactor = reactor + + def request(self, method, uri, headers=None, bodyProducer=None): + """ + Issue a request to the server indicated by the given uri. + + Supports `http` and `https` schemes. + + An existing connection from the connection pool may be used or a new one may be + created. + + See also: twisted.web.iweb.IAgent.request + + Args: + method (bytes): The request method to use, such as `GET`, `POST`, etc + + uri (bytes): The location of the resource to request. + + headers (Headers|None): Extra headers to send with the request + + bodyProducer (IBodyProducer|None): An object which can generate bytes to + make up the body of this request (for example, the properly encoded + contents of a file for a file upload). Or, None if the request is to + have no body. + + Returns: + Deferred[IResponse]: completes when the header of the response has + been received (regardless of the response status code). + """ + uri = uri.strip() + if not _VALID_URI.match(uri): + raise ValueError("Invalid URI {!r}".format(uri)) + + parsed_uri = URI.fromBytes(uri) + pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port) + request_path = parsed_uri.originForm + + if parsed_uri.scheme == b"http" and self.http_proxy_endpoint: + # Cache *all* connections under the same key, since we are only + # connecting to a single destination, the proxy: + pool_key = ("http-proxy", self.http_proxy_endpoint) + endpoint = self.http_proxy_endpoint + request_path = uri + elif parsed_uri.scheme == b"https" and self.https_proxy_endpoint: + endpoint = HTTPConnectProxyEndpoint( + self._reactor, + self.https_proxy_endpoint, + parsed_uri.host, + parsed_uri.port, + ) + else: + # not using a proxy + endpoint = HostnameEndpoint( + self._reactor, parsed_uri.host, parsed_uri.port, **self._endpoint_kwargs + ) + + logger.debug("Requesting %s via %s", uri, endpoint) + + if parsed_uri.scheme == b"https": + tls_connection_creator = self._policy_for_https.creatorForNetloc( + parsed_uri.host, parsed_uri.port + ) + endpoint = wrapClientTLS(tls_connection_creator, endpoint) + elif parsed_uri.scheme == b"http": + pass + else: + return defer.fail( + Failure( + SchemeNotSupported("Unsupported scheme: %r" % (parsed_uri.scheme,)) + ) + ) + + return self._requestWithEndpoint( + pool_key, endpoint, method, parsed_uri, headers, bodyProducer, request_path + ) + + +def _http_proxy_endpoint(proxy, reactor, **kwargs): + """Parses an http proxy setting and returns an endpoint for the proxy + + Args: + proxy (bytes|None): the proxy setting + reactor: reactor to be used to connect to the proxy + kwargs: other args to be passed to HostnameEndpoint + + Returns: + interfaces.IStreamClientEndpoint|None: endpoint to use to connect to the proxy, + or None + """ + if proxy is None: + return None + + # currently we only support hostname:port. Some apps also support + # protocol://[:port], which allows a way of requiring a TLS connection to the + # proxy. + + host, port = parse_host_port(proxy, default_port=1080) + return HostnameEndpoint(reactor, host, port, **kwargs) + + +def parse_host_port(hostport, default_port=None): + # could have sworn we had one of these somewhere else... + if b":" in hostport: + host, port = hostport.rsplit(b":", 1) + try: + port = int(port) + return host, port + except ValueError: + # the thing after the : wasn't a valid port; presumably this is an + # IPv6 address. + pass + + return hostport, default_port diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 7dde2ad055..e994037be6 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -103,7 +103,7 @@ class HttpPusher(object): if "url" not in self.data: raise PusherConfigException("'url' required in data for HTTP pusher") self.url = self.data["url"] - self.http_client = hs.get_simple_http_client() + self.http_client = hs.get_proxied_http_client() self.data_minus_url = {} self.data_minus_url.update(self.data) del self.data_minus_url["url"] diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 00a7dd6d09..24a0ce74f2 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -381,7 +381,7 @@ class CasTicketServlet(RestServlet): self.cas_displayname_attribute = hs.config.cas_displayname_attribute self.cas_required_attributes = hs.config.cas_required_attributes self._sso_auth_handler = SSOAuthHandler(hs) - self._http_client = hs.get_simple_http_client() + self._http_client = hs.get_proxied_http_client() @defer.inlineCallbacks def on_GET(self, request): diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 5a25b6b3fc..531d923f76 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -74,6 +74,8 @@ class PreviewUrlResource(DirectServeResource): treq_args={"browser_like_redirects": True}, ip_whitelist=hs.config.url_preview_ip_range_whitelist, ip_blacklist=hs.config.url_preview_ip_range_blacklist, + http_proxy=os.getenv("http_proxy"), + https_proxy=os.getenv("HTTPS_PROXY"), ) self.media_repo = media_repo self.primary_base_path = media_repo.primary_base_path diff --git a/synapse/server.py b/synapse/server.py index 0b81af646c..f8aeebcff8 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -23,6 +23,7 @@ # Imports required for the default HomeServer() implementation import abc import logging +import os from twisted.enterprise import adbapi from twisted.mail.smtp import sendmail @@ -168,6 +169,7 @@ class HomeServer(object): "filtering", "http_client_context_factory", "simple_http_client", + "proxied_http_client", "media_repository", "media_repository_resource", "federation_transport_client", @@ -311,6 +313,13 @@ class HomeServer(object): def build_simple_http_client(self): return SimpleHttpClient(self) + def build_proxied_http_client(self): + return SimpleHttpClient( + self, + http_proxy=os.getenv("http_proxy"), + https_proxy=os.getenv("HTTPS_PROXY"), + ) + def build_room_creation_handler(self): return RoomCreationHandler(self) diff --git a/synapse/server.pyi b/synapse/server.pyi index 83d1f11283..b5e0b57095 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -12,6 +12,7 @@ import synapse.handlers.message import synapse.handlers.room import synapse.handlers.room_member import synapse.handlers.set_password +import synapse.http.client import synapse.rest.media.v1.media_repository import synapse.server_notices.server_notices_manager import synapse.server_notices.server_notices_sender @@ -38,6 +39,14 @@ class HomeServer(object): pass def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler: pass + def get_simple_http_client(self) -> synapse.http.client.SimpleHttpClient: + """Fetch an HTTP client implementation which doesn't do any blacklisting + or support any HTTP_PROXY settings""" + pass + def get_proxied_http_client(self) -> synapse.http.client.SimpleHttpClient: + """Fetch an HTTP client implementation which doesn't do any blacklisting + but does support HTTP_PROXY settings""" + pass def get_deactivate_account_handler( self, ) -> synapse.handlers.deactivate_account.DeactivateAccountHandler: diff --git a/tests/http/__init__.py b/tests/http/__init__.py index 2d5dba6464..2096ba3c91 100644 --- a/tests/http/__init__.py +++ b/tests/http/__init__.py @@ -20,6 +20,23 @@ from zope.interface import implementer from OpenSSL import SSL from OpenSSL.SSL import Connection from twisted.internet.interfaces import IOpenSSLServerConnectionCreator +from twisted.internet.ssl import Certificate, trustRootFromCertificates +from twisted.web.client import BrowserLikePolicyForHTTPS # noqa: F401 +from twisted.web.iweb import IPolicyForHTTPS # noqa: F401 + + +def get_test_https_policy(): + """Get a test IPolicyForHTTPS which trusts the test CA cert + + Returns: + IPolicyForHTTPS + """ + ca_file = get_test_ca_cert_file() + with open(ca_file) as stream: + content = stream.read() + cert = Certificate.loadPEM(content) + trust_root = trustRootFromCertificates([cert]) + return BrowserLikePolicyForHTTPS(trustRoot=trust_root) def get_test_ca_cert_file(): diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 71d7025264..cfcd98ff7d 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -124,19 +124,24 @@ class MatrixFederationAgentTests(unittest.TestCase): FakeTransport(client_protocol, self.reactor, server_tls_protocol) ) + # grab a hold of the TLS connection, in case it gets torn down + server_tls_connection = server_tls_protocol._tlsConnection + + # fish the test server back out of the server-side TLS protocol. + http_protocol = server_tls_protocol.wrappedProtocol + # give the reactor a pump to get the TLS juices flowing. self.reactor.pump((0.1,)) # check the SNI - server_name = server_tls_protocol._tlsConnection.get_servername() + server_name = server_tls_connection.get_servername() self.assertEqual( server_name, expected_sni, "Expected SNI %s but got %s" % (expected_sni, server_name), ) - # fish the test server back out of the server-side TLS protocol. - return server_tls_protocol.wrappedProtocol + return http_protocol @defer.inlineCallbacks def _make_get_request(self, uri): diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py new file mode 100644 index 0000000000..22abf76515 --- /dev/null +++ b/tests/http/test_proxyagent.py @@ -0,0 +1,334 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +import treq + +from twisted.internet import interfaces # noqa: F401 +from twisted.internet.protocol import Factory +from twisted.protocols.tls import TLSMemoryBIOFactory +from twisted.web.http import HTTPChannel + +from synapse.http.proxyagent import ProxyAgent + +from tests.http import TestServerTLSConnectionFactory, get_test_https_policy +from tests.server import FakeTransport, ThreadedMemoryReactorClock +from tests.unittest import TestCase + +logger = logging.getLogger(__name__) + +HTTPFactory = Factory.forProtocol(HTTPChannel) + + +class MatrixFederationAgentTests(TestCase): + def setUp(self): + self.reactor = ThreadedMemoryReactorClock() + + def _make_connection( + self, client_factory, server_factory, ssl=False, expected_sni=None + ): + """Builds a test server, and completes the outgoing client connection + + Args: + client_factory (interfaces.IProtocolFactory): the the factory that the + application is trying to use to make the outbound connection. We will + invoke it to build the client Protocol + + server_factory (interfaces.IProtocolFactory): a factory to build the + server-side protocol + + ssl (bool): If true, we will expect an ssl connection and wrap + server_factory with a TLSMemoryBIOFactory + + expected_sni (bytes|None): the expected SNI value + + Returns: + IProtocol: the server Protocol returned by server_factory + """ + if ssl: + server_factory = _wrap_server_factory_for_tls(server_factory) + + server_protocol = server_factory.buildProtocol(None) + + # now, tell the client protocol factory to build the client protocol, + # and wire the output of said protocol up to the server via + # a FakeTransport. + # + # Normally this would be done by the TCP socket code in Twisted, but we are + # stubbing that out here. + client_protocol = client_factory.buildProtocol(None) + client_protocol.makeConnection( + FakeTransport(server_protocol, self.reactor, client_protocol) + ) + + # tell the server protocol to send its stuff back to the client, too + server_protocol.makeConnection( + FakeTransport(client_protocol, self.reactor, server_protocol) + ) + + if ssl: + http_protocol = server_protocol.wrappedProtocol + tls_connection = server_protocol._tlsConnection + else: + http_protocol = server_protocol + tls_connection = None + + # give the reactor a pump to get the TLS juices flowing (if needed) + self.reactor.advance(0) + + if expected_sni is not None: + server_name = tls_connection.get_servername() + self.assertEqual( + server_name, + expected_sni, + "Expected SNI %s but got %s" % (expected_sni, server_name), + ) + + return http_protocol + + def test_http_request(self): + agent = ProxyAgent(self.reactor) + + self.reactor.lookups["test.com"] = "1.2.3.4" + d = agent.request(b"GET", b"http://test.com") + + # there should be a pending TCP connection + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 80) + + # make a test server, and wire up the client + http_server = self._make_connection( + client_factory, _get_test_protocol_factory() + ) + + # the FakeTransport is async, so we need to pump the reactor + self.reactor.advance(0) + + # now there should be a pending request + self.assertEqual(len(http_server.requests), 1) + + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) + request.write(b"result") + request.finish() + + self.reactor.advance(0) + + resp = self.successResultOf(d) + body = self.successResultOf(treq.content(resp)) + self.assertEqual(body, b"result") + + def test_https_request(self): + agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy()) + + self.reactor.lookups["test.com"] = "1.2.3.4" + d = agent.request(b"GET", b"https://test.com/abc") + + # there should be a pending TCP connection + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 443) + + # make a test server, and wire up the client + http_server = self._make_connection( + client_factory, + _get_test_protocol_factory(), + ssl=True, + expected_sni=b"test.com", + ) + + # the FakeTransport is async, so we need to pump the reactor + self.reactor.advance(0) + + # now there should be a pending request + self.assertEqual(len(http_server.requests), 1) + + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/abc") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) + request.write(b"result") + request.finish() + + self.reactor.advance(0) + + resp = self.successResultOf(d) + body = self.successResultOf(treq.content(resp)) + self.assertEqual(body, b"result") + + def test_http_request_via_proxy(self): + agent = ProxyAgent(self.reactor, http_proxy=b"proxy.com:8888") + + self.reactor.lookups["proxy.com"] = "1.2.3.5" + d = agent.request(b"GET", b"http://test.com") + + # there should be a pending TCP connection + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, "1.2.3.5") + self.assertEqual(port, 8888) + + # make a test server, and wire up the client + http_server = self._make_connection( + client_factory, _get_test_protocol_factory() + ) + + # the FakeTransport is async, so we need to pump the reactor + self.reactor.advance(0) + + # now there should be a pending request + self.assertEqual(len(http_server.requests), 1) + + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"http://test.com") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) + request.write(b"result") + request.finish() + + self.reactor.advance(0) + + resp = self.successResultOf(d) + body = self.successResultOf(treq.content(resp)) + self.assertEqual(body, b"result") + + def test_https_request_via_proxy(self): + agent = ProxyAgent( + self.reactor, + contextFactory=get_test_https_policy(), + https_proxy=b"proxy.com", + ) + + self.reactor.lookups["proxy.com"] = "1.2.3.5" + d = agent.request(b"GET", b"https://test.com/abc") + + # there should be a pending TCP connection + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, "1.2.3.5") + self.assertEqual(port, 1080) + + # make a test HTTP server, and wire up the client + proxy_server = self._make_connection( + client_factory, _get_test_protocol_factory() + ) + + # fish the transports back out so that we can do the old switcheroo + s2c_transport = proxy_server.transport + client_protocol = s2c_transport.other + c2s_transport = client_protocol.transport + + # the FakeTransport is async, so we need to pump the reactor + self.reactor.advance(0) + + # now there should be a pending CONNECT request + self.assertEqual(len(proxy_server.requests), 1) + + request = proxy_server.requests[0] + self.assertEqual(request.method, b"CONNECT") + self.assertEqual(request.path, b"test.com:443") + + # tell the proxy server not to close the connection + proxy_server.persistent = True + + # this just stops the http Request trying to do a chunked response + # request.setHeader(b"Content-Length", b"0") + request.finish() + + # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel + ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory()) + ssl_protocol = ssl_factory.buildProtocol(None) + http_server = ssl_protocol.wrappedProtocol + + ssl_protocol.makeConnection( + FakeTransport(client_protocol, self.reactor, ssl_protocol) + ) + c2s_transport.other = ssl_protocol + + self.reactor.advance(0) + + server_name = ssl_protocol._tlsConnection.get_servername() + expected_sni = b"test.com" + self.assertEqual( + server_name, + expected_sni, + "Expected SNI %s but got %s" % (expected_sni, server_name), + ) + + # now there should be a pending request + self.assertEqual(len(http_server.requests), 1) + + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/abc") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) + request.write(b"result") + request.finish() + + self.reactor.advance(0) + + resp = self.successResultOf(d) + body = self.successResultOf(treq.content(resp)) + self.assertEqual(body, b"result") + + +def _wrap_server_factory_for_tls(factory, sanlist=None): + """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory + + The resultant factory will create a TLS server which presents a certificate + signed by our test CA, valid for the domains in `sanlist` + + Args: + factory (interfaces.IProtocolFactory): protocol factory to wrap + sanlist (iterable[bytes]): list of domains the cert should be valid for + + Returns: + interfaces.IProtocolFactory + """ + if sanlist is None: + sanlist = [b"DNS:test.com"] + + connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist) + return TLSMemoryBIOFactory( + connection_creator, isClient=False, wrappedFactory=factory + ) + + +def _get_test_protocol_factory(): + """Get a protocol Factory which will build an HTTPChannel + + Returns: + interfaces.IProtocolFactory + """ + server_factory = Factory.forProtocol(HTTPChannel) + + # Request.finish expects the factory to have a 'log' method. + server_factory.log = _log_request + + return server_factory + + +def _log_request(request): + """Implements Factory.log, which is expected by Request.finish""" + logger.info("Completed request %s", request) diff --git a/tests/push/test_http.py b/tests/push/test_http.py index 8ce6bb62da..af2327fb66 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -50,7 +50,7 @@ class HTTPPusherTests(HomeserverTestCase): config = self.default_config() config["start_pushers"] = True - hs = self.setup_test_homeserver(config=config, simple_http_client=m) + hs = self.setup_test_homeserver(config=config, proxied_http_client=m) return hs diff --git a/tests/server.py b/tests/server.py index 469efb4edb..f878aeaada 100644 --- a/tests/server.py +++ b/tests/server.py @@ -395,11 +395,24 @@ class FakeTransport(object): self.disconnecting = True if self._protocol: self._protocol.connectionLost(reason) - self.disconnected = True + + # if we still have data to write, delay until that is done + if self.buffer: + logger.info( + "FakeTransport: Delaying disconnect until buffer is flushed" + ) + else: + self.disconnected = True def abortConnection(self): logger.info("FakeTransport: abortConnection()") - self.loseConnection() + + if not self.disconnecting: + self.disconnecting = True + if self._protocol: + self._protocol.connectionLost(None) + + self.disconnected = True def pauseProducing(self): if not self.producer: @@ -430,6 +443,9 @@ class FakeTransport(object): self._reactor.callLater(0.0, _produce) def write(self, byt): + if self.disconnecting: + raise Exception("Writing to disconnecting FakeTransport") + self.buffer = self.buffer + byt # always actually do the write asynchronously. Some protocols (notably the @@ -474,6 +490,10 @@ class FakeTransport(object): if self.buffer and self.autoflush: self._reactor.callLater(0.0, self.flush) + if not self.buffer and self.disconnecting: + logger.info("FakeTransport: Buffer now empty, completing disconnect") + self.disconnected = True + def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol: """ -- cgit 1.5.1 From c6516adbe03a0acdd614ba6eb9d6f447dd4259e9 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 1 Nov 2019 16:19:09 +0000 Subject: Factor out an _AsyncEventContextImpl (#6298) The intention here is to make it clearer which fields we can expect to be populated when: notably, that the _event_type etc aren't used for the synchronous impl of EventContext. --- changelog.d/6298.misc | 1 + synapse/events/snapshot.py | 107 ++++++++++++++++------------------------- synapse/handlers/federation.py | 38 +++++++-------- tests/test_federation.py | 4 +- 4 files changed, 65 insertions(+), 85 deletions(-) create mode 100644 changelog.d/6298.misc (limited to 'tests') diff --git a/changelog.d/6298.misc b/changelog.d/6298.misc new file mode 100644 index 0000000000..d4190730b2 --- /dev/null +++ b/changelog.d/6298.misc @@ -0,0 +1 @@ +Refactor EventContext for clarity. \ No newline at end of file diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 27cd8a63ff..a269de5482 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -37,9 +37,6 @@ class EventContext: delta_ids (dict[(str, str), str]): Delta from ``prev_group``. (type, state_key) -> event_id. ``None`` for an outlier. - prev_state_events (?): XXX: is this ever set to anything other than - the empty list? - app_service: FIXME _current_state_ids (dict[(str, str), str]|None): @@ -51,36 +48,16 @@ class EventContext: The current state map excluding the current event. None if outlier or we haven't fetched the state from DB yet. (type, state_key) -> event_id - - _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have - been calculated. None if we haven't started calculating yet - - _event_type (str): The type of the event the context is associated with. - Only set when state has not been fetched yet. - - _event_state_key (str|None): The state_key of the event the context is - associated with. Only set when state has not been fetched yet. - - _prev_state_id (str|None): If the event associated with the context is - a state event, then `_prev_state_id` is the event_id of the state - that was replaced. - Only set when state has not been fetched yet. """ state_group = attr.ib(default=None) rejected = attr.ib(default=False) prev_group = attr.ib(default=None) delta_ids = attr.ib(default=None) - prev_state_events = attr.ib(default=attr.Factory(list)) app_service = attr.ib(default=None) - _current_state_ids = attr.ib(default=None) _prev_state_ids = attr.ib(default=None) - _prev_state_id = attr.ib(default=None) - - _event_type = attr.ib(default=None) - _event_state_key = attr.ib(default=None) - _fetching_state_deferred = attr.ib(default=None) + _current_state_ids = attr.ib(default=None) @staticmethod def with_state( @@ -90,7 +67,6 @@ class EventContext: current_state_ids=current_state_ids, prev_state_ids=prev_state_ids, state_group=state_group, - fetching_state_deferred=defer.succeed(None), prev_group=prev_group, delta_ids=delta_ids, ) @@ -125,7 +101,6 @@ class EventContext: "rejected": self.rejected, "prev_group": self.prev_group, "delta_ids": _encode_state_dict(self.delta_ids), - "prev_state_events": self.prev_state_events, "app_service_id": self.app_service.id if self.app_service else None, } @@ -141,7 +116,7 @@ class EventContext: Returns: EventContext """ - context = EventContext( + context = _AsyncEventContextImpl( # We use the state_group and prev_state_id stuff to pull the # current_state_ids out of the DB and construct prev_state_ids. prev_state_id=input["prev_state_id"], @@ -151,7 +126,6 @@ class EventContext: prev_group=input["prev_group"], delta_ids=_decode_state_dict(input["delta_ids"]), rejected=input["rejected"], - prev_state_events=input["prev_state_events"], ) app_service_id = input["app_service_id"] @@ -170,14 +144,7 @@ class EventContext: Maps a (type, state_key) to the event ID of the state event matching this tuple. """ - - if not self._fetching_state_deferred: - self._fetching_state_deferred = run_in_background( - self._fill_out_state, store - ) - - yield make_deferred_yieldable(self._fetching_state_deferred) - + yield self._ensure_fetched(store) return self._current_state_ids @defer.inlineCallbacks @@ -190,14 +157,7 @@ class EventContext: Maps a (type, state_key) to the event ID of the state event matching this tuple. """ - - if not self._fetching_state_deferred: - self._fetching_state_deferred = run_in_background( - self._fill_out_state, store - ) - - yield make_deferred_yieldable(self._fetching_state_deferred) - + yield self._ensure_fetched(store) return self._prev_state_ids def get_cached_current_state_ids(self): @@ -211,6 +171,44 @@ class EventContext: return self._current_state_ids + def _ensure_fetched(self, store): + return defer.succeed(None) + + +@attr.s(slots=True) +class _AsyncEventContextImpl(EventContext): + """ + An implementation of EventContext which fetches _current_state_ids and + _prev_state_ids from the database on demand. + + Attributes: + + _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have + been calculated. None if we haven't started calculating yet + + _event_type (str): The type of the event the context is associated with. + + _event_state_key (str): The state_key of the event the context is + associated with. + + _prev_state_id (str|None): If the event associated with the context is + a state event, then `_prev_state_id` is the event_id of the state + that was replaced. + """ + + _prev_state_id = attr.ib(default=None) + _event_type = attr.ib(default=None) + _event_state_key = attr.ib(default=None) + _fetching_state_deferred = attr.ib(default=None) + + def _ensure_fetched(self, store): + if not self._fetching_state_deferred: + self._fetching_state_deferred = run_in_background( + self._fill_out_state, store + ) + + return make_deferred_yieldable(self._fetching_state_deferred) + @defer.inlineCallbacks def _fill_out_state(self, store): """Called to populate the _current_state_ids and _prev_state_ids @@ -228,27 +226,6 @@ class EventContext: else: self._prev_state_ids = self._current_state_ids - @defer.inlineCallbacks - def update_state( - self, state_group, prev_state_ids, current_state_ids, prev_group, delta_ids - ): - """Replace the state in the context - """ - - # We need to make sure we wait for any ongoing fetching of state - # to complete so that the updated state doesn't get clobbered - if self._fetching_state_deferred: - yield make_deferred_yieldable(self._fetching_state_deferred) - - self.state_group = state_group - self._prev_state_ids = prev_state_ids - self.prev_group = prev_group - self._current_state_ids = current_state_ids - self.delta_ids = delta_ids - - # We need to ensure that that we've marked as having fetched the state - self._fetching_state_deferred = defer.succeed(None) - def _encode_state_dict(state_dict): """Since dicts of (type, state_key) -> event_id cannot be serialized in diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index dab6be9573..8cafcfdab0 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -45,6 +45,7 @@ from synapse.api.errors import ( from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.crypto.event_signing import compute_event_signature from synapse.event_auth import auth_types_for_event +from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator from synapse.logging.context import ( make_deferred_yieldable, @@ -1871,14 +1872,7 @@ class FederationHandler(BaseHandler): if c and c.type == EventTypes.Create: auth_events[(c.type, c.state_key)] = c - try: - yield self.do_auth(origin, event, context, auth_events=auth_events) - except AuthError as e: - logger.warning( - "[%s %s] Rejecting: %s", event.room_id, event.event_id, e.msg - ) - - context.rejected = RejectedReason.AUTH_ERROR + context = yield self.do_auth(origin, event, context, auth_events=auth_events) if not context.rejected: yield self._check_for_soft_fail(event, state, backfilled) @@ -2047,12 +2041,12 @@ class FederationHandler(BaseHandler): Also NB that this function adds entries to it. Returns: - defer.Deferred[None] + defer.Deferred[EventContext]: updated context object """ room_version = yield self.store.get_room_version(event.room_id) try: - yield self._update_auth_events_and_context_for_auth( + context = yield self._update_auth_events_and_context_for_auth( origin, event, context, auth_events ) except Exception: @@ -2070,7 +2064,9 @@ class FederationHandler(BaseHandler): event_auth.check(room_version, event, auth_events=auth_events) except AuthError as e: logger.warning("Failed auth resolution for %r because %s", event, e) - raise e + context.rejected = RejectedReason.AUTH_ERROR + + return context @defer.inlineCallbacks def _update_auth_events_and_context_for_auth( @@ -2094,7 +2090,7 @@ class FederationHandler(BaseHandler): auth_events (dict[(str, str)->synapse.events.EventBase]): Returns: - defer.Deferred[None] + defer.Deferred[EventContext]: updated context """ event_auth_events = set(event.auth_event_ids()) @@ -2133,7 +2129,7 @@ class FederationHandler(BaseHandler): # The other side isn't around or doesn't implement the # endpoint, so lets just bail out. logger.info("Failed to get event auth from remote: %s", e) - return + return context seen_remotes = yield self.store.have_seen_events( [e.event_id for e in remote_auth_chain] @@ -2174,7 +2170,7 @@ class FederationHandler(BaseHandler): if event.internal_metadata.is_outlier(): logger.info("Skipping auth_event fetch for outlier") - return + return context # FIXME: Assumes we have and stored all the state for all the # prev_events @@ -2183,7 +2179,7 @@ class FederationHandler(BaseHandler): ) if not different_auth: - return + return context logger.info( "auth_events refers to events which are not in our calculated auth " @@ -2230,10 +2226,12 @@ class FederationHandler(BaseHandler): auth_events.update(new_state) - yield self._update_context_for_auth_events( + context = yield self._update_context_for_auth_events( event, context, auth_events, event_key ) + return context + @defer.inlineCallbacks def _update_context_for_auth_events(self, event, context, auth_events, event_key): """Update the state_ids in an event context after auth event resolution, @@ -2242,14 +2240,16 @@ class FederationHandler(BaseHandler): Args: event (Event): The event we're handling the context for - context (synapse.events.snapshot.EventContext): event context - to be updated + context (synapse.events.snapshot.EventContext): initial event context auth_events (dict[(str, str)->str]): Events to update in the event context. event_key ((str, str)): (type, state_key) for the current event. this will not be included in the current_state in the context. + + Returns: + Deferred[EventContext]: new event context """ state_updates = { k: a.event_id for k, a in iteritems(auth_events) if k != event_key @@ -2274,7 +2274,7 @@ class FederationHandler(BaseHandler): current_state_ids=current_state_ids, ) - yield context.update_state( + return EventContext.with_state( state_group=state_group, current_state_ids=current_state_ids, prev_state_ids=prev_state_ids, diff --git a/tests/test_federation.py b/tests/test_federation.py index d1acb16f30..7d82b58466 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -59,7 +59,9 @@ class MessageAcceptTests(unittest.TestCase): ) self.handler = self.homeserver.get_handlers().federation_handler - self.handler.do_auth = lambda *a, **b: succeed(True) + self.handler.do_auth = lambda origin, event, context, auth_events: succeed( + context + ) self.client = self.homeserver.get_federation_client() self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed( pdus -- cgit 1.5.1 From 988d8d6507a0e8b34f2c352c77b5742197762190 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 1 Nov 2019 16:22:44 +0000 Subject: Incorporate review --- synapse/api/constants.py | 2 +- synapse/api/filtering.py | 2 +- synapse/storage/data_stores/main/events.py | 2 +- synapse/storage/data_stores/main/schema/delta/56/event_labels.sql | 6 ++++++ tests/api/test_filtering.py | 8 ++++---- tests/rest/client/v1/test_rooms.py | 8 ++++---- tests/rest/client/v2_alpha/test_sync.py | 8 ++++---- 7 files changed, 21 insertions(+), 15 deletions(-) (limited to 'tests') diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 066ce18704..49c4b85054 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -144,4 +144,4 @@ class EventContentFields(object): """Fields found in events' content, regardless of type.""" # Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326 - Labels = "org.matrix.labels" + LABELS = "org.matrix.labels" diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 30a7ee0a7a..bec13f08d8 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -309,7 +309,7 @@ class Filter(object): content = event.get("content", {}) # check if there is a string url field in the content for filtering purposes contains_url = isinstance(content.get("url"), text_type) - labels = content.get(EventContentFields.Labels, []) + labels = content.get(EventContentFields.LABELS, []) return self.check_fields(room_id, sender, ev_type, labels, contains_url) diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 577e79bcf9..1045c7fa2e 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -1491,7 +1491,7 @@ class EventsStore( self._handle_event_relations(txn, event) # Store the labels for this event. - labels = event.content.get(EventContentFields.Labels) + labels = event.content.get(EventContentFields.LABELS) if labels: self.insert_labels_for_event_txn( txn, event.event_id, labels, event.room_id, event.depth diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql b/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql index 2acd8e1be5..5e29c1da19 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql @@ -13,6 +13,8 @@ * limitations under the License. */ +-- room_id and topoligical_ordering are denormalised from the events table in order to +-- make the index work. CREATE TABLE IF NOT EXISTS event_labels ( event_id TEXT, label TEXT, @@ -21,4 +23,8 @@ CREATE TABLE IF NOT EXISTS event_labels ( PRIMARY KEY(event_id, label) ); + +-- This index enables an event pagination looking for a particular label to index the +-- event_labels table first, which is much quicker than scanning the events table and then +-- filtering by label, if the label is rarely used relative to the size of the room. CREATE INDEX event_labels_room_id_label_idx ON event_labels(room_id, label, topological_ordering); diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 8ec48c4154..2dc5052249 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -329,7 +329,7 @@ class FilteringTestCase(unittest.TestCase): sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown", - content={EventContentFields.Labels: ["#fun"]}, + content={EventContentFields.LABELS: ["#fun"]}, ) self.assertTrue(Filter(definition).check(event)) @@ -338,7 +338,7 @@ class FilteringTestCase(unittest.TestCase): sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown", - content={EventContentFields.Labels: ["#notfun"]}, + content={EventContentFields.LABELS: ["#notfun"]}, ) self.assertFalse(Filter(definition).check(event)) @@ -349,7 +349,7 @@ class FilteringTestCase(unittest.TestCase): sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown", - content={EventContentFields.Labels: ["#fun"]}, + content={EventContentFields.LABELS: ["#fun"]}, ) self.assertFalse(Filter(definition).check(event)) @@ -358,7 +358,7 @@ class FilteringTestCase(unittest.TestCase): sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown", - content={EventContentFields.Labels: ["#notfun"]}, + content={EventContentFields.LABELS: ["#notfun"]}, ) self.assertTrue(Filter(definition).check(event)) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 0dc0faa0e5..5e38fd6ced 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -860,7 +860,7 @@ class RoomMessageListTestCase(RoomBase): content={ "msgtype": "m.text", "body": "with right label", - EventContentFields.Labels: ["#fun"], + EventContentFields.LABELS: ["#fun"], }, ) @@ -876,7 +876,7 @@ class RoomMessageListTestCase(RoomBase): content={ "msgtype": "m.text", "body": "with wrong label", - EventContentFields.Labels: ["#work"], + EventContentFields.LABELS: ["#work"], }, ) @@ -886,7 +886,7 @@ class RoomMessageListTestCase(RoomBase): content={ "msgtype": "m.text", "body": "with two wrong labels", - EventContentFields.Labels: ["#work", "#notfun"], + EventContentFields.LABELS: ["#work", "#notfun"], }, ) @@ -896,7 +896,7 @@ class RoomMessageListTestCase(RoomBase): content={ "msgtype": "m.text", "body": "with right label", - EventContentFields.Labels: ["#fun"], + EventContentFields.LABELS: ["#fun"], }, ) diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index c3c6f75ced..3283c0e47b 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -157,7 +157,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): content={ "msgtype": "m.text", "body": "with right label", - EventContentFields.Labels: ["#fun"], + EventContentFields.LABELS: ["#fun"], }, tok=tok, ) @@ -175,7 +175,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): content={ "msgtype": "m.text", "body": "with wrong label", - EventContentFields.Labels: ["#work"], + EventContentFields.LABELS: ["#work"], }, tok=tok, ) @@ -186,7 +186,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): content={ "msgtype": "m.text", "body": "with two wrong labels", - EventContentFields.Labels: ["#work", "#notfun"], + EventContentFields.LABELS: ["#work", "#notfun"], }, tok=tok, ) @@ -197,7 +197,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): content={ "msgtype": "m.text", "body": "with right label", - EventContentFields.Labels: ["#fun"], + EventContentFields.LABELS: ["#fun"], }, tok=tok, ) -- cgit 1.5.1 From 09957ce0e4dcfd84c2de4039653059faae03065b Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 4 Nov 2019 17:09:22 +0000 Subject: Implement per-room message retention policies --- changelog.d/5815.feature | 1 + docs/sample_config.yaml | 63 ++++ synapse/api/constants.py | 2 + synapse/config/server.py | 172 +++++++++++ synapse/events/validator.py | 100 ++++++- synapse/handlers/federation.py | 2 +- synapse/handlers/message.py | 4 +- synapse/handlers/pagination.py | 111 +++++++ synapse/storage/data_stores/main/events.py | 3 + synapse/storage/data_stores/main/room.py | 252 ++++++++++++++++ .../main/schema/delta/56/room_retention.sql | 33 +++ synapse/visibility.py | 17 ++ tests/rest/client/test_retention.py | 320 +++++++++++++++++++++ 13 files changed, 1074 insertions(+), 6 deletions(-) create mode 100644 changelog.d/5815.feature create mode 100644 synapse/storage/data_stores/main/schema/delta/56/room_retention.sql create mode 100644 tests/rest/client/test_retention.py (limited to 'tests') diff --git a/changelog.d/5815.feature b/changelog.d/5815.feature new file mode 100644 index 0000000000..ca4df4e7f6 --- /dev/null +++ b/changelog.d/5815.feature @@ -0,0 +1 @@ +Implement per-room message retention policies. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index d2f4aff826..87fba27d13 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -328,6 +328,69 @@ listeners: # #user_ips_max_age: 14d +# Message retention policy at the server level. +# +# Room admins and mods can define a retention period for their rooms using the +# 'm.room.retention' state event, and server admins can cap this period by setting +# the 'allowed_lifetime_min' and 'allowed_lifetime_max' config options. +# +# If this feature is enabled, Synapse will regularly look for and purge events +# which are older than the room's maximum retention period. Synapse will also +# filter events received over federation so that events that should have been +# purged are ignored and not stored again. +# +retention: + # The message retention policies feature is disabled by default. Uncomment the + # following line to enable it. + # + #enabled: true + + # Default retention policy. If set, Synapse will apply it to rooms that lack the + # 'm.room.retention' state event. Currently, the value of 'min_lifetime' doesn't + # matter much because Synapse doesn't take it into account yet. + # + #default_policy: + # min_lifetime: 1d + # max_lifetime: 1y + + # Retention policy limits. If set, a user won't be able to send a + # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime' + # that's not within this range. This is especially useful in closed federations, + # in which server admins can make sure every federating server applies the same + # rules. + # + #allowed_lifetime_min: 1d + #allowed_lifetime_max: 1y + + # Server admins can define the settings of the background jobs purging the + # events which lifetime has expired under the 'purge_jobs' section. + # + # If no configuration is provided, a single job will be set up to delete expired + # events in every room daily. + # + # Each job's configuration defines which range of message lifetimes the job + # takes care of. For example, if 'shortest_max_lifetime' is '2d' and + # 'longest_max_lifetime' is '3d', the job will handle purging expired events in + # rooms whose state defines a 'max_lifetime' that's both higher than 2 days, and + # lower than or equal to 3 days. Both the minimum and the maximum value of a + # range are optional, e.g. a job with no 'shortest_max_lifetime' and a + # 'longest_max_lifetime' of '3d' will handle every room with a retention policy + # which 'max_lifetime' is lower than or equal to three days. + # + # The rationale for this per-job configuration is that some rooms might have a + # retention policy with a low 'max_lifetime', where history needs to be purged + # of outdated messages on a very frequent basis (e.g. every 5min), but not want + # that purge to be performed by a job that's iterating over every room it knows, + # which would be quite heavy on the server. + # + #purge_jobs: + # - shortest_max_lifetime: 1d + # longest_max_lifetime: 3d + # interval: 5m: + # - shortest_max_lifetime: 3d + # longest_max_lifetime: 1y + # interval: 24h + ## TLS ## diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 49c4b85054..e3f086f1c3 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -94,6 +94,8 @@ class EventTypes(object): ServerACL = "m.room.server_acl" Pinned = "m.room.pinned_events" + Retention = "m.room.retention" + class RejectedReason(object): AUTH_ERROR = "auth_error" diff --git a/synapse/config/server.py b/synapse/config/server.py index d556df308d..aa93a416f1 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -246,6 +246,115 @@ class ServerConfig(Config): # events with profile information that differ from the target's global profile. self.allow_per_room_profiles = config.get("allow_per_room_profiles", True) + retention_config = config.get("retention") + if retention_config is None: + retention_config = {} + + self.retention_enabled = retention_config.get("enabled", False) + + retention_default_policy = retention_config.get("default_policy") + + if retention_default_policy is not None: + self.retention_default_min_lifetime = retention_default_policy.get( + "min_lifetime" + ) + if self.retention_default_min_lifetime is not None: + self.retention_default_min_lifetime = self.parse_duration( + self.retention_default_min_lifetime + ) + + self.retention_default_max_lifetime = retention_default_policy.get( + "max_lifetime" + ) + if self.retention_default_max_lifetime is not None: + self.retention_default_max_lifetime = self.parse_duration( + self.retention_default_max_lifetime + ) + + if ( + self.retention_default_min_lifetime is not None + and self.retention_default_max_lifetime is not None + and ( + self.retention_default_min_lifetime + > self.retention_default_max_lifetime + ) + ): + raise ConfigError( + "The default retention policy's 'min_lifetime' can not be greater" + " than its 'max_lifetime'" + ) + else: + self.retention_default_min_lifetime = None + self.retention_default_max_lifetime = None + + self.retention_allowed_lifetime_min = retention_config.get("allowed_lifetime_min") + if self.retention_allowed_lifetime_min is not None: + self.retention_allowed_lifetime_min = self.parse_duration( + self.retention_allowed_lifetime_min + ) + + self.retention_allowed_lifetime_max = retention_config.get("allowed_lifetime_max") + if self.retention_allowed_lifetime_max is not None: + self.retention_allowed_lifetime_max = self.parse_duration( + self.retention_allowed_lifetime_max + ) + + if ( + self.retention_allowed_lifetime_min is not None + and self.retention_allowed_lifetime_max is not None + and self.retention_allowed_lifetime_min > self.retention_allowed_lifetime_max + ): + raise ConfigError( + "Invalid retention policy limits: 'allowed_lifetime_min' can not be" + " greater than 'allowed_lifetime_max'" + ) + + self.retention_purge_jobs = [] + for purge_job_config in retention_config.get("purge_jobs", []): + interval_config = purge_job_config.get("interval") + + if interval_config is None: + raise ConfigError( + "A retention policy's purge jobs configuration must have the" + " 'interval' key set." + ) + + interval = self.parse_duration(interval_config) + + shortest_max_lifetime = purge_job_config.get("shortest_max_lifetime") + + if shortest_max_lifetime is not None: + shortest_max_lifetime = self.parse_duration(shortest_max_lifetime) + + longest_max_lifetime = purge_job_config.get("longest_max_lifetime") + + if longest_max_lifetime is not None: + longest_max_lifetime = self.parse_duration(longest_max_lifetime) + + if ( + shortest_max_lifetime is not None + and longest_max_lifetime is not None + and shortest_max_lifetime > longest_max_lifetime + ): + raise ConfigError( + "A retention policy's purge jobs configuration's" + " 'shortest_max_lifetime' value can not be greater than its" + " 'longest_max_lifetime' value." + ) + + self.retention_purge_jobs.append({ + "interval": interval, + "shortest_max_lifetime": shortest_max_lifetime, + "longest_max_lifetime": longest_max_lifetime, + }) + + if not self.retention_purge_jobs: + self.retention_purge_jobs = [{ + "interval": self.parse_duration("1d"), + "shortest_max_lifetime": None, + "longest_max_lifetime": None, + }] + self.listeners = [] # type: List[dict] for listener in config.get("listeners", []): if not isinstance(listener.get("port", None), int): @@ -761,6 +870,69 @@ class ServerConfig(Config): # Defaults to `28d`. Set to `null` to disable clearing out of old rows. # #user_ips_max_age: 14d + + # Message retention policy at the server level. + # + # Room admins and mods can define a retention period for their rooms using the + # 'm.room.retention' state event, and server admins can cap this period by setting + # the 'allowed_lifetime_min' and 'allowed_lifetime_max' config options. + # + # If this feature is enabled, Synapse will regularly look for and purge events + # which are older than the room's maximum retention period. Synapse will also + # filter events received over federation so that events that should have been + # purged are ignored and not stored again. + # + retention: + # The message retention policies feature is disabled by default. Uncomment the + # following line to enable it. + # + #enabled: true + + # Default retention policy. If set, Synapse will apply it to rooms that lack the + # 'm.room.retention' state event. Currently, the value of 'min_lifetime' doesn't + # matter much because Synapse doesn't take it into account yet. + # + #default_policy: + # min_lifetime: 1d + # max_lifetime: 1y + + # Retention policy limits. If set, a user won't be able to send a + # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime' + # that's not within this range. This is especially useful in closed federations, + # in which server admins can make sure every federating server applies the same + # rules. + # + #allowed_lifetime_min: 1d + #allowed_lifetime_max: 1y + + # Server admins can define the settings of the background jobs purging the + # events which lifetime has expired under the 'purge_jobs' section. + # + # If no configuration is provided, a single job will be set up to delete expired + # events in every room daily. + # + # Each job's configuration defines which range of message lifetimes the job + # takes care of. For example, if 'shortest_max_lifetime' is '2d' and + # 'longest_max_lifetime' is '3d', the job will handle purging expired events in + # rooms whose state defines a 'max_lifetime' that's both higher than 2 days, and + # lower than or equal to 3 days. Both the minimum and the maximum value of a + # range are optional, e.g. a job with no 'shortest_max_lifetime' and a + # 'longest_max_lifetime' of '3d' will handle every room with a retention policy + # which 'max_lifetime' is lower than or equal to three days. + # + # The rationale for this per-job configuration is that some rooms might have a + # retention policy with a low 'max_lifetime', where history needs to be purged + # of outdated messages on a very frequent basis (e.g. every 5min), but not want + # that purge to be performed by a job that's iterating over every room it knows, + # which would be quite heavy on the server. + # + #purge_jobs: + # - shortest_max_lifetime: 1d + # longest_max_lifetime: 3d + # interval: 5m: + # - shortest_max_lifetime: 3d + # longest_max_lifetime: 1y + # interval: 24h """ % locals() ) diff --git a/synapse/events/validator.py b/synapse/events/validator.py index 272426e105..9b90c9ce04 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from six import string_types +from six import integer_types, string_types from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes, Membership from synapse.api.errors import Codes, SynapseError @@ -22,11 +22,12 @@ from synapse.types import EventID, RoomID, UserID class EventValidator(object): - def validate_new(self, event): + def validate_new(self, event, config): """Validates the event has roughly the right format Args: - event (FrozenEvent) + event (FrozenEvent): The event to validate. + config (Config): The homeserver's configuration. """ self.validate_builder(event) @@ -67,6 +68,99 @@ class EventValidator(object): Codes.INVALID_PARAM, ) + if event.type == EventTypes.Retention: + self._validate_retention(event, config) + + def _validate_retention(self, event, config): + """Checks that an event that defines the retention policy for a room respects the + boundaries imposed by the server's administrator. + + Args: + event (FrozenEvent): The event to validate. + config (Config): The homeserver's configuration. + """ + min_lifetime = event.content.get("min_lifetime") + max_lifetime = event.content.get("max_lifetime") + + if min_lifetime is not None: + if not isinstance(min_lifetime, integer_types): + raise SynapseError( + code=400, + msg="'min_lifetime' must be an integer", + errcode=Codes.BAD_JSON, + ) + + if ( + config.retention_allowed_lifetime_min is not None + and min_lifetime < config.retention_allowed_lifetime_min + ): + raise SynapseError( + code=400, + msg=( + "'min_lifetime' can't be lower than the minimum allowed" + " value enforced by the server's administrator" + ), + errcode=Codes.BAD_JSON, + ) + + if ( + config.retention_allowed_lifetime_max is not None + and min_lifetime > config.retention_allowed_lifetime_max + ): + raise SynapseError( + code=400, + msg=( + "'min_lifetime' can't be greater than the maximum allowed" + " value enforced by the server's administrator" + ), + errcode=Codes.BAD_JSON, + ) + + if max_lifetime is not None: + if not isinstance(max_lifetime, integer_types): + raise SynapseError( + code=400, + msg="'max_lifetime' must be an integer", + errcode=Codes.BAD_JSON, + ) + + if ( + config.retention_allowed_lifetime_min is not None + and max_lifetime < config.retention_allowed_lifetime_min + ): + raise SynapseError( + code=400, + msg=( + "'max_lifetime' can't be lower than the minimum allowed value" + " enforced by the server's administrator" + ), + errcode=Codes.BAD_JSON, + ) + + if ( + config.retention_allowed_lifetime_max is not None + and max_lifetime > config.retention_allowed_lifetime_max + ): + raise SynapseError( + code=400, + msg=( + "'max_lifetime' can't be greater than the maximum allowed" + " value enforced by the server's administrator" + ), + errcode=Codes.BAD_JSON, + ) + + if ( + min_lifetime is not None + and max_lifetime is not None + and min_lifetime > max_lifetime + ): + raise SynapseError( + code=400, + msg="'min_lifetime' can't be greater than 'max_lifetime", + errcode=Codes.BAD_JSON, + ) + def validate_builder(self, event): """Validates that the builder/event has roughly the right format. Only checks values that we expect a proto event to have, rather than all the diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 8cafcfdab0..3994137d18 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -2454,7 +2454,7 @@ class FederationHandler(BaseHandler): room_version, event_dict, event, context ) - EventValidator().validate_new(event) + EventValidator().validate_new(event, self.config) # We need to tell the transaction queue to send this out, even # though the sender isn't a local user. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index d682dc2b7a..155ed6e06a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -417,7 +417,7 @@ class EventCreationHandler(object): 403, "You must be in the room to create an alias for it" ) - self.validator.validate_new(event) + self.validator.validate_new(event, self.config) return (event, context) @@ -634,7 +634,7 @@ class EventCreationHandler(object): if requester: context.app_service = requester.app_service - self.validator.validate_new(event) + self.validator.validate_new(event, self.config) # If this event is an annotation then we check that that the sender # can't annotate the same way twice (e.g. stops users from liking an diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 97f15a1c32..e1800177fa 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -15,12 +15,15 @@ # limitations under the License. import logging +from six import iteritems + from twisted.internet import defer from twisted.python.failure import Failure from synapse.api.constants import EventTypes, Membership from synapse.api.errors import SynapseError from synapse.logging.context import run_in_background +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.state import StateFilter from synapse.types import RoomStreamToken from synapse.util.async_helpers import ReadWriteLock @@ -80,6 +83,114 @@ class PaginationHandler(object): self._purges_by_id = {} self._event_serializer = hs.get_event_client_serializer() + self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime + + if hs.config.retention_enabled: + # Run the purge jobs described in the configuration file. + for job in hs.config.retention_purge_jobs: + self.clock.looping_call( + run_as_background_process, + job["interval"], + "purge_history_for_rooms_in_range", + self.purge_history_for_rooms_in_range, + job["shortest_max_lifetime"], + job["longest_max_lifetime"], + ) + + @defer.inlineCallbacks + def purge_history_for_rooms_in_range(self, min_ms, max_ms): + """Purge outdated events from rooms within the given retention range. + + If a default retention policy is defined in the server's configuration and its + 'max_lifetime' is within this range, also targets rooms which don't have a + retention policy. + + Args: + min_ms (int|None): Duration in milliseconds that define the lower limit of + the range to handle (exclusive). If None, it means that the range has no + lower limit. + max_ms (int|None): Duration in milliseconds that define the upper limit of + the range to handle (inclusive). If None, it means that the range has no + upper limit. + """ + # We want the storage layer to to include rooms with no retention policy in its + # return value only if a default retention policy is defined in the server's + # configuration and that policy's 'max_lifetime' is either lower (or equal) than + # max_ms or higher than min_ms (or both). + if self._retention_default_max_lifetime is not None: + include_null = True + + if min_ms is not None and min_ms >= self._retention_default_max_lifetime: + # The default max_lifetime is lower than (or equal to) min_ms. + include_null = False + + if max_ms is not None and max_ms < self._retention_default_max_lifetime: + # The default max_lifetime is higher than max_ms. + include_null = False + else: + include_null = False + + rooms = yield self.store.get_rooms_for_retention_period_in_range( + min_ms, max_ms, include_null + ) + + for room_id, retention_policy in iteritems(rooms): + if room_id in self._purges_in_progress_by_room: + logger.warning( + "[purge] not purging room %s as there's an ongoing purge running" + " for this room", + room_id, + ) + continue + + max_lifetime = retention_policy["max_lifetime"] + + if max_lifetime is None: + # If max_lifetime is None, it means that include_null equals True, + # therefore we can safely assume that there is a default policy defined + # in the server's configuration. + max_lifetime = self._retention_default_max_lifetime + + # Figure out what token we should start purging at. + ts = self.clock.time_msec() - max_lifetime + + stream_ordering = ( + yield self.store.find_first_stream_ordering_after_ts(ts) + ) + + r = ( + yield self.store.get_room_event_after_stream_ordering( + room_id, stream_ordering, + ) + ) + if not r: + logger.warning( + "[purge] purging events not possible: No event found " + "(ts %i => stream_ordering %i)", + ts, stream_ordering, + ) + continue + + (stream, topo, _event_id) = r + token = "t%d-%d" % (topo, stream) + + purge_id = random_string(16) + + self._purges_by_id[purge_id] = PurgeStatus() + + logger.info( + "Starting purging events in room %s (purge_id %s)" % (room_id, purge_id) + ) + + # We want to purge everything, including local events, and to run the purge in + # the background so that it's not blocking any other operation apart from + # other purges in the same room. + run_as_background_process( + "_purge_history", + self._purge_history, + purge_id, room_id, token, True, + ) + def start_purge_history(self, room_id, token, delete_local_events=False): """Start off a history purge on a room. diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 301f8ea128..b332a42d82 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -929,6 +929,9 @@ class EventsStore( elif event.type == EventTypes.Redaction: # Insert into the redactions table. self._store_redaction(txn, event) + elif event.type == EventTypes.Retention: + # Update the room_retention table. + self._store_retention_policy_for_room_txn(txn, event) self._handle_event_relations(txn, event) diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index 67bb1b6f60..54a7d24c73 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -19,10 +19,13 @@ import logging import re from typing import Optional, Tuple +from six import integer_types + from canonicaljson import json from twisted.internet import defer +from synapse.api.constants import EventTypes from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.search import SearchStore @@ -302,6 +305,85 @@ class RoomWorkerStore(SQLBaseStore): class RoomStore(RoomWorkerStore, SearchStore): + def __init__(self, db_conn, hs): + super(RoomStore, self).__init__(db_conn, hs) + + self.config = hs.config + + self.register_background_update_handler( + "insert_room_retention", self._background_insert_retention, + ) + + @defer.inlineCallbacks + def _background_insert_retention(self, progress, batch_size): + """Retrieves a list of all rooms within a range and inserts an entry for each of + them into the room_retention table. + NULLs the property's columns if missing from the retention event in the room's + state (or NULLs all of them if there's no retention event in the room's state), + so that we fall back to the server's retention policy. + """ + + last_room = progress.get("room_id", "") + + def _background_insert_retention_txn(txn): + txn.execute( + """ + SELECT state.room_id, state.event_id, events.json + FROM current_state_events as state + LEFT JOIN event_json AS events ON (state.event_id = events.event_id) + WHERE state.room_id > ? AND state.type = '%s' + ORDER BY state.room_id ASC + LIMIT ?; + """ % EventTypes.Retention, + (last_room, batch_size) + ) + + rows = self.cursor_to_dict(txn) + + if not rows: + return True + + for row in rows: + if not row["json"]: + retention_policy = {} + else: + ev = json.loads(row["json"]) + retention_policy = json.dumps(ev["content"]) + + self._simple_insert_txn( + txn=txn, + table="room_retention", + values={ + "room_id": row["room_id"], + "event_id": row["event_id"], + "min_lifetime": retention_policy.get("min_lifetime"), + "max_lifetime": retention_policy.get("max_lifetime"), + } + ) + + logger.info("Inserted %d rows into room_retention", len(rows)) + + self._background_update_progress_txn( + txn, "insert_room_retention", { + "room_id": rows[-1]["room_id"], + } + ) + + if batch_size > len(rows): + return True + else: + return False + + end = yield self.runInteraction( + "insert_room_retention", + _background_insert_retention_txn, + ) + + if end: + yield self._end_background_update("insert_room_retention") + + defer.returnValue(batch_size) + @defer.inlineCallbacks def store_room(self, room_id, room_creator_user_id, is_public): """Stores a room. @@ -502,6 +584,37 @@ class RoomStore(RoomWorkerStore, SearchStore): txn, event, "content.body", event.content["body"] ) + def _store_retention_policy_for_room_txn(self, txn, event): + if ( + hasattr(event, "content") + and ("min_lifetime" in event.content or "max_lifetime" in event.content) + ): + if ( + ("min_lifetime" in event.content and not isinstance( + event.content.get("min_lifetime"), integer_types + )) + or ("max_lifetime" in event.content and not isinstance( + event.content.get("max_lifetime"), integer_types + )) + ): + # Ignore the event if one of the value isn't an integer. + return + + self._simple_insert_txn( + txn=txn, + table="room_retention", + values={ + "room_id": event.room_id, + "event_id": event.event_id, + "min_lifetime": event.content.get("min_lifetime"), + "max_lifetime": event.content.get("max_lifetime"), + }, + ) + + self._invalidate_cache_and_stream( + txn, self.get_retention_policy_for_room, (event.room_id,) + ) + def add_event_report( self, room_id, event_id, user_id, reason, content, received_ts ): @@ -683,3 +796,142 @@ class RoomStore(RoomWorkerStore, SearchStore): remote_media_mxcs.append((hostname, media_id)) return local_media_mxcs, remote_media_mxcs + + @defer.inlineCallbacks + def get_rooms_for_retention_period_in_range(self, min_ms, max_ms, include_null=False): + """Retrieves all of the rooms within the given retention range. + + Optionally includes the rooms which don't have a retention policy. + + Args: + min_ms (int|None): Duration in milliseconds that define the lower limit of + the range to handle (exclusive). If None, doesn't set a lower limit. + max_ms (int|None): Duration in milliseconds that define the upper limit of + the range to handle (inclusive). If None, doesn't set an upper limit. + include_null (bool): Whether to include rooms which retention policy is NULL + in the returned set. + + Returns: + dict[str, dict]: The rooms within this range, along with their retention + policy. The key is "room_id", and maps to a dict describing the retention + policy associated with this room ID. The keys for this nested dict are + "min_lifetime" (int|None), and "max_lifetime" (int|None). + """ + + def get_rooms_for_retention_period_in_range_txn(txn): + range_conditions = [] + args = [] + + if min_ms is not None: + range_conditions.append("max_lifetime > ?") + args.append(min_ms) + + if max_ms is not None: + range_conditions.append("max_lifetime <= ?") + args.append(max_ms) + + # Do a first query which will retrieve the rooms that have a retention policy + # in their current state. + sql = """ + SELECT room_id, min_lifetime, max_lifetime FROM room_retention + INNER JOIN current_state_events USING (event_id, room_id) + """ + + if len(range_conditions): + sql += " WHERE (" + " AND ".join(range_conditions) + ")" + + if include_null: + sql += " OR max_lifetime IS NULL" + + txn.execute(sql, args) + + rows = self.cursor_to_dict(txn) + rooms_dict = {} + + for row in rows: + rooms_dict[row["room_id"]] = { + "min_lifetime": row["min_lifetime"], + "max_lifetime": row["max_lifetime"], + } + + if include_null: + # If required, do a second query that retrieves all of the rooms we know + # of so we can handle rooms with no retention policy. + sql = "SELECT DISTINCT room_id FROM current_state_events" + + txn.execute(sql) + + rows = self.cursor_to_dict(txn) + + # If a room isn't already in the dict (i.e. it doesn't have a retention + # policy in its state), add it with a null policy. + for row in rows: + if row["room_id"] not in rooms_dict: + rooms_dict[row["room_id"]] = { + "min_lifetime": None, + "max_lifetime": None, + } + + return rooms_dict + + rooms = yield self.runInteraction( + "get_rooms_for_retention_period_in_range", + get_rooms_for_retention_period_in_range_txn, + ) + + defer.returnValue(rooms) + + @cachedInlineCallbacks() + def get_retention_policy_for_room(self, room_id): + """Get the retention policy for a given room. + + If no retention policy has been found for this room, returns a policy defined + by the configured default policy (which has None as both the 'min_lifetime' and + the 'max_lifetime' if no default policy has been defined in the server's + configuration). + + Args: + room_id (str): The ID of the room to get the retention policy of. + + Returns: + dict[int, int]: "min_lifetime" and "max_lifetime" for this room. + """ + + def get_retention_policy_for_room_txn(txn): + txn.execute( + """ + SELECT min_lifetime, max_lifetime FROM room_retention + INNER JOIN current_state_events USING (event_id, room_id) + WHERE room_id = ?; + """, + (room_id,) + ) + + return self.cursor_to_dict(txn) + + ret = yield self.runInteraction( + "get_retention_policy_for_room", + get_retention_policy_for_room_txn, + ) + + # If we don't know this room ID, ret will be None, in this case return the default + # policy. + if not ret: + defer.returnValue({ + "min_lifetime": self.config.retention_default_min_lifetime, + "max_lifetime": self.config.retention_default_max_lifetime, + }) + + row = ret[0] + + # If one of the room's policy's attributes isn't defined, use the matching + # attribute from the default policy. + # The default values will be None if no default policy has been defined, or if one + # of the attributes is missing from the default policy. + if row["min_lifetime"] is None: + row["min_lifetime"] = self.config.retention_default_min_lifetime + + if row["max_lifetime"] is None: + row["max_lifetime"] = self.config.retention_default_max_lifetime + + defer.returnValue(row) diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql b/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql new file mode 100644 index 0000000000..ee6cdf7a14 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql @@ -0,0 +1,33 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Tracks the retention policy of a room. +-- A NULL max_lifetime or min_lifetime means that the matching property is not defined in +-- the room's retention policy state event. +-- If a room doesn't have a retention policy state event in its state, both max_lifetime +-- and min_lifetime are NULL. +CREATE TABLE IF NOT EXISTS room_retention( + room_id TEXT, + event_id TEXT, + min_lifetime BIGINT, + max_lifetime BIGINT, + + PRIMARY KEY(room_id, event_id) +); + +CREATE INDEX room_retention_max_lifetime_idx on room_retention(max_lifetime); + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('insert_room_retention', '{}'); diff --git a/synapse/visibility.py b/synapse/visibility.py index 8c843febd8..4498c156bc 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -86,6 +86,14 @@ def filter_events_for_client( erased_senders = yield storage.main.are_users_erased((e.sender for e in events)) + room_ids = set(e.room_id for e in events) + retention_policies = {} + + for room_id in room_ids: + retention_policies[room_id] = yield storage.main.get_retention_policy_for_room( + room_id + ) + def allowed(event): """ Args: @@ -103,6 +111,15 @@ def filter_events_for_client( if not event.is_state() and event.sender in ignore_list: return None + retention_policy = retention_policies[event.room_id] + max_lifetime = retention_policy.get("max_lifetime") + + if max_lifetime is not None: + oldest_allowed_ts = storage.main.clock.time_msec() - max_lifetime + + if event.origin_server_ts < oldest_allowed_ts: + return None + if event.event_id in always_include_ids: return event diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py new file mode 100644 index 0000000000..41ea9db689 --- /dev/null +++ b/tests/rest/client/test_retention.py @@ -0,0 +1,320 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from mock import Mock + +from synapse.api.constants import EventTypes +from synapse.rest import admin +from synapse.rest.client.v1 import login, room +from synapse.visibility import filter_events_for_client + +from tests import unittest + +one_hour_ms = 3600000 +one_day_ms = one_hour_ms * 24 + + +class RetentionTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["default_room_version"] = "1" + config["retention"] = { + "enabled": True, + "default_policy": { + "min_lifetime": one_day_ms, + "max_lifetime": one_day_ms * 3, + }, + "allowed_lifetime_min": one_day_ms, + "allowed_lifetime_max": one_day_ms * 3, + } + + self.hs = self.setup_test_homeserver(config=config) + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("user", "password") + self.token = self.login("user", "password") + + def test_retention_state_event(self): + """Tests that the server configuration can limit the values a user can set to the + room's retention policy. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.Retention, + body={ + "max_lifetime": one_day_ms * 4, + }, + tok=self.token, + expect_code=400, + ) + + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.Retention, + body={ + "max_lifetime": one_hour_ms, + }, + tok=self.token, + expect_code=400, + ) + + def test_retention_event_purged_with_state_event(self): + """Tests that expired events are correctly purged when the room's retention policy + is defined by a state event. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + # Set the room's retention period to 2 days. + lifetime = one_day_ms * 2 + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.Retention, + body={ + "max_lifetime": lifetime, + }, + tok=self.token, + ) + + self._test_retention_event_purged(room_id, one_day_ms * 1.5) + + def test_retention_event_purged_without_state_event(self): + """Tests that expired events are correctly purged when the room's retention policy + is defined by the server's configuration's default retention policy. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + self._test_retention_event_purged(room_id, one_day_ms * 2) + + def test_visibility(self): + """Tests that synapse.visibility.filter_events_for_client correctly filters out + outdated events + """ + store = self.hs.get_datastore() + storage = self.hs.get_storage() + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + events = [] + + # Send a first event, which should be filtered out at the end of the test. + resp = self.helper.send( + room_id=room_id, + body="1", + tok=self.token, + ) + + # Get the event from the store so that we end up with a FrozenEvent that we can + # give to filter_events_for_client. We need to do this now because the event won't + # be in the database anymore after it has expired. + events.append(self.get_success( + store.get_event( + resp.get("event_id") + ) + )) + + # Advance the time by 2 days. We're using the default retention policy, therefore + # after this the first event will still be valid. + self.reactor.advance(one_day_ms * 2 / 1000) + + # Send another event, which shouldn't get filtered out. + resp = self.helper.send( + room_id=room_id, + body="2", + tok=self.token, + ) + + valid_event_id = resp.get("event_id") + + events.append(self.get_success( + store.get_event( + valid_event_id + ) + )) + + # Advance the time by anothe 2 days. After this, the first event should be + # outdated but not the second one. + self.reactor.advance(one_day_ms * 2 / 1000) + + # Run filter_events_for_client with our list of FrozenEvents. + filtered_events = self.get_success(filter_events_for_client( + storage, self.user_id, events + )) + + # We should only get one event back. + self.assertEqual(len(filtered_events), 1, filtered_events) + # That event should be the second, not outdated event. + self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events) + + def _test_retention_event_purged(self, room_id, increment): + # Send a first event to the room. This is the event we'll want to be purged at the + # end of the test. + resp = self.helper.send( + room_id=room_id, + body="1", + tok=self.token, + ) + + expired_event_id = resp.get("event_id") + + # Check that we can retrieve the event. + expired_event = self.get_event(room_id, expired_event_id) + self.assertEqual(expired_event.get("content", {}).get("body"), "1", expired_event) + + # Advance the time. + self.reactor.advance(increment / 1000) + + # Send another event. We need this because the purge job won't purge the most + # recent event in the room. + resp = self.helper.send( + room_id=room_id, + body="2", + tok=self.token, + ) + + valid_event_id = resp.get("event_id") + + # Advance the time again. Now our first event should have expired but our second + # one should still be kept. + self.reactor.advance(increment / 1000) + + # Check that the event has been purged from the database. + self.get_event(room_id, expired_event_id, expected_code=404) + + # Check that the event that hasn't been purged can still be retrieved. + valid_event = self.get_event(room_id, valid_event_id) + self.assertEqual(valid_event.get("content", {}).get("body"), "2", valid_event) + + def get_event(self, room_id, event_id, expected_code=200): + url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) + + request, channel = self.make_request("GET", url, access_token=self.token) + self.render(request) + + self.assertEqual(channel.code, expected_code, channel.result) + + return channel.json_body + + +class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["default_room_version"] = "1" + config["retention"] = { + "enabled": True, + } + + mock_federation_client = Mock(spec=["backfill"]) + + self.hs = self.setup_test_homeserver( + config=config, + federation_client=mock_federation_client, + ) + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("user", "password") + self.token = self.login("user", "password") + + def test_no_default_policy(self): + """Tests that an event doesn't get expired if there is neither a default retention + policy nor a policy specific to the room. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + self._test_retention(room_id) + + def test_state_policy(self): + """Tests that an event gets correctly expired if there is no default retention + policy but there's a policy specific to the room. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + # Set the maximum lifetime to 35 days so that the first event gets expired but not + # the second one. + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.Retention, + body={ + "max_lifetime": one_day_ms * 35, + }, + tok=self.token, + ) + + self._test_retention(room_id, expected_code_for_first_event=404) + + def _test_retention(self, room_id, expected_code_for_first_event=200): + # Send a first event to the room. This is the event we'll want to be purged at the + # end of the test. + resp = self.helper.send( + room_id=room_id, + body="1", + tok=self.token, + ) + + first_event_id = resp.get("event_id") + + # Check that we can retrieve the event. + expired_event = self.get_event(room_id, first_event_id) + self.assertEqual(expired_event.get("content", {}).get("body"), "1", expired_event) + + # Advance the time by a month. + self.reactor.advance(one_day_ms * 30 / 1000) + + # Send another event. We need this because the purge job won't purge the most + # recent event in the room. + resp = self.helper.send( + room_id=room_id, + body="2", + tok=self.token, + ) + + second_event_id = resp.get("event_id") + + # Advance the time by another month. + self.reactor.advance(one_day_ms * 30 / 1000) + + # Check if the event has been purged from the database. + first_event = self.get_event( + room_id, first_event_id, expected_code=expected_code_for_first_event + ) + + if expected_code_for_first_event == 200: + self.assertEqual(first_event.get("content", {}).get("body"), "1", first_event) + + # Check that the event that hasn't been purged can still be retrieved. + second_event = self.get_event(room_id, second_event_id) + self.assertEqual(second_event.get("content", {}).get("body"), "2", second_event) + + def get_event(self, room_id, event_id, expected_code=200): + url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) + + request, channel = self.make_request("GET", url, access_token=self.token) + self.render(request) + + self.assertEqual(channel.code, expected_code, channel.result) + + return channel.json_body -- cgit 1.5.1 From 4e1c7b79fa3498c48106c17d0edbab2f7bcc0c38 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Tue, 5 Nov 2019 05:05:48 +1100 Subject: Remove the psutil dependency (#6318) * remove psutil and replace with resource --- changelog.d/6318.misc | 1 + synapse/app/homeserver.py | 174 ++++++++++++++++++++++------------------- synapse/python_dependencies.py | 1 - synapse/server.py | 2 + tests/test_phone_home.py | 51 ++++++++++++ 5 files changed, 146 insertions(+), 83 deletions(-) create mode 100644 changelog.d/6318.misc create mode 100644 tests/test_phone_home.py (limited to 'tests') diff --git a/changelog.d/6318.misc b/changelog.d/6318.misc new file mode 100644 index 0000000000..63527ccef4 --- /dev/null +++ b/changelog.d/6318.misc @@ -0,0 +1 @@ +Remove the dependency on psutil and replace functionality with the stdlib `resource` module. diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 8d28076d92..00a7f8330e 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -19,12 +19,13 @@ from __future__ import print_function import gc import logging +import math import os +import resource import sys from six import iteritems -import psutil from prometheus_client import Gauge from twisted.application import service @@ -471,6 +472,87 @@ class SynapseService(service.Service): return self._port.stopListening() +# Contains the list of processes we will be monitoring +# currently either 0 or 1 +_stats_process = [] + + +@defer.inlineCallbacks +def phone_stats_home(hs, stats, stats_process=_stats_process): + logger.info("Gathering stats for reporting") + now = int(hs.get_clock().time()) + uptime = int(now - hs.start_time) + if uptime < 0: + uptime = 0 + + stats["homeserver"] = hs.config.server_name + stats["server_context"] = hs.config.server_context + stats["timestamp"] = now + stats["uptime_seconds"] = uptime + version = sys.version_info + stats["python_version"] = "{}.{}.{}".format( + version.major, version.minor, version.micro + ) + stats["total_users"] = yield hs.get_datastore().count_all_users() + + total_nonbridged_users = yield hs.get_datastore().count_nonbridged_users() + stats["total_nonbridged_users"] = total_nonbridged_users + + daily_user_type_results = yield hs.get_datastore().count_daily_user_type() + for name, count in iteritems(daily_user_type_results): + stats["daily_user_type_" + name] = count + + room_count = yield hs.get_datastore().get_room_count() + stats["total_room_count"] = room_count + + stats["daily_active_users"] = yield hs.get_datastore().count_daily_users() + stats["monthly_active_users"] = yield hs.get_datastore().count_monthly_users() + stats["daily_active_rooms"] = yield hs.get_datastore().count_daily_active_rooms() + stats["daily_messages"] = yield hs.get_datastore().count_daily_messages() + + r30_results = yield hs.get_datastore().count_r30_users() + for name, count in iteritems(r30_results): + stats["r30_users_" + name] = count + + daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages() + stats["daily_sent_messages"] = daily_sent_messages + stats["cache_factor"] = CACHE_SIZE_FACTOR + stats["event_cache_size"] = hs.config.event_cache_size + + # + # Performance statistics + # + old = stats_process[0] + new = (now, resource.getrusage(resource.RUSAGE_SELF)) + stats_process[0] = new + + # Get RSS in bytes + stats["memory_rss"] = new[1].ru_maxrss + + # Get CPU time in % of a single core, not % of all cores + used_cpu_time = (new[1].ru_utime + new[1].ru_stime) - ( + old[1].ru_utime + old[1].ru_stime + ) + if used_cpu_time == 0 or new[0] == old[0]: + stats["cpu_average"] = 0 + else: + stats["cpu_average"] = math.floor(used_cpu_time / (new[0] - old[0]) * 100) + + # + # Database version + # + + stats["database_engine"] = hs.get_datastore().database_engine_name + stats["database_server_version"] = hs.get_datastore().get_server_version() + logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)) + try: + yield hs.get_proxied_http_client().put_json( + hs.config.report_stats_endpoint, stats + ) + except Exception as e: + logger.warning("Error reporting stats: %s", e) + + def run(hs): PROFILE_SYNAPSE = False if PROFILE_SYNAPSE: @@ -497,91 +579,19 @@ def run(hs): reactor.run = profile(reactor.run) clock = hs.get_clock() - start_time = clock.time() stats = {} - # Contains the list of processes we will be monitoring - # currently either 0 or 1 - stats_process = [] + def performance_stats_init(): + _stats_process.clear() + _stats_process.append( + (int(hs.get_clock().time(), resource.getrusage(resource.RUSAGE_SELF))) + ) def start_phone_stats_home(): - return run_as_background_process("phone_stats_home", phone_stats_home) - - @defer.inlineCallbacks - def phone_stats_home(): - logger.info("Gathering stats for reporting") - now = int(hs.get_clock().time()) - uptime = int(now - start_time) - if uptime < 0: - uptime = 0 - - stats["homeserver"] = hs.config.server_name - stats["server_context"] = hs.config.server_context - stats["timestamp"] = now - stats["uptime_seconds"] = uptime - version = sys.version_info - stats["python_version"] = "{}.{}.{}".format( - version.major, version.minor, version.micro - ) - stats["total_users"] = yield hs.get_datastore().count_all_users() - - total_nonbridged_users = yield hs.get_datastore().count_nonbridged_users() - stats["total_nonbridged_users"] = total_nonbridged_users - - daily_user_type_results = yield hs.get_datastore().count_daily_user_type() - for name, count in iteritems(daily_user_type_results): - stats["daily_user_type_" + name] = count - - room_count = yield hs.get_datastore().get_room_count() - stats["total_room_count"] = room_count - - stats["daily_active_users"] = yield hs.get_datastore().count_daily_users() - stats["monthly_active_users"] = yield hs.get_datastore().count_monthly_users() - stats[ - "daily_active_rooms" - ] = yield hs.get_datastore().count_daily_active_rooms() - stats["daily_messages"] = yield hs.get_datastore().count_daily_messages() - - r30_results = yield hs.get_datastore().count_r30_users() - for name, count in iteritems(r30_results): - stats["r30_users_" + name] = count - - daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages() - stats["daily_sent_messages"] = daily_sent_messages - stats["cache_factor"] = CACHE_SIZE_FACTOR - stats["event_cache_size"] = hs.config.event_cache_size - - if len(stats_process) > 0: - stats["memory_rss"] = 0 - stats["cpu_average"] = 0 - for process in stats_process: - stats["memory_rss"] += process.memory_info().rss - stats["cpu_average"] += int(process.cpu_percent(interval=None)) - - stats["database_engine"] = hs.get_datastore().database_engine_name - stats["database_server_version"] = hs.get_datastore().get_server_version() - logger.info( - "Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats) + return run_as_background_process( + "phone_stats_home", phone_stats_home, hs, stats ) - try: - yield hs.get_proxied_http_client().put_json( - hs.config.report_stats_endpoint, stats - ) - except Exception as e: - logger.warning("Error reporting stats: %s", e) - - def performance_stats_init(): - try: - process = psutil.Process() - # Ensure we can fetch both, and make the initial request for cpu_percent - # so the next request will use this as the initial point. - process.memory_info().rss - process.cpu_percent(interval=None) - logger.info("report_stats can use psutil") - stats_process.append(process) - except (AttributeError): - logger.warning("Unable to read memory/cpu stats. Disabling reporting.") def generate_user_daily_visit_stats(): return run_as_background_process( @@ -626,7 +636,7 @@ def run(hs): if hs.config.report_stats: logger.info("Scheduling stats reporting for 3 hour intervals") - clock.looping_call(start_phone_stats_home, 3 * 60 * 60 * 1000) + clock.looping_call(start_phone_stats_home, 3 * 60 * 60 * 1000, hs, stats) # We need to defer this init for the cases that we daemonize # otherwise the process ID we get is that of the non-daemon process @@ -634,7 +644,7 @@ def run(hs): # We wait 5 minutes to send the first set of stats as the server can # be quite busy the first few minutes - clock.call_later(5 * 60, start_phone_stats_home) + clock.call_later(5 * 60, start_phone_stats_home, hs, stats) _base.start_reactor( "synapse-homeserver", diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index aa7da1c543..5871feaafd 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -61,7 +61,6 @@ REQUIREMENTS = [ "bcrypt>=3.1.0", "pillow>=4.3.0", "sortedcontainers>=1.4.4", - "psutil>=2.0.0", "pymacaroons>=0.13.0", "msgpack>=0.5.2", "phonenumbers>=8.2.0", diff --git a/synapse/server.py b/synapse/server.py index f8aeebcff8..90c3b072e8 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -221,6 +221,7 @@ class HomeServer(object): self.hostname = hostname self._building = {} self._listening_services = [] + self.start_time = None self.clock = Clock(reactor) self.distributor = Distributor() @@ -240,6 +241,7 @@ class HomeServer(object): datastore = self.DATASTORE_CLASS(conn, self) self.datastores = DataStores(datastore, conn, self) conn.commit() + self.start_time = int(self.get_clock().time()) logger.info("Finished setting up.") def setup_master(self): diff --git a/tests/test_phone_home.py b/tests/test_phone_home.py new file mode 100644 index 0000000000..7657bddea5 --- /dev/null +++ b/tests/test_phone_home.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import resource + +import mock + +from synapse.app.homeserver import phone_stats_home + +from tests.unittest import HomeserverTestCase + + +class PhoneHomeStatsTestCase(HomeserverTestCase): + def test_performance_frozen_clock(self): + """ + If time doesn't move, don't error out. + """ + past_stats = [ + (self.hs.get_clock().time(), resource.getrusage(resource.RUSAGE_SELF)) + ] + stats = {} + self.get_success(phone_stats_home(self.hs, stats, past_stats)) + self.assertEqual(stats["cpu_average"], 0) + + def test_performance_100(self): + """ + 1 second of usage over 1 second is 100% CPU usage. + """ + real_res = resource.getrusage(resource.RUSAGE_SELF) + old_resource = mock.Mock(spec=real_res) + old_resource.ru_utime = real_res.ru_utime - 1 + old_resource.ru_stime = real_res.ru_stime + old_resource.ru_maxrss = real_res.ru_maxrss + + past_stats = [(self.hs.get_clock().time(), old_resource)] + stats = {} + self.reactor.advance(1) + self.get_success(phone_stats_home(self.hs, stats, past_stats)) + self.assertApproximates(stats["cpu_average"], 100, tolerance=2.5) -- cgit 1.5.1 From a7c818c79b70d6b70abc5b26f0e1e78fd60c087e Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 5 Nov 2019 13:21:26 +0000 Subject: Add test case --- tests/rest/client/v1/test_rooms.py | 182 +++++++++++++++++++++++++++++++++++++ 1 file changed, 182 insertions(+) (limited to 'tests') diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 5e38fd6ced..621c894e35 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1106,3 +1106,185 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): res_displayname = channel.json_body["content"]["displayname"] self.assertEqual(res_displayname, self.displayname, channel.result) + + +class ContextTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + profile.register_servlets, + ] + + def test_context_filter_labels(self): + """Test that we can filter by a label.""" + context_filter = json.dumps( + { + "types": [EventTypes.Message], + "org.matrix.labels": ["#fun"], + } + ) + + res = self._test_context_filter_labels(context_filter) + + self.assertEqual( + res["event"]["content"]["body"], "with right label", res["event"] + ) + + events_before = res["events_before"] + + self.assertEqual( + len(events_before), 1, [event["content"] for event in events_before] + ) + self.assertEqual( + events_before[0]["content"]["body"], "with right label", events_before[0] + ) + + events_after = res["events_before"] + + self.assertEqual( + len(events_after), 1, [event["content"] for event in events_after] + ) + self.assertEqual( + events_after[0]["content"]["body"], "with right label", events_after[0] + ) + + def test_context_filter_not_labels(self): + """Test that we can filter by the absence of a label.""" + context_filter = json.dumps( + { + "types": [EventTypes.Message], + "org.matrix.not_labels": ["#fun"], + } + ) + + res = self._test_context_filter_labels(context_filter) + + events_before = res["events_before"] + + self.assertEqual( + len(events_before), 1, [event["content"] for event in events_before] + ) + self.assertEqual( + events_before[0]["content"]["body"], "without label", events_before[0] + ) + + events_after = res["events_after"] + + self.assertEqual( + len(events_after), 2, [event["content"] for event in events_after] + ) + self.assertEqual( + events_after[0]["content"]["body"], "with wrong label", events_after[0] + ) + self.assertEqual( + events_after[1]["content"]["body"], "with two wrong labels", events_after[1] + ) + + def test_context_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label.""" + context_filter = json.dumps( + { + "types": [EventTypes.Message], + "org.matrix.labels": ["#work"], + "org.matrix.not_labels": ["#notfun"], + } + ) + + res = self._test_context_filter_labels(context_filter) + + events_before = res["events_before"] + + self.assertEqual( + len(events_before), 0, [event["content"] for event in events_before] + ) + + events_after = res["events_after"] + + self.assertEqual( + len(events_after), 1, [event["content"] for event in events_after] + ) + self.assertEqual( + events_after[0]["content"]["body"], "with wrong label", events_after[0] + ) + + def _test_context_filter_labels(self, context_filter): + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + + room_id = self.helper.create_room_as(user_id, tok=tok) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + tok=tok, + ) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "without label"}, + tok=tok, + ) + + # The event we'll look up the context for. + res = self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + tok=tok, + ) + event_id = res["event_id"] + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with wrong label", + EventContentFields.LABELS: ["#work"], + }, + tok=tok, + ) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with two wrong labels", + EventContentFields.LABELS: ["#work", "#notfun"], + }, + tok=tok, + ) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + tok=tok, + ) + + request, channel = self.make_request( + "GET", + "/rooms/%s/context/%s?filter=%s" % (room_id, event_id, context_filter), + access_token=tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + return channel.json_body + -- cgit 1.5.1 From c9e4748cb75271a2178d0cae05d551829249ada3 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 5 Nov 2019 13:47:47 +0000 Subject: Merge labels tests for /context and /messages --- tests/rest/client/v1/test_rooms.py | 276 +++++++++++++++++-------------------- 1 file changed, 130 insertions(+), 146 deletions(-) (limited to 'tests') diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 621c894e35..fe327d1bf8 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -811,105 +811,6 @@ class RoomMessageListTestCase(RoomBase): self.assertTrue("chunk" in channel.json_body) self.assertTrue("end" in channel.json_body) - def test_filter_labels(self): - """Test that we can filter by a label.""" - message_filter = json.dumps( - {"types": [EventTypes.Message], "org.matrix.labels": ["#fun"]} - ) - - events = self._test_filter_labels(message_filter) - - self.assertEqual(len(events), 2, [event["content"] for event in events]) - self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) - self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) - - def test_filter_not_labels(self): - """Test that we can filter by the absence of a label.""" - message_filter = json.dumps( - {"types": [EventTypes.Message], "org.matrix.not_labels": ["#fun"]} - ) - - events = self._test_filter_labels(message_filter) - - self.assertEqual(len(events), 3, [event["content"] for event in events]) - self.assertEqual(events[0]["content"]["body"], "without label", events[0]) - self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1]) - self.assertEqual( - events[2]["content"]["body"], "with two wrong labels", events[2] - ) - - def test_filter_labels_not_labels(self): - """Test that we can filter by both a label and the absence of another label.""" - sync_filter = json.dumps( - { - "types": [EventTypes.Message], - "org.matrix.labels": ["#work"], - "org.matrix.not_labels": ["#notfun"], - } - ) - - events = self._test_filter_labels(sync_filter) - - self.assertEqual(len(events), 1, [event["content"] for event in events]) - self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) - - def _test_filter_labels(self, message_filter): - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with right label", - EventContentFields.LABELS: ["#fun"], - }, - ) - - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"msgtype": "m.text", "body": "without label"}, - ) - - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with wrong label", - EventContentFields.LABELS: ["#work"], - }, - ) - - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with two wrong labels", - EventContentFields.LABELS: ["#work", "#notfun"], - }, - ) - - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with right label", - EventContentFields.LABELS: ["#fun"], - }, - ) - - token = "s0_0_0_0_0_0_0_0_0" - request, channel = self.make_request( - "GET", - "/rooms/%s/messages?access_token=x&from=%s&filter=%s" - % (self.room_id, token, message_filter), - ) - self.render(request) - - return channel.json_body["chunk"] - class RoomSearchTestCase(unittest.HomeserverTestCase): servlets = [ @@ -1108,7 +1009,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): self.assertEqual(res_displayname, self.displayname, channel.result) -class ContextTestCase(unittest.HomeserverTestCase): +class LabelsTestCase(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -1116,8 +1017,13 @@ class ContextTestCase(unittest.HomeserverTestCase): profile.register_servlets, ] + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("test", "test") + self.tok = self.login("test", "test") + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + def test_context_filter_labels(self): - """Test that we can filter by a label.""" + """Test that we can filter by a label on a /context request.""" context_filter = json.dumps( { "types": [EventTypes.Message], @@ -1125,13 +1031,17 @@ class ContextTestCase(unittest.HomeserverTestCase): } ) - res = self._test_context_filter_labels(context_filter) + event_id = self._send_labelled_messages_in_room() - self.assertEqual( - res["event"]["content"]["body"], "with right label", res["event"] + request, channel = self.make_request( + "GET", + "/rooms/%s/context/%s?filter=%s" % (self.room_id, event_id, context_filter), + access_token=self.tok, ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) - events_before = res["events_before"] + events_before = channel.json_body["events_before"] self.assertEqual( len(events_before), 1, [event["content"] for event in events_before] @@ -1140,7 +1050,7 @@ class ContextTestCase(unittest.HomeserverTestCase): events_before[0]["content"]["body"], "with right label", events_before[0] ) - events_after = res["events_before"] + events_after = channel.json_body["events_before"] self.assertEqual( len(events_after), 1, [event["content"] for event in events_after] @@ -1150,7 +1060,7 @@ class ContextTestCase(unittest.HomeserverTestCase): ) def test_context_filter_not_labels(self): - """Test that we can filter by the absence of a label.""" + """Test that we can filter by the absence of a label on a /context request.""" context_filter = json.dumps( { "types": [EventTypes.Message], @@ -1158,9 +1068,17 @@ class ContextTestCase(unittest.HomeserverTestCase): } ) - res = self._test_context_filter_labels(context_filter) + event_id = self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "GET", + "/rooms/%s/context/%s?filter=%s" % (self.room_id, event_id, context_filter), + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) - events_before = res["events_before"] + events_before = channel.json_body["events_before"] self.assertEqual( len(events_before), 1, [event["content"] for event in events_before] @@ -1169,7 +1087,7 @@ class ContextTestCase(unittest.HomeserverTestCase): events_before[0]["content"]["body"], "without label", events_before[0] ) - events_after = res["events_after"] + events_after = channel.json_body["events_after"] self.assertEqual( len(events_after), 2, [event["content"] for event in events_after] @@ -1182,7 +1100,9 @@ class ContextTestCase(unittest.HomeserverTestCase): ) def test_context_filter_labels_not_labels(self): - """Test that we can filter by both a label and the absence of another label.""" + """Test that we can filter by both a label and the absence of another label on a + /context request. + """ context_filter = json.dumps( { "types": [EventTypes.Message], @@ -1191,15 +1111,23 @@ class ContextTestCase(unittest.HomeserverTestCase): } ) - res = self._test_context_filter_labels(context_filter) + event_id = self._send_labelled_messages_in_room() - events_before = res["events_before"] + request, channel = self.make_request( + "GET", + "/rooms/%s/context/%s?filter=%s" % (self.room_id, event_id, context_filter), + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] self.assertEqual( len(events_before), 0, [event["content"] for event in events_before] ) - events_after = res["events_after"] + events_after = channel.json_body["events_after"] self.assertEqual( len(events_after), 1, [event["content"] for event in events_after] @@ -1208,83 +1136,139 @@ class ContextTestCase(unittest.HomeserverTestCase): events_after[0]["content"]["body"], "with wrong label", events_after[0] ) - def _test_context_filter_labels(self, context_filter): - user_id = self.register_user("kermit", "test") - tok = self.login("kermit", "test") + def test_messages_filter_labels(self): + """Test that we can filter by a label on a /messages request.""" + message_filter = json.dumps( + {"types": [EventTypes.Message], "org.matrix.labels": ["#fun"]} + ) + + self._send_labelled_messages_in_room() + + token = "s0_0_0_0_0_0_0_0_0" + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" + % (self.room_id, self.tok, token, message_filter), + ) + self.render(request) + + events = channel.json_body["chunk"] + + self.assertEqual(len(events), 2, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) + self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) + + def test_messages_filter_not_labels(self): + """Test that we can filter by the absence of a label on a /messages request.""" + message_filter = json.dumps( + {"types": [EventTypes.Message], "org.matrix.not_labels": ["#fun"]} + ) + + self._send_labelled_messages_in_room() + + token = "s0_0_0_0_0_0_0_0_0" + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" + % (self.room_id, self.tok, token, message_filter), + ) + self.render(request) + + events = channel.json_body["chunk"] + + self.assertEqual(len(events), 4, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "without label", events[0]) + self.assertEqual(events[1]["content"]["body"], "without label", events[1]) + self.assertEqual(events[2]["content"]["body"], "with wrong label", events[2]) + self.assertEqual( + events[3]["content"]["body"], "with two wrong labels", events[3] + ) + + def test_messages_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label on a + /messages request. + """ + message_filter = json.dumps( + { + "types": [EventTypes.Message], + "org.matrix.labels": ["#work"], + "org.matrix.not_labels": ["#notfun"], + } + ) + + self._send_labelled_messages_in_room() + + token = "s0_0_0_0_0_0_0_0_0" + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" + % (self.room_id, self.tok, token, message_filter), + ) + self.render(request) + + events = channel.json_body["chunk"] - room_id = self.helper.create_room_as(user_id, tok=tok) + self.assertEqual(len(events), 1, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) + def _send_labelled_messages_in_room(self): self.helper.send_event( - room_id=room_id, + room_id=self.room_id, type=EventTypes.Message, content={ "msgtype": "m.text", "body": "with right label", EventContentFields.LABELS: ["#fun"], }, - tok=tok, + tok=self.tok, ) self.helper.send_event( - room_id=room_id, + room_id=self.room_id, type=EventTypes.Message, content={"msgtype": "m.text", "body": "without label"}, - tok=tok, + tok=self.tok, ) - # The event we'll look up the context for. res = self.helper.send_event( - room_id=room_id, + room_id=self.room_id, type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with right label", - EventContentFields.LABELS: ["#fun"], - }, - tok=tok, + content={"msgtype": "m.text", "body": "without label"}, + tok=self.tok, ) event_id = res["event_id"] self.helper.send_event( - room_id=room_id, + room_id=self.room_id, type=EventTypes.Message, content={ "msgtype": "m.text", "body": "with wrong label", EventContentFields.LABELS: ["#work"], }, - tok=tok, + tok=self.tok, ) self.helper.send_event( - room_id=room_id, + room_id=self.room_id, type=EventTypes.Message, content={ "msgtype": "m.text", "body": "with two wrong labels", EventContentFields.LABELS: ["#work", "#notfun"], }, - tok=tok, + tok=self.tok, ) self.helper.send_event( - room_id=room_id, + room_id=self.room_id, type=EventTypes.Message, content={ "msgtype": "m.text", "body": "with right label", EventContentFields.LABELS: ["#fun"], }, - tok=tok, + tok=self.tok, ) - request, channel = self.make_request( - "GET", - "/rooms/%s/context/%s?filter=%s" % (room_id, event_id, context_filter), - access_token=tok, - ) - self.render(request) - self.assertEqual(channel.code, 200, channel.result) - - return channel.json_body - + return event_id -- cgit 1.5.1 From 037360e6cf2ca181b7cf03884375d4a4d52ad64e Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 5 Nov 2019 14:33:18 +0000 Subject: Add tests for /search --- tests/rest/client/v1/test_rooms.py | 187 ++++++++++++++++++++++++++++--------- 1 file changed, 143 insertions(+), 44 deletions(-) (limited to 'tests') diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index fe327d1bf8..cc7499dcc0 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1017,6 +1017,18 @@ class LabelsTestCase(unittest.HomeserverTestCase): profile.register_servlets, ] + # Filter that should only catch messages with the label "#fun". + FILTER_LABELS = {"types": [EventTypes.Message], "org.matrix.labels": ["#fun"]} + # Filter that should only catch messages without the label "#fun". + FILTER_NOT_LABELS = {"types": [EventTypes.Message], "org.matrix.not_labels": ["#fun"]} + # Filter that should only catch messages with the label "#work" but without the label + # "#notfun". + FILTER_LABELS_NOT_LABELS = { + "types": [EventTypes.Message], + "org.matrix.labels": ["#work"], + "org.matrix.not_labels": ["#notfun"], + } + def prepare(self, reactor, clock, homeserver): self.user_id = self.register_user("test", "test") self.tok = self.login("test", "test") @@ -1024,18 +1036,12 @@ class LabelsTestCase(unittest.HomeserverTestCase): def test_context_filter_labels(self): """Test that we can filter by a label on a /context request.""" - context_filter = json.dumps( - { - "types": [EventTypes.Message], - "org.matrix.labels": ["#fun"], - } - ) - event_id = self._send_labelled_messages_in_room() request, channel = self.make_request( "GET", - "/rooms/%s/context/%s?filter=%s" % (self.room_id, event_id, context_filter), + "/rooms/%s/context/%s?filter=%s" + % (self.room_id, event_id, json.dumps(self.FILTER_LABELS)), access_token=self.tok, ) self.render(request) @@ -1061,18 +1067,12 @@ class LabelsTestCase(unittest.HomeserverTestCase): def test_context_filter_not_labels(self): """Test that we can filter by the absence of a label on a /context request.""" - context_filter = json.dumps( - { - "types": [EventTypes.Message], - "org.matrix.not_labels": ["#fun"], - } - ) - event_id = self._send_labelled_messages_in_room() request, channel = self.make_request( "GET", - "/rooms/%s/context/%s?filter=%s" % (self.room_id, event_id, context_filter), + "/rooms/%s/context/%s?filter=%s" + % (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)), access_token=self.tok, ) self.render(request) @@ -1103,19 +1103,12 @@ class LabelsTestCase(unittest.HomeserverTestCase): """Test that we can filter by both a label and the absence of another label on a /context request. """ - context_filter = json.dumps( - { - "types": [EventTypes.Message], - "org.matrix.labels": ["#work"], - "org.matrix.not_labels": ["#notfun"], - } - ) - event_id = self._send_labelled_messages_in_room() request, channel = self.make_request( "GET", - "/rooms/%s/context/%s?filter=%s" % (self.room_id, event_id, context_filter), + "/rooms/%s/context/%s?filter=%s" + % (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)), access_token=self.tok, ) self.render(request) @@ -1138,17 +1131,13 @@ class LabelsTestCase(unittest.HomeserverTestCase): def test_messages_filter_labels(self): """Test that we can filter by a label on a /messages request.""" - message_filter = json.dumps( - {"types": [EventTypes.Message], "org.matrix.labels": ["#fun"]} - ) - self._send_labelled_messages_in_room() token = "s0_0_0_0_0_0_0_0_0" request, channel = self.make_request( "GET", "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" - % (self.room_id, self.tok, token, message_filter), + % (self.room_id, self.tok, token, json.dumps(self.FILTER_LABELS)), ) self.render(request) @@ -1160,17 +1149,13 @@ class LabelsTestCase(unittest.HomeserverTestCase): def test_messages_filter_not_labels(self): """Test that we can filter by the absence of a label on a /messages request.""" - message_filter = json.dumps( - {"types": [EventTypes.Message], "org.matrix.not_labels": ["#fun"]} - ) - self._send_labelled_messages_in_room() token = "s0_0_0_0_0_0_0_0_0" request, channel = self.make_request( "GET", "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" - % (self.room_id, self.tok, token, message_filter), + % (self.room_id, self.tok, token, json.dumps(self.FILTER_NOT_LABELS)), ) self.render(request) @@ -1188,21 +1173,13 @@ class LabelsTestCase(unittest.HomeserverTestCase): """Test that we can filter by both a label and the absence of another label on a /messages request. """ - message_filter = json.dumps( - { - "types": [EventTypes.Message], - "org.matrix.labels": ["#work"], - "org.matrix.not_labels": ["#notfun"], - } - ) - self._send_labelled_messages_in_room() token = "s0_0_0_0_0_0_0_0_0" request, channel = self.make_request( "GET", "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" - % (self.room_id, self.tok, token, message_filter), + % (self.room_id, self.tok, token, json.dumps(self.FILTER_LABELS_NOT_LABELS)), ) self.render(request) @@ -1211,7 +1188,128 @@ class LabelsTestCase(unittest.HomeserverTestCase): self.assertEqual(len(events), 1, [event["content"] for event in events]) self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) + def test_search_filter_labels(self): + """Test that we can filter by a label on a /search request.""" + request_data = json.dumps({ + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS, + } + } + }) + + self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "POST", "/search?access_token=%s" % self.tok, request_data + ) + self.render(request) + + results = channel.json_body["search_categories"]["room_events"]["results"] + + self.assertEqual( + len(results), + 2, + [result["result"]["content"] for result in results], + ) + self.assertEqual( + results[0]["result"]["content"]["body"], + "with right label", + results[0]["result"]["content"]["body"], + ) + self.assertEqual( + results[1]["result"]["content"]["body"], + "with right label", + results[1]["result"]["content"]["body"], + ) + + def test_search_filter_not_labels(self): + """Test that we can filter by the absence of a label on a /search request.""" + request_data = json.dumps({ + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_NOT_LABELS, + } + } + }) + + self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "POST", "/search?access_token=%s" % self.tok, request_data + ) + self.render(request) + + results = channel.json_body["search_categories"]["room_events"]["results"] + + self.assertEqual( + len(results), + 4, + [result["result"]["content"] for result in results], + ) + self.assertEqual( + results[0]["result"]["content"]["body"], + "without label", + results[0]["result"]["content"]["body"], + ) + self.assertEqual( + results[1]["result"]["content"]["body"], + "without label", + results[1]["result"]["content"]["body"], + ) + self.assertEqual( + results[2]["result"]["content"]["body"], + "with wrong label", + results[2]["result"]["content"]["body"], + ) + self.assertEqual( + results[3]["result"]["content"]["body"], + "with two wrong labels", + results[3]["result"]["content"]["body"], + ) + + def test_search_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label on a + /search request. + """ + request_data = json.dumps({ + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS_NOT_LABELS, + } + } + }) + + self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "POST", "/search?access_token=%s" % self.tok, request_data + ) + self.render(request) + + results = channel.json_body["search_categories"]["room_events"]["results"] + + self.assertEqual( + len(results), + 1, + [result["result"]["content"] for result in results], + ) + self.assertEqual( + results[0]["result"]["content"]["body"], + "with wrong label", + results[0]["result"]["content"]["body"], + ) + def _send_labelled_messages_in_room(self): + """Sends several messages to a room with different labels (or without any) to test + filtering by label. + + Returns: + The ID of the event to use if we're testing filtering on /context. + """ self.helper.send_event( room_id=self.room_id, type=EventTypes.Message, @@ -1236,6 +1334,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): content={"msgtype": "m.text", "body": "without label"}, tok=self.tok, ) + # Return this event's ID when we test filtering in /context requests. event_id = res["event_id"] self.helper.send_event( -- cgit 1.5.1 From 8822b331114a2f6fdcd5916f0c91991c0acae07e Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 5 Nov 2019 10:56:39 +0000 Subject: Update copyrights --- synapse/api/constants.py | 3 ++- synapse/api/filtering.py | 3 +++ synapse/rest/client/versions.py | 3 +++ synapse/storage/data_stores/main/stream.py | 3 +++ tests/api/test_filtering.py | 3 +++ tests/rest/client/v1/test_rooms.py | 2 ++ tests/rest/client/v1/utils.py | 3 +++ tests/rest/client/v2_alpha/test_sync.py | 3 ++- 8 files changed, 21 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 49c4b85054..312acff3d6 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index bec13f08d8..6eab1f13f0 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index bb30ce3f34..2a477ad22e 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py index 616ef91d4e..9cac664880 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 2dc5052249..63d8633582 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index cc7499dcc0..b2c1ef6f0e 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd # Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 8ea0cb05ea..e7417b3d14 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 3283c0e47b..661c1f88b9 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2018 New Vector +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. -- cgit 1.5.1 From a6863da24934dcbb2ae09a9e0b6e37140ef390ff Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 5 Nov 2019 14:50:19 +0000 Subject: Lint --- tests/rest/client/v1/test_rooms.py | 71 ++++++++++++++++++++++---------------- 1 file changed, 41 insertions(+), 30 deletions(-) (limited to 'tests') diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index b2c1ef6f0e..c5d67fc1cd 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1020,9 +1020,15 @@ class LabelsTestCase(unittest.HomeserverTestCase): ] # Filter that should only catch messages with the label "#fun". - FILTER_LABELS = {"types": [EventTypes.Message], "org.matrix.labels": ["#fun"]} + FILTER_LABELS = { + "types": [EventTypes.Message], + "org.matrix.labels": ["#fun"], + } # Filter that should only catch messages without the label "#fun". - FILTER_NOT_LABELS = {"types": [EventTypes.Message], "org.matrix.not_labels": ["#fun"]} + FILTER_NOT_LABELS = { + "types": [EventTypes.Message], + "org.matrix.not_labels": ["#fun"], + } # Filter that should only catch messages with the label "#work" but without the label # "#notfun". FILTER_LABELS_NOT_LABELS = { @@ -1181,7 +1187,12 @@ class LabelsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" - % (self.room_id, self.tok, token, json.dumps(self.FILTER_LABELS_NOT_LABELS)), + % ( + self.room_id, + self.tok, + token, + json.dumps(self.FILTER_LABELS_NOT_LABELS), + ), ) self.render(request) @@ -1192,14 +1203,16 @@ class LabelsTestCase(unittest.HomeserverTestCase): def test_search_filter_labels(self): """Test that we can filter by a label on a /search request.""" - request_data = json.dumps({ - "search_categories": { - "room_events": { - "search_term": "label", - "filter": self.FILTER_LABELS, + request_data = json.dumps( + { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS, + } } } - }) + ) self._send_labelled_messages_in_room() @@ -1211,9 +1224,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): results = channel.json_body["search_categories"]["room_events"]["results"] self.assertEqual( - len(results), - 2, - [result["result"]["content"] for result in results], + len(results), 2, [result["result"]["content"] for result in results], ) self.assertEqual( results[0]["result"]["content"]["body"], @@ -1228,14 +1239,16 @@ class LabelsTestCase(unittest.HomeserverTestCase): def test_search_filter_not_labels(self): """Test that we can filter by the absence of a label on a /search request.""" - request_data = json.dumps({ - "search_categories": { - "room_events": { - "search_term": "label", - "filter": self.FILTER_NOT_LABELS, + request_data = json.dumps( + { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_NOT_LABELS, + } } } - }) + ) self._send_labelled_messages_in_room() @@ -1247,9 +1260,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): results = channel.json_body["search_categories"]["room_events"]["results"] self.assertEqual( - len(results), - 4, - [result["result"]["content"] for result in results], + len(results), 4, [result["result"]["content"] for result in results], ) self.assertEqual( results[0]["result"]["content"]["body"], @@ -1276,14 +1287,16 @@ class LabelsTestCase(unittest.HomeserverTestCase): """Test that we can filter by both a label and the absence of another label on a /search request. """ - request_data = json.dumps({ - "search_categories": { - "room_events": { - "search_term": "label", - "filter": self.FILTER_LABELS_NOT_LABELS, + request_data = json.dumps( + { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS_NOT_LABELS, + } } } - }) + ) self._send_labelled_messages_in_room() @@ -1295,9 +1308,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): results = channel.json_body["search_categories"]["room_events"]["results"] self.assertEqual( - len(results), - 1, - [result["result"]["content"] for result in results], + len(results), 1, [result["result"]["content"] for result in results], ) self.assertEqual( results[0]["result"]["content"]["body"], -- cgit 1.5.1 From e9bfe719ba1928dc191cea93120c5c8a89584434 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 5 Nov 2019 15:45:17 +0000 Subject: Strip overlong OpenGraph data from url preview ... to stop people causing DoSes with malicious web pages --- changelog.d/6331.feature | 1 + synapse/rest/media/v1/preview_url_resource.py | 20 +++++++++++++++- tests/rest/media/v1/test_url_preview.py | 34 +++++++++++++++++++++++++++ 3 files changed, 54 insertions(+), 1 deletion(-) create mode 100644 changelog.d/6331.feature (limited to 'tests') diff --git a/changelog.d/6331.feature b/changelog.d/6331.feature new file mode 100644 index 0000000000..eaf69ef3f6 --- /dev/null +++ b/changelog.d/6331.feature @@ -0,0 +1 @@ +Limit the length of data returned by url previews, to prevent DoS attacks. diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 0c68c3aad5..6d8c39a410 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -56,6 +56,9 @@ logger = logging.getLogger(__name__) _charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I) _content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I) +OG_TAG_NAME_MAXLEN = 50 +OG_TAG_VALUE_MAXLEN = 1000 + class PreviewUrlResource(DirectServeResource): isLeaf = True @@ -167,7 +170,7 @@ class PreviewUrlResource(DirectServeResource): ts (int): Returns: - Deferred[str]: json-encoded og data + Deferred[bytes]: json-encoded og data """ # check the URL cache in the DB (which will also provide us with # historical previews, if we have any) @@ -268,6 +271,17 @@ class PreviewUrlResource(DirectServeResource): logger.warn("Failed to find any OG data in %s", url) og = {} + # filter out any stupidly long values + keys_to_remove = [] + for k, v in og.items(): + if len(k) > OG_TAG_NAME_MAXLEN or len(v) > OG_TAG_VALUE_MAXLEN: + logger.warning( + "Pruning overlong tag %s from OG data", k[:OG_TAG_NAME_MAXLEN] + ) + keys_to_remove.append(k) + for k in keys_to_remove: + del og[k] + logger.debug("Calculated OG for %s as %s" % (url, og)) jsonog = json.dumps(og) @@ -502,6 +516,10 @@ def _calc_og(tree, media_uri): og = {} for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"): if "content" in tag.attrib: + # if we've got more than 50 tags, someone is taking the piss + if len(og) >= 50: + logger.warning("skipping OG for page with too many og: tags") + return {} og[tag.attrib["property"]] = tag.attrib["content"] # TODO: grab article: meta tags too, e.g.: diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 976652aee8..da19a8e86f 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -247,6 +247,40 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430") + def test_overlong_title(self): + self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] + + end_content = ( + b"" + b"" + b"x" * 2000 + b"" + b'' + b"" + ) + + request, channel = self.make_request( + "GET", "url_preview?url=http://matrix.org", shorthand=False + ) + request.render(self.preview_url) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: text/html; charset="windows-1251"\r\n\r\n' + ) + % (len(end_content),) + + end_content + ) + + self.pump() + self.assertEqual(channel.code, 200) + res = channel.json_body + self.assertCountEqual(["og:description"], res.keys()) + def test_ipaddr(self): """ IP addresses can be previewed directly. -- cgit 1.5.1 From e78167c94b3f63136f7d0e4f32a05ad1befdc0ec Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 5 Nov 2019 16:46:39 +0000 Subject: Apply suggestions from code review Co-Authored-By: Brendan Abolivier Co-Authored-By: Erik Johnston --- synapse/rest/media/v1/preview_url_resource.py | 2 +- tests/rest/media/v1/test_url_preview.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 6d8c39a410..4d4b3c1462 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -518,7 +518,7 @@ def _calc_og(tree, media_uri): if "content" in tag.attrib: # if we've got more than 50 tags, someone is taking the piss if len(og) >= 50: - logger.warning("skipping OG for page with too many og: tags") + logger.warning("Skipping OG for page with too many 'og:' tags") return {} og[tag.attrib["property"]] = tag.attrib["content"] diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index da19a8e86f..852b8ab11c 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -279,6 +279,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.pump() self.assertEqual(channel.code, 200) res = channel.json_body + # We should only see the `og:description` field, as `title` is too long and should be stripped out self.assertCountEqual(["og:description"], res.keys()) def test_ipaddr(self): -- cgit 1.5.1 From 807ec3bd99908d2991d2b3d0615b0862610c6dc3 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 6 Nov 2019 10:01:39 +0000 Subject: Fix bug which caused rejected events to be stored with the wrong room state (#6320) Fixes a bug where rejected events were persisted with the wrong state group. Also fixes an occasional internal-server-error when receiving events over federation which are rejected and (possibly because they are backwards-extremities) have no prev_group. Fixes #6289. --- changelog.d/6320.bugfix | 1 + synapse/events/snapshot.py | 25 ++++- synapse/handlers/federation.py | 1 + synapse/state/__init__.py | 172 ++++++++++++++---------------- synapse/storage/data_stores/main/state.py | 2 +- tests/handlers/test_federation.py | 126 ++++++++++++++++++++++ tests/test_state.py | 61 +++++++++-- 7 files changed, 285 insertions(+), 103 deletions(-) create mode 100644 changelog.d/6320.bugfix (limited to 'tests') diff --git a/changelog.d/6320.bugfix b/changelog.d/6320.bugfix new file mode 100644 index 0000000000..2c3fad5655 --- /dev/null +++ b/changelog.d/6320.bugfix @@ -0,0 +1 @@ +Fix bug which casued rejected events to be persisted with the wrong room state. diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 0f3c5989cb..64e898f40c 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -48,10 +48,21 @@ class EventContext: Note that this is a private attribute: it should be accessed via the ``state_group`` property. + state_group_before_event: The ID of the state group representing the state + of the room before this event. + + If this is a non-state event, this will be the same as ``state_group``. If + it's a state event, it will be the same as ``prev_group``. + + If ``state_group`` is None (ie, the event is an outlier), + ``state_group_before_event`` will always also be ``None``. + prev_group: If it is known, ``state_group``'s prev_group. Note that this being None does not necessarily mean that ``state_group`` does not have a prev_group! + If the event is a state event, this is normally the same as ``prev_group``. + If ``state_group`` is None (ie, the event is an outlier), ``prev_group`` will always also be ``None``. @@ -77,7 +88,8 @@ class EventContext: ``get_current_state_ids``. _AsyncEventContext impl calculates this on-demand: it will be None until that happens. - _prev_state_ids: The room state map, excluding this event. For a non-state + _prev_state_ids: The room state map, excluding this event - ie, the state + in ``state_group_before_event``. For a non-state event, this will be the same as _current_state_events. Note that it is a completely different thing to prev_group! @@ -92,6 +104,7 @@ class EventContext: rejected = attr.ib(default=False, type=Union[bool, str]) _state_group = attr.ib(default=None, type=Optional[int]) + state_group_before_event = attr.ib(default=None, type=Optional[int]) prev_group = attr.ib(default=None, type=Optional[int]) delta_ids = attr.ib(default=None, type=Optional[Dict[Tuple[str, str], str]]) app_service = attr.ib(default=None, type=Optional[ApplicationService]) @@ -103,12 +116,18 @@ class EventContext: @staticmethod def with_state( - state_group, current_state_ids, prev_state_ids, prev_group=None, delta_ids=None + state_group, + state_group_before_event, + current_state_ids, + prev_state_ids, + prev_group=None, + delta_ids=None, ): return EventContext( current_state_ids=current_state_ids, prev_state_ids=prev_state_ids, state_group=state_group, + state_group_before_event=state_group_before_event, prev_group=prev_group, delta_ids=delta_ids, ) @@ -140,6 +159,7 @@ class EventContext: "event_type": event.type, "event_state_key": event.state_key if event.is_state() else None, "state_group": self._state_group, + "state_group_before_event": self.state_group_before_event, "rejected": self.rejected, "prev_group": self.prev_group, "delta_ids": _encode_state_dict(self.delta_ids), @@ -165,6 +185,7 @@ class EventContext: event_type=input["event_type"], event_state_key=input["event_state_key"], state_group=input["state_group"], + state_group_before_event=input["state_group_before_event"], prev_group=input["prev_group"], delta_ids=_decode_state_dict(input["delta_ids"]), rejected=input["rejected"], diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index b7916de909..05dd8d2671 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -2280,6 +2280,7 @@ class FederationHandler(BaseHandler): return EventContext.with_state( state_group=state_group, + state_group_before_event=context.state_group_before_event, current_state_ids=current_state_ids, prev_state_ids=prev_state_ids, prev_group=prev_group, diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 2c04ab1854..139beef8ed 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -16,6 +16,7 @@ import logging from collections import namedtuple +from typing import Iterable, Optional from six import iteritems, itervalues @@ -27,6 +28,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions +from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.logging.utils import log_function from synapse.state import v1, v2 @@ -212,15 +214,17 @@ class StateHandler(object): return joined_hosts @defer.inlineCallbacks - def compute_event_context(self, event, old_state=None): + def compute_event_context( + self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None + ): """Build an EventContext structure for the event. This works out what the current state should be for the event, and generates a new state group if necessary. Args: - event (synapse.events.EventBase): - old_state (dict|None): The state at the event if it can't be + event: + old_state: The state at the event if it can't be calculated from existing events. This is normally only specified when receiving an event from federation where we don't have the prev events for, e.g. when backfilling. @@ -251,113 +255,103 @@ class StateHandler(object): # group for it. context = EventContext.with_state( state_group=None, + state_group_before_event=None, current_state_ids=current_state_ids, prev_state_ids=prev_state_ids, ) return context + # + # first of all, figure out the state before the event + # + if old_state: - # We already have the state, so we don't need to calculate it. - # Let's just correctly fill out the context and create a - # new state group for it. - - prev_state_ids = {(s.type, s.state_key): s.event_id for s in old_state} - - if event.is_state(): - key = (event.type, event.state_key) - if key in prev_state_ids: - replaces = prev_state_ids[key] - if replaces != event.event_id: # Paranoia check - event.unsigned["replaces_state"] = replaces - current_state_ids = dict(prev_state_ids) - current_state_ids[key] = event.event_id - else: - current_state_ids = prev_state_ids + # if we're given the state before the event, then we use that + state_ids_before_event = { + (s.type, s.state_key): s.event_id for s in old_state + } + state_group_before_event = None + state_group_before_event_prev_group = None + deltas_to_state_group_before_event = None - state_group = yield self.state_store.store_state_group( - event.event_id, - event.room_id, - prev_group=None, - delta_ids=None, - current_state_ids=current_state_ids, - ) + else: + # otherwise, we'll need to resolve the state across the prev_events. + logger.debug("calling resolve_state_groups from compute_event_context") - context = EventContext.with_state( - state_group=state_group, - current_state_ids=current_state_ids, - prev_state_ids=prev_state_ids, + entry = yield self.resolve_state_groups_for_events( + event.room_id, event.prev_event_ids() ) - return context + state_ids_before_event = entry.state + state_group_before_event = entry.state_group + state_group_before_event_prev_group = entry.prev_group + deltas_to_state_group_before_event = entry.delta_ids - logger.debug("calling resolve_state_groups from compute_event_context") + # + # make sure that we have a state group at that point. If it's not a state event, + # that will be the state group for the new event. If it *is* a state event, + # it might get rejected (in which case we'll need to persist it with the + # previous state group) + # - entry = yield self.resolve_state_groups_for_events( - event.room_id, event.prev_event_ids() - ) + if not state_group_before_event: + state_group_before_event = yield self.state_store.store_state_group( + event.event_id, + event.room_id, + prev_group=state_group_before_event_prev_group, + delta_ids=deltas_to_state_group_before_event, + current_state_ids=state_ids_before_event, + ) - prev_state_ids = entry.state - prev_group = None - delta_ids = None + # XXX: can we update the state cache entry for the new state group? or + # could we set a flag on resolve_state_groups_for_events to tell it to + # always make a state group? + + # + # now if it's not a state event, we're done + # + + if not event.is_state(): + return EventContext.with_state( + state_group_before_event=state_group_before_event, + state_group=state_group_before_event, + current_state_ids=state_ids_before_event, + prev_state_ids=state_ids_before_event, + prev_group=state_group_before_event_prev_group, + delta_ids=deltas_to_state_group_before_event, + ) - if event.is_state(): - # If this is a state event then we need to create a new state - # group for the state after this event. + # + # otherwise, we'll need to create a new state group for after the event + # - key = (event.type, event.state_key) - if key in prev_state_ids: - replaces = prev_state_ids[key] + key = (event.type, event.state_key) + if key in state_ids_before_event: + replaces = state_ids_before_event[key] + if replaces != event.event_id: event.unsigned["replaces_state"] = replaces - current_state_ids = dict(prev_state_ids) - current_state_ids[key] = event.event_id - - if entry.state_group: - # If the state at the event has a state group assigned then - # we can use that as the prev group - prev_group = entry.state_group - delta_ids = {key: event.event_id} - elif entry.prev_group: - # If the state at the event only has a prev group, then we can - # use that as a prev group too. - prev_group = entry.prev_group - delta_ids = dict(entry.delta_ids) - delta_ids[key] = event.event_id - - state_group = yield self.state_store.store_state_group( - event.event_id, - event.room_id, - prev_group=prev_group, - delta_ids=delta_ids, - current_state_ids=current_state_ids, - ) - else: - current_state_ids = prev_state_ids - prev_group = entry.prev_group - delta_ids = entry.delta_ids - - if entry.state_group is None: - entry.state_group = yield self.state_store.store_state_group( - event.event_id, - event.room_id, - prev_group=entry.prev_group, - delta_ids=entry.delta_ids, - current_state_ids=current_state_ids, - ) - entry.state_id = entry.state_group - - state_group = entry.state_group - - context = EventContext.with_state( - state_group=state_group, - current_state_ids=current_state_ids, - prev_state_ids=prev_state_ids, - prev_group=prev_group, + state_ids_after_event = dict(state_ids_before_event) + state_ids_after_event[key] = event.event_id + delta_ids = {key: event.event_id} + + state_group_after_event = yield self.state_store.store_state_group( + event.event_id, + event.room_id, + prev_group=state_group_before_event, delta_ids=delta_ids, + current_state_ids=state_ids_after_event, ) - return context + return EventContext.with_state( + state_group=state_group_after_event, + state_group_before_event=state_group_before_event, + current_state_ids=state_ids_after_event, + prev_state_ids=state_ids_before_event, + prev_group=state_group_before_event, + delta_ids=delta_ids, + ) @measure_func() @defer.inlineCallbacks diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 3132848034..9e1541988e 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -1231,7 +1231,7 @@ class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore): # if the event was rejected, just give it the same state as its # predecessor. if context.rejected: - state_groups[event.event_id] = context.prev_group + state_groups[event.event_id] = context.state_group_before_event continue state_groups[event.event_id] = context.state_group diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index d56220f403..b4d92cf732 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -12,13 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging + from synapse.api.constants import EventTypes from synapse.api.errors import AuthError, Codes +from synapse.federation.federation_base import event_from_pdu_json +from synapse.logging.context import LoggingContext, run_in_background from synapse.rest import admin from synapse.rest.client.v1 import login, room from tests import unittest +logger = logging.getLogger(__name__) + class FederationTestCase(unittest.HomeserverTestCase): servlets = [ @@ -79,3 +85,123 @@ class FederationTestCase(unittest.HomeserverTestCase): self.assertEqual(failure.code, 403, failure) self.assertEqual(failure.errcode, Codes.FORBIDDEN, failure) self.assertEqual(failure.msg, "You are not invited to this room.") + + def test_rejected_message_event_state(self): + """ + Check that we store the state group correctly for rejected non-state events. + + Regression test for #6289. + """ + OTHER_SERVER = "otherserver" + OTHER_USER = "@otheruser:" + OTHER_SERVER + + # create the room + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + + # pretend that another server has joined + join_event = self._build_and_send_join_event(OTHER_SERVER, OTHER_USER, room_id) + + # check the state group + sg = self.successResultOf( + self.store._get_state_group_for_event(join_event.event_id) + ) + + # build and send an event which will be rejected + ev = event_from_pdu_json( + { + "type": EventTypes.Message, + "content": {}, + "room_id": room_id, + "sender": "@yetanotheruser:" + OTHER_SERVER, + "depth": join_event["depth"] + 1, + "prev_events": [join_event.event_id], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + join_event.format_version, + ) + + with LoggingContext(request="send_rejected"): + d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev) + self.get_success(d) + + # that should have been rejected + e = self.get_success(self.store.get_event(ev.event_id, allow_rejected=True)) + self.assertIsNotNone(e.rejected_reason) + + # ... and the state group should be the same as before + sg2 = self.successResultOf(self.store._get_state_group_for_event(ev.event_id)) + + self.assertEqual(sg, sg2) + + def test_rejected_state_event_state(self): + """ + Check that we store the state group correctly for rejected state events. + + Regression test for #6289. + """ + OTHER_SERVER = "otherserver" + OTHER_USER = "@otheruser:" + OTHER_SERVER + + # create the room + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + + # pretend that another server has joined + join_event = self._build_and_send_join_event(OTHER_SERVER, OTHER_USER, room_id) + + # check the state group + sg = self.successResultOf( + self.store._get_state_group_for_event(join_event.event_id) + ) + + # build and send an event which will be rejected + ev = event_from_pdu_json( + { + "type": "org.matrix.test", + "state_key": "test_key", + "content": {}, + "room_id": room_id, + "sender": "@yetanotheruser:" + OTHER_SERVER, + "depth": join_event["depth"] + 1, + "prev_events": [join_event.event_id], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + join_event.format_version, + ) + + with LoggingContext(request="send_rejected"): + d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev) + self.get_success(d) + + # that should have been rejected + e = self.get_success(self.store.get_event(ev.event_id, allow_rejected=True)) + self.assertIsNotNone(e.rejected_reason) + + # ... and the state group should be the same as before + sg2 = self.successResultOf(self.store._get_state_group_for_event(ev.event_id)) + + self.assertEqual(sg, sg2) + + def _build_and_send_join_event(self, other_server, other_user, room_id): + join_event = self.get_success( + self.handler.on_make_join_request(other_server, room_id, other_user) + ) + # the auth code requires that a signature exists, but doesn't check that + # signature... go figure. + join_event.signatures[other_server] = {"x": "y"} + with LoggingContext(request="send_join"): + d = run_in_background( + self.handler.on_send_join_request, other_server, join_event + ) + self.get_success(d) + + # sanity-check: the room should show that the new user is a member + r = self.get_success(self.store.get_current_state_ids(room_id)) + self.assertEqual(r[(EventTypes.Member, other_user)], join_event.event_id) + + return join_event diff --git a/tests/test_state.py b/tests/test_state.py index 38246555bd..176535947a 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -21,6 +21,7 @@ from synapse.api.auth import Auth from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions from synapse.events import FrozenEvent +from synapse.events.snapshot import EventContext from synapse.state import StateHandler, StateResolutionHandler from tests import unittest @@ -198,16 +199,22 @@ class StateTestCase(unittest.TestCase): self.store.register_events(graph.walk()) - context_store = {} + context_store = {} # type: dict[str, EventContext] for event in graph.walk(): context = yield self.state.compute_event_context(event) self.store.register_event_context(event, context) context_store[event.event_id] = context - prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store) + ctx_c = context_store["C"] + ctx_d = context_store["D"] + + prev_state_ids = yield ctx_d.get_prev_state_ids(self.store) self.assertEqual(2, len(prev_state_ids)) + self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event) + self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group) + @defer.inlineCallbacks def test_branch_basic_conflict(self): graph = Graph( @@ -241,12 +248,19 @@ class StateTestCase(unittest.TestCase): self.store.register_event_context(event, context) context_store[event.event_id] = context - prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store) + # C ends up winning the resolution between B and C + + ctx_c = context_store["C"] + ctx_d = context_store["D"] + prev_state_ids = yield ctx_d.get_prev_state_ids(self.store) self.assertSetEqual( {"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()} ) + self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event) + self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group) + @defer.inlineCallbacks def test_branch_have_banned_conflict(self): graph = Graph( @@ -292,11 +306,18 @@ class StateTestCase(unittest.TestCase): self.store.register_event_context(event, context) context_store[event.event_id] = context - prev_state_ids = yield context_store["E"].get_prev_state_ids(self.store) + # C ends up winning the resolution between C and D because bans win over other + # changes + + ctx_c = context_store["C"] + ctx_e = context_store["E"] + prev_state_ids = yield ctx_e.get_prev_state_ids(self.store) self.assertSetEqual( {"START", "A", "B", "C"}, {e for e in prev_state_ids.values()} ) + self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event) + self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group) @defer.inlineCallbacks def test_branch_have_perms_conflict(self): @@ -360,12 +381,20 @@ class StateTestCase(unittest.TestCase): self.store.register_event_context(event, context) context_store[event.event_id] = context - prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store) + # B ends up winning the resolution between B and C because power levels + # win over other changes. + ctx_b = context_store["B"] + ctx_d = context_store["D"] + + prev_state_ids = yield ctx_d.get_prev_state_ids(self.store) self.assertSetEqual( {"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()} ) + self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event) + self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group) + def _add_depths(self, nodes, edges): def _get_depth(ev): node = nodes[ev] @@ -390,13 +419,16 @@ class StateTestCase(unittest.TestCase): context = yield self.state.compute_event_context(event, old_state=old_state) - current_state_ids = yield context.get_current_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids(self.store) + self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) - self.assertEqual( - set(e.event_id for e in old_state), set(current_state_ids.values()) + current_state_ids = yield context.get_current_state_ids(self.store) + self.assertCountEqual( + (e.event_id for e in old_state), current_state_ids.values() ) - self.assertIsNotNone(context.state_group) + self.assertIsNotNone(context.state_group_before_event) + self.assertEqual(context.state_group_before_event, context.state_group) @defer.inlineCallbacks def test_annotate_with_old_state(self): @@ -411,11 +443,18 @@ class StateTestCase(unittest.TestCase): context = yield self.state.compute_event_context(event, old_state=old_state) prev_state_ids = yield context.get_prev_state_ids(self.store) + self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) - self.assertEqual( - set(e.event_id for e in old_state), set(prev_state_ids.values()) + current_state_ids = yield context.get_current_state_ids(self.store) + self.assertCountEqual( + (e.event_id for e in old_state + [event]), current_state_ids.values() ) + self.assertIsNotNone(context.state_group_before_event) + self.assertNotEqual(context.state_group_before_event, context.state_group) + self.assertEqual(context.state_group_before_event, context.prev_group) + self.assertEqual({("state", ""): event.event_id}, context.delta_ids) + @defer.inlineCallbacks def test_trivial_annotate_message(self): prev_event_id = "prev_event_id" -- cgit 1.5.1 From 5c3363233cca7044a333b7e19ba239eaf5587ff8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 6 Nov 2019 17:02:05 +0000 Subject: Fix deleting state groups during room purge. And fix the tests to actually test that things got deleted. --- synapse/storage/data_stores/main/events.py | 27 ++++++++++++++------------- tests/rest/admin/test_admin.py | 4 +++- 2 files changed, 17 insertions(+), 14 deletions(-) (limited to 'tests') diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index d69c59f5a1..946823876a 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -1633,7 +1633,20 @@ class EventsStore( return self.runInteraction("purge_room", self._purge_room_txn, room_id) def _purge_room_txn(self, txn, room_id): - # First delete tables which lack an index on room_id but have one on event_id + # First we fetch all the state groups that should be deleted, before + # we delete that information. + txn.execute( + """ + SELECT DISTINCT state_group FROM events + INNER JOIN event_to_state_groups USING(event_id) + WHERE events.room_id = ? + """, + (room_id,), + ) + + state_groups = [row[0] for row in txn] + + # Now we delete tables which lack an index on room_id but have one on event_id for table in ( "event_auth", "event_edges", @@ -1717,18 +1730,6 @@ class EventsStore( # index on them. In any case we should be clearing out 'stream' tables # periodically anyway (#5888) - # Now we fetch all the state groups that should be deleted. - txn.execute( - """ - SELECT DISTINCT state_group FROM events - INNER JOIN event_to_state_groups USING(event_id) - WHERE events.room_id = ? - """, - (room_id,), - ) - - state_groups = [row[0] for row in txn] - # TODO: we could probably usefully do a bunch of cache invalidation here logger.info("[purge] done") diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 8e1ca8b738..d9f1b95cb0 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -628,10 +628,12 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase): "local_invites", "room_account_data", "room_tags", + "state_groups", + "state_groups_state", ): count = self.get_success( self.store._simple_select_one_onecol( - table="events", + table=table, keyvalues={"room_id": room_id}, retcol="COUNT(*)", desc="test_purge_room", -- cgit 1.5.1 From f03c9d34442368c8f2c26d2ac16b770bc451c76d Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 6 Nov 2019 15:47:40 +0000 Subject: Don't apply retention policy based filtering on state events As per MSC1763, 'Retention is only considered for non-state events.', so don't filter out state events based on the room's retention policy. --- synapse/visibility.py | 15 +++++++++------ tests/rest/client/test_retention.py | 10 ++++++++++ 2 files changed, 19 insertions(+), 6 deletions(-) (limited to 'tests') diff --git a/synapse/visibility.py b/synapse/visibility.py index 4498c156bc..4d4141dacc 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -111,14 +111,17 @@ def filter_events_for_client( if not event.is_state() and event.sender in ignore_list: return None - retention_policy = retention_policies[event.room_id] - max_lifetime = retention_policy.get("max_lifetime") + # Don't try to apply the room's retention policy if the event is a state event, as + # MSC1763 states that retention is only considered for non-state events. + if not event.is_state(): + retention_policy = retention_policies[event.room_id] + max_lifetime = retention_policy.get("max_lifetime") - if max_lifetime is not None: - oldest_allowed_ts = storage.main.clock.time_msec() - max_lifetime + if max_lifetime is not None: + oldest_allowed_ts = storage.main.clock.time_msec() - max_lifetime - if event.origin_server_ts < oldest_allowed_ts: - return None + if event.origin_server_ts < oldest_allowed_ts: + return None if event.event_id in always_include_ids: return event diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index 41ea9db689..7b6f25a838 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -164,6 +164,12 @@ class RetentionTestCase(unittest.HomeserverTestCase): self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events) def _test_retention_event_purged(self, room_id, increment): + # Get the create event to, later, check that we can still access it. + message_handler = self.hs.get_message_handler() + create_event = self.get_success( + message_handler.get_room_data(self.user_id, room_id, EventTypes.Create) + ) + # Send a first event to the room. This is the event we'll want to be purged at the # end of the test. resp = self.helper.send( @@ -202,6 +208,10 @@ class RetentionTestCase(unittest.HomeserverTestCase): valid_event = self.get_event(room_id, valid_event_id) self.assertEqual(valid_event.get("content", {}).get("body"), "2", valid_event) + # Check that we can still access state events that were sent before the event that + # has been purged. + self.get_event(room_id, create_event.event_id) + def get_event(self, room_id, event_id, expected_code=200): url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) -- cgit 1.5.1 From c350bc2f92d87e46a40f917f65c9e10e0f4999fc Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 13 Nov 2019 19:09:20 +0000 Subject: Blacklist PurgeRoomTestCase (#6361) --- changelog.d/6361.misc | 1 + tests/rest/admin/test_admin.py | 2 ++ 2 files changed, 3 insertions(+) create mode 100644 changelog.d/6361.misc (limited to 'tests') diff --git a/changelog.d/6361.misc b/changelog.d/6361.misc new file mode 100644 index 0000000000..324d74ebf9 --- /dev/null +++ b/changelog.d/6361.misc @@ -0,0 +1 @@ +Temporarily blacklist the failing unit test PurgeRoomTestCase.test_purge_room. diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index d9f1b95cb0..9575058252 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -641,3 +641,5 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase): ) self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) + + test_purge_room.skip = "Disabled because it's currently broken" -- cgit 1.5.1 From 7c24d0f443724082376c89f9f75954d81f524a8e Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 19 Nov 2019 13:22:37 +0000 Subject: Lint --- synapse/config/server.py | 39 ++++++++++------- synapse/handlers/pagination.py | 17 +++----- synapse/storage/data_stores/main/room.py | 49 +++++++++++---------- tests/rest/client/test_retention.py | 73 ++++++++++---------------------- 4 files changed, 77 insertions(+), 101 deletions(-) (limited to 'tests') diff --git a/synapse/config/server.py b/synapse/config/server.py index aa93a416f1..8a55ffac4f 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -19,7 +19,7 @@ import logging import os.path import re from textwrap import indent -from typing import List +from typing import List, Dict, Optional import attr import yaml @@ -287,13 +287,17 @@ class ServerConfig(Config): self.retention_default_min_lifetime = None self.retention_default_max_lifetime = None - self.retention_allowed_lifetime_min = retention_config.get("allowed_lifetime_min") + self.retention_allowed_lifetime_min = retention_config.get( + "allowed_lifetime_min" + ) if self.retention_allowed_lifetime_min is not None: self.retention_allowed_lifetime_min = self.parse_duration( self.retention_allowed_lifetime_min ) - self.retention_allowed_lifetime_max = retention_config.get("allowed_lifetime_max") + self.retention_allowed_lifetime_max = retention_config.get( + "allowed_lifetime_max" + ) if self.retention_allowed_lifetime_max is not None: self.retention_allowed_lifetime_max = self.parse_duration( self.retention_allowed_lifetime_max @@ -302,14 +306,15 @@ class ServerConfig(Config): if ( self.retention_allowed_lifetime_min is not None and self.retention_allowed_lifetime_max is not None - and self.retention_allowed_lifetime_min > self.retention_allowed_lifetime_max + and self.retention_allowed_lifetime_min + > self.retention_allowed_lifetime_max ): raise ConfigError( "Invalid retention policy limits: 'allowed_lifetime_min' can not be" " greater than 'allowed_lifetime_max'" ) - self.retention_purge_jobs = [] + self.retention_purge_jobs = [] # type: List[Dict[str, Optional[int]]] for purge_job_config in retention_config.get("purge_jobs", []): interval_config = purge_job_config.get("interval") @@ -342,18 +347,22 @@ class ServerConfig(Config): " 'longest_max_lifetime' value." ) - self.retention_purge_jobs.append({ - "interval": interval, - "shortest_max_lifetime": shortest_max_lifetime, - "longest_max_lifetime": longest_max_lifetime, - }) + self.retention_purge_jobs.append( + { + "interval": interval, + "shortest_max_lifetime": shortest_max_lifetime, + "longest_max_lifetime": longest_max_lifetime, + } + ) if not self.retention_purge_jobs: - self.retention_purge_jobs = [{ - "interval": self.parse_duration("1d"), - "shortest_max_lifetime": None, - "longest_max_lifetime": None, - }] + self.retention_purge_jobs = [ + { + "interval": self.parse_duration("1d"), + "shortest_max_lifetime": None, + "longest_max_lifetime": None, + } + ] self.listeners = [] # type: List[dict] for listener in config.get("listeners", []): diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index e1800177fa..d122c11a4d 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -154,20 +154,17 @@ class PaginationHandler(object): # Figure out what token we should start purging at. ts = self.clock.time_msec() - max_lifetime - stream_ordering = ( - yield self.store.find_first_stream_ordering_after_ts(ts) - ) + stream_ordering = yield self.store.find_first_stream_ordering_after_ts(ts) - r = ( - yield self.store.get_room_event_after_stream_ordering( - room_id, stream_ordering, - ) + r = yield self.store.get_room_event_after_stream_ordering( + room_id, stream_ordering, ) if not r: logger.warning( "[purge] purging events not possible: No event found " "(ts %i => stream_ordering %i)", - ts, stream_ordering, + ts, + stream_ordering, ) continue @@ -186,9 +183,7 @@ class PaginationHandler(object): # the background so that it's not blocking any other operation apart from # other purges in the same room. run_as_background_process( - "_purge_history", - self._purge_history, - purge_id, room_id, token, True, + "_purge_history", self._purge_history, purge_id, room_id, token, True, ) def start_purge_history(self, room_id, token, delete_local_events=False): diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index 54a7d24c73..7fceae59ca 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -334,8 +334,9 @@ class RoomStore(RoomWorkerStore, SearchStore): WHERE state.room_id > ? AND state.type = '%s' ORDER BY state.room_id ASC LIMIT ?; - """ % EventTypes.Retention, - (last_room, batch_size) + """ + % EventTypes.Retention, + (last_room, batch_size), ) rows = self.cursor_to_dict(txn) @@ -358,15 +359,13 @@ class RoomStore(RoomWorkerStore, SearchStore): "event_id": row["event_id"], "min_lifetime": retention_policy.get("min_lifetime"), "max_lifetime": retention_policy.get("max_lifetime"), - } + }, ) logger.info("Inserted %d rows into room_retention", len(rows)) self._background_update_progress_txn( - txn, "insert_room_retention", { - "room_id": rows[-1]["room_id"], - } + txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]} ) if batch_size > len(rows): @@ -375,8 +374,7 @@ class RoomStore(RoomWorkerStore, SearchStore): return False end = yield self.runInteraction( - "insert_room_retention", - _background_insert_retention_txn, + "insert_room_retention", _background_insert_retention_txn, ) if end: @@ -585,17 +583,15 @@ class RoomStore(RoomWorkerStore, SearchStore): ) def _store_retention_policy_for_room_txn(self, txn, event): - if ( - hasattr(event, "content") - and ("min_lifetime" in event.content or "max_lifetime" in event.content) + if hasattr(event, "content") and ( + "min_lifetime" in event.content or "max_lifetime" in event.content ): if ( - ("min_lifetime" in event.content and not isinstance( - event.content.get("min_lifetime"), integer_types - )) - or ("max_lifetime" in event.content and not isinstance( - event.content.get("max_lifetime"), integer_types - )) + "min_lifetime" in event.content + and not isinstance(event.content.get("min_lifetime"), integer_types) + ) or ( + "max_lifetime" in event.content + and not isinstance(event.content.get("max_lifetime"), integer_types) ): # Ignore the event if one of the value isn't an integer. return @@ -798,7 +794,9 @@ class RoomStore(RoomWorkerStore, SearchStore): return local_media_mxcs, remote_media_mxcs @defer.inlineCallbacks - def get_rooms_for_retention_period_in_range(self, min_ms, max_ms, include_null=False): + def get_rooms_for_retention_period_in_range( + self, min_ms, max_ms, include_null=False + ): """Retrieves all of the rooms within the given retention range. Optionally includes the rooms which don't have a retention policy. @@ -904,23 +902,24 @@ class RoomStore(RoomWorkerStore, SearchStore): INNER JOIN current_state_events USING (event_id, room_id) WHERE room_id = ?; """, - (room_id,) + (room_id,), ) return self.cursor_to_dict(txn) ret = yield self.runInteraction( - "get_retention_policy_for_room", - get_retention_policy_for_room_txn, + "get_retention_policy_for_room", get_retention_policy_for_room_txn, ) # If we don't know this room ID, ret will be None, in this case return the default # policy. if not ret: - defer.returnValue({ - "min_lifetime": self.config.retention_default_min_lifetime, - "max_lifetime": self.config.retention_default_max_lifetime, - }) + defer.returnValue( + { + "min_lifetime": self.config.retention_default_min_lifetime, + "max_lifetime": self.config.retention_default_max_lifetime, + } + ) row = ret[0] diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index 7b6f25a838..6bf485c239 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -61,9 +61,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): self.helper.send_state( room_id=room_id, event_type=EventTypes.Retention, - body={ - "max_lifetime": one_day_ms * 4, - }, + body={"max_lifetime": one_day_ms * 4}, tok=self.token, expect_code=400, ) @@ -71,9 +69,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): self.helper.send_state( room_id=room_id, event_type=EventTypes.Retention, - body={ - "max_lifetime": one_hour_ms, - }, + body={"max_lifetime": one_hour_ms}, tok=self.token, expect_code=400, ) @@ -89,9 +85,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): self.helper.send_state( room_id=room_id, event_type=EventTypes.Retention, - body={ - "max_lifetime": lifetime, - }, + body={"max_lifetime": lifetime}, tok=self.token, ) @@ -115,20 +109,12 @@ class RetentionTestCase(unittest.HomeserverTestCase): events = [] # Send a first event, which should be filtered out at the end of the test. - resp = self.helper.send( - room_id=room_id, - body="1", - tok=self.token, - ) + resp = self.helper.send(room_id=room_id, body="1", tok=self.token) # Get the event from the store so that we end up with a FrozenEvent that we can # give to filter_events_for_client. We need to do this now because the event won't # be in the database anymore after it has expired. - events.append(self.get_success( - store.get_event( - resp.get("event_id") - ) - )) + events.append(self.get_success(store.get_event(resp.get("event_id")))) # Advance the time by 2 days. We're using the default retention policy, therefore # after this the first event will still be valid. @@ -143,20 +129,16 @@ class RetentionTestCase(unittest.HomeserverTestCase): valid_event_id = resp.get("event_id") - events.append(self.get_success( - store.get_event( - valid_event_id - ) - )) + events.append(self.get_success(store.get_event(valid_event_id))) # Advance the time by anothe 2 days. After this, the first event should be # outdated but not the second one. self.reactor.advance(one_day_ms * 2 / 1000) # Run filter_events_for_client with our list of FrozenEvents. - filtered_events = self.get_success(filter_events_for_client( - storage, self.user_id, events - )) + filtered_events = self.get_success( + filter_events_for_client(storage, self.user_id, events) + ) # We should only get one event back. self.assertEqual(len(filtered_events), 1, filtered_events) @@ -172,28 +154,22 @@ class RetentionTestCase(unittest.HomeserverTestCase): # Send a first event to the room. This is the event we'll want to be purged at the # end of the test. - resp = self.helper.send( - room_id=room_id, - body="1", - tok=self.token, - ) + resp = self.helper.send(room_id=room_id, body="1", tok=self.token) expired_event_id = resp.get("event_id") # Check that we can retrieve the event. expired_event = self.get_event(room_id, expired_event_id) - self.assertEqual(expired_event.get("content", {}).get("body"), "1", expired_event) + self.assertEqual( + expired_event.get("content", {}).get("body"), "1", expired_event + ) # Advance the time. self.reactor.advance(increment / 1000) # Send another event. We need this because the purge job won't purge the most # recent event in the room. - resp = self.helper.send( - room_id=room_id, - body="2", - tok=self.token, - ) + resp = self.helper.send(room_id=room_id, body="2", tok=self.token) valid_event_id = resp.get("event_id") @@ -240,8 +216,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): mock_federation_client = Mock(spec=["backfill"]) self.hs = self.setup_test_homeserver( - config=config, - federation_client=mock_federation_client, + config=config, federation_client=mock_federation_client, ) return self.hs @@ -268,9 +243,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): self.helper.send_state( room_id=room_id, event_type=EventTypes.Retention, - body={ - "max_lifetime": one_day_ms * 35, - }, + body={"max_lifetime": one_day_ms * 35}, tok=self.token, ) @@ -289,18 +262,16 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): # Check that we can retrieve the event. expired_event = self.get_event(room_id, first_event_id) - self.assertEqual(expired_event.get("content", {}).get("body"), "1", expired_event) + self.assertEqual( + expired_event.get("content", {}).get("body"), "1", expired_event + ) # Advance the time by a month. self.reactor.advance(one_day_ms * 30 / 1000) # Send another event. We need this because the purge job won't purge the most # recent event in the room. - resp = self.helper.send( - room_id=room_id, - body="2", - tok=self.token, - ) + resp = self.helper.send(room_id=room_id, body="2", tok=self.token) second_event_id = resp.get("event_id") @@ -313,7 +284,9 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): ) if expected_code_for_first_event == 200: - self.assertEqual(first_event.get("content", {}).get("body"), "1", first_event) + self.assertEqual( + first_event.get("content", {}).get("body"), "1", first_event + ) # Check that the event that hasn't been purged can still be retrieved. second_event = self.get_event(room_id, second_event_id) -- cgit 1.5.1 From bf9a11c54d3c9ca1e4fa9420b567efceb1e46d5b Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 19 Nov 2019 13:30:04 +0000 Subject: Lint again --- synapse/config/server.py | 2 +- tests/rest/client/test_retention.py | 12 ++---------- 2 files changed, 3 insertions(+), 11 deletions(-) (limited to 'tests') diff --git a/synapse/config/server.py b/synapse/config/server.py index 8a55ffac4f..ed916e1400 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -19,7 +19,7 @@ import logging import os.path import re from textwrap import indent -from typing import List, Dict, Optional +from typing import Dict, List, Optional import attr import yaml diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index 6bf485c239..9e549d8a91 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -121,11 +121,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): self.reactor.advance(one_day_ms * 2 / 1000) # Send another event, which shouldn't get filtered out. - resp = self.helper.send( - room_id=room_id, - body="2", - tok=self.token, - ) + resp = self.helper.send(room_id=room_id, body="2", tok=self.token) valid_event_id = resp.get("event_id") @@ -252,11 +248,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): def _test_retention(self, room_id, expected_code_for_first_event=200): # Send a first event to the room. This is the event we'll want to be purged at the # end of the test. - resp = self.helper.send( - room_id=room_id, - body="1", - tok=self.token, - ) + resp = self.helper.send(room_id=room_id, body="1", tok=self.token) first_event_id = resp.get("event_id") -- cgit 1.5.1 From 6356f2088f0adb681fe24a8435955b19883fa3b4 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 20 Nov 2019 12:09:06 +0000 Subject: Test if a purge can make /messages return 500 responses --- tests/rest/client/v1/test_rooms.py | 72 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) (limited to 'tests') diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 5e38fd6ced..ebaa67e899 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -25,7 +25,9 @@ from twisted.internet import defer import synapse.rest.admin from synapse.api.constants import EventContentFields, EventTypes, Membership +from synapse.handlers.pagination import PurgeStatus from synapse.rest.client.v1 import login, profile, room +from synapse.util.stringutils import random_string from tests import unittest @@ -910,6 +912,76 @@ class RoomMessageListTestCase(RoomBase): return channel.json_body["chunk"] + def test_room_messages_purge(self): + store = self.hs.get_datastore() + pagination_handler = self.hs.get_pagination_handler() + + # Send a first message in the room, which will be removed by the purge. + first_event_id = self.helper.send(self.room_id, "message 1")["event_id"] + first_token = self.get_success( + store.get_topological_token_for_event(first_event_id) + ) + + # Send a second message in the room, which won't be removed, and which we'll + # use as the marker to purge events before. + second_event_id = self.helper.send(self.room_id, "message 2")["event_id"] + second_token = self.get_success( + store.get_topological_token_for_event(second_event_id) + ) + + # Send a third event in the room to ensure we don't fall under any edge case + # due to our marker being the latest forward extremity in the room. + self.helper.send(self.room_id, "message 3") + + # Check that we get the first and second message when querying /messages. + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" + % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})), + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 2, [event["content"] for event in chunk]) + + # Purge every event before the second event. + purge_id = random_string(16) + pagination_handler._purges_by_id[purge_id] = PurgeStatus() + self.get_success(pagination_handler._purge_history( + purge_id=purge_id, + room_id=self.room_id, + token=second_token, + delete_local_events=True, + )) + + # Check that we only get the second message through /message now that the first + # has been purged. + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" + % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})), + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 1, [event["content"] for event in chunk]) + + # Check that we get no event, but also no error, when querying /messages with + # the token that was pointing at the first event, because we don't have it + # anymore. + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" + % (self.room_id, first_token, json.dumps({"types": [EventTypes.Message]})), + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 0, [event["content"] for event in chunk]) + class RoomSearchTestCase(unittest.HomeserverTestCase): servlets = [ -- cgit 1.5.1 From e2a20326e8141fdf9304434901da38c64b917a78 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 20 Nov 2019 15:08:47 +0000 Subject: Lint --- tests/rest/client/v1/test_rooms.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) (limited to 'tests') diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index ebaa67e899..e84e578f99 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -948,12 +948,14 @@ class RoomMessageListTestCase(RoomBase): # Purge every event before the second event. purge_id = random_string(16) pagination_handler._purges_by_id[purge_id] = PurgeStatus() - self.get_success(pagination_handler._purge_history( - purge_id=purge_id, - room_id=self.room_id, - token=second_token, - delete_local_events=True, - )) + self.get_success( + pagination_handler._purge_history( + purge_id=purge_id, + room_id=self.room_id, + token=second_token, + delete_local_events=True, + ) + ) # Check that we only get the second message through /message now that the first # has been purged. -- cgit 1.5.1 From 9eebd46048d0b34767047b2156760a1467f19ae6 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Tue, 26 Nov 2019 03:45:50 +1100 Subject: Improve the performance of structured logging (#6322) --- changelog.d/6322.misc | 1 + synapse/logging/_structured.py | 14 +++++- synapse/logging/_terse_json.py | 106 ++++++++++++++++++++++++++++++----------- tests/server.py | 2 + 4 files changed, 93 insertions(+), 30 deletions(-) create mode 100644 changelog.d/6322.misc (limited to 'tests') diff --git a/changelog.d/6322.misc b/changelog.d/6322.misc new file mode 100644 index 0000000000..70ef36ca80 --- /dev/null +++ b/changelog.d/6322.misc @@ -0,0 +1 @@ +Improve the performance of outputting structured logging. diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py index 334ddaf39a..ffa7b20ca8 100644 --- a/synapse/logging/_structured.py +++ b/synapse/logging/_structured.py @@ -261,6 +261,18 @@ def parse_drain_configs( ) +class StoppableLogPublisher(LogPublisher): + """ + A log publisher that can tell its observers to shut down any external + communications. + """ + + def stop(self): + for obs in self._observers: + if hasattr(obs, "stop"): + obs.stop() + + def setup_structured_logging( hs, config, @@ -336,7 +348,7 @@ def setup_structured_logging( # We should never get here, but, just in case, throw an error. raise ConfigError("%s drain type cannot be configured" % (observer.type,)) - publisher = LogPublisher(*observers) + publisher = StoppableLogPublisher(*observers) log_filter = LogLevelFilterPredicate() for namespace, namespace_config in log_config.get( diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py index 76ce7d8808..05fc64f409 100644 --- a/synapse/logging/_terse_json.py +++ b/synapse/logging/_terse_json.py @@ -17,25 +17,29 @@ Log formatters that output terse JSON. """ +import json import sys +import traceback from collections import deque from ipaddress import IPv4Address, IPv6Address, ip_address from math import floor -from typing import IO +from typing import IO, Optional import attr -from simplejson import dumps from zope.interface import implementer from twisted.application.internet import ClientService +from twisted.internet.defer import Deferred from twisted.internet.endpoints import ( HostnameEndpoint, TCP4ClientEndpoint, TCP6ClientEndpoint, ) +from twisted.internet.interfaces import IPushProducer, ITransport from twisted.internet.protocol import Factory, Protocol from twisted.logger import FileLogObserver, ILogObserver, Logger -from twisted.python.failure import Failure + +_encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":")) def flatten_event(event: dict, metadata: dict, include_time: bool = False): @@ -141,11 +145,49 @@ def TerseJSONToConsoleLogObserver(outFile: IO[str], metadata: dict) -> FileLogOb def formatEvent(_event: dict) -> str: flattened = flatten_event(_event, metadata) - return dumps(flattened, ensure_ascii=False, separators=(",", ":")) + "\n" + return _encoder.encode(flattened) + "\n" return FileLogObserver(outFile, formatEvent) +@attr.s +@implementer(IPushProducer) +class LogProducer(object): + """ + An IPushProducer that writes logs from its buffer to its transport when it + is resumed. + + Args: + buffer: Log buffer to read logs from. + transport: Transport to write to. + """ + + transport = attr.ib(type=ITransport) + _buffer = attr.ib(type=deque) + _paused = attr.ib(default=False, type=bool, init=False) + + def pauseProducing(self): + self._paused = True + + def stopProducing(self): + self._paused = True + self._buffer = None + + def resumeProducing(self): + self._paused = False + + while self._paused is False and (self._buffer and self.transport.connected): + try: + event = self._buffer.popleft() + self.transport.write(_encoder.encode(event).encode("utf8")) + self.transport.write(b"\n") + except Exception: + # Something has gone wrong writing to the transport -- log it + # and break out of the while. + traceback.print_exc(file=sys.__stderr__) + break + + @attr.s @implementer(ILogObserver) class TerseJSONToTCPLogObserver(object): @@ -165,8 +207,9 @@ class TerseJSONToTCPLogObserver(object): metadata = attr.ib(type=dict) maximum_buffer = attr.ib(type=int) _buffer = attr.ib(default=attr.Factory(deque), type=deque) - _writer = attr.ib(default=None) + _connection_waiter = attr.ib(default=None, type=Optional[Deferred]) _logger = attr.ib(default=attr.Factory(Logger)) + _producer = attr.ib(default=None, type=Optional[LogProducer]) def start(self) -> None: @@ -187,38 +230,43 @@ class TerseJSONToTCPLogObserver(object): factory = Factory.forProtocol(Protocol) self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor()) self._service.startService() + self._connect() - def _write_loop(self) -> None: + def stop(self): + self._service.stopService() + + def _connect(self) -> None: """ - Implement the write loop. + Triggers an attempt to connect then write to the remote if not already writing. """ - if self._writer: + if self._connection_waiter: return - self._writer = self._service.whenConnected() + self._connection_waiter = self._service.whenConnected(failAfterFailures=1) + + @self._connection_waiter.addErrback + def fail(r): + r.printTraceback(file=sys.__stderr__) + self._connection_waiter = None + self._connect() - @self._writer.addBoth + @self._connection_waiter.addCallback def writer(r): - if isinstance(r, Failure): - r.printTraceback(file=sys.__stderr__) - self._writer = None - self.hs.get_reactor().callLater(1, self._write_loop) + # We have a connection. If we already have a producer, and its + # transport is the same, just trigger a resumeProducing. + if self._producer and r.transport is self._producer.transport: + self._producer.resumeProducing() return - try: - for event in self._buffer: - r.transport.write( - dumps(event, ensure_ascii=False, separators=(",", ":")).encode( - "utf8" - ) - ) - r.transport.write(b"\n") - self._buffer.clear() - except Exception as e: - sys.__stderr__.write("Failed writing out logs with %s\n" % (str(e),)) - - self._writer = False - self.hs.get_reactor().callLater(1, self._write_loop) + # If the producer is still producing, stop it. + if self._producer: + self._producer.stopProducing() + + # Make a new producer and start it. + self._producer = LogProducer(buffer=self._buffer, transport=r.transport) + r.transport.registerProducer(self._producer, True) + self._producer.resumeProducing() + self._connection_waiter = None def _handle_pressure(self) -> None: """ @@ -277,4 +325,4 @@ class TerseJSONToTCPLogObserver(object): self._logger.failure("Failed clearing backpressure") # Try and write immediately. - self._write_loop() + self._connect() diff --git a/tests/server.py b/tests/server.py index f878aeaada..2b7cf4242e 100644 --- a/tests/server.py +++ b/tests/server.py @@ -379,6 +379,7 @@ class FakeTransport(object): disconnecting = False disconnected = False + connected = True buffer = attr.ib(default=b"") producer = attr.ib(default=None) autoflush = attr.ib(default=True) @@ -402,6 +403,7 @@ class FakeTransport(object): "FakeTransport: Delaying disconnect until buffer is flushed" ) else: + self.connected = False self.disconnected = True def abortConnection(self): -- cgit 1.5.1 From f0ef9708241ec65fe6f32f1ad36f719ab4ab2b53 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 26 Nov 2019 17:49:12 +0000 Subject: Don't restrict the tests to v1 rooms --- tests/rest/client/test_retention.py | 2 -- 1 file changed, 2 deletions(-) (limited to 'tests') diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index 9e549d8a91..95475bb651 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -34,7 +34,6 @@ class RetentionTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() - config["default_room_version"] = "1" config["retention"] = { "enabled": True, "default_policy": { @@ -204,7 +203,6 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() - config["default_room_version"] = "1" config["retention"] = { "enabled": True, } -- cgit 1.5.1 From ce578031f4d0fe6f1eb26de4cb3d30a4175468db Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 26 Nov 2019 18:42:27 +0000 Subject: Remove assertion and provide a clear warning on startup for missing public_baseurl (#6379) --- changelog.d/6379.misc | 1 + synapse/config/emailconfig.py | 2 ++ synapse/config/registration.py | 7 +++++++ tests/rest/client/v2_alpha/test_register.py | 1 + 4 files changed, 11 insertions(+) create mode 100644 changelog.d/6379.misc (limited to 'tests') diff --git a/changelog.d/6379.misc b/changelog.d/6379.misc new file mode 100644 index 0000000000..725c2e7d87 --- /dev/null +++ b/changelog.d/6379.misc @@ -0,0 +1 @@ +Complain on startup instead of 500'ing during runtime when `public_baseurl` isn't set when necessary. \ No newline at end of file diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 43fad0bf8b..ac1724045f 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -146,6 +146,8 @@ class EmailConfig(Config): if k not in email_config: missing.append("email." + k) + # public_baseurl is required to build password reset and validation links that + # will be emailed to users if config.get("public_baseurl") is None: missing.append("public_baseurl") diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 1f6dac69da..ee9614c5f7 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -106,6 +106,13 @@ class RegistrationConfig(Config): account_threepid_delegates = config.get("account_threepid_delegates") or {} self.account_threepid_delegate_email = account_threepid_delegates.get("email") self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn") + if self.account_threepid_delegate_msisdn and not self.public_baseurl: + raise ConfigError( + "The configuration option `public_baseurl` is required if " + "`account_threepid_delegate.msisdn` is set, such that " + "clients know where to submit validation tokens to. Please " + "configure `public_baseurl`." + ) self.default_identity_server = config.get("default_identity_server") self.allow_guest_access = config.get("allow_guest_access", False) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index dab87e5edf..c0d0d2b44e 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -203,6 +203,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): @unittest.override_config( { + "public_baseurl": "https://test_server", "enable_registration_captcha": True, "user_consent": { "version": "1", -- cgit 1.5.1 From 0d27aba900136514a8801b902f9a8ac69150e2c0 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Wed, 27 Nov 2019 16:14:44 -0500 Subject: add etag and count to key backup endpoints (#5858) --- changelog.d/5858.feature | 1 + synapse/handlers/e2e_room_keys.py | 130 +++++++----- synapse/rest/client/v2_alpha/room_keys.py | 8 +- synapse/storage/data_stores/main/e2e_room_keys.py | 226 +++++++++++++++------ .../main/schema/delta/56/room_key_etag.sql | 17 ++ tests/handlers/test_e2e_room_keys.py | 31 +++ tests/storage/test_e2e_room_keys.py | 8 +- 7 files changed, 297 insertions(+), 124 deletions(-) create mode 100644 changelog.d/5858.feature create mode 100644 synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql (limited to 'tests') diff --git a/changelog.d/5858.feature b/changelog.d/5858.feature new file mode 100644 index 0000000000..55ee93051e --- /dev/null +++ b/changelog.d/5858.feature @@ -0,0 +1 @@ +Add etag and count fields to key backup endpoints to help clients guess if there are new keys. diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 0cea445f0d..f1b4424a02 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2017, 2018 New Vector Ltd +# Copyright 2019 Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -103,14 +104,35 @@ class E2eRoomKeysHandler(object): rooms session_id(string): session ID to delete keys for, for None to delete keys for all sessions + Raises: + NotFoundError: if the backup version does not exist Returns: - A deferred of the deletion transaction + A dict containing the count and etag for the backup version """ # lock for consistency with uploading with (yield self._upload_linearizer.queue(user_id)): + # make sure the backup version exists + try: + version_info = yield self.store.get_e2e_room_keys_version_info( + user_id, version + ) + except StoreError as e: + if e.code == 404: + raise NotFoundError("Unknown backup version") + else: + raise + yield self.store.delete_e2e_room_keys(user_id, version, room_id, session_id) + version_etag = version_info["etag"] + 1 + yield self.store.update_e2e_room_keys_version( + user_id, version, None, version_etag + ) + + count = yield self.store.count_e2e_room_keys(user_id, version) + return {"etag": str(version_etag), "count": count} + @trace @defer.inlineCallbacks def upload_room_keys(self, user_id, version, room_keys): @@ -138,6 +160,9 @@ class E2eRoomKeysHandler(object): } } + Returns: + A dict containing the count and etag for the backup version + Raises: NotFoundError: if there are no versions defined RoomKeysVersionError: if the uploaded version is not the current version @@ -171,59 +196,62 @@ class E2eRoomKeysHandler(object): else: raise - # go through the room_keys. - # XXX: this should/could be done concurrently, given we're in a lock. + # Fetch any existing room keys for the sessions that have been + # submitted. Then compare them with the submitted keys. If the + # key is new, insert it; if the key should be updated, then update + # it; otherwise, drop it. + existing_keys = yield self.store.get_e2e_room_keys_multi( + user_id, version, room_keys["rooms"] + ) + to_insert = [] # batch the inserts together + changed = False # if anything has changed, we need to update the etag for room_id, room in iteritems(room_keys["rooms"]): - for session_id, session in iteritems(room["sessions"]): - yield self._upload_room_key( - user_id, version, room_id, session_id, session + for session_id, room_key in iteritems(room["sessions"]): + log_kv( + { + "message": "Trying to upload room key", + "room_id": room_id, + "session_id": session_id, + "user_id": user_id, + } ) - - @defer.inlineCallbacks - def _upload_room_key(self, user_id, version, room_id, session_id, room_key): - """Upload a given room_key for a given room and session into a given - version of the backup. Merges the key with any which might already exist. - - Args: - user_id(str): the user whose backup we're setting - version(str): the version ID of the backup we're updating - room_id(str): the ID of the room whose keys we're setting - session_id(str): the session whose room_key we're setting - room_key(dict): the room_key being set - """ - log_kv( - { - "message": "Trying to upload room key", - "room_id": room_id, - "session_id": session_id, - "user_id": user_id, - } - ) - # get the room_key for this particular row - current_room_key = None - try: - current_room_key = yield self.store.get_e2e_room_key( - user_id, version, room_id, session_id - ) - except StoreError as e: - if e.code == 404: - log_kv( - { - "message": "Room key not found.", - "room_id": room_id, - "user_id": user_id, - } + current_room_key = existing_keys.get(room_id, {}).get(session_id) + if current_room_key: + if self._should_replace_room_key(current_room_key, room_key): + log_kv({"message": "Replacing room key."}) + # updates are done one at a time in the DB, so send + # updates right away rather than batching them up, + # like we do with the inserts + yield self.store.update_e2e_room_key( + user_id, version, room_id, session_id, room_key + ) + changed = True + else: + log_kv({"message": "Not replacing room_key."}) + else: + log_kv( + { + "message": "Room key not found.", + "room_id": room_id, + "user_id": user_id, + } + ) + log_kv({"message": "Replacing room key."}) + to_insert.append((room_id, session_id, room_key)) + changed = True + + if len(to_insert): + yield self.store.add_e2e_room_keys(user_id, version, to_insert) + + version_etag = version_info["etag"] + if changed: + version_etag = version_etag + 1 + yield self.store.update_e2e_room_keys_version( + user_id, version, None, version_etag ) - else: - raise - if self._should_replace_room_key(current_room_key, room_key): - log_kv({"message": "Replacing room key."}) - yield self.store.set_e2e_room_key( - user_id, version, room_id, session_id, room_key - ) - else: - log_kv({"message": "Not replacing room_key."}) + count = yield self.store.count_e2e_room_keys(user_id, version) + return {"etag": str(version_etag), "count": count} @staticmethod def _should_replace_room_key(current_room_key, room_key): @@ -314,6 +342,8 @@ class E2eRoomKeysHandler(object): raise NotFoundError("Unknown backup version") else: raise + + res["count"] = yield self.store.count_e2e_room_keys(user_id, res["version"]) return res @trace diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py index d596786430..d83ac8e3c5 100644 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -134,8 +134,8 @@ class RoomKeysServlet(RestServlet): if room_id: body = {"rooms": {room_id: body}} - yield self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) - return 200, {} + ret = yield self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) + return 200, ret @defer.inlineCallbacks def on_GET(self, request, room_id, session_id): @@ -239,10 +239,10 @@ class RoomKeysServlet(RestServlet): user_id = requester.user.to_string() version = parse_string(request, "version") - yield self.e2e_room_keys_handler.delete_room_keys( + ret = yield self.e2e_room_keys_handler.delete_room_keys( user_id, version, room_id, session_id ) - return 200, {} + return 200, ret class RoomKeysNewVersionServlet(RestServlet): diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py index 1cbbae5b63..113224fd7c 100644 --- a/synapse/storage/data_stores/main/e2e_room_keys.py +++ b/synapse/storage/data_stores/main/e2e_room_keys.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2017 New Vector Ltd +# Copyright 2019 Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,49 +25,8 @@ from synapse.storage._base import SQLBaseStore class EndToEndRoomKeyStore(SQLBaseStore): @defer.inlineCallbacks - def get_e2e_room_key(self, user_id, version, room_id, session_id): - """Get the encrypted E2E room key for a given session from a given - backup version of room_keys. We only store the 'best' room key for a given - session at a given time, as determined by the handler. - - Args: - user_id(str): the user whose backup we're querying - version(str): the version ID of the backup for the set of keys we're querying - room_id(str): the ID of the room whose keys we're querying. - This is a bit redundant as it's implied by the session_id, but - we include for consistency with the rest of the API. - session_id(str): the session whose room_key we're querying. - - Returns: - A deferred dict giving the session_data and message metadata for - this room key. - """ - - row = yield self._simple_select_one( - table="e2e_room_keys", - keyvalues={ - "user_id": user_id, - "version": version, - "room_id": room_id, - "session_id": session_id, - }, - retcols=( - "first_message_index", - "forwarded_count", - "is_verified", - "session_data", - ), - desc="get_e2e_room_key", - ) - - row["session_data"] = json.loads(row["session_data"]) - - return row - - @defer.inlineCallbacks - def set_e2e_room_key(self, user_id, version, room_id, session_id, room_key): - """Replaces or inserts the encrypted E2E room key for a given session in - a given backup + def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key): + """Replaces the encrypted E2E room key for a given session in a given backup Args: user_id(str): the user whose backup we're setting @@ -78,7 +38,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): StoreError """ - yield self._simple_upsert( + yield self._simple_update_one( table="e2e_room_keys", keyvalues={ "user_id": user_id, @@ -86,21 +46,51 @@ class EndToEndRoomKeyStore(SQLBaseStore): "room_id": room_id, "session_id": session_id, }, - values={ + updatevalues={ "first_message_index": room_key["first_message_index"], "forwarded_count": room_key["forwarded_count"], "is_verified": room_key["is_verified"], "session_data": json.dumps(room_key["session_data"]), }, - lock=False, + desc="update_e2e_room_key", ) - log_kv( - { - "message": "Set room key", - "room_id": room_id, - "session_id": session_id, - "room_key": room_key, - } + + @defer.inlineCallbacks + def add_e2e_room_keys(self, user_id, version, room_keys): + """Bulk add room keys to a given backup. + + Args: + user_id (str): the user whose backup we're adding to + version (str): the version ID of the backup for the set of keys we're adding to + room_keys (iterable[(str, str, dict)]): the keys to add, in the form + (roomID, sessionID, keyData) + """ + + values = [] + for (room_id, session_id, room_key) in room_keys: + values.append( + { + "user_id": user_id, + "version": version, + "room_id": room_id, + "session_id": session_id, + "first_message_index": room_key["first_message_index"], + "forwarded_count": room_key["forwarded_count"], + "is_verified": room_key["is_verified"], + "session_data": json.dumps(room_key["session_data"]), + } + ) + log_kv( + { + "message": "Set room key", + "room_id": room_id, + "session_id": session_id, + "room_key": room_key, + } + ) + + yield self._simple_insert_many( + table="e2e_room_keys", values=values, desc="add_e2e_room_keys" ) @trace @@ -110,11 +100,11 @@ class EndToEndRoomKeyStore(SQLBaseStore): room, or a given session. Args: - user_id(str): the user whose backup we're querying - version(str): the version ID of the backup for the set of keys we're querying - room_id(str): Optional. the ID of the room whose keys we're querying, if any. + user_id (str): the user whose backup we're querying + version (str): the version ID of the backup for the set of keys we're querying + room_id (str): Optional. the ID of the room whose keys we're querying, if any. If not specified, we return the keys for all the rooms in the backup. - session_id(str): Optional. the session whose room_key we're querying, if any. + session_id (str): Optional. the session whose room_key we're querying, if any. If specified, we also require the room_id to be specified. If not specified, we return all the keys in this version of the backup (or for the specified room) @@ -162,6 +152,95 @@ class EndToEndRoomKeyStore(SQLBaseStore): return sessions + def get_e2e_room_keys_multi(self, user_id, version, room_keys): + """Get multiple room keys at a time. The difference between this function and + get_e2e_room_keys is that this function can be used to retrieve + multiple specific keys at a time, whereas get_e2e_room_keys is used for + getting all the keys in a backup version, all the keys for a room, or a + specific key. + + Args: + user_id (str): the user whose backup we're querying + version (str): the version ID of the backup we're querying about + room_keys (dict[str, dict[str, iterable[str]]]): a map from + room ID -> {"session": [session ids]} indicating the session IDs + that we want to query + + Returns: + Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key + """ + + return self.runInteraction( + "get_e2e_room_keys_multi", + self._get_e2e_room_keys_multi_txn, + user_id, + version, + room_keys, + ) + + @staticmethod + def _get_e2e_room_keys_multi_txn(txn, user_id, version, room_keys): + if not room_keys: + return {} + + where_clauses = [] + params = [user_id, version] + for room_id, room in room_keys.items(): + sessions = list(room["sessions"]) + if not sessions: + continue + params.append(room_id) + params.extend(sessions) + where_clauses.append( + "(room_id = ? AND session_id IN (%s))" + % (",".join(["?" for _ in sessions]),) + ) + + # check if we're actually querying something + if not where_clauses: + return {} + + sql = """ + SELECT room_id, session_id, first_message_index, forwarded_count, + is_verified, session_data + FROM e2e_room_keys + WHERE user_id = ? AND version = ? AND (%s) + """ % ( + " OR ".join(where_clauses) + ) + + txn.execute(sql, params) + + ret = {} + + for row in txn: + room_id = row[0] + session_id = row[1] + ret.setdefault(room_id, {}) + ret[room_id][session_id] = { + "first_message_index": row[2], + "forwarded_count": row[3], + "is_verified": row[4], + "session_data": json.loads(row[5]), + } + + return ret + + def count_e2e_room_keys(self, user_id, version): + """Get the number of keys in a backup version. + + Args: + user_id (str): the user whose backup we're querying + version (str): the version ID of the backup we're querying about + """ + + return self._simple_select_one_onecol( + table="e2e_room_keys", + keyvalues={"user_id": user_id, "version": version}, + retcol="COUNT(*)", + desc="count_e2e_room_keys", + ) + @trace @defer.inlineCallbacks def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): @@ -219,6 +298,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): version(str) algorithm(str) auth_data(object): opaque dict supplied by the client + etag(int): tag of the keys in the backup """ def _get_e2e_room_keys_version_info_txn(txn): @@ -236,10 +316,12 @@ class EndToEndRoomKeyStore(SQLBaseStore): txn, table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": this_version, "deleted": 0}, - retcols=("version", "algorithm", "auth_data"), + retcols=("version", "algorithm", "auth_data", "etag"), ) result["auth_data"] = json.loads(result["auth_data"]) result["version"] = str(result["version"]) + if result["etag"] is None: + result["etag"] = 0 return result return self.runInteraction( @@ -288,21 +370,33 @@ class EndToEndRoomKeyStore(SQLBaseStore): ) @trace - def update_e2e_room_keys_version(self, user_id, version, info): + def update_e2e_room_keys_version( + self, user_id, version, info=None, version_etag=None + ): """Update a given backup version Args: user_id(str): the user whose backup version we're updating version(str): the version ID of the backup version we're updating - info(dict): the new backup version info to store + info (dict): the new backup version info to store. If None, then + the backup version info is not updated + version_etag (Optional[int]): etag of the keys in the backup. If + None, then the etag is not updated """ + updatevalues = {} - return self._simple_update( - table="e2e_room_keys_versions", - keyvalues={"user_id": user_id, "version": version}, - updatevalues={"auth_data": json.dumps(info["auth_data"])}, - desc="update_e2e_room_keys_version", - ) + if info is not None and "auth_data" in info: + updatevalues["auth_data"] = json.dumps(info["auth_data"]) + if version_etag is not None: + updatevalues["etag"] = version_etag + + if updatevalues: + return self._simple_update( + table="e2e_room_keys_versions", + keyvalues={"user_id": user_id, "version": version}, + updatevalues=updatevalues, + desc="update_e2e_room_keys_version", + ) @trace def delete_e2e_room_keys_version(self, user_id, version=None): diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql b/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql new file mode 100644 index 0000000000..7d70dd071e --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql @@ -0,0 +1,17 @@ +/* Copyright 2019 Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- store the current etag of backup version +ALTER TABLE e2e_room_keys_versions ADD COLUMN etag BIGINT; diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index 0bb96674a2..70f172eb02 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd # Copyright 2017 New Vector Ltd +# Copyright 2019 Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -94,23 +95,29 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) + version_etag = res["etag"] + del res["etag"] self.assertDictEqual( res, { "version": "1", "algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data", + "count": 0, }, ) # check we can retrieve it as a specific version res = yield self.handler.get_version_info(self.local_user, "1") + self.assertEqual(res["etag"], version_etag) + del res["etag"] self.assertDictEqual( res, { "version": "1", "algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data", + "count": 0, }, ) @@ -126,12 +133,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) + del res["etag"] self.assertDictEqual( res, { "version": "2", "algorithm": "m.megolm_backup.v1", "auth_data": "second_version_auth_data", + "count": 0, }, ) @@ -158,12 +167,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) + del res["etag"] self.assertDictEqual( res, { "algorithm": "m.megolm_backup.v1", "auth_data": "revised_first_version_auth_data", "version": version, + "count": 0, }, ) @@ -207,12 +218,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) + del res["etag"] # etag is opaque, so don't test its contents self.assertDictEqual( res, { "algorithm": "m.megolm_backup.v1", "auth_data": "revised_first_version_auth_data", "version": version, + "count": 0, }, ) @@ -409,6 +422,11 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): yield self.handler.upload_room_keys(self.local_user, version, room_keys) + # get the etag to compare to future versions + res = yield self.handler.get_version_info(self.local_user) + backup_etag = res["etag"] + self.assertEqual(res["count"], 1) + new_room_keys = copy.deepcopy(room_keys) new_room_key = new_room_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"] @@ -423,6 +441,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): "SSBBTSBBIEZJU0gK", ) + # the etag should be the same since the session did not change + res = yield self.handler.get_version_info(self.local_user) + self.assertEqual(res["etag"], backup_etag) + # test that marking the session as verified however /does/ replace it new_room_key["is_verified"] = True yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) @@ -432,6 +454,11 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) + # the etag should NOT be equal now, since the key changed + res = yield self.handler.get_version_info(self.local_user) + self.assertNotEqual(res["etag"], backup_etag) + backup_etag = res["etag"] + # test that a session with a higher forwarded_count doesn't replace one # with a lower forwarding count new_room_key["forwarded_count"] = 2 @@ -443,6 +470,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) + # the etag should be the same since the session did not change + res = yield self.handler.get_version_info(self.local_user) + self.assertEqual(res["etag"], backup_etag) + # TODO: check edge cases as well as the common variations here @defer.inlineCallbacks diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py index d128fde441..35dafbb904 100644 --- a/tests/storage/test_e2e_room_keys.py +++ b/tests/storage/test_e2e_room_keys.py @@ -39,8 +39,8 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): ) self.get_success( - self.store.set_e2e_room_key( - "user_id", version1, "room", "session", room_key + self.store.add_e2e_room_keys( + "user_id", version1, [("room", "session", room_key)] ) ) @@ -51,8 +51,8 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): ) self.get_success( - self.store.set_e2e_room_key( - "user_id", version2, "room", "session", room_key + self.store.add_e2e_room_keys( + "user_id", version2, [("room", "session", room_key)] ) ) -- cgit 1.5.1 From 0f87b912aba7e678041632bc9a6d1f7c2d24342c Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Thu, 28 Nov 2019 08:54:07 +1100 Subject: Implementation of MSC2314 (#6176) --- changelog.d/6176.feature | 1 + synapse/federation/federation_server.py | 26 ++++++++---- synapse/federation/transport/server.py | 6 +-- sytest-blacklist | 6 ++- tests/federation/test_complexity.py | 28 ++---------- tests/federation/test_federation_sender.py | 4 +- tests/federation/test_federation_server.py | 63 +++++++++++++++++++++++++++ tests/handlers/test_typing.py | 3 ++ tests/replication/slave/storage/_base.py | 3 ++ tests/replication/tcp/streams/_base.py | 4 ++ tests/storage/test_roommember.py | 26 +----------- tests/unittest.py | 68 +++++++++++++++++++++++++++++- tests/utils.py | 1 + 13 files changed, 174 insertions(+), 65 deletions(-) create mode 100644 changelog.d/6176.feature (limited to 'tests') diff --git a/changelog.d/6176.feature b/changelog.d/6176.feature new file mode 100644 index 0000000000..3c66d689d4 --- /dev/null +++ b/changelog.d/6176.feature @@ -0,0 +1 @@ +Implement the `/_matrix/federation/unstable/net.atleastfornow/state/` API as drafted in MSC2314. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index d942d77a72..84d4eca041 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd # Copyright 2018 New Vector Ltd +# Copyright 2019 Matrix.org Federation C.I.C # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -73,6 +74,7 @@ class FederationServer(FederationBase): self.auth = hs.get_auth() self.handler = hs.get_handlers().federation_handler + self.state = hs.get_state_handler() self._server_linearizer = Linearizer("fed_server") self._transaction_linearizer = Linearizer("fed_txn_handler") @@ -264,9 +266,6 @@ class FederationServer(FederationBase): await self.registry.on_edu(edu_type, origin, content) async def on_context_state_request(self, origin, room_id, event_id): - if not event_id: - raise NotImplementedError("Specify an event") - origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) @@ -280,13 +279,18 @@ class FederationServer(FederationBase): # - but that's non-trivial to get right, and anyway somewhat defeats # the point of the linearizer. with (await self._server_linearizer.queue((origin, room_id))): - resp = await self._state_resp_cache.wrap( - (room_id, event_id), - self._on_context_state_request_compute, - room_id, - event_id, + resp = dict( + await self._state_resp_cache.wrap( + (room_id, event_id), + self._on_context_state_request_compute, + room_id, + event_id, + ) ) + room_version = await self.store.get_room_version(room_id) + resp["room_version"] = room_version + return 200, resp async def on_state_ids_request(self, origin, room_id, event_id): @@ -306,7 +310,11 @@ class FederationServer(FederationBase): return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} async def _on_context_state_request_compute(self, room_id, event_id): - pdus = await self.handler.get_state_for_pdu(room_id, event_id) + if event_id: + pdus = await self.handler.get_state_for_pdu(room_id, event_id) + else: + pdus = (await self.state.get_current_state(room_id)).values() + auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus]) return { diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 09baa9c57d..fefc789c85 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -421,7 +421,7 @@ class FederationEventServlet(BaseFederationServlet): return await self.handler.on_pdu_request(origin, event_id) -class FederationStateServlet(BaseFederationServlet): +class FederationStateV1Servlet(BaseFederationServlet): PATH = "/state/(?P[^/]*)/?" # This is when someone asks for all data for a given context. @@ -429,7 +429,7 @@ class FederationStateServlet(BaseFederationServlet): return await self.handler.on_context_state_request( origin, context, - parse_string_from_args(query, "event_id", None, required=True), + parse_string_from_args(query, "event_id", None, required=False), ) @@ -1360,7 +1360,7 @@ class RoomComplexityServlet(BaseFederationServlet): FEDERATION_SERVLET_CLASSES = ( FederationSendServlet, FederationEventServlet, - FederationStateServlet, + FederationStateV1Servlet, FederationStateIdsServlet, FederationBackfillServlet, FederationQueryServlet, diff --git a/sytest-blacklist b/sytest-blacklist index 11785fd43f..411cce0692 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -1,6 +1,6 @@ # This file serves as a blacklist for SyTest tests that we expect will fail in # Synapse. -# +# # Each line of this file is scanned by sytest during a run and if the line # exactly matches the name of a test, it will be marked as "expected fail", # meaning the test will still run, but failure will not mark the entire test @@ -29,3 +29,7 @@ Enabling an unknown default rule fails with 404 # Blacklisted due to https://github.com/matrix-org/synapse/issues/1663 New federated private chats get full presence information (SYN-115) + +# Blacklisted due to https://github.com/matrix-org/matrix-doc/pull/2314 removing +# this requirement from the spec +Inbound federation of state requires event_id as a mandatory paramater diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 51714a2b06..24fa8dbb45 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -18,17 +18,14 @@ from mock import Mock from twisted.internet import defer from synapse.api.errors import Codes, SynapseError -from synapse.config.ratelimiting import FederationRateLimitConfig -from synapse.federation.transport import server from synapse.rest import admin from synapse.rest.client.v1 import login, room from synapse.types import UserID -from synapse.util.ratelimitutils import FederationRateLimiter from tests import unittest -class RoomComplexityTests(unittest.HomeserverTestCase): +class RoomComplexityTests(unittest.FederatingHomeserverTestCase): servlets = [ admin.register_servlets, @@ -41,25 +38,6 @@ class RoomComplexityTests(unittest.HomeserverTestCase): config["limit_remote_rooms"] = {"enabled": True, "complexity": 0.05} return config - def prepare(self, reactor, clock, homeserver): - class Authenticator(object): - def authenticate_request(self, request, content): - return defer.succeed("otherserver.nottld") - - ratelimiter = FederationRateLimiter( - clock, - FederationRateLimitConfig( - window_size=1, - sleep_limit=1, - sleep_msec=1, - reject_limit=1000, - concurrent_requests=1000, - ), - ) - server.register_servlets( - homeserver, self.resource, Authenticator(), ratelimiter - ) - def test_complexity_simple(self): u1 = self.register_user("u1", "pass") @@ -105,7 +83,7 @@ class RoomComplexityTests(unittest.HomeserverTestCase): d = handler._remote_join( None, - ["otherserver.example"], + ["other.example.com"], "roomid", UserID.from_string(u1), {"membership": "join"}, @@ -146,7 +124,7 @@ class RoomComplexityTests(unittest.HomeserverTestCase): d = handler._remote_join( None, - ["otherserver.example"], + ["other.example.com"], room_1, UserID.from_string(u1), {"membership": "join"}, diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index cce8d8c6de..d456267b87 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -19,7 +19,7 @@ from twisted.internet import defer from synapse.types import ReadReceipt -from tests.unittest import HomeserverTestCase +from tests.unittest import HomeserverTestCase, override_config class FederationSenderTestCases(HomeserverTestCase): @@ -29,6 +29,7 @@ class FederationSenderTestCases(HomeserverTestCase): federation_transport_client=Mock(spec=["send_transaction"]), ) + @override_config({"send_federation": True}) def test_send_receipts(self): mock_state_handler = self.hs.get_state_handler() mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] @@ -69,6 +70,7 @@ class FederationSenderTestCases(HomeserverTestCase): ], ) + @override_config({"send_federation": True}) def test_send_receipts_with_backoff(self): """Send two receipts in quick succession; the second should be flushed, but only after 20ms""" diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index b08be451aa..1ec8c40901 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2018 New Vector Ltd +# Copyright 2019 Matrix.org Federation C.I.C # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,6 +17,8 @@ import logging from synapse.events import FrozenEvent from synapse.federation.federation_server import server_matches_acl_event +from synapse.rest import admin +from synapse.rest.client.v1 import login, room from tests import unittest @@ -41,6 +44,66 @@ class ServerACLsTestCase(unittest.TestCase): self.assertTrue(server_matches_acl_event("1:2:3:4", e)) +class StateQueryTests(unittest.FederatingHomeserverTestCase): + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def test_without_event_id(self): + """ + Querying v1/state/ without an event ID will return the current + known state. + """ + u1 = self.register_user("u1", "pass") + u1_token = self.login("u1", "pass") + + room_1 = self.helper.create_room_as(u1, tok=u1_token) + self.inject_room_member(room_1, "@user:other.example.com", "join") + + request, channel = self.make_request( + "GET", "/_matrix/federation/v1/state/%s" % (room_1,) + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + self.assertEqual( + channel.json_body["room_version"], + self.hs.config.default_room_version.identifier, + ) + + members = set( + map( + lambda x: x["state_key"], + filter( + lambda x: x["type"] == "m.room.member", channel.json_body["pdus"] + ), + ) + ) + + self.assertEqual(members, set(["@user:other.example.com", u1])) + self.assertEqual(len(channel.json_body["pdus"]), 6) + + def test_needs_to_be_in_room(self): + """ + Querying v1/state/ requires the server + be in the room to provide data. + """ + u1 = self.register_user("u1", "pass") + u1_token = self.login("u1", "pass") + + room_1 = self.helper.create_room_as(u1, tok=u1_token) + + request, channel = self.make_request( + "GET", "/_matrix/federation/v1/state/%s" % (room_1,) + ) + self.render(request) + self.assertEquals(403, channel.code, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + + def _create_acl_event(content): return FrozenEvent( { diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 5ec568f4e6..f6d8660285 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -24,6 +24,7 @@ from synapse.api.errors import AuthError from synapse.types import UserID from tests import unittest +from tests.unittest import override_config from tests.utils import register_federation_servlets # Some local users to test with @@ -174,6 +175,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ], ) + @override_config({"send_federation": True}) def test_started_typing_remote_send(self): self.room_members = [U_APPLE, U_ONION] @@ -237,6 +239,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ], ) + @override_config({"send_federation": True}) def test_stopped_typing(self): self.room_members = [U_APPLE, U_BANANA, U_ONION] diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 4f924ce451..e7472e3a93 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -48,7 +48,10 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): server_factory = ReplicationStreamProtocolFactory(self.hs) self.streamer = server_factory.streamer + handler_factory = Mock() self.replication_handler = ReplicationClientHandler(self.slaved_store) + self.replication_handler.factory = handler_factory + client_factory = ReplicationClientFactory( self.hs, "client_name", self.replication_handler ) diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index ce3835ae6a..1d14e77255 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from mock import Mock + from synapse.replication.tcp.commands import ReplicateCommand from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory @@ -30,7 +32,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): server = server_factory.buildProtocol(None) # build a replication client, with a dummy handler + handler_factory = Mock() self.test_handler = TestReplicationClientHandler() + self.test_handler.factory = handler_factory self.client = ClientReplicationStreamProtocol( "client", "test", clock, self.test_handler ) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 9ddd17f73d..105a0c2b02 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -16,8 +16,7 @@ from unittest.mock import Mock -from synapse.api.constants import EventTypes, Membership -from synapse.api.room_versions import RoomVersions +from synapse.api.constants import Membership from synapse.rest.admin import register_servlets_for_client_rest_resource from synapse.rest.client.v1 import login, room from synapse.types import Requester, UserID @@ -44,9 +43,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # We can't test the RoomMemberStore on its own without the other event # storage logic self.store = hs.get_datastore() - self.storage = hs.get_storage() - self.event_builder_factory = hs.get_event_builder_factory() - self.event_creation_handler = hs.get_event_creation_handler() self.u_alice = self.register_user("alice", "pass") self.t_alice = self.login("alice", "pass") @@ -55,26 +51,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # User elsewhere on another host self.u_charlie = UserID.from_string("@charlie:elsewhere") - def inject_room_member(self, room, user, membership, replaces_state=None): - builder = self.event_builder_factory.for_room_version( - RoomVersions.V1, - { - "type": EventTypes.Member, - "sender": user, - "state_key": user, - "room_id": room, - "content": {"membership": membership}, - }, - ) - - event, context = self.get_success( - self.event_creation_handler.create_new_client_event(builder) - ) - - self.get_success(self.storage.persistence.persist_event(event, context)) - - return event - def test_one_member(self): # Alice creates the room, and is automatically joined diff --git a/tests/unittest.py b/tests/unittest.py index 561cebc223..31997a0f31 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2018 New Vector +# Copyright 2019 Matrix.org Federation C.I.C # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import gc import hashlib import hmac @@ -27,13 +29,17 @@ from twisted.internet.defer import Deferred, succeed from twisted.python.threadpool import ThreadPool from twisted.trial import unittest -from synapse.api.constants import EventTypes +from synapse.api.constants import EventTypes, Membership +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.config.homeserver import HomeServerConfig +from synapse.config.ratelimiting import FederationRateLimitConfig +from synapse.federation.transport import server as federation_server from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest from synapse.logging.context import LoggingContext from synapse.server import HomeServer from synapse.types import Requester, UserID, create_requester +from synapse.util.ratelimitutils import FederationRateLimiter from tests.server import get_clock, make_request, render, setup_test_homeserver from tests.test_utils.logging_setup import setup_logging @@ -559,6 +565,66 @@ class HomeserverTestCase(TestCase): self.render(request) self.assertEqual(channel.code, 403, channel.result) + def inject_room_member(self, room: str, user: str, membership: Membership) -> None: + """ + Inject a membership event into a room. + + Args: + room: Room ID to inject the event into. + user: MXID of the user to inject the membership for. + membership: The membership type. + """ + event_builder_factory = self.hs.get_event_builder_factory() + event_creation_handler = self.hs.get_event_creation_handler() + + room_version = self.get_success(self.hs.get_datastore().get_room_version(room)) + + builder = event_builder_factory.for_room_version( + KNOWN_ROOM_VERSIONS[room_version], + { + "type": EventTypes.Member, + "sender": user, + "state_key": user, + "room_id": room, + "content": {"membership": membership}, + }, + ) + + event, context = self.get_success( + event_creation_handler.create_new_client_event(builder) + ) + + self.get_success( + self.hs.get_storage().persistence.persist_event(event, context) + ) + + +class FederatingHomeserverTestCase(HomeserverTestCase): + """ + A federating homeserver that authenticates incoming requests as `other.example.com`. + """ + + def prepare(self, reactor, clock, homeserver): + class Authenticator(object): + def authenticate_request(self, request, content): + return succeed("other.example.com") + + ratelimiter = FederationRateLimiter( + clock, + FederationRateLimitConfig( + window_size=1, + sleep_limit=1, + sleep_msec=1, + reject_limit=1000, + concurrent_requests=1000, + ), + ) + federation_server.register_servlets( + homeserver, self.resource, Authenticator(), ratelimiter + ) + + return super().prepare(reactor, clock, homeserver) + def override_config(extra_config): """A decorator which can be applied to test functions to give additional HS config diff --git a/tests/utils.py b/tests/utils.py index 7dc9bdc505..de2ac1ed33 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -109,6 +109,7 @@ def default_config(name, parse=False): """ config_dict = { "server_name": name, + "send_federation": False, "media_store_path": "media", "uploads_path": "uploads", # the test signing key is just an arbitrary ed25519 key to keep the config -- cgit 1.5.1 From 8c9a713f8db1d6fcc1f876ac6fbd0e54b5e5819c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 28 Nov 2019 11:32:06 +0000 Subject: Add tests --- tests/rest/client/v1/test_rooms.py | 140 +++++++++++++++++++++++++++++++++++++ 1 file changed, 140 insertions(+) (limited to 'tests') diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index e84e578f99..eda2fabc71 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1180,3 +1180,143 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): res_displayname = channel.json_body["content"]["displayname"] self.assertEqual(res_displayname, self.displayname, channel.result) + + +class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): + """Tests that clients can add a "reason" field to membership events and + that they get correctly added to the generated events and propagated. + """ + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.creator = self.register_user("creator", "test") + self.creator_tok = self.login("creator", "test") + + self.second_user_id = self.register_user("second", "test") + self.second_tok = self.login("second", "test") + + self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok) + + def test_join_reason(self): + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/join".format(self.room_id), + content={"reason": reason}, + access_token=self.second_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_leave_reason(self): + self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) + + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/leave".format(self.room_id), + content={"reason": reason}, + access_token=self.second_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_kick_reason(self): + self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) + + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/kick".format(self.room_id), + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.second_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_ban_reason(self): + self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) + + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/ban".format(self.room_id), + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.creator_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_unban_reason(self): + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/unban".format(self.room_id), + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.creator_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_invite_reason(self): + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/invite".format(self.room_id), + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.creator_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_reject_invite_reason(self): + self.helper.invite( + self.room_id, + src=self.creator, + targ=self.second_user_id, + tok=self.creator_tok, + ) + + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/leave".format(self.room_id), + content={"reason": reason}, + access_token=self.second_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def _check_for_reason(self, reason): + request, channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/{}/state/m.room.member/{}".format( + self.room_id, self.second_user_id + ), + access_token=self.creator_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + event_content = channel.json_body + + self.assertEqual(event_content.get("reason"), reason, channel.result) -- cgit 1.5.1 From 54dd5dc12b0ac5c48303144c4a73ce3822209488 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 3 Dec 2019 19:19:45 +0000 Subject: Add ephemeral messages support (MSC2228) (#6409) Implement part [MSC2228](https://github.com/matrix-org/matrix-doc/pull/2228). The parts that differ are: * the feature is hidden behind a configuration flag (`enable_ephemeral_messages`) * self-destruction doesn't happen for state events * only implement support for the `m.self_destruct_after` field (not the `m.self_destruct` one) * doesn't send synthetic redactions to clients because for this specific case we consider the clients to be able to destroy an event themselves, instead we just censor it (by pruning its JSON) in the database --- changelog.d/6409.feature | 1 + synapse/api/constants.py | 4 + synapse/config/server.py | 2 + synapse/handlers/federation.py | 8 ++ synapse/handlers/message.py | 123 +++++++++++++++++++- synapse/storage/data_stores/main/events.py | 126 ++++++++++++++++++++- .../main/schema/delta/56/event_expiry.sql | 21 ++++ tests/rest/client/test_ephemeral_message.py | 101 +++++++++++++++++ 8 files changed, 379 insertions(+), 7 deletions(-) create mode 100644 changelog.d/6409.feature create mode 100644 synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql create mode 100644 tests/rest/client/test_ephemeral_message.py (limited to 'tests') diff --git a/changelog.d/6409.feature b/changelog.d/6409.feature new file mode 100644 index 0000000000..653ff5a5ad --- /dev/null +++ b/changelog.d/6409.feature @@ -0,0 +1 @@ +Add ephemeral messages support by partially implementing [MSC2228](https://github.com/matrix-org/matrix-doc/pull/2228). diff --git a/synapse/api/constants.py b/synapse/api/constants.py index e3f086f1c3..69cef369a5 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -147,3 +147,7 @@ class EventContentFields(object): # Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326 LABELS = "org.matrix.labels" + + # Timestamp to delete the event after + # cf https://github.com/matrix-org/matrix-doc/pull/2228 + SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after" diff --git a/synapse/config/server.py b/synapse/config/server.py index 7a9d711669..837fbe1582 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -490,6 +490,8 @@ class ServerConfig(Config): "cleanup_extremities_with_dummy_events", True ) + self.enable_ephemeral_messages = config.get("enable_ephemeral_messages", False) + def has_tls_listener(self) -> bool: return any(l["tls"] for l in self.listeners) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index d3267734f7..d9d0cd9eef 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -121,6 +121,7 @@ class FederationHandler(BaseHandler): self.pusher_pool = hs.get_pusherpool() self.spam_checker = hs.get_spam_checker() self.event_creation_handler = hs.get_event_creation_handler() + self._message_handler = hs.get_message_handler() self._server_notices_mxid = hs.config.server_notices_mxid self.config = hs.config self.http_client = hs.get_simple_http_client() @@ -141,6 +142,8 @@ class FederationHandler(BaseHandler): self.third_party_event_rules = hs.get_third_party_event_rules() + self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages + @defer.inlineCallbacks def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False): """ Process a PDU received via a federation /send/ transaction, or @@ -2715,6 +2718,11 @@ class FederationHandler(BaseHandler): event_and_contexts, backfilled=backfilled ) + if self._ephemeral_messages_enabled: + for (event, context) in event_and_contexts: + # If there's an expiry timestamp on the event, schedule its expiry. + self._message_handler.maybe_schedule_expiry(event) + if not backfilled: # Never notify for backfilled events for event, _ in event_and_contexts: yield self._notify_persisted_event(event, max_stream_id) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 3b0156f516..4f53a5f5dc 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Optional from six import iteritems, itervalues, string_types @@ -22,9 +23,16 @@ from canonicaljson import encode_canonical_json, json from twisted.internet import defer from twisted.internet.defer import succeed +from twisted.internet.interfaces import IDelayedCall from synapse import event_auth -from synapse.api.constants import EventTypes, Membership, RelationTypes, UserTypes +from synapse.api.constants import ( + EventContentFields, + EventTypes, + Membership, + RelationTypes, + UserTypes, +) from synapse.api.errors import ( AuthError, Codes, @@ -62,6 +70,17 @@ class MessageHandler(object): self.storage = hs.get_storage() self.state_store = self.storage.state self._event_serializer = hs.get_event_client_serializer() + self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages + self._is_worker_app = bool(hs.config.worker_app) + + # The scheduled call to self._expire_event. None if no call is currently + # scheduled. + self._scheduled_expiry = None # type: Optional[IDelayedCall] + + if not hs.config.worker_app: + run_as_background_process( + "_schedule_next_expiry", self._schedule_next_expiry + ) @defer.inlineCallbacks def get_room_data( @@ -225,6 +244,100 @@ class MessageHandler(object): for user_id, profile in iteritems(users_with_profile) } + def maybe_schedule_expiry(self, event): + """Schedule the expiry of an event if there's not already one scheduled, + or if the one running is for an event that will expire after the provided + timestamp. + + This function needs to invalidate the event cache, which is only possible on + the master process, and therefore needs to be run on there. + + Args: + event (EventBase): The event to schedule the expiry of. + """ + assert not self._is_worker_app + + expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) + if not isinstance(expiry_ts, int) or event.is_state(): + return + + # _schedule_expiry_for_event won't actually schedule anything if there's already + # a task scheduled for a timestamp that's sooner than the provided one. + self._schedule_expiry_for_event(event.event_id, expiry_ts) + + @defer.inlineCallbacks + def _schedule_next_expiry(self): + """Retrieve the ID and the expiry timestamp of the next event to be expired, + and schedule an expiry task for it. + + If there's no event left to expire, set _expiry_scheduled to None so that a + future call to save_expiry_ts can schedule a new expiry task. + """ + # Try to get the expiry timestamp of the next event to expire. + res = yield self.store.get_next_event_to_expire() + if res: + event_id, expiry_ts = res + self._schedule_expiry_for_event(event_id, expiry_ts) + + def _schedule_expiry_for_event(self, event_id, expiry_ts): + """Schedule an expiry task for the provided event if there's not already one + scheduled at a timestamp that's sooner than the provided one. + + Args: + event_id (str): The ID of the event to expire. + expiry_ts (int): The timestamp at which to expire the event. + """ + if self._scheduled_expiry: + # If the provided timestamp refers to a time before the scheduled time of the + # next expiry task, cancel that task and reschedule it for this timestamp. + next_scheduled_expiry_ts = self._scheduled_expiry.getTime() * 1000 + if expiry_ts < next_scheduled_expiry_ts: + self._scheduled_expiry.cancel() + else: + return + + # Figure out how many seconds we need to wait before expiring the event. + now_ms = self.clock.time_msec() + delay = (expiry_ts - now_ms) / 1000 + + # callLater doesn't support negative delays, so trim the delay to 0 if we're + # in that case. + if delay < 0: + delay = 0 + + logger.info("Scheduling expiry for event %s in %.3fs", event_id, delay) + + self._scheduled_expiry = self.clock.call_later( + delay, + run_as_background_process, + "_expire_event", + self._expire_event, + event_id, + ) + + @defer.inlineCallbacks + def _expire_event(self, event_id): + """Retrieve and expire an event that needs to be expired from the database. + + If the event doesn't exist in the database, log it and delete the expiry date + from the database (so that we don't try to expire it again). + """ + assert self._ephemeral_events_enabled + + self._scheduled_expiry = None + + logger.info("Expiring event %s", event_id) + + try: + # Expire the event if we know about it. This function also deletes the expiry + # date from the database in the same database transaction. + yield self.store.expire_event(event_id) + except Exception as e: + logger.error("Could not expire event %s: %r", event_id, e) + + # Schedule the expiry of the next event to expire. + yield self._schedule_next_expiry() + # The duration (in ms) after which rooms should be removed # `_rooms_to_exclude_from_dummy_event_insertion` (with the effect that we will try @@ -295,6 +408,10 @@ class EventCreationHandler(object): 5 * 60 * 1000, ) + self._message_handler = hs.get_message_handler() + + self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages + @defer.inlineCallbacks def create_event( self, @@ -877,6 +994,10 @@ class EventCreationHandler(object): event, context=context ) + if self._ephemeral_events_enabled: + # If there's an expiry timestamp on the event, schedule its expiry. + self._message_handler.maybe_schedule_expiry(event) + yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) def _notify(): diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 2737a1d3ae..79c91fe284 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -130,6 +130,8 @@ class EventsStore( if self.hs.config.redaction_retention_period is not None: hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000) + self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages + @defer.inlineCallbacks def _read_forward_extremities(self): def fetch(txn): @@ -940,6 +942,12 @@ class EventsStore( txn, event.event_id, labels, event.room_id, event.depth ) + if self._ephemeral_messages_enabled: + # If there's an expiry timestamp on the event, store it. + expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) + if isinstance(expiry_ts, int) and not event.is_state(): + self._insert_event_expiry_txn(txn, event.event_id, expiry_ts) + # Insert into the room_memberships table. self._store_room_members_txn( txn, @@ -1101,12 +1109,7 @@ class EventsStore( def _update_censor_txn(txn): for redaction_id, event_id, pruned_json in updates: if pruned_json: - self._simple_update_one_txn( - txn, - table="event_json", - keyvalues={"event_id": event_id}, - updatevalues={"json": pruned_json}, - ) + self._censor_event_txn(txn, event_id, pruned_json) self._simple_update_one_txn( txn, @@ -1117,6 +1120,22 @@ class EventsStore( yield self.runInteraction("_update_censor_txn", _update_censor_txn) + def _censor_event_txn(self, txn, event_id, pruned_json): + """Censor an event by replacing its JSON in the event_json table with the + provided pruned JSON. + + Args: + txn (LoggingTransaction): The database transaction. + event_id (str): The ID of the event to censor. + pruned_json (str): The pruned JSON + """ + self._simple_update_one_txn( + txn, + table="event_json", + keyvalues={"event_id": event_id}, + updatevalues={"json": pruned_json}, + ) + @defer.inlineCallbacks def count_daily_messages(self): """ @@ -1957,6 +1976,101 @@ class EventsStore( ], ) + def _insert_event_expiry_txn(self, txn, event_id, expiry_ts): + """Save the expiry timestamp associated with a given event ID. + + Args: + txn (LoggingTransaction): The database transaction to use. + event_id (str): The event ID the expiry timestamp is associated with. + expiry_ts (int): The timestamp at which to expire (delete) the event. + """ + return self._simple_insert_txn( + txn=txn, + table="event_expiry", + values={"event_id": event_id, "expiry_ts": expiry_ts}, + ) + + @defer.inlineCallbacks + def expire_event(self, event_id): + """Retrieve and expire an event that has expired, and delete its associated + expiry timestamp. If the event can't be retrieved, delete its associated + timestamp so we don't try to expire it again in the future. + + Args: + event_id (str): The ID of the event to delete. + """ + # Try to retrieve the event's content from the database or the event cache. + event = yield self.get_event(event_id) + + def delete_expired_event_txn(txn): + # Delete the expiry timestamp associated with this event from the database. + self._delete_event_expiry_txn(txn, event_id) + + if not event: + # If we can't find the event, log a warning and delete the expiry date + # from the database so that we don't try to expire it again in the + # future. + logger.warning( + "Can't expire event %s because we don't have it.", event_id + ) + return + + # Prune the event's dict then convert it to JSON. + pruned_json = encode_json(prune_event_dict(event.get_dict())) + + # Update the event_json table to replace the event's JSON with the pruned + # JSON. + self._censor_event_txn(txn, event.event_id, pruned_json) + + # We need to invalidate the event cache entry for this event because we + # changed its content in the database. We can't call + # self._invalidate_cache_and_stream because self.get_event_cache isn't of the + # right type. + txn.call_after(self._get_event_cache.invalidate, (event.event_id,)) + # Send that invalidation to replication so that other workers also invalidate + # the event cache. + self._send_invalidation_to_replication( + txn, "_get_event_cache", (event.event_id,) + ) + + yield self.runInteraction("delete_expired_event", delete_expired_event_txn) + + def _delete_event_expiry_txn(self, txn, event_id): + """Delete the expiry timestamp associated with an event ID without deleting the + actual event. + + Args: + txn (LoggingTransaction): The transaction to use to perform the deletion. + event_id (str): The event ID to delete the associated expiry timestamp of. + """ + return self._simple_delete_txn( + txn=txn, table="event_expiry", keyvalues={"event_id": event_id} + ) + + def get_next_event_to_expire(self): + """Retrieve the entry with the lowest expiry timestamp in the event_expiry + table, or None if there's no more event to expire. + + Returns: Deferred[Optional[Tuple[str, int]]] + A tuple containing the event ID as its first element and an expiry timestamp + as its second one, if there's at least one row in the event_expiry table. + None otherwise. + """ + + def get_next_event_to_expire_txn(txn): + txn.execute( + """ + SELECT event_id, expiry_ts FROM event_expiry + ORDER BY expiry_ts ASC LIMIT 1 + """ + ) + + return txn.fetchone() + + return self.runInteraction( + desc="get_next_event_to_expire", func=get_next_event_to_expire_txn + ) + AllNewEventsResult = namedtuple( "AllNewEventsResult", diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql b/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql new file mode 100644 index 0000000000..81a36a8b1d --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql @@ -0,0 +1,21 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS event_expiry ( + event_id TEXT PRIMARY KEY, + expiry_ts BIGINT NOT NULL +); + +CREATE INDEX event_expiry_expiry_ts_idx ON event_expiry(expiry_ts); diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py new file mode 100644 index 0000000000..5e9c07ebf3 --- /dev/null +++ b/tests/rest/client/test_ephemeral_message.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.api.constants import EventContentFields, EventTypes +from synapse.rest import admin +from synapse.rest.client.v1 import room + +from tests import unittest + + +class EphemeralMessageTestCase(unittest.HomeserverTestCase): + + user_id = "@user:test" + + servlets = [ + admin.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + config["enable_ephemeral_messages"] = True + + self.hs = self.setup_test_homeserver(config=config) + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.room_id = self.helper.create_room_as(self.user_id) + + def test_message_expiry_no_delay(self): + """Tests that sending a message sent with a m.self_destruct_after field set to the + past results in that event being deleted right away. + """ + # Send a message in the room that has expired. From here, the reactor clock is + # at 200ms, so 0 is in the past, and even if that wasn't the case and the clock + # is at 0ms the code path is the same if the event's expiry timestamp is the + # current timestamp. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "hello", + EventContentFields.SELF_DESTRUCT_AFTER: 0, + }, + ) + event_id = res["event_id"] + + # Check that we can't retrieve the content of the event. + event_content = self.get_event(self.room_id, event_id)["content"] + self.assertFalse(bool(event_content), event_content) + + def test_message_expiry_delay(self): + """Tests that sending a message with a m.self_destruct_after field set to the + future results in that event not being deleted right away, but advancing the + clock to after that expiry timestamp causes the event to be deleted. + """ + # Send a message in the room that'll expire in 1s. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "hello", + EventContentFields.SELF_DESTRUCT_AFTER: self.clock.time_msec() + 1000, + }, + ) + event_id = res["event_id"] + + # Check that we can retrieve the content of the event before it has expired. + event_content = self.get_event(self.room_id, event_id)["content"] + self.assertTrue(bool(event_content), event_content) + + # Advance the clock to after the deletion. + self.reactor.advance(1) + + # Check that we can't retrieve the content of the event anymore. + event_content = self.get_event(self.room_id, event_id)["content"] + self.assertFalse(bool(event_content), event_content) + + def get_event(self, room_id, event_id, expected_code=200): + url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) + + request, channel = self.make_request("GET", url) + self.render(request) + + self.assertEqual(channel.code, expected_code, channel.result) + + return channel.json_body -- cgit 1.5.1 From cb0aeb147e3b3defc27866ad0e4982e63600a7ee Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Wed, 4 Dec 2019 09:46:16 +0000 Subject: privacy by default for room dir (#6355) Ensure that the the default settings for the room directory are that the it is hidden from public view by default. --- UPGRADE.rst | 17 ++++++++++ changelog.d/6354.feature | 1 + docs/sample_config.yaml | 13 ++++---- synapse/config/server.py | 26 +++++++++------- tests/federation/transport/test_server.py | 52 +++++++++++++++++++++++++++++++ 5 files changed, 91 insertions(+), 18 deletions(-) create mode 100644 changelog.d/6354.feature create mode 100644 tests/federation/transport/test_server.py (limited to 'tests') diff --git a/UPGRADE.rst b/UPGRADE.rst index 5ebf16a73e..d9020f2663 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -75,6 +75,23 @@ for example: wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb +Upgrading to v1.7.0 +=================== + +In an attempt to configure Synapse in a privacy preserving way, the default +behaviours of ``allow_public_rooms_without_auth`` and +``allow_public_rooms_over_federation`` have been inverted. This means that by +default, only authenticated users querying the Client/Server API will be able +to query the room directory, and relatedly that the server will not share +room directory information with other servers over federation. + +If your installation does not explicitly set these settings one way or the other +and you want either setting to be ``true`` then it will necessary to update +your homeserver configuration file accordingly. + +For more details on the surrounding context see our `explainer +`_. + Upgrading to v1.5.0 =================== diff --git a/changelog.d/6354.feature b/changelog.d/6354.feature new file mode 100644 index 0000000000..fed9db884b --- /dev/null +++ b/changelog.d/6354.feature @@ -0,0 +1 @@ +Configure privacy preserving settings by default for the room directory. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index c7391f0c48..10664ae8f7 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -54,15 +54,16 @@ pid_file: DATADIR/homeserver.pid # #require_auth_for_profile_requests: true -# If set to 'false', requires authentication to access the server's public rooms -# directory through the client API. Defaults to 'true'. +# If set to 'true', removes the need for authentication to access the server's +# public rooms directory through the client API, meaning that anyone can +# query the room directory. Defaults to 'false'. # -#allow_public_rooms_without_auth: false +#allow_public_rooms_without_auth: true -# If set to 'false', forbids any other homeserver to fetch the server's public -# rooms directory via federation. Defaults to 'true'. +# If set to 'true', allows any other homeserver to fetch the server's public +# rooms directory via federation. Defaults to 'false'. # -#allow_public_rooms_over_federation: false +#allow_public_rooms_over_federation: true # The default room version for newly created rooms. # diff --git a/synapse/config/server.py b/synapse/config/server.py index 837fbe1582..a4bef00936 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -118,15 +118,16 @@ class ServerConfig(Config): self.allow_public_rooms_without_auth = False self.allow_public_rooms_over_federation = False else: - # If set to 'False', requires authentication to access the server's public - # rooms directory through the client API. Defaults to 'True'. + # If set to 'true', removes the need for authentication to access the server's + # public rooms directory through the client API, meaning that anyone can + # query the room directory. Defaults to 'false'. self.allow_public_rooms_without_auth = config.get( - "allow_public_rooms_without_auth", True + "allow_public_rooms_without_auth", False ) - # If set to 'False', forbids any other homeserver to fetch the server's public - # rooms directory via federation. Defaults to 'True'. + # If set to 'true', allows any other homeserver to fetch the server's public + # rooms directory via federation. Defaults to 'false'. self.allow_public_rooms_over_federation = config.get( - "allow_public_rooms_over_federation", True + "allow_public_rooms_over_federation", False ) default_room_version = config.get("default_room_version", DEFAULT_ROOM_VERSION) @@ -620,15 +621,16 @@ class ServerConfig(Config): # #require_auth_for_profile_requests: true - # If set to 'false', requires authentication to access the server's public rooms - # directory through the client API. Defaults to 'true'. + # If set to 'true', removes the need for authentication to access the server's + # public rooms directory through the client API, meaning that anyone can + # query the room directory. Defaults to 'false'. # - #allow_public_rooms_without_auth: false + #allow_public_rooms_without_auth: true - # If set to 'false', forbids any other homeserver to fetch the server's public - # rooms directory via federation. Defaults to 'true'. + # If set to 'true', allows any other homeserver to fetch the server's public + # rooms directory via federation. Defaults to 'false'. # - #allow_public_rooms_over_federation: false + #allow_public_rooms_over_federation: true # The default room version for newly created rooms. # diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py new file mode 100644 index 0000000000..27d83bb7d9 --- /dev/null +++ b/tests/federation/transport/test_server.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from twisted.internet import defer + +from synapse.config.ratelimiting import FederationRateLimitConfig +from synapse.federation.transport import server +from synapse.util.ratelimitutils import FederationRateLimiter + +from tests import unittest +from tests.unittest import override_config + + +class RoomDirectoryFederationTests(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, homeserver): + class Authenticator(object): + def authenticate_request(self, request, content): + return defer.succeed("otherserver.nottld") + + ratelimiter = FederationRateLimiter(clock, FederationRateLimitConfig()) + server.register_servlets( + homeserver, self.resource, Authenticator(), ratelimiter + ) + + @override_config({"allow_public_rooms_over_federation": False}) + def test_blocked_public_room_list_over_federation(self): + request, channel = self.make_request( + "GET", "/_matrix/federation/v1/publicRooms" + ) + self.render(request) + self.assertEquals(403, channel.code) + + @override_config({"allow_public_rooms_over_federation": True}) + def test_open_public_room_list_over_federation(self): + request, channel = self.make_request( + "GET", "/_matrix/federation/v1/publicRooms" + ) + self.render(request) + self.assertEquals(200, channel.code) -- cgit 1.5.1 From 65c6aee621fecff1c6a863d6b910c973196ad6bc Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 4 Dec 2019 14:36:39 +0000 Subject: Un-remove room purge test --- tests/rest/client/v1/test_rooms.py | 72 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) (limited to 'tests') diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 4095e63aef..1ca7fa742f 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -815,6 +815,78 @@ class RoomMessageListTestCase(RoomBase): self.assertTrue("chunk" in channel.json_body) self.assertTrue("end" in channel.json_body) + def test_room_messages_purge(self): + store = self.hs.get_datastore() + pagination_handler = self.hs.get_pagination_handler() + + # Send a first message in the room, which will be removed by the purge. + first_event_id = self.helper.send(self.room_id, "message 1")["event_id"] + first_token = self.get_success( + store.get_topological_token_for_event(first_event_id) + ) + + # Send a second message in the room, which won't be removed, and which we'll + # use as the marker to purge events before. + second_event_id = self.helper.send(self.room_id, "message 2")["event_id"] + second_token = self.get_success( + store.get_topological_token_for_event(second_event_id) + ) + + # Send a third event in the room to ensure we don't fall under any edge case + # due to our marker being the latest forward extremity in the room. + self.helper.send(self.room_id, "message 3") + + # Check that we get the first and second message when querying /messages. + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" + % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})), + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 2, [event["content"] for event in chunk]) + + # Purge every event before the second event. + purge_id = random_string(16) + pagination_handler._purges_by_id[purge_id] = PurgeStatus() + self.get_success( + pagination_handler._purge_history( + purge_id=purge_id, + room_id=self.room_id, + token=second_token, + delete_local_events=True, + ) + ) + + # Check that we only get the second message through /message now that the first + # has been purged. + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" + % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})), + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 1, [event["content"] for event in chunk]) + + # Check that we get no event, but also no error, when querying /messages with + # the token that was pointing at the first event, because we don't have it + # anymore. + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" + % (self.room_id, first_token, json.dumps({"types": [EventTypes.Message]})), + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 0, [event["content"] for event in chunk]) + class RoomSearchTestCase(unittest.HomeserverTestCase): servlets = [ -- cgit 1.5.1 From ee86abb2d6e9c7d553858e814b4343bcf95af75a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 4 Dec 2019 10:15:55 +0000 Subject: Remove underscore from SQLBaseStore functions --- scripts-dev/hash_history.py | 2 +- scripts/synapse_port_db | 18 +-- synapse/app/user_dir.py | 2 +- synapse/storage/_base.py | 142 ++++++++++----------- synapse/storage/background_updates.py | 14 +- synapse/storage/data_stores/main/__init__.py | 16 +-- synapse/storage/data_stores/main/account_data.py | 18 +-- synapse/storage/data_stores/main/appservice.py | 10 +- synapse/storage/data_stores/main/cache.py | 2 +- synapse/storage/data_stores/main/client_ips.py | 8 +- synapse/storage/data_stores/main/deviceinbox.py | 4 +- synapse/storage/data_stores/main/devices.py | 46 +++---- synapse/storage/data_stores/main/directory.py | 12 +- synapse/storage/data_stores/main/e2e_room_keys.py | 20 +-- .../storage/data_stores/main/end_to_end_keys.py | 20 +-- .../storage/data_stores/main/event_federation.py | 16 +-- .../storage/data_stores/main/event_push_actions.py | 12 +- synapse/storage/data_stores/main/events.py | 52 ++++---- .../storage/data_stores/main/events_bg_updates.py | 10 +- synapse/storage/data_stores/main/events_worker.py | 6 +- synapse/storage/data_stores/main/filtering.py | 2 +- synapse/storage/data_stores/main/group_server.py | 128 +++++++++---------- synapse/storage/data_stores/main/keys.py | 6 +- .../storage/data_stores/main/media_repository.py | 24 ++-- .../data_stores/main/monthly_active_users.py | 6 +- synapse/storage/data_stores/main/openid.py | 2 +- synapse/storage/data_stores/main/presence.py | 8 +- synapse/storage/data_stores/main/profile.py | 24 ++-- synapse/storage/data_stores/main/push_rule.py | 24 ++-- synapse/storage/data_stores/main/pusher.py | 26 ++-- synapse/storage/data_stores/main/receipts.py | 16 +-- synapse/storage/data_stores/main/registration.py | 92 ++++++------- synapse/storage/data_stores/main/rejections.py | 4 +- synapse/storage/data_stores/main/relations.py | 4 +- synapse/storage/data_stores/main/room.py | 34 ++--- synapse/storage/data_stores/main/roommember.py | 16 +-- synapse/storage/data_stores/main/search.py | 8 +- synapse/storage/data_stores/main/signatures.py | 2 +- synapse/storage/data_stores/main/state.py | 36 +++--- synapse/storage/data_stores/main/state_deltas.py | 2 +- synapse/storage/data_stores/main/stats.py | 28 ++-- synapse/storage/data_stores/main/stream.py | 14 +- synapse/storage/data_stores/main/tags.py | 6 +- synapse/storage/data_stores/main/transactions.py | 12 +- synapse/storage/data_stores/main/user_directory.py | 56 ++++---- .../storage/data_stores/main/user_erasure_store.py | 4 +- tests/handlers/test_stats.py | 30 ++--- tests/handlers/test_user_directory.py | 12 +- tests/rest/admin/test_admin.py | 2 +- tests/storage/test__base.py | 8 +- tests/storage/test_base.py | 18 +-- tests/storage/test_client_ips.py | 12 +- tests/storage/test_event_push_actions.py | 4 +- tests/storage/test_redaction.py | 4 +- tests/storage/test_roommember.py | 2 +- tests/unittest.py | 2 +- 56 files changed, 550 insertions(+), 558 deletions(-) (limited to 'tests') diff --git a/scripts-dev/hash_history.py b/scripts-dev/hash_history.py index d20f6db176..bf3862a386 100644 --- a/scripts-dev/hash_history.py +++ b/scripts-dev/hash_history.py @@ -27,7 +27,7 @@ class Store(object): "_store_pdu_reference_hash_txn" ] _store_prev_pdu_hash_txn = SignatureStore.__dict__["_store_prev_pdu_hash_txn"] - _simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"] + simple_insert_txn = SQLBaseStore.__dict__["simple_insert_txn"] store = Store() diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index f24b8ffe67..9dd1700ff0 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -221,7 +221,7 @@ class Porter(object): def setup_table(self, table): if table in APPEND_ONLY_TABLES: # It's safe to just carry on inserting. - row = yield self.postgres_store._simple_select_one( + row = yield self.postgres_store.simple_select_one( table="port_from_sqlite3", keyvalues={"table_name": table}, retcols=("forward_rowid", "backward_rowid"), @@ -236,7 +236,7 @@ class Porter(object): ) backward_chunk = 0 else: - yield self.postgres_store._simple_insert( + yield self.postgres_store.simple_insert( table="port_from_sqlite3", values={ "table_name": table, @@ -266,7 +266,7 @@ class Porter(object): yield self.postgres_store.execute(delete_all) - yield self.postgres_store._simple_insert( + yield self.postgres_store.simple_insert( table="port_from_sqlite3", values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0}, ) @@ -320,7 +320,7 @@ class Porter(object): if table == "user_directory_stream_pos": # We need to make sure there is a single row, `(X, null), as that is # what synapse expects to be there. - yield self.postgres_store._simple_insert( + yield self.postgres_store.simple_insert( table=table, values={"stream_id": None} ) self.progress.update(table, table_size) # Mark table as done @@ -375,7 +375,7 @@ class Porter(object): def insert(txn): self.postgres_store.insert_many_txn(txn, table, headers[1:], rows) - self.postgres_store._simple_update_one_txn( + self.postgres_store.simple_update_one_txn( txn, table="port_from_sqlite3", keyvalues={"table_name": table}, @@ -452,7 +452,7 @@ class Porter(object): ], ) - self.postgres_store._simple_update_one_txn( + self.postgres_store.simple_update_one_txn( txn, table="port_from_sqlite3", keyvalues={"table_name": "event_search"}, @@ -591,11 +591,11 @@ class Porter(object): # Step 2. Get tables. self.progress.set_state("Fetching tables") - sqlite_tables = yield self.sqlite_store._simple_select_onecol( + sqlite_tables = yield self.sqlite_store.simple_select_onecol( table="sqlite_master", keyvalues={"type": "table"}, retcol="name" ) - postgres_tables = yield self.postgres_store._simple_select_onecol( + postgres_tables = yield self.postgres_store.simple_select_onecol( table="information_schema.tables", keyvalues={}, retcol="distinct table_name", @@ -722,7 +722,7 @@ class Porter(object): next_chunk = yield self.sqlite_store.execute(get_start_id) next_chunk = max(max_inserted_rowid + 1, next_chunk) - yield self.postgres_store._simple_insert( + yield self.postgres_store.simple_insert( table="port_from_sqlite3", values={ "table_name": "sent_transactions", diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index 6cb100319f..0fa2b50999 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -64,7 +64,7 @@ class UserDirectorySlaveStore( super(UserDirectorySlaveStore, self).__init__(db_conn, hs) events_max = self._stream_id_gen.get_current_token() - curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict( + curr_state_delta_prefill, min_curr_state_delta_id = self.get_cache_dict( db_conn, "current_state_delta_stream", entity_column="room_id", diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 1ed89d9f2a..9205e550bb 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -262,7 +262,7 @@ class SQLBaseStore(object): If the background updates have not completed, wait 15 sec and check again. """ - updates = yield self._simple_select_list( + updates = yield self.simple_select_list( "background_updates", keyvalues=None, retcols=["update_name"], @@ -307,7 +307,7 @@ class SQLBaseStore(object): self._clock.looping_call(loop, 10000) - def _new_transaction( + def new_transaction( self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs ): start = monotonic_time() @@ -444,7 +444,7 @@ class SQLBaseStore(object): try: result = yield self.runWithConnection( - self._new_transaction, + self.new_transaction, desc, after_callbacks, exception_callbacks, @@ -516,7 +516,7 @@ class SQLBaseStore(object): results = list(dict(zip(col_headers, row)) for row in cursor) return results - def _execute(self, desc, decoder, query, *args): + def execute(self, desc, decoder, query, *args): """Runs a single query for a result set. Args: @@ -541,7 +541,7 @@ class SQLBaseStore(object): # no complex WHERE clauses, just a dict of values for columns. @defer.inlineCallbacks - def _simple_insert(self, table, values, or_ignore=False, desc="_simple_insert"): + def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"): """Executes an INSERT query on the named table. Args: @@ -557,7 +557,7 @@ class SQLBaseStore(object): `or_ignore` is True """ try: - yield self.runInteraction(desc, self._simple_insert_txn, table, values) + yield self.runInteraction(desc, self.simple_insert_txn, table, values) except self.database_engine.module.IntegrityError: # We have to do or_ignore flag at this layer, since we can't reuse # a cursor after we receive an error from the db. @@ -567,7 +567,7 @@ class SQLBaseStore(object): return True @staticmethod - def _simple_insert_txn(txn, table, values): + def simple_insert_txn(txn, table, values): keys, vals = zip(*values.items()) sql = "INSERT INTO %s (%s) VALUES(%s)" % ( @@ -578,11 +578,11 @@ class SQLBaseStore(object): txn.execute(sql, vals) - def _simple_insert_many(self, table, values, desc): - return self.runInteraction(desc, self._simple_insert_many_txn, table, values) + def simple_insert_many(self, table, values, desc): + return self.runInteraction(desc, self.simple_insert_many_txn, table, values) @staticmethod - def _simple_insert_many_txn(txn, table, values): + def simple_insert_many_txn(txn, table, values): if not values: return @@ -611,13 +611,13 @@ class SQLBaseStore(object): txn.executemany(sql, vals) @defer.inlineCallbacks - def _simple_upsert( + def simple_upsert( self, table, keyvalues, values, insertion_values={}, - desc="_simple_upsert", + desc="simple_upsert", lock=True, ): """ @@ -649,7 +649,7 @@ class SQLBaseStore(object): try: result = yield self.runInteraction( desc, - self._simple_upsert_txn, + self.simple_upsert_txn, table, keyvalues, values, @@ -669,7 +669,7 @@ class SQLBaseStore(object): "IntegrityError when upserting into %s; retrying: %s", table, e ) - def _simple_upsert_txn( + def simple_upsert_txn( self, txn, table, keyvalues, values, insertion_values={}, lock=True ): """ @@ -693,11 +693,11 @@ class SQLBaseStore(object): self.database_engine.can_native_upsert and table not in self._unsafe_to_upsert_tables ): - return self._simple_upsert_txn_native_upsert( + return self.simple_upsert_txn_native_upsert( txn, table, keyvalues, values, insertion_values=insertion_values ) else: - return self._simple_upsert_txn_emulated( + return self.simple_upsert_txn_emulated( txn, table, keyvalues, @@ -706,7 +706,7 @@ class SQLBaseStore(object): lock=lock, ) - def _simple_upsert_txn_emulated( + def simple_upsert_txn_emulated( self, txn, table, keyvalues, values, insertion_values={}, lock=True ): """ @@ -775,7 +775,7 @@ class SQLBaseStore(object): # successfully inserted return True - def _simple_upsert_txn_native_upsert( + def simple_upsert_txn_native_upsert( self, txn, table, keyvalues, values, insertion_values={} ): """ @@ -809,7 +809,7 @@ class SQLBaseStore(object): ) txn.execute(sql, list(allvalues.values())) - def _simple_upsert_many_txn( + def simple_upsert_many_txn( self, txn, table, key_names, key_values, value_names, value_values ): """ @@ -829,15 +829,15 @@ class SQLBaseStore(object): self.database_engine.can_native_upsert and table not in self._unsafe_to_upsert_tables ): - return self._simple_upsert_many_txn_native_upsert( + return self.simple_upsert_many_txn_native_upsert( txn, table, key_names, key_values, value_names, value_values ) else: - return self._simple_upsert_many_txn_emulated( + return self.simple_upsert_many_txn_emulated( txn, table, key_names, key_values, value_names, value_values ) - def _simple_upsert_many_txn_emulated( + def simple_upsert_many_txn_emulated( self, txn, table, key_names, key_values, value_names, value_values ): """ @@ -862,9 +862,9 @@ class SQLBaseStore(object): _keys = {x: y for x, y in zip(key_names, keyv)} _vals = {x: y for x, y in zip(value_names, valv)} - self._simple_upsert_txn_emulated(txn, table, _keys, _vals) + self.simple_upsert_txn_emulated(txn, table, _keys, _vals) - def _simple_upsert_many_txn_native_upsert( + def simple_upsert_many_txn_native_upsert( self, txn, table, key_names, key_values, value_names, value_values ): """ @@ -909,8 +909,8 @@ class SQLBaseStore(object): return txn.execute_batch(sql, args) - def _simple_select_one( - self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one" + def simple_select_one( + self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one" ): """Executes a SELECT query on the named table, which is expected to return a single row, returning multiple columns from it. @@ -924,16 +924,16 @@ class SQLBaseStore(object): statement returns no rows """ return self.runInteraction( - desc, self._simple_select_one_txn, table, keyvalues, retcols, allow_none + desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none ) - def _simple_select_one_onecol( + def simple_select_one_onecol( self, table, keyvalues, retcol, allow_none=False, - desc="_simple_select_one_onecol", + desc="simple_select_one_onecol", ): """Executes a SELECT query on the named table, which is expected to return a single row, returning a single column from it. @@ -945,7 +945,7 @@ class SQLBaseStore(object): """ return self.runInteraction( desc, - self._simple_select_one_onecol_txn, + self.simple_select_one_onecol_txn, table, keyvalues, retcol, @@ -953,10 +953,10 @@ class SQLBaseStore(object): ) @classmethod - def _simple_select_one_onecol_txn( + def simple_select_one_onecol_txn( cls, txn, table, keyvalues, retcol, allow_none=False ): - ret = cls._simple_select_onecol_txn( + ret = cls.simple_select_onecol_txn( txn, table=table, keyvalues=keyvalues, retcol=retcol ) @@ -969,7 +969,7 @@ class SQLBaseStore(object): raise StoreError(404, "No row found") @staticmethod - def _simple_select_onecol_txn(txn, table, keyvalues, retcol): + def simple_select_onecol_txn(txn, table, keyvalues, retcol): sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table} if keyvalues: @@ -980,8 +980,8 @@ class SQLBaseStore(object): return [r[0] for r in txn] - def _simple_select_onecol( - self, table, keyvalues, retcol, desc="_simple_select_onecol" + def simple_select_onecol( + self, table, keyvalues, retcol, desc="simple_select_onecol" ): """Executes a SELECT query on the named table, which returns a list comprising of the values of the named column from the selected rows. @@ -995,12 +995,10 @@ class SQLBaseStore(object): Deferred: Results in a list """ return self.runInteraction( - desc, self._simple_select_onecol_txn, table, keyvalues, retcol + desc, self.simple_select_onecol_txn, table, keyvalues, retcol ) - def _simple_select_list( - self, table, keyvalues, retcols, desc="_simple_select_list" - ): + def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. @@ -1014,11 +1012,11 @@ class SQLBaseStore(object): defer.Deferred: resolves to list[dict[str, Any]] """ return self.runInteraction( - desc, self._simple_select_list_txn, table, keyvalues, retcols + desc, self.simple_select_list_txn, table, keyvalues, retcols ) @classmethod - def _simple_select_list_txn(cls, txn, table, keyvalues, retcols): + def simple_select_list_txn(cls, txn, table, keyvalues, retcols): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. @@ -1044,14 +1042,14 @@ class SQLBaseStore(object): return cls.cursor_to_dict(txn) @defer.inlineCallbacks - def _simple_select_many_batch( + def simple_select_many_batch( self, table, column, iterable, retcols, keyvalues={}, - desc="_simple_select_many_batch", + desc="simple_select_many_batch", batch_size=100, ): """Executes a SELECT query on the named table, which may return zero or @@ -1080,7 +1078,7 @@ class SQLBaseStore(object): for chunk in chunks: rows = yield self.runInteraction( desc, - self._simple_select_many_txn, + self.simple_select_many_txn, table, column, chunk, @@ -1093,7 +1091,7 @@ class SQLBaseStore(object): return results @classmethod - def _simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols): + def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. @@ -1126,13 +1124,13 @@ class SQLBaseStore(object): txn.execute(sql, values) return cls.cursor_to_dict(txn) - def _simple_update(self, table, keyvalues, updatevalues, desc): + def simple_update(self, table, keyvalues, updatevalues, desc): return self.runInteraction( - desc, self._simple_update_txn, table, keyvalues, updatevalues + desc, self.simple_update_txn, table, keyvalues, updatevalues ) @staticmethod - def _simple_update_txn(txn, table, keyvalues, updatevalues): + def simple_update_txn(txn, table, keyvalues, updatevalues): if keyvalues: where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) else: @@ -1148,8 +1146,8 @@ class SQLBaseStore(object): return txn.rowcount - def _simple_update_one( - self, table, keyvalues, updatevalues, desc="_simple_update_one" + def simple_update_one( + self, table, keyvalues, updatevalues, desc="simple_update_one" ): """Executes an UPDATE query on the named table, setting new values for columns in a row matching the key values. @@ -1169,12 +1167,12 @@ class SQLBaseStore(object): the update column in the 'keyvalues' dict as well. """ return self.runInteraction( - desc, self._simple_update_one_txn, table, keyvalues, updatevalues + desc, self.simple_update_one_txn, table, keyvalues, updatevalues ) @classmethod - def _simple_update_one_txn(cls, txn, table, keyvalues, updatevalues): - rowcount = cls._simple_update_txn(txn, table, keyvalues, updatevalues) + def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues): + rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues) if rowcount == 0: raise StoreError(404, "No row found (%s)" % (table,)) @@ -1182,7 +1180,7 @@ class SQLBaseStore(object): raise StoreError(500, "More than one row matched (%s)" % (table,)) @staticmethod - def _simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False): + def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False): select_sql = "SELECT %s FROM %s WHERE %s" % ( ", ".join(retcols), table, @@ -1201,7 +1199,7 @@ class SQLBaseStore(object): return dict(zip(retcols, row)) - def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"): + def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"): """Executes a DELETE query on the named table, expecting to delete a single row. @@ -1209,10 +1207,10 @@ class SQLBaseStore(object): table : string giving the table name keyvalues : dict of column names and values to select the row with """ - return self.runInteraction(desc, self._simple_delete_one_txn, table, keyvalues) + return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues) @staticmethod - def _simple_delete_one_txn(txn, table, keyvalues): + def simple_delete_one_txn(txn, table, keyvalues): """Executes a DELETE query on the named table, expecting to delete a single row. @@ -1231,11 +1229,11 @@ class SQLBaseStore(object): if txn.rowcount > 1: raise StoreError(500, "More than one row matched (%s)" % (table,)) - def _simple_delete(self, table, keyvalues, desc): - return self.runInteraction(desc, self._simple_delete_txn, table, keyvalues) + def simple_delete(self, table, keyvalues, desc): + return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues) @staticmethod - def _simple_delete_txn(txn, table, keyvalues): + def simple_delete_txn(txn, table, keyvalues): sql = "DELETE FROM %s WHERE %s" % ( table, " AND ".join("%s = ?" % (k,) for k in keyvalues), @@ -1244,13 +1242,13 @@ class SQLBaseStore(object): txn.execute(sql, list(keyvalues.values())) return txn.rowcount - def _simple_delete_many(self, table, column, iterable, keyvalues, desc): + def simple_delete_many(self, table, column, iterable, keyvalues, desc): return self.runInteraction( - desc, self._simple_delete_many_txn, table, column, iterable, keyvalues + desc, self.simple_delete_many_txn, table, column, iterable, keyvalues ) @staticmethod - def _simple_delete_many_txn(txn, table, column, iterable, keyvalues): + def simple_delete_many_txn(txn, table, column, iterable, keyvalues): """Executes a DELETE query on the named table. Filters rows by if value of `column` is in `iterable`. @@ -1283,7 +1281,7 @@ class SQLBaseStore(object): return txn.rowcount - def _get_cache_dict( + def get_cache_dict( self, db_conn, table, entity_column, stream_column, max_value, limit=100000 ): # Fetch a mapping of room_id -> max stream position for "recent" rooms. @@ -1349,7 +1347,7 @@ class SQLBaseStore(object): # which is fine. pass - def _simple_select_list_paginate( + def simple_select_list_paginate( self, table, keyvalues, @@ -1358,7 +1356,7 @@ class SQLBaseStore(object): limit, retcols, order_direction="ASC", - desc="_simple_select_list_paginate", + desc="simple_select_list_paginate", ): """ Executes a SELECT query on the named table with start and limit, @@ -1380,7 +1378,7 @@ class SQLBaseStore(object): """ return self.runInteraction( desc, - self._simple_select_list_paginate_txn, + self.simple_select_list_paginate_txn, table, keyvalues, orderby, @@ -1391,7 +1389,7 @@ class SQLBaseStore(object): ) @classmethod - def _simple_select_list_paginate_txn( + def simple_select_list_paginate_txn( cls, txn, table, @@ -1452,9 +1450,7 @@ class SQLBaseStore(object): txn.execute(sql_count) return txn.fetchone()[0] - def _simple_search_list( - self, table, term, col, retcols, desc="_simple_search_list" - ): + def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. @@ -1469,11 +1465,11 @@ class SQLBaseStore(object): """ return self.runInteraction( - desc, self._simple_search_list_txn, table, term, col, retcols + desc, self.simple_search_list_txn, table, term, col, retcols ) @classmethod - def _simple_search_list_txn(cls, txn, table, term, col, retcols): + def simple_search_list_txn(cls, txn, table, term, col, retcols): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 37d469ffd7..06955a0537 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -139,7 +139,7 @@ class BackgroundUpdateStore(SQLBaseStore): # otherwise, check if there are updates to be run. This is important, # as we may be running on a worker which doesn't perform the bg updates # itself, but still wants to wait for them to happen. - updates = yield self._simple_select_onecol( + updates = yield self.simple_select_onecol( "background_updates", keyvalues=None, retcol="1", @@ -161,7 +161,7 @@ class BackgroundUpdateStore(SQLBaseStore): if update_name in self._background_update_queue: return False - update_exists = await self._simple_select_one_onecol( + update_exists = await self.simple_select_one_onecol( "background_updates", keyvalues={"update_name": update_name}, retcol="1", @@ -184,7 +184,7 @@ class BackgroundUpdateStore(SQLBaseStore): no more work to do. """ if not self._background_update_queue: - updates = yield self._simple_select_list( + updates = yield self.simple_select_list( "background_updates", keyvalues=None, retcols=("update_name", "depends_on"), @@ -226,7 +226,7 @@ class BackgroundUpdateStore(SQLBaseStore): else: batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE - progress_json = yield self._simple_select_one_onecol( + progress_json = yield self.simple_select_one_onecol( "background_updates", keyvalues={"update_name": update_name}, retcol="progress_json", @@ -413,7 +413,7 @@ class BackgroundUpdateStore(SQLBaseStore): self._background_update_queue = [] progress_json = json.dumps(progress) - return self._simple_insert( + return self.simple_insert( "background_updates", {"update_name": update_name, "progress_json": progress_json}, ) @@ -429,7 +429,7 @@ class BackgroundUpdateStore(SQLBaseStore): self._background_update_queue = [ name for name in self._background_update_queue if name != update_name ] - return self._simple_delete_one( + return self.simple_delete_one( "background_updates", keyvalues={"update_name": update_name} ) @@ -444,7 +444,7 @@ class BackgroundUpdateStore(SQLBaseStore): progress_json = json.dumps(progress) - self._simple_update_one_txn( + self.simple_update_one_txn( txn, "background_updates", keyvalues={"update_name": update_name}, diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index 474924c68f..2a5b33dda1 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -173,7 +173,7 @@ class DataStore( self._presence_on_startup = self._get_active_presence(db_conn) - presence_cache_prefill, min_presence_val = self._get_cache_dict( + presence_cache_prefill, min_presence_val = self.get_cache_dict( db_conn, "presence_stream", entity_column="user_id", @@ -187,7 +187,7 @@ class DataStore( ) max_device_inbox_id = self._device_inbox_id_gen.get_current_token() - device_inbox_prefill, min_device_inbox_id = self._get_cache_dict( + device_inbox_prefill, min_device_inbox_id = self.get_cache_dict( db_conn, "device_inbox", entity_column="user_id", @@ -202,7 +202,7 @@ class DataStore( ) # The federation outbox and the local device inbox uses the same # stream_id generator. - device_outbox_prefill, min_device_outbox_id = self._get_cache_dict( + device_outbox_prefill, min_device_outbox_id = self.get_cache_dict( db_conn, "device_federation_outbox", entity_column="destination", @@ -228,7 +228,7 @@ class DataStore( ) events_max = self._stream_id_gen.get_current_token() - curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict( + curr_state_delta_prefill, min_curr_state_delta_id = self.get_cache_dict( db_conn, "current_state_delta_stream", entity_column="room_id", @@ -242,7 +242,7 @@ class DataStore( prefilled_cache=curr_state_delta_prefill, ) - _group_updates_prefill, min_group_updates_id = self._get_cache_dict( + _group_updates_prefill, min_group_updates_id = self.get_cache_dict( db_conn, "local_group_updates", entity_column="user_id", @@ -482,7 +482,7 @@ class DataStore( Returns: defer.Deferred: resolves to list[dict[str, Any]] """ - return self._simple_select_list( + return self.simple_select_list( table="users", keyvalues={}, retcols=["name", "password_hash", "is_guest", "admin", "user_type"], @@ -504,7 +504,7 @@ class DataStore( """ users = yield self.runInteraction( "get_users_paginate", - self._simple_select_list_paginate_txn, + self.simple_select_list_paginate_txn, table="users", keyvalues={"is_guest": False}, orderby=order, @@ -526,7 +526,7 @@ class DataStore( Returns: defer.Deferred: resolves to list[dict[str, Any]] """ - return self._simple_search_list( + return self.simple_search_list( table="users", term=term, col="name", diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py index 22093484ed..b0d22faf3f 100644 --- a/synapse/storage/data_stores/main/account_data.py +++ b/synapse/storage/data_stores/main/account_data.py @@ -67,7 +67,7 @@ class AccountDataWorkerStore(SQLBaseStore): """ def get_account_data_for_user_txn(txn): - rows = self._simple_select_list_txn( + rows = self.simple_select_list_txn( txn, "account_data", {"user_id": user_id}, @@ -78,7 +78,7 @@ class AccountDataWorkerStore(SQLBaseStore): row["account_data_type"]: json.loads(row["content"]) for row in rows } - rows = self._simple_select_list_txn( + rows = self.simple_select_list_txn( txn, "room_account_data", {"user_id": user_id}, @@ -102,7 +102,7 @@ class AccountDataWorkerStore(SQLBaseStore): Returns: Deferred: A dict """ - result = yield self._simple_select_one_onecol( + result = yield self.simple_select_one_onecol( table="account_data", keyvalues={"user_id": user_id, "account_data_type": data_type}, retcol="content", @@ -127,7 +127,7 @@ class AccountDataWorkerStore(SQLBaseStore): """ def get_account_data_for_room_txn(txn): - rows = self._simple_select_list_txn( + rows = self.simple_select_list_txn( txn, "room_account_data", {"user_id": user_id, "room_id": room_id}, @@ -156,7 +156,7 @@ class AccountDataWorkerStore(SQLBaseStore): """ def get_account_data_for_room_and_type_txn(txn): - content_json = self._simple_select_one_onecol_txn( + content_json = self.simple_select_one_onecol_txn( txn, table="room_account_data", keyvalues={ @@ -300,9 +300,9 @@ class AccountDataStore(AccountDataWorkerStore): with self._account_data_id_gen.get_next() as next_id: # no need to lock here as room_account_data has a unique constraint - # on (user_id, room_id, account_data_type) so _simple_upsert will + # on (user_id, room_id, account_data_type) so simple_upsert will # retry if there is a conflict. - yield self._simple_upsert( + yield self.simple_upsert( desc="add_room_account_data", table="room_account_data", keyvalues={ @@ -346,9 +346,9 @@ class AccountDataStore(AccountDataWorkerStore): with self._account_data_id_gen.get_next() as next_id: # no need to lock here as account_data has a unique constraint on - # (user_id, account_data_type) so _simple_upsert will retry if + # (user_id, account_data_type) so simple_upsert will retry if # there is a conflict. - yield self._simple_upsert( + yield self.simple_upsert( desc="add_user_account_data", table="account_data", keyvalues={"user_id": user_id, "account_data_type": account_data_type}, diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py index 81babf2029..6b82fd392a 100644 --- a/synapse/storage/data_stores/main/appservice.py +++ b/synapse/storage/data_stores/main/appservice.py @@ -133,7 +133,7 @@ class ApplicationServiceTransactionWorkerStore( A Deferred which resolves to a list of ApplicationServices, which may be empty. """ - results = yield self._simple_select_list( + results = yield self.simple_select_list( "application_services_state", dict(state=state), ["as_id"] ) # NB: This assumes this class is linked with ApplicationServiceStore @@ -155,7 +155,7 @@ class ApplicationServiceTransactionWorkerStore( Returns: A Deferred which resolves to ApplicationServiceState. """ - result = yield self._simple_select_one( + result = yield self.simple_select_one( "application_services_state", dict(as_id=service.id), ["state"], @@ -175,7 +175,7 @@ class ApplicationServiceTransactionWorkerStore( Returns: A Deferred which resolves when the state was set successfully. """ - return self._simple_upsert( + return self.simple_upsert( "application_services_state", dict(as_id=service.id), dict(state=state) ) @@ -249,7 +249,7 @@ class ApplicationServiceTransactionWorkerStore( ) # Set current txn_id for AS to 'txn_id' - self._simple_upsert_txn( + self.simple_upsert_txn( txn, "application_services_state", dict(as_id=service.id), @@ -257,7 +257,7 @@ class ApplicationServiceTransactionWorkerStore( ) # Delete txn - self._simple_delete_txn( + self.simple_delete_txn( txn, "application_services_txns", dict(txn_id=txn_id, as_id=service.id) ) diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py index 258c08722a..de3256049d 100644 --- a/synapse/storage/data_stores/main/cache.py +++ b/synapse/storage/data_stores/main/cache.py @@ -95,7 +95,7 @@ class CacheInvalidationStore(SQLBaseStore): txn.call_after(ctx.__exit__, None, None, None) txn.call_after(self.hs.get_notifier().on_new_replication_data) - self._simple_insert_txn( + self.simple_insert_txn( txn, table="cache_invalidation_stream", values={ diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index cae93b0e22..66522a04b7 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -431,7 +431,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry try: - self._simple_upsert_txn( + self.simple_upsert_txn( txn, table="user_ips", keyvalues={ @@ -450,7 +450,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): # Technically an access token might not be associated with # a device so we need to check. if device_id: - self._simple_upsert_txn( + self.simple_upsert_txn( txn, table="devices", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -483,7 +483,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): if device_id is not None: keyvalues["device_id"] = device_id - res = yield self._simple_select_list( + res = yield self.simple_select_list( table="devices", keyvalues=keyvalues, retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), @@ -516,7 +516,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): user_agent, _, last_seen = self._batch_row_update[key] results[(access_token, ip)] = (user_agent, last_seen) - rows = yield self._simple_select_list( + rows = yield self.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "last_seen"], diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py index a23744f11c..206d39134d 100644 --- a/synapse/storage/data_stores/main/deviceinbox.py +++ b/synapse/storage/data_stores/main/deviceinbox.py @@ -314,7 +314,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) # Check if we've already inserted a matching message_id for that # origin. This can happen if the origin doesn't receive our # acknowledgement from the first time we received the message. - already_inserted = self._simple_select_one_txn( + already_inserted = self.simple_select_one_txn( txn, table="device_federation_inbox", keyvalues={"origin": origin, "message_id": message_id}, @@ -326,7 +326,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) # Add an entry for this message_id so that we know we've processed # it. - self._simple_insert_txn( + self.simple_insert_txn( txn, table="device_federation_inbox", values={ diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index a3ad23e783..727c582121 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -61,7 +61,7 @@ class DeviceWorkerStore(SQLBaseStore): Raises: StoreError: if the device is not found """ - return self._simple_select_one( + return self.simple_select_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), @@ -80,7 +80,7 @@ class DeviceWorkerStore(SQLBaseStore): containing "device_id", "user_id" and "display_name" for each device. """ - devices = yield self._simple_select_list( + devices = yield self.simple_select_list( table="devices", keyvalues={"user_id": user_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), @@ -414,7 +414,7 @@ class DeviceWorkerStore(SQLBaseStore): from_user_id, stream_id, ) - self._simple_insert_txn( + self.simple_insert_txn( txn, "user_signature_stream", values={ @@ -466,7 +466,7 @@ class DeviceWorkerStore(SQLBaseStore): @cachedInlineCallbacks(num_args=2, tree=True) def _get_cached_user_device(self, user_id, device_id): - content = yield self._simple_select_one_onecol( + content = yield self.simple_select_one_onecol( table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, retcol="content", @@ -476,7 +476,7 @@ class DeviceWorkerStore(SQLBaseStore): @cachedInlineCallbacks() def _get_cached_devices_for_user(self, user_id): - devices = yield self._simple_select_list( + devices = yield self.simple_select_list( table="device_lists_remote_cache", keyvalues={"user_id": user_id}, retcols=("device_id", "content"), @@ -584,7 +584,7 @@ class DeviceWorkerStore(SQLBaseStore): SELECT DISTINCT user_ids FROM user_signature_stream WHERE from_user_id = ? AND stream_id > ? """ - rows = yield self._execute( + rows = yield self.execute( "get_users_whose_signatures_changed", None, sql, user_id, from_key ) return set(user for row in rows for user in json.loads(row[0])) @@ -605,7 +605,7 @@ class DeviceWorkerStore(SQLBaseStore): WHERE ? < stream_id AND stream_id <= ? GROUP BY user_id, destination """ - return self._execute( + return self.execute( "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key ) @@ -614,7 +614,7 @@ class DeviceWorkerStore(SQLBaseStore): """Get the last stream_id we got for a user. May be None if we haven't got any information for them. """ - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, retcol="stream_id", @@ -628,7 +628,7 @@ class DeviceWorkerStore(SQLBaseStore): inlineCallbacks=True, ) def get_device_list_last_stream_id_for_remotes(self, user_ids): - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="device_lists_remote_extremeties", column="user_id", iterable=user_ids, @@ -722,7 +722,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return False try: - inserted = yield self._simple_insert( + inserted = yield self.simple_insert( "devices", values={ "user_id": user_id, @@ -736,7 +736,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if not inserted: # if the device already exists, check if it's a real device, or # if the device ID is reserved by something else - hidden = yield self._simple_select_one_onecol( + hidden = yield self.simple_select_one_onecol( "devices", keyvalues={"user_id": user_id, "device_id": device_id}, retcol="hidden", @@ -771,7 +771,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): Returns: defer.Deferred """ - yield self._simple_delete_one( + yield self.simple_delete_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, desc="delete_device", @@ -789,7 +789,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): Returns: defer.Deferred """ - yield self._simple_delete_many( + yield self.simple_delete_many( table="devices", column="device_id", iterable=device_ids, @@ -818,7 +818,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): updates["display_name"] = new_display_name if not updates: return defer.succeed(None) - return self._simple_update_one( + return self.simple_update_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, updatevalues=updates, @@ -829,7 +829,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): def mark_remote_user_device_list_as_unsubscribed(self, user_id): """Mark that we no longer track device lists for remote user. """ - yield self._simple_delete( + yield self.simple_delete( table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, desc="mark_remote_user_device_list_as_unsubscribed", @@ -866,7 +866,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self, txn, user_id, device_id, content, stream_id ): if content.get("deleted"): - self._simple_delete_txn( + self.simple_delete_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -874,7 +874,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id)) else: - self._simple_upsert_txn( + self.simple_upsert_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -890,7 +890,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) ) - self._simple_upsert_txn( + self.simple_upsert_txn( txn, table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, @@ -923,11 +923,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id): - self._simple_delete_txn( + self.simple_delete_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id} ) - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="device_lists_remote_cache", values=[ @@ -946,7 +946,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) ) - self._simple_upsert_txn( + self.simple_upsert_txn( txn, table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, @@ -995,7 +995,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): [(user_id, device_id, stream_id) for device_id in device_ids], ) - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="device_lists_stream", values=[ @@ -1006,7 +1006,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): context = get_active_span_text_map() - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="device_lists_outbound_pokes", values=[ diff --git a/synapse/storage/data_stores/main/directory.py b/synapse/storage/data_stores/main/directory.py index 297966d9f4..d332f8a409 100644 --- a/synapse/storage/data_stores/main/directory.py +++ b/synapse/storage/data_stores/main/directory.py @@ -36,7 +36,7 @@ class DirectoryWorkerStore(SQLBaseStore): Deferred: results in namedtuple with keys "room_id" and "servers" or None if no association can be found """ - room_id = yield self._simple_select_one_onecol( + room_id = yield self.simple_select_one_onecol( "room_aliases", {"room_alias": room_alias.to_string()}, "room_id", @@ -47,7 +47,7 @@ class DirectoryWorkerStore(SQLBaseStore): if not room_id: return None - servers = yield self._simple_select_onecol( + servers = yield self.simple_select_onecol( "room_alias_servers", {"room_alias": room_alias.to_string()}, "server", @@ -60,7 +60,7 @@ class DirectoryWorkerStore(SQLBaseStore): return RoomAliasMapping(room_id, room_alias.to_string(), servers) def get_room_alias_creator(self, room_alias): - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="room_aliases", keyvalues={"room_alias": room_alias}, retcol="creator", @@ -69,7 +69,7 @@ class DirectoryWorkerStore(SQLBaseStore): @cached(max_entries=5000) def get_aliases_for_room(self, room_id): - return self._simple_select_onecol( + return self.simple_select_onecol( "room_aliases", {"room_id": room_id}, "room_alias", @@ -93,7 +93,7 @@ class DirectoryStore(DirectoryWorkerStore): """ def alias_txn(txn): - self._simple_insert_txn( + self.simple_insert_txn( txn, "room_aliases", { @@ -103,7 +103,7 @@ class DirectoryStore(DirectoryWorkerStore): }, ) - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="room_alias_servers", values=[ diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py index 113224fd7c..df89eda337 100644 --- a/synapse/storage/data_stores/main/e2e_room_keys.py +++ b/synapse/storage/data_stores/main/e2e_room_keys.py @@ -38,7 +38,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): StoreError """ - yield self._simple_update_one( + yield self.simple_update_one( table="e2e_room_keys", keyvalues={ "user_id": user_id, @@ -89,7 +89,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): } ) - yield self._simple_insert_many( + yield self.simple_insert_many( table="e2e_room_keys", values=values, desc="add_e2e_room_keys" ) @@ -125,7 +125,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): if session_id: keyvalues["session_id"] = session_id - rows = yield self._simple_select_list( + rows = yield self.simple_select_list( table="e2e_room_keys", keyvalues=keyvalues, retcols=( @@ -234,7 +234,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): version (str): the version ID of the backup we're querying about """ - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="e2e_room_keys", keyvalues={"user_id": user_id, "version": version}, retcol="COUNT(*)", @@ -267,7 +267,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): if session_id: keyvalues["session_id"] = session_id - yield self._simple_delete( + yield self.simple_delete( table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys" ) @@ -312,7 +312,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): # it isn't there. raise StoreError(404, "No row found") - result = self._simple_select_one_txn( + result = self.simple_select_one_txn( txn, table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": this_version, "deleted": 0}, @@ -352,7 +352,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): new_version = str(int(current_version) + 1) - self._simple_insert_txn( + self.simple_insert_txn( txn, table="e2e_room_keys_versions", values={ @@ -391,7 +391,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): updatevalues["etag"] = version_etag if updatevalues: - return self._simple_update( + return self.simple_update( table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": version}, updatevalues=updatevalues, @@ -420,13 +420,13 @@ class EndToEndRoomKeyStore(SQLBaseStore): else: this_version = version - self._simple_delete_txn( + self.simple_delete_txn( txn, table="e2e_room_keys", keyvalues={"user_id": user_id, "version": this_version}, ) - return self._simple_update_one_txn( + return self.simple_update_one_txn( txn, table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": this_version}, diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py index 643327b57b..08bcdc4725 100644 --- a/synapse/storage/data_stores/main/end_to_end_keys.py +++ b/synapse/storage/data_stores/main/end_to_end_keys.py @@ -186,7 +186,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): key_id) to json string for key """ - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="e2e_one_time_keys_json", column="key_id", iterable=key_ids, @@ -219,7 +219,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): # a unique constraint. If there is a race of two calls to # `add_e2e_one_time_keys` then they'll conflict and we will only # insert one set. - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="e2e_one_time_keys_json", values=[ @@ -350,7 +350,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): WHERE ? < stream_id AND stream_id <= ? GROUP BY user_id """ - return self._execute( + return self.execute( "get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key ) @@ -367,7 +367,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): set_tag("time_now", time_now) set_tag("device_keys", device_keys) - old_key_json = self._simple_select_one_onecol_txn( + old_key_json = self.simple_select_one_onecol_txn( txn, table="e2e_device_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -383,7 +383,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): log_kv({"Message": "Device key already stored."}) return False - self._simple_upsert_txn( + self.simple_upsert_txn( txn, table="e2e_device_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -442,12 +442,12 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): "user_id": user_id, } ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="e2e_device_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="e2e_one_time_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -492,7 +492,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): # The "keys" property must only have one entry, which will be the public # key, so we just grab the first value in there pubkey = next(iter(key["keys"].values())) - self._simple_insert_txn( + self.simple_insert_txn( txn, "devices", values={ @@ -505,7 +505,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): # and finally, store the key itself with self._cross_signing_id_gen.get_next() as stream_id: - self._simple_insert_txn( + self.simple_insert_txn( txn, "e2e_cross_signing_keys", values={ @@ -539,7 +539,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): user_id (str): the user who made the signatures signatures (iterable[SignatureListItem]): signatures to add """ - return self._simple_insert_many( + return self.simple_insert_many( "e2e_cross_signing_signatures", [ { diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 90bef0cd2c..051ac7a8cb 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -126,7 +126,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas Returns Deferred[int] """ - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="events", column="event_id", iterable=event_ids, @@ -140,7 +140,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return max(row["depth"] for row in rows) def _get_oldest_events_in_room_txn(self, txn, room_id): - return self._simple_select_onecol_txn( + return self.simple_select_onecol_txn( txn, table="event_backward_extremities", keyvalues={"room_id": room_id}, @@ -235,7 +235,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas @cached(max_entries=5000, iterable=True) def get_latest_event_ids_in_room(self, room_id): - return self._simple_select_onecol( + return self.simple_select_onecol( table="event_forward_extremities", keyvalues={"room_id": room_id}, retcol="event_id", @@ -271,7 +271,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas ) def _get_min_depth_interaction(self, txn, room_id): - min_depth = self._simple_select_one_onecol_txn( + min_depth = self.simple_select_one_onecol_txn( txn, table="room_depth", keyvalues={"room_id": room_id}, @@ -383,7 +383,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas queue = PriorityQueue() for event_id in event_list: - depth = self._simple_select_one_onecol_txn( + depth = self.simple_select_one_onecol_txn( txn, table="events", keyvalues={"event_id": event_id, "room_id": room_id}, @@ -468,7 +468,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas Returns: Deferred[list[str]] """ - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="event_edges", column="prev_event_id", iterable=event_ids, @@ -508,7 +508,7 @@ class EventFederationStore(EventFederationWorkerStore): if min_depth and depth >= min_depth: return - self._simple_upsert_txn( + self.simple_upsert_txn( txn, table="room_depth", keyvalues={"room_id": room_id}, @@ -520,7 +520,7 @@ class EventFederationStore(EventFederationWorkerStore): For the given event, update the event edges table and forward and backward extremities tables. """ - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="event_edges", values=[ diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index 04ce21ac66..0a37847cfd 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -441,7 +441,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): ) def _add_push_actions_to_staging_txn(txn): - # We don't use _simple_insert_many here to avoid the overhead + # We don't use simple_insert_many here to avoid the overhead # of generating lists of dicts. sql = """ @@ -472,7 +472,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): """ try: - res = yield self._simple_delete( + res = yield self.simple_delete( table="event_push_actions_staging", keyvalues={"event_id": event_id}, desc="remove_push_actions_from_staging", @@ -677,7 +677,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): ) for event, _ in events_and_contexts: - user_ids = self._simple_select_onecol_txn( + user_ids = self.simple_select_onecol_txn( txn, table="event_push_actions_staging", keyvalues={"event_id": event.event_id}, @@ -844,7 +844,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): the archiving process has caught up or not. """ - old_rotate_stream_ordering = self._simple_select_one_onecol_txn( + old_rotate_stream_ordering = self.simple_select_one_onecol_txn( txn, table="event_push_summary_stream_ordering", keyvalues={}, @@ -880,7 +880,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): return caught_up def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering): - old_rotate_stream_ordering = self._simple_select_one_onecol_txn( + old_rotate_stream_ordering = self.simple_select_one_onecol_txn( txn, table="event_push_summary_stream_ordering", keyvalues={}, @@ -912,7 +912,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): # If the `old.user_id` above is NULL then we know there isn't already an # entry in the table, so we simply insert it. Otherwise we update the # existing table. - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="event_push_summary", values=[ diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 79c91fe284..98ae69e996 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -432,7 +432,7 @@ class EventsStore( # event's auth chain, but its easier for now just to store them (and # it doesn't take much storage compared to storing the entire event # anyway). - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="event_auth", values=[ @@ -580,12 +580,12 @@ class EventsStore( self, txn, new_forward_extremities, max_stream_order ): for room_id, new_extrem in iteritems(new_forward_extremities): - self._simple_delete_txn( + self.simple_delete_txn( txn, table="event_forward_extremities", keyvalues={"room_id": room_id} ) txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="event_forward_extremities", values=[ @@ -598,7 +598,7 @@ class EventsStore( # new stream_ordering to new forward extremeties in the room. # This allows us to later efficiently look up the forward extremeties # for a room before a given stream_ordering - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="stream_ordering_to_exterm", values=[ @@ -722,7 +722,7 @@ class EventsStore( # change in outlier status to our workers. stream_order = event.internal_metadata.stream_ordering state_group_id = context.state_group - self._simple_insert_txn( + self.simple_insert_txn( txn, table="ex_outlier_stream", values={ @@ -794,7 +794,7 @@ class EventsStore( d.pop("redacted_because", None) return d - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="event_json", values=[ @@ -811,7 +811,7 @@ class EventsStore( ], ) - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="events", values=[ @@ -841,7 +841,7 @@ class EventsStore( # If we're persisting an unredacted event we go and ensure # that we mark any redactions that reference this event as # requiring censoring. - self._simple_update_txn( + self.simple_update_txn( txn, table="redactions", keyvalues={"redacts": event.event_id}, @@ -983,7 +983,7 @@ class EventsStore( state_values.append(vals) - self._simple_insert_many_txn(txn, table="state_events", values=state_values) + self.simple_insert_many_txn(txn, table="state_events", values=state_values) # Prefill the event cache self._add_to_cache(txn, events_and_contexts) @@ -1032,7 +1032,7 @@ class EventsStore( # invalidate the cache for the redacted event txn.call_after(self._invalidate_get_event_cache, event.redacts) - self._simple_insert_txn( + self.simple_insert_txn( txn, table="redactions", values={ @@ -1077,9 +1077,7 @@ class EventsStore( LIMIT ? """ - rows = yield self._execute( - "_censor_redactions_fetch", None, sql, before_ts, 100 - ) + rows = yield self.execute("_censor_redactions_fetch", None, sql, before_ts, 100) updates = [] @@ -1111,7 +1109,7 @@ class EventsStore( if pruned_json: self._censor_event_txn(txn, event_id, pruned_json) - self._simple_update_one_txn( + self.simple_update_one_txn( txn, table="redactions", keyvalues={"event_id": redaction_id}, @@ -1129,7 +1127,7 @@ class EventsStore( event_id (str): The ID of the event to censor. pruned_json (str): The pruned JSON """ - self._simple_update_one_txn( + self.simple_update_one_txn( txn, table="event_json", keyvalues={"event_id": event_id}, @@ -1780,7 +1778,7 @@ class EventsStore( "[purge] found %i state groups to delete", len(state_groups_to_delete) ) - rows = self._simple_select_many_txn( + rows = self.simple_select_many_txn( txn, table="state_group_edges", column="prev_state_group", @@ -1807,15 +1805,15 @@ class EventsStore( curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) curr_state = curr_state[sg] - self._simple_delete_txn( + self.simple_delete_txn( txn, table="state_groups_state", keyvalues={"state_group": sg} ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="state_group_edges", keyvalues={"state_group": sg} ) - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -1852,7 +1850,7 @@ class EventsStore( state group. """ - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="state_group_edges", column="prev_state_group", iterable=state_groups, @@ -1882,7 +1880,7 @@ class EventsStore( # first we have to delete the state groups states logger.info("[purge] removing %s from state_groups_state", room_id) - self._simple_delete_many_txn( + self.simple_delete_many_txn( txn, table="state_groups_state", column="state_group", @@ -1893,7 +1891,7 @@ class EventsStore( # ... and the state group edges logger.info("[purge] removing %s from state_group_edges", room_id) - self._simple_delete_many_txn( + self.simple_delete_many_txn( txn, table="state_group_edges", column="state_group", @@ -1904,7 +1902,7 @@ class EventsStore( # ... and the state groups logger.info("[purge] removing %s from state_groups", room_id) - self._simple_delete_many_txn( + self.simple_delete_many_txn( txn, table="state_groups", column="id", @@ -1921,7 +1919,7 @@ class EventsStore( @cachedInlineCallbacks(max_entries=5000) def _get_event_ordering(self, event_id): - res = yield self._simple_select_one( + res = yield self.simple_select_one( table="events", retcols=["topological_ordering", "stream_ordering"], keyvalues={"event_id": event_id}, @@ -1962,7 +1960,7 @@ class EventsStore( room_id (str): The ID of the room the event was sent to. topological_ordering (int): The position of the event in the room's topology. """ - return self._simple_insert_many_txn( + return self.simple_insert_many_txn( txn=txn, table="event_labels", values=[ @@ -1984,7 +1982,7 @@ class EventsStore( event_id (str): The event ID the expiry timestamp is associated with. expiry_ts (int): The timestamp at which to expire (delete) the event. """ - return self._simple_insert_txn( + return self.simple_insert_txn( txn=txn, table="event_expiry", values={"event_id": event_id, "expiry_ts": expiry_ts}, @@ -2043,7 +2041,7 @@ class EventsStore( txn (LoggingTransaction): The transaction to use to perform the deletion. event_id (str): The event ID to delete the associated expiry timestamp of. """ - return self._simple_delete_txn( + return self.simple_delete_txn( txn=txn, table="event_expiry", keyvalues={"event_id": event_id} ) diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py index aa87f9abc5..37dfc8c871 100644 --- a/synapse/storage/data_stores/main/events_bg_updates.py +++ b/synapse/storage/data_stores/main/events_bg_updates.py @@ -189,7 +189,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)] for chunk in chunks: - ev_rows = self._simple_select_many_txn( + ev_rows = self.simple_select_many_txn( txn, table="event_json", column="event_id", @@ -366,7 +366,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): to_delete.intersection_update(original_set) - deleted = self._simple_delete_many_txn( + deleted = self.simple_delete_many_txn( txn=txn, table="event_forward_extremities", column="event_id", @@ -382,7 +382,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): if deleted: # We now need to invalidate the caches of these rooms - rows = self._simple_select_many_txn( + rows = self.simple_select_many_txn( txn, table="events", column="event_id", @@ -396,7 +396,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): self.get_latest_event_ids_in_room.invalidate, (room_id,) ) - self._simple_delete_many_txn( + self.simple_delete_many_txn( txn=txn, table="_extremities_to_check", column="event_id", @@ -533,7 +533,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): try: event_json = json.loads(event_json_raw) - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn=txn, table="event_labels", values=[ diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index e782e8f481..ec4af29299 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -78,7 +78,7 @@ class EventsWorkerStore(SQLBaseStore): Deferred[int|None]: Timestamp in milliseconds, or None for events that were persisted before received_ts was implemented. """ - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="events", keyvalues={"event_id": event_id}, retcol="received_ts", @@ -452,7 +452,7 @@ class EventsWorkerStore(SQLBaseStore): event_id for events, _ in event_list for event_id in events ) - row_dict = self._new_transaction( + row_dict = self.new_transaction( conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch ) @@ -745,7 +745,7 @@ class EventsWorkerStore(SQLBaseStore): """Given a list of event ids, check if we have already processed and stored them as non outliers. """ - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="events", retcols=("event_id",), column="event_id", diff --git a/synapse/storage/data_stores/main/filtering.py b/synapse/storage/data_stores/main/filtering.py index f05ace299a..17ef7b9354 100644 --- a/synapse/storage/data_stores/main/filtering.py +++ b/synapse/storage/data_stores/main/filtering.py @@ -30,7 +30,7 @@ class FilteringStore(SQLBaseStore): except ValueError: raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM) - def_json = yield self._simple_select_one_onecol( + def_json = yield self.simple_select_one_onecol( table="user_filters", keyvalues={"user_id": user_localpart, "filter_id": filter_id}, retcol="filter_json", diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py index 5ded539af8..9e1d12bcb7 100644 --- a/synapse/storage/data_stores/main/group_server.py +++ b/synapse/storage/data_stores/main/group_server.py @@ -35,7 +35,7 @@ class GroupServerStore(SQLBaseStore): * "invite" * "open" """ - return self._simple_update_one( + return self.simple_update_one( table="groups", keyvalues={"group_id": group_id}, updatevalues={"join_policy": join_policy}, @@ -43,7 +43,7 @@ class GroupServerStore(SQLBaseStore): ) def get_group(self, group_id): - return self._simple_select_one( + return self.simple_select_one( table="groups", keyvalues={"group_id": group_id}, retcols=( @@ -65,7 +65,7 @@ class GroupServerStore(SQLBaseStore): if not include_private: keyvalues["is_public"] = True - return self._simple_select_list( + return self.simple_select_list( table="group_users", keyvalues=keyvalues, retcols=("user_id", "is_public", "is_admin"), @@ -75,7 +75,7 @@ class GroupServerStore(SQLBaseStore): def get_invited_users_in_group(self, group_id): # TODO: Pagination - return self._simple_select_onecol( + return self.simple_select_onecol( table="group_invites", keyvalues={"group_id": group_id}, retcol="user_id", @@ -89,7 +89,7 @@ class GroupServerStore(SQLBaseStore): if not include_private: keyvalues["is_public"] = True - return self._simple_select_list( + return self.simple_select_list( table="group_rooms", keyvalues=keyvalues, retcols=("room_id", "is_public"), @@ -180,7 +180,7 @@ class GroupServerStore(SQLBaseStore): an order of 1 will put the room first. Otherwise, the room gets added to the end. """ - room_in_group = self._simple_select_one_onecol_txn( + room_in_group = self.simple_select_one_onecol_txn( txn, table="group_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, @@ -193,7 +193,7 @@ class GroupServerStore(SQLBaseStore): if category_id is None: category_id = _DEFAULT_CATEGORY_ID else: - cat_exists = self._simple_select_one_onecol_txn( + cat_exists = self.simple_select_one_onecol_txn( txn, table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, @@ -204,7 +204,7 @@ class GroupServerStore(SQLBaseStore): raise SynapseError(400, "Category doesn't exist") # TODO: Check category is part of summary already - cat_exists = self._simple_select_one_onecol_txn( + cat_exists = self.simple_select_one_onecol_txn( txn, table="group_summary_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, @@ -224,7 +224,7 @@ class GroupServerStore(SQLBaseStore): (group_id, category_id, group_id, category_id), ) - existing = self._simple_select_one_txn( + existing = self.simple_select_one_txn( txn, table="group_summary_rooms", keyvalues={ @@ -257,7 +257,7 @@ class GroupServerStore(SQLBaseStore): to_update["room_order"] = order if is_public is not None: to_update["is_public"] = is_public - self._simple_update_txn( + self.simple_update_txn( txn, table="group_summary_rooms", keyvalues={ @@ -271,7 +271,7 @@ class GroupServerStore(SQLBaseStore): if is_public is None: is_public = True - self._simple_insert_txn( + self.simple_insert_txn( txn, table="group_summary_rooms", values={ @@ -287,7 +287,7 @@ class GroupServerStore(SQLBaseStore): if category_id is None: category_id = _DEFAULT_CATEGORY_ID - return self._simple_delete( + return self.simple_delete( table="group_summary_rooms", keyvalues={ "group_id": group_id, @@ -299,7 +299,7 @@ class GroupServerStore(SQLBaseStore): @defer.inlineCallbacks def get_group_categories(self, group_id): - rows = yield self._simple_select_list( + rows = yield self.simple_select_list( table="group_room_categories", keyvalues={"group_id": group_id}, retcols=("category_id", "is_public", "profile"), @@ -316,7 +316,7 @@ class GroupServerStore(SQLBaseStore): @defer.inlineCallbacks def get_group_category(self, group_id, category_id): - category = yield self._simple_select_one( + category = yield self.simple_select_one( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, retcols=("is_public", "profile"), @@ -343,7 +343,7 @@ class GroupServerStore(SQLBaseStore): else: update_values["is_public"] = is_public - return self._simple_upsert( + return self.simple_upsert( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, values=update_values, @@ -352,7 +352,7 @@ class GroupServerStore(SQLBaseStore): ) def remove_group_category(self, group_id, category_id): - return self._simple_delete( + return self.simple_delete( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, desc="remove_group_category", @@ -360,7 +360,7 @@ class GroupServerStore(SQLBaseStore): @defer.inlineCallbacks def get_group_roles(self, group_id): - rows = yield self._simple_select_list( + rows = yield self.simple_select_list( table="group_roles", keyvalues={"group_id": group_id}, retcols=("role_id", "is_public", "profile"), @@ -377,7 +377,7 @@ class GroupServerStore(SQLBaseStore): @defer.inlineCallbacks def get_group_role(self, group_id, role_id): - role = yield self._simple_select_one( + role = yield self.simple_select_one( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, retcols=("is_public", "profile"), @@ -404,7 +404,7 @@ class GroupServerStore(SQLBaseStore): else: update_values["is_public"] = is_public - return self._simple_upsert( + return self.simple_upsert( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, values=update_values, @@ -413,7 +413,7 @@ class GroupServerStore(SQLBaseStore): ) def remove_group_role(self, group_id, role_id): - return self._simple_delete( + return self.simple_delete( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, desc="remove_group_role", @@ -444,7 +444,7 @@ class GroupServerStore(SQLBaseStore): an order of 1 will put the user first. Otherwise, the user gets added to the end. """ - user_in_group = self._simple_select_one_onecol_txn( + user_in_group = self.simple_select_one_onecol_txn( txn, table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -457,7 +457,7 @@ class GroupServerStore(SQLBaseStore): if role_id is None: role_id = _DEFAULT_ROLE_ID else: - role_exists = self._simple_select_one_onecol_txn( + role_exists = self.simple_select_one_onecol_txn( txn, table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, @@ -468,7 +468,7 @@ class GroupServerStore(SQLBaseStore): raise SynapseError(400, "Role doesn't exist") # TODO: Check role is part of the summary already - role_exists = self._simple_select_one_onecol_txn( + role_exists = self.simple_select_one_onecol_txn( txn, table="group_summary_roles", keyvalues={"group_id": group_id, "role_id": role_id}, @@ -488,7 +488,7 @@ class GroupServerStore(SQLBaseStore): (group_id, role_id, group_id, role_id), ) - existing = self._simple_select_one_txn( + existing = self.simple_select_one_txn( txn, table="group_summary_users", keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id}, @@ -517,7 +517,7 @@ class GroupServerStore(SQLBaseStore): to_update["user_order"] = order if is_public is not None: to_update["is_public"] = is_public - self._simple_update_txn( + self.simple_update_txn( txn, table="group_summary_users", keyvalues={ @@ -531,7 +531,7 @@ class GroupServerStore(SQLBaseStore): if is_public is None: is_public = True - self._simple_insert_txn( + self.simple_insert_txn( txn, table="group_summary_users", values={ @@ -547,7 +547,7 @@ class GroupServerStore(SQLBaseStore): if role_id is None: role_id = _DEFAULT_ROLE_ID - return self._simple_delete( + return self.simple_delete( table="group_summary_users", keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id}, desc="remove_user_from_summary", @@ -561,7 +561,7 @@ class GroupServerStore(SQLBaseStore): Deferred[list[str]]: A twisted.Deferred containing a list of group ids containing this room """ - return self._simple_select_onecol( + return self.simple_select_onecol( table="group_rooms", keyvalues={"room_id": room_id}, retcol="group_id", @@ -630,7 +630,7 @@ class GroupServerStore(SQLBaseStore): ) def is_user_in_group(self, user_id, group_id): - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="user_id", @@ -639,7 +639,7 @@ class GroupServerStore(SQLBaseStore): ).addCallback(lambda r: bool(r)) def is_user_admin_in_group(self, group_id, user_id): - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="is_admin", @@ -650,7 +650,7 @@ class GroupServerStore(SQLBaseStore): def add_group_invite(self, group_id, user_id): """Record that the group server has invited a user """ - return self._simple_insert( + return self.simple_insert( table="group_invites", values={"group_id": group_id, "user_id": user_id}, desc="add_group_invite", @@ -659,7 +659,7 @@ class GroupServerStore(SQLBaseStore): def is_user_invited_to_local_group(self, group_id, user_id): """Has the group server invited a user? """ - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="user_id", @@ -682,7 +682,7 @@ class GroupServerStore(SQLBaseStore): """ def _get_users_membership_in_group_txn(txn): - row = self._simple_select_one_txn( + row = self.simple_select_one_txn( txn, table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -697,7 +697,7 @@ class GroupServerStore(SQLBaseStore): "is_privileged": row["is_admin"], } - row = self._simple_select_one_onecol_txn( + row = self.simple_select_one_onecol_txn( txn, table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -738,7 +738,7 @@ class GroupServerStore(SQLBaseStore): """ def _add_user_to_group_txn(txn): - self._simple_insert_txn( + self.simple_insert_txn( txn, table="group_users", values={ @@ -749,14 +749,14 @@ class GroupServerStore(SQLBaseStore): }, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, ) if local_attestation: - self._simple_insert_txn( + self.simple_insert_txn( txn, table="group_attestations_renewals", values={ @@ -766,7 +766,7 @@ class GroupServerStore(SQLBaseStore): }, ) if remote_attestation: - self._simple_insert_txn( + self.simple_insert_txn( txn, table="group_attestations_remote", values={ @@ -781,27 +781,27 @@ class GroupServerStore(SQLBaseStore): def remove_user_from_group(self, group_id, user_id): def _remove_user_from_group_txn(txn): - self._simple_delete_txn( + self.simple_delete_txn( txn, table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="group_summary_users", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -812,14 +812,14 @@ class GroupServerStore(SQLBaseStore): ) def add_room_to_group(self, group_id, room_id, is_public): - return self._simple_insert( + return self.simple_insert( table="group_rooms", values={"group_id": group_id, "room_id": room_id, "is_public": is_public}, desc="add_room_to_group", ) def update_room_in_group_visibility(self, group_id, room_id, is_public): - return self._simple_update( + return self.simple_update( table="group_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, updatevalues={"is_public": is_public}, @@ -828,13 +828,13 @@ class GroupServerStore(SQLBaseStore): def remove_room_from_group(self, group_id, room_id): def _remove_room_from_group_txn(txn): - self._simple_delete_txn( + self.simple_delete_txn( txn, table="group_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="group_summary_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, @@ -847,7 +847,7 @@ class GroupServerStore(SQLBaseStore): def get_publicised_groups_for_user(self, user_id): """Get all groups a user is publicising """ - return self._simple_select_onecol( + return self.simple_select_onecol( table="local_group_membership", keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True}, retcol="group_id", @@ -857,7 +857,7 @@ class GroupServerStore(SQLBaseStore): def update_group_publicity(self, group_id, user_id, publicise): """Update whether the user is publicising their membership of the group """ - return self._simple_update_one( + return self.simple_update_one( table="local_group_membership", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={"is_publicised": publicise}, @@ -893,12 +893,12 @@ class GroupServerStore(SQLBaseStore): def _register_user_group_membership_txn(txn, next_id): # TODO: Upsert? - self._simple_delete_txn( + self.simple_delete_txn( txn, table="local_group_membership", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_insert_txn( + self.simple_insert_txn( txn, table="local_group_membership", values={ @@ -911,7 +911,7 @@ class GroupServerStore(SQLBaseStore): }, ) - self._simple_insert_txn( + self.simple_insert_txn( txn, table="local_group_updates", values={ @@ -930,7 +930,7 @@ class GroupServerStore(SQLBaseStore): if membership == "join": if local_attestation: - self._simple_insert_txn( + self.simple_insert_txn( txn, table="group_attestations_renewals", values={ @@ -940,7 +940,7 @@ class GroupServerStore(SQLBaseStore): }, ) if remote_attestation: - self._simple_insert_txn( + self.simple_insert_txn( txn, table="group_attestations_remote", values={ @@ -951,12 +951,12 @@ class GroupServerStore(SQLBaseStore): }, ) else: - self._simple_delete_txn( + self.simple_delete_txn( txn, table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -976,7 +976,7 @@ class GroupServerStore(SQLBaseStore): def create_group( self, group_id, user_id, name, avatar_url, short_description, long_description ): - yield self._simple_insert( + yield self.simple_insert( table="groups", values={ "group_id": group_id, @@ -991,7 +991,7 @@ class GroupServerStore(SQLBaseStore): @defer.inlineCallbacks def update_group_profile(self, group_id, profile): - yield self._simple_update_one( + yield self.simple_update_one( table="groups", keyvalues={"group_id": group_id}, updatevalues=profile, @@ -1017,7 +1017,7 @@ class GroupServerStore(SQLBaseStore): def update_attestation_renewal(self, group_id, user_id, attestation): """Update an attestation that we have renewed """ - return self._simple_update_one( + return self.simple_update_one( table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={"valid_until_ms": attestation["valid_until_ms"]}, @@ -1027,7 +1027,7 @@ class GroupServerStore(SQLBaseStore): def update_remote_attestion(self, group_id, user_id, attestation): """Update an attestation that a remote has renewed """ - return self._simple_update_one( + return self.simple_update_one( table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={ @@ -1046,7 +1046,7 @@ class GroupServerStore(SQLBaseStore): group_id (str) user_id (str) """ - return self._simple_delete( + return self.simple_delete( table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, desc="remove_attestation_renewal", @@ -1057,7 +1057,7 @@ class GroupServerStore(SQLBaseStore): """Get the attestation that proves the remote agrees that the user is in the group. """ - row = yield self._simple_select_one( + row = yield self.simple_select_one( table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, retcols=("valid_until_ms", "attestation_json"), @@ -1072,7 +1072,7 @@ class GroupServerStore(SQLBaseStore): return None def get_joined_groups(self, user_id): - return self._simple_select_onecol( + return self.simple_select_onecol( table="local_group_membership", keyvalues={"user_id": user_id, "membership": "join"}, retcol="group_id", @@ -1188,7 +1188,7 @@ class GroupServerStore(SQLBaseStore): ] for table in tables: - self._simple_delete_txn( + self.simple_delete_txn( txn, table=table, keyvalues={"group_id": group_id} ) diff --git a/synapse/storage/data_stores/main/keys.py b/synapse/storage/data_stores/main/keys.py index ebc7db3ed6..c7150432b3 100644 --- a/synapse/storage/data_stores/main/keys.py +++ b/synapse/storage/data_stores/main/keys.py @@ -129,7 +129,7 @@ class KeyStore(SQLBaseStore): return self.runInteraction( "store_server_verify_keys", - self._simple_upsert_many_txn, + self.simple_upsert_many_txn, table="server_signature_keys", key_names=("server_name", "key_id"), key_values=key_values, @@ -157,7 +157,7 @@ class KeyStore(SQLBaseStore): ts_valid_until_ms (int): The time when this json stops being valid. key_json (bytes): The encoded JSON. """ - return self._simple_upsert( + return self.simple_upsert( table="server_keys_json", keyvalues={ "server_name": server_name, @@ -196,7 +196,7 @@ class KeyStore(SQLBaseStore): keyvalues["key_id"] = key_id if from_server is not None: keyvalues["from_server"] = from_server - rows = self._simple_select_list_txn( + rows = self.simple_select_list_txn( txn, "server_keys_json", keyvalues=keyvalues, diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py index 0f2887bdce..0cb9446f96 100644 --- a/synapse/storage/data_stores/main/media_repository.py +++ b/synapse/storage/data_stores/main/media_repository.py @@ -39,7 +39,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): Returns: None if the media_id doesn't exist. """ - return self._simple_select_one( + return self.simple_select_one( "local_media_repository", {"media_id": media_id}, ( @@ -64,7 +64,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): user_id, url_cache=None, ): - return self._simple_insert( + return self.simple_insert( "local_media_repository", { "media_id": media_id, @@ -129,7 +129,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def store_url_cache( self, url, response_code, etag, expires_ts, og, media_id, download_ts ): - return self._simple_insert( + return self.simple_insert( "local_media_repository_url_cache", { "url": url, @@ -144,7 +144,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) def get_local_media_thumbnails(self, media_id): - return self._simple_select_list( + return self.simple_select_list( "local_media_repository_thumbnails", {"media_id": media_id}, ( @@ -166,7 +166,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): thumbnail_method, thumbnail_length, ): - return self._simple_insert( + return self.simple_insert( "local_media_repository_thumbnails", { "media_id": media_id, @@ -180,7 +180,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) def get_cached_remote_media(self, origin, media_id): - return self._simple_select_one( + return self.simple_select_one( "remote_media_cache", {"media_origin": origin, "media_id": media_id}, ( @@ -205,7 +205,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): upload_name, filesystem_id, ): - return self._simple_insert( + return self.simple_insert( "remote_media_cache", { "media_origin": origin, @@ -253,7 +253,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): return self.runInteraction("update_cached_last_access_time", update_cache_txn) def get_remote_media_thumbnails(self, origin, media_id): - return self._simple_select_list( + return self.simple_select_list( "remote_media_cache_thumbnails", {"media_origin": origin, "media_id": media_id}, ( @@ -278,7 +278,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): thumbnail_method, thumbnail_length, ): - return self._simple_insert( + return self.simple_insert( "remote_media_cache_thumbnails", { "media_origin": origin, @@ -300,18 +300,18 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): " WHERE last_access_ts < ?" ) - return self._execute( + return self.execute( "get_remote_media_before", self.cursor_to_dict, sql, before_ts ) def delete_remote_media(self, media_origin, media_id): def delete_remote_media_txn(txn): - self._simple_delete_txn( + self.simple_delete_txn( txn, "remote_media_cache", keyvalues={"media_origin": media_origin, "media_id": media_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, "remote_media_cache_thumbnails", keyvalues={"media_origin": media_origin, "media_id": media_id}, diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py index b41c3d317a..b8fc28f97b 100644 --- a/synapse/storage/data_stores/main/monthly_active_users.py +++ b/synapse/storage/data_stores/main/monthly_active_users.py @@ -32,7 +32,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): self._clock = hs.get_clock() self.hs = hs # Do not add more reserved users than the total allowable number - self._new_transaction( + self.new_transaction( dbconn, "initialise_mau_threepids", [], @@ -261,7 +261,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): # never be a big table and alternative approaches (batching multiple # upserts into a single txn) introduced a lot of extra complexity. # See https://github.com/matrix-org/synapse/issues/3854 for more - is_insert = self._simple_upsert_txn( + is_insert = self.simple_upsert_txn( txn, table="monthly_active_users", keyvalues={"user_id": user_id}, @@ -281,7 +281,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): """ - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="monthly_active_users", keyvalues={"user_id": user_id}, retcol="timestamp", diff --git a/synapse/storage/data_stores/main/openid.py b/synapse/storage/data_stores/main/openid.py index 79b40044d9..650e49750e 100644 --- a/synapse/storage/data_stores/main/openid.py +++ b/synapse/storage/data_stores/main/openid.py @@ -3,7 +3,7 @@ from synapse.storage._base import SQLBaseStore class OpenIdStore(SQLBaseStore): def insert_open_id_token(self, token, ts_valid_until_ms, user_id): - return self._simple_insert( + return self.simple_insert( table="open_id_tokens", values={ "token": token, diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py index 523ed6575e..a5e121efd1 100644 --- a/synapse/storage/data_stores/main/presence.py +++ b/synapse/storage/data_stores/main/presence.py @@ -46,7 +46,7 @@ class PresenceStore(SQLBaseStore): txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,)) # Actually insert new rows - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="presence_stream", values=[ @@ -103,7 +103,7 @@ class PresenceStore(SQLBaseStore): inlineCallbacks=True, ) def get_presence_for_users(self, user_ids): - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="presence_stream", column="user_id", iterable=user_ids, @@ -129,7 +129,7 @@ class PresenceStore(SQLBaseStore): return self._presence_id_gen.get_current_token() def allow_presence_visible(self, observed_localpart, observer_userid): - return self._simple_insert( + return self.simple_insert( table="presence_allow_inbound", values={ "observed_user_id": observed_localpart, @@ -140,7 +140,7 @@ class PresenceStore(SQLBaseStore): ) def disallow_presence_visible(self, observed_localpart, observer_userid): - return self._simple_delete_one( + return self.simple_delete_one( table="presence_allow_inbound", keyvalues={ "observed_user_id": observed_localpart, diff --git a/synapse/storage/data_stores/main/profile.py b/synapse/storage/data_stores/main/profile.py index e4e8a1c1d6..c8b5b60301 100644 --- a/synapse/storage/data_stores/main/profile.py +++ b/synapse/storage/data_stores/main/profile.py @@ -24,7 +24,7 @@ class ProfileWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_profileinfo(self, user_localpart): try: - profile = yield self._simple_select_one( + profile = yield self.simple_select_one( table="profiles", keyvalues={"user_id": user_localpart}, retcols=("displayname", "avatar_url"), @@ -42,7 +42,7 @@ class ProfileWorkerStore(SQLBaseStore): ) def get_profile_displayname(self, user_localpart): - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, retcol="displayname", @@ -50,7 +50,7 @@ class ProfileWorkerStore(SQLBaseStore): ) def get_profile_avatar_url(self, user_localpart): - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, retcol="avatar_url", @@ -58,7 +58,7 @@ class ProfileWorkerStore(SQLBaseStore): ) def get_from_remote_profile_cache(self, user_id): - return self._simple_select_one( + return self.simple_select_one( table="remote_profile_cache", keyvalues={"user_id": user_id}, retcols=("displayname", "avatar_url"), @@ -67,12 +67,12 @@ class ProfileWorkerStore(SQLBaseStore): ) def create_profile(self, user_localpart): - return self._simple_insert( + return self.simple_insert( table="profiles", values={"user_id": user_localpart}, desc="create_profile" ) def set_profile_displayname(self, user_localpart, new_displayname): - return self._simple_update_one( + return self.simple_update_one( table="profiles", keyvalues={"user_id": user_localpart}, updatevalues={"displayname": new_displayname}, @@ -80,7 +80,7 @@ class ProfileWorkerStore(SQLBaseStore): ) def set_profile_avatar_url(self, user_localpart, new_avatar_url): - return self._simple_update_one( + return self.simple_update_one( table="profiles", keyvalues={"user_id": user_localpart}, updatevalues={"avatar_url": new_avatar_url}, @@ -95,7 +95,7 @@ class ProfileStore(ProfileWorkerStore): This should only be called when `is_subscribed_remote_profile_for_user` would return true for the user. """ - return self._simple_upsert( + return self.simple_upsert( table="remote_profile_cache", keyvalues={"user_id": user_id}, values={ @@ -107,7 +107,7 @@ class ProfileStore(ProfileWorkerStore): ) def update_remote_profile_cache(self, user_id, displayname, avatar_url): - return self._simple_update( + return self.simple_update( table="remote_profile_cache", keyvalues={"user_id": user_id}, values={ @@ -125,7 +125,7 @@ class ProfileStore(ProfileWorkerStore): """ subscribed = yield self.is_subscribed_remote_profile_for_user(user_id) if not subscribed: - yield self._simple_delete( + yield self.simple_delete( table="remote_profile_cache", keyvalues={"user_id": user_id}, desc="delete_remote_profile_cache", @@ -155,7 +155,7 @@ class ProfileStore(ProfileWorkerStore): def is_subscribed_remote_profile_for_user(self, user_id): """Check whether we are interested in a remote user's profile. """ - res = yield self._simple_select_one_onecol( + res = yield self.simple_select_one_onecol( table="group_users", keyvalues={"user_id": user_id}, retcol="user_id", @@ -166,7 +166,7 @@ class ProfileStore(ProfileWorkerStore): if res: return True - res = yield self._simple_select_one_onecol( + res = yield self.simple_select_one_onecol( table="group_invites", keyvalues={"user_id": user_id}, retcol="user_id", diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index b520062d84..75bd499bcd 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -75,7 +75,7 @@ class PushRulesWorkerStore( def __init__(self, db_conn, hs): super(PushRulesWorkerStore, self).__init__(db_conn, hs) - push_rules_prefill, push_rules_id = self._get_cache_dict( + push_rules_prefill, push_rules_id = self.get_cache_dict( db_conn, "push_rules_stream", entity_column="user_id", @@ -100,7 +100,7 @@ class PushRulesWorkerStore( @cachedInlineCallbacks(max_entries=5000) def get_push_rules_for_user(self, user_id): - rows = yield self._simple_select_list( + rows = yield self.simple_select_list( table="push_rules", keyvalues={"user_name": user_id}, retcols=( @@ -124,7 +124,7 @@ class PushRulesWorkerStore( @cachedInlineCallbacks(max_entries=5000) def get_push_rules_enabled_for_user(self, user_id): - results = yield self._simple_select_list( + results = yield self.simple_select_list( table="push_rules_enable", keyvalues={"user_name": user_id}, retcols=("user_name", "rule_id", "enabled"), @@ -162,7 +162,7 @@ class PushRulesWorkerStore( results = {user_id: [] for user_id in user_ids} - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="push_rules", column="user_name", iterable=user_ids, @@ -320,7 +320,7 @@ class PushRulesWorkerStore( results = {user_id: {} for user_id in user_ids} - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="push_rules_enable", column="user_name", iterable=user_ids, @@ -395,7 +395,7 @@ class PushRuleStore(PushRulesWorkerStore): relative_to_rule = before or after - res = self._simple_select_one_txn( + res = self.simple_select_one_txn( txn, table="push_rules", keyvalues={"user_name": user_id, "rule_id": relative_to_rule}, @@ -499,7 +499,7 @@ class PushRuleStore(PushRulesWorkerStore): actions_json, update_stream=True, ): - """Specialised version of _simple_upsert_txn that picks a push_rule_id + """Specialised version of simple_upsert_txn that picks a push_rule_id using the _push_rule_id_gen if it needs to insert the rule. It assumes that the "push_rules" table is locked""" @@ -518,7 +518,7 @@ class PushRuleStore(PushRulesWorkerStore): # We didn't update a row with the given rule_id so insert one push_rule_id = self._push_rule_id_gen.get_next() - self._simple_insert_txn( + self.simple_insert_txn( txn, table="push_rules", values={ @@ -561,7 +561,7 @@ class PushRuleStore(PushRulesWorkerStore): """ def delete_push_rule_txn(txn, stream_id, event_stream_ordering): - self._simple_delete_one_txn( + self.simple_delete_one_txn( txn, "push_rules", {"user_name": user_id, "rule_id": rule_id} ) @@ -596,7 +596,7 @@ class PushRuleStore(PushRulesWorkerStore): self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled ): new_id = self._push_rules_enable_id_gen.get_next() - self._simple_upsert_txn( + self.simple_upsert_txn( txn, "push_rules_enable", {"user_name": user_id, "rule_id": rule_id}, @@ -636,7 +636,7 @@ class PushRuleStore(PushRulesWorkerStore): update_stream=False, ) else: - self._simple_update_one_txn( + self.simple_update_one_txn( txn, "push_rules", {"user_name": user_id, "rule_id": rule_id}, @@ -675,7 +675,7 @@ class PushRuleStore(PushRulesWorkerStore): if data is not None: values.update(data) - self._simple_insert_txn(txn, "push_rules_stream", values=values) + self.simple_insert_txn(txn, "push_rules_stream", values=values) txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,)) txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,)) diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py index d76861cdc0..d5a169872b 100644 --- a/synapse/storage/data_stores/main/pusher.py +++ b/synapse/storage/data_stores/main/pusher.py @@ -59,7 +59,7 @@ class PusherWorkerStore(SQLBaseStore): @defer.inlineCallbacks def user_has_pusher(self, user_id): - ret = yield self._simple_select_one_onecol( + ret = yield self.simple_select_one_onecol( "pushers", {"user_name": user_id}, "id", allow_none=True ) return ret is not None @@ -72,7 +72,7 @@ class PusherWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_pushers_by(self, keyvalues): - ret = yield self._simple_select_list( + ret = yield self.simple_select_list( "pushers", keyvalues, [ @@ -193,7 +193,7 @@ class PusherWorkerStore(SQLBaseStore): inlineCallbacks=True, ) def get_if_users_have_pushers(self, user_ids): - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="pushers", column="user_name", iterable=user_ids, @@ -229,8 +229,8 @@ class PusherStore(PusherWorkerStore): ): with self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on - # (app_id, pushkey, user_name) so _simple_upsert will retry - yield self._simple_upsert( + # (app_id, pushkey, user_name) so simple_upsert will retry + yield self.simple_upsert( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, values={ @@ -269,7 +269,7 @@ class PusherStore(PusherWorkerStore): txn, self.get_if_user_has_pusher, (user_id,) ) - self._simple_delete_one_txn( + self.simple_delete_one_txn( txn, "pushers", {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, @@ -278,7 +278,7 @@ class PusherStore(PusherWorkerStore): # it's possible for us to end up with duplicate rows for # (app_id, pushkey, user_id) at different stream_ids, but that # doesn't really matter. - self._simple_insert_txn( + self.simple_insert_txn( txn, table="deleted_pushers", values={ @@ -296,7 +296,7 @@ class PusherStore(PusherWorkerStore): def update_pusher_last_stream_ordering( self, app_id, pushkey, user_id, last_stream_ordering ): - yield self._simple_update_one( + yield self.simple_update_one( "pushers", {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, {"last_stream_ordering": last_stream_ordering}, @@ -319,7 +319,7 @@ class PusherStore(PusherWorkerStore): Returns: Deferred[bool]: True if the pusher still exists; False if it has been deleted. """ - updated = yield self._simple_update( + updated = yield self.simple_update( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, updatevalues={ @@ -333,7 +333,7 @@ class PusherStore(PusherWorkerStore): @defer.inlineCallbacks def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since): - yield self._simple_update( + yield self.simple_update( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, updatevalues={"failing_since": failing_since}, @@ -342,7 +342,7 @@ class PusherStore(PusherWorkerStore): @defer.inlineCallbacks def get_throttle_params_by_room(self, pusher_id): - res = yield self._simple_select_list( + res = yield self.simple_select_list( "pusher_throttle", {"pusher": pusher_id}, ["room_id", "last_sent_ts", "throttle_ms"], @@ -361,8 +361,8 @@ class PusherStore(PusherWorkerStore): @defer.inlineCallbacks def set_throttle_params(self, pusher_id, room_id, params): # no need to lock because `pusher_throttle` has a primary key on - # (pusher, room_id) so _simple_upsert will retry - yield self._simple_upsert( + # (pusher, room_id) so simple_upsert will retry + yield self.simple_upsert( "pusher_throttle", {"pusher": pusher_id, "room_id": room_id}, params, diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py index 8b17334ff4..380f388e30 100644 --- a/synapse/storage/data_stores/main/receipts.py +++ b/synapse/storage/data_stores/main/receipts.py @@ -61,7 +61,7 @@ class ReceiptsWorkerStore(SQLBaseStore): @cached(num_args=2) def get_receipts_for_room(self, room_id, receipt_type): - return self._simple_select_list( + return self.simple_select_list( table="receipts_linearized", keyvalues={"room_id": room_id, "receipt_type": receipt_type}, retcols=("user_id", "event_id"), @@ -70,7 +70,7 @@ class ReceiptsWorkerStore(SQLBaseStore): @cached(num_args=3) def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type): - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="receipts_linearized", keyvalues={ "room_id": room_id, @@ -84,7 +84,7 @@ class ReceiptsWorkerStore(SQLBaseStore): @cachedInlineCallbacks(num_args=2) def get_receipts_for_user(self, user_id, receipt_type): - rows = yield self._simple_select_list( + rows = yield self.simple_select_list( table="receipts_linearized", keyvalues={"user_id": user_id, "receipt_type": receipt_type}, retcols=("room_id", "event_id"), @@ -335,7 +335,7 @@ class ReceiptsStore(ReceiptsWorkerStore): otherwise, the rx timestamp of the event that the RR corresponds to (or 0 if the event is unknown) """ - res = self._simple_select_one_txn( + res = self.simple_select_one_txn( txn, table="events", retcols=["stream_ordering", "received_ts"], @@ -388,7 +388,7 @@ class ReceiptsStore(ReceiptsWorkerStore): (user_id, room_id, receipt_type), ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="receipts_linearized", keyvalues={ @@ -398,7 +398,7 @@ class ReceiptsStore(ReceiptsWorkerStore): }, ) - self._simple_insert_txn( + self.simple_insert_txn( txn, table="receipts_linearized", values={ @@ -514,7 +514,7 @@ class ReceiptsStore(ReceiptsWorkerStore): self._get_linearized_receipts_for_room.invalidate_many, (room_id,) ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="receipts_graph", keyvalues={ @@ -523,7 +523,7 @@ class ReceiptsStore(ReceiptsWorkerStore): "user_id": user_id, }, ) - self._simple_insert_txn( + self.simple_insert_txn( txn, table="receipts_graph", values={ diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py index 653c9318cb..debc6706f5 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/data_stores/main/registration.py @@ -45,7 +45,7 @@ class RegistrationWorkerStore(SQLBaseStore): @cached() def get_user_by_id(self, user_id): - return self._simple_select_one( + return self.simple_select_one( table="users", keyvalues={"name": user_id}, retcols=[ @@ -109,7 +109,7 @@ class RegistrationWorkerStore(SQLBaseStore): otherwise int representation of the timestamp (as a number of milliseconds since epoch). """ - res = yield self._simple_select_one_onecol( + res = yield self.simple_select_one_onecol( table="account_validity", keyvalues={"user_id": user_id}, retcol="expiration_ts_ms", @@ -137,7 +137,7 @@ class RegistrationWorkerStore(SQLBaseStore): """ def set_account_validity_for_user_txn(txn): - self._simple_update_txn( + self.simple_update_txn( txn=txn, table="account_validity", keyvalues={"user_id": user_id}, @@ -167,7 +167,7 @@ class RegistrationWorkerStore(SQLBaseStore): Raises: StoreError: The provided token is already set for another user. """ - yield self._simple_update_one( + yield self.simple_update_one( table="account_validity", keyvalues={"user_id": user_id}, updatevalues={"renewal_token": renewal_token}, @@ -184,7 +184,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: defer.Deferred[str]: The ID of the user to which the token belongs. """ - res = yield self._simple_select_one_onecol( + res = yield self.simple_select_one_onecol( table="account_validity", keyvalues={"renewal_token": renewal_token}, retcol="user_id", @@ -203,7 +203,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: defer.Deferred[str]: The renewal token associated with this user ID. """ - res = yield self._simple_select_one_onecol( + res = yield self.simple_select_one_onecol( table="account_validity", keyvalues={"user_id": user_id}, retcol="renewal_token", @@ -250,7 +250,7 @@ class RegistrationWorkerStore(SQLBaseStore): email_sent (bool): Flag which indicates whether a renewal email has been sent to this user. """ - yield self._simple_update_one( + yield self.simple_update_one( table="account_validity", keyvalues={"user_id": user_id}, updatevalues={"email_sent": email_sent}, @@ -265,7 +265,7 @@ class RegistrationWorkerStore(SQLBaseStore): Args: user_id (str): ID of the user to remove from the account validity table. """ - yield self._simple_delete_one( + yield self.simple_delete_one( table="account_validity", keyvalues={"user_id": user_id}, desc="delete_account_validity_for_user", @@ -281,7 +281,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns (bool): true iff the user is a server admin, false otherwise. """ - res = yield self._simple_select_one_onecol( + res = yield self.simple_select_one_onecol( table="users", keyvalues={"name": user.to_string()}, retcol="admin", @@ -299,7 +299,7 @@ class RegistrationWorkerStore(SQLBaseStore): admin (bool): true iff the user is to be a server admin, false otherwise. """ - return self._simple_update_one( + return self.simple_update_one( table="users", keyvalues={"name": user.to_string()}, updatevalues={"admin": 1 if admin else 0}, @@ -351,7 +351,7 @@ class RegistrationWorkerStore(SQLBaseStore): return res def is_real_user_txn(self, txn, user_id): - res = self._simple_select_one_onecol_txn( + res = self.simple_select_one_onecol_txn( txn=txn, table="users", keyvalues={"name": user_id}, @@ -361,7 +361,7 @@ class RegistrationWorkerStore(SQLBaseStore): return res is None def is_support_user_txn(self, txn, user_id): - res = self._simple_select_one_onecol_txn( + res = self.simple_select_one_onecol_txn( txn=txn, table="users", keyvalues={"name": user_id}, @@ -394,7 +394,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: str|None: the mxid of the user, or None if they are not known """ - return await self._simple_select_one_onecol( + return await self.simple_select_one_onecol( table="user_external_ids", keyvalues={"auth_provider": auth_provider, "external_id": external_id}, retcol="user_id", @@ -536,7 +536,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: str|None: user id or None if no user id/threepid mapping exists """ - ret = self._simple_select_one_txn( + ret = self.simple_select_one_txn( txn, "user_threepids", {"medium": medium, "address": address}, @@ -549,7 +549,7 @@ class RegistrationWorkerStore(SQLBaseStore): @defer.inlineCallbacks def user_add_threepid(self, user_id, medium, address, validated_at, added_at): - yield self._simple_upsert( + yield self.simple_upsert( "user_threepids", {"medium": medium, "address": address}, {"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, @@ -557,7 +557,7 @@ class RegistrationWorkerStore(SQLBaseStore): @defer.inlineCallbacks def user_get_threepids(self, user_id): - ret = yield self._simple_select_list( + ret = yield self.simple_select_list( "user_threepids", {"user_id": user_id}, ["medium", "address", "validated_at", "added_at"], @@ -566,7 +566,7 @@ class RegistrationWorkerStore(SQLBaseStore): return ret def user_delete_threepid(self, user_id, medium, address): - return self._simple_delete( + return self.simple_delete( "user_threepids", keyvalues={"user_id": user_id, "medium": medium, "address": address}, desc="user_delete_threepid", @@ -579,7 +579,7 @@ class RegistrationWorkerStore(SQLBaseStore): user_id: The user id to delete all threepids of """ - return self._simple_delete( + return self.simple_delete( "user_threepids", keyvalues={"user_id": user_id}, desc="user_delete_threepids", @@ -601,7 +601,7 @@ class RegistrationWorkerStore(SQLBaseStore): """ # We need to use an upsert, in case they user had already bound the # threepid - return self._simple_upsert( + return self.simple_upsert( table="user_threepid_id_server", keyvalues={ "user_id": user_id, @@ -627,7 +627,7 @@ class RegistrationWorkerStore(SQLBaseStore): medium (str): The medium of the threepid (e.g "email") address (str): The address of the threepid (e.g "bob@example.com") """ - return self._simple_select_list( + return self.simple_select_list( table="user_threepid_id_server", keyvalues={"user_id": user_id}, retcols=["medium", "address"], @@ -648,7 +648,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: Deferred """ - return self._simple_delete( + return self.simple_delete( table="user_threepid_id_server", keyvalues={ "user_id": user_id, @@ -671,7 +671,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: Deferred[list[str]]: Resolves to a list of identity servers """ - return self._simple_select_onecol( + return self.simple_select_onecol( table="user_threepid_id_server", keyvalues={"user_id": user_id, "medium": medium, "address": address}, retcol="id_server", @@ -689,7 +689,7 @@ class RegistrationWorkerStore(SQLBaseStore): defer.Deferred(bool): The requested value. """ - res = yield self._simple_select_one_onecol( + res = yield self.simple_select_one_onecol( table="users", keyvalues={"name": user_id}, retcol="deactivated", @@ -776,12 +776,12 @@ class RegistrationWorkerStore(SQLBaseStore): """ def delete_threepid_session_txn(txn): - self._simple_delete_txn( + self.simple_delete_txn( txn, table="threepid_validation_token", keyvalues={"session_id": session_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, @@ -961,7 +961,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): """ next_id = self._access_tokens_id_gen.get_next() - yield self._simple_insert( + yield self.simple_insert( "access_tokens", { "id": next_id, @@ -1037,7 +1037,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): # Ensure that the guest user actually exists # ``allow_none=False`` makes this raise an exception # if the row isn't in the database. - self._simple_select_one_txn( + self.simple_select_one_txn( txn, "users", keyvalues={"name": user_id, "is_guest": 1}, @@ -1045,7 +1045,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): allow_none=False, ) - self._simple_update_one_txn( + self.simple_update_one_txn( txn, "users", keyvalues={"name": user_id, "is_guest": 1}, @@ -1059,7 +1059,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): }, ) else: - self._simple_insert_txn( + self.simple_insert_txn( txn, "users", values={ @@ -1114,7 +1114,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): external_id: id on that system user_id: complete mxid that it is mapped to """ - return self._simple_insert( + return self.simple_insert( table="user_external_ids", values={ "auth_provider": auth_provider, @@ -1132,7 +1132,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): """ def user_set_password_hash_txn(txn): - self._simple_update_one_txn( + self.simple_update_one_txn( txn, "users", {"name": user_id}, {"password_hash": password_hash} ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) @@ -1152,7 +1152,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): """ def f(txn): - self._simple_update_one_txn( + self.simple_update_one_txn( txn, table="users", keyvalues={"name": user_id}, @@ -1176,7 +1176,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): """ def f(txn): - self._simple_update_one_txn( + self.simple_update_one_txn( txn, table="users", keyvalues={"name": user_id}, @@ -1234,7 +1234,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): def delete_access_token(self, access_token): def f(txn): - self._simple_delete_one_txn( + self.simple_delete_one_txn( txn, table="access_tokens", keyvalues={"token": access_token} ) @@ -1246,7 +1246,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): @cachedInlineCallbacks() def is_guest(self, user_id): - res = yield self._simple_select_one_onecol( + res = yield self.simple_select_one_onecol( table="users", keyvalues={"name": user_id}, retcol="is_guest", @@ -1261,7 +1261,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): Adds a user to the table of users who need to be parted from all the rooms they're in """ - return self._simple_insert( + return self.simple_insert( "users_pending_deactivation", values={"user_id": user_id}, desc="add_user_pending_deactivation", @@ -1274,7 +1274,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): """ # XXX: This should be simple_delete_one but we failed to put a unique index on # the table, so somehow duplicate entries have ended up in it. - return self._simple_delete( + return self.simple_delete( "users_pending_deactivation", keyvalues={"user_id": user_id}, desc="del_user_pending_deactivation", @@ -1285,7 +1285,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): Gets one user from the table of users waiting to be parted from all the rooms they're in. """ - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( "users_pending_deactivation", keyvalues={}, retcol="user_id", @@ -1315,7 +1315,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): # Insert everything into a transaction in order to run atomically def validate_threepid_session_txn(txn): - row = self._simple_select_one_txn( + row = self.simple_select_one_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, @@ -1333,7 +1333,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): 400, "This client_secret does not match the provided session_id" ) - row = self._simple_select_one_txn( + row = self.simple_select_one_txn( txn, table="threepid_validation_token", keyvalues={"session_id": session_id, "token": token}, @@ -1358,7 +1358,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) # Looks good. Validate the session - self._simple_update_txn( + self.simple_update_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, @@ -1401,7 +1401,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): if validated_at: insertion_values["validated_at"] = validated_at - return self._simple_upsert( + return self.simple_upsert( table="threepid_validation_session", keyvalues={"session_id": session_id}, values={"last_send_attempt": send_attempt}, @@ -1439,7 +1439,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): def start_or_continue_validation_session_txn(txn): # Create or update a validation session - self._simple_upsert_txn( + self.simple_upsert_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, @@ -1452,7 +1452,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) # Create a new validation token with this session ID - self._simple_insert_txn( + self.simple_insert_txn( txn, table="threepid_validation_token", values={ @@ -1501,7 +1501,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) def set_user_deactivated_status_txn(self, txn, user_id, deactivated): - self._simple_update_one_txn( + self.simple_update_one_txn( txn=txn, table="users", keyvalues={"name": user_id}, @@ -1560,7 +1560,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): expiration_ts, ) - self._simple_upsert_txn( + self.simple_upsert_txn( txn, "account_validity", keyvalues={"user_id": user_id}, diff --git a/synapse/storage/data_stores/main/rejections.py b/synapse/storage/data_stores/main/rejections.py index 7d5de0ea2e..f81f9279a1 100644 --- a/synapse/storage/data_stores/main/rejections.py +++ b/synapse/storage/data_stores/main/rejections.py @@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) class RejectionsStore(SQLBaseStore): def _store_rejections_txn(self, txn, event_id, reason): - self._simple_insert_txn( + self.simple_insert_txn( txn, table="rejections", values={ @@ -33,7 +33,7 @@ class RejectionsStore(SQLBaseStore): ) def get_rejection_reason(self, event_id): - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="rejections", retcol="reason", keyvalues={"event_id": event_id}, diff --git a/synapse/storage/data_stores/main/relations.py b/synapse/storage/data_stores/main/relations.py index 858f65582b..aa5e10538b 100644 --- a/synapse/storage/data_stores/main/relations.py +++ b/synapse/storage/data_stores/main/relations.py @@ -352,7 +352,7 @@ class RelationsStore(RelationsWorkerStore): aggregation_key = relation.get("key") - self._simple_insert_txn( + self.simple_insert_txn( txn, table="event_relations", values={ @@ -380,6 +380,6 @@ class RelationsStore(RelationsWorkerStore): redacted_event_id (str): The event that was redacted. """ - self._simple_delete_txn( + self.simple_delete_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id} ) diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index b7f9024811..8f9b6365c1 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -53,7 +53,7 @@ class RoomWorkerStore(SQLBaseStore): Returns: A dict containing the room information, or None if the room is unknown. """ - return self._simple_select_one( + return self.simple_select_one( table="rooms", keyvalues={"room_id": room_id}, retcols=("room_id", "is_public", "creator"), @@ -62,7 +62,7 @@ class RoomWorkerStore(SQLBaseStore): ) def get_public_room_ids(self): - return self._simple_select_onecol( + return self.simple_select_onecol( table="rooms", keyvalues={"is_public": True}, retcol="room_id", @@ -266,7 +266,7 @@ class RoomWorkerStore(SQLBaseStore): @cached(max_entries=10000) def is_room_blocked(self, room_id): - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="blocked_rooms", keyvalues={"room_id": room_id}, retcol="1", @@ -287,7 +287,7 @@ class RoomWorkerStore(SQLBaseStore): of RatelimitOverride are None or 0 then ratelimitng has been disabled for that user entirely. """ - row = yield self._simple_select_one( + row = yield self.simple_select_one( table="ratelimit_override", keyvalues={"user_id": user_id}, retcols=("messages_per_second", "burst_count"), @@ -407,7 +407,7 @@ class RoomStore(RoomWorkerStore, SearchStore): ev = json.loads(row["json"]) retention_policy = json.dumps(ev["content"]) - self._simple_insert_txn( + self.simple_insert_txn( txn=txn, table="room_retention", values={ @@ -453,7 +453,7 @@ class RoomStore(RoomWorkerStore, SearchStore): try: def store_room_txn(txn, next_id): - self._simple_insert_txn( + self.simple_insert_txn( txn, "rooms", { @@ -463,7 +463,7 @@ class RoomStore(RoomWorkerStore, SearchStore): }, ) if is_public: - self._simple_insert_txn( + self.simple_insert_txn( txn, table="public_room_list_stream", values={ @@ -482,14 +482,14 @@ class RoomStore(RoomWorkerStore, SearchStore): @defer.inlineCallbacks def set_room_is_public(self, room_id, is_public): def set_room_is_public_txn(txn, next_id): - self._simple_update_one_txn( + self.simple_update_one_txn( txn, table="rooms", keyvalues={"room_id": room_id}, updatevalues={"is_public": is_public}, ) - entries = self._simple_select_list_txn( + entries = self.simple_select_list_txn( txn, table="public_room_list_stream", keyvalues={ @@ -507,7 +507,7 @@ class RoomStore(RoomWorkerStore, SearchStore): add_to_stream = bool(entries[-1]["visibility"]) != is_public if add_to_stream: - self._simple_insert_txn( + self.simple_insert_txn( txn, table="public_room_list_stream", values={ @@ -547,7 +547,7 @@ class RoomStore(RoomWorkerStore, SearchStore): def set_room_is_public_appservice_txn(txn, next_id): if is_public: try: - self._simple_insert_txn( + self.simple_insert_txn( txn, table="appservice_room_list", values={ @@ -560,7 +560,7 @@ class RoomStore(RoomWorkerStore, SearchStore): # We've already inserted, nothing to do. return else: - self._simple_delete_txn( + self.simple_delete_txn( txn, table="appservice_room_list", keyvalues={ @@ -570,7 +570,7 @@ class RoomStore(RoomWorkerStore, SearchStore): }, ) - entries = self._simple_select_list_txn( + entries = self.simple_select_list_txn( txn, table="public_room_list_stream", keyvalues={ @@ -588,7 +588,7 @@ class RoomStore(RoomWorkerStore, SearchStore): add_to_stream = bool(entries[-1]["visibility"]) != is_public if add_to_stream: - self._simple_insert_txn( + self.simple_insert_txn( txn, table="public_room_list_stream", values={ @@ -652,7 +652,7 @@ class RoomStore(RoomWorkerStore, SearchStore): # Ignore the event if one of the value isn't an integer. return - self._simple_insert_txn( + self.simple_insert_txn( txn=txn, table="room_retention", values={ @@ -671,7 +671,7 @@ class RoomStore(RoomWorkerStore, SearchStore): self, room_id, event_id, user_id, reason, content, received_ts ): next_id = self._event_reports_id_gen.get_next() - return self._simple_insert( + return self.simple_insert( table="event_reports", values={ "id": next_id, @@ -717,7 +717,7 @@ class RoomStore(RoomWorkerStore, SearchStore): Returns: Deferred """ - yield self._simple_upsert( + yield self.simple_upsert( table="blocked_rooms", keyvalues={"room_id": room_id}, values={}, diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index b314d75941..fe2428a281 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -128,7 +128,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): membership column is up to date """ - pending_update = self._simple_select_one_txn( + pending_update = self.simple_select_one_txn( txn, table="background_updates", keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME}, @@ -603,7 +603,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): to `user_id` and ProfileInfo (or None if not join event). """ - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="room_memberships", column="event_id", iterable=event_ids, @@ -643,7 +643,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # the returned user actually has the correct domain. like_clause = "%:" + host - rows = yield self._execute("is_host_joined", None, sql, room_id, like_clause) + rows = yield self.execute("is_host_joined", None, sql, room_id, like_clause) if not rows: return False @@ -683,7 +683,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # the returned user actually has the correct domain. like_clause = "%:" + host - rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause) + rows = yield self.execute("was_host_joined", None, sql, room_id, like_clause) if not rows: return False @@ -805,7 +805,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): Deferred[set[str]]: Set of room IDs. """ - room_ids = yield self._simple_select_onecol( + room_ids = yield self.simple_select_onecol( table="room_memberships", keyvalues={"membership": Membership.JOIN, "user_id": user_id}, retcol="room_id", @@ -820,7 +820,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): """Get user_id and membership of a set of event IDs. """ - return self._simple_select_many_batch( + return self.simple_select_many_batch( table="room_memberships", column="event_id", iterable=member_event_ids, @@ -990,7 +990,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): def _store_room_members_txn(self, txn, events, backfilled): """Store a room member in the database. """ - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="room_memberships", values=[ @@ -1028,7 +1028,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): is_mine = self.hs.is_mine_id(event.state_key) if is_new_state and is_mine: if event.membership == Membership.INVITE: - self._simple_insert_txn( + self.simple_insert_txn( txn, table="local_invites", values={ diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py index d1d7c6863d..f735cf095c 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -441,7 +441,7 @@ class SearchStore(SearchBackgroundUpdateStore): # entire table from the database. sql += " ORDER BY rank DESC LIMIT 500" - results = yield self._execute("search_msgs", self.cursor_to_dict, sql, *args) + results = yield self.execute("search_msgs", self.cursor_to_dict, sql, *args) results = list(filter(lambda row: row["room_id"] in room_ids, results)) @@ -455,7 +455,7 @@ class SearchStore(SearchBackgroundUpdateStore): count_sql += " GROUP BY room_id" - count_results = yield self._execute( + count_results = yield self.execute( "search_rooms_count", self.cursor_to_dict, count_sql, *count_args ) @@ -586,7 +586,7 @@ class SearchStore(SearchBackgroundUpdateStore): args.append(limit) - results = yield self._execute("search_rooms", self.cursor_to_dict, sql, *args) + results = yield self.execute("search_rooms", self.cursor_to_dict, sql, *args) results = list(filter(lambda row: row["room_id"] in room_ids, results)) @@ -600,7 +600,7 @@ class SearchStore(SearchBackgroundUpdateStore): count_sql += " GROUP BY room_id" - count_results = yield self._execute( + count_results = yield self.execute( "search_rooms_count", self.cursor_to_dict, count_sql, *count_args ) diff --git a/synapse/storage/data_stores/main/signatures.py b/synapse/storage/data_stores/main/signatures.py index 556191b76f..f3da29ce14 100644 --- a/synapse/storage/data_stores/main/signatures.py +++ b/synapse/storage/data_stores/main/signatures.py @@ -98,4 +98,4 @@ class SignatureStore(SignatureWorkerStore): } ) - self._simple_insert_many_txn(txn, table="event_reference_hashes", values=vals) + self.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals) diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 6a90daea31..2b33ec1a35 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -89,7 +89,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): count = 0 while next_group: - next_group = self._simple_select_one_onecol_txn( + next_group = self.simple_select_one_onecol_txn( txn, table="state_group_edges", keyvalues={"state_group": next_group}, @@ -192,7 +192,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): ): break - next_group = self._simple_select_one_onecol_txn( + next_group = self.simple_select_one_onecol_txn( txn, table="state_group_edges", keyvalues={"state_group": next_group}, @@ -431,7 +431,7 @@ class StateGroupWorkerStore( """ def _get_state_group_delta_txn(txn): - prev_group = self._simple_select_one_onecol_txn( + prev_group = self.simple_select_one_onecol_txn( txn, table="state_group_edges", keyvalues={"state_group": state_group}, @@ -442,7 +442,7 @@ class StateGroupWorkerStore( if not prev_group: return _GetStateGroupDelta(None, None) - delta_ids = self._simple_select_list_txn( + delta_ids = self.simple_select_list_txn( txn, table="state_groups_state", keyvalues={"state_group": state_group}, @@ -644,7 +644,7 @@ class StateGroupWorkerStore( @cached(max_entries=50000) def _get_state_group_for_event(self, event_id): - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="event_to_state_groups", keyvalues={"event_id": event_id}, retcol="state_group", @@ -661,7 +661,7 @@ class StateGroupWorkerStore( def _get_state_group_for_events(self, event_ids): """Returns mapping event_id -> state_group """ - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="event_to_state_groups", column="event_id", iterable=event_ids, @@ -902,7 +902,7 @@ class StateGroupWorkerStore( state_group = self.database_engine.get_next_state_group_id(txn) - self._simple_insert_txn( + self.simple_insert_txn( txn, table="state_groups", values={"id": state_group, "room_id": room_id, "event_id": event_id}, @@ -911,7 +911,7 @@ class StateGroupWorkerStore( # We persist as a delta if we can, while also ensuring the chain # of deltas isn't tooo long, as otherwise read performance degrades. if prev_group: - is_in_db = self._simple_select_one_onecol_txn( + is_in_db = self.simple_select_one_onecol_txn( txn, table="state_groups", keyvalues={"id": prev_group}, @@ -926,13 +926,13 @@ class StateGroupWorkerStore( potential_hops = self._count_state_group_hops_txn(txn, prev_group) if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: - self._simple_insert_txn( + self.simple_insert_txn( txn, table="state_group_edges", values={"state_group": state_group, "prev_state_group": prev_group}, ) - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -947,7 +947,7 @@ class StateGroupWorkerStore( ], ) else: - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -1007,7 +1007,7 @@ class StateGroupWorkerStore( referenced. """ - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="event_to_state_groups", column="state_group", iterable=state_groups, @@ -1065,7 +1065,7 @@ class StateBackgroundUpdateStore( batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR)) if max_group is None: - rows = yield self._execute( + rows = yield self.execute( "_background_deduplicate_state", None, "SELECT coalesce(max(id), 0) FROM state_groups", @@ -1135,13 +1135,13 @@ class StateBackgroundUpdateStore( if prev_state.get(key, None) != value } - self._simple_delete_txn( + self.simple_delete_txn( txn, table="state_group_edges", keyvalues={"state_group": state_group}, ) - self._simple_insert_txn( + self.simple_insert_txn( txn, table="state_group_edges", values={ @@ -1150,13 +1150,13 @@ class StateBackgroundUpdateStore( }, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="state_groups_state", keyvalues={"state_group": state_group}, ) - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -1263,7 +1263,7 @@ class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore): state_groups[event.event_id] = context.state_group - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="event_to_state_groups", values=[ diff --git a/synapse/storage/data_stores/main/state_deltas.py b/synapse/storage/data_stores/main/state_deltas.py index 28f33ec18f..03b908026b 100644 --- a/synapse/storage/data_stores/main/state_deltas.py +++ b/synapse/storage/data_stores/main/state_deltas.py @@ -105,7 +105,7 @@ class StateDeltasStore(SQLBaseStore): ) def _get_max_stream_id_in_current_state_deltas_txn(self, txn): - return self._simple_select_one_onecol_txn( + return self.simple_select_one_onecol_txn( txn, table="current_state_delta_stream", keyvalues={}, diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py index 45b3de7d56..3aeba859fd 100644 --- a/synapse/storage/data_stores/main/stats.py +++ b/synapse/storage/data_stores/main/stats.py @@ -186,7 +186,7 @@ class StatsStore(StateDeltasStore): """ Returns the stats processor positions. """ - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="stats_incremental_position", keyvalues={}, retcol="stream_id", @@ -215,7 +215,7 @@ class StatsStore(StateDeltasStore): if field and "\0" in field: fields[col] = None - return self._simple_upsert( + return self.simple_upsert( table="room_stats_state", keyvalues={"room_id": room_id}, values=fields, @@ -257,7 +257,7 @@ class StatsStore(StateDeltasStore): ABSOLUTE_STATS_FIELDS[stats_type] + PER_SLICE_FIELDS[stats_type] ) - slice_list = self._simple_select_list_paginate_txn( + slice_list = self.simple_select_list_paginate_txn( txn, table + "_historical", {id_col: stats_id}, @@ -282,7 +282,7 @@ class StatsStore(StateDeltasStore): "name", "topic", "canonical_alias", "avatar", "join_rules", "history_visibility" """ - return self._simple_select_one( + return self.simple_select_one( "room_stats_state", {"room_id": room_id}, retcols=( @@ -308,7 +308,7 @@ class StatsStore(StateDeltasStore): """ table, id_col = TYPE_TO_TABLE[stats_type] - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( "%s_current" % (table,), keyvalues={id_col: id}, retcol="completed_delta_stream_id", @@ -344,7 +344,7 @@ class StatsStore(StateDeltasStore): complete_with_stream_id=stream_id, ) - self._simple_update_one_txn( + self.simple_update_one_txn( txn, table="stats_incremental_position", keyvalues={}, @@ -517,17 +517,17 @@ class StatsStore(StateDeltasStore): else: self.database_engine.lock_table(txn, table) retcols = list(chain(absolutes.keys(), additive_relatives.keys())) - current_row = self._simple_select_one_txn( + current_row = self.simple_select_one_txn( txn, table, keyvalues, retcols, allow_none=True ) if current_row is None: merged_dict = {**keyvalues, **absolutes, **additive_relatives} - self._simple_insert_txn(txn, table, merged_dict) + self.simple_insert_txn(txn, table, merged_dict) else: for (key, val) in additive_relatives.items(): current_row[key] += val current_row.update(absolutes) - self._simple_update_one_txn(txn, table, keyvalues, current_row) + self.simple_update_one_txn(txn, table, keyvalues, current_row) def _upsert_copy_from_table_with_additive_relatives_txn( self, @@ -614,11 +614,11 @@ class StatsStore(StateDeltasStore): txn.execute(sql, qargs) else: self.database_engine.lock_table(txn, into_table) - src_row = self._simple_select_one_txn( + src_row = self.simple_select_one_txn( txn, src_table, keyvalues, copy_columns ) all_dest_keyvalues = {**keyvalues, **extra_dst_keyvalues} - dest_current_row = self._simple_select_one_txn( + dest_current_row = self.simple_select_one_txn( txn, into_table, keyvalues=all_dest_keyvalues, @@ -634,11 +634,11 @@ class StatsStore(StateDeltasStore): **src_row, **additive_relatives, } - self._simple_insert_txn(txn, into_table, merged_dict) + self.simple_insert_txn(txn, into_table, merged_dict) else: for (key, val) in additive_relatives.items(): src_row[key] = dest_current_row[key] + val - self._simple_update_txn(txn, into_table, all_dest_keyvalues, src_row) + self.simple_update_txn(txn, into_table, all_dest_keyvalues, src_row) def get_changes_room_total_events_and_bytes(self, min_pos, max_pos): """Fetches the counts of events in the given range of stream IDs. @@ -735,7 +735,7 @@ class StatsStore(StateDeltasStore): def _fetch_current_state_stats(txn): pos = self.get_room_max_stream_ordering() - rows = self._simple_select_many_txn( + rows = self.simple_select_many_txn( txn, table="current_state_events", column="type", diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py index 21a410afd0..60487c4559 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -255,7 +255,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): super(StreamWorkerStore, self).__init__(db_conn, hs) events_max = self.get_room_max_stream_ordering() - event_cache_prefill, min_event_val = self._get_cache_dict( + event_cache_prefill, min_event_val = self.get_cache_dict( db_conn, "events", entity_column="room_id", @@ -576,7 +576,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): Returns: A deferred "s%d" stream token. """ - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering" ).addCallback(lambda row: "s%d" % (row,)) @@ -589,7 +589,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): Returns: A deferred "t%d-%d" topological token. """ - return self._simple_select_one( + return self.simple_select_one( table="events", keyvalues={"event_id": event_id}, retcols=("stream_ordering", "topological_ordering"), @@ -613,7 +613,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): "SELECT coalesce(max(topological_ordering), 0) FROM events" " WHERE room_id = ? AND stream_ordering < ?" ) - return self._execute( + return self.execute( "get_max_topological_token", None, sql, room_id, stream_key ).addCallback(lambda r: r[0][0] if r else 0) @@ -709,7 +709,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): dict """ - results = self._simple_select_one_txn( + results = self.simple_select_one_txn( txn, "events", keyvalues={"event_id": event_id, "room_id": room_id}, @@ -797,7 +797,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return upper_bound, events def get_federation_out_pos(self, typ): - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="federation_stream_position", retcol="stream_id", keyvalues={"type": typ}, @@ -805,7 +805,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) def update_federation_out_pos(self, typ, stream_id): - return self._simple_update_one( + return self.simple_update_one( table="federation_stream_position", keyvalues={"type": typ}, updatevalues={"stream_id": stream_id}, diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/data_stores/main/tags.py index aa24339717..85012403be 100644 --- a/synapse/storage/data_stores/main/tags.py +++ b/synapse/storage/data_stores/main/tags.py @@ -41,7 +41,7 @@ class TagsWorkerStore(AccountDataWorkerStore): tag strings to tag content. """ - deferred = self._simple_select_list( + deferred = self.simple_select_list( "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"] ) @@ -153,7 +153,7 @@ class TagsWorkerStore(AccountDataWorkerStore): Returns: A deferred list of string tags. """ - return self._simple_select_list( + return self.simple_select_list( table="room_tags", keyvalues={"user_id": user_id, "room_id": room_id}, retcols=("tag", "content"), @@ -178,7 +178,7 @@ class TagsStore(TagsWorkerStore): content_json = json.dumps(content) def add_tag_txn(txn, next_id): - self._simple_upsert_txn( + self.simple_upsert_txn( txn, table="room_tags", keyvalues={"user_id": user_id, "room_id": room_id, "tag": tag}, diff --git a/synapse/storage/data_stores/main/transactions.py b/synapse/storage/data_stores/main/transactions.py index 01b1be5e14..c162f3ea16 100644 --- a/synapse/storage/data_stores/main/transactions.py +++ b/synapse/storage/data_stores/main/transactions.py @@ -85,7 +85,7 @@ class TransactionStore(SQLBaseStore): ) def _get_received_txn_response(self, txn, transaction_id, origin): - result = self._simple_select_one_txn( + result = self.simple_select_one_txn( txn, table="received_transactions", keyvalues={"transaction_id": transaction_id, "origin": origin}, @@ -119,7 +119,7 @@ class TransactionStore(SQLBaseStore): response_json (str) """ - return self._simple_insert( + return self.simple_insert( table="received_transactions", values={ "transaction_id": transaction_id, @@ -160,7 +160,7 @@ class TransactionStore(SQLBaseStore): return result def _get_destination_retry_timings(self, txn, destination): - result = self._simple_select_one_txn( + result = self.simple_select_one_txn( txn, table="destinations", keyvalues={"destination": destination}, @@ -227,7 +227,7 @@ class TransactionStore(SQLBaseStore): # We need to be careful here as the data may have changed from under us # due to a worker setting the timings. - prev_row = self._simple_select_one_txn( + prev_row = self.simple_select_one_txn( txn, table="destinations", keyvalues={"destination": destination}, @@ -236,7 +236,7 @@ class TransactionStore(SQLBaseStore): ) if not prev_row: - self._simple_insert_txn( + self.simple_insert_txn( txn, table="destinations", values={ @@ -247,7 +247,7 @@ class TransactionStore(SQLBaseStore): }, ) elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval: - self._simple_update_one_txn( + self.simple_update_one_txn( txn, "destinations", keyvalues={"destination": destination}, diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/data_stores/main/user_directory.py index 652abe0e6a..1a85aabbfb 100644 --- a/synapse/storage/data_stores/main/user_directory.py +++ b/synapse/storage/data_stores/main/user_directory.py @@ -85,7 +85,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore """ txn.execute(sql) rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()] - self._simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms) + self.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms) del rooms # If search all users is on, get all the users we want to add. @@ -100,13 +100,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore txn.execute("SELECT name FROM users") users = [{"user_id": x[0]} for x in txn.fetchall()] - self._simple_insert_many_txn(txn, TEMP_TABLE + "_users", users) + self.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users) new_pos = yield self.get_max_stream_id_in_current_state_deltas() yield self.runInteraction( "populate_user_directory_temp_build", _make_staging_area ) - yield self._simple_insert(TEMP_TABLE + "_position", {"position": new_pos}) + yield self.simple_insert(TEMP_TABLE + "_position", {"position": new_pos}) yield self._end_background_update("populate_user_directory_createtables") return 1 @@ -116,7 +116,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore """ Update the user directory stream position, then clean up the old tables. """ - position = yield self._simple_select_one_onecol( + position = yield self.simple_select_one_onecol( TEMP_TABLE + "_position", None, "position" ) yield self.update_user_directory_stream_pos(position) @@ -243,7 +243,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore to_insert.clear() # We've finished a room. Delete it from the table. - yield self._simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id}) + yield self.simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id}) # Update the remaining counter. progress["remaining"] -= 1 yield self.runInteraction( @@ -312,7 +312,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore ) # We've finished processing a user. Delete it from the table. - yield self._simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id}) + yield self.simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id}) # Update the remaining counter. progress["remaining"] -= 1 yield self.runInteraction( @@ -361,7 +361,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore """ def _update_profile_in_user_dir_txn(txn): - new_entry = self._simple_upsert_txn( + new_entry = self.simple_upsert_txn( txn, table="user_directory", keyvalues={"user_id": user_id}, @@ -435,7 +435,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore ) elif isinstance(self.database_engine, Sqlite3Engine): value = "%s %s" % (user_id, display_name) if display_name else user_id - self._simple_upsert_txn( + self.simple_upsert_txn( txn, table="user_directory_search", keyvalues={"user_id": user_id}, @@ -462,7 +462,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore """ def _add_users_who_share_room_txn(txn): - self._simple_upsert_many_txn( + self.simple_upsert_many_txn( txn, table="users_who_share_private_rooms", key_names=["user_id", "other_user_id", "room_id"], @@ -489,7 +489,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore def _add_users_in_public_rooms_txn(txn): - self._simple_upsert_many_txn( + self.simple_upsert_many_txn( txn, table="users_in_public_rooms", key_names=["user_id", "room_id"], @@ -519,7 +519,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore @cached() def get_user_in_directory(self, user_id): - return self._simple_select_one( + return self.simple_select_one( table="user_directory", keyvalues={"user_id": user_id}, retcols=("display_name", "avatar_url"), @@ -528,7 +528,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore ) def update_user_directory_stream_pos(self, stream_id): - return self._simple_update_one( + return self.simple_update_one( table="user_directory_stream_pos", keyvalues={}, updatevalues={"stream_id": stream_id}, @@ -547,21 +547,21 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): def remove_from_user_dir(self, user_id): def _remove_from_user_dir_txn(txn): - self._simple_delete_txn( + self.simple_delete_txn( txn, table="user_directory", keyvalues={"user_id": user_id} ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="user_directory_search", keyvalues={"user_id": user_id} ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="users_in_public_rooms", keyvalues={"user_id": user_id} ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"user_id": user_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"other_user_id": user_id}, @@ -575,14 +575,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): """Get all user_ids that are in the room directory because they're in the given room_id """ - user_ids_share_pub = yield self._simple_select_onecol( + user_ids_share_pub = yield self.simple_select_onecol( table="users_in_public_rooms", keyvalues={"room_id": room_id}, retcol="user_id", desc="get_users_in_dir_due_to_room", ) - user_ids_share_priv = yield self._simple_select_onecol( + user_ids_share_priv = yield self.simple_select_onecol( table="users_who_share_private_rooms", keyvalues={"room_id": room_id}, retcol="other_user_id", @@ -605,17 +605,17 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): """ def _remove_user_who_share_room_txn(txn): - self._simple_delete_txn( + self.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"user_id": user_id, "room_id": room_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"other_user_id": user_id, "room_id": room_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="users_in_public_rooms", keyvalues={"user_id": user_id, "room_id": room_id}, @@ -636,14 +636,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): Returns: list: user_id """ - rows = yield self._simple_select_onecol( + rows = yield self.simple_select_onecol( table="users_who_share_private_rooms", keyvalues={"user_id": user_id}, retcol="room_id", desc="get_rooms_user_is_in", ) - pub_rows = yield self._simple_select_onecol( + pub_rows = yield self.simple_select_onecol( table="users_in_public_rooms", keyvalues={"user_id": user_id}, retcol="room_id", @@ -674,14 +674,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): ) f2 USING (room_id) """ - rows = yield self._execute( + rows = yield self.execute( "get_rooms_in_common_for_users", None, sql, user_id, other_user_id ) return [room_id for room_id, in rows] def get_user_directory_stream_pos(self): - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="user_directory_stream_pos", keyvalues={}, retcol="stream_id", @@ -786,9 +786,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): # This should be unreachable. raise Exception("Unrecognized database engine") - results = yield self._execute( - "search_user_dir", self.cursor_to_dict, sql, *args - ) + results = yield self.execute("search_user_dir", self.cursor_to_dict, sql, *args) limited = len(results) > limit diff --git a/synapse/storage/data_stores/main/user_erasure_store.py b/synapse/storage/data_stores/main/user_erasure_store.py index aa4f0da5f0..37860af070 100644 --- a/synapse/storage/data_stores/main/user_erasure_store.py +++ b/synapse/storage/data_stores/main/user_erasure_store.py @@ -31,7 +31,7 @@ class UserErasureWorkerStore(SQLBaseStore): Returns: Deferred[bool]: True if the user has requested erasure """ - return self._simple_select_onecol( + return self.simple_select_onecol( table="erased_users", keyvalues={"user_id": user_id}, retcol="1", @@ -56,7 +56,7 @@ class UserErasureWorkerStore(SQLBaseStore): # iterate it multiple times, and (b) avoiding duplicates. user_ids = tuple(set(user_ids)) - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="erased_users", column="user_id", iterable=user_ids, diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index e0075ccd32..380fd0d107 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -45,13 +45,13 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.store._all_done = False self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", {"update_name": "populate_stats_prepare", "progress_json": "{}"}, ) ) self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", { "update_name": "populate_stats_process_rooms", @@ -61,7 +61,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", { "update_name": "populate_stats_process_users", @@ -71,7 +71,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", @@ -82,7 +82,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) def get_all_room_state(self): - return self.store._simple_select_list( + return self.store.simple_select_list( "room_stats_state", None, retcols=("name", "topic", "canonical_alias") ) @@ -96,7 +96,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): end_ts = self.store.quantise_stats_time(self.reactor.seconds() * 1000) return self.get_success( - self.store._simple_select_one( + self.store.simple_select_one( table + "_historical", {id_col: stat_id, end_ts: end_ts}, cols, @@ -180,7 +180,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.handler.stats_enabled = True self.store._all_done = False self.get_success( - self.store._simple_update_one( + self.store.simple_update_one( table="stats_incremental_position", keyvalues={}, updatevalues={"stream_id": 0}, @@ -188,7 +188,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", {"update_name": "populate_stats_prepare", "progress_json": "{}"}, ) @@ -205,13 +205,13 @@ class StatsRoomTests(unittest.HomeserverTestCase): # Now do the initial ingestion. self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", {"update_name": "populate_stats_process_rooms", "progress_json": "{}"}, ) ) self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", @@ -656,12 +656,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.store._all_done = False self.get_success( - self.store._simple_delete( + self.store.simple_delete( "room_stats_current", {"1": 1}, "test_delete_stats" ) ) self.get_success( - self.store._simple_delete( + self.store.simple_delete( "user_stats_current", {"1": 1}, "test_delete_stats" ) ) @@ -675,7 +675,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.store._all_done = False self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", { "update_name": "populate_stats_process_rooms", @@ -685,7 +685,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", { "update_name": "populate_stats_process_users", @@ -695,7 +695,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index c5e91a8c41..d5b1c5b4ac 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -158,7 +158,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def get_users_in_public_rooms(self): r = self.get_success( - self.store._simple_select_list( + self.store.simple_select_list( "users_in_public_rooms", None, ("user_id", "room_id") ) ) @@ -169,7 +169,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def get_users_who_share_private_rooms(self): return self.get_success( - self.store._simple_select_list( + self.store.simple_select_list( "users_who_share_private_rooms", None, ["user_id", "other_user_id", "room_id"], @@ -184,7 +184,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.store._all_done = False self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", { "update_name": "populate_user_directory_createtables", @@ -193,7 +193,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", { "update_name": "populate_user_directory_process_rooms", @@ -203,7 +203,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", { "update_name": "populate_user_directory_process_users", @@ -213,7 +213,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", { "update_name": "populate_user_directory_cleanup", diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 9575058252..124ce0768a 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -632,7 +632,7 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase): "state_groups_state", ): count = self.get_success( - self.store._simple_select_one_onecol( + self.store.simple_select_one_onecol( table=table, keyvalues={"room_id": room_id}, retcol="COUNT(*)", diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 9b81b536f5..7b7434a468 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -356,7 +356,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): self.get_success( self.storage.runInteraction( "test", - self.storage._simple_upsert_many_txn, + self.storage.simple_upsert_many_txn, self.table_name, key_names, key_values, @@ -367,7 +367,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): # Check results are what we expect res = self.get_success( - self.storage._simple_select_list( + self.storage.simple_select_list( self.table_name, None, ["id, username, value"] ) ) @@ -383,7 +383,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): self.get_success( self.storage.runInteraction( "test", - self.storage._simple_upsert_many_txn, + self.storage.simple_upsert_many_txn, self.table_name, key_names, key_values, @@ -394,7 +394,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): # Check results are what we expect res = self.get_success( - self.storage._simple_select_list( + self.storage.simple_select_list( self.table_name, None, ["id, username, value"] ) ) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index c778de1f0c..de5e4a5fce 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -65,7 +65,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_insert_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_insert( + yield self.datastore.simple_insert( table="tablename", values={"columname": "Value"} ) @@ -77,7 +77,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_insert_3cols(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_insert( + yield self.datastore.simple_insert( table="tablename", # Use OrderedDict() so we can assert on the SQL generated values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]), @@ -92,7 +92,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)])) - value = yield self.datastore._simple_select_one_onecol( + value = yield self.datastore.simple_select_one_onecol( table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol" ) @@ -106,7 +106,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 self.mock_txn.fetchone.return_value = (1, 2, 3) - ret = yield self.datastore._simple_select_one( + ret = yield self.datastore.simple_select_one( table="tablename", keyvalues={"keycol": "TheKey"}, retcols=["colA", "colB", "colC"], @@ -122,7 +122,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 0 self.mock_txn.fetchone.return_value = None - ret = yield self.datastore._simple_select_one( + ret = yield self.datastore.simple_select_one( table="tablename", keyvalues={"keycol": "Not here"}, retcols=["colA"], @@ -137,7 +137,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)])) self.mock_txn.description = (("colA", None, None, None, None, None, None),) - ret = yield self.datastore._simple_select_list( + ret = yield self.datastore.simple_select_list( table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"] ) @@ -150,7 +150,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_update_one_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_update_one( + yield self.datastore.simple_update_one( table="tablename", keyvalues={"keycol": "TheKey"}, updatevalues={"columnname": "New Value"}, @@ -165,7 +165,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_update_one_4cols(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_update_one( + yield self.datastore.simple_update_one( table="tablename", keyvalues=OrderedDict([("colA", 1), ("colB", 2)]), updatevalues=OrderedDict([("colC", 3), ("colD", 4)]), @@ -180,7 +180,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_delete_one(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_delete_one( + yield self.datastore.simple_delete_one( table="tablename", keyvalues={"keycol": "Go away"} ) diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index afac5dec7f..25bdd2c163 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -81,7 +81,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.pump(0) result = self.get_success( - self.store._simple_select_list( + self.store.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -112,7 +112,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.pump(0) result = self.get_success( - self.store._simple_select_list( + self.store.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -218,7 +218,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # But clear the associated entry in devices table self.get_success( - self.store._simple_update( + self.store.simple_update( table="devices", keyvalues={"user_id": user_id, "device_id": "device_id"}, updatevalues={"last_seen": None, "ip": None, "user_agent": None}, @@ -245,7 +245,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # Register the background update to run again. self.get_success( - self.store._simple_insert( + self.store.simple_insert( table="background_updates", values={ "update_name": "devices_last_seen", @@ -297,7 +297,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # We should see that in the DB result = self.get_success( - self.store._simple_select_list( + self.store.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -323,7 +323,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # We should get no results. result = self.get_success( - self.store._simple_select_list( + self.store.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index b114c6fb1d..2337a1ae46 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -116,7 +116,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): yield _inject_actions(6, PlAIN_NOTIF) yield _rotate(7) - yield self.store._simple_delete( + yield self.store.simple_delete( table="event_push_actions", keyvalues={"1": 1}, desc="" ) @@ -135,7 +135,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_find_first_stream_ordering_after_ts(self): def add_event(so, ts): - return self.store._simple_insert( + return self.store.simple_insert( "events", { "stream_ordering": so, diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 4561c3e383..4930b6777e 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -338,7 +338,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) event_json = self.get_success( - self.store._simple_select_one_onecol( + self.store.simple_select_one_onecol( table="event_json", keyvalues={"event_id": msg_event.event_id}, retcol="json", @@ -356,7 +356,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.reactor.advance(60 * 60 * 2) event_json = self.get_success( - self.store._simple_select_one_onecol( + self.store.simple_select_one_onecol( table="event_json", keyvalues={"event_id": msg_event.event_id}, retcol="json", diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 105a0c2b02..d389cf578f 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -132,7 +132,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): # Register the background update to run again. self.get_success( - self.store._simple_insert( + self.store.simple_insert( table="background_updates", values={ "update_name": "current_state_events_membership", diff --git a/tests/unittest.py b/tests/unittest.py index 31997a0f31..295573bc46 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -544,7 +544,7 @@ class HomeserverTestCase(TestCase): Add the given event as an extremity to the room. """ self.get_success( - self.hs.get_datastore()._simple_insert( + self.hs.get_datastore().simple_insert( table="event_forward_extremities", values={"room_id": room_id, "event_id": event_id}, desc="test_add_extremity", -- cgit 1.5.1 From 756d4942f5707922f29fe1fdfd945d73a19d7ac3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 4 Dec 2019 13:52:46 +0000 Subject: Move DB pool and helper functions into dedicated Database class --- scripts/synapse_port_db | 73 +- synapse/app/_base.py | 2 +- synapse/app/user_dir.py | 2 +- synapse/module_api/__init__.py | 2 +- synapse/storage/_base.py | 1468 +------------------ synapse/storage/background_updates.py | 16 +- synapse/storage/data_stores/main/__init__.py | 36 +- synapse/storage/data_stores/main/account_data.py | 26 +- synapse/storage/data_stores/main/appservice.py | 24 +- synapse/storage/data_stores/main/cache.py | 6 +- synapse/storage/data_stores/main/client_ips.py | 24 +- synapse/storage/data_stores/main/deviceinbox.py | 20 +- synapse/storage/data_stores/main/devices.py | 70 +- synapse/storage/data_stores/main/directory.py | 20 +- synapse/storage/data_stores/main/e2e_room_keys.py | 28 +- .../storage/data_stores/main/end_to_end_keys.py | 44 +- .../storage/data_stores/main/event_federation.py | 40 +- .../storage/data_stores/main/event_push_actions.py | 42 +- synapse/storage/data_stores/main/events.py | 90 +- .../storage/data_stores/main/events_bg_updates.py | 24 +- synapse/storage/data_stores/main/events_worker.py | 18 +- synapse/storage/data_stores/main/filtering.py | 4 +- synapse/storage/data_stores/main/group_server.py | 160 +-- synapse/storage/data_stores/main/keys.py | 12 +- .../storage/data_stores/main/media_repository.py | 44 +- .../data_stores/main/monthly_active_users.py | 12 +- synapse/storage/data_stores/main/openid.py | 6 +- synapse/storage/data_stores/main/presence.py | 12 +- synapse/storage/data_stores/main/profile.py | 28 +- synapse/storage/data_stores/main/push_rule.py | 36 +- synapse/storage/data_stores/main/pusher.py | 34 +- synapse/storage/data_stores/main/receipts.py | 36 +- synapse/storage/data_stores/main/registration.py | 164 +-- synapse/storage/data_stores/main/rejections.py | 4 +- synapse/storage/data_stores/main/relations.py | 12 +- synapse/storage/data_stores/main/room.py | 74 +- synapse/storage/data_stores/main/roommember.py | 44 +- synapse/storage/data_stores/main/search.py | 30 +- synapse/storage/data_stores/main/signatures.py | 4 +- synapse/storage/data_stores/main/state.py | 54 +- synapse/storage/data_stores/main/state_deltas.py | 8 +- synapse/storage/data_stores/main/stats.py | 48 +- synapse/storage/data_stores/main/stream.py | 30 +- synapse/storage/data_stores/main/tags.py | 18 +- synapse/storage/data_stores/main/transactions.py | 22 +- synapse/storage/data_stores/main/user_directory.py | 80 +- .../storage/data_stores/main/user_erasure_store.py | 6 +- synapse/storage/database.py | 1485 ++++++++++++++++++++ tests/handlers/test_stats.py | 30 +- tests/handlers/test_user_directory.py | 12 +- tests/rest/admin/test_admin.py | 2 +- tests/storage/test__base.py | 16 +- tests/storage/test_background_update.py | 2 +- tests/storage/test_base.py | 18 +- tests/storage/test_cleanup_extrems.py | 4 +- tests/storage/test_client_ips.py | 12 +- tests/storage/test_event_federation.py | 8 +- tests/storage/test_event_push_actions.py | 12 +- tests/storage/test_monthly_active_users.py | 6 +- tests/storage/test_redaction.py | 4 +- tests/storage/test_roommember.py | 2 +- tests/unittest.py | 2 +- 62 files changed, 2377 insertions(+), 2295 deletions(-) create mode 100644 synapse/storage/database.py (limited to 'tests') diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index c4cf11d19a..7a2e177d3d 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -173,14 +173,14 @@ class Store( return (yield self.db_pool.runWithConnection(r)) def execute(self, f, *args, **kwargs): - return self.runInteraction(f.__name__, f, *args, **kwargs) + return self.db.runInteraction(f.__name__, f, *args, **kwargs) def execute_sql(self, sql, *args): def r(txn): txn.execute(sql, args) return txn.fetchall() - return self.runInteraction("execute_sql", r) + return self.db.runInteraction("execute_sql", r) def insert_many_txn(self, txn, table, headers, rows): sql = "INSERT INTO %s (%s) VALUES (%s)" % ( @@ -223,7 +223,7 @@ class Porter(object): def setup_table(self, table): if table in APPEND_ONLY_TABLES: # It's safe to just carry on inserting. - row = yield self.postgres_store.simple_select_one( + row = yield self.postgres_store.db.simple_select_one( table="port_from_sqlite3", keyvalues={"table_name": table}, retcols=("forward_rowid", "backward_rowid"), @@ -233,12 +233,14 @@ class Porter(object): total_to_port = None if row is None: if table == "sent_transactions": - forward_chunk, already_ported, total_to_port = ( - yield self._setup_sent_transactions() - ) + ( + forward_chunk, + already_ported, + total_to_port, + ) = yield self._setup_sent_transactions() backward_chunk = 0 else: - yield self.postgres_store.simple_insert( + yield self.postgres_store.db.simple_insert( table="port_from_sqlite3", values={ "table_name": table, @@ -268,7 +270,7 @@ class Porter(object): yield self.postgres_store.execute(delete_all) - yield self.postgres_store.simple_insert( + yield self.postgres_store.db.simple_insert( table="port_from_sqlite3", values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0}, ) @@ -322,7 +324,7 @@ class Porter(object): if table == "user_directory_stream_pos": # We need to make sure there is a single row, `(X, null), as that is # what synapse expects to be there. - yield self.postgres_store.simple_insert( + yield self.postgres_store.db.simple_insert( table=table, values={"stream_id": None} ) self.progress.update(table, table_size) # Mark table as done @@ -363,7 +365,7 @@ class Porter(object): return headers, forward_rows, backward_rows - headers, frows, brows = yield self.sqlite_store.runInteraction("select", r) + headers, frows, brows = yield self.sqlite_store.db.runInteraction("select", r) if frows or brows: if frows: @@ -377,7 +379,7 @@ class Porter(object): def insert(txn): self.postgres_store.insert_many_txn(txn, table, headers[1:], rows) - self.postgres_store.simple_update_one_txn( + self.postgres_store.db.simple_update_one_txn( txn, table="port_from_sqlite3", keyvalues={"table_name": table}, @@ -416,7 +418,7 @@ class Porter(object): return headers, rows - headers, rows = yield self.sqlite_store.runInteraction("select", r) + headers, rows = yield self.sqlite_store.db.runInteraction("select", r) if rows: forward_chunk = rows[-1][0] + 1 @@ -433,8 +435,8 @@ class Porter(object): rows_dict = [] for row in rows: d = dict(zip(headers, row)) - if "\0" in d['value']: - logger.warning('dropping search row %s', d) + if "\0" in d["value"]: + logger.warning("dropping search row %s", d) else: rows_dict.append(d) @@ -454,7 +456,7 @@ class Porter(object): ], ) - self.postgres_store.simple_update_one_txn( + self.postgres_store.db.simple_update_one_txn( txn, table="port_from_sqlite3", keyvalues={"table_name": "event_search"}, @@ -504,17 +506,14 @@ class Porter(object): self.progress.set_state("Preparing %s" % config["name"]) conn = self.setup_db(config, engine) - db_pool = adbapi.ConnectionPool( - config["name"], **config["args"] - ) + db_pool = adbapi.ConnectionPool(config["name"], **config["args"]) hs = MockHomeserver(self.hs_config, engine, conn, db_pool) store = Store(conn, hs) - yield store.runInteraction( - "%s_engine.check_database" % config["name"], - engine.check_database, + yield store.db.runInteraction( + "%s_engine.check_database" % config["name"], engine.check_database, ) return store @@ -541,7 +540,9 @@ class Porter(object): self.sqlite_store = yield self.build_db_store(self.sqlite_config) # Check if all background updates are done, abort if not. - updates_complete = yield self.sqlite_store.has_completed_background_updates() + updates_complete = ( + yield self.sqlite_store.has_completed_background_updates() + ) if not updates_complete: sys.stderr.write( "Pending background updates exist in the SQLite3 database." @@ -582,22 +583,22 @@ class Porter(object): ) try: - yield self.postgres_store.runInteraction("alter_table", alter_table) + yield self.postgres_store.db.runInteraction("alter_table", alter_table) except Exception: # On Error Resume Next pass - yield self.postgres_store.runInteraction( + yield self.postgres_store.db.runInteraction( "create_port_table", create_port_table ) # Step 2. Get tables. self.progress.set_state("Fetching tables") - sqlite_tables = yield self.sqlite_store.simple_select_onecol( + sqlite_tables = yield self.sqlite_store.db.simple_select_onecol( table="sqlite_master", keyvalues={"type": "table"}, retcol="name" ) - postgres_tables = yield self.postgres_store.simple_select_onecol( + postgres_tables = yield self.postgres_store.db.simple_select_onecol( table="information_schema.tables", keyvalues={}, retcol="distinct table_name", @@ -687,11 +688,11 @@ class Porter(object): rows = txn.fetchall() headers = [column[0] for column in txn.description] - ts_ind = headers.index('ts') + ts_ind = headers.index("ts") return headers, [r for r in rows if r[ts_ind] < yesterday] - headers, rows = yield self.sqlite_store.runInteraction("select", r) + headers, rows = yield self.sqlite_store.db.runInteraction("select", r) rows = self._convert_rows("sent_transactions", headers, rows) @@ -724,7 +725,7 @@ class Porter(object): next_chunk = yield self.sqlite_store.execute(get_start_id) next_chunk = max(max_inserted_rowid + 1, next_chunk) - yield self.postgres_store.simple_insert( + yield self.postgres_store.db.simple_insert( table="port_from_sqlite3", values={ "table_name": "sent_transactions", @@ -737,7 +738,7 @@ class Porter(object): txn.execute( "SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday,) ) - size, = txn.fetchone() + (size,) = txn.fetchone() return int(size) remaining_count = yield self.sqlite_store.execute(get_sent_table_size) @@ -790,7 +791,7 @@ class Porter(object): next_id = curr_id + 1 txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,)) - return self.postgres_store.runInteraction("setup_state_group_id_seq", r) + return self.postgres_store.db.runInteraction("setup_state_group_id_seq", r) ############################################## @@ -871,7 +872,7 @@ class CursesProgress(Progress): duration = int(now) - int(self.start_time) minutes, seconds = divmod(duration, 60) - duration_str = '%02dm %02ds' % (minutes, seconds) + duration_str = "%02dm %02ds" % (minutes, seconds) if self.finished: status = "Time spent: %s (Done!)" % (duration_str,) @@ -881,7 +882,7 @@ class CursesProgress(Progress): left = float(self.total_remaining) / self.total_processed est_remaining = (int(now) - self.start_time) * left - est_remaining_str = '%02dm %02ds remaining' % divmod(est_remaining, 60) + est_remaining_str = "%02dm %02ds remaining" % divmod(est_remaining, 60) else: est_remaining_str = "Unknown" status = "Time spent: %s (est. remaining: %s)" % ( @@ -967,7 +968,7 @@ if __name__ == "__main__": description="A script to port an existing synapse SQLite database to" " a new PostgreSQL database." ) - parser.add_argument("-v", action='store_true') + parser.add_argument("-v", action="store_true") parser.add_argument( "--sqlite-database", required=True, @@ -976,12 +977,12 @@ if __name__ == "__main__": ) parser.add_argument( "--postgres-config", - type=argparse.FileType('r'), + type=argparse.FileType("r"), required=True, help="The database config file for the PostgreSQL database", ) parser.add_argument( - "--curses", action='store_true', help="display a curses based progress UI" + "--curses", action="store_true", help="display a curses based progress UI" ) parser.add_argument( diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 2ac7d5c064..9c96816096 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -269,7 +269,7 @@ def start(hs, listeners=None): # It is now safe to start your Synapse. hs.start_listening(listeners) - hs.get_datastore().start_profiling() + hs.get_datastore().db.start_profiling() setup_sentry(hs) setup_sdnotify(hs) diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index 0fa2b50999..b6d4481725 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -64,7 +64,7 @@ class UserDirectorySlaveStore( super(UserDirectorySlaveStore, self).__init__(db_conn, hs) events_max = self._stream_id_gen.get_current_token() - curr_state_delta_prefill, min_curr_state_delta_id = self.get_cache_dict( + curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict( db_conn, "current_state_delta_stream", entity_column="room_id", diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 735b882363..305b9b0178 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -175,4 +175,4 @@ class ModuleApi(object): Returns: Deferred[object]: result of func """ - return self._store.runInteraction(desc, func, *args, **kwargs) + return self._store.db.runInteraction(desc, func, *args, **kwargs) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 9205e550bb..fd5bb3e1de 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -16,1304 +16,28 @@ # limitations under the License. import logging import random -import sys -import time -from typing import Iterable, Tuple -from six import PY2, iteritems, iterkeys, itervalues -from six.moves import builtins, intern, range +from six import PY2 +from six.moves import builtins from canonicaljson import json -from prometheus_client import Histogram -from twisted.internet import defer - -from synapse.api.errors import StoreError -from synapse.logging.context import LoggingContext, make_deferred_yieldable -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.storage.database import LoggingTransaction # noqa: F401 +from synapse.storage.database import make_in_list_sql_clause # noqa: F401 +from synapse.storage.database import Database from synapse.types import get_domain_from_id -from synapse.util.stringutils import exception_to_unicode - -# import a function which will return a monotonic time, in seconds -try: - # on python 3, use time.monotonic, since time.clock can go backwards - from time import monotonic as monotonic_time -except ImportError: - # ... but python 2 doesn't have it - from time import clock as monotonic_time logger = logging.getLogger(__name__) -try: - MAX_TXN_ID = sys.maxint - 1 -except AttributeError: - # python 3 does not have a maximum int value - MAX_TXN_ID = 2 ** 63 - 1 - -sql_logger = logging.getLogger("synapse.storage.SQL") -transaction_logger = logging.getLogger("synapse.storage.txn") -perf_logger = logging.getLogger("synapse.storage.TIME") - -sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec") - -sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"]) -sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"]) - - -# Unique indexes which have been added in background updates. Maps from table name -# to the name of the background update which added the unique index to that table. -# -# This is used by the upsert logic to figure out which tables are safe to do a proper -# UPSERT on: until the relevant background update has completed, we -# have to emulate an upsert by locking the table. -# -UNIQUE_INDEX_BACKGROUND_UPDATES = { - "user_ips": "user_ips_device_unique_index", - "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx", - "device_lists_remote_cache": "device_lists_remote_cache_unique_idx", - "event_search": "event_search_event_id_idx", -} - - -class LoggingTransaction(object): - """An object that almost-transparently proxies for the 'txn' object - passed to the constructor. Adds logging and metrics to the .execute() - method. - - Args: - txn: The database transcation object to wrap. - name (str): The name of this transactions for logging. - database_engine (Sqlite3Engine|PostgresEngine) - after_callbacks(list|None): A list that callbacks will be appended to - that have been added by `call_after` which should be run on - successful completion of the transaction. None indicates that no - callbacks should be allowed to be scheduled to run. - exception_callbacks(list|None): A list that callbacks will be appended - to that have been added by `call_on_exception` which should be run - if transaction ends with an error. None indicates that no callbacks - should be allowed to be scheduled to run. - """ - - __slots__ = [ - "txn", - "name", - "database_engine", - "after_callbacks", - "exception_callbacks", - ] - - def __init__( - self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None - ): - object.__setattr__(self, "txn", txn) - object.__setattr__(self, "name", name) - object.__setattr__(self, "database_engine", database_engine) - object.__setattr__(self, "after_callbacks", after_callbacks) - object.__setattr__(self, "exception_callbacks", exception_callbacks) - - def call_after(self, callback, *args, **kwargs): - """Call the given callback on the main twisted thread after the - transaction has finished. Used to invalidate the caches on the - correct thread. - """ - self.after_callbacks.append((callback, args, kwargs)) - - def call_on_exception(self, callback, *args, **kwargs): - self.exception_callbacks.append((callback, args, kwargs)) - - def __getattr__(self, name): - return getattr(self.txn, name) - - def __setattr__(self, name, value): - setattr(self.txn, name, value) - - def __iter__(self): - return self.txn.__iter__() - - def execute_batch(self, sql, args): - if isinstance(self.database_engine, PostgresEngine): - from psycopg2.extras import execute_batch - - self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args) - else: - for val in args: - self.execute(sql, val) - - def execute(self, sql, *args): - self._do_execute(self.txn.execute, sql, *args) - - def executemany(self, sql, *args): - self._do_execute(self.txn.executemany, sql, *args) - - def _make_sql_one_line(self, sql): - "Strip newlines out of SQL so that the loggers in the DB are on one line" - return " ".join(l.strip() for l in sql.splitlines() if l.strip()) - - def _do_execute(self, func, sql, *args): - sql = self._make_sql_one_line(sql) - - # TODO(paul): Maybe use 'info' and 'debug' for values? - sql_logger.debug("[SQL] {%s} %s", self.name, sql) - - sql = self.database_engine.convert_param_style(sql) - if args: - try: - sql_logger.debug("[SQL values] {%s} %r", self.name, args[0]) - except Exception: - # Don't let logging failures stop SQL from working - pass - - start = time.time() - - try: - return func(sql, *args) - except Exception as e: - logger.debug("[SQL FAIL] {%s} %s", self.name, e) - raise - finally: - secs = time.time() - start - sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs) - sql_query_timer.labels(sql.split()[0]).observe(secs) - - -class PerformanceCounters(object): - def __init__(self): - self.current_counters = {} - self.previous_counters = {} - - def update(self, key, duration_secs): - count, cum_time = self.current_counters.get(key, (0, 0)) - count += 1 - cum_time += duration_secs - self.current_counters[key] = (count, cum_time) - - def interval(self, interval_duration_secs, limit=3): - counters = [] - for name, (count, cum_time) in iteritems(self.current_counters): - prev_count, prev_time = self.previous_counters.get(name, (0, 0)) - counters.append( - ( - (cum_time - prev_time) / interval_duration_secs, - count - prev_count, - name, - ) - ) - - self.previous_counters = dict(self.current_counters) - - counters.sort(reverse=True) - - top_n_counters = ", ".join( - "%s(%d): %.3f%%" % (name, count, 100 * ratio) - for ratio, count, name in counters[:limit] - ) - - return top_n_counters - class SQLBaseStore(object): - _TXN_ID = 0 - def __init__(self, db_conn, hs): self.hs = hs self._clock = hs.get_clock() - self._db_pool = hs.get_db_pool() - - self._previous_txn_total_time = 0 - self._current_txn_total_time = 0 - self._previous_loop_ts = 0 - - # TODO(paul): These can eventually be removed once the metrics code - # is running in mainline, and we have some nice monitoring frontends - # to watch it - self._txn_perf_counters = PerformanceCounters() - self.database_engine = hs.database_engine - - # A set of tables that are not safe to use native upserts in. - self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys()) - - # We add the user_directory_search table to the blacklist on SQLite - # because the existing search table does not have an index, making it - # unsafe to use native upserts. - if isinstance(self.database_engine, Sqlite3Engine): - self._unsafe_to_upsert_tables.add("user_directory_search") - - if self.database_engine.can_native_upsert: - # Check ASAP (and then later, every 1s) to see if we have finished - # background updates of tables that aren't safe to update. - self._clock.call_later( - 0.0, - run_as_background_process, - "upsert_safety_check", - self._check_safe_to_upsert, - ) - + self.db = Database(hs) self.rand = random.SystemRandom() - @defer.inlineCallbacks - def _check_safe_to_upsert(self): - """ - Is it safe to use native UPSERT? - - If there are background updates, we will need to wait, as they may be - the addition of indexes that set the UNIQUE constraint that we require. - - If the background updates have not completed, wait 15 sec and check again. - """ - updates = yield self.simple_select_list( - "background_updates", - keyvalues=None, - retcols=["update_name"], - desc="check_background_updates", - ) - updates = [x["update_name"] for x in updates] - - for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items(): - if update_name not in updates: - logger.debug("Now safe to upsert in %s", table) - self._unsafe_to_upsert_tables.discard(table) - - # If there's any updates still running, reschedule to run. - if updates: - self._clock.call_later( - 15.0, - run_as_background_process, - "upsert_safety_check", - self._check_safe_to_upsert, - ) - - def start_profiling(self): - self._previous_loop_ts = monotonic_time() - - def loop(): - curr = self._current_txn_total_time - prev = self._previous_txn_total_time - self._previous_txn_total_time = curr - - time_now = monotonic_time() - time_then = self._previous_loop_ts - self._previous_loop_ts = time_now - - duration = time_now - time_then - ratio = (curr - prev) / duration - - top_three_counters = self._txn_perf_counters.interval(duration, limit=3) - - perf_logger.info( - "Total database time: %.3f%% {%s}", ratio * 100, top_three_counters - ) - - self._clock.looping_call(loop, 10000) - - def new_transaction( - self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs - ): - start = monotonic_time() - txn_id = self._TXN_ID - - # We don't really need these to be unique, so lets stop it from - # growing really large. - self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID) - - name = "%s-%x" % (desc, txn_id) - - transaction_logger.debug("[TXN START] {%s}", name) - - try: - i = 0 - N = 5 - while True: - cursor = LoggingTransaction( - conn.cursor(), - name, - self.database_engine, - after_callbacks, - exception_callbacks, - ) - try: - r = func(cursor, *args, **kwargs) - conn.commit() - return r - except self.database_engine.module.OperationalError as e: - # This can happen if the database disappears mid - # transaction. - logger.warning( - "[TXN OPERROR] {%s} %s %d/%d", - name, - exception_to_unicode(e), - i, - N, - ) - if i < N: - i += 1 - try: - conn.rollback() - except self.database_engine.module.Error as e1: - logger.warning( - "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1) - ) - continue - raise - except self.database_engine.module.DatabaseError as e: - if self.database_engine.is_deadlock(e): - logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N) - if i < N: - i += 1 - try: - conn.rollback() - except self.database_engine.module.Error as e1: - logger.warning( - "[TXN EROLL] {%s} %s", - name, - exception_to_unicode(e1), - ) - continue - raise - finally: - # we're either about to retry with a new cursor, or we're about to - # release the connection. Once we release the connection, it could - # get used for another query, which might do a conn.rollback(). - # - # In the latter case, even though that probably wouldn't affect the - # results of this transaction, python's sqlite will reset all - # statements on the connection [1], which will make our cursor - # invalid [2]. - # - # In any case, continuing to read rows after commit()ing seems - # dubious from the PoV of ACID transactional semantics - # (sqlite explicitly says that once you commit, you may see rows - # from subsequent updates.) - # - # In psycopg2, cursors are essentially a client-side fabrication - - # all the data is transferred to the client side when the statement - # finishes executing - so in theory we could go on streaming results - # from the cursor, but attempting to do so would make us - # incompatible with sqlite, so let's make sure we're not doing that - # by closing the cursor. - # - # (*named* cursors in psycopg2 are different and are proper server- - # side things, but (a) we don't use them and (b) they are implicitly - # closed by ending the transaction anyway.) - # - # In short, if we haven't finished with the cursor yet, that's a - # problem waiting to bite us. - # - # TL;DR: we're done with the cursor, so we can close it. - # - # [1]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/connection.c#L465 - # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236 - cursor.close() - except Exception as e: - logger.debug("[TXN FAIL] {%s} %s", name, e) - raise - finally: - end = monotonic_time() - duration = end - start - - LoggingContext.current_context().add_database_transaction(duration) - - transaction_logger.debug("[TXN END] {%s} %f sec", name, duration) - - self._current_txn_total_time += duration - self._txn_perf_counters.update(desc, duration) - sql_txn_timer.labels(desc).observe(duration) - - @defer.inlineCallbacks - def runInteraction(self, desc, func, *args, **kwargs): - """Starts a transaction on the database and runs a given function - - Arguments: - desc (str): description of the transaction, for logging and metrics - func (func): callback function, which will be called with a - database transaction (twisted.enterprise.adbapi.Transaction) as - its first argument, followed by `args` and `kwargs`. - - args (list): positional args to pass to `func` - kwargs (dict): named args to pass to `func` - - Returns: - Deferred: The result of func - """ - after_callbacks = [] - exception_callbacks = [] - - if LoggingContext.current_context() == LoggingContext.sentinel: - logger.warning("Starting db txn '%s' from sentinel context", desc) - - try: - result = yield self.runWithConnection( - self.new_transaction, - desc, - after_callbacks, - exception_callbacks, - func, - *args, - **kwargs - ) - - for after_callback, after_args, after_kwargs in after_callbacks: - after_callback(*after_args, **after_kwargs) - except: # noqa: E722, as we reraise the exception this is fine. - for after_callback, after_args, after_kwargs in exception_callbacks: - after_callback(*after_args, **after_kwargs) - raise - - return result - - @defer.inlineCallbacks - def runWithConnection(self, func, *args, **kwargs): - """Wraps the .runWithConnection() method on the underlying db_pool. - - Arguments: - func (func): callback function, which will be called with a - database connection (twisted.enterprise.adbapi.Connection) as - its first argument, followed by `args` and `kwargs`. - args (list): positional args to pass to `func` - kwargs (dict): named args to pass to `func` - - Returns: - Deferred: The result of func - """ - parent_context = LoggingContext.current_context() - if parent_context == LoggingContext.sentinel: - logger.warning( - "Starting db connection from sentinel context: metrics will be lost" - ) - parent_context = None - - start_time = monotonic_time() - - def inner_func(conn, *args, **kwargs): - with LoggingContext("runWithConnection", parent_context) as context: - sched_duration_sec = monotonic_time() - start_time - sql_scheduling_timer.observe(sched_duration_sec) - context.add_database_scheduled(sched_duration_sec) - - if self.database_engine.is_connection_closed(conn): - logger.debug("Reconnecting closed database connection") - conn.reconnect() - - return func(conn, *args, **kwargs) - - result = yield make_deferred_yieldable( - self._db_pool.runWithConnection(inner_func, *args, **kwargs) - ) - - return result - - @staticmethod - def cursor_to_dict(cursor): - """Converts a SQL cursor into an list of dicts. - - Args: - cursor : The DBAPI cursor which has executed a query. - Returns: - A list of dicts where the key is the column header. - """ - col_headers = list(intern(str(column[0])) for column in cursor.description) - results = list(dict(zip(col_headers, row)) for row in cursor) - return results - - def execute(self, desc, decoder, query, *args): - """Runs a single query for a result set. - - Args: - decoder - The function which can resolve the cursor results to - something meaningful. - query - The query string to execute - *args - Query args. - Returns: - The result of decoder(results) - """ - - def interaction(txn): - txn.execute(query, args) - if decoder: - return decoder(txn) - else: - return txn.fetchall() - - return self.runInteraction(desc, interaction) - - # "Simple" SQL API methods that operate on a single table with no JOINs, - # no complex WHERE clauses, just a dict of values for columns. - - @defer.inlineCallbacks - def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"): - """Executes an INSERT query on the named table. - - Args: - table : string giving the table name - values : dict of new column names and values for them - or_ignore : bool stating whether an exception should be raised - when a conflicting row already exists. If True, False will be - returned by the function instead - desc : string giving a description of the transaction - - Returns: - bool: Whether the row was inserted or not. Only useful when - `or_ignore` is True - """ - try: - yield self.runInteraction(desc, self.simple_insert_txn, table, values) - except self.database_engine.module.IntegrityError: - # We have to do or_ignore flag at this layer, since we can't reuse - # a cursor after we receive an error from the db. - if not or_ignore: - raise - return False - return True - - @staticmethod - def simple_insert_txn(txn, table, values): - keys, vals = zip(*values.items()) - - sql = "INSERT INTO %s (%s) VALUES(%s)" % ( - table, - ", ".join(k for k in keys), - ", ".join("?" for _ in keys), - ) - - txn.execute(sql, vals) - - def simple_insert_many(self, table, values, desc): - return self.runInteraction(desc, self.simple_insert_many_txn, table, values) - - @staticmethod - def simple_insert_many_txn(txn, table, values): - if not values: - return - - # This is a *slight* abomination to get a list of tuples of key names - # and a list of tuples of value names. - # - # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}] - # => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)] - # - # The sort is to ensure that we don't rely on dictionary iteration - # order. - keys, vals = zip( - *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i] - ) - - for k in keys: - if k != keys[0]: - raise RuntimeError("All items must have the same keys") - - sql = "INSERT INTO %s (%s) VALUES(%s)" % ( - table, - ", ".join(k for k in keys[0]), - ", ".join("?" for _ in keys[0]), - ) - - txn.executemany(sql, vals) - - @defer.inlineCallbacks - def simple_upsert( - self, - table, - keyvalues, - values, - insertion_values={}, - desc="simple_upsert", - lock=True, - ): - """ - - `lock` should generally be set to True (the default), but can be set - to False if either of the following are true: - - * there is a UNIQUE INDEX on the key columns. In this case a conflict - will cause an IntegrityError in which case this function will retry - the update. - - * we somehow know that we are the only thread which will be updating - this table. - - Args: - table (str): The table to upsert into - keyvalues (dict): The unique key columns and their new values - values (dict): The nonunique columns and their new values - insertion_values (dict): additional key/values to use only when - inserting - lock (bool): True to lock the table when doing the upsert. - Returns: - Deferred(None or bool): Native upserts always return None. Emulated - upserts return True if a new entry was created, False if an existing - one was updated. - """ - attempts = 0 - while True: - try: - result = yield self.runInteraction( - desc, - self.simple_upsert_txn, - table, - keyvalues, - values, - insertion_values, - lock=lock, - ) - return result - except self.database_engine.module.IntegrityError as e: - attempts += 1 - if attempts >= 5: - # don't retry forever, because things other than races - # can cause IntegrityErrors - raise - - # presumably we raced with another transaction: let's retry. - logger.warning( - "IntegrityError when upserting into %s; retrying: %s", table, e - ) - - def simple_upsert_txn( - self, txn, table, keyvalues, values, insertion_values={}, lock=True - ): - """ - Pick the UPSERT method which works best on the platform. Either the - native one (Pg9.5+, recent SQLites), or fall back to an emulated method. - - Args: - txn: The transaction to use. - table (str): The table to upsert into - keyvalues (dict): The unique key tables and their new values - values (dict): The nonunique columns and their new values - insertion_values (dict): additional key/values to use only when - inserting - lock (bool): True to lock the table when doing the upsert. - Returns: - None or bool: Native upserts always return None. Emulated - upserts return True if a new entry was created, False if an existing - one was updated. - """ - if ( - self.database_engine.can_native_upsert - and table not in self._unsafe_to_upsert_tables - ): - return self.simple_upsert_txn_native_upsert( - txn, table, keyvalues, values, insertion_values=insertion_values - ) - else: - return self.simple_upsert_txn_emulated( - txn, - table, - keyvalues, - values, - insertion_values=insertion_values, - lock=lock, - ) - - def simple_upsert_txn_emulated( - self, txn, table, keyvalues, values, insertion_values={}, lock=True - ): - """ - Args: - table (str): The table to upsert into - keyvalues (dict): The unique key tables and their new values - values (dict): The nonunique columns and their new values - insertion_values (dict): additional key/values to use only when - inserting - lock (bool): True to lock the table when doing the upsert. - Returns: - bool: Return True if a new entry was created, False if an existing - one was updated. - """ - # We need to lock the table :(, unless we're *really* careful - if lock: - self.database_engine.lock_table(txn, table) - - def _getwhere(key): - # If the value we're passing in is None (aka NULL), we need to use - # IS, not =, as NULL = NULL equals NULL (False). - if keyvalues[key] is None: - return "%s IS ?" % (key,) - else: - return "%s = ?" % (key,) - - if not values: - # If `values` is empty, then all of the values we care about are in - # the unique key, so there is nothing to UPDATE. We can just do a - # SELECT instead to see if it exists. - sql = "SELECT 1 FROM %s WHERE %s" % ( - table, - " AND ".join(_getwhere(k) for k in keyvalues), - ) - sqlargs = list(keyvalues.values()) - txn.execute(sql, sqlargs) - if txn.fetchall(): - # We have an existing record. - return False - else: - # First try to update. - sql = "UPDATE %s SET %s WHERE %s" % ( - table, - ", ".join("%s = ?" % (k,) for k in values), - " AND ".join(_getwhere(k) for k in keyvalues), - ) - sqlargs = list(values.values()) + list(keyvalues.values()) - - txn.execute(sql, sqlargs) - if txn.rowcount > 0: - # successfully updated at least one row. - return False - - # We didn't find any existing rows, so insert a new one - allvalues = {} - allvalues.update(keyvalues) - allvalues.update(values) - allvalues.update(insertion_values) - - sql = "INSERT INTO %s (%s) VALUES (%s)" % ( - table, - ", ".join(k for k in allvalues), - ", ".join("?" for _ in allvalues), - ) - txn.execute(sql, list(allvalues.values())) - # successfully inserted - return True - - def simple_upsert_txn_native_upsert( - self, txn, table, keyvalues, values, insertion_values={} - ): - """ - Use the native UPSERT functionality in recent PostgreSQL versions. - - Args: - table (str): The table to upsert into - keyvalues (dict): The unique key tables and their new values - values (dict): The nonunique columns and their new values - insertion_values (dict): additional key/values to use only when - inserting - Returns: - None - """ - allvalues = {} - allvalues.update(keyvalues) - allvalues.update(insertion_values) - - if not values: - latter = "NOTHING" - else: - allvalues.update(values) - latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values) - - sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % ( - table, - ", ".join(k for k in allvalues), - ", ".join("?" for _ in allvalues), - ", ".join(k for k in keyvalues), - latter, - ) - txn.execute(sql, list(allvalues.values())) - - def simple_upsert_many_txn( - self, txn, table, key_names, key_values, value_names, value_values - ): - """ - Upsert, many times. - - Args: - table (str): The table to upsert into - key_names (list[str]): The key column names. - key_values (list[list]): A list of each row's key column values. - value_names (list[str]): The value column names. If empty, no - values will be used, even if value_values is provided. - value_values (list[list]): A list of each row's value column values. - Returns: - None - """ - if ( - self.database_engine.can_native_upsert - and table not in self._unsafe_to_upsert_tables - ): - return self.simple_upsert_many_txn_native_upsert( - txn, table, key_names, key_values, value_names, value_values - ) - else: - return self.simple_upsert_many_txn_emulated( - txn, table, key_names, key_values, value_names, value_values - ) - - def simple_upsert_many_txn_emulated( - self, txn, table, key_names, key_values, value_names, value_values - ): - """ - Upsert, many times, but without native UPSERT support or batching. - - Args: - table (str): The table to upsert into - key_names (list[str]): The key column names. - key_values (list[list]): A list of each row's key column values. - value_names (list[str]): The value column names. If empty, no - values will be used, even if value_values is provided. - value_values (list[list]): A list of each row's value column values. - Returns: - None - """ - # No value columns, therefore make a blank list so that the following - # zip() works correctly. - if not value_names: - value_values = [() for x in range(len(key_values))] - - for keyv, valv in zip(key_values, value_values): - _keys = {x: y for x, y in zip(key_names, keyv)} - _vals = {x: y for x, y in zip(value_names, valv)} - - self.simple_upsert_txn_emulated(txn, table, _keys, _vals) - - def simple_upsert_many_txn_native_upsert( - self, txn, table, key_names, key_values, value_names, value_values - ): - """ - Upsert, many times, using batching where possible. - - Args: - table (str): The table to upsert into - key_names (list[str]): The key column names. - key_values (list[list]): A list of each row's key column values. - value_names (list[str]): The value column names. If empty, no - values will be used, even if value_values is provided. - value_values (list[list]): A list of each row's value column values. - Returns: - None - """ - allnames = [] - allnames.extend(key_names) - allnames.extend(value_names) - - if not value_names: - # No value columns, therefore make a blank list so that the - # following zip() works correctly. - latter = "NOTHING" - value_values = [() for x in range(len(key_values))] - else: - latter = "UPDATE SET " + ", ".join( - k + "=EXCLUDED." + k for k in value_names - ) - - sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % ( - table, - ", ".join(k for k in allnames), - ", ".join("?" for _ in allnames), - ", ".join(key_names), - latter, - ) - - args = [] - - for x, y in zip(key_values, value_values): - args.append(tuple(x) + tuple(y)) - - return txn.execute_batch(sql, args) - - def simple_select_one( - self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one" - ): - """Executes a SELECT query on the named table, which is expected to - return a single row, returning multiple columns from it. - - Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - retcols : list of strings giving the names of the columns to return - - allow_none : If true, return None instead of failing if the SELECT - statement returns no rows - """ - return self.runInteraction( - desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none - ) - - def simple_select_one_onecol( - self, - table, - keyvalues, - retcol, - allow_none=False, - desc="simple_select_one_onecol", - ): - """Executes a SELECT query on the named table, which is expected to - return a single row, returning a single column from it. - - Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - retcol : string giving the name of the column to return - """ - return self.runInteraction( - desc, - self.simple_select_one_onecol_txn, - table, - keyvalues, - retcol, - allow_none=allow_none, - ) - - @classmethod - def simple_select_one_onecol_txn( - cls, txn, table, keyvalues, retcol, allow_none=False - ): - ret = cls.simple_select_onecol_txn( - txn, table=table, keyvalues=keyvalues, retcol=retcol - ) - - if ret: - return ret[0] - else: - if allow_none: - return None - else: - raise StoreError(404, "No row found") - - @staticmethod - def simple_select_onecol_txn(txn, table, keyvalues, retcol): - sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table} - - if keyvalues: - sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) - txn.execute(sql, list(keyvalues.values())) - else: - txn.execute(sql) - - return [r[0] for r in txn] - - def simple_select_onecol( - self, table, keyvalues, retcol, desc="simple_select_onecol" - ): - """Executes a SELECT query on the named table, which returns a list - comprising of the values of the named column from the selected rows. - - Args: - table (str): table name - keyvalues (dict|None): column names and values to select the rows with - retcol (str): column whos value we wish to retrieve. - - Returns: - Deferred: Results in a list - """ - return self.runInteraction( - desc, self.simple_select_onecol_txn, table, keyvalues, retcol - ) - - def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"): - """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. - - Args: - table (str): the table name - keyvalues (dict[str, Any] | None): - column names and values to select the rows with, or None to not - apply a WHERE clause. - retcols (iterable[str]): the names of the columns to return - Returns: - defer.Deferred: resolves to list[dict[str, Any]] - """ - return self.runInteraction( - desc, self.simple_select_list_txn, table, keyvalues, retcols - ) - - @classmethod - def simple_select_list_txn(cls, txn, table, keyvalues, retcols): - """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. - - Args: - txn : Transaction object - table (str): the table name - keyvalues (dict[str, T] | None): - column names and values to select the rows with, or None to not - apply a WHERE clause. - retcols (iterable[str]): the names of the columns to return - """ - if keyvalues: - sql = "SELECT %s FROM %s WHERE %s" % ( - ", ".join(retcols), - table, - " AND ".join("%s = ?" % (k,) for k in keyvalues), - ) - txn.execute(sql, list(keyvalues.values())) - else: - sql = "SELECT %s FROM %s" % (", ".join(retcols), table) - txn.execute(sql) - - return cls.cursor_to_dict(txn) - - @defer.inlineCallbacks - def simple_select_many_batch( - self, - table, - column, - iterable, - retcols, - keyvalues={}, - desc="simple_select_many_batch", - batch_size=100, - ): - """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. - - Filters rows by if value of `column` is in `iterable`. - - Args: - table : string giving the table name - column : column name to test for inclusion against `iterable` - iterable : list - keyvalues : dict of column names and values to select the rows with - retcols : list of strings giving the names of the columns to return - """ - results = [] - - if not iterable: - return results - - # iterables can not be sliced, so convert it to a list first - it_list = list(iterable) - - chunks = [ - it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size) - ] - for chunk in chunks: - rows = yield self.runInteraction( - desc, - self.simple_select_many_txn, - table, - column, - chunk, - keyvalues, - retcols, - ) - - results.extend(rows) - - return results - - @classmethod - def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols): - """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. - - Filters rows by if value of `column` is in `iterable`. - - Args: - txn : Transaction object - table : string giving the table name - column : column name to test for inclusion against `iterable` - iterable : list - keyvalues : dict of column names and values to select the rows with - retcols : list of strings giving the names of the columns to return - """ - if not iterable: - return [] - - clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable) - clauses = [clause] - - for key, value in iteritems(keyvalues): - clauses.append("%s = ?" % (key,)) - values.append(value) - - sql = "SELECT %s FROM %s WHERE %s" % ( - ", ".join(retcols), - table, - " AND ".join(clauses), - ) - - txn.execute(sql, values) - return cls.cursor_to_dict(txn) - - def simple_update(self, table, keyvalues, updatevalues, desc): - return self.runInteraction( - desc, self.simple_update_txn, table, keyvalues, updatevalues - ) - - @staticmethod - def simple_update_txn(txn, table, keyvalues, updatevalues): - if keyvalues: - where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) - else: - where = "" - - update_sql = "UPDATE %s SET %s %s" % ( - table, - ", ".join("%s = ?" % (k,) for k in updatevalues), - where, - ) - - txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values())) - - return txn.rowcount - - def simple_update_one( - self, table, keyvalues, updatevalues, desc="simple_update_one" - ): - """Executes an UPDATE query on the named table, setting new values for - columns in a row matching the key values. - - Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - updatevalues : dict giving column names and values to update - retcols : optional list of column names to return - - If present, retcols gives a list of column names on which to perform - a SELECT statement *before* performing the UPDATE statement. The values - of these will be returned in a dict. - - These are performed within the same transaction, allowing an atomic - get-and-set. This can be used to implement compare-and-set by putting - the update column in the 'keyvalues' dict as well. - """ - return self.runInteraction( - desc, self.simple_update_one_txn, table, keyvalues, updatevalues - ) - - @classmethod - def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues): - rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues) - - if rowcount == 0: - raise StoreError(404, "No row found (%s)" % (table,)) - if rowcount > 1: - raise StoreError(500, "More than one row matched (%s)" % (table,)) - - @staticmethod - def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False): - select_sql = "SELECT %s FROM %s WHERE %s" % ( - ", ".join(retcols), - table, - " AND ".join("%s = ?" % (k,) for k in keyvalues), - ) - - txn.execute(select_sql, list(keyvalues.values())) - row = txn.fetchone() - - if not row: - if allow_none: - return None - raise StoreError(404, "No row found (%s)" % (table,)) - if txn.rowcount > 1: - raise StoreError(500, "More than one row matched (%s)" % (table,)) - - return dict(zip(retcols, row)) - - def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"): - """Executes a DELETE query on the named table, expecting to delete a - single row. - - Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - """ - return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues) - - @staticmethod - def simple_delete_one_txn(txn, table, keyvalues): - """Executes a DELETE query on the named table, expecting to delete a - single row. - - Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - """ - sql = "DELETE FROM %s WHERE %s" % ( - table, - " AND ".join("%s = ?" % (k,) for k in keyvalues), - ) - - txn.execute(sql, list(keyvalues.values())) - if txn.rowcount == 0: - raise StoreError(404, "No row found (%s)" % (table,)) - if txn.rowcount > 1: - raise StoreError(500, "More than one row matched (%s)" % (table,)) - - def simple_delete(self, table, keyvalues, desc): - return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues) - - @staticmethod - def simple_delete_txn(txn, table, keyvalues): - sql = "DELETE FROM %s WHERE %s" % ( - table, - " AND ".join("%s = ?" % (k,) for k in keyvalues), - ) - - txn.execute(sql, list(keyvalues.values())) - return txn.rowcount - - def simple_delete_many(self, table, column, iterable, keyvalues, desc): - return self.runInteraction( - desc, self.simple_delete_many_txn, table, column, iterable, keyvalues - ) - - @staticmethod - def simple_delete_many_txn(txn, table, column, iterable, keyvalues): - """Executes a DELETE query on the named table. - - Filters rows by if value of `column` is in `iterable`. - - Args: - txn : Transaction object - table : string giving the table name - column : column name to test for inclusion against `iterable` - iterable : list - keyvalues : dict of column names and values to select the rows with - - Returns: - int: Number rows deleted - """ - if not iterable: - return 0 - - sql = "DELETE FROM %s" % table - - clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable) - clauses = [clause] - - for key, value in iteritems(keyvalues): - clauses.append("%s = ?" % (key,)) - values.append(value) - - if clauses: - sql = "%s WHERE %s" % (sql, " AND ".join(clauses)) - txn.execute(sql, values) - - return txn.rowcount - - def get_cache_dict( - self, db_conn, table, entity_column, stream_column, max_value, limit=100000 - ): - # Fetch a mapping of room_id -> max stream position for "recent" rooms. - # It doesn't really matter how many we get, the StreamChangeCache will - # do the right thing to ensure it respects the max size of cache. - sql = ( - "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s" - " WHERE %(stream)s > ? - %(limit)s" - " GROUP BY %(entity)s" - ) % { - "table": table, - "entity": entity_column, - "stream": stream_column, - "limit": limit, - } - - sql = self.database_engine.convert_param_style(sql) - - txn = db_conn.cursor() - txn.execute(sql, (int(max_value),)) - - cache = {row[0]: int(row[1]) for row in txn} - - txn.close() - - if cache: - min_val = min(itervalues(cache)) - else: - min_val = max_value - - return cache, min_val - def _invalidate_state_caches(self, room_id, members_changed): """Invalidates caches that are based on the current state, but does not stream invalidations down replication. @@ -1347,159 +71,6 @@ class SQLBaseStore(object): # which is fine. pass - def simple_select_list_paginate( - self, - table, - keyvalues, - orderby, - start, - limit, - retcols, - order_direction="ASC", - desc="simple_select_list_paginate", - ): - """ - Executes a SELECT query on the named table with start and limit, - of row numbers, which may return zero or number of rows from start to limit, - returning the result as a list of dicts. - - Args: - table (str): the table name - keyvalues (dict[str, T] | None): - column names and values to select the rows with, or None to not - apply a WHERE clause. - orderby (str): Column to order the results by. - start (int): Index to begin the query at. - limit (int): Number of results to return. - retcols (iterable[str]): the names of the columns to return - order_direction (str): Whether the results should be ordered "ASC" or "DESC". - Returns: - defer.Deferred: resolves to list[dict[str, Any]] - """ - return self.runInteraction( - desc, - self.simple_select_list_paginate_txn, - table, - keyvalues, - orderby, - start, - limit, - retcols, - order_direction=order_direction, - ) - - @classmethod - def simple_select_list_paginate_txn( - cls, - txn, - table, - keyvalues, - orderby, - start, - limit, - retcols, - order_direction="ASC", - ): - """ - Executes a SELECT query on the named table with start and limit, - of row numbers, which may return zero or number of rows from start to limit, - returning the result as a list of dicts. - - Args: - txn : Transaction object - table (str): the table name - keyvalues (dict[str, T] | None): - column names and values to select the rows with, or None to not - apply a WHERE clause. - orderby (str): Column to order the results by. - start (int): Index to begin the query at. - limit (int): Number of results to return. - retcols (iterable[str]): the names of the columns to return - order_direction (str): Whether the results should be ordered "ASC" or "DESC". - Returns: - defer.Deferred: resolves to list[dict[str, Any]] - """ - if order_direction not in ["ASC", "DESC"]: - raise ValueError("order_direction must be one of 'ASC' or 'DESC'.") - - if keyvalues: - where_clause = "WHERE " + " AND ".join("%s = ?" % (k,) for k in keyvalues) - else: - where_clause = "" - - sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % ( - ", ".join(retcols), - table, - where_clause, - orderby, - order_direction, - ) - txn.execute(sql, list(keyvalues.values()) + [limit, start]) - - return cls.cursor_to_dict(txn) - - def get_user_count_txn(self, txn): - """Get a total number of registered users in the users list. - - Args: - txn : Transaction object - Returns: - int : number of users - """ - sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;" - txn.execute(sql_count) - return txn.fetchone()[0] - - def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"): - """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. - - Args: - table (str): the table name - term (str | None): - term for searching the table matched to a column. - col (str): column to query term should be matched to - retcols (iterable[str]): the names of the columns to return - Returns: - defer.Deferred: resolves to list[dict[str, Any]] or None - """ - - return self.runInteraction( - desc, self.simple_search_list_txn, table, term, col, retcols - ) - - @classmethod - def simple_search_list_txn(cls, txn, table, term, col, retcols): - """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. - - Args: - txn : Transaction object - table (str): the table name - term (str | None): - term for searching the table matched to a column. - col (str): column to query term should be matched to - retcols (iterable[str]): the names of the columns to return - Returns: - defer.Deferred: resolves to list[dict[str, Any]] or None - """ - if term: - sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col) - termvalues = ["%%" + term + "%%"] - txn.execute(sql, termvalues) - else: - return 0 - - return cls.cursor_to_dict(txn) - - -class _RollbackButIsFineException(Exception): - """ This exception is used to rollback a transaction without implying - something went wrong. - """ - - pass - def db_to_json(db_content): """ @@ -1528,30 +99,3 @@ def db_to_json(db_content): except Exception: logging.warning("Tried to decode '%r' as JSON and failed", db_content) raise - - -def make_in_list_sql_clause( - database_engine, column: str, iterable: Iterable -) -> Tuple[str, Iterable]: - """Returns an SQL clause that checks the given column is in the iterable. - - On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres - it expands to `column = ANY(?)`. While both DBs support the `IN` form, - using the `ANY` form on postgres means that it views queries with - different length iterables as the same, helping the query stats. - - Args: - database_engine - column: Name of the column - iterable: The values to check the column against. - - Returns: - A tuple of SQL query and the args - """ - - if database_engine.supports_using_any_list: - # This should hopefully be faster, but also makes postgres query - # stats easier to understand. - return "%s = ANY(?)" % (column,), [list(iterable)] - else: - return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable) diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 06955a0537..dfca94b0e0 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -139,7 +139,7 @@ class BackgroundUpdateStore(SQLBaseStore): # otherwise, check if there are updates to be run. This is important, # as we may be running on a worker which doesn't perform the bg updates # itself, but still wants to wait for them to happen. - updates = yield self.simple_select_onecol( + updates = yield self.db.simple_select_onecol( "background_updates", keyvalues=None, retcol="1", @@ -161,7 +161,7 @@ class BackgroundUpdateStore(SQLBaseStore): if update_name in self._background_update_queue: return False - update_exists = await self.simple_select_one_onecol( + update_exists = await self.db.simple_select_one_onecol( "background_updates", keyvalues={"update_name": update_name}, retcol="1", @@ -184,7 +184,7 @@ class BackgroundUpdateStore(SQLBaseStore): no more work to do. """ if not self._background_update_queue: - updates = yield self.simple_select_list( + updates = yield self.db.simple_select_list( "background_updates", keyvalues=None, retcols=("update_name", "depends_on"), @@ -226,7 +226,7 @@ class BackgroundUpdateStore(SQLBaseStore): else: batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE - progress_json = yield self.simple_select_one_onecol( + progress_json = yield self.db.simple_select_one_onecol( "background_updates", keyvalues={"update_name": update_name}, retcol="progress_json", @@ -391,7 +391,7 @@ class BackgroundUpdateStore(SQLBaseStore): def updater(progress, batch_size): if runner is not None: logger.info("Adding index %s to %s", index_name, table) - yield self.runWithConnection(runner) + yield self.db.runWithConnection(runner) yield self._end_background_update(update_name) return 1 @@ -413,7 +413,7 @@ class BackgroundUpdateStore(SQLBaseStore): self._background_update_queue = [] progress_json = json.dumps(progress) - return self.simple_insert( + return self.db.simple_insert( "background_updates", {"update_name": update_name, "progress_json": progress_json}, ) @@ -429,7 +429,7 @@ class BackgroundUpdateStore(SQLBaseStore): self._background_update_queue = [ name for name in self._background_update_queue if name != update_name ] - return self.simple_delete_one( + return self.db.simple_delete_one( "background_updates", keyvalues={"update_name": update_name} ) @@ -444,7 +444,7 @@ class BackgroundUpdateStore(SQLBaseStore): progress_json = json.dumps(progress) - self.simple_update_one_txn( + self.db.simple_update_one_txn( txn, "background_updates", keyvalues={"update_name": update_name}, diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index 2a5b33dda1..46f0f26af6 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -171,9 +171,11 @@ class DataStore( else: self._cache_id_gen = None + super(DataStore, self).__init__(db_conn, hs) + self._presence_on_startup = self._get_active_presence(db_conn) - presence_cache_prefill, min_presence_val = self.get_cache_dict( + presence_cache_prefill, min_presence_val = self.db.get_cache_dict( db_conn, "presence_stream", entity_column="user_id", @@ -187,7 +189,7 @@ class DataStore( ) max_device_inbox_id = self._device_inbox_id_gen.get_current_token() - device_inbox_prefill, min_device_inbox_id = self.get_cache_dict( + device_inbox_prefill, min_device_inbox_id = self.db.get_cache_dict( db_conn, "device_inbox", entity_column="user_id", @@ -202,7 +204,7 @@ class DataStore( ) # The federation outbox and the local device inbox uses the same # stream_id generator. - device_outbox_prefill, min_device_outbox_id = self.get_cache_dict( + device_outbox_prefill, min_device_outbox_id = self.db.get_cache_dict( db_conn, "device_federation_outbox", entity_column="destination", @@ -228,7 +230,7 @@ class DataStore( ) events_max = self._stream_id_gen.get_current_token() - curr_state_delta_prefill, min_curr_state_delta_id = self.get_cache_dict( + curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict( db_conn, "current_state_delta_stream", entity_column="room_id", @@ -242,7 +244,7 @@ class DataStore( prefilled_cache=curr_state_delta_prefill, ) - _group_updates_prefill, min_group_updates_id = self.get_cache_dict( + _group_updates_prefill, min_group_updates_id = self.db.get_cache_dict( db_conn, "local_group_updates", entity_column="user_id", @@ -262,8 +264,6 @@ class DataStore( # Used in _generate_user_daily_visits to keep track of progress self._last_user_visit_update = self._get_start_of_day() - super(DataStore, self).__init__(db_conn, hs) - def take_presence_startup_info(self): active_on_startup = self._presence_on_startup self._presence_on_startup = None @@ -283,7 +283,7 @@ class DataStore( txn = db_conn.cursor() txn.execute(sql, (PresenceState.OFFLINE,)) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) txn.close() for row in rows: @@ -296,7 +296,7 @@ class DataStore( Counts the number of users who used this homeserver in the last 24 hours. """ yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) - return self.runInteraction("count_daily_users", self._count_users, yesterday) + return self.db.runInteraction("count_daily_users", self._count_users, yesterday) def count_monthly_users(self): """ @@ -306,7 +306,7 @@ class DataStore( amongst other things, includes a 3 day grace period before a user counts. """ thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) - return self.runInteraction( + return self.db.runInteraction( "count_monthly_users", self._count_users, thirty_days_ago ) @@ -406,7 +406,7 @@ class DataStore( return results - return self.runInteraction("count_r30_users", _count_r30_users) + return self.db.runInteraction("count_r30_users", _count_r30_users) def _get_start_of_day(self): """ @@ -471,7 +471,7 @@ class DataStore( # frequently self._last_user_visit_update = now - return self.runInteraction( + return self.db.runInteraction( "generate_user_daily_visits", _generate_user_daily_visits ) @@ -482,7 +482,7 @@ class DataStore( Returns: defer.Deferred: resolves to list[dict[str, Any]] """ - return self.simple_select_list( + return self.db.simple_select_list( table="users", keyvalues={}, retcols=["name", "password_hash", "is_guest", "admin", "user_type"], @@ -502,9 +502,9 @@ class DataStore( Returns: defer.Deferred: resolves to json object {list[dict[str, Any]], count} """ - users = yield self.runInteraction( + users = yield self.db.runInteraction( "get_users_paginate", - self.simple_select_list_paginate_txn, + self.db.simple_select_list_paginate_txn, table="users", keyvalues={"is_guest": False}, orderby=order, @@ -512,7 +512,9 @@ class DataStore( limit=limit, retcols=["name", "password_hash", "is_guest", "admin", "user_type"], ) - count = yield self.runInteraction("get_users_paginate", self.get_user_count_txn) + count = yield self.db.runInteraction( + "get_users_paginate", self.get_user_count_txn + ) retval = {"users": users, "total": count} return retval @@ -526,7 +528,7 @@ class DataStore( Returns: defer.Deferred: resolves to list[dict[str, Any]] """ - return self.simple_search_list( + return self.db.simple_search_list( table="users", term=term, col="name", diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py index b0d22faf3f..a96fe9485c 100644 --- a/synapse/storage/data_stores/main/account_data.py +++ b/synapse/storage/data_stores/main/account_data.py @@ -67,7 +67,7 @@ class AccountDataWorkerStore(SQLBaseStore): """ def get_account_data_for_user_txn(txn): - rows = self.simple_select_list_txn( + rows = self.db.simple_select_list_txn( txn, "account_data", {"user_id": user_id}, @@ -78,7 +78,7 @@ class AccountDataWorkerStore(SQLBaseStore): row["account_data_type"]: json.loads(row["content"]) for row in rows } - rows = self.simple_select_list_txn( + rows = self.db.simple_select_list_txn( txn, "room_account_data", {"user_id": user_id}, @@ -92,7 +92,7 @@ class AccountDataWorkerStore(SQLBaseStore): return global_account_data, by_room - return self.runInteraction( + return self.db.runInteraction( "get_account_data_for_user", get_account_data_for_user_txn ) @@ -102,7 +102,7 @@ class AccountDataWorkerStore(SQLBaseStore): Returns: Deferred: A dict """ - result = yield self.simple_select_one_onecol( + result = yield self.db.simple_select_one_onecol( table="account_data", keyvalues={"user_id": user_id, "account_data_type": data_type}, retcol="content", @@ -127,7 +127,7 @@ class AccountDataWorkerStore(SQLBaseStore): """ def get_account_data_for_room_txn(txn): - rows = self.simple_select_list_txn( + rows = self.db.simple_select_list_txn( txn, "room_account_data", {"user_id": user_id, "room_id": room_id}, @@ -138,7 +138,7 @@ class AccountDataWorkerStore(SQLBaseStore): row["account_data_type"]: json.loads(row["content"]) for row in rows } - return self.runInteraction( + return self.db.runInteraction( "get_account_data_for_room", get_account_data_for_room_txn ) @@ -156,7 +156,7 @@ class AccountDataWorkerStore(SQLBaseStore): """ def get_account_data_for_room_and_type_txn(txn): - content_json = self.simple_select_one_onecol_txn( + content_json = self.db.simple_select_one_onecol_txn( txn, table="room_account_data", keyvalues={ @@ -170,7 +170,7 @@ class AccountDataWorkerStore(SQLBaseStore): return json.loads(content_json) if content_json else None - return self.runInteraction( + return self.db.runInteraction( "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn ) @@ -207,7 +207,7 @@ class AccountDataWorkerStore(SQLBaseStore): room_results = txn.fetchall() return global_results, room_results - return self.runInteraction( + return self.db.runInteraction( "get_all_updated_account_data_txn", get_updated_account_data_txn ) @@ -252,7 +252,7 @@ class AccountDataWorkerStore(SQLBaseStore): if not changed: return {}, {} - return self.runInteraction( + return self.db.runInteraction( "get_updated_account_data_for_user", get_updated_account_data_for_user_txn ) @@ -302,7 +302,7 @@ class AccountDataStore(AccountDataWorkerStore): # no need to lock here as room_account_data has a unique constraint # on (user_id, room_id, account_data_type) so simple_upsert will # retry if there is a conflict. - yield self.simple_upsert( + yield self.db.simple_upsert( desc="add_room_account_data", table="room_account_data", keyvalues={ @@ -348,7 +348,7 @@ class AccountDataStore(AccountDataWorkerStore): # no need to lock here as account_data has a unique constraint on # (user_id, account_data_type) so simple_upsert will retry if # there is a conflict. - yield self.simple_upsert( + yield self.db.simple_upsert( desc="add_user_account_data", table="account_data", keyvalues={"user_id": user_id, "account_data_type": account_data_type}, @@ -388,4 +388,4 @@ class AccountDataStore(AccountDataWorkerStore): ) txn.execute(update_max_id_sql, (next_id, next_id)) - return self.runInteraction("update_account_data_max_stream_id", _update) + return self.db.runInteraction("update_account_data_max_stream_id", _update) diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py index 6b82fd392a..6b2e12719c 100644 --- a/synapse/storage/data_stores/main/appservice.py +++ b/synapse/storage/data_stores/main/appservice.py @@ -133,7 +133,7 @@ class ApplicationServiceTransactionWorkerStore( A Deferred which resolves to a list of ApplicationServices, which may be empty. """ - results = yield self.simple_select_list( + results = yield self.db.simple_select_list( "application_services_state", dict(state=state), ["as_id"] ) # NB: This assumes this class is linked with ApplicationServiceStore @@ -155,7 +155,7 @@ class ApplicationServiceTransactionWorkerStore( Returns: A Deferred which resolves to ApplicationServiceState. """ - result = yield self.simple_select_one( + result = yield self.db.simple_select_one( "application_services_state", dict(as_id=service.id), ["state"], @@ -175,7 +175,7 @@ class ApplicationServiceTransactionWorkerStore( Returns: A Deferred which resolves when the state was set successfully. """ - return self.simple_upsert( + return self.db.simple_upsert( "application_services_state", dict(as_id=service.id), dict(state=state) ) @@ -216,7 +216,7 @@ class ApplicationServiceTransactionWorkerStore( ) return AppServiceTransaction(service=service, id=new_txn_id, events=events) - return self.runInteraction("create_appservice_txn", _create_appservice_txn) + return self.db.runInteraction("create_appservice_txn", _create_appservice_txn) def complete_appservice_txn(self, txn_id, service): """Completes an application service transaction. @@ -249,7 +249,7 @@ class ApplicationServiceTransactionWorkerStore( ) # Set current txn_id for AS to 'txn_id' - self.simple_upsert_txn( + self.db.simple_upsert_txn( txn, "application_services_state", dict(as_id=service.id), @@ -257,11 +257,13 @@ class ApplicationServiceTransactionWorkerStore( ) # Delete txn - self.simple_delete_txn( + self.db.simple_delete_txn( txn, "application_services_txns", dict(txn_id=txn_id, as_id=service.id) ) - return self.runInteraction("complete_appservice_txn", _complete_appservice_txn) + return self.db.runInteraction( + "complete_appservice_txn", _complete_appservice_txn + ) @defer.inlineCallbacks def get_oldest_unsent_txn(self, service): @@ -283,7 +285,7 @@ class ApplicationServiceTransactionWorkerStore( " ORDER BY txn_id ASC LIMIT 1", (service.id,), ) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if not rows: return None @@ -291,7 +293,7 @@ class ApplicationServiceTransactionWorkerStore( return entry - entry = yield self.runInteraction( + entry = yield self.db.runInteraction( "get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn ) @@ -321,7 +323,7 @@ class ApplicationServiceTransactionWorkerStore( "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) ) - return self.runInteraction( + return self.db.runInteraction( "set_appservice_last_pos", set_appservice_last_pos_txn ) @@ -350,7 +352,7 @@ class ApplicationServiceTransactionWorkerStore( return upper_bound, [row[1] for row in rows] - upper_bound, event_ids = yield self.runInteraction( + upper_bound, event_ids = yield self.db.runInteraction( "get_new_events_for_appservice", get_new_events_for_appservice_txn ) diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py index de3256049d..54ed8574c4 100644 --- a/synapse/storage/data_stores/main/cache.py +++ b/synapse/storage/data_stores/main/cache.py @@ -95,7 +95,7 @@ class CacheInvalidationStore(SQLBaseStore): txn.call_after(ctx.__exit__, None, None, None) txn.call_after(self.hs.get_notifier().on_new_replication_data) - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="cache_invalidation_stream", values={ @@ -122,7 +122,9 @@ class CacheInvalidationStore(SQLBaseStore): txn.execute(sql, (last_id, limit)) return txn.fetchall() - return self.runInteraction("get_all_updated_caches", get_all_updated_caches_txn) + return self.db.runInteraction( + "get_all_updated_caches", get_all_updated_caches_txn + ) def get_cache_stream_token(self): if self._cache_id_gen: diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index 66522a04b7..6f2a720b97 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -91,7 +91,7 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): txn.execute("DROP INDEX IF EXISTS user_ips_user_ip") txn.close() - yield self.runWithConnection(f) + yield self.db.runWithConnection(f) yield self._end_background_update("user_ips_drop_nonunique_index") return 1 @@ -106,7 +106,7 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): def user_ips_analyze(txn): txn.execute("ANALYZE user_ips") - yield self.runInteraction("user_ips_analyze", user_ips_analyze) + yield self.db.runInteraction("user_ips_analyze", user_ips_analyze) yield self._end_background_update("user_ips_analyze") @@ -140,7 +140,7 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): return None # Get a last seen that has roughly `batch_size` since `begin_last_seen` - end_last_seen = yield self.runInteraction( + end_last_seen = yield self.db.runInteraction( "user_ips_dups_get_last_seen", get_last_seen ) @@ -275,7 +275,7 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): txn, "user_ips_remove_dupes", {"last_seen": end_last_seen} ) - yield self.runInteraction("user_ips_dups_remove", remove) + yield self.db.runInteraction("user_ips_dups_remove", remove) if last: yield self._end_background_update("user_ips_remove_dupes") @@ -352,7 +352,7 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): return len(rows) - updated = yield self.runInteraction( + updated = yield self.db.runInteraction( "_devices_last_seen_update", _devices_last_seen_update_txn ) @@ -417,12 +417,12 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): to_update = self._batch_row_update self._batch_row_update = {} - return self.runInteraction( + return self.db.runInteraction( "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update ) def _update_client_ips_batch_txn(self, txn, to_update): - if "user_ips" in self._unsafe_to_upsert_tables or ( + if "user_ips" in self.db._unsafe_to_upsert_tables or ( not self.database_engine.can_native_upsert ): self.database_engine.lock_table(txn, "user_ips") @@ -431,7 +431,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry try: - self.simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="user_ips", keyvalues={ @@ -450,7 +450,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): # Technically an access token might not be associated with # a device so we need to check. if device_id: - self.simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="devices", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -483,7 +483,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): if device_id is not None: keyvalues["device_id"] = device_id - res = yield self.simple_select_list( + res = yield self.db.simple_select_list( table="devices", keyvalues=keyvalues, retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), @@ -516,7 +516,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): user_agent, _, last_seen = self._batch_row_update[key] results[(access_token, ip)] = (user_agent, last_seen) - rows = yield self.simple_select_list( + rows = yield self.db.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "last_seen"], @@ -577,4 +577,4 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): def _prune_old_user_ips_txn(txn): txn.execute(sql, (timestamp,)) - await self.runInteraction("_prune_old_user_ips", _prune_old_user_ips_txn) + await self.db.runInteraction("_prune_old_user_ips", _prune_old_user_ips_txn) diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py index 206d39134d..440793ad49 100644 --- a/synapse/storage/data_stores/main/deviceinbox.py +++ b/synapse/storage/data_stores/main/deviceinbox.py @@ -69,7 +69,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): stream_pos = current_stream_id return messages, stream_pos - return self.runInteraction( + return self.db.runInteraction( "get_new_messages_for_device", get_new_messages_for_device_txn ) @@ -109,7 +109,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): txn.execute(sql, (user_id, device_id, up_to_stream_id)) return txn.rowcount - count = yield self.runInteraction( + count = yield self.db.runInteraction( "delete_messages_for_device", delete_messages_for_device_txn ) @@ -178,7 +178,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): stream_pos = current_stream_id return messages, stream_pos - return self.runInteraction( + return self.db.runInteraction( "get_new_device_msgs_for_remote", get_new_messages_for_remote_destination_txn, ) @@ -203,7 +203,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): ) txn.execute(sql, (destination, up_to_stream_id)) - return self.runInteraction( + return self.db.runInteraction( "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn ) @@ -232,7 +232,7 @@ class DeviceInboxBackgroundUpdateStore(BackgroundUpdateStore): txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id") txn.close() - yield self.runWithConnection(reindex_txn) + yield self.db.runWithConnection(reindex_txn) yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID) @@ -294,7 +294,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() - yield self.runInteraction( + yield self.db.runInteraction( "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id ) for user_id in local_messages_by_user_then_device.keys(): @@ -314,7 +314,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) # Check if we've already inserted a matching message_id for that # origin. This can happen if the origin doesn't receive our # acknowledgement from the first time we received the message. - already_inserted = self.simple_select_one_txn( + already_inserted = self.db.simple_select_one_txn( txn, table="device_federation_inbox", keyvalues={"origin": origin, "message_id": message_id}, @@ -326,7 +326,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) # Add an entry for this message_id so that we know we've processed # it. - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="device_federation_inbox", values={ @@ -344,7 +344,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() - yield self.runInteraction( + yield self.db.runInteraction( "add_messages_from_remote_to_device_inbox", add_messages_txn, now_ms, @@ -465,6 +465,6 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) return rows - return self.runInteraction( + return self.db.runInteraction( "get_all_new_device_messages", get_all_new_device_messages_txn ) diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 727c582121..d98511ddd4 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -61,7 +61,7 @@ class DeviceWorkerStore(SQLBaseStore): Raises: StoreError: if the device is not found """ - return self.simple_select_one( + return self.db.simple_select_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), @@ -80,7 +80,7 @@ class DeviceWorkerStore(SQLBaseStore): containing "device_id", "user_id" and "display_name" for each device. """ - devices = yield self.simple_select_list( + devices = yield self.db.simple_select_list( table="devices", keyvalues={"user_id": user_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), @@ -122,7 +122,7 @@ class DeviceWorkerStore(SQLBaseStore): # consider the device update to be too large, and simply skip the # stream_id; the rationale being that such a large device list update # is likely an error. - updates = yield self.runInteraction( + updates = yield self.db.runInteraction( "get_device_updates_by_remote", self._get_device_updates_by_remote_txn, destination, @@ -283,7 +283,7 @@ class DeviceWorkerStore(SQLBaseStore): """ devices = ( - yield self.runInteraction( + yield self.db.runInteraction( "_get_e2e_device_keys_txn", self._get_e2e_device_keys_txn, query_map.keys(), @@ -340,12 +340,12 @@ class DeviceWorkerStore(SQLBaseStore): rows = txn.fetchall() return rows[0][0] - return self.runInteraction("get_last_device_update_for_remote_user", f) + return self.db.runInteraction("get_last_device_update_for_remote_user", f) def mark_as_sent_devices_by_remote(self, destination, stream_id): """Mark that updates have successfully been sent to the destination. """ - return self.runInteraction( + return self.db.runInteraction( "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn, destination, @@ -399,7 +399,7 @@ class DeviceWorkerStore(SQLBaseStore): """ with self._device_list_id_gen.get_next() as stream_id: - yield self.runInteraction( + yield self.db.runInteraction( "add_user_sig_change_to_streams", self._add_user_signature_change_txn, from_user_id, @@ -414,7 +414,7 @@ class DeviceWorkerStore(SQLBaseStore): from_user_id, stream_id, ) - self.simple_insert_txn( + self.db.simple_insert_txn( txn, "user_signature_stream", values={ @@ -466,7 +466,7 @@ class DeviceWorkerStore(SQLBaseStore): @cachedInlineCallbacks(num_args=2, tree=True) def _get_cached_user_device(self, user_id, device_id): - content = yield self.simple_select_one_onecol( + content = yield self.db.simple_select_one_onecol( table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, retcol="content", @@ -476,7 +476,7 @@ class DeviceWorkerStore(SQLBaseStore): @cachedInlineCallbacks() def _get_cached_devices_for_user(self, user_id): - devices = yield self.simple_select_list( + devices = yield self.db.simple_select_list( table="device_lists_remote_cache", keyvalues={"user_id": user_id}, retcols=("device_id", "content"), @@ -492,7 +492,7 @@ class DeviceWorkerStore(SQLBaseStore): Returns: (stream_id, devices) """ - return self.runInteraction( + return self.db.runInteraction( "get_devices_with_keys_by_user", self._get_devices_with_keys_by_user_txn, user_id, @@ -565,7 +565,7 @@ class DeviceWorkerStore(SQLBaseStore): return changes - return self.runInteraction( + return self.db.runInteraction( "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn ) @@ -584,7 +584,7 @@ class DeviceWorkerStore(SQLBaseStore): SELECT DISTINCT user_ids FROM user_signature_stream WHERE from_user_id = ? AND stream_id > ? """ - rows = yield self.execute( + rows = yield self.db.execute( "get_users_whose_signatures_changed", None, sql, user_id, from_key ) return set(user for row in rows for user in json.loads(row[0])) @@ -605,7 +605,7 @@ class DeviceWorkerStore(SQLBaseStore): WHERE ? < stream_id AND stream_id <= ? GROUP BY user_id, destination """ - return self.execute( + return self.db.execute( "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key ) @@ -614,7 +614,7 @@ class DeviceWorkerStore(SQLBaseStore): """Get the last stream_id we got for a user. May be None if we haven't got any information for them. """ - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, retcol="stream_id", @@ -628,7 +628,7 @@ class DeviceWorkerStore(SQLBaseStore): inlineCallbacks=True, ) def get_device_list_last_stream_id_for_remotes(self, user_ids): - rows = yield self.simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="device_lists_remote_extremeties", column="user_id", iterable=user_ids, @@ -685,7 +685,7 @@ class DeviceBackgroundUpdateStore(BackgroundUpdateStore): txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id") txn.close() - yield self.runWithConnection(f) + yield self.db.runWithConnection(f) yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES) return 1 @@ -722,7 +722,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return False try: - inserted = yield self.simple_insert( + inserted = yield self.db.simple_insert( "devices", values={ "user_id": user_id, @@ -736,7 +736,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if not inserted: # if the device already exists, check if it's a real device, or # if the device ID is reserved by something else - hidden = yield self.simple_select_one_onecol( + hidden = yield self.db.simple_select_one_onecol( "devices", keyvalues={"user_id": user_id, "device_id": device_id}, retcol="hidden", @@ -771,7 +771,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): Returns: defer.Deferred """ - yield self.simple_delete_one( + yield self.db.simple_delete_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, desc="delete_device", @@ -789,7 +789,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): Returns: defer.Deferred """ - yield self.simple_delete_many( + yield self.db.simple_delete_many( table="devices", column="device_id", iterable=device_ids, @@ -818,7 +818,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): updates["display_name"] = new_display_name if not updates: return defer.succeed(None) - return self.simple_update_one( + return self.db.simple_update_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, updatevalues=updates, @@ -829,7 +829,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): def mark_remote_user_device_list_as_unsubscribed(self, user_id): """Mark that we no longer track device lists for remote user. """ - yield self.simple_delete( + yield self.db.simple_delete( table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, desc="mark_remote_user_device_list_as_unsubscribed", @@ -853,7 +853,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): Returns: Deferred[None] """ - return self.runInteraction( + return self.db.runInteraction( "update_remote_device_list_cache_entry", self._update_remote_device_list_cache_entry_txn, user_id, @@ -866,7 +866,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self, txn, user_id, device_id, content, stream_id ): if content.get("deleted"): - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -874,7 +874,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id)) else: - self.simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -890,7 +890,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) ) - self.simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, @@ -914,7 +914,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): Returns: Deferred[None] """ - return self.runInteraction( + return self.db.runInteraction( "update_remote_device_list_cache", self._update_remote_device_list_cache_txn, user_id, @@ -923,11 +923,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id): - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id} ) - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="device_lists_remote_cache", values=[ @@ -946,7 +946,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) ) - self.simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, @@ -962,7 +962,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): (if any) should be poked. """ with self._device_list_id_gen.get_next() as stream_id: - yield self.runInteraction( + yield self.db.runInteraction( "add_device_change_to_streams", self._add_device_change_txn, user_id, @@ -995,7 +995,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): [(user_id, device_id, stream_id) for device_id in device_ids], ) - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="device_lists_stream", values=[ @@ -1006,7 +1006,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): context = get_active_span_text_map() - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="device_lists_outbound_pokes", values=[ @@ -1069,7 +1069,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return run_as_background_process( "prune_old_outbound_device_pokes", - self.runInteraction, + self.db.runInteraction, "_prune_old_outbound_device_pokes", _prune_txn, ) diff --git a/synapse/storage/data_stores/main/directory.py b/synapse/storage/data_stores/main/directory.py index d332f8a409..c9e7de7d12 100644 --- a/synapse/storage/data_stores/main/directory.py +++ b/synapse/storage/data_stores/main/directory.py @@ -36,7 +36,7 @@ class DirectoryWorkerStore(SQLBaseStore): Deferred: results in namedtuple with keys "room_id" and "servers" or None if no association can be found """ - room_id = yield self.simple_select_one_onecol( + room_id = yield self.db.simple_select_one_onecol( "room_aliases", {"room_alias": room_alias.to_string()}, "room_id", @@ -47,7 +47,7 @@ class DirectoryWorkerStore(SQLBaseStore): if not room_id: return None - servers = yield self.simple_select_onecol( + servers = yield self.db.simple_select_onecol( "room_alias_servers", {"room_alias": room_alias.to_string()}, "server", @@ -60,7 +60,7 @@ class DirectoryWorkerStore(SQLBaseStore): return RoomAliasMapping(room_id, room_alias.to_string(), servers) def get_room_alias_creator(self, room_alias): - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="room_aliases", keyvalues={"room_alias": room_alias}, retcol="creator", @@ -69,7 +69,7 @@ class DirectoryWorkerStore(SQLBaseStore): @cached(max_entries=5000) def get_aliases_for_room(self, room_id): - return self.simple_select_onecol( + return self.db.simple_select_onecol( "room_aliases", {"room_id": room_id}, "room_alias", @@ -93,7 +93,7 @@ class DirectoryStore(DirectoryWorkerStore): """ def alias_txn(txn): - self.simple_insert_txn( + self.db.simple_insert_txn( txn, "room_aliases", { @@ -103,7 +103,7 @@ class DirectoryStore(DirectoryWorkerStore): }, ) - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="room_alias_servers", values=[ @@ -117,7 +117,9 @@ class DirectoryStore(DirectoryWorkerStore): ) try: - ret = yield self.runInteraction("create_room_alias_association", alias_txn) + ret = yield self.db.runInteraction( + "create_room_alias_association", alias_txn + ) except self.database_engine.module.IntegrityError: raise SynapseError( 409, "Room alias %s already exists" % room_alias.to_string() @@ -126,7 +128,7 @@ class DirectoryStore(DirectoryWorkerStore): @defer.inlineCallbacks def delete_room_alias(self, room_alias): - room_id = yield self.runInteraction( + room_id = yield self.db.runInteraction( "delete_room_alias", self._delete_room_alias_txn, room_alias ) @@ -168,6 +170,6 @@ class DirectoryStore(DirectoryWorkerStore): txn, self.get_aliases_for_room, (new_room_id,) ) - return self.runInteraction( + return self.db.runInteraction( "_update_aliases_for_room_txn", _update_aliases_for_room_txn ) diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py index df89eda337..84594cf0a9 100644 --- a/synapse/storage/data_stores/main/e2e_room_keys.py +++ b/synapse/storage/data_stores/main/e2e_room_keys.py @@ -38,7 +38,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): StoreError """ - yield self.simple_update_one( + yield self.db.simple_update_one( table="e2e_room_keys", keyvalues={ "user_id": user_id, @@ -89,7 +89,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): } ) - yield self.simple_insert_many( + yield self.db.simple_insert_many( table="e2e_room_keys", values=values, desc="add_e2e_room_keys" ) @@ -125,7 +125,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): if session_id: keyvalues["session_id"] = session_id - rows = yield self.simple_select_list( + rows = yield self.db.simple_select_list( table="e2e_room_keys", keyvalues=keyvalues, retcols=( @@ -170,7 +170,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key """ - return self.runInteraction( + return self.db.runInteraction( "get_e2e_room_keys_multi", self._get_e2e_room_keys_multi_txn, user_id, @@ -234,7 +234,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): version (str): the version ID of the backup we're querying about """ - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="e2e_room_keys", keyvalues={"user_id": user_id, "version": version}, retcol="COUNT(*)", @@ -267,7 +267,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): if session_id: keyvalues["session_id"] = session_id - yield self.simple_delete( + yield self.db.simple_delete( table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys" ) @@ -312,7 +312,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): # it isn't there. raise StoreError(404, "No row found") - result = self.simple_select_one_txn( + result = self.db.simple_select_one_txn( txn, table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": this_version, "deleted": 0}, @@ -324,7 +324,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): result["etag"] = 0 return result - return self.runInteraction( + return self.db.runInteraction( "get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn ) @@ -352,7 +352,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): new_version = str(int(current_version) + 1) - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="e2e_room_keys_versions", values={ @@ -365,7 +365,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): return new_version - return self.runInteraction( + return self.db.runInteraction( "create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn ) @@ -391,7 +391,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): updatevalues["etag"] = version_etag if updatevalues: - return self.simple_update( + return self.db.simple_update( table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": version}, updatevalues=updatevalues, @@ -420,19 +420,19 @@ class EndToEndRoomKeyStore(SQLBaseStore): else: this_version = version - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="e2e_room_keys", keyvalues={"user_id": user_id, "version": this_version}, ) - return self.simple_update_one_txn( + return self.db.simple_update_one_txn( txn, table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": this_version}, updatevalues={"deleted": 1}, ) - return self.runInteraction( + return self.db.runInteraction( "delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn ) diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py index 08bcdc4725..38cd0ca9b8 100644 --- a/synapse/storage/data_stores/main/end_to_end_keys.py +++ b/synapse/storage/data_stores/main/end_to_end_keys.py @@ -48,7 +48,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): if not query_list: return {} - results = yield self.runInteraction( + results = yield self.db.runInteraction( "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list, @@ -125,7 +125,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): ) txn.execute(sql, query_params) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) result = {} for row in rows: @@ -143,7 +143,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): ) txn.execute(signature_sql, signature_query_params) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) # add each cross-signing signature to the correct device in the result dict. for row in rows: @@ -186,7 +186,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): key_id) to json string for key """ - rows = yield self.simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="e2e_one_time_keys_json", column="key_id", iterable=key_ids, @@ -219,7 +219,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): # a unique constraint. If there is a race of two calls to # `add_e2e_one_time_keys` then they'll conflict and we will only # insert one set. - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="e2e_one_time_keys_json", values=[ @@ -238,7 +238,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): txn, self.count_e2e_one_time_keys, (user_id, device_id) ) - yield self.runInteraction( + yield self.db.runInteraction( "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys ) @@ -261,7 +261,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore): result[algorithm] = key_count return result - return self.runInteraction("count_e2e_one_time_keys", _count_e2e_one_time_keys) + return self.db.runInteraction( + "count_e2e_one_time_keys", _count_e2e_one_time_keys + ) def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None): """Returns a user's cross-signing key. @@ -322,7 +324,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): Returns: dict of the key data or None if not found """ - return self.runInteraction( + return self.db.runInteraction( "get_e2e_cross_signing_key", self._get_e2e_cross_signing_key_txn, user_id, @@ -350,7 +352,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): WHERE ? < stream_id AND stream_id <= ? GROUP BY user_id """ - return self.execute( + return self.db.execute( "get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key ) @@ -367,7 +369,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): set_tag("time_now", time_now) set_tag("device_keys", device_keys) - old_key_json = self.simple_select_one_onecol_txn( + old_key_json = self.db.simple_select_one_onecol_txn( txn, table="e2e_device_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -383,7 +385,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): log_kv({"Message": "Device key already stored."}) return False - self.simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="e2e_device_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -392,7 +394,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): log_kv({"message": "Device keys stored."}) return True - return self.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn) + return self.db.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn) def claim_e2e_one_time_keys(self, query_list): """Take a list of one time keys out of the database""" @@ -431,7 +433,9 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): ) return result - return self.runInteraction("claim_e2e_one_time_keys", _claim_e2e_one_time_keys) + return self.db.runInteraction( + "claim_e2e_one_time_keys", _claim_e2e_one_time_keys + ) def delete_e2e_keys_by_device(self, user_id, device_id): def delete_e2e_keys_by_device_txn(txn): @@ -442,12 +446,12 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): "user_id": user_id, } ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="e2e_device_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="e2e_one_time_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -456,7 +460,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): txn, self.count_e2e_one_time_keys, (user_id, device_id) ) - return self.runInteraction( + return self.db.runInteraction( "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn ) @@ -492,7 +496,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): # The "keys" property must only have one entry, which will be the public # key, so we just grab the first value in there pubkey = next(iter(key["keys"].values())) - self.simple_insert_txn( + self.db.simple_insert_txn( txn, "devices", values={ @@ -505,7 +509,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): # and finally, store the key itself with self._cross_signing_id_gen.get_next() as stream_id: - self.simple_insert_txn( + self.db.simple_insert_txn( txn, "e2e_cross_signing_keys", values={ @@ -524,7 +528,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): key_type (str): the type of cross-signing key to set key (dict): the key data """ - return self.runInteraction( + return self.db.runInteraction( "add_e2e_cross_signing_key", self._set_e2e_cross_signing_key_txn, user_id, @@ -539,7 +543,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): user_id (str): the user who made the signatures signatures (iterable[SignatureListItem]): signatures to add """ - return self.simple_insert_many( + return self.db.simple_insert_many( "e2e_cross_signing_signatures", [ { diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 051ac7a8cb..77e4353b59 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -58,7 +58,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas Returns: list of event_ids """ - return self.runInteraction( + return self.db.runInteraction( "get_auth_chain_ids", self._get_auth_chain_ids_txn, event_ids, include_given ) @@ -90,12 +90,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return list(results) def get_oldest_events_in_room(self, room_id): - return self.runInteraction( + return self.db.runInteraction( "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id ) def get_oldest_events_with_depth_in_room(self, room_id): - return self.runInteraction( + return self.db.runInteraction( "get_oldest_events_with_depth_in_room", self.get_oldest_events_with_depth_in_room_txn, room_id, @@ -126,7 +126,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas Returns Deferred[int] """ - rows = yield self.simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="events", column="event_id", iterable=event_ids, @@ -140,7 +140,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return max(row["depth"] for row in rows) def _get_oldest_events_in_room_txn(self, txn, room_id): - return self.simple_select_onecol_txn( + return self.db.simple_select_onecol_txn( txn, table="event_backward_extremities", keyvalues={"room_id": room_id}, @@ -188,7 +188,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas where *hashes* is a map from algorithm to hash. """ - return self.runInteraction( + return self.db.runInteraction( "get_latest_event_ids_and_hashes_in_room", self._get_latest_event_ids_and_hashes_in_room, room_id, @@ -229,13 +229,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas txn.execute(sql, query_args) return [room_id for room_id, in txn] - return self.runInteraction( + return self.db.runInteraction( "get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn ) @cached(max_entries=5000, iterable=True) def get_latest_event_ids_in_room(self, room_id): - return self.simple_select_onecol( + return self.db.simple_select_onecol( table="event_forward_extremities", keyvalues={"room_id": room_id}, retcol="event_id", @@ -266,12 +266,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas def get_min_depth(self, room_id): """ For hte given room, get the minimum depth we have seen for it. """ - return self.runInteraction( + return self.db.runInteraction( "get_min_depth", self._get_min_depth_interaction, room_id ) def _get_min_depth_interaction(self, txn, room_id): - min_depth = self.simple_select_one_onecol_txn( + min_depth = self.db.simple_select_one_onecol_txn( txn, table="room_depth", keyvalues={"room_id": room_id}, @@ -337,7 +337,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas txn.execute(sql, (stream_ordering, room_id)) return [event_id for event_id, in txn] - return self.runInteraction( + return self.db.runInteraction( "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn ) @@ -352,7 +352,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas limit (int) """ return ( - self.runInteraction( + self.db.runInteraction( "get_backfill_events", self._get_backfill_events, room_id, @@ -383,7 +383,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas queue = PriorityQueue() for event_id in event_list: - depth = self.simple_select_one_onecol_txn( + depth = self.db.simple_select_one_onecol_txn( txn, table="events", keyvalues={"event_id": event_id, "room_id": room_id}, @@ -415,7 +415,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas @defer.inlineCallbacks def get_missing_events(self, room_id, earliest_events, latest_events, limit): - ids = yield self.runInteraction( + ids = yield self.db.runInteraction( "get_missing_events", self._get_missing_events, room_id, @@ -468,7 +468,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas Returns: Deferred[list[str]] """ - rows = yield self.simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="event_edges", column="prev_event_id", iterable=event_ids, @@ -508,7 +508,7 @@ class EventFederationStore(EventFederationWorkerStore): if min_depth and depth >= min_depth: return - self.simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="room_depth", keyvalues={"room_id": room_id}, @@ -520,7 +520,7 @@ class EventFederationStore(EventFederationWorkerStore): For the given event, update the event edges table and forward and backward extremities tables. """ - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="event_edges", values=[ @@ -604,13 +604,13 @@ class EventFederationStore(EventFederationWorkerStore): return run_as_background_process( "delete_old_forward_extrem_cache", - self.runInteraction, + self.db.runInteraction, "_delete_old_forward_extrem_cache", _delete_old_forward_extrem_cache_txn, ) def clean_room_for_join(self, room_id): - return self.runInteraction( + return self.db.runInteraction( "clean_room_for_join", self._clean_room_for_join_txn, room_id ) @@ -660,7 +660,7 @@ class EventFederationStore(EventFederationWorkerStore): return min_stream_id >= target_min_stream_id - result = yield self.runInteraction( + result = yield self.db.runInteraction( self.EVENT_AUTH_STATE_ONLY, delete_event_auth ) diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index 0a37847cfd..725d0881dc 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -93,7 +93,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): def get_unread_event_push_actions_by_room_for_user( self, room_id, user_id, last_read_event_id ): - ret = yield self.runInteraction( + ret = yield self.db.runInteraction( "get_unread_event_push_actions_by_room", self._get_unread_counts_by_receipt_txn, room_id, @@ -177,7 +177,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, (min_stream_ordering, max_stream_ordering)) return [r[0] for r in txn] - ret = yield self.runInteraction("get_push_action_users_in_range", f) + ret = yield self.db.runInteraction("get_push_action_users_in_range", f) return ret @defer.inlineCallbacks @@ -229,7 +229,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, args) return txn.fetchall() - after_read_receipt = yield self.runInteraction( + after_read_receipt = yield self.db.runInteraction( "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt ) @@ -257,7 +257,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, args) return txn.fetchall() - no_read_receipt = yield self.runInteraction( + no_read_receipt = yield self.db.runInteraction( "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt ) @@ -329,7 +329,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, args) return txn.fetchall() - after_read_receipt = yield self.runInteraction( + after_read_receipt = yield self.db.runInteraction( "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt ) @@ -357,7 +357,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, args) return txn.fetchall() - no_read_receipt = yield self.runInteraction( + no_read_receipt = yield self.db.runInteraction( "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt ) @@ -407,7 +407,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, (user_id, min_stream_ordering)) return bool(txn.fetchone()) - return self.runInteraction( + return self.db.runInteraction( "get_if_maybe_push_in_range_for_user", _get_if_maybe_push_in_range_for_user_txn, ) @@ -458,7 +458,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): ), ) - return self.runInteraction( + return self.db.runInteraction( "add_push_actions_to_staging", _add_push_actions_to_staging_txn ) @@ -472,7 +472,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): """ try: - res = yield self.simple_delete( + res = yield self.db.simple_delete( table="event_push_actions_staging", keyvalues={"event_id": event_id}, desc="remove_push_actions_from_staging", @@ -489,7 +489,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): def _find_stream_orderings_for_times(self): return run_as_background_process( "event_push_action_stream_orderings", - self.runInteraction, + self.db.runInteraction, "_find_stream_orderings_for_times", self._find_stream_orderings_for_times_txn, ) @@ -525,7 +525,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): Deferred[int]: stream ordering of the first event received on/after the timestamp """ - return self.runInteraction( + return self.db.runInteraction( "_find_first_stream_ordering_after_ts_txn", self._find_first_stream_ordering_after_ts_txn, ts, @@ -677,7 +677,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): ) for event, _ in events_and_contexts: - user_ids = self.simple_select_onecol_txn( + user_ids = self.db.simple_select_onecol_txn( txn, table="event_push_actions_staging", keyvalues={"event_id": event.event_id}, @@ -727,9 +727,9 @@ class EventPushActionsStore(EventPushActionsWorkerStore): " LIMIT ?" % (before_clause,) ) txn.execute(sql, args) - return self.cursor_to_dict(txn) + return self.db.cursor_to_dict(txn) - push_actions = yield self.runInteraction("get_push_actions_for_user", f) + push_actions = yield self.db.runInteraction("get_push_actions_for_user", f) for pa in push_actions: pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) return push_actions @@ -748,7 +748,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): txn.execute(sql, (stream_ordering,)) return txn.fetchone() - result = yield self.runInteraction("get_time_of_last_push_action_before", f) + result = yield self.db.runInteraction("get_time_of_last_push_action_before", f) return result[0] if result else None @defer.inlineCallbacks @@ -757,7 +757,9 @@ class EventPushActionsStore(EventPushActionsWorkerStore): txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions") return txn.fetchone() - result = yield self.runInteraction("get_latest_push_action_stream_ordering", f) + result = yield self.db.runInteraction( + "get_latest_push_action_stream_ordering", f + ) return result[0] or 0 def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id): @@ -830,7 +832,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): while True: logger.info("Rotating notifications") - caught_up = yield self.runInteraction( + caught_up = yield self.db.runInteraction( "_rotate_notifs", self._rotate_notifs_txn ) if caught_up: @@ -844,7 +846,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): the archiving process has caught up or not. """ - old_rotate_stream_ordering = self.simple_select_one_onecol_txn( + old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn( txn, table="event_push_summary_stream_ordering", keyvalues={}, @@ -880,7 +882,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): return caught_up def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering): - old_rotate_stream_ordering = self.simple_select_one_onecol_txn( + old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn( txn, table="event_push_summary_stream_ordering", keyvalues={}, @@ -912,7 +914,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): # If the `old.user_id` above is NULL then we know there isn't already an # entry in the table, so we simply insert it. Otherwise we update the # existing table. - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="event_push_summary", values=[ diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 98ae69e996..01ec9ec397 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -143,7 +143,7 @@ class EventsStore( ) return txn.fetchall() - res = yield self.runInteraction("read_forward_extremities", fetch) + res = yield self.db.runInteraction("read_forward_extremities", fetch) self._current_forward_extremities_amount = c_counter(list(x[0] for x in res)) @_retry_on_integrity_error @@ -208,7 +208,7 @@ class EventsStore( for (event, context), stream in zip(events_and_contexts, stream_orderings): event.internal_metadata.stream_ordering = stream - yield self.runInteraction( + yield self.db.runInteraction( "persist_events", self._persist_events_txn, events_and_contexts=events_and_contexts, @@ -281,7 +281,7 @@ class EventsStore( results.extend(r[0] for r in txn if not json.loads(r[1]).get("soft_failed")) for chunk in batch_iter(event_ids, 100): - yield self.runInteraction( + yield self.db.runInteraction( "_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk ) @@ -345,7 +345,7 @@ class EventsStore( existing_prevs.add(prev_event_id) for chunk in batch_iter(event_ids, 100): - yield self.runInteraction( + yield self.db.runInteraction( "_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk ) @@ -432,7 +432,7 @@ class EventsStore( # event's auth chain, but its easier for now just to store them (and # it doesn't take much storage compared to storing the entire event # anyway). - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="event_auth", values=[ @@ -580,12 +580,12 @@ class EventsStore( self, txn, new_forward_extremities, max_stream_order ): for room_id, new_extrem in iteritems(new_forward_extremities): - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="event_forward_extremities", keyvalues={"room_id": room_id} ) txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="event_forward_extremities", values=[ @@ -598,7 +598,7 @@ class EventsStore( # new stream_ordering to new forward extremeties in the room. # This allows us to later efficiently look up the forward extremeties # for a room before a given stream_ordering - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="stream_ordering_to_exterm", values=[ @@ -722,7 +722,7 @@ class EventsStore( # change in outlier status to our workers. stream_order = event.internal_metadata.stream_ordering state_group_id = context.state_group - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="ex_outlier_stream", values={ @@ -794,7 +794,7 @@ class EventsStore( d.pop("redacted_because", None) return d - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="event_json", values=[ @@ -811,7 +811,7 @@ class EventsStore( ], ) - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="events", values=[ @@ -841,7 +841,7 @@ class EventsStore( # If we're persisting an unredacted event we go and ensure # that we mark any redactions that reference this event as # requiring censoring. - self.simple_update_txn( + self.db.simple_update_txn( txn, table="redactions", keyvalues={"redacts": event.event_id}, @@ -983,7 +983,7 @@ class EventsStore( state_values.append(vals) - self.simple_insert_many_txn(txn, table="state_events", values=state_values) + self.db.simple_insert_many_txn(txn, table="state_events", values=state_values) # Prefill the event cache self._add_to_cache(txn, events_and_contexts) @@ -1014,7 +1014,7 @@ class EventsStore( ) txn.execute(sql + clause, args) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) for row in rows: event = ev_map[row["event_id"]] if not row["rejects"] and not row["redacts"]: @@ -1032,7 +1032,7 @@ class EventsStore( # invalidate the cache for the redacted event txn.call_after(self._invalidate_get_event_cache, event.redacts) - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="redactions", values={ @@ -1077,7 +1077,9 @@ class EventsStore( LIMIT ? """ - rows = yield self.execute("_censor_redactions_fetch", None, sql, before_ts, 100) + rows = yield self.db.execute( + "_censor_redactions_fetch", None, sql, before_ts, 100 + ) updates = [] @@ -1109,14 +1111,14 @@ class EventsStore( if pruned_json: self._censor_event_txn(txn, event_id, pruned_json) - self.simple_update_one_txn( + self.db.simple_update_one_txn( txn, table="redactions", keyvalues={"event_id": redaction_id}, updatevalues={"have_censored": True}, ) - yield self.runInteraction("_update_censor_txn", _update_censor_txn) + yield self.db.runInteraction("_update_censor_txn", _update_censor_txn) def _censor_event_txn(self, txn, event_id, pruned_json): """Censor an event by replacing its JSON in the event_json table with the @@ -1127,7 +1129,7 @@ class EventsStore( event_id (str): The ID of the event to censor. pruned_json (str): The pruned JSON """ - self.simple_update_one_txn( + self.db.simple_update_one_txn( txn, table="event_json", keyvalues={"event_id": event_id}, @@ -1153,7 +1155,7 @@ class EventsStore( (count,) = txn.fetchone() return count - ret = yield self.runInteraction("count_messages", _count_messages) + ret = yield self.db.runInteraction("count_messages", _count_messages) return ret @defer.inlineCallbacks @@ -1174,7 +1176,7 @@ class EventsStore( (count,) = txn.fetchone() return count - ret = yield self.runInteraction("count_daily_sent_messages", _count_messages) + ret = yield self.db.runInteraction("count_daily_sent_messages", _count_messages) return ret @defer.inlineCallbacks @@ -1189,7 +1191,7 @@ class EventsStore( (count,) = txn.fetchone() return count - ret = yield self.runInteraction("count_daily_active_rooms", _count) + ret = yield self.db.runInteraction("count_daily_active_rooms", _count) return ret def get_current_backfill_token(self): @@ -1241,7 +1243,7 @@ class EventsStore( return new_event_updates - return self.runInteraction( + return self.db.runInteraction( "get_all_new_forward_event_rows", get_all_new_forward_event_rows ) @@ -1286,7 +1288,7 @@ class EventsStore( return new_event_updates - return self.runInteraction( + return self.db.runInteraction( "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows ) @@ -1379,7 +1381,7 @@ class EventsStore( backward_ex_outliers, ) - return self.runInteraction("get_all_new_events", get_all_new_events_txn) + return self.db.runInteraction("get_all_new_events", get_all_new_events_txn) def purge_history(self, room_id, token, delete_local_events): """Deletes room history before a certain point @@ -1399,7 +1401,7 @@ class EventsStore( deleted events. """ - return self.runInteraction( + return self.db.runInteraction( "purge_history", self._purge_history_txn, room_id, @@ -1647,7 +1649,7 @@ class EventsStore( Deferred[List[int]]: The list of state groups to delete. """ - return self.runInteraction("purge_room", self._purge_room_txn, room_id) + return self.db.runInteraction("purge_room", self._purge_room_txn, room_id) def _purge_room_txn(self, txn, room_id): # First we fetch all the state groups that should be deleted, before @@ -1766,7 +1768,7 @@ class EventsStore( to delete. """ - return self.runInteraction( + return self.db.runInteraction( "purge_unreferenced_state_groups", self._purge_unreferenced_state_groups, room_id, @@ -1778,7 +1780,7 @@ class EventsStore( "[purge] found %i state groups to delete", len(state_groups_to_delete) ) - rows = self.simple_select_many_txn( + rows = self.db.simple_select_many_txn( txn, table="state_group_edges", column="prev_state_group", @@ -1805,15 +1807,15 @@ class EventsStore( curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) curr_state = curr_state[sg] - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="state_groups_state", keyvalues={"state_group": sg} ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="state_group_edges", keyvalues={"state_group": sg} ) - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -1850,7 +1852,7 @@ class EventsStore( state group. """ - rows = yield self.simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="state_group_edges", column="prev_state_group", iterable=state_groups, @@ -1869,7 +1871,7 @@ class EventsStore( state_groups_to_delete (list[int]): State groups to delete """ - return self.runInteraction( + return self.db.runInteraction( "purge_room_state", self._purge_room_state_txn, room_id, @@ -1880,7 +1882,7 @@ class EventsStore( # first we have to delete the state groups states logger.info("[purge] removing %s from state_groups_state", room_id) - self.simple_delete_many_txn( + self.db.simple_delete_many_txn( txn, table="state_groups_state", column="state_group", @@ -1891,7 +1893,7 @@ class EventsStore( # ... and the state group edges logger.info("[purge] removing %s from state_group_edges", room_id) - self.simple_delete_many_txn( + self.db.simple_delete_many_txn( txn, table="state_group_edges", column="state_group", @@ -1902,7 +1904,7 @@ class EventsStore( # ... and the state groups logger.info("[purge] removing %s from state_groups", room_id) - self.simple_delete_many_txn( + self.db.simple_delete_many_txn( txn, table="state_groups", column="id", @@ -1919,7 +1921,7 @@ class EventsStore( @cachedInlineCallbacks(max_entries=5000) def _get_event_ordering(self, event_id): - res = yield self.simple_select_one( + res = yield self.db.simple_select_one( table="events", retcols=["topological_ordering", "stream_ordering"], keyvalues={"event_id": event_id}, @@ -1942,7 +1944,7 @@ class EventsStore( txn.execute(sql, (from_token, to_token, limit)) return txn.fetchall() - return self.runInteraction( + return self.db.runInteraction( "get_all_updated_current_state_deltas", get_all_updated_current_state_deltas_txn, ) @@ -1960,7 +1962,7 @@ class EventsStore( room_id (str): The ID of the room the event was sent to. topological_ordering (int): The position of the event in the room's topology. """ - return self.simple_insert_many_txn( + return self.db.simple_insert_many_txn( txn=txn, table="event_labels", values=[ @@ -1982,7 +1984,7 @@ class EventsStore( event_id (str): The event ID the expiry timestamp is associated with. expiry_ts (int): The timestamp at which to expire (delete) the event. """ - return self.simple_insert_txn( + return self.db.simple_insert_txn( txn=txn, table="event_expiry", values={"event_id": event_id, "expiry_ts": expiry_ts}, @@ -2031,7 +2033,7 @@ class EventsStore( txn, "_get_event_cache", (event.event_id,) ) - yield self.runInteraction("delete_expired_event", delete_expired_event_txn) + yield self.db.runInteraction("delete_expired_event", delete_expired_event_txn) def _delete_event_expiry_txn(self, txn, event_id): """Delete the expiry timestamp associated with an event ID without deleting the @@ -2041,7 +2043,7 @@ class EventsStore( txn (LoggingTransaction): The transaction to use to perform the deletion. event_id (str): The event ID to delete the associated expiry timestamp of. """ - return self.simple_delete_txn( + return self.db.simple_delete_txn( txn=txn, table="event_expiry", keyvalues={"event_id": event_id} ) @@ -2065,7 +2067,7 @@ class EventsStore( return txn.fetchone() - return self.runInteraction( + return self.db.runInteraction( desc="get_next_event_to_expire", func=get_next_event_to_expire_txn ) diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py index 37dfc8c871..365e966956 100644 --- a/synapse/storage/data_stores/main/events_bg_updates.py +++ b/synapse/storage/data_stores/main/events_bg_updates.py @@ -151,7 +151,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): return len(rows) - result = yield self.runInteraction( + result = yield self.db.runInteraction( self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn ) @@ -189,7 +189,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)] for chunk in chunks: - ev_rows = self.simple_select_many_txn( + ev_rows = self.db.simple_select_many_txn( txn, table="event_json", column="event_id", @@ -228,7 +228,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): return len(rows_to_update) - result = yield self.runInteraction( + result = yield self.db.runInteraction( self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn ) @@ -366,7 +366,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): to_delete.intersection_update(original_set) - deleted = self.simple_delete_many_txn( + deleted = self.db.simple_delete_many_txn( txn=txn, table="event_forward_extremities", column="event_id", @@ -382,7 +382,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): if deleted: # We now need to invalidate the caches of these rooms - rows = self.simple_select_many_txn( + rows = self.db.simple_select_many_txn( txn, table="events", column="event_id", @@ -396,7 +396,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): self.get_latest_event_ids_in_room.invalidate, (room_id,) ) - self.simple_delete_many_txn( + self.db.simple_delete_many_txn( txn=txn, table="_extremities_to_check", column="event_id", @@ -406,7 +406,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): return len(original_set) - num_handled = yield self.runInteraction( + num_handled = yield self.db.runInteraction( "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn ) @@ -416,7 +416,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): def _drop_table_txn(txn): txn.execute("DROP TABLE _extremities_to_check") - yield self.runInteraction( + yield self.db.runInteraction( "_cleanup_extremities_bg_update_drop_table", _drop_table_txn ) @@ -470,7 +470,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): return len(rows) - count = yield self.runInteraction( + count = yield self.db.runInteraction( "_redactions_received_ts", _redactions_received_ts_txn ) @@ -501,7 +501,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): txn.execute("DROP INDEX redactions_censored_redacts") - yield self.runInteraction( + yield self.db.runInteraction( "_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn ) @@ -533,7 +533,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): try: event_json = json.loads(event_json_raw) - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn=txn, table="event_labels", values=[ @@ -565,7 +565,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): return nbrows - num_rows = yield self.runInteraction( + num_rows = yield self.db.runInteraction( desc="event_store_labels", func=_event_store_labels_txn ) diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 6a08a746b6..e041fc5eac 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -78,7 +78,7 @@ class EventsWorkerStore(SQLBaseStore): Deferred[int|None]: Timestamp in milliseconds, or None for events that were persisted before received_ts was implemented. """ - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="events", keyvalues={"event_id": event_id}, retcol="received_ts", @@ -117,7 +117,7 @@ class EventsWorkerStore(SQLBaseStore): return ts - return self.runInteraction( + return self.db.runInteraction( "get_approximate_received_ts", _get_approximate_received_ts_txn ) @@ -452,7 +452,7 @@ class EventsWorkerStore(SQLBaseStore): event_id for events, _ in event_list for event_id in events ) - row_dict = self.new_transaction( + row_dict = self.db.new_transaction( conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch ) @@ -584,7 +584,7 @@ class EventsWorkerStore(SQLBaseStore): if should_start: run_as_background_process( - "fetch_events", self.runWithConnection, self._do_fetch + "fetch_events", self.db.runWithConnection, self._do_fetch ) logger.debug("Loading %d events: %s", len(events), events) @@ -745,7 +745,7 @@ class EventsWorkerStore(SQLBaseStore): """Given a list of event ids, check if we have already processed and stored them as non outliers. """ - rows = yield self.simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="events", retcols=("event_id",), column="event_id", @@ -780,7 +780,9 @@ class EventsWorkerStore(SQLBaseStore): # break the input up into chunks of 100 input_iterator = iter(event_ids) for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []): - yield self.runInteraction("have_seen_events", have_seen_events_txn, chunk) + yield self.db.runInteraction( + "have_seen_events", have_seen_events_txn, chunk + ) return results def _get_total_state_event_counts_txn(self, txn, room_id): @@ -807,7 +809,7 @@ class EventsWorkerStore(SQLBaseStore): Returns: Deferred[int] """ - return self.runInteraction( + return self.db.runInteraction( "get_total_state_event_counts", self._get_total_state_event_counts_txn, room_id, @@ -832,7 +834,7 @@ class EventsWorkerStore(SQLBaseStore): Returns: Deferred[int] """ - return self.runInteraction( + return self.db.runInteraction( "get_current_state_event_counts", self._get_current_state_event_counts_txn, room_id, diff --git a/synapse/storage/data_stores/main/filtering.py b/synapse/storage/data_stores/main/filtering.py index 17ef7b9354..342d6622a4 100644 --- a/synapse/storage/data_stores/main/filtering.py +++ b/synapse/storage/data_stores/main/filtering.py @@ -30,7 +30,7 @@ class FilteringStore(SQLBaseStore): except ValueError: raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM) - def_json = yield self.simple_select_one_onecol( + def_json = yield self.db.simple_select_one_onecol( table="user_filters", keyvalues={"user_id": user_localpart, "filter_id": filter_id}, retcol="filter_json", @@ -71,4 +71,4 @@ class FilteringStore(SQLBaseStore): return filter_id - return self.runInteraction("add_user_filter", _do_txn) + return self.db.runInteraction("add_user_filter", _do_txn) diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py index 9e1d12bcb7..7f5e8dce66 100644 --- a/synapse/storage/data_stores/main/group_server.py +++ b/synapse/storage/data_stores/main/group_server.py @@ -35,7 +35,7 @@ class GroupServerStore(SQLBaseStore): * "invite" * "open" """ - return self.simple_update_one( + return self.db.simple_update_one( table="groups", keyvalues={"group_id": group_id}, updatevalues={"join_policy": join_policy}, @@ -43,7 +43,7 @@ class GroupServerStore(SQLBaseStore): ) def get_group(self, group_id): - return self.simple_select_one( + return self.db.simple_select_one( table="groups", keyvalues={"group_id": group_id}, retcols=( @@ -65,7 +65,7 @@ class GroupServerStore(SQLBaseStore): if not include_private: keyvalues["is_public"] = True - return self.simple_select_list( + return self.db.simple_select_list( table="group_users", keyvalues=keyvalues, retcols=("user_id", "is_public", "is_admin"), @@ -75,7 +75,7 @@ class GroupServerStore(SQLBaseStore): def get_invited_users_in_group(self, group_id): # TODO: Pagination - return self.simple_select_onecol( + return self.db.simple_select_onecol( table="group_invites", keyvalues={"group_id": group_id}, retcol="user_id", @@ -89,7 +89,7 @@ class GroupServerStore(SQLBaseStore): if not include_private: keyvalues["is_public"] = True - return self.simple_select_list( + return self.db.simple_select_list( table="group_rooms", keyvalues=keyvalues, retcols=("room_id", "is_public"), @@ -153,10 +153,12 @@ class GroupServerStore(SQLBaseStore): return rooms, categories - return self.runInteraction("get_rooms_for_summary", _get_rooms_for_summary_txn) + return self.db.runInteraction( + "get_rooms_for_summary", _get_rooms_for_summary_txn + ) def add_room_to_summary(self, group_id, room_id, category_id, order, is_public): - return self.runInteraction( + return self.db.runInteraction( "add_room_to_summary", self._add_room_to_summary_txn, group_id, @@ -180,7 +182,7 @@ class GroupServerStore(SQLBaseStore): an order of 1 will put the room first. Otherwise, the room gets added to the end. """ - room_in_group = self.simple_select_one_onecol_txn( + room_in_group = self.db.simple_select_one_onecol_txn( txn, table="group_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, @@ -193,7 +195,7 @@ class GroupServerStore(SQLBaseStore): if category_id is None: category_id = _DEFAULT_CATEGORY_ID else: - cat_exists = self.simple_select_one_onecol_txn( + cat_exists = self.db.simple_select_one_onecol_txn( txn, table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, @@ -204,7 +206,7 @@ class GroupServerStore(SQLBaseStore): raise SynapseError(400, "Category doesn't exist") # TODO: Check category is part of summary already - cat_exists = self.simple_select_one_onecol_txn( + cat_exists = self.db.simple_select_one_onecol_txn( txn, table="group_summary_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, @@ -224,7 +226,7 @@ class GroupServerStore(SQLBaseStore): (group_id, category_id, group_id, category_id), ) - existing = self.simple_select_one_txn( + existing = self.db.simple_select_one_txn( txn, table="group_summary_rooms", keyvalues={ @@ -257,7 +259,7 @@ class GroupServerStore(SQLBaseStore): to_update["room_order"] = order if is_public is not None: to_update["is_public"] = is_public - self.simple_update_txn( + self.db.simple_update_txn( txn, table="group_summary_rooms", keyvalues={ @@ -271,7 +273,7 @@ class GroupServerStore(SQLBaseStore): if is_public is None: is_public = True - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="group_summary_rooms", values={ @@ -287,7 +289,7 @@ class GroupServerStore(SQLBaseStore): if category_id is None: category_id = _DEFAULT_CATEGORY_ID - return self.simple_delete( + return self.db.simple_delete( table="group_summary_rooms", keyvalues={ "group_id": group_id, @@ -299,7 +301,7 @@ class GroupServerStore(SQLBaseStore): @defer.inlineCallbacks def get_group_categories(self, group_id): - rows = yield self.simple_select_list( + rows = yield self.db.simple_select_list( table="group_room_categories", keyvalues={"group_id": group_id}, retcols=("category_id", "is_public", "profile"), @@ -316,7 +318,7 @@ class GroupServerStore(SQLBaseStore): @defer.inlineCallbacks def get_group_category(self, group_id, category_id): - category = yield self.simple_select_one( + category = yield self.db.simple_select_one( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, retcols=("is_public", "profile"), @@ -343,7 +345,7 @@ class GroupServerStore(SQLBaseStore): else: update_values["is_public"] = is_public - return self.simple_upsert( + return self.db.simple_upsert( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, values=update_values, @@ -352,7 +354,7 @@ class GroupServerStore(SQLBaseStore): ) def remove_group_category(self, group_id, category_id): - return self.simple_delete( + return self.db.simple_delete( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, desc="remove_group_category", @@ -360,7 +362,7 @@ class GroupServerStore(SQLBaseStore): @defer.inlineCallbacks def get_group_roles(self, group_id): - rows = yield self.simple_select_list( + rows = yield self.db.simple_select_list( table="group_roles", keyvalues={"group_id": group_id}, retcols=("role_id", "is_public", "profile"), @@ -377,7 +379,7 @@ class GroupServerStore(SQLBaseStore): @defer.inlineCallbacks def get_group_role(self, group_id, role_id): - role = yield self.simple_select_one( + role = yield self.db.simple_select_one( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, retcols=("is_public", "profile"), @@ -404,7 +406,7 @@ class GroupServerStore(SQLBaseStore): else: update_values["is_public"] = is_public - return self.simple_upsert( + return self.db.simple_upsert( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, values=update_values, @@ -413,14 +415,14 @@ class GroupServerStore(SQLBaseStore): ) def remove_group_role(self, group_id, role_id): - return self.simple_delete( + return self.db.simple_delete( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, desc="remove_group_role", ) def add_user_to_summary(self, group_id, user_id, role_id, order, is_public): - return self.runInteraction( + return self.db.runInteraction( "add_user_to_summary", self._add_user_to_summary_txn, group_id, @@ -444,7 +446,7 @@ class GroupServerStore(SQLBaseStore): an order of 1 will put the user first. Otherwise, the user gets added to the end. """ - user_in_group = self.simple_select_one_onecol_txn( + user_in_group = self.db.simple_select_one_onecol_txn( txn, table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -457,7 +459,7 @@ class GroupServerStore(SQLBaseStore): if role_id is None: role_id = _DEFAULT_ROLE_ID else: - role_exists = self.simple_select_one_onecol_txn( + role_exists = self.db.simple_select_one_onecol_txn( txn, table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, @@ -468,7 +470,7 @@ class GroupServerStore(SQLBaseStore): raise SynapseError(400, "Role doesn't exist") # TODO: Check role is part of the summary already - role_exists = self.simple_select_one_onecol_txn( + role_exists = self.db.simple_select_one_onecol_txn( txn, table="group_summary_roles", keyvalues={"group_id": group_id, "role_id": role_id}, @@ -488,7 +490,7 @@ class GroupServerStore(SQLBaseStore): (group_id, role_id, group_id, role_id), ) - existing = self.simple_select_one_txn( + existing = self.db.simple_select_one_txn( txn, table="group_summary_users", keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id}, @@ -517,7 +519,7 @@ class GroupServerStore(SQLBaseStore): to_update["user_order"] = order if is_public is not None: to_update["is_public"] = is_public - self.simple_update_txn( + self.db.simple_update_txn( txn, table="group_summary_users", keyvalues={ @@ -531,7 +533,7 @@ class GroupServerStore(SQLBaseStore): if is_public is None: is_public = True - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="group_summary_users", values={ @@ -547,7 +549,7 @@ class GroupServerStore(SQLBaseStore): if role_id is None: role_id = _DEFAULT_ROLE_ID - return self.simple_delete( + return self.db.simple_delete( table="group_summary_users", keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id}, desc="remove_user_from_summary", @@ -561,7 +563,7 @@ class GroupServerStore(SQLBaseStore): Deferred[list[str]]: A twisted.Deferred containing a list of group ids containing this room """ - return self.simple_select_onecol( + return self.db.simple_select_onecol( table="group_rooms", keyvalues={"room_id": room_id}, retcol="group_id", @@ -625,12 +627,12 @@ class GroupServerStore(SQLBaseStore): return users, roles - return self.runInteraction( + return self.db.runInteraction( "get_users_for_summary_by_role", _get_users_for_summary_txn ) def is_user_in_group(self, user_id, group_id): - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="user_id", @@ -639,7 +641,7 @@ class GroupServerStore(SQLBaseStore): ).addCallback(lambda r: bool(r)) def is_user_admin_in_group(self, group_id, user_id): - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="is_admin", @@ -650,7 +652,7 @@ class GroupServerStore(SQLBaseStore): def add_group_invite(self, group_id, user_id): """Record that the group server has invited a user """ - return self.simple_insert( + return self.db.simple_insert( table="group_invites", values={"group_id": group_id, "user_id": user_id}, desc="add_group_invite", @@ -659,7 +661,7 @@ class GroupServerStore(SQLBaseStore): def is_user_invited_to_local_group(self, group_id, user_id): """Has the group server invited a user? """ - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="user_id", @@ -682,7 +684,7 @@ class GroupServerStore(SQLBaseStore): """ def _get_users_membership_in_group_txn(txn): - row = self.simple_select_one_txn( + row = self.db.simple_select_one_txn( txn, table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -697,7 +699,7 @@ class GroupServerStore(SQLBaseStore): "is_privileged": row["is_admin"], } - row = self.simple_select_one_onecol_txn( + row = self.db.simple_select_one_onecol_txn( txn, table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -710,7 +712,7 @@ class GroupServerStore(SQLBaseStore): return {} - return self.runInteraction( + return self.db.runInteraction( "get_users_membership_info_in_group", _get_users_membership_in_group_txn ) @@ -738,7 +740,7 @@ class GroupServerStore(SQLBaseStore): """ def _add_user_to_group_txn(txn): - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="group_users", values={ @@ -749,14 +751,14 @@ class GroupServerStore(SQLBaseStore): }, ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, ) if local_attestation: - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="group_attestations_renewals", values={ @@ -766,7 +768,7 @@ class GroupServerStore(SQLBaseStore): }, ) if remote_attestation: - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="group_attestations_remote", values={ @@ -777,49 +779,49 @@ class GroupServerStore(SQLBaseStore): }, ) - return self.runInteraction("add_user_to_group", _add_user_to_group_txn) + return self.db.runInteraction("add_user_to_group", _add_user_to_group_txn) def remove_user_from_group(self, group_id, user_id): def _remove_user_from_group_txn(txn): - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_summary_users", keyvalues={"group_id": group_id, "user_id": user_id}, ) - return self.runInteraction( + return self.db.runInteraction( "remove_user_from_group", _remove_user_from_group_txn ) def add_room_to_group(self, group_id, room_id, is_public): - return self.simple_insert( + return self.db.simple_insert( table="group_rooms", values={"group_id": group_id, "room_id": room_id, "is_public": is_public}, desc="add_room_to_group", ) def update_room_in_group_visibility(self, group_id, room_id, is_public): - return self.simple_update( + return self.db.simple_update( table="group_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, updatevalues={"is_public": is_public}, @@ -828,26 +830,26 @@ class GroupServerStore(SQLBaseStore): def remove_room_from_group(self, group_id, room_id): def _remove_room_from_group_txn(txn): - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_summary_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, ) - return self.runInteraction( + return self.db.runInteraction( "remove_room_from_group", _remove_room_from_group_txn ) def get_publicised_groups_for_user(self, user_id): """Get all groups a user is publicising """ - return self.simple_select_onecol( + return self.db.simple_select_onecol( table="local_group_membership", keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True}, retcol="group_id", @@ -857,7 +859,7 @@ class GroupServerStore(SQLBaseStore): def update_group_publicity(self, group_id, user_id, publicise): """Update whether the user is publicising their membership of the group """ - return self.simple_update_one( + return self.db.simple_update_one( table="local_group_membership", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={"is_publicised": publicise}, @@ -893,12 +895,12 @@ class GroupServerStore(SQLBaseStore): def _register_user_group_membership_txn(txn, next_id): # TODO: Upsert? - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="local_group_membership", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="local_group_membership", values={ @@ -911,7 +913,7 @@ class GroupServerStore(SQLBaseStore): }, ) - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="local_group_updates", values={ @@ -930,7 +932,7 @@ class GroupServerStore(SQLBaseStore): if membership == "join": if local_attestation: - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="group_attestations_renewals", values={ @@ -940,7 +942,7 @@ class GroupServerStore(SQLBaseStore): }, ) if remote_attestation: - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="group_attestations_remote", values={ @@ -951,12 +953,12 @@ class GroupServerStore(SQLBaseStore): }, ) else: - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -965,7 +967,7 @@ class GroupServerStore(SQLBaseStore): return next_id with self._group_updates_id_gen.get_next() as next_id: - res = yield self.runInteraction( + res = yield self.db.runInteraction( "register_user_group_membership", _register_user_group_membership_txn, next_id, @@ -976,7 +978,7 @@ class GroupServerStore(SQLBaseStore): def create_group( self, group_id, user_id, name, avatar_url, short_description, long_description ): - yield self.simple_insert( + yield self.db.simple_insert( table="groups", values={ "group_id": group_id, @@ -991,7 +993,7 @@ class GroupServerStore(SQLBaseStore): @defer.inlineCallbacks def update_group_profile(self, group_id, profile): - yield self.simple_update_one( + yield self.db.simple_update_one( table="groups", keyvalues={"group_id": group_id}, updatevalues=profile, @@ -1008,16 +1010,16 @@ class GroupServerStore(SQLBaseStore): WHERE valid_until_ms <= ? """ txn.execute(sql, (valid_until_ms,)) - return self.cursor_to_dict(txn) + return self.db.cursor_to_dict(txn) - return self.runInteraction( + return self.db.runInteraction( "get_attestations_need_renewals", _get_attestations_need_renewals_txn ) def update_attestation_renewal(self, group_id, user_id, attestation): """Update an attestation that we have renewed """ - return self.simple_update_one( + return self.db.simple_update_one( table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={"valid_until_ms": attestation["valid_until_ms"]}, @@ -1027,7 +1029,7 @@ class GroupServerStore(SQLBaseStore): def update_remote_attestion(self, group_id, user_id, attestation): """Update an attestation that a remote has renewed """ - return self.simple_update_one( + return self.db.simple_update_one( table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={ @@ -1046,7 +1048,7 @@ class GroupServerStore(SQLBaseStore): group_id (str) user_id (str) """ - return self.simple_delete( + return self.db.simple_delete( table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, desc="remove_attestation_renewal", @@ -1057,7 +1059,7 @@ class GroupServerStore(SQLBaseStore): """Get the attestation that proves the remote agrees that the user is in the group. """ - row = yield self.simple_select_one( + row = yield self.db.simple_select_one( table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, retcols=("valid_until_ms", "attestation_json"), @@ -1072,7 +1074,7 @@ class GroupServerStore(SQLBaseStore): return None def get_joined_groups(self, user_id): - return self.simple_select_onecol( + return self.db.simple_select_onecol( table="local_group_membership", keyvalues={"user_id": user_id, "membership": "join"}, retcol="group_id", @@ -1099,7 +1101,7 @@ class GroupServerStore(SQLBaseStore): for row in txn ] - return self.runInteraction( + return self.db.runInteraction( "get_all_groups_for_user", _get_all_groups_for_user_txn ) @@ -1129,7 +1131,7 @@ class GroupServerStore(SQLBaseStore): for group_id, membership, gtype, content_json in txn ] - return self.runInteraction( + return self.db.runInteraction( "get_groups_changes_for_user", _get_groups_changes_for_user_txn ) @@ -1154,7 +1156,7 @@ class GroupServerStore(SQLBaseStore): for stream_id, group_id, user_id, gtype, content_json in txn ] - return self.runInteraction( + return self.db.runInteraction( "get_all_groups_changes", _get_all_groups_changes_txn ) @@ -1188,8 +1190,8 @@ class GroupServerStore(SQLBaseStore): ] for table in tables: - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table=table, keyvalues={"group_id": group_id} ) - return self.runInteraction("delete_group", _delete_group_txn) + return self.db.runInteraction("delete_group", _delete_group_txn) diff --git a/synapse/storage/data_stores/main/keys.py b/synapse/storage/data_stores/main/keys.py index c7150432b3..6b12f5a75f 100644 --- a/synapse/storage/data_stores/main/keys.py +++ b/synapse/storage/data_stores/main/keys.py @@ -92,7 +92,7 @@ class KeyStore(SQLBaseStore): _get_keys(txn, batch) return keys - return self.runInteraction("get_server_verify_keys", _txn) + return self.db.runInteraction("get_server_verify_keys", _txn) def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys): """Stores NACL verification keys for remote servers. @@ -127,9 +127,9 @@ class KeyStore(SQLBaseStore): f((i,)) return res - return self.runInteraction( + return self.db.runInteraction( "store_server_verify_keys", - self.simple_upsert_many_txn, + self.db.simple_upsert_many_txn, table="server_signature_keys", key_names=("server_name", "key_id"), key_values=key_values, @@ -157,7 +157,7 @@ class KeyStore(SQLBaseStore): ts_valid_until_ms (int): The time when this json stops being valid. key_json (bytes): The encoded JSON. """ - return self.simple_upsert( + return self.db.simple_upsert( table="server_keys_json", keyvalues={ "server_name": server_name, @@ -196,7 +196,7 @@ class KeyStore(SQLBaseStore): keyvalues["key_id"] = key_id if from_server is not None: keyvalues["from_server"] = from_server - rows = self.simple_select_list_txn( + rows = self.db.simple_select_list_txn( txn, "server_keys_json", keyvalues=keyvalues, @@ -211,4 +211,4 @@ class KeyStore(SQLBaseStore): results[(server_name, key_id, from_server)] = rows return results - return self.runInteraction("get_server_keys_json", _get_server_keys_json_txn) + return self.db.runInteraction("get_server_keys_json", _get_server_keys_json_txn) diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py index 0cb9446f96..ea02497784 100644 --- a/synapse/storage/data_stores/main/media_repository.py +++ b/synapse/storage/data_stores/main/media_repository.py @@ -39,7 +39,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): Returns: None if the media_id doesn't exist. """ - return self.simple_select_one( + return self.db.simple_select_one( "local_media_repository", {"media_id": media_id}, ( @@ -64,7 +64,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): user_id, url_cache=None, ): - return self.simple_insert( + return self.db.simple_insert( "local_media_repository", { "media_id": media_id, @@ -124,12 +124,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) ) - return self.runInteraction("get_url_cache", get_url_cache_txn) + return self.db.runInteraction("get_url_cache", get_url_cache_txn) def store_url_cache( self, url, response_code, etag, expires_ts, og, media_id, download_ts ): - return self.simple_insert( + return self.db.simple_insert( "local_media_repository_url_cache", { "url": url, @@ -144,7 +144,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) def get_local_media_thumbnails(self, media_id): - return self.simple_select_list( + return self.db.simple_select_list( "local_media_repository_thumbnails", {"media_id": media_id}, ( @@ -166,7 +166,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): thumbnail_method, thumbnail_length, ): - return self.simple_insert( + return self.db.simple_insert( "local_media_repository_thumbnails", { "media_id": media_id, @@ -180,7 +180,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) def get_cached_remote_media(self, origin, media_id): - return self.simple_select_one( + return self.db.simple_select_one( "remote_media_cache", {"media_origin": origin, "media_id": media_id}, ( @@ -205,7 +205,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): upload_name, filesystem_id, ): - return self.simple_insert( + return self.db.simple_insert( "remote_media_cache", { "media_origin": origin, @@ -250,10 +250,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): txn.executemany(sql, ((time_ms, media_id) for media_id in local_media)) - return self.runInteraction("update_cached_last_access_time", update_cache_txn) + return self.db.runInteraction( + "update_cached_last_access_time", update_cache_txn + ) def get_remote_media_thumbnails(self, origin, media_id): - return self.simple_select_list( + return self.db.simple_select_list( "remote_media_cache_thumbnails", {"media_origin": origin, "media_id": media_id}, ( @@ -278,7 +280,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): thumbnail_method, thumbnail_length, ): - return self.simple_insert( + return self.db.simple_insert( "remote_media_cache_thumbnails", { "media_origin": origin, @@ -300,24 +302,24 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): " WHERE last_access_ts < ?" ) - return self.execute( - "get_remote_media_before", self.cursor_to_dict, sql, before_ts + return self.db.execute( + "get_remote_media_before", self.db.cursor_to_dict, sql, before_ts ) def delete_remote_media(self, media_origin, media_id): def delete_remote_media_txn(txn): - self.simple_delete_txn( + self.db.simple_delete_txn( txn, "remote_media_cache", keyvalues={"media_origin": media_origin, "media_id": media_id}, ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, "remote_media_cache_thumbnails", keyvalues={"media_origin": media_origin, "media_id": media_id}, ) - return self.runInteraction("delete_remote_media", delete_remote_media_txn) + return self.db.runInteraction("delete_remote_media", delete_remote_media_txn) def get_expired_url_cache(self, now_ts): sql = ( @@ -331,7 +333,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): txn.execute(sql, (now_ts,)) return [row[0] for row in txn] - return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn) + return self.db.runInteraction( + "get_expired_url_cache", _get_expired_url_cache_txn + ) def delete_url_cache(self, media_ids): if len(media_ids) == 0: @@ -342,7 +346,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def _delete_url_cache_txn(txn): txn.executemany(sql, [(media_id,) for media_id in media_ids]) - return self.runInteraction("delete_url_cache", _delete_url_cache_txn) + return self.db.runInteraction("delete_url_cache", _delete_url_cache_txn) def get_url_cache_media_before(self, before_ts): sql = ( @@ -356,7 +360,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): txn.execute(sql, (before_ts,)) return [row[0] for row in txn] - return self.runInteraction( + return self.db.runInteraction( "get_url_cache_media_before", _get_url_cache_media_before_txn ) @@ -373,6 +377,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): txn.executemany(sql, [(media_id,) for media_id in media_ids]) - return self.runInteraction( + return self.db.runInteraction( "delete_url_cache_media", _delete_url_cache_media_txn ) diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py index b8fc28f97b..34bf3a1880 100644 --- a/synapse/storage/data_stores/main/monthly_active_users.py +++ b/synapse/storage/data_stores/main/monthly_active_users.py @@ -32,7 +32,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): self._clock = hs.get_clock() self.hs = hs # Do not add more reserved users than the total allowable number - self.new_transaction( + self.db.new_transaction( dbconn, "initialise_mau_threepids", [], @@ -146,7 +146,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): txn.execute(sql, query_args) reserved_users = yield self.get_registered_reserved_users() - yield self.runInteraction( + yield self.db.runInteraction( "reap_monthly_active_users", _reap_users, reserved_users ) # It seems poor to invalidate the whole cache, Postgres supports @@ -174,7 +174,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): (count,) = txn.fetchone() return count - return self.runInteraction("count_users", _count_users) + return self.db.runInteraction("count_users", _count_users) @defer.inlineCallbacks def get_registered_reserved_users(self): @@ -217,7 +217,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): if is_support: return - yield self.runInteraction( + yield self.db.runInteraction( "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id ) @@ -261,7 +261,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): # never be a big table and alternative approaches (batching multiple # upserts into a single txn) introduced a lot of extra complexity. # See https://github.com/matrix-org/synapse/issues/3854 for more - is_insert = self.simple_upsert_txn( + is_insert = self.db.simple_upsert_txn( txn, table="monthly_active_users", keyvalues={"user_id": user_id}, @@ -281,7 +281,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): """ - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="monthly_active_users", keyvalues={"user_id": user_id}, retcol="timestamp", diff --git a/synapse/storage/data_stores/main/openid.py b/synapse/storage/data_stores/main/openid.py index 650e49750e..cc21437e92 100644 --- a/synapse/storage/data_stores/main/openid.py +++ b/synapse/storage/data_stores/main/openid.py @@ -3,7 +3,7 @@ from synapse.storage._base import SQLBaseStore class OpenIdStore(SQLBaseStore): def insert_open_id_token(self, token, ts_valid_until_ms, user_id): - return self.simple_insert( + return self.db.simple_insert( table="open_id_tokens", values={ "token": token, @@ -28,4 +28,6 @@ class OpenIdStore(SQLBaseStore): else: return rows[0][0] - return self.runInteraction("get_user_id_for_token", get_user_id_for_token_txn) + return self.db.runInteraction( + "get_user_id_for_token", get_user_id_for_token_txn + ) diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py index a5e121efd1..a2c83e0867 100644 --- a/synapse/storage/data_stores/main/presence.py +++ b/synapse/storage/data_stores/main/presence.py @@ -29,7 +29,7 @@ class PresenceStore(SQLBaseStore): ) with stream_ordering_manager as stream_orderings: - yield self.runInteraction( + yield self.db.runInteraction( "update_presence", self._update_presence_txn, stream_orderings, @@ -46,7 +46,7 @@ class PresenceStore(SQLBaseStore): txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,)) # Actually insert new rows - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="presence_stream", values=[ @@ -88,7 +88,7 @@ class PresenceStore(SQLBaseStore): txn.execute(sql, (last_id, current_id)) return txn.fetchall() - return self.runInteraction( + return self.db.runInteraction( "get_all_presence_updates", get_all_presence_updates_txn ) @@ -103,7 +103,7 @@ class PresenceStore(SQLBaseStore): inlineCallbacks=True, ) def get_presence_for_users(self, user_ids): - rows = yield self.simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="presence_stream", column="user_id", iterable=user_ids, @@ -129,7 +129,7 @@ class PresenceStore(SQLBaseStore): return self._presence_id_gen.get_current_token() def allow_presence_visible(self, observed_localpart, observer_userid): - return self.simple_insert( + return self.db.simple_insert( table="presence_allow_inbound", values={ "observed_user_id": observed_localpart, @@ -140,7 +140,7 @@ class PresenceStore(SQLBaseStore): ) def disallow_presence_visible(self, observed_localpart, observer_userid): - return self.simple_delete_one( + return self.db.simple_delete_one( table="presence_allow_inbound", keyvalues={ "observed_user_id": observed_localpart, diff --git a/synapse/storage/data_stores/main/profile.py b/synapse/storage/data_stores/main/profile.py index c8b5b60301..2b52cf9c1a 100644 --- a/synapse/storage/data_stores/main/profile.py +++ b/synapse/storage/data_stores/main/profile.py @@ -24,7 +24,7 @@ class ProfileWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_profileinfo(self, user_localpart): try: - profile = yield self.simple_select_one( + profile = yield self.db.simple_select_one( table="profiles", keyvalues={"user_id": user_localpart}, retcols=("displayname", "avatar_url"), @@ -42,7 +42,7 @@ class ProfileWorkerStore(SQLBaseStore): ) def get_profile_displayname(self, user_localpart): - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, retcol="displayname", @@ -50,7 +50,7 @@ class ProfileWorkerStore(SQLBaseStore): ) def get_profile_avatar_url(self, user_localpart): - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, retcol="avatar_url", @@ -58,7 +58,7 @@ class ProfileWorkerStore(SQLBaseStore): ) def get_from_remote_profile_cache(self, user_id): - return self.simple_select_one( + return self.db.simple_select_one( table="remote_profile_cache", keyvalues={"user_id": user_id}, retcols=("displayname", "avatar_url"), @@ -67,12 +67,12 @@ class ProfileWorkerStore(SQLBaseStore): ) def create_profile(self, user_localpart): - return self.simple_insert( + return self.db.simple_insert( table="profiles", values={"user_id": user_localpart}, desc="create_profile" ) def set_profile_displayname(self, user_localpart, new_displayname): - return self.simple_update_one( + return self.db.simple_update_one( table="profiles", keyvalues={"user_id": user_localpart}, updatevalues={"displayname": new_displayname}, @@ -80,7 +80,7 @@ class ProfileWorkerStore(SQLBaseStore): ) def set_profile_avatar_url(self, user_localpart, new_avatar_url): - return self.simple_update_one( + return self.db.simple_update_one( table="profiles", keyvalues={"user_id": user_localpart}, updatevalues={"avatar_url": new_avatar_url}, @@ -95,7 +95,7 @@ class ProfileStore(ProfileWorkerStore): This should only be called when `is_subscribed_remote_profile_for_user` would return true for the user. """ - return self.simple_upsert( + return self.db.simple_upsert( table="remote_profile_cache", keyvalues={"user_id": user_id}, values={ @@ -107,7 +107,7 @@ class ProfileStore(ProfileWorkerStore): ) def update_remote_profile_cache(self, user_id, displayname, avatar_url): - return self.simple_update( + return self.db.simple_update( table="remote_profile_cache", keyvalues={"user_id": user_id}, values={ @@ -125,7 +125,7 @@ class ProfileStore(ProfileWorkerStore): """ subscribed = yield self.is_subscribed_remote_profile_for_user(user_id) if not subscribed: - yield self.simple_delete( + yield self.db.simple_delete( table="remote_profile_cache", keyvalues={"user_id": user_id}, desc="delete_remote_profile_cache", @@ -144,9 +144,9 @@ class ProfileStore(ProfileWorkerStore): txn.execute(sql, (last_checked,)) - return self.cursor_to_dict(txn) + return self.db.cursor_to_dict(txn) - return self.runInteraction( + return self.db.runInteraction( "get_remote_profile_cache_entries_that_expire", _get_remote_profile_cache_entries_that_expire_txn, ) @@ -155,7 +155,7 @@ class ProfileStore(ProfileWorkerStore): def is_subscribed_remote_profile_for_user(self, user_id): """Check whether we are interested in a remote user's profile. """ - res = yield self.simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="group_users", keyvalues={"user_id": user_id}, retcol="user_id", @@ -166,7 +166,7 @@ class ProfileStore(ProfileWorkerStore): if res: return True - res = yield self.simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="group_invites", keyvalues={"user_id": user_id}, retcol="user_id", diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index 75bd499bcd..de682cc63a 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -75,7 +75,7 @@ class PushRulesWorkerStore( def __init__(self, db_conn, hs): super(PushRulesWorkerStore, self).__init__(db_conn, hs) - push_rules_prefill, push_rules_id = self.get_cache_dict( + push_rules_prefill, push_rules_id = self.db.get_cache_dict( db_conn, "push_rules_stream", entity_column="user_id", @@ -100,7 +100,7 @@ class PushRulesWorkerStore( @cachedInlineCallbacks(max_entries=5000) def get_push_rules_for_user(self, user_id): - rows = yield self.simple_select_list( + rows = yield self.db.simple_select_list( table="push_rules", keyvalues={"user_name": user_id}, retcols=( @@ -124,7 +124,7 @@ class PushRulesWorkerStore( @cachedInlineCallbacks(max_entries=5000) def get_push_rules_enabled_for_user(self, user_id): - results = yield self.simple_select_list( + results = yield self.db.simple_select_list( table="push_rules_enable", keyvalues={"user_name": user_id}, retcols=("user_name", "rule_id", "enabled"), @@ -146,7 +146,7 @@ class PushRulesWorkerStore( (count,) = txn.fetchone() return bool(count) - return self.runInteraction( + return self.db.runInteraction( "have_push_rules_changed", have_push_rules_changed_txn ) @@ -162,7 +162,7 @@ class PushRulesWorkerStore( results = {user_id: [] for user_id in user_ids} - rows = yield self.simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="push_rules", column="user_name", iterable=user_ids, @@ -320,7 +320,7 @@ class PushRulesWorkerStore( results = {user_id: {} for user_id in user_ids} - rows = yield self.simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="push_rules_enable", column="user_name", iterable=user_ids, @@ -350,7 +350,7 @@ class PushRuleStore(PushRulesWorkerStore): with self._push_rules_stream_id_gen.get_next() as ids: stream_id, event_stream_ordering = ids if before or after: - yield self.runInteraction( + yield self.db.runInteraction( "_add_push_rule_relative_txn", self._add_push_rule_relative_txn, stream_id, @@ -364,7 +364,7 @@ class PushRuleStore(PushRulesWorkerStore): after, ) else: - yield self.runInteraction( + yield self.db.runInteraction( "_add_push_rule_highest_priority_txn", self._add_push_rule_highest_priority_txn, stream_id, @@ -395,7 +395,7 @@ class PushRuleStore(PushRulesWorkerStore): relative_to_rule = before or after - res = self.simple_select_one_txn( + res = self.db.simple_select_one_txn( txn, table="push_rules", keyvalues={"user_name": user_id, "rule_id": relative_to_rule}, @@ -518,7 +518,7 @@ class PushRuleStore(PushRulesWorkerStore): # We didn't update a row with the given rule_id so insert one push_rule_id = self._push_rule_id_gen.get_next() - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="push_rules", values={ @@ -561,7 +561,7 @@ class PushRuleStore(PushRulesWorkerStore): """ def delete_push_rule_txn(txn, stream_id, event_stream_ordering): - self.simple_delete_one_txn( + self.db.simple_delete_one_txn( txn, "push_rules", {"user_name": user_id, "rule_id": rule_id} ) @@ -571,7 +571,7 @@ class PushRuleStore(PushRulesWorkerStore): with self._push_rules_stream_id_gen.get_next() as ids: stream_id, event_stream_ordering = ids - yield self.runInteraction( + yield self.db.runInteraction( "delete_push_rule", delete_push_rule_txn, stream_id, @@ -582,7 +582,7 @@ class PushRuleStore(PushRulesWorkerStore): def set_push_rule_enabled(self, user_id, rule_id, enabled): with self._push_rules_stream_id_gen.get_next() as ids: stream_id, event_stream_ordering = ids - yield self.runInteraction( + yield self.db.runInteraction( "_set_push_rule_enabled_txn", self._set_push_rule_enabled_txn, stream_id, @@ -596,7 +596,7 @@ class PushRuleStore(PushRulesWorkerStore): self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled ): new_id = self._push_rules_enable_id_gen.get_next() - self.simple_upsert_txn( + self.db.simple_upsert_txn( txn, "push_rules_enable", {"user_name": user_id, "rule_id": rule_id}, @@ -636,7 +636,7 @@ class PushRuleStore(PushRulesWorkerStore): update_stream=False, ) else: - self.simple_update_one_txn( + self.db.simple_update_one_txn( txn, "push_rules", {"user_name": user_id, "rule_id": rule_id}, @@ -655,7 +655,7 @@ class PushRuleStore(PushRulesWorkerStore): with self._push_rules_stream_id_gen.get_next() as ids: stream_id, event_stream_ordering = ids - yield self.runInteraction( + yield self.db.runInteraction( "set_push_rule_actions", set_push_rule_actions_txn, stream_id, @@ -675,7 +675,7 @@ class PushRuleStore(PushRulesWorkerStore): if data is not None: values.update(data) - self.simple_insert_txn(txn, "push_rules_stream", values=values) + self.db.simple_insert_txn(txn, "push_rules_stream", values=values) txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,)) txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,)) @@ -699,7 +699,7 @@ class PushRuleStore(PushRulesWorkerStore): txn.execute(sql, (last_id, current_id, limit)) return txn.fetchall() - return self.runInteraction( + return self.db.runInteraction( "get_all_push_rule_updates", get_all_push_rule_updates_txn ) diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py index d5a169872b..f07309ef09 100644 --- a/synapse/storage/data_stores/main/pusher.py +++ b/synapse/storage/data_stores/main/pusher.py @@ -59,7 +59,7 @@ class PusherWorkerStore(SQLBaseStore): @defer.inlineCallbacks def user_has_pusher(self, user_id): - ret = yield self.simple_select_one_onecol( + ret = yield self.db.simple_select_one_onecol( "pushers", {"user_name": user_id}, "id", allow_none=True ) return ret is not None @@ -72,7 +72,7 @@ class PusherWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_pushers_by(self, keyvalues): - ret = yield self.simple_select_list( + ret = yield self.db.simple_select_list( "pushers", keyvalues, [ @@ -100,11 +100,11 @@ class PusherWorkerStore(SQLBaseStore): def get_all_pushers(self): def get_pushers(txn): txn.execute("SELECT * FROM pushers") - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) return self._decode_pushers_rows(rows) - rows = yield self.runInteraction("get_all_pushers", get_pushers) + rows = yield self.db.runInteraction("get_all_pushers", get_pushers) return rows def get_all_updated_pushers(self, last_id, current_id, limit): @@ -134,7 +134,7 @@ class PusherWorkerStore(SQLBaseStore): return updated, deleted - return self.runInteraction( + return self.db.runInteraction( "get_all_updated_pushers", get_all_updated_pushers_txn ) @@ -177,7 +177,7 @@ class PusherWorkerStore(SQLBaseStore): return results - return self.runInteraction( + return self.db.runInteraction( "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn ) @@ -193,7 +193,7 @@ class PusherWorkerStore(SQLBaseStore): inlineCallbacks=True, ) def get_if_users_have_pushers(self, user_ids): - rows = yield self.simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="pushers", column="user_name", iterable=user_ids, @@ -230,7 +230,7 @@ class PusherStore(PusherWorkerStore): with self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on # (app_id, pushkey, user_name) so simple_upsert will retry - yield self.simple_upsert( + yield self.db.simple_upsert( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, values={ @@ -255,7 +255,7 @@ class PusherStore(PusherWorkerStore): if user_has_pusher is not True: # invalidate, since we the user might not have had a pusher before - yield self.runInteraction( + yield self.db.runInteraction( "add_pusher", self._invalidate_cache_and_stream, self.get_if_user_has_pusher, @@ -269,7 +269,7 @@ class PusherStore(PusherWorkerStore): txn, self.get_if_user_has_pusher, (user_id,) ) - self.simple_delete_one_txn( + self.db.simple_delete_one_txn( txn, "pushers", {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, @@ -278,7 +278,7 @@ class PusherStore(PusherWorkerStore): # it's possible for us to end up with duplicate rows for # (app_id, pushkey, user_id) at different stream_ids, but that # doesn't really matter. - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="deleted_pushers", values={ @@ -290,13 +290,13 @@ class PusherStore(PusherWorkerStore): ) with self._pushers_id_gen.get_next() as stream_id: - yield self.runInteraction("delete_pusher", delete_pusher_txn, stream_id) + yield self.db.runInteraction("delete_pusher", delete_pusher_txn, stream_id) @defer.inlineCallbacks def update_pusher_last_stream_ordering( self, app_id, pushkey, user_id, last_stream_ordering ): - yield self.simple_update_one( + yield self.db.simple_update_one( "pushers", {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, {"last_stream_ordering": last_stream_ordering}, @@ -319,7 +319,7 @@ class PusherStore(PusherWorkerStore): Returns: Deferred[bool]: True if the pusher still exists; False if it has been deleted. """ - updated = yield self.simple_update( + updated = yield self.db.simple_update( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, updatevalues={ @@ -333,7 +333,7 @@ class PusherStore(PusherWorkerStore): @defer.inlineCallbacks def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since): - yield self.simple_update( + yield self.db.simple_update( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, updatevalues={"failing_since": failing_since}, @@ -342,7 +342,7 @@ class PusherStore(PusherWorkerStore): @defer.inlineCallbacks def get_throttle_params_by_room(self, pusher_id): - res = yield self.simple_select_list( + res = yield self.db.simple_select_list( "pusher_throttle", {"pusher": pusher_id}, ["room_id", "last_sent_ts", "throttle_ms"], @@ -362,7 +362,7 @@ class PusherStore(PusherWorkerStore): def set_throttle_params(self, pusher_id, room_id, params): # no need to lock because `pusher_throttle` has a primary key on # (pusher, room_id) so simple_upsert will retry - yield self.simple_upsert( + yield self.db.simple_upsert( "pusher_throttle", {"pusher": pusher_id, "room_id": room_id}, params, diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py index 380f388e30..ac2d45bd5c 100644 --- a/synapse/storage/data_stores/main/receipts.py +++ b/synapse/storage/data_stores/main/receipts.py @@ -61,7 +61,7 @@ class ReceiptsWorkerStore(SQLBaseStore): @cached(num_args=2) def get_receipts_for_room(self, room_id, receipt_type): - return self.simple_select_list( + return self.db.simple_select_list( table="receipts_linearized", keyvalues={"room_id": room_id, "receipt_type": receipt_type}, retcols=("user_id", "event_id"), @@ -70,7 +70,7 @@ class ReceiptsWorkerStore(SQLBaseStore): @cached(num_args=3) def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type): - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="receipts_linearized", keyvalues={ "room_id": room_id, @@ -84,7 +84,7 @@ class ReceiptsWorkerStore(SQLBaseStore): @cachedInlineCallbacks(num_args=2) def get_receipts_for_user(self, user_id, receipt_type): - rows = yield self.simple_select_list( + rows = yield self.db.simple_select_list( table="receipts_linearized", keyvalues={"user_id": user_id, "receipt_type": receipt_type}, retcols=("room_id", "event_id"), @@ -108,7 +108,7 @@ class ReceiptsWorkerStore(SQLBaseStore): txn.execute(sql, (user_id,)) return txn.fetchall() - rows = yield self.runInteraction("get_receipts_for_user_with_orderings", f) + rows = yield self.db.runInteraction("get_receipts_for_user_with_orderings", f) return { row[0]: { "event_id": row[1], @@ -187,11 +187,11 @@ class ReceiptsWorkerStore(SQLBaseStore): txn.execute(sql, (room_id, to_key)) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) return rows - rows = yield self.runInteraction("get_linearized_receipts_for_room", f) + rows = yield self.db.runInteraction("get_linearized_receipts_for_room", f) if not rows: return [] @@ -237,9 +237,11 @@ class ReceiptsWorkerStore(SQLBaseStore): txn.execute(sql + clause, [to_key] + list(args)) - return self.cursor_to_dict(txn) + return self.db.cursor_to_dict(txn) - txn_results = yield self.runInteraction("_get_linearized_receipts_for_rooms", f) + txn_results = yield self.db.runInteraction( + "_get_linearized_receipts_for_rooms", f + ) results = {} for row in txn_results: @@ -282,7 +284,7 @@ class ReceiptsWorkerStore(SQLBaseStore): return list(r[0:5] + (json.loads(r[5]),) for r in txn) - return self.runInteraction( + return self.db.runInteraction( "get_all_updated_receipts", get_all_updated_receipts_txn ) @@ -335,7 +337,7 @@ class ReceiptsStore(ReceiptsWorkerStore): otherwise, the rx timestamp of the event that the RR corresponds to (or 0 if the event is unknown) """ - res = self.simple_select_one_txn( + res = self.db.simple_select_one_txn( txn, table="events", retcols=["stream_ordering", "received_ts"], @@ -388,7 +390,7 @@ class ReceiptsStore(ReceiptsWorkerStore): (user_id, room_id, receipt_type), ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="receipts_linearized", keyvalues={ @@ -398,7 +400,7 @@ class ReceiptsStore(ReceiptsWorkerStore): }, ) - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="receipts_linearized", values={ @@ -453,13 +455,13 @@ class ReceiptsStore(ReceiptsWorkerStore): else: raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,)) - linearized_event_id = yield self.runInteraction( + linearized_event_id = yield self.db.runInteraction( "insert_receipt_conv", graph_to_linear ) stream_id_manager = self._receipts_id_gen.get_next() with stream_id_manager as stream_id: - event_ts = yield self.runInteraction( + event_ts = yield self.db.runInteraction( "insert_linearized_receipt", self.insert_linearized_receipt_txn, room_id, @@ -488,7 +490,7 @@ class ReceiptsStore(ReceiptsWorkerStore): return stream_id, max_persisted_id def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data): - return self.runInteraction( + return self.db.runInteraction( "insert_graph_receipt", self.insert_graph_receipt_txn, room_id, @@ -514,7 +516,7 @@ class ReceiptsStore(ReceiptsWorkerStore): self._get_linearized_receipts_for_room.invalidate_many, (room_id,) ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="receipts_graph", keyvalues={ @@ -523,7 +525,7 @@ class ReceiptsStore(ReceiptsWorkerStore): "user_id": user_id, }, ) - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="receipts_graph", values={ diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py index debc6706f5..8f9aa87ceb 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/data_stores/main/registration.py @@ -45,7 +45,7 @@ class RegistrationWorkerStore(SQLBaseStore): @cached() def get_user_by_id(self, user_id): - return self.simple_select_one( + return self.db.simple_select_one( table="users", keyvalues={"name": user_id}, retcols=[ @@ -94,7 +94,7 @@ class RegistrationWorkerStore(SQLBaseStore): including the keys `name`, `is_guest`, `device_id`, `token_id`, `valid_until_ms`. """ - return self.runInteraction( + return self.db.runInteraction( "get_user_by_access_token", self._query_for_auth, token ) @@ -109,7 +109,7 @@ class RegistrationWorkerStore(SQLBaseStore): otherwise int representation of the timestamp (as a number of milliseconds since epoch). """ - res = yield self.simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="account_validity", keyvalues={"user_id": user_id}, retcol="expiration_ts_ms", @@ -137,7 +137,7 @@ class RegistrationWorkerStore(SQLBaseStore): """ def set_account_validity_for_user_txn(txn): - self.simple_update_txn( + self.db.simple_update_txn( txn=txn, table="account_validity", keyvalues={"user_id": user_id}, @@ -151,7 +151,7 @@ class RegistrationWorkerStore(SQLBaseStore): txn, self.get_expiration_ts_for_user, (user_id,) ) - yield self.runInteraction( + yield self.db.runInteraction( "set_account_validity_for_user", set_account_validity_for_user_txn ) @@ -167,7 +167,7 @@ class RegistrationWorkerStore(SQLBaseStore): Raises: StoreError: The provided token is already set for another user. """ - yield self.simple_update_one( + yield self.db.simple_update_one( table="account_validity", keyvalues={"user_id": user_id}, updatevalues={"renewal_token": renewal_token}, @@ -184,7 +184,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: defer.Deferred[str]: The ID of the user to which the token belongs. """ - res = yield self.simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="account_validity", keyvalues={"renewal_token": renewal_token}, retcol="user_id", @@ -203,7 +203,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: defer.Deferred[str]: The renewal token associated with this user ID. """ - res = yield self.simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="account_validity", keyvalues={"user_id": user_id}, retcol="renewal_token", @@ -229,9 +229,9 @@ class RegistrationWorkerStore(SQLBaseStore): ) values = [False, now_ms, renew_at] txn.execute(sql, values) - return self.cursor_to_dict(txn) + return self.db.cursor_to_dict(txn) - res = yield self.runInteraction( + res = yield self.db.runInteraction( "get_users_expiring_soon", select_users_txn, self.clock.time_msec(), @@ -250,7 +250,7 @@ class RegistrationWorkerStore(SQLBaseStore): email_sent (bool): Flag which indicates whether a renewal email has been sent to this user. """ - yield self.simple_update_one( + yield self.db.simple_update_one( table="account_validity", keyvalues={"user_id": user_id}, updatevalues={"email_sent": email_sent}, @@ -265,7 +265,7 @@ class RegistrationWorkerStore(SQLBaseStore): Args: user_id (str): ID of the user to remove from the account validity table. """ - yield self.simple_delete_one( + yield self.db.simple_delete_one( table="account_validity", keyvalues={"user_id": user_id}, desc="delete_account_validity_for_user", @@ -281,7 +281,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns (bool): true iff the user is a server admin, false otherwise. """ - res = yield self.simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="users", keyvalues={"name": user.to_string()}, retcol="admin", @@ -299,7 +299,7 @@ class RegistrationWorkerStore(SQLBaseStore): admin (bool): true iff the user is to be a server admin, false otherwise. """ - return self.simple_update_one( + return self.db.simple_update_one( table="users", keyvalues={"name": user.to_string()}, updatevalues={"admin": 1 if admin else 0}, @@ -316,7 +316,7 @@ class RegistrationWorkerStore(SQLBaseStore): ) txn.execute(sql, (token,)) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if rows: return rows[0] @@ -332,7 +332,9 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: Deferred[bool]: True if user 'user_type' is null or empty string """ - res = yield self.runInteraction("is_real_user", self.is_real_user_txn, user_id) + res = yield self.db.runInteraction( + "is_real_user", self.is_real_user_txn, user_id + ) return res @cachedInlineCallbacks() @@ -345,13 +347,13 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: Deferred[bool]: True if user is of type UserTypes.SUPPORT """ - res = yield self.runInteraction( + res = yield self.db.runInteraction( "is_support_user", self.is_support_user_txn, user_id ) return res def is_real_user_txn(self, txn, user_id): - res = self.simple_select_one_onecol_txn( + res = self.db.simple_select_one_onecol_txn( txn=txn, table="users", keyvalues={"name": user_id}, @@ -361,7 +363,7 @@ class RegistrationWorkerStore(SQLBaseStore): return res is None def is_support_user_txn(self, txn, user_id): - res = self.simple_select_one_onecol_txn( + res = self.db.simple_select_one_onecol_txn( txn=txn, table="users", keyvalues={"name": user_id}, @@ -380,7 +382,7 @@ class RegistrationWorkerStore(SQLBaseStore): txn.execute(sql, (user_id,)) return dict(txn) - return self.runInteraction("get_users_by_id_case_insensitive", f) + return self.db.runInteraction("get_users_by_id_case_insensitive", f) async def get_user_by_external_id( self, auth_provider: str, external_id: str @@ -394,7 +396,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: str|None: the mxid of the user, or None if they are not known """ - return await self.simple_select_one_onecol( + return await self.db.simple_select_one_onecol( table="user_external_ids", keyvalues={"auth_provider": auth_provider, "external_id": external_id}, retcol="user_id", @@ -408,12 +410,12 @@ class RegistrationWorkerStore(SQLBaseStore): def _count_users(txn): txn.execute("SELECT COUNT(*) AS users FROM users") - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if rows: return rows[0]["users"] return 0 - ret = yield self.runInteraction("count_users", _count_users) + ret = yield self.db.runInteraction("count_users", _count_users) return ret def count_daily_user_type(self): @@ -445,7 +447,7 @@ class RegistrationWorkerStore(SQLBaseStore): results[row[0]] = row[1] return results - return self.runInteraction("count_daily_user_type", _count_daily_user_type) + return self.db.runInteraction("count_daily_user_type", _count_daily_user_type) @defer.inlineCallbacks def count_nonbridged_users(self): @@ -459,7 +461,7 @@ class RegistrationWorkerStore(SQLBaseStore): (count,) = txn.fetchone() return count - ret = yield self.runInteraction("count_users", _count_users) + ret = yield self.db.runInteraction("count_users", _count_users) return ret @defer.inlineCallbacks @@ -468,12 +470,12 @@ class RegistrationWorkerStore(SQLBaseStore): def _count_users(txn): txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null") - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if rows: return rows[0]["users"] return 0 - ret = yield self.runInteraction("count_real_users", _count_users) + ret = yield self.db.runInteraction("count_real_users", _count_users) return ret @defer.inlineCallbacks @@ -503,7 +505,7 @@ class RegistrationWorkerStore(SQLBaseStore): return ( ( - yield self.runInteraction( + yield self.db.runInteraction( "find_next_generated_user_id", _find_next_generated_user_id ) ) @@ -520,7 +522,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: Deferred[str|None]: user id or None if no user id/threepid mapping exists """ - user_id = yield self.runInteraction( + user_id = yield self.db.runInteraction( "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address ) return user_id @@ -536,7 +538,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: str|None: user id or None if no user id/threepid mapping exists """ - ret = self.simple_select_one_txn( + ret = self.db.simple_select_one_txn( txn, "user_threepids", {"medium": medium, "address": address}, @@ -549,7 +551,7 @@ class RegistrationWorkerStore(SQLBaseStore): @defer.inlineCallbacks def user_add_threepid(self, user_id, medium, address, validated_at, added_at): - yield self.simple_upsert( + yield self.db.simple_upsert( "user_threepids", {"medium": medium, "address": address}, {"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, @@ -557,7 +559,7 @@ class RegistrationWorkerStore(SQLBaseStore): @defer.inlineCallbacks def user_get_threepids(self, user_id): - ret = yield self.simple_select_list( + ret = yield self.db.simple_select_list( "user_threepids", {"user_id": user_id}, ["medium", "address", "validated_at", "added_at"], @@ -566,7 +568,7 @@ class RegistrationWorkerStore(SQLBaseStore): return ret def user_delete_threepid(self, user_id, medium, address): - return self.simple_delete( + return self.db.simple_delete( "user_threepids", keyvalues={"user_id": user_id, "medium": medium, "address": address}, desc="user_delete_threepid", @@ -579,7 +581,7 @@ class RegistrationWorkerStore(SQLBaseStore): user_id: The user id to delete all threepids of """ - return self.simple_delete( + return self.db.simple_delete( "user_threepids", keyvalues={"user_id": user_id}, desc="user_delete_threepids", @@ -601,7 +603,7 @@ class RegistrationWorkerStore(SQLBaseStore): """ # We need to use an upsert, in case they user had already bound the # threepid - return self.simple_upsert( + return self.db.simple_upsert( table="user_threepid_id_server", keyvalues={ "user_id": user_id, @@ -627,7 +629,7 @@ class RegistrationWorkerStore(SQLBaseStore): medium (str): The medium of the threepid (e.g "email") address (str): The address of the threepid (e.g "bob@example.com") """ - return self.simple_select_list( + return self.db.simple_select_list( table="user_threepid_id_server", keyvalues={"user_id": user_id}, retcols=["medium", "address"], @@ -648,7 +650,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: Deferred """ - return self.simple_delete( + return self.db.simple_delete( table="user_threepid_id_server", keyvalues={ "user_id": user_id, @@ -671,7 +673,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: Deferred[list[str]]: Resolves to a list of identity servers """ - return self.simple_select_onecol( + return self.db.simple_select_onecol( table="user_threepid_id_server", keyvalues={"user_id": user_id, "medium": medium, "address": address}, retcol="id_server", @@ -689,7 +691,7 @@ class RegistrationWorkerStore(SQLBaseStore): defer.Deferred(bool): The requested value. """ - res = yield self.simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="users", keyvalues={"name": user_id}, retcol="deactivated", @@ -756,13 +758,13 @@ class RegistrationWorkerStore(SQLBaseStore): sql += " LIMIT 1" txn.execute(sql, list(keyvalues.values())) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if not rows: return None return rows[0] - return self.runInteraction( + return self.db.runInteraction( "get_threepid_validation_session", get_threepid_validation_session_txn ) @@ -776,18 +778,18 @@ class RegistrationWorkerStore(SQLBaseStore): """ def delete_threepid_session_txn(txn): - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="threepid_validation_token", keyvalues={"session_id": session_id}, ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, ) - return self.runInteraction( + return self.db.runInteraction( "delete_threepid_session", delete_threepid_session_txn ) @@ -857,7 +859,7 @@ class RegistrationBackgroundUpdateStore( (last_user, batch_size), ) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if not rows: return True, 0 @@ -880,7 +882,7 @@ class RegistrationBackgroundUpdateStore( else: return False, len(rows) - end, nb_processed = yield self.runInteraction( + end, nb_processed = yield self.db.runInteraction( "users_set_deactivated_flag", _background_update_set_deactivated_flag_txn ) @@ -911,7 +913,7 @@ class RegistrationBackgroundUpdateStore( txn.executemany(sql, [(id_server,) for id_server in id_servers]) if id_servers: - yield self.runInteraction( + yield self.db.runInteraction( "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn ) @@ -961,7 +963,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): """ next_id = self._access_tokens_id_gen.get_next() - yield self.simple_insert( + yield self.db.simple_insert( "access_tokens", { "id": next_id, @@ -1003,7 +1005,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): Raises: StoreError if the user_id could not be registered. """ - return self.runInteraction( + return self.db.runInteraction( "register_user", self._register_user, user_id, @@ -1037,7 +1039,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): # Ensure that the guest user actually exists # ``allow_none=False`` makes this raise an exception # if the row isn't in the database. - self.simple_select_one_txn( + self.db.simple_select_one_txn( txn, "users", keyvalues={"name": user_id, "is_guest": 1}, @@ -1045,7 +1047,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): allow_none=False, ) - self.simple_update_one_txn( + self.db.simple_update_one_txn( txn, "users", keyvalues={"name": user_id, "is_guest": 1}, @@ -1059,7 +1061,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): }, ) else: - self.simple_insert_txn( + self.db.simple_insert_txn( txn, "users", values={ @@ -1114,7 +1116,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): external_id: id on that system user_id: complete mxid that it is mapped to """ - return self.simple_insert( + return self.db.simple_insert( table="user_external_ids", values={ "auth_provider": auth_provider, @@ -1132,12 +1134,14 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): """ def user_set_password_hash_txn(txn): - self.simple_update_one_txn( + self.db.simple_update_one_txn( txn, "users", {"name": user_id}, {"password_hash": password_hash} ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - return self.runInteraction("user_set_password_hash", user_set_password_hash_txn) + return self.db.runInteraction( + "user_set_password_hash", user_set_password_hash_txn + ) def user_set_consent_version(self, user_id, consent_version): """Updates the user table to record privacy policy consent @@ -1152,7 +1156,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): """ def f(txn): - self.simple_update_one_txn( + self.db.simple_update_one_txn( txn, table="users", keyvalues={"name": user_id}, @@ -1160,7 +1164,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - return self.runInteraction("user_set_consent_version", f) + return self.db.runInteraction("user_set_consent_version", f) def user_set_consent_server_notice_sent(self, user_id, consent_version): """Updates the user table to record that we have sent the user a server @@ -1176,7 +1180,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): """ def f(txn): - self.simple_update_one_txn( + self.db.simple_update_one_txn( txn, table="users", keyvalues={"name": user_id}, @@ -1184,7 +1188,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - return self.runInteraction("user_set_consent_server_notice_sent", f) + return self.db.runInteraction("user_set_consent_server_notice_sent", f) def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None): """ @@ -1230,11 +1234,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): return tokens_and_devices - return self.runInteraction("user_delete_access_tokens", f) + return self.db.runInteraction("user_delete_access_tokens", f) def delete_access_token(self, access_token): def f(txn): - self.simple_delete_one_txn( + self.db.simple_delete_one_txn( txn, table="access_tokens", keyvalues={"token": access_token} ) @@ -1242,11 +1246,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): txn, self.get_user_by_access_token, (access_token,) ) - return self.runInteraction("delete_access_token", f) + return self.db.runInteraction("delete_access_token", f) @cachedInlineCallbacks() def is_guest(self, user_id): - res = yield self.simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="users", keyvalues={"name": user_id}, retcol="is_guest", @@ -1261,7 +1265,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): Adds a user to the table of users who need to be parted from all the rooms they're in """ - return self.simple_insert( + return self.db.simple_insert( "users_pending_deactivation", values={"user_id": user_id}, desc="add_user_pending_deactivation", @@ -1274,7 +1278,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): """ # XXX: This should be simple_delete_one but we failed to put a unique index on # the table, so somehow duplicate entries have ended up in it. - return self.simple_delete( + return self.db.simple_delete( "users_pending_deactivation", keyvalues={"user_id": user_id}, desc="del_user_pending_deactivation", @@ -1285,7 +1289,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): Gets one user from the table of users waiting to be parted from all the rooms they're in. """ - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( "users_pending_deactivation", keyvalues={}, retcol="user_id", @@ -1315,7 +1319,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): # Insert everything into a transaction in order to run atomically def validate_threepid_session_txn(txn): - row = self.simple_select_one_txn( + row = self.db.simple_select_one_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, @@ -1333,7 +1337,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): 400, "This client_secret does not match the provided session_id" ) - row = self.simple_select_one_txn( + row = self.db.simple_select_one_txn( txn, table="threepid_validation_token", keyvalues={"session_id": session_id, "token": token}, @@ -1358,7 +1362,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) # Looks good. Validate the session - self.simple_update_txn( + self.db.simple_update_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, @@ -1368,7 +1372,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): return next_link # Return next_link if it exists - return self.runInteraction( + return self.db.runInteraction( "validate_threepid_session_txn", validate_threepid_session_txn ) @@ -1401,7 +1405,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): if validated_at: insertion_values["validated_at"] = validated_at - return self.simple_upsert( + return self.db.simple_upsert( table="threepid_validation_session", keyvalues={"session_id": session_id}, values={"last_send_attempt": send_attempt}, @@ -1439,7 +1443,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): def start_or_continue_validation_session_txn(txn): # Create or update a validation session - self.simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, @@ -1452,7 +1456,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) # Create a new validation token with this session ID - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="threepid_validation_token", values={ @@ -1463,7 +1467,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): }, ) - return self.runInteraction( + return self.db.runInteraction( "start_or_continue_validation_session", start_or_continue_validation_session_txn, ) @@ -1478,7 +1482,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): """ return txn.execute(sql, (ts,)) - return self.runInteraction( + return self.db.runInteraction( "cull_expired_threepid_validation_tokens", cull_expired_threepid_validation_tokens_txn, self.clock.time_msec(), @@ -1493,7 +1497,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): deactivated (bool): The value to set for `deactivated`. """ - yield self.runInteraction( + yield self.db.runInteraction( "set_user_deactivated_status", self.set_user_deactivated_status_txn, user_id, @@ -1501,7 +1505,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) def set_user_deactivated_status_txn(self, txn, user_id, deactivated): - self.simple_update_one_txn( + self.db.simple_update_one_txn( txn=txn, table="users", keyvalues={"name": user_id}, @@ -1529,14 +1533,14 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) txn.execute(sql, []) - res = self.cursor_to_dict(txn) + res = self.db.cursor_to_dict(txn) if res: for user in res: self.set_expiration_date_for_user_txn( txn, user["name"], use_delta=True ) - yield self.runInteraction( + yield self.db.runInteraction( "get_users_with_no_expiration_date", select_users_with_no_expiration_date_txn, ) @@ -1560,7 +1564,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): expiration_ts, ) - self.simple_upsert_txn( + self.db.simple_upsert_txn( txn, "account_validity", keyvalues={"user_id": user_id}, diff --git a/synapse/storage/data_stores/main/rejections.py b/synapse/storage/data_stores/main/rejections.py index f81f9279a1..1c07c7a425 100644 --- a/synapse/storage/data_stores/main/rejections.py +++ b/synapse/storage/data_stores/main/rejections.py @@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) class RejectionsStore(SQLBaseStore): def _store_rejections_txn(self, txn, event_id, reason): - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="rejections", values={ @@ -33,7 +33,7 @@ class RejectionsStore(SQLBaseStore): ) def get_rejection_reason(self, event_id): - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="rejections", retcol="reason", keyvalues={"event_id": event_id}, diff --git a/synapse/storage/data_stores/main/relations.py b/synapse/storage/data_stores/main/relations.py index aa5e10538b..046c2b4845 100644 --- a/synapse/storage/data_stores/main/relations.py +++ b/synapse/storage/data_stores/main/relations.py @@ -129,7 +129,7 @@ class RelationsWorkerStore(SQLBaseStore): chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token ) - return self.runInteraction( + return self.db.runInteraction( "get_recent_references_for_event", _get_recent_references_for_event_txn ) @@ -223,7 +223,7 @@ class RelationsWorkerStore(SQLBaseStore): chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token ) - return self.runInteraction( + return self.db.runInteraction( "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn ) @@ -268,7 +268,7 @@ class RelationsWorkerStore(SQLBaseStore): if row: return row[0] - edit_id = yield self.runInteraction( + edit_id = yield self.db.runInteraction( "get_applicable_edit", _get_applicable_edit_txn ) @@ -318,7 +318,7 @@ class RelationsWorkerStore(SQLBaseStore): return bool(txn.fetchone()) - return self.runInteraction( + return self.db.runInteraction( "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) @@ -352,7 +352,7 @@ class RelationsStore(RelationsWorkerStore): aggregation_key = relation.get("key") - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="event_relations", values={ @@ -380,6 +380,6 @@ class RelationsStore(RelationsWorkerStore): redacted_event_id (str): The event that was redacted. """ - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id} ) diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index f309e3640c..a26ed47afc 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -54,7 +54,7 @@ class RoomWorkerStore(SQLBaseStore): Returns: A dict containing the room information, or None if the room is unknown. """ - return self.simple_select_one( + return self.db.simple_select_one( table="rooms", keyvalues={"room_id": room_id}, retcols=("room_id", "is_public", "creator"), @@ -63,7 +63,7 @@ class RoomWorkerStore(SQLBaseStore): ) def get_public_room_ids(self): - return self.simple_select_onecol( + return self.db.simple_select_onecol( table="rooms", keyvalues={"is_public": True}, retcol="room_id", @@ -120,7 +120,7 @@ class RoomWorkerStore(SQLBaseStore): txn.execute(sql, query_args) return txn.fetchone()[0] - return self.runInteraction("count_public_rooms", _count_public_rooms_txn) + return self.db.runInteraction("count_public_rooms", _count_public_rooms_txn) @defer.inlineCallbacks def get_largest_public_rooms( @@ -253,21 +253,21 @@ class RoomWorkerStore(SQLBaseStore): def _get_largest_public_rooms_txn(txn): txn.execute(sql, query_args) - results = self.cursor_to_dict(txn) + results = self.db.cursor_to_dict(txn) if not forwards: results.reverse() return results - ret_val = yield self.runInteraction( + ret_val = yield self.db.runInteraction( "get_largest_public_rooms", _get_largest_public_rooms_txn ) defer.returnValue(ret_val) @cached(max_entries=10000) def is_room_blocked(self, room_id): - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="blocked_rooms", keyvalues={"room_id": room_id}, retcol="1", @@ -288,7 +288,7 @@ class RoomWorkerStore(SQLBaseStore): of RatelimitOverride are None or 0 then ratelimitng has been disabled for that user entirely. """ - row = yield self.simple_select_one( + row = yield self.db.simple_select_one( table="ratelimit_override", keyvalues={"user_id": user_id}, retcols=("messages_per_second", "burst_count"), @@ -330,9 +330,9 @@ class RoomWorkerStore(SQLBaseStore): (room_id,), ) - return self.cursor_to_dict(txn) + return self.db.cursor_to_dict(txn) - ret = yield self.runInteraction( + ret = yield self.db.runInteraction( "get_retention_policy_for_room", get_retention_policy_for_room_txn, ) @@ -396,7 +396,7 @@ class RoomBackgroundUpdateStore(BackgroundUpdateStore): (last_room, batch_size), ) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if not rows: return True @@ -408,7 +408,7 @@ class RoomBackgroundUpdateStore(BackgroundUpdateStore): ev = json.loads(row["json"]) retention_policy = json.dumps(ev["content"]) - self.simple_insert_txn( + self.db.simple_insert_txn( txn=txn, table="room_retention", values={ @@ -430,7 +430,7 @@ class RoomBackgroundUpdateStore(BackgroundUpdateStore): else: return False - end = yield self.runInteraction( + end = yield self.db.runInteraction( "insert_room_retention", _background_insert_retention_txn, ) @@ -461,7 +461,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): try: def store_room_txn(txn, next_id): - self.simple_insert_txn( + self.db.simple_insert_txn( txn, "rooms", { @@ -471,7 +471,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) if is_public: - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="public_room_list_stream", values={ @@ -482,7 +482,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): ) with self._public_room_id_gen.get_next() as next_id: - yield self.runInteraction("store_room_txn", store_room_txn, next_id) + yield self.db.runInteraction("store_room_txn", store_room_txn, next_id) except Exception as e: logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") @@ -490,14 +490,14 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): @defer.inlineCallbacks def set_room_is_public(self, room_id, is_public): def set_room_is_public_txn(txn, next_id): - self.simple_update_one_txn( + self.db.simple_update_one_txn( txn, table="rooms", keyvalues={"room_id": room_id}, updatevalues={"is_public": is_public}, ) - entries = self.simple_select_list_txn( + entries = self.db.simple_select_list_txn( txn, table="public_room_list_stream", keyvalues={ @@ -515,7 +515,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): add_to_stream = bool(entries[-1]["visibility"]) != is_public if add_to_stream: - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="public_room_list_stream", values={ @@ -528,7 +528,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): ) with self._public_room_id_gen.get_next() as next_id: - yield self.runInteraction( + yield self.db.runInteraction( "set_room_is_public", set_room_is_public_txn, next_id ) self.hs.get_notifier().on_new_replication_data() @@ -555,7 +555,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): def set_room_is_public_appservice_txn(txn, next_id): if is_public: try: - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="appservice_room_list", values={ @@ -568,7 +568,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): # We've already inserted, nothing to do. return else: - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="appservice_room_list", keyvalues={ @@ -578,7 +578,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - entries = self.simple_select_list_txn( + entries = self.db.simple_select_list_txn( txn, table="public_room_list_stream", keyvalues={ @@ -596,7 +596,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): add_to_stream = bool(entries[-1]["visibility"]) != is_public if add_to_stream: - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="public_room_list_stream", values={ @@ -609,7 +609,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): ) with self._public_room_id_gen.get_next() as next_id: - yield self.runInteraction( + yield self.db.runInteraction( "set_room_is_public_appservice", set_room_is_public_appservice_txn, next_id, @@ -626,7 +626,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): row = txn.fetchone() return row[0] or 0 - return self.runInteraction("get_rooms", f) + return self.db.runInteraction("get_rooms", f) def _store_room_topic_txn(self, txn, event): if hasattr(event, "content") and "topic" in event.content: @@ -660,7 +660,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): # Ignore the event if one of the value isn't an integer. return - self.simple_insert_txn( + self.db.simple_insert_txn( txn=txn, table="room_retention", values={ @@ -679,7 +679,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): self, room_id, event_id, user_id, reason, content, received_ts ): next_id = self._event_reports_id_gen.get_next() - return self.simple_insert( + return self.db.simple_insert( table="event_reports", values={ "id": next_id, @@ -712,7 +712,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): if prev_id == current_id: return defer.succeed([]) - return self.runInteraction("get_all_new_public_rooms", get_all_new_public_rooms) + return self.db.runInteraction( + "get_all_new_public_rooms", get_all_new_public_rooms + ) @defer.inlineCallbacks def block_room(self, room_id, user_id): @@ -725,14 +727,14 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): Returns: Deferred """ - yield self.simple_upsert( + yield self.db.simple_upsert( table="blocked_rooms", keyvalues={"room_id": room_id}, values={}, insertion_values={"user_id": user_id}, desc="block_room", ) - yield self.runInteraction( + yield self.db.runInteraction( "block_room_invalidation", self._invalidate_cache_and_stream, self.is_room_blocked, @@ -763,7 +765,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): return local_media_mxcs, remote_media_mxcs - return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn) + return self.db.runInteraction( + "get_media_ids_in_room", _get_media_mxcs_in_room_txn + ) def quarantine_media_ids_in_room(self, room_id, quarantined_by): """For a room loops through all events with media and quarantines @@ -802,7 +806,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): return total_media_quarantined - return self.runInteraction( + return self.db.runInteraction( "quarantine_media_in_room", _quarantine_media_in_room_txn ) @@ -907,7 +911,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): txn.execute(sql, args) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) rooms_dict = {} for row in rows: @@ -923,7 +927,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): txn.execute(sql) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) # If a room isn't already in the dict (i.e. it doesn't have a retention # policy in its state), add it with a null policy. @@ -936,7 +940,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): return rooms_dict - rooms = yield self.runInteraction( + rooms = yield self.db.runInteraction( "get_rooms_for_retention_period_in_range", get_rooms_for_retention_period_in_range_txn, ) diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index fe2428a281..7f4d02b25b 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -116,7 +116,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): txn.execute(query) return list(txn)[0][0] - count = yield self.runInteraction("get_known_servers", _transact) + count = yield self.db.runInteraction("get_known_servers", _transact) # We always know about ourselves, even if we have nothing in # room_memberships (for example, the server is new). @@ -128,7 +128,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): membership column is up to date """ - pending_update = self.simple_select_one_txn( + pending_update = self.db.simple_select_one_txn( txn, table="background_updates", keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME}, @@ -144,7 +144,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): 15.0, run_as_background_process, "_check_safe_current_state_events_membership_updated", - self.runInteraction, + self.db.runInteraction, "_check_safe_current_state_events_membership_updated", self._check_safe_current_state_events_membership_updated_txn, ) @@ -161,7 +161,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): @cached(max_entries=100000, iterable=True) def get_users_in_room(self, room_id): - return self.runInteraction( + return self.db.runInteraction( "get_users_in_room", self.get_users_in_room_txn, room_id ) @@ -269,7 +269,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): return res - return self.runInteraction("get_room_summary", _get_room_summary_txn) + return self.db.runInteraction("get_room_summary", _get_room_summary_txn) def _get_user_counts_in_room_txn(self, txn, room_id): """ @@ -339,7 +339,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): if not membership_list: return defer.succeed(None) - rooms = yield self.runInteraction( + rooms = yield self.db.runInteraction( "get_rooms_for_user_where_membership_is", self._get_rooms_for_user_where_membership_is_txn, user_id, @@ -392,7 +392,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) txn.execute(sql, (user_id, *args)) - results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)] + results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)] if do_invite: sql = ( @@ -412,7 +412,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): stream_ordering=r["stream_ordering"], membership=Membership.INVITE, ) - for r in self.cursor_to_dict(txn) + for r in self.db.cursor_to_dict(txn) ) return results @@ -603,7 +603,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): to `user_id` and ProfileInfo (or None if not join event). """ - rows = yield self.simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="room_memberships", column="event_id", iterable=event_ids, @@ -643,7 +643,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # the returned user actually has the correct domain. like_clause = "%:" + host - rows = yield self.execute("is_host_joined", None, sql, room_id, like_clause) + rows = yield self.db.execute("is_host_joined", None, sql, room_id, like_clause) if not rows: return False @@ -683,7 +683,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # the returned user actually has the correct domain. like_clause = "%:" + host - rows = yield self.execute("was_host_joined", None, sql, room_id, like_clause) + rows = yield self.db.execute("was_host_joined", None, sql, room_id, like_clause) if not rows: return False @@ -753,7 +753,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): rows = txn.fetchall() return rows[0][0] - count = yield self.runInteraction("did_forget_membership", f) + count = yield self.db.runInteraction("did_forget_membership", f) return count == 0 @cached() @@ -790,7 +790,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): txn.execute(sql, (user_id,)) return set(row[0] for row in txn if row[1] == 0) - return self.runInteraction( + return self.db.runInteraction( "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn ) @@ -805,7 +805,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): Deferred[set[str]]: Set of room IDs. """ - room_ids = yield self.simple_select_onecol( + room_ids = yield self.db.simple_select_onecol( table="room_memberships", keyvalues={"membership": Membership.JOIN, "user_id": user_id}, retcol="room_id", @@ -820,7 +820,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): """Get user_id and membership of a set of event IDs. """ - return self.simple_select_many_batch( + return self.db.simple_select_many_batch( table="room_memberships", column="event_id", iterable=member_event_ids, @@ -874,7 +874,7 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if not rows: return 0 @@ -915,7 +915,7 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): return len(rows) - result = yield self.runInteraction( + result = yield self.db.runInteraction( _MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn ) @@ -971,7 +971,7 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): # string, which will compare before all room IDs correctly. last_processed_room = progress.get("last_processed_room", "") - row_count, finished = yield self.runInteraction( + row_count, finished = yield self.db.runInteraction( "_background_current_state_membership_update", _background_current_state_membership_txn, last_processed_room, @@ -990,7 +990,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): def _store_room_members_txn(self, txn, events, backfilled): """Store a room member in the database. """ - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="room_memberships", values=[ @@ -1028,7 +1028,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): is_mine = self.hs.is_mine_id(event.state_key) if is_new_state and is_mine: if event.membership == Membership.INVITE: - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="local_invites", values={ @@ -1068,7 +1068,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): txn.execute(sql, (stream_ordering, True, room_id, user_id)) with self._stream_id_gen.get_next() as stream_ordering: - yield self.runInteraction("locally_reject_invite", f, stream_ordering) + yield self.db.runInteraction("locally_reject_invite", f, stream_ordering) def forget(self, user_id, room_id): """Indicate that user_id wishes to discard history for room_id.""" @@ -1091,7 +1091,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): txn, self.get_forgotten_rooms_for_user, (user_id,) ) - return self.runInteraction("forget_membership", f) + return self.db.runInteraction("forget_membership", f) class _JoinedHostsCache(object): diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py index f735cf095c..55a604850e 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -93,7 +93,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): # store_search_entries_txn with a generator function, but that # would mean having two cursors open on the database at once. # Instead we just build a list of results. - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if not rows: return 0 @@ -159,7 +159,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): return len(event_search_rows) - result = yield self.runInteraction( + result = yield self.db.runInteraction( self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn ) @@ -206,7 +206,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): conn.set_session(autocommit=False) if isinstance(self.database_engine, PostgresEngine): - yield self.runWithConnection(create_index) + yield self.db.runWithConnection(create_index) yield self._end_background_update(self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME) return 1 @@ -237,12 +237,12 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): ) conn.set_session(autocommit=False) - yield self.runWithConnection(create_index) + yield self.db.runWithConnection(create_index) pg = dict(progress) pg["have_added_indexes"] = True - yield self.runInteraction( + yield self.db.runInteraction( self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_update_progress_txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, @@ -280,7 +280,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): return len(rows), True - num_rows, finished = yield self.runInteraction( + num_rows, finished = yield self.db.runInteraction( self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn ) @@ -441,7 +441,9 @@ class SearchStore(SearchBackgroundUpdateStore): # entire table from the database. sql += " ORDER BY rank DESC LIMIT 500" - results = yield self.execute("search_msgs", self.cursor_to_dict, sql, *args) + results = yield self.db.execute( + "search_msgs", self.db.cursor_to_dict, sql, *args + ) results = list(filter(lambda row: row["room_id"] in room_ids, results)) @@ -455,8 +457,8 @@ class SearchStore(SearchBackgroundUpdateStore): count_sql += " GROUP BY room_id" - count_results = yield self.execute( - "search_rooms_count", self.cursor_to_dict, count_sql, *count_args + count_results = yield self.db.execute( + "search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args ) count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) @@ -586,7 +588,9 @@ class SearchStore(SearchBackgroundUpdateStore): args.append(limit) - results = yield self.execute("search_rooms", self.cursor_to_dict, sql, *args) + results = yield self.db.execute( + "search_rooms", self.db.cursor_to_dict, sql, *args + ) results = list(filter(lambda row: row["room_id"] in room_ids, results)) @@ -600,8 +604,8 @@ class SearchStore(SearchBackgroundUpdateStore): count_sql += " GROUP BY room_id" - count_results = yield self.execute( - "search_rooms_count", self.cursor_to_dict, count_sql, *count_args + count_results = yield self.db.execute( + "search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args ) count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) @@ -686,7 +690,7 @@ class SearchStore(SearchBackgroundUpdateStore): return highlight_words - return self.runInteraction("_find_highlights", f) + return self.db.runInteraction("_find_highlights", f) def _to_postgres_options(options_dict): diff --git a/synapse/storage/data_stores/main/signatures.py b/synapse/storage/data_stores/main/signatures.py index f3da29ce14..563216b63c 100644 --- a/synapse/storage/data_stores/main/signatures.py +++ b/synapse/storage/data_stores/main/signatures.py @@ -48,7 +48,7 @@ class SignatureWorkerStore(SQLBaseStore): for event_id in event_ids } - return self.runInteraction("get_event_reference_hashes", f) + return self.db.runInteraction("get_event_reference_hashes", f) @defer.inlineCallbacks def add_event_hashes(self, event_ids): @@ -98,4 +98,4 @@ class SignatureStore(SignatureWorkerStore): } ) - self.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals) + self.db.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals) diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 2b33ec1a35..851e81d6b3 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -89,7 +89,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): count = 0 while next_group: - next_group = self.simple_select_one_onecol_txn( + next_group = self.db.simple_select_one_onecol_txn( txn, table="state_group_edges", keyvalues={"state_group": next_group}, @@ -192,7 +192,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): ): break - next_group = self.simple_select_one_onecol_txn( + next_group = self.db.simple_select_one_onecol_txn( txn, table="state_group_edges", keyvalues={"state_group": next_group}, @@ -348,7 +348,9 @@ class StateGroupWorkerStore( (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn } - return self.runInteraction("get_current_state_ids", _get_current_state_ids_txn) + return self.db.runInteraction( + "get_current_state_ids", _get_current_state_ids_txn + ) # FIXME: how should this be cached? def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()): @@ -392,7 +394,7 @@ class StateGroupWorkerStore( return results - return self.runInteraction( + return self.db.runInteraction( "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn ) @@ -431,7 +433,7 @@ class StateGroupWorkerStore( """ def _get_state_group_delta_txn(txn): - prev_group = self.simple_select_one_onecol_txn( + prev_group = self.db.simple_select_one_onecol_txn( txn, table="state_group_edges", keyvalues={"state_group": state_group}, @@ -442,7 +444,7 @@ class StateGroupWorkerStore( if not prev_group: return _GetStateGroupDelta(None, None) - delta_ids = self.simple_select_list_txn( + delta_ids = self.db.simple_select_list_txn( txn, table="state_groups_state", keyvalues={"state_group": state_group}, @@ -454,7 +456,9 @@ class StateGroupWorkerStore( {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids}, ) - return self.runInteraction("get_state_group_delta", _get_state_group_delta_txn) + return self.db.runInteraction( + "get_state_group_delta", _get_state_group_delta_txn + ) @defer.inlineCallbacks def get_state_groups_ids(self, _room_id, event_ids): @@ -540,7 +544,7 @@ class StateGroupWorkerStore( chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] for chunk in chunks: - res = yield self.runInteraction( + res = yield self.db.runInteraction( "_get_state_groups_from_groups", self._get_state_groups_from_groups_txn, chunk, @@ -644,7 +648,7 @@ class StateGroupWorkerStore( @cached(max_entries=50000) def _get_state_group_for_event(self, event_id): - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="event_to_state_groups", keyvalues={"event_id": event_id}, retcol="state_group", @@ -661,7 +665,7 @@ class StateGroupWorkerStore( def _get_state_group_for_events(self, event_ids): """Returns mapping event_id -> state_group """ - rows = yield self.simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="event_to_state_groups", column="event_id", iterable=event_ids, @@ -902,7 +906,7 @@ class StateGroupWorkerStore( state_group = self.database_engine.get_next_state_group_id(txn) - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="state_groups", values={"id": state_group, "room_id": room_id, "event_id": event_id}, @@ -911,7 +915,7 @@ class StateGroupWorkerStore( # We persist as a delta if we can, while also ensuring the chain # of deltas isn't tooo long, as otherwise read performance degrades. if prev_group: - is_in_db = self.simple_select_one_onecol_txn( + is_in_db = self.db.simple_select_one_onecol_txn( txn, table="state_groups", keyvalues={"id": prev_group}, @@ -926,13 +930,13 @@ class StateGroupWorkerStore( potential_hops = self._count_state_group_hops_txn(txn, prev_group) if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="state_group_edges", values={"state_group": state_group, "prev_state_group": prev_group}, ) - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -947,7 +951,7 @@ class StateGroupWorkerStore( ], ) else: - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -993,7 +997,7 @@ class StateGroupWorkerStore( return state_group - return self.runInteraction("store_state_group", _store_state_group_txn) + return self.db.runInteraction("store_state_group", _store_state_group_txn) @defer.inlineCallbacks def get_referenced_state_groups(self, state_groups): @@ -1007,7 +1011,7 @@ class StateGroupWorkerStore( referenced. """ - rows = yield self.simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="event_to_state_groups", column="state_group", iterable=state_groups, @@ -1065,7 +1069,7 @@ class StateBackgroundUpdateStore( batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR)) if max_group is None: - rows = yield self.execute( + rows = yield self.db.execute( "_background_deduplicate_state", None, "SELECT coalesce(max(id), 0) FROM state_groups", @@ -1135,13 +1139,13 @@ class StateBackgroundUpdateStore( if prev_state.get(key, None) != value } - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="state_group_edges", keyvalues={"state_group": state_group}, ) - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="state_group_edges", values={ @@ -1150,13 +1154,13 @@ class StateBackgroundUpdateStore( }, ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="state_groups_state", keyvalues={"state_group": state_group}, ) - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -1183,7 +1187,7 @@ class StateBackgroundUpdateStore( return False, batch_size - finished, result = yield self.runInteraction( + finished, result = yield self.db.runInteraction( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn ) @@ -1218,7 +1222,7 @@ class StateBackgroundUpdateStore( ) txn.execute("DROP INDEX IF EXISTS state_groups_state_id") - yield self.runWithConnection(reindex_txn) + yield self.db.runWithConnection(reindex_txn) yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) @@ -1263,7 +1267,7 @@ class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore): state_groups[event.event_id] = context.state_group - self.simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="event_to_state_groups", values=[ diff --git a/synapse/storage/data_stores/main/state_deltas.py b/synapse/storage/data_stores/main/state_deltas.py index 03b908026b..12c982cb26 100644 --- a/synapse/storage/data_stores/main/state_deltas.py +++ b/synapse/storage/data_stores/main/state_deltas.py @@ -98,14 +98,14 @@ class StateDeltasStore(SQLBaseStore): ORDER BY stream_id ASC """ txn.execute(sql, (prev_stream_id, clipped_stream_id)) - return clipped_stream_id, self.cursor_to_dict(txn) + return clipped_stream_id, self.db.cursor_to_dict(txn) - return self.runInteraction( + return self.db.runInteraction( "get_current_state_deltas", get_current_state_deltas_txn ) def _get_max_stream_id_in_current_state_deltas_txn(self, txn): - return self.simple_select_one_onecol_txn( + return self.db.simple_select_one_onecol_txn( txn, table="current_state_delta_stream", keyvalues={}, @@ -113,7 +113,7 @@ class StateDeltasStore(SQLBaseStore): ) def get_max_stream_id_in_current_state_deltas(self): - return self.runInteraction( + return self.db.runInteraction( "get_max_stream_id_in_current_state_deltas", self._get_max_stream_id_in_current_state_deltas_txn, ) diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py index 3aeba859fd..974ffc15bd 100644 --- a/synapse/storage/data_stores/main/stats.py +++ b/synapse/storage/data_stores/main/stats.py @@ -117,7 +117,7 @@ class StatsStore(StateDeltasStore): txn.execute(sql, (last_user_id, batch_size)) return [r for r, in txn] - users_to_work_on = yield self.runInteraction( + users_to_work_on = yield self.db.runInteraction( "_populate_stats_process_users", _get_next_batch ) @@ -130,7 +130,7 @@ class StatsStore(StateDeltasStore): yield self._calculate_and_set_initial_state_for_user(user_id) progress["last_user_id"] = user_id - yield self.runInteraction( + yield self.db.runInteraction( "populate_stats_process_users", self._background_update_progress_txn, "populate_stats_process_users", @@ -160,7 +160,7 @@ class StatsStore(StateDeltasStore): txn.execute(sql, (last_room_id, batch_size)) return [r for r, in txn] - rooms_to_work_on = yield self.runInteraction( + rooms_to_work_on = yield self.db.runInteraction( "populate_stats_rooms_get_batch", _get_next_batch ) @@ -173,7 +173,7 @@ class StatsStore(StateDeltasStore): yield self._calculate_and_set_initial_state_for_room(room_id) progress["last_room_id"] = room_id - yield self.runInteraction( + yield self.db.runInteraction( "_populate_stats_process_rooms", self._background_update_progress_txn, "populate_stats_process_rooms", @@ -186,7 +186,7 @@ class StatsStore(StateDeltasStore): """ Returns the stats processor positions. """ - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="stats_incremental_position", keyvalues={}, retcol="stream_id", @@ -215,7 +215,7 @@ class StatsStore(StateDeltasStore): if field and "\0" in field: fields[col] = None - return self.simple_upsert( + return self.db.simple_upsert( table="room_stats_state", keyvalues={"room_id": room_id}, values=fields, @@ -236,7 +236,7 @@ class StatsStore(StateDeltasStore): Deferred[list[dict]], where the dict has the keys of ABSOLUTE_STATS_FIELDS[stats_type], and "bucket_size" and "end_ts". """ - return self.runInteraction( + return self.db.runInteraction( "get_statistics_for_subject", self._get_statistics_for_subject_txn, stats_type, @@ -257,7 +257,7 @@ class StatsStore(StateDeltasStore): ABSOLUTE_STATS_FIELDS[stats_type] + PER_SLICE_FIELDS[stats_type] ) - slice_list = self.simple_select_list_paginate_txn( + slice_list = self.db.simple_select_list_paginate_txn( txn, table + "_historical", {id_col: stats_id}, @@ -282,7 +282,7 @@ class StatsStore(StateDeltasStore): "name", "topic", "canonical_alias", "avatar", "join_rules", "history_visibility" """ - return self.simple_select_one( + return self.db.simple_select_one( "room_stats_state", {"room_id": room_id}, retcols=( @@ -308,7 +308,7 @@ class StatsStore(StateDeltasStore): """ table, id_col = TYPE_TO_TABLE[stats_type] - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( "%s_current" % (table,), keyvalues={id_col: id}, retcol="completed_delta_stream_id", @@ -344,14 +344,14 @@ class StatsStore(StateDeltasStore): complete_with_stream_id=stream_id, ) - self.simple_update_one_txn( + self.db.simple_update_one_txn( txn, table="stats_incremental_position", keyvalues={}, updatevalues={"stream_id": stream_id}, ) - return self.runInteraction( + return self.db.runInteraction( "bulk_update_stats_delta", _bulk_update_stats_delta_txn ) @@ -382,7 +382,7 @@ class StatsStore(StateDeltasStore): Does not work with per-slice fields. """ - return self.runInteraction( + return self.db.runInteraction( "update_stats_delta", self._update_stats_delta_txn, ts, @@ -517,17 +517,17 @@ class StatsStore(StateDeltasStore): else: self.database_engine.lock_table(txn, table) retcols = list(chain(absolutes.keys(), additive_relatives.keys())) - current_row = self.simple_select_one_txn( + current_row = self.db.simple_select_one_txn( txn, table, keyvalues, retcols, allow_none=True ) if current_row is None: merged_dict = {**keyvalues, **absolutes, **additive_relatives} - self.simple_insert_txn(txn, table, merged_dict) + self.db.simple_insert_txn(txn, table, merged_dict) else: for (key, val) in additive_relatives.items(): current_row[key] += val current_row.update(absolutes) - self.simple_update_one_txn(txn, table, keyvalues, current_row) + self.db.simple_update_one_txn(txn, table, keyvalues, current_row) def _upsert_copy_from_table_with_additive_relatives_txn( self, @@ -614,11 +614,11 @@ class StatsStore(StateDeltasStore): txn.execute(sql, qargs) else: self.database_engine.lock_table(txn, into_table) - src_row = self.simple_select_one_txn( + src_row = self.db.simple_select_one_txn( txn, src_table, keyvalues, copy_columns ) all_dest_keyvalues = {**keyvalues, **extra_dst_keyvalues} - dest_current_row = self.simple_select_one_txn( + dest_current_row = self.db.simple_select_one_txn( txn, into_table, keyvalues=all_dest_keyvalues, @@ -634,11 +634,11 @@ class StatsStore(StateDeltasStore): **src_row, **additive_relatives, } - self.simple_insert_txn(txn, into_table, merged_dict) + self.db.simple_insert_txn(txn, into_table, merged_dict) else: for (key, val) in additive_relatives.items(): src_row[key] = dest_current_row[key] + val - self.simple_update_txn(txn, into_table, all_dest_keyvalues, src_row) + self.db.simple_update_txn(txn, into_table, all_dest_keyvalues, src_row) def get_changes_room_total_events_and_bytes(self, min_pos, max_pos): """Fetches the counts of events in the given range of stream IDs. @@ -652,7 +652,7 @@ class StatsStore(StateDeltasStore): changes. """ - return self.runInteraction( + return self.db.runInteraction( "stats_incremental_total_events_and_bytes", self.get_changes_room_total_events_and_bytes_txn, min_pos, @@ -735,7 +735,7 @@ class StatsStore(StateDeltasStore): def _fetch_current_state_stats(txn): pos = self.get_room_max_stream_ordering() - rows = self.simple_select_many_txn( + rows = self.db.simple_select_many_txn( txn, table="current_state_events", column="type", @@ -791,7 +791,7 @@ class StatsStore(StateDeltasStore): current_state_events_count, users_in_room, pos, - ) = yield self.runInteraction( + ) = yield self.db.runInteraction( "get_initial_state_for_room", _fetch_current_state_stats ) @@ -866,7 +866,7 @@ class StatsStore(StateDeltasStore): (count,) = txn.fetchone() return count, pos - joined_rooms, pos = yield self.runInteraction( + joined_rooms, pos = yield self.db.runInteraction( "calculate_and_set_initial_state_for_user", _calculate_and_set_initial_state_for_user_txn, ) diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py index 60487c4559..2ff8c57109 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -255,7 +255,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): super(StreamWorkerStore, self).__init__(db_conn, hs) events_max = self.get_room_max_stream_ordering() - event_cache_prefill, min_event_val = self.get_cache_dict( + event_cache_prefill, min_event_val = self.db.get_cache_dict( db_conn, "events", entity_column="room_id", @@ -400,7 +400,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] return rows - rows = yield self.runInteraction("get_room_events_stream_for_room", f) + rows = yield self.db.runInteraction("get_room_events_stream_for_room", f) ret = yield self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True @@ -450,7 +450,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return rows - rows = yield self.runInteraction("get_membership_changes_for_user", f) + rows = yield self.db.runInteraction("get_membership_changes_for_user", f) ret = yield self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True @@ -511,7 +511,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): end_token = RoomStreamToken.parse(end_token) - rows, token = yield self.runInteraction( + rows, token = yield self.db.runInteraction( "get_recent_event_ids_for_room", self._paginate_room_events_txn, room_id, @@ -548,7 +548,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): txn.execute(sql, (room_id, stream_ordering)) return txn.fetchone() - return self.runInteraction("get_room_event_after_stream_ordering", _f) + return self.db.runInteraction("get_room_event_after_stream_ordering", _f) @defer.inlineCallbacks def get_room_events_max_id(self, room_id=None): @@ -562,7 +562,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if room_id is None: return "s%d" % (token,) else: - topo = yield self.runInteraction( + topo = yield self.db.runInteraction( "_get_max_topological_txn", self._get_max_topological_txn, room_id ) return "t%d-%d" % (topo, token) @@ -576,7 +576,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): Returns: A deferred "s%d" stream token. """ - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering" ).addCallback(lambda row: "s%d" % (row,)) @@ -589,7 +589,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): Returns: A deferred "t%d-%d" topological token. """ - return self.simple_select_one( + return self.db.simple_select_one( table="events", keyvalues={"event_id": event_id}, retcols=("stream_ordering", "topological_ordering"), @@ -613,7 +613,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): "SELECT coalesce(max(topological_ordering), 0) FROM events" " WHERE room_id = ? AND stream_ordering < ?" ) - return self.execute( + return self.db.execute( "get_max_topological_token", None, sql, room_id, stream_key ).addCallback(lambda r: r[0][0] if r else 0) @@ -667,7 +667,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): dict """ - results = yield self.runInteraction( + results = yield self.db.runInteraction( "get_events_around", self._get_events_around_txn, room_id, @@ -709,7 +709,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): dict """ - results = self.simple_select_one_txn( + results = self.db.simple_select_one_txn( txn, "events", keyvalues={"event_id": event_id, "room_id": room_id}, @@ -788,7 +788,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return upper_bound, [row[1] for row in rows] - upper_bound, event_ids = yield self.runInteraction( + upper_bound, event_ids = yield self.db.runInteraction( "get_all_new_events_stream", get_all_new_events_stream_txn ) @@ -797,7 +797,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return upper_bound, events def get_federation_out_pos(self, typ): - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="federation_stream_position", retcol="stream_id", keyvalues={"type": typ}, @@ -805,7 +805,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) def update_federation_out_pos(self, typ, stream_id): - return self.simple_update_one( + return self.db.simple_update_one( table="federation_stream_position", keyvalues={"type": typ}, updatevalues={"stream_id": stream_id}, @@ -956,7 +956,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if to_key: to_key = RoomStreamToken.parse(to_key) - rows, token = yield self.runInteraction( + rows, token = yield self.db.runInteraction( "paginate_room_events", self._paginate_room_events_txn, room_id, diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/data_stores/main/tags.py index 85012403be..2aa1bafd48 100644 --- a/synapse/storage/data_stores/main/tags.py +++ b/synapse/storage/data_stores/main/tags.py @@ -41,7 +41,7 @@ class TagsWorkerStore(AccountDataWorkerStore): tag strings to tag content. """ - deferred = self.simple_select_list( + deferred = self.db.simple_select_list( "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"] ) @@ -78,7 +78,7 @@ class TagsWorkerStore(AccountDataWorkerStore): txn.execute(sql, (last_id, current_id, limit)) return txn.fetchall() - tag_ids = yield self.runInteraction( + tag_ids = yield self.db.runInteraction( "get_all_updated_tags", get_all_updated_tags_txn ) @@ -98,7 +98,7 @@ class TagsWorkerStore(AccountDataWorkerStore): batch_size = 50 results = [] for i in range(0, len(tag_ids), batch_size): - tags = yield self.runInteraction( + tags = yield self.db.runInteraction( "get_all_updated_tag_content", get_tag_content, tag_ids[i : i + batch_size], @@ -135,7 +135,9 @@ class TagsWorkerStore(AccountDataWorkerStore): if not changed: return {} - room_ids = yield self.runInteraction("get_updated_tags", get_updated_tags_txn) + room_ids = yield self.db.runInteraction( + "get_updated_tags", get_updated_tags_txn + ) results = {} if room_ids: @@ -153,7 +155,7 @@ class TagsWorkerStore(AccountDataWorkerStore): Returns: A deferred list of string tags. """ - return self.simple_select_list( + return self.db.simple_select_list( table="room_tags", keyvalues={"user_id": user_id, "room_id": room_id}, retcols=("tag", "content"), @@ -178,7 +180,7 @@ class TagsStore(TagsWorkerStore): content_json = json.dumps(content) def add_tag_txn(txn, next_id): - self.simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="room_tags", keyvalues={"user_id": user_id, "room_id": room_id, "tag": tag}, @@ -187,7 +189,7 @@ class TagsStore(TagsWorkerStore): self._update_revision_txn(txn, user_id, room_id, next_id) with self._account_data_id_gen.get_next() as next_id: - yield self.runInteraction("add_tag", add_tag_txn, next_id) + yield self.db.runInteraction("add_tag", add_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) @@ -210,7 +212,7 @@ class TagsStore(TagsWorkerStore): self._update_revision_txn(txn, user_id, room_id, next_id) with self._account_data_id_gen.get_next() as next_id: - yield self.runInteraction("remove_tag", remove_tag_txn, next_id) + yield self.db.runInteraction("remove_tag", remove_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) diff --git a/synapse/storage/data_stores/main/transactions.py b/synapse/storage/data_stores/main/transactions.py index c162f3ea16..c0d155a43c 100644 --- a/synapse/storage/data_stores/main/transactions.py +++ b/synapse/storage/data_stores/main/transactions.py @@ -77,7 +77,7 @@ class TransactionStore(SQLBaseStore): this transaction or a 2-tuple of (int, dict) """ - return self.runInteraction( + return self.db.runInteraction( "get_received_txn_response", self._get_received_txn_response, transaction_id, @@ -85,7 +85,7 @@ class TransactionStore(SQLBaseStore): ) def _get_received_txn_response(self, txn, transaction_id, origin): - result = self.simple_select_one_txn( + result = self.db.simple_select_one_txn( txn, table="received_transactions", keyvalues={"transaction_id": transaction_id, "origin": origin}, @@ -119,7 +119,7 @@ class TransactionStore(SQLBaseStore): response_json (str) """ - return self.simple_insert( + return self.db.simple_insert( table="received_transactions", values={ "transaction_id": transaction_id, @@ -148,7 +148,7 @@ class TransactionStore(SQLBaseStore): if result is not SENTINEL: return result - result = yield self.runInteraction( + result = yield self.db.runInteraction( "get_destination_retry_timings", self._get_destination_retry_timings, destination, @@ -160,7 +160,7 @@ class TransactionStore(SQLBaseStore): return result def _get_destination_retry_timings(self, txn, destination): - result = self.simple_select_one_txn( + result = self.db.simple_select_one_txn( txn, table="destinations", keyvalues={"destination": destination}, @@ -187,7 +187,7 @@ class TransactionStore(SQLBaseStore): """ self._destination_retry_cache.pop(destination, None) - return self.runInteraction( + return self.db.runInteraction( "set_destination_retry_timings", self._set_destination_retry_timings, destination, @@ -227,7 +227,7 @@ class TransactionStore(SQLBaseStore): # We need to be careful here as the data may have changed from under us # due to a worker setting the timings. - prev_row = self.simple_select_one_txn( + prev_row = self.db.simple_select_one_txn( txn, table="destinations", keyvalues={"destination": destination}, @@ -236,7 +236,7 @@ class TransactionStore(SQLBaseStore): ) if not prev_row: - self.simple_insert_txn( + self.db.simple_insert_txn( txn, table="destinations", values={ @@ -247,7 +247,7 @@ class TransactionStore(SQLBaseStore): }, ) elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval: - self.simple_update_one_txn( + self.db.simple_update_one_txn( txn, "destinations", keyvalues={"destination": destination}, @@ -270,4 +270,6 @@ class TransactionStore(SQLBaseStore): def _cleanup_transactions_txn(txn): txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) - return self.runInteraction("_cleanup_transactions", _cleanup_transactions_txn) + return self.db.runInteraction( + "_cleanup_transactions", _cleanup_transactions_txn + ) diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/data_stores/main/user_directory.py index 1a85aabbfb..7118bd62f3 100644 --- a/synapse/storage/data_stores/main/user_directory.py +++ b/synapse/storage/data_stores/main/user_directory.py @@ -85,7 +85,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore """ txn.execute(sql) rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()] - self.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms) + self.db.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms) del rooms # If search all users is on, get all the users we want to add. @@ -100,13 +100,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore txn.execute("SELECT name FROM users") users = [{"user_id": x[0]} for x in txn.fetchall()] - self.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users) + self.db.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users) new_pos = yield self.get_max_stream_id_in_current_state_deltas() - yield self.runInteraction( + yield self.db.runInteraction( "populate_user_directory_temp_build", _make_staging_area ) - yield self.simple_insert(TEMP_TABLE + "_position", {"position": new_pos}) + yield self.db.simple_insert(TEMP_TABLE + "_position", {"position": new_pos}) yield self._end_background_update("populate_user_directory_createtables") return 1 @@ -116,7 +116,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore """ Update the user directory stream position, then clean up the old tables. """ - position = yield self.simple_select_one_onecol( + position = yield self.db.simple_select_one_onecol( TEMP_TABLE + "_position", None, "position" ) yield self.update_user_directory_stream_pos(position) @@ -126,7 +126,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users") txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position") - yield self.runInteraction( + yield self.db.runInteraction( "populate_user_directory_cleanup", _delete_staging_area ) @@ -170,7 +170,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore return rooms_to_work_on - rooms_to_work_on = yield self.runInteraction( + rooms_to_work_on = yield self.db.runInteraction( "populate_user_directory_temp_read", _get_next_batch ) @@ -243,10 +243,10 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore to_insert.clear() # We've finished a room. Delete it from the table. - yield self.simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id}) + yield self.db.simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id}) # Update the remaining counter. progress["remaining"] -= 1 - yield self.runInteraction( + yield self.db.runInteraction( "populate_user_directory", self._background_update_progress_txn, "populate_user_directory_process_rooms", @@ -291,7 +291,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore return users_to_work_on - users_to_work_on = yield self.runInteraction( + users_to_work_on = yield self.db.runInteraction( "populate_user_directory_temp_read", _get_next_batch ) @@ -312,10 +312,10 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore ) # We've finished processing a user. Delete it from the table. - yield self.simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id}) + yield self.db.simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id}) # Update the remaining counter. progress["remaining"] -= 1 - yield self.runInteraction( + yield self.db.runInteraction( "populate_user_directory", self._background_update_progress_txn, "populate_user_directory_process_users", @@ -361,7 +361,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore """ def _update_profile_in_user_dir_txn(txn): - new_entry = self.simple_upsert_txn( + new_entry = self.db.simple_upsert_txn( txn, table="user_directory", keyvalues={"user_id": user_id}, @@ -435,7 +435,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore ) elif isinstance(self.database_engine, Sqlite3Engine): value = "%s %s" % (user_id, display_name) if display_name else user_id - self.simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="user_directory_search", keyvalues={"user_id": user_id}, @@ -448,7 +448,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) - return self.runInteraction( + return self.db.runInteraction( "update_profile_in_user_dir", _update_profile_in_user_dir_txn ) @@ -462,7 +462,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore """ def _add_users_who_share_room_txn(txn): - self.simple_upsert_many_txn( + self.db.simple_upsert_many_txn( txn, table="users_who_share_private_rooms", key_names=["user_id", "other_user_id", "room_id"], @@ -474,7 +474,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore value_values=None, ) - return self.runInteraction( + return self.db.runInteraction( "add_users_who_share_room", _add_users_who_share_room_txn ) @@ -489,7 +489,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore def _add_users_in_public_rooms_txn(txn): - self.simple_upsert_many_txn( + self.db.simple_upsert_many_txn( txn, table="users_in_public_rooms", key_names=["user_id", "room_id"], @@ -498,7 +498,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore value_values=None, ) - return self.runInteraction( + return self.db.runInteraction( "add_users_in_public_rooms", _add_users_in_public_rooms_txn ) @@ -513,13 +513,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore txn.execute("DELETE FROM users_who_share_private_rooms") txn.call_after(self.get_user_in_directory.invalidate_all) - return self.runInteraction( + return self.db.runInteraction( "delete_all_from_user_dir", _delete_all_from_user_dir_txn ) @cached() def get_user_in_directory(self, user_id): - return self.simple_select_one( + return self.db.simple_select_one( table="user_directory", keyvalues={"user_id": user_id}, retcols=("display_name", "avatar_url"), @@ -528,7 +528,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore ) def update_user_directory_stream_pos(self, stream_id): - return self.simple_update_one( + return self.db.simple_update_one( table="user_directory_stream_pos", keyvalues={}, updatevalues={"stream_id": stream_id}, @@ -547,42 +547,42 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): def remove_from_user_dir(self, user_id): def _remove_from_user_dir_txn(txn): - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="user_directory", keyvalues={"user_id": user_id} ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="user_directory_search", keyvalues={"user_id": user_id} ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="users_in_public_rooms", keyvalues={"user_id": user_id} ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"user_id": user_id}, ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"other_user_id": user_id}, ) txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) - return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn) + return self.db.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn) @defer.inlineCallbacks def get_users_in_dir_due_to_room(self, room_id): """Get all user_ids that are in the room directory because they're in the given room_id """ - user_ids_share_pub = yield self.simple_select_onecol( + user_ids_share_pub = yield self.db.simple_select_onecol( table="users_in_public_rooms", keyvalues={"room_id": room_id}, retcol="user_id", desc="get_users_in_dir_due_to_room", ) - user_ids_share_priv = yield self.simple_select_onecol( + user_ids_share_priv = yield self.db.simple_select_onecol( table="users_who_share_private_rooms", keyvalues={"room_id": room_id}, retcol="other_user_id", @@ -605,23 +605,23 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): """ def _remove_user_who_share_room_txn(txn): - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"user_id": user_id, "room_id": room_id}, ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"other_user_id": user_id, "room_id": room_id}, ) - self.simple_delete_txn( + self.db.simple_delete_txn( txn, table="users_in_public_rooms", keyvalues={"user_id": user_id, "room_id": room_id}, ) - return self.runInteraction( + return self.db.runInteraction( "remove_user_who_share_room", _remove_user_who_share_room_txn ) @@ -636,14 +636,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): Returns: list: user_id """ - rows = yield self.simple_select_onecol( + rows = yield self.db.simple_select_onecol( table="users_who_share_private_rooms", keyvalues={"user_id": user_id}, retcol="room_id", desc="get_rooms_user_is_in", ) - pub_rows = yield self.simple_select_onecol( + pub_rows = yield self.db.simple_select_onecol( table="users_in_public_rooms", keyvalues={"user_id": user_id}, retcol="room_id", @@ -674,14 +674,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): ) f2 USING (room_id) """ - rows = yield self.execute( + rows = yield self.db.execute( "get_rooms_in_common_for_users", None, sql, user_id, other_user_id ) return [room_id for room_id, in rows] def get_user_directory_stream_pos(self): - return self.simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="user_directory_stream_pos", keyvalues={}, retcol="stream_id", @@ -786,7 +786,9 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): # This should be unreachable. raise Exception("Unrecognized database engine") - results = yield self.execute("search_user_dir", self.cursor_to_dict, sql, *args) + results = yield self.db.execute( + "search_user_dir", self.db.cursor_to_dict, sql, *args + ) limited = len(results) > limit diff --git a/synapse/storage/data_stores/main/user_erasure_store.py b/synapse/storage/data_stores/main/user_erasure_store.py index 37860af070..af8025bc17 100644 --- a/synapse/storage/data_stores/main/user_erasure_store.py +++ b/synapse/storage/data_stores/main/user_erasure_store.py @@ -31,7 +31,7 @@ class UserErasureWorkerStore(SQLBaseStore): Returns: Deferred[bool]: True if the user has requested erasure """ - return self.simple_select_onecol( + return self.db.simple_select_onecol( table="erased_users", keyvalues={"user_id": user_id}, retcol="1", @@ -56,7 +56,7 @@ class UserErasureWorkerStore(SQLBaseStore): # iterate it multiple times, and (b) avoiding duplicates. user_ids = tuple(set(user_ids)) - rows = yield self.simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="erased_users", column="user_id", iterable=user_ids, @@ -88,4 +88,4 @@ class UserErasureStore(UserErasureWorkerStore): self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,)) - return self.runInteraction("mark_user_erased", f) + return self.db.runInteraction("mark_user_erased", f) diff --git a/synapse/storage/database.py b/synapse/storage/database.py new file mode 100644 index 0000000000..c2e121a001 --- /dev/null +++ b/synapse/storage/database.py @@ -0,0 +1,1485 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import random +import sys +import time +from typing import Iterable, Tuple + +from six import iteritems, iterkeys, itervalues +from six.moves import intern, range + +from prometheus_client import Histogram + +from twisted.internet import defer + +from synapse.api.errors import StoreError +from synapse.logging.context import LoggingContext, make_deferred_yieldable +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.util.stringutils import exception_to_unicode + +# import a function which will return a monotonic time, in seconds +try: + # on python 3, use time.monotonic, since time.clock can go backwards + from time import monotonic as monotonic_time +except ImportError: + # ... but python 2 doesn't have it + from time import clock as monotonic_time + +logger = logging.getLogger(__name__) + +try: + MAX_TXN_ID = sys.maxint - 1 +except AttributeError: + # python 3 does not have a maximum int value + MAX_TXN_ID = 2 ** 63 - 1 + +sql_logger = logging.getLogger("synapse.storage.SQL") +transaction_logger = logging.getLogger("synapse.storage.txn") +perf_logger = logging.getLogger("synapse.storage.TIME") + +sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec") + +sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"]) +sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"]) + + +# Unique indexes which have been added in background updates. Maps from table name +# to the name of the background update which added the unique index to that table. +# +# This is used by the upsert logic to figure out which tables are safe to do a proper +# UPSERT on: until the relevant background update has completed, we +# have to emulate an upsert by locking the table. +# +UNIQUE_INDEX_BACKGROUND_UPDATES = { + "user_ips": "user_ips_device_unique_index", + "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx", + "device_lists_remote_cache": "device_lists_remote_cache_unique_idx", + "event_search": "event_search_event_id_idx", +} + + +class LoggingTransaction(object): + """An object that almost-transparently proxies for the 'txn' object + passed to the constructor. Adds logging and metrics to the .execute() + method. + + Args: + txn: The database transcation object to wrap. + name (str): The name of this transactions for logging. + database_engine (Sqlite3Engine|PostgresEngine) + after_callbacks(list|None): A list that callbacks will be appended to + that have been added by `call_after` which should be run on + successful completion of the transaction. None indicates that no + callbacks should be allowed to be scheduled to run. + exception_callbacks(list|None): A list that callbacks will be appended + to that have been added by `call_on_exception` which should be run + if transaction ends with an error. None indicates that no callbacks + should be allowed to be scheduled to run. + """ + + __slots__ = [ + "txn", + "name", + "database_engine", + "after_callbacks", + "exception_callbacks", + ] + + def __init__( + self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None + ): + object.__setattr__(self, "txn", txn) + object.__setattr__(self, "name", name) + object.__setattr__(self, "database_engine", database_engine) + object.__setattr__(self, "after_callbacks", after_callbacks) + object.__setattr__(self, "exception_callbacks", exception_callbacks) + + def call_after(self, callback, *args, **kwargs): + """Call the given callback on the main twisted thread after the + transaction has finished. Used to invalidate the caches on the + correct thread. + """ + self.after_callbacks.append((callback, args, kwargs)) + + def call_on_exception(self, callback, *args, **kwargs): + self.exception_callbacks.append((callback, args, kwargs)) + + def __getattr__(self, name): + return getattr(self.txn, name) + + def __setattr__(self, name, value): + setattr(self.txn, name, value) + + def __iter__(self): + return self.txn.__iter__() + + def execute_batch(self, sql, args): + if isinstance(self.database_engine, PostgresEngine): + from psycopg2.extras import execute_batch + + self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args) + else: + for val in args: + self.execute(sql, val) + + def execute(self, sql, *args): + self._do_execute(self.txn.execute, sql, *args) + + def executemany(self, sql, *args): + self._do_execute(self.txn.executemany, sql, *args) + + def _make_sql_one_line(self, sql): + "Strip newlines out of SQL so that the loggers in the DB are on one line" + return " ".join(l.strip() for l in sql.splitlines() if l.strip()) + + def _do_execute(self, func, sql, *args): + sql = self._make_sql_one_line(sql) + + # TODO(paul): Maybe use 'info' and 'debug' for values? + sql_logger.debug("[SQL] {%s} %s", self.name, sql) + + sql = self.database_engine.convert_param_style(sql) + if args: + try: + sql_logger.debug("[SQL values] {%s} %r", self.name, args[0]) + except Exception: + # Don't let logging failures stop SQL from working + pass + + start = time.time() + + try: + return func(sql, *args) + except Exception as e: + logger.debug("[SQL FAIL] {%s} %s", self.name, e) + raise + finally: + secs = time.time() - start + sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs) + sql_query_timer.labels(sql.split()[0]).observe(secs) + + +class PerformanceCounters(object): + def __init__(self): + self.current_counters = {} + self.previous_counters = {} + + def update(self, key, duration_secs): + count, cum_time = self.current_counters.get(key, (0, 0)) + count += 1 + cum_time += duration_secs + self.current_counters[key] = (count, cum_time) + + def interval(self, interval_duration_secs, limit=3): + counters = [] + for name, (count, cum_time) in iteritems(self.current_counters): + prev_count, prev_time = self.previous_counters.get(name, (0, 0)) + counters.append( + ( + (cum_time - prev_time) / interval_duration_secs, + count - prev_count, + name, + ) + ) + + self.previous_counters = dict(self.current_counters) + + counters.sort(reverse=True) + + top_n_counters = ", ".join( + "%s(%d): %.3f%%" % (name, count, 100 * ratio) + for ratio, count, name in counters[:limit] + ) + + return top_n_counters + + +class Database(object): + _TXN_ID = 0 + + def __init__(self, hs): + self.hs = hs + self._clock = hs.get_clock() + self._db_pool = hs.get_db_pool() + + self._previous_txn_total_time = 0 + self._current_txn_total_time = 0 + self._previous_loop_ts = 0 + + # TODO(paul): These can eventually be removed once the metrics code + # is running in mainline, and we have some nice monitoring frontends + # to watch it + self._txn_perf_counters = PerformanceCounters() + + self.database_engine = hs.database_engine + + # A set of tables that are not safe to use native upserts in. + self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys()) + + # We add the user_directory_search table to the blacklist on SQLite + # because the existing search table does not have an index, making it + # unsafe to use native upserts. + if isinstance(self.database_engine, Sqlite3Engine): + self._unsafe_to_upsert_tables.add("user_directory_search") + + if self.database_engine.can_native_upsert: + # Check ASAP (and then later, every 1s) to see if we have finished + # background updates of tables that aren't safe to update. + self._clock.call_later( + 0.0, + run_as_background_process, + "upsert_safety_check", + self._check_safe_to_upsert, + ) + + self.rand = random.SystemRandom() + + @defer.inlineCallbacks + def _check_safe_to_upsert(self): + """ + Is it safe to use native UPSERT? + + If there are background updates, we will need to wait, as they may be + the addition of indexes that set the UNIQUE constraint that we require. + + If the background updates have not completed, wait 15 sec and check again. + """ + updates = yield self.simple_select_list( + "background_updates", + keyvalues=None, + retcols=["update_name"], + desc="check_background_updates", + ) + updates = [x["update_name"] for x in updates] + + for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items(): + if update_name not in updates: + logger.debug("Now safe to upsert in %s", table) + self._unsafe_to_upsert_tables.discard(table) + + # If there's any updates still running, reschedule to run. + if updates: + self._clock.call_later( + 15.0, + run_as_background_process, + "upsert_safety_check", + self._check_safe_to_upsert, + ) + + def start_profiling(self): + self._previous_loop_ts = monotonic_time() + + def loop(): + curr = self._current_txn_total_time + prev = self._previous_txn_total_time + self._previous_txn_total_time = curr + + time_now = monotonic_time() + time_then = self._previous_loop_ts + self._previous_loop_ts = time_now + + duration = time_now - time_then + ratio = (curr - prev) / duration + + top_three_counters = self._txn_perf_counters.interval(duration, limit=3) + + perf_logger.info( + "Total database time: %.3f%% {%s}", ratio * 100, top_three_counters + ) + + self._clock.looping_call(loop, 10000) + + def new_transaction( + self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs + ): + start = monotonic_time() + txn_id = self._TXN_ID + + # We don't really need these to be unique, so lets stop it from + # growing really large. + self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID) + + name = "%s-%x" % (desc, txn_id) + + transaction_logger.debug("[TXN START] {%s}", name) + + try: + i = 0 + N = 5 + while True: + cursor = LoggingTransaction( + conn.cursor(), + name, + self.database_engine, + after_callbacks, + exception_callbacks, + ) + try: + r = func(cursor, *args, **kwargs) + conn.commit() + return r + except self.database_engine.module.OperationalError as e: + # This can happen if the database disappears mid + # transaction. + logger.warning( + "[TXN OPERROR] {%s} %s %d/%d", + name, + exception_to_unicode(e), + i, + N, + ) + if i < N: + i += 1 + try: + conn.rollback() + except self.database_engine.module.Error as e1: + logger.warning( + "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1) + ) + continue + raise + except self.database_engine.module.DatabaseError as e: + if self.database_engine.is_deadlock(e): + logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N) + if i < N: + i += 1 + try: + conn.rollback() + except self.database_engine.module.Error as e1: + logger.warning( + "[TXN EROLL] {%s} %s", + name, + exception_to_unicode(e1), + ) + continue + raise + finally: + # we're either about to retry with a new cursor, or we're about to + # release the connection. Once we release the connection, it could + # get used for another query, which might do a conn.rollback(). + # + # In the latter case, even though that probably wouldn't affect the + # results of this transaction, python's sqlite will reset all + # statements on the connection [1], which will make our cursor + # invalid [2]. + # + # In any case, continuing to read rows after commit()ing seems + # dubious from the PoV of ACID transactional semantics + # (sqlite explicitly says that once you commit, you may see rows + # from subsequent updates.) + # + # In psycopg2, cursors are essentially a client-side fabrication - + # all the data is transferred to the client side when the statement + # finishes executing - so in theory we could go on streaming results + # from the cursor, but attempting to do so would make us + # incompatible with sqlite, so let's make sure we're not doing that + # by closing the cursor. + # + # (*named* cursors in psycopg2 are different and are proper server- + # side things, but (a) we don't use them and (b) they are implicitly + # closed by ending the transaction anyway.) + # + # In short, if we haven't finished with the cursor yet, that's a + # problem waiting to bite us. + # + # TL;DR: we're done with the cursor, so we can close it. + # + # [1]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/connection.c#L465 + # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236 + cursor.close() + except Exception as e: + logger.debug("[TXN FAIL] {%s} %s", name, e) + raise + finally: + end = monotonic_time() + duration = end - start + + LoggingContext.current_context().add_database_transaction(duration) + + transaction_logger.debug("[TXN END] {%s} %f sec", name, duration) + + self._current_txn_total_time += duration + self._txn_perf_counters.update(desc, duration) + sql_txn_timer.labels(desc).observe(duration) + + @defer.inlineCallbacks + def runInteraction(self, desc, func, *args, **kwargs): + """Starts a transaction on the database and runs a given function + + Arguments: + desc (str): description of the transaction, for logging and metrics + func (func): callback function, which will be called with a + database transaction (twisted.enterprise.adbapi.Transaction) as + its first argument, followed by `args` and `kwargs`. + + args (list): positional args to pass to `func` + kwargs (dict): named args to pass to `func` + + Returns: + Deferred: The result of func + """ + after_callbacks = [] + exception_callbacks = [] + + if LoggingContext.current_context() == LoggingContext.sentinel: + logger.warning("Starting db txn '%s' from sentinel context", desc) + + try: + result = yield self.runWithConnection( + self.new_transaction, + desc, + after_callbacks, + exception_callbacks, + func, + *args, + **kwargs + ) + + for after_callback, after_args, after_kwargs in after_callbacks: + after_callback(*after_args, **after_kwargs) + except: # noqa: E722, as we reraise the exception this is fine. + for after_callback, after_args, after_kwargs in exception_callbacks: + after_callback(*after_args, **after_kwargs) + raise + + return result + + @defer.inlineCallbacks + def runWithConnection(self, func, *args, **kwargs): + """Wraps the .runWithConnection() method on the underlying db_pool. + + Arguments: + func (func): callback function, which will be called with a + database connection (twisted.enterprise.adbapi.Connection) as + its first argument, followed by `args` and `kwargs`. + args (list): positional args to pass to `func` + kwargs (dict): named args to pass to `func` + + Returns: + Deferred: The result of func + """ + parent_context = LoggingContext.current_context() + if parent_context == LoggingContext.sentinel: + logger.warning( + "Starting db connection from sentinel context: metrics will be lost" + ) + parent_context = None + + start_time = monotonic_time() + + def inner_func(conn, *args, **kwargs): + with LoggingContext("runWithConnection", parent_context) as context: + sched_duration_sec = monotonic_time() - start_time + sql_scheduling_timer.observe(sched_duration_sec) + context.add_database_scheduled(sched_duration_sec) + + if self.database_engine.is_connection_closed(conn): + logger.debug("Reconnecting closed database connection") + conn.reconnect() + + return func(conn, *args, **kwargs) + + result = yield make_deferred_yieldable( + self._db_pool.runWithConnection(inner_func, *args, **kwargs) + ) + + return result + + @staticmethod + def cursor_to_dict(cursor): + """Converts a SQL cursor into an list of dicts. + + Args: + cursor : The DBAPI cursor which has executed a query. + Returns: + A list of dicts where the key is the column header. + """ + col_headers = list(intern(str(column[0])) for column in cursor.description) + results = list(dict(zip(col_headers, row)) for row in cursor) + return results + + def execute(self, desc, decoder, query, *args): + """Runs a single query for a result set. + + Args: + decoder - The function which can resolve the cursor results to + something meaningful. + query - The query string to execute + *args - Query args. + Returns: + The result of decoder(results) + """ + + def interaction(txn): + txn.execute(query, args) + if decoder: + return decoder(txn) + else: + return txn.fetchall() + + return self.runInteraction(desc, interaction) + + # "Simple" SQL API methods that operate on a single table with no JOINs, + # no complex WHERE clauses, just a dict of values for columns. + + @defer.inlineCallbacks + def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"): + """Executes an INSERT query on the named table. + + Args: + table : string giving the table name + values : dict of new column names and values for them + or_ignore : bool stating whether an exception should be raised + when a conflicting row already exists. If True, False will be + returned by the function instead + desc : string giving a description of the transaction + + Returns: + bool: Whether the row was inserted or not. Only useful when + `or_ignore` is True + """ + try: + yield self.runInteraction(desc, self.simple_insert_txn, table, values) + except self.database_engine.module.IntegrityError: + # We have to do or_ignore flag at this layer, since we can't reuse + # a cursor after we receive an error from the db. + if not or_ignore: + raise + return False + return True + + @staticmethod + def simple_insert_txn(txn, table, values): + keys, vals = zip(*values.items()) + + sql = "INSERT INTO %s (%s) VALUES(%s)" % ( + table, + ", ".join(k for k in keys), + ", ".join("?" for _ in keys), + ) + + txn.execute(sql, vals) + + def simple_insert_many(self, table, values, desc): + return self.runInteraction(desc, self.simple_insert_many_txn, table, values) + + @staticmethod + def simple_insert_many_txn(txn, table, values): + if not values: + return + + # This is a *slight* abomination to get a list of tuples of key names + # and a list of tuples of value names. + # + # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}] + # => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)] + # + # The sort is to ensure that we don't rely on dictionary iteration + # order. + keys, vals = zip( + *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i] + ) + + for k in keys: + if k != keys[0]: + raise RuntimeError("All items must have the same keys") + + sql = "INSERT INTO %s (%s) VALUES(%s)" % ( + table, + ", ".join(k for k in keys[0]), + ", ".join("?" for _ in keys[0]), + ) + + txn.executemany(sql, vals) + + @defer.inlineCallbacks + def simple_upsert( + self, + table, + keyvalues, + values, + insertion_values={}, + desc="simple_upsert", + lock=True, + ): + """ + + `lock` should generally be set to True (the default), but can be set + to False if either of the following are true: + + * there is a UNIQUE INDEX on the key columns. In this case a conflict + will cause an IntegrityError in which case this function will retry + the update. + + * we somehow know that we are the only thread which will be updating + this table. + + Args: + table (str): The table to upsert into + keyvalues (dict): The unique key columns and their new values + values (dict): The nonunique columns and their new values + insertion_values (dict): additional key/values to use only when + inserting + lock (bool): True to lock the table when doing the upsert. + Returns: + Deferred(None or bool): Native upserts always return None. Emulated + upserts return True if a new entry was created, False if an existing + one was updated. + """ + attempts = 0 + while True: + try: + result = yield self.runInteraction( + desc, + self.simple_upsert_txn, + table, + keyvalues, + values, + insertion_values, + lock=lock, + ) + return result + except self.database_engine.module.IntegrityError as e: + attempts += 1 + if attempts >= 5: + # don't retry forever, because things other than races + # can cause IntegrityErrors + raise + + # presumably we raced with another transaction: let's retry. + logger.warning( + "IntegrityError when upserting into %s; retrying: %s", table, e + ) + + def simple_upsert_txn( + self, txn, table, keyvalues, values, insertion_values={}, lock=True + ): + """ + Pick the UPSERT method which works best on the platform. Either the + native one (Pg9.5+, recent SQLites), or fall back to an emulated method. + + Args: + txn: The transaction to use. + table (str): The table to upsert into + keyvalues (dict): The unique key tables and their new values + values (dict): The nonunique columns and their new values + insertion_values (dict): additional key/values to use only when + inserting + lock (bool): True to lock the table when doing the upsert. + Returns: + None or bool: Native upserts always return None. Emulated + upserts return True if a new entry was created, False if an existing + one was updated. + """ + if ( + self.database_engine.can_native_upsert + and table not in self._unsafe_to_upsert_tables + ): + return self.simple_upsert_txn_native_upsert( + txn, table, keyvalues, values, insertion_values=insertion_values + ) + else: + return self.simple_upsert_txn_emulated( + txn, + table, + keyvalues, + values, + insertion_values=insertion_values, + lock=lock, + ) + + def simple_upsert_txn_emulated( + self, txn, table, keyvalues, values, insertion_values={}, lock=True + ): + """ + Args: + table (str): The table to upsert into + keyvalues (dict): The unique key tables and their new values + values (dict): The nonunique columns and their new values + insertion_values (dict): additional key/values to use only when + inserting + lock (bool): True to lock the table when doing the upsert. + Returns: + bool: Return True if a new entry was created, False if an existing + one was updated. + """ + # We need to lock the table :(, unless we're *really* careful + if lock: + self.database_engine.lock_table(txn, table) + + def _getwhere(key): + # If the value we're passing in is None (aka NULL), we need to use + # IS, not =, as NULL = NULL equals NULL (False). + if keyvalues[key] is None: + return "%s IS ?" % (key,) + else: + return "%s = ?" % (key,) + + if not values: + # If `values` is empty, then all of the values we care about are in + # the unique key, so there is nothing to UPDATE. We can just do a + # SELECT instead to see if it exists. + sql = "SELECT 1 FROM %s WHERE %s" % ( + table, + " AND ".join(_getwhere(k) for k in keyvalues), + ) + sqlargs = list(keyvalues.values()) + txn.execute(sql, sqlargs) + if txn.fetchall(): + # We have an existing record. + return False + else: + # First try to update. + sql = "UPDATE %s SET %s WHERE %s" % ( + table, + ", ".join("%s = ?" % (k,) for k in values), + " AND ".join(_getwhere(k) for k in keyvalues), + ) + sqlargs = list(values.values()) + list(keyvalues.values()) + + txn.execute(sql, sqlargs) + if txn.rowcount > 0: + # successfully updated at least one row. + return False + + # We didn't find any existing rows, so insert a new one + allvalues = {} + allvalues.update(keyvalues) + allvalues.update(values) + allvalues.update(insertion_values) + + sql = "INSERT INTO %s (%s) VALUES (%s)" % ( + table, + ", ".join(k for k in allvalues), + ", ".join("?" for _ in allvalues), + ) + txn.execute(sql, list(allvalues.values())) + # successfully inserted + return True + + def simple_upsert_txn_native_upsert( + self, txn, table, keyvalues, values, insertion_values={} + ): + """ + Use the native UPSERT functionality in recent PostgreSQL versions. + + Args: + table (str): The table to upsert into + keyvalues (dict): The unique key tables and their new values + values (dict): The nonunique columns and their new values + insertion_values (dict): additional key/values to use only when + inserting + Returns: + None + """ + allvalues = {} + allvalues.update(keyvalues) + allvalues.update(insertion_values) + + if not values: + latter = "NOTHING" + else: + allvalues.update(values) + latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values) + + sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % ( + table, + ", ".join(k for k in allvalues), + ", ".join("?" for _ in allvalues), + ", ".join(k for k in keyvalues), + latter, + ) + txn.execute(sql, list(allvalues.values())) + + def simple_upsert_many_txn( + self, txn, table, key_names, key_values, value_names, value_values + ): + """ + Upsert, many times. + + Args: + table (str): The table to upsert into + key_names (list[str]): The key column names. + key_values (list[list]): A list of each row's key column values. + value_names (list[str]): The value column names. If empty, no + values will be used, even if value_values is provided. + value_values (list[list]): A list of each row's value column values. + Returns: + None + """ + if ( + self.database_engine.can_native_upsert + and table not in self._unsafe_to_upsert_tables + ): + return self.simple_upsert_many_txn_native_upsert( + txn, table, key_names, key_values, value_names, value_values + ) + else: + return self.simple_upsert_many_txn_emulated( + txn, table, key_names, key_values, value_names, value_values + ) + + def simple_upsert_many_txn_emulated( + self, txn, table, key_names, key_values, value_names, value_values + ): + """ + Upsert, many times, but without native UPSERT support or batching. + + Args: + table (str): The table to upsert into + key_names (list[str]): The key column names. + key_values (list[list]): A list of each row's key column values. + value_names (list[str]): The value column names. If empty, no + values will be used, even if value_values is provided. + value_values (list[list]): A list of each row's value column values. + Returns: + None + """ + # No value columns, therefore make a blank list so that the following + # zip() works correctly. + if not value_names: + value_values = [() for x in range(len(key_values))] + + for keyv, valv in zip(key_values, value_values): + _keys = {x: y for x, y in zip(key_names, keyv)} + _vals = {x: y for x, y in zip(value_names, valv)} + + self.simple_upsert_txn_emulated(txn, table, _keys, _vals) + + def simple_upsert_many_txn_native_upsert( + self, txn, table, key_names, key_values, value_names, value_values + ): + """ + Upsert, many times, using batching where possible. + + Args: + table (str): The table to upsert into + key_names (list[str]): The key column names. + key_values (list[list]): A list of each row's key column values. + value_names (list[str]): The value column names. If empty, no + values will be used, even if value_values is provided. + value_values (list[list]): A list of each row's value column values. + Returns: + None + """ + allnames = [] + allnames.extend(key_names) + allnames.extend(value_names) + + if not value_names: + # No value columns, therefore make a blank list so that the + # following zip() works correctly. + latter = "NOTHING" + value_values = [() for x in range(len(key_values))] + else: + latter = "UPDATE SET " + ", ".join( + k + "=EXCLUDED." + k for k in value_names + ) + + sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % ( + table, + ", ".join(k for k in allnames), + ", ".join("?" for _ in allnames), + ", ".join(key_names), + latter, + ) + + args = [] + + for x, y in zip(key_values, value_values): + args.append(tuple(x) + tuple(y)) + + return txn.execute_batch(sql, args) + + def simple_select_one( + self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one" + ): + """Executes a SELECT query on the named table, which is expected to + return a single row, returning multiple columns from it. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + retcols : list of strings giving the names of the columns to return + + allow_none : If true, return None instead of failing if the SELECT + statement returns no rows + """ + return self.runInteraction( + desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none + ) + + def simple_select_one_onecol( + self, + table, + keyvalues, + retcol, + allow_none=False, + desc="simple_select_one_onecol", + ): + """Executes a SELECT query on the named table, which is expected to + return a single row, returning a single column from it. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + retcol : string giving the name of the column to return + """ + return self.runInteraction( + desc, + self.simple_select_one_onecol_txn, + table, + keyvalues, + retcol, + allow_none=allow_none, + ) + + @classmethod + def simple_select_one_onecol_txn( + cls, txn, table, keyvalues, retcol, allow_none=False + ): + ret = cls.simple_select_onecol_txn( + txn, table=table, keyvalues=keyvalues, retcol=retcol + ) + + if ret: + return ret[0] + else: + if allow_none: + return None + else: + raise StoreError(404, "No row found") + + @staticmethod + def simple_select_onecol_txn(txn, table, keyvalues, retcol): + sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table} + + if keyvalues: + sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) + txn.execute(sql, list(keyvalues.values())) + else: + txn.execute(sql) + + return [r[0] for r in txn] + + def simple_select_onecol( + self, table, keyvalues, retcol, desc="simple_select_onecol" + ): + """Executes a SELECT query on the named table, which returns a list + comprising of the values of the named column from the selected rows. + + Args: + table (str): table name + keyvalues (dict|None): column names and values to select the rows with + retcol (str): column whos value we wish to retrieve. + + Returns: + Deferred: Results in a list + """ + return self.runInteraction( + desc, self.simple_select_onecol_txn, table, keyvalues, retcol + ) + + def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Args: + table (str): the table name + keyvalues (dict[str, Any] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. + retcols (iterable[str]): the names of the columns to return + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + return self.runInteraction( + desc, self.simple_select_list_txn, table, keyvalues, retcols + ) + + @classmethod + def simple_select_list_txn(cls, txn, table, keyvalues, retcols): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Args: + txn : Transaction object + table (str): the table name + keyvalues (dict[str, T] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. + retcols (iterable[str]): the names of the columns to return + """ + if keyvalues: + sql = "SELECT %s FROM %s WHERE %s" % ( + ", ".join(retcols), + table, + " AND ".join("%s = ?" % (k,) for k in keyvalues), + ) + txn.execute(sql, list(keyvalues.values())) + else: + sql = "SELECT %s FROM %s" % (", ".join(retcols), table) + txn.execute(sql) + + return cls.cursor_to_dict(txn) + + @defer.inlineCallbacks + def simple_select_many_batch( + self, + table, + column, + iterable, + retcols, + keyvalues={}, + desc="simple_select_many_batch", + batch_size=100, + ): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Filters rows by if value of `column` is in `iterable`. + + Args: + table : string giving the table name + column : column name to test for inclusion against `iterable` + iterable : list + keyvalues : dict of column names and values to select the rows with + retcols : list of strings giving the names of the columns to return + """ + results = [] + + if not iterable: + return results + + # iterables can not be sliced, so convert it to a list first + it_list = list(iterable) + + chunks = [ + it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size) + ] + for chunk in chunks: + rows = yield self.runInteraction( + desc, + self.simple_select_many_txn, + table, + column, + chunk, + keyvalues, + retcols, + ) + + results.extend(rows) + + return results + + @classmethod + def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Filters rows by if value of `column` is in `iterable`. + + Args: + txn : Transaction object + table : string giving the table name + column : column name to test for inclusion against `iterable` + iterable : list + keyvalues : dict of column names and values to select the rows with + retcols : list of strings giving the names of the columns to return + """ + if not iterable: + return [] + + clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable) + clauses = [clause] + + for key, value in iteritems(keyvalues): + clauses.append("%s = ?" % (key,)) + values.append(value) + + sql = "SELECT %s FROM %s WHERE %s" % ( + ", ".join(retcols), + table, + " AND ".join(clauses), + ) + + txn.execute(sql, values) + return cls.cursor_to_dict(txn) + + def simple_update(self, table, keyvalues, updatevalues, desc): + return self.runInteraction( + desc, self.simple_update_txn, table, keyvalues, updatevalues + ) + + @staticmethod + def simple_update_txn(txn, table, keyvalues, updatevalues): + if keyvalues: + where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) + else: + where = "" + + update_sql = "UPDATE %s SET %s %s" % ( + table, + ", ".join("%s = ?" % (k,) for k in updatevalues), + where, + ) + + txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values())) + + return txn.rowcount + + def simple_update_one( + self, table, keyvalues, updatevalues, desc="simple_update_one" + ): + """Executes an UPDATE query on the named table, setting new values for + columns in a row matching the key values. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + updatevalues : dict giving column names and values to update + retcols : optional list of column names to return + + If present, retcols gives a list of column names on which to perform + a SELECT statement *before* performing the UPDATE statement. The values + of these will be returned in a dict. + + These are performed within the same transaction, allowing an atomic + get-and-set. This can be used to implement compare-and-set by putting + the update column in the 'keyvalues' dict as well. + """ + return self.runInteraction( + desc, self.simple_update_one_txn, table, keyvalues, updatevalues + ) + + @classmethod + def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues): + rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues) + + if rowcount == 0: + raise StoreError(404, "No row found (%s)" % (table,)) + if rowcount > 1: + raise StoreError(500, "More than one row matched (%s)" % (table,)) + + @staticmethod + def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False): + select_sql = "SELECT %s FROM %s WHERE %s" % ( + ", ".join(retcols), + table, + " AND ".join("%s = ?" % (k,) for k in keyvalues), + ) + + txn.execute(select_sql, list(keyvalues.values())) + row = txn.fetchone() + + if not row: + if allow_none: + return None + raise StoreError(404, "No row found (%s)" % (table,)) + if txn.rowcount > 1: + raise StoreError(500, "More than one row matched (%s)" % (table,)) + + return dict(zip(retcols, row)) + + def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"): + """Executes a DELETE query on the named table, expecting to delete a + single row. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + """ + return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues) + + @staticmethod + def simple_delete_one_txn(txn, table, keyvalues): + """Executes a DELETE query on the named table, expecting to delete a + single row. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + """ + sql = "DELETE FROM %s WHERE %s" % ( + table, + " AND ".join("%s = ?" % (k,) for k in keyvalues), + ) + + txn.execute(sql, list(keyvalues.values())) + if txn.rowcount == 0: + raise StoreError(404, "No row found (%s)" % (table,)) + if txn.rowcount > 1: + raise StoreError(500, "More than one row matched (%s)" % (table,)) + + def simple_delete(self, table, keyvalues, desc): + return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues) + + @staticmethod + def simple_delete_txn(txn, table, keyvalues): + sql = "DELETE FROM %s WHERE %s" % ( + table, + " AND ".join("%s = ?" % (k,) for k in keyvalues), + ) + + txn.execute(sql, list(keyvalues.values())) + return txn.rowcount + + def simple_delete_many(self, table, column, iterable, keyvalues, desc): + return self.runInteraction( + desc, self.simple_delete_many_txn, table, column, iterable, keyvalues + ) + + @staticmethod + def simple_delete_many_txn(txn, table, column, iterable, keyvalues): + """Executes a DELETE query on the named table. + + Filters rows by if value of `column` is in `iterable`. + + Args: + txn : Transaction object + table : string giving the table name + column : column name to test for inclusion against `iterable` + iterable : list + keyvalues : dict of column names and values to select the rows with + + Returns: + int: Number rows deleted + """ + if not iterable: + return 0 + + sql = "DELETE FROM %s" % table + + clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable) + clauses = [clause] + + for key, value in iteritems(keyvalues): + clauses.append("%s = ?" % (key,)) + values.append(value) + + if clauses: + sql = "%s WHERE %s" % (sql, " AND ".join(clauses)) + txn.execute(sql, values) + + return txn.rowcount + + def get_cache_dict( + self, db_conn, table, entity_column, stream_column, max_value, limit=100000 + ): + # Fetch a mapping of room_id -> max stream position for "recent" rooms. + # It doesn't really matter how many we get, the StreamChangeCache will + # do the right thing to ensure it respects the max size of cache. + sql = ( + "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s" + " WHERE %(stream)s > ? - %(limit)s" + " GROUP BY %(entity)s" + ) % { + "table": table, + "entity": entity_column, + "stream": stream_column, + "limit": limit, + } + + sql = self.database_engine.convert_param_style(sql) + + txn = db_conn.cursor() + txn.execute(sql, (int(max_value),)) + + cache = {row[0]: int(row[1]) for row in txn} + + txn.close() + + if cache: + min_val = min(itervalues(cache)) + else: + min_val = max_value + + return cache, min_val + + def simple_select_list_paginate( + self, + table, + keyvalues, + orderby, + start, + limit, + retcols, + order_direction="ASC", + desc="simple_select_list_paginate", + ): + """ + Executes a SELECT query on the named table with start and limit, + of row numbers, which may return zero or number of rows from start to limit, + returning the result as a list of dicts. + + Args: + table (str): the table name + keyvalues (dict[str, T] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. + orderby (str): Column to order the results by. + start (int): Index to begin the query at. + limit (int): Number of results to return. + retcols (iterable[str]): the names of the columns to return + order_direction (str): Whether the results should be ordered "ASC" or "DESC". + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + return self.runInteraction( + desc, + self.simple_select_list_paginate_txn, + table, + keyvalues, + orderby, + start, + limit, + retcols, + order_direction=order_direction, + ) + + @classmethod + def simple_select_list_paginate_txn( + cls, + txn, + table, + keyvalues, + orderby, + start, + limit, + retcols, + order_direction="ASC", + ): + """ + Executes a SELECT query on the named table with start and limit, + of row numbers, which may return zero or number of rows from start to limit, + returning the result as a list of dicts. + + Args: + txn : Transaction object + table (str): the table name + keyvalues (dict[str, T] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. + orderby (str): Column to order the results by. + start (int): Index to begin the query at. + limit (int): Number of results to return. + retcols (iterable[str]): the names of the columns to return + order_direction (str): Whether the results should be ordered "ASC" or "DESC". + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + if order_direction not in ["ASC", "DESC"]: + raise ValueError("order_direction must be one of 'ASC' or 'DESC'.") + + if keyvalues: + where_clause = "WHERE " + " AND ".join("%s = ?" % (k,) for k in keyvalues) + else: + where_clause = "" + + sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % ( + ", ".join(retcols), + table, + where_clause, + orderby, + order_direction, + ) + txn.execute(sql, list(keyvalues.values()) + [limit, start]) + + return cls.cursor_to_dict(txn) + + def get_user_count_txn(self, txn): + """Get a total number of registered users in the users list. + + Args: + txn : Transaction object + Returns: + int : number of users + """ + sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;" + txn.execute(sql_count) + return txn.fetchone()[0] + + def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Args: + table (str): the table name + term (str | None): + term for searching the table matched to a column. + col (str): column to query term should be matched to + retcols (iterable[str]): the names of the columns to return + Returns: + defer.Deferred: resolves to list[dict[str, Any]] or None + """ + + return self.runInteraction( + desc, self.simple_search_list_txn, table, term, col, retcols + ) + + @classmethod + def simple_search_list_txn(cls, txn, table, term, col, retcols): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Args: + txn : Transaction object + table (str): the table name + term (str | None): + term for searching the table matched to a column. + col (str): column to query term should be matched to + retcols (iterable[str]): the names of the columns to return + Returns: + defer.Deferred: resolves to list[dict[str, Any]] or None + """ + if term: + sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col) + termvalues = ["%%" + term + "%%"] + txn.execute(sql, termvalues) + else: + return 0 + + return cls.cursor_to_dict(txn) + + +def make_in_list_sql_clause( + database_engine, column: str, iterable: Iterable +) -> Tuple[str, Iterable]: + """Returns an SQL clause that checks the given column is in the iterable. + + On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres + it expands to `column = ANY(?)`. While both DBs support the `IN` form, + using the `ANY` form on postgres means that it views queries with + different length iterables as the same, helping the query stats. + + Args: + database_engine + column: Name of the column + iterable: The values to check the column against. + + Returns: + A tuple of SQL query and the args + """ + + if database_engine.supports_using_any_list: + # This should hopefully be faster, but also makes postgres query + # stats easier to understand. + return "%s = ANY(?)" % (column,), [list(iterable)] + else: + return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable) diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 380fd0d107..7f7962c3dd 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -45,13 +45,13 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.store._all_done = False self.get_success( - self.store.simple_insert( + self.store.db.simple_insert( "background_updates", {"update_name": "populate_stats_prepare", "progress_json": "{}"}, ) ) self.get_success( - self.store.simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_process_rooms", @@ -61,7 +61,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store.simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_process_users", @@ -71,7 +71,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store.simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", @@ -82,7 +82,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) def get_all_room_state(self): - return self.store.simple_select_list( + return self.store.db.simple_select_list( "room_stats_state", None, retcols=("name", "topic", "canonical_alias") ) @@ -96,7 +96,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): end_ts = self.store.quantise_stats_time(self.reactor.seconds() * 1000) return self.get_success( - self.store.simple_select_one( + self.store.db.simple_select_one( table + "_historical", {id_col: stat_id, end_ts: end_ts}, cols, @@ -180,7 +180,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.handler.stats_enabled = True self.store._all_done = False self.get_success( - self.store.simple_update_one( + self.store.db.simple_update_one( table="stats_incremental_position", keyvalues={}, updatevalues={"stream_id": 0}, @@ -188,7 +188,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) self.get_success( - self.store.simple_insert( + self.store.db.simple_insert( "background_updates", {"update_name": "populate_stats_prepare", "progress_json": "{}"}, ) @@ -205,13 +205,13 @@ class StatsRoomTests(unittest.HomeserverTestCase): # Now do the initial ingestion. self.get_success( - self.store.simple_insert( + self.store.db.simple_insert( "background_updates", {"update_name": "populate_stats_process_rooms", "progress_json": "{}"}, ) ) self.get_success( - self.store.simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", @@ -656,12 +656,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.store._all_done = False self.get_success( - self.store.simple_delete( + self.store.db.simple_delete( "room_stats_current", {"1": 1}, "test_delete_stats" ) ) self.get_success( - self.store.simple_delete( + self.store.db.simple_delete( "user_stats_current", {"1": 1}, "test_delete_stats" ) ) @@ -675,7 +675,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.store._all_done = False self.get_success( - self.store.simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_process_rooms", @@ -685,7 +685,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store.simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_process_users", @@ -695,7 +695,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store.simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index d5b1c5b4ac..bc9d441541 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -158,7 +158,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def get_users_in_public_rooms(self): r = self.get_success( - self.store.simple_select_list( + self.store.db.simple_select_list( "users_in_public_rooms", None, ("user_id", "room_id") ) ) @@ -169,7 +169,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def get_users_who_share_private_rooms(self): return self.get_success( - self.store.simple_select_list( + self.store.db.simple_select_list( "users_who_share_private_rooms", None, ["user_id", "other_user_id", "room_id"], @@ -184,7 +184,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.store._all_done = False self.get_success( - self.store.simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_user_directory_createtables", @@ -193,7 +193,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.store.simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_user_directory_process_rooms", @@ -203,7 +203,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.store.simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_user_directory_process_users", @@ -213,7 +213,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.store.simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_user_directory_cleanup", diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 124ce0768a..0ed2594381 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -632,7 +632,7 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase): "state_groups_state", ): count = self.get_success( - self.store.simple_select_one_onecol( + self.store.db.simple_select_one_onecol( table=table, keyvalues={"room_id": room_id}, retcol="COUNT(*)", diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 7b7434a468..d491ea2924 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -323,7 +323,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): self.table_name = "table_" + hs.get_secrets().token_hex(6) self.get_success( - self.storage.runInteraction( + self.storage.db.runInteraction( "create", lambda x, *a: x.execute(*a), "CREATE TABLE %s (id INTEGER, username TEXT, value TEXT)" @@ -331,7 +331,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.storage.runInteraction( + self.storage.db.runInteraction( "index", lambda x, *a: x.execute(*a), "CREATE UNIQUE INDEX %sindex ON %s(id, username)" @@ -354,9 +354,9 @@ class UpsertManyTests(unittest.HomeserverTestCase): value_values = [["hello"], ["there"]] self.get_success( - self.storage.runInteraction( + self.storage.db.runInteraction( "test", - self.storage.simple_upsert_many_txn, + self.storage.db.simple_upsert_many_txn, self.table_name, key_names, key_values, @@ -367,7 +367,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): # Check results are what we expect res = self.get_success( - self.storage.simple_select_list( + self.storage.db.simple_select_list( self.table_name, None, ["id, username, value"] ) ) @@ -381,9 +381,9 @@ class UpsertManyTests(unittest.HomeserverTestCase): value_values = [["bleb"]] self.get_success( - self.storage.runInteraction( + self.storage.db.runInteraction( "test", - self.storage.simple_upsert_many_txn, + self.storage.db.simple_upsert_many_txn, self.table_name, key_names, key_values, @@ -394,7 +394,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): # Check results are what we expect res = self.get_success( - self.storage.simple_select_list( + self.storage.db.simple_select_list( self.table_name, None, ["id, username, value"] ) ) diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 9fabe3fbc0..e360297df9 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -37,7 +37,7 @@ class BackgroundUpdateTestCase(unittest.TestCase): def update(progress, count): self.clock.advance_time_msec(count * duration_ms) progress = {"my_key": progress["my_key"] + 1} - yield self.store.runInteraction( + yield self.store.db.runInteraction( "update_progress", self.store._background_update_progress_txn, "test_update", diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index de5e4a5fce..7915d48a9e 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -65,7 +65,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_insert_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore.simple_insert( + yield self.datastore.db.simple_insert( table="tablename", values={"columname": "Value"} ) @@ -77,7 +77,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_insert_3cols(self): self.mock_txn.rowcount = 1 - yield self.datastore.simple_insert( + yield self.datastore.db.simple_insert( table="tablename", # Use OrderedDict() so we can assert on the SQL generated values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]), @@ -92,7 +92,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)])) - value = yield self.datastore.simple_select_one_onecol( + value = yield self.datastore.db.simple_select_one_onecol( table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol" ) @@ -106,7 +106,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 self.mock_txn.fetchone.return_value = (1, 2, 3) - ret = yield self.datastore.simple_select_one( + ret = yield self.datastore.db.simple_select_one( table="tablename", keyvalues={"keycol": "TheKey"}, retcols=["colA", "colB", "colC"], @@ -122,7 +122,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 0 self.mock_txn.fetchone.return_value = None - ret = yield self.datastore.simple_select_one( + ret = yield self.datastore.db.simple_select_one( table="tablename", keyvalues={"keycol": "Not here"}, retcols=["colA"], @@ -137,7 +137,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)])) self.mock_txn.description = (("colA", None, None, None, None, None, None),) - ret = yield self.datastore.simple_select_list( + ret = yield self.datastore.db.simple_select_list( table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"] ) @@ -150,7 +150,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_update_one_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore.simple_update_one( + yield self.datastore.db.simple_update_one( table="tablename", keyvalues={"keycol": "TheKey"}, updatevalues={"columnname": "New Value"}, @@ -165,7 +165,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_update_one_4cols(self): self.mock_txn.rowcount = 1 - yield self.datastore.simple_update_one( + yield self.datastore.db.simple_update_one( table="tablename", keyvalues=OrderedDict([("colA", 1), ("colB", 2)]), updatevalues=OrderedDict([("colC", 3), ("colD", 4)]), @@ -180,7 +180,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_delete_one(self): self.mock_txn.rowcount = 1 - yield self.datastore.simple_delete_one( + yield self.datastore.db.simple_delete_one( table="tablename", keyvalues={"keycol": "Go away"} ) diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index 69dcaa63d5..e454bbff29 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -62,7 +62,9 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): prepare_database.executescript(txn, schema_path) self.get_success( - self.store.runInteraction("test_delete_forward_extremities", run_delta_file) + self.store.db.runInteraction( + "test_delete_forward_extremities", run_delta_file + ) ) # Ugh, have to reset this flag diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 25bdd2c163..c4f838907c 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -81,7 +81,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.pump(0) result = self.get_success( - self.store.simple_select_list( + self.store.db.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -112,7 +112,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.pump(0) result = self.get_success( - self.store.simple_select_list( + self.store.db.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -218,7 +218,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # But clear the associated entry in devices table self.get_success( - self.store.simple_update( + self.store.db.simple_update( table="devices", keyvalues={"user_id": user_id, "device_id": "device_id"}, updatevalues={"last_seen": None, "ip": None, "user_agent": None}, @@ -245,7 +245,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # Register the background update to run again. self.get_success( - self.store.simple_insert( + self.store.db.simple_insert( table="background_updates", values={ "update_name": "devices_last_seen", @@ -297,7 +297,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # We should see that in the DB result = self.get_success( - self.store.simple_select_list( + self.store.db.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -323,7 +323,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # We should get no results. result = self.get_success( - self.store.simple_select_list( + self.store.db.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 2fe50377f8..eadfb90a22 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -61,7 +61,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): ) for i in range(0, 11): - yield self.store.runInteraction("insert", insert_event, i) + yield self.store.db.runInteraction("insert", insert_event, i) # this should get the last five and five others r = yield self.store.get_prev_events_for_room(room_id) @@ -93,9 +93,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): ) for i in range(0, 20): - yield self.store.runInteraction("insert", insert_event, i, room1) - yield self.store.runInteraction("insert", insert_event, i, room2) - yield self.store.runInteraction("insert", insert_event, i, room3) + yield self.store.db.runInteraction("insert", insert_event, i, room1) + yield self.store.db.runInteraction("insert", insert_event, i, room2) + yield self.store.db.runInteraction("insert", insert_event, i, room3) # Test simple case r = yield self.store.get_rooms_with_many_extremities(5, 5, []) diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 2337a1ae46..d4bcf1821e 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -55,7 +55,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def _assert_counts(noitf_count, highlight_count): - counts = yield self.store.runInteraction( + counts = yield self.store.db.runInteraction( "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0 ) self.assertEquals( @@ -74,7 +74,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): yield self.store.add_push_actions_to_staging( event.event_id, {user_id: action} ) - yield self.store.runInteraction( + yield self.store.db.runInteraction( "", self.store._set_push_actions_for_event_and_users_txn, [(event, None)], @@ -82,12 +82,12 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): ) def _rotate(stream): - return self.store.runInteraction( + return self.store.db.runInteraction( "", self.store._rotate_notifs_before_txn, stream ) def _mark_read(stream, depth): - return self.store.runInteraction( + return self.store.db.runInteraction( "", self.store._remove_old_push_actions_before_txn, room_id, @@ -116,7 +116,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): yield _inject_actions(6, PlAIN_NOTIF) yield _rotate(7) - yield self.store.simple_delete( + yield self.store.db.simple_delete( table="event_push_actions", keyvalues={"1": 1}, desc="" ) @@ -135,7 +135,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_find_first_stream_ordering_after_ts(self): def add_event(so, ts): - return self.store.simple_insert( + return self.store.db.simple_insert( "events", { "stream_ordering": so, diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 90a63dc477..3c78faab45 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -65,7 +65,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.user_add_threepid(user1, "email", user1_email, now, now) self.store.user_add_threepid(user2, "email", user2_email, now, now) - self.store.runInteraction( + self.store.db.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) self.pump() @@ -183,7 +183,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): ) self.hs.config.mau_limits_reserved_threepids = threepids - self.store.runInteraction( + self.store.db.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) count = self.store.get_monthly_active_count() @@ -244,7 +244,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): {"medium": "email", "address": user2_email}, ] self.hs.config.mau_limits_reserved_threepids = threepids - self.store.runInteraction( + self.store.db.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 4930b6777e..dc45173355 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -338,7 +338,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) event_json = self.get_success( - self.store.simple_select_one_onecol( + self.store.db.simple_select_one_onecol( table="event_json", keyvalues={"event_id": msg_event.event_id}, retcol="json", @@ -356,7 +356,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.reactor.advance(60 * 60 * 2) event_json = self.get_success( - self.store.simple_select_one_onecol( + self.store.db.simple_select_one_onecol( table="event_json", keyvalues={"event_id": msg_event.event_id}, retcol="json", diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index d389cf578f..5f957680a2 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -132,7 +132,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): # Register the background update to run again. self.get_success( - self.store.simple_insert( + self.store.db.simple_insert( table="background_updates", values={ "update_name": "current_state_events_membership", diff --git a/tests/unittest.py b/tests/unittest.py index 295573bc46..fc856a574a 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -544,7 +544,7 @@ class HomeserverTestCase(TestCase): Add the given event as an extremity to the room. """ self.get_success( - self.hs.get_datastore().simple_insert( + self.hs.get_datastore().db.simple_insert( table="event_forward_extremities", values={"room_id": room_id, "event_id": event_id}, desc="test_add_extremity", -- cgit 1.5.1 From 4a33a6dd19590b8e6626a5af5a69507dc11236f8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 4 Dec 2019 15:09:36 +0000 Subject: Move background update handling out of store --- synapse/app/homeserver.py | 2 +- synapse/rest/media/v1/preview_url_resource.py | 2 +- synapse/storage/background_updates.py | 15 +++---- synapse/storage/data_stores/main/client_ips.py | 36 ++++++++-------- synapse/storage/data_stores/main/deviceinbox.py | 9 ++-- synapse/storage/data_stores/main/devices.py | 15 ++++--- .../storage/data_stores/main/event_federation.py | 6 +-- .../storage/data_stores/main/event_push_actions.py | 4 +- synapse/storage/data_stores/main/events.py | 6 +-- .../storage/data_stores/main/events_bg_updates.py | 49 +++++++++++---------- .../storage/data_stores/main/media_repository.py | 6 +-- synapse/storage/data_stores/main/registration.py | 21 ++++----- synapse/storage/data_stores/main/room.py | 9 ++-- synapse/storage/data_stores/main/roommember.py | 27 +++++++----- synapse/storage/data_stores/main/search.py | 31 ++++++++------ synapse/storage/data_stores/main/state.py | 19 ++++---- synapse/storage/data_stores/main/stats.py | 20 ++++----- synapse/storage/data_stores/main/user_directory.py | 33 ++++++++------ synapse/storage/database.py | 3 ++ synmark/__init__.py | 6 +-- tests/handlers/test_stats.py | 50 +++++++++++++++------- tests/handlers/test_user_directory.py | 18 +++++--- tests/storage/test_background_update.py | 26 +++++++---- tests/storage/test_cleanup_extrems.py | 14 ++++-- tests/storage/test_client_ips.py | 26 ++++++++--- tests/storage/test_roommember.py | 18 +++++--- tests/unittest.py | 10 +++-- 27 files changed, 281 insertions(+), 200 deletions(-) (limited to 'tests') diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 267aebaae9..9f81a857ab 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -436,7 +436,7 @@ def setup(config_options): _base.start(hs, config.listeners) hs.get_pusherpool().start() - hs.get_datastore().start_doing_background_updates() + hs.get_datastore().db.updates.start_doing_background_updates() except Exception: # Print the exception and bail out. print("Error during startup:", file=sys.stderr) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index fb0d02aa83..6b978be876 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -402,7 +402,7 @@ class PreviewUrlResource(DirectServeResource): logger.info("Running url preview cache expiry") - if not (yield self.store.has_completed_background_updates()): + if not (yield self.store.db.updates.has_completed_background_updates()): logger.info("Still running DB updates; skipping expiry") return diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index dfca94b0e0..a9a13a2658 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -22,7 +22,6 @@ from twisted.internet import defer from synapse.metrics.background_process_metrics import run_as_background_process from . import engines -from ._base import SQLBaseStore logger = logging.getLogger(__name__) @@ -74,7 +73,7 @@ class BackgroundUpdatePerformance(object): return float(self.total_item_count) / float(self.total_duration_ms) -class BackgroundUpdateStore(SQLBaseStore): +class BackgroundUpdater(object): """ Background updates are updates to the database that run in the background. Each update processes a batch of data at once. We attempt to limit the impact of each update by monitoring how long each batch takes to @@ -86,8 +85,10 @@ class BackgroundUpdateStore(SQLBaseStore): BACKGROUND_UPDATE_INTERVAL_MS = 1000 BACKGROUND_UPDATE_DURATION_MS = 100 - def __init__(self, db_conn, hs): - super(BackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, hs, database): + self._clock = hs.get_clock() + self.db = database + self._background_update_performance = {} self._background_update_queue = [] self._background_update_handlers = {} @@ -101,9 +102,7 @@ class BackgroundUpdateStore(SQLBaseStore): logger.info("Starting background schema updates") while True: if sleep: - yield self.hs.get_clock().sleep( - self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0 - ) + yield self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0) try: result = yield self.do_next_background_update( @@ -380,7 +379,7 @@ class BackgroundUpdateStore(SQLBaseStore): logger.debug("[SQL] %s", sql) c.execute(sql) - if isinstance(self.database_engine, engines.PostgresEngine): + if isinstance(self.db.database_engine, engines.PostgresEngine): runner = create_index_psql elif psql_only: runner = None diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index 6f2a720b97..7b470a58f1 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -20,7 +20,7 @@ from six import iteritems from twisted.internet import defer from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.storage import background_updates +from synapse.storage._base import SQLBaseStore from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.descriptors import Cache @@ -32,41 +32,41 @@ logger = logging.getLogger(__name__) LAST_SEEN_GRANULARITY = 120 * 1000 -class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): +class ClientIpBackgroundUpdateStore(SQLBaseStore): def __init__(self, db_conn, hs): super(ClientIpBackgroundUpdateStore, self).__init__(db_conn, hs) - self.register_background_index_update( + self.db.updates.register_background_index_update( "user_ips_device_index", index_name="user_ips_device_id", table="user_ips", columns=["user_id", "device_id", "last_seen"], ) - self.register_background_index_update( + self.db.updates.register_background_index_update( "user_ips_last_seen_index", index_name="user_ips_last_seen", table="user_ips", columns=["user_id", "last_seen"], ) - self.register_background_index_update( + self.db.updates.register_background_index_update( "user_ips_last_seen_only_index", index_name="user_ips_last_seen_only", table="user_ips", columns=["last_seen"], ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "user_ips_analyze", self._analyze_user_ip ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "user_ips_remove_dupes", self._remove_user_ip_dupes ) # Register a unique index - self.register_background_index_update( + self.db.updates.register_background_index_update( "user_ips_device_unique_index", index_name="user_ips_user_token_ip_unique_index", table="user_ips", @@ -75,12 +75,12 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): ) # Drop the old non-unique index - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "user_ips_drop_nonunique_index", self._remove_user_ip_nonunique ) # Update the last seen info in devices. - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "devices_last_seen", self._devices_last_seen_update ) @@ -92,7 +92,7 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): txn.close() yield self.db.runWithConnection(f) - yield self._end_background_update("user_ips_drop_nonunique_index") + yield self.db.updates._end_background_update("user_ips_drop_nonunique_index") return 1 @defer.inlineCallbacks @@ -108,7 +108,7 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): yield self.db.runInteraction("user_ips_analyze", user_ips_analyze) - yield self._end_background_update("user_ips_analyze") + yield self.db.updates._end_background_update("user_ips_analyze") return 1 @@ -271,14 +271,14 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): (user_id, access_token, ip, device_id, user_agent, last_seen), ) - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, "user_ips_remove_dupes", {"last_seen": end_last_seen} ) yield self.db.runInteraction("user_ips_dups_remove", remove) if last: - yield self._end_background_update("user_ips_remove_dupes") + yield self.db.updates._end_background_update("user_ips_remove_dupes") return batch_size @@ -344,7 +344,7 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): txn.execute_batch(sql, rows) _, _, _, user_id, device_id = rows[-1] - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, "devices_last_seen", {"last_user_id": user_id, "last_device_id": device_id}, @@ -357,7 +357,7 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): ) if not updated: - yield self._end_background_update("devices_last_seen") + yield self.db.updates._end_background_update("devices_last_seen") return updated @@ -546,7 +546,9 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): # Nothing to do return - if not await self.has_completed_background_update("devices_last_seen"): + if not await self.db.updates.has_completed_background_update( + "devices_last_seen" + ): # Only start pruning if we have finished populating the devices # last seen info. return diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py index 440793ad49..3c9f09301a 100644 --- a/synapse/storage/data_stores/main/deviceinbox.py +++ b/synapse/storage/data_stores/main/deviceinbox.py @@ -21,7 +21,6 @@ from twisted.internet import defer from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.storage.background_updates import BackgroundUpdateStore from synapse.util.caches.expiringcache import ExpiringCache logger = logging.getLogger(__name__) @@ -208,20 +207,20 @@ class DeviceInboxWorkerStore(SQLBaseStore): ) -class DeviceInboxBackgroundUpdateStore(BackgroundUpdateStore): +class DeviceInboxBackgroundUpdateStore(SQLBaseStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" def __init__(self, db_conn, hs): super(DeviceInboxBackgroundUpdateStore, self).__init__(db_conn, hs) - self.register_background_index_update( + self.db.updates.register_background_index_update( "device_inbox_stream_index", index_name="device_inbox_stream_id_user_id", table="device_inbox", columns=["stream_id", "user_id"], ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox ) @@ -234,7 +233,7 @@ class DeviceInboxBackgroundUpdateStore(BackgroundUpdateStore): yield self.db.runWithConnection(reindex_txn) - yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID) + yield self.db.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID) return 1 diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index d98511ddd4..91ddaf137e 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -31,7 +31,6 @@ from synapse.logging.opentracing import ( ) from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.background_updates import BackgroundUpdateStore from synapse.types import get_verify_key_from_cross_signing_key from synapse.util import batch_iter from synapse.util.caches.descriptors import ( @@ -642,11 +641,11 @@ class DeviceWorkerStore(SQLBaseStore): return results -class DeviceBackgroundUpdateStore(BackgroundUpdateStore): +class DeviceBackgroundUpdateStore(SQLBaseStore): def __init__(self, db_conn, hs): super(DeviceBackgroundUpdateStore, self).__init__(db_conn, hs) - self.register_background_index_update( + self.db.updates.register_background_index_update( "device_lists_stream_idx", index_name="device_lists_stream_user_id", table="device_lists_stream", @@ -654,7 +653,7 @@ class DeviceBackgroundUpdateStore(BackgroundUpdateStore): ) # create a unique index on device_lists_remote_cache - self.register_background_index_update( + self.db.updates.register_background_index_update( "device_lists_remote_cache_unique_idx", index_name="device_lists_remote_cache_unique_id", table="device_lists_remote_cache", @@ -663,7 +662,7 @@ class DeviceBackgroundUpdateStore(BackgroundUpdateStore): ) # And one on device_lists_remote_extremeties - self.register_background_index_update( + self.db.updates.register_background_index_update( "device_lists_remote_extremeties_unique_idx", index_name="device_lists_remote_extremeties_unique_idx", table="device_lists_remote_extremeties", @@ -672,7 +671,7 @@ class DeviceBackgroundUpdateStore(BackgroundUpdateStore): ) # once they complete, we can remove the old non-unique indexes. - self.register_background_update_handler( + self.db.updates.register_background_update_handler( DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES, self._drop_device_list_streams_non_unique_indexes, ) @@ -686,7 +685,9 @@ class DeviceBackgroundUpdateStore(BackgroundUpdateStore): txn.close() yield self.db.runWithConnection(f) - yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES) + yield self.db.updates._end_background_update( + DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES + ) return 1 diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 77e4353b59..31d2e8eb28 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -494,7 +494,7 @@ class EventFederationStore(EventFederationWorkerStore): def __init__(self, db_conn, hs): super(EventFederationStore, self).__init__(db_conn, hs) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth ) @@ -654,7 +654,7 @@ class EventFederationStore(EventFederationWorkerStore): "max_stream_id_exclusive": min_stream_id, } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, self.EVENT_AUTH_STATE_ONLY, new_progress ) @@ -665,6 +665,6 @@ class EventFederationStore(EventFederationWorkerStore): ) if not result: - yield self._end_background_update(self.EVENT_AUTH_STATE_ONLY) + yield self.db.updates._end_background_update(self.EVENT_AUTH_STATE_ONLY) return batch_size diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index 725d0881dc..eec054cd48 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -614,14 +614,14 @@ class EventPushActionsStore(EventPushActionsWorkerStore): def __init__(self, db_conn, hs): super(EventPushActionsStore, self).__init__(db_conn, hs) - self.register_background_index_update( + self.db.updates.register_background_index_update( self.EPA_HIGHLIGHT_INDEX, index_name="event_push_actions_u_highlight", table="event_push_actions", columns=["user_id", "stream_ordering"], ) - self.register_background_index_update( + self.db.updates.register_background_index_update( "event_push_actions_highlights_index", index_name="event_push_actions_highlights_index", table="event_push_actions", diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 01ec9ec397..d644c82784 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -38,7 +38,6 @@ from synapse.logging.utils import log_function from synapse.metrics import BucketCollector from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import make_in_list_sql_clause -from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.data_stores.main.event_federation import EventFederationStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.state import StateGroupWorkerStore @@ -94,10 +93,7 @@ def _retry_on_integrity_error(func): # inherits from EventFederationStore so that we can call _update_backward_extremities # and _handle_mult_prev_events (though arguably those could both be moved in here) class EventsStore( - StateGroupWorkerStore, - EventFederationStore, - EventsWorkerStore, - BackgroundUpdateStore, + StateGroupWorkerStore, EventFederationStore, EventsWorkerStore, ): def __init__(self, db_conn, hs): super(EventsStore, self).__init__(db_conn, hs) diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py index 365e966956..cb1fc30c31 100644 --- a/synapse/storage/data_stores/main/events_bg_updates.py +++ b/synapse/storage/data_stores/main/events_bg_updates.py @@ -22,13 +22,12 @@ from canonicaljson import json from twisted.internet import defer from synapse.api.constants import EventContentFields -from synapse.storage._base import make_in_list_sql_clause -from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause logger = logging.getLogger(__name__) -class EventsBackgroundUpdatesStore(BackgroundUpdateStore): +class EventsBackgroundUpdatesStore(SQLBaseStore): EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" @@ -37,15 +36,15 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): def __init__(self, db_conn, hs): super(EventsBackgroundUpdatesStore, self).__init__(db_conn, hs) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, self._background_reindex_fields_sender, ) - self.register_background_index_update( + self.db.updates.register_background_index_update( "event_contains_url_index", index_name="event_contains_url_index", table="events", @@ -56,7 +55,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): # an event_id index on event_search is useful for the purge_history # api. Plus it means we get to enforce some integrity with a UNIQUE # clause - self.register_background_index_update( + self.db.updates.register_background_index_update( "event_search_event_id_idx", index_name="event_search_event_id_idx", table="event_search", @@ -65,16 +64,16 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): psql_only=True, ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "redactions_received_ts", self._redactions_received_ts ) # This index gets deleted in `event_fix_redactions_bytes` update - self.register_background_index_update( + self.db.updates.register_background_index_update( "event_fix_redactions_bytes_create_index", index_name="redactions_censored_redacts", table="redactions", @@ -82,11 +81,11 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): where_clause="have_censored", ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "event_fix_redactions_bytes", self._event_fix_redactions_bytes ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "event_store_labels", self._event_store_labels ) @@ -145,7 +144,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): "rows_inserted": rows_inserted + len(rows), } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress ) @@ -156,7 +155,9 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): ) if not result: - yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME) + yield self.db.updates._end_background_update( + self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME + ) return result @@ -222,7 +223,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): "rows_inserted": rows_inserted + len(rows_to_update), } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress ) @@ -233,7 +234,9 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): ) if not result: - yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME) + yield self.db.updates._end_background_update( + self.EVENT_ORIGIN_SERVER_TS_NAME + ) return result @@ -411,7 +414,9 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): ) if not num_handled: - yield self._end_background_update(self.DELETE_SOFT_FAILED_EXTREMITIES) + yield self.db.updates._end_background_update( + self.DELETE_SOFT_FAILED_EXTREMITIES + ) def _drop_table_txn(txn): txn.execute("DROP TABLE _extremities_to_check") @@ -464,7 +469,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): txn.execute(sql, (self._clock.time_msec(), last_event_id, upper_event_id)) - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, "redactions_received_ts", {"last_event_id": upper_event_id} ) @@ -475,7 +480,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): ) if not count: - yield self._end_background_update("redactions_received_ts") + yield self.db.updates._end_background_update("redactions_received_ts") return count @@ -505,7 +510,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): "_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn ) - yield self._end_background_update("event_fix_redactions_bytes") + yield self.db.updates._end_background_update("event_fix_redactions_bytes") return 1 @@ -559,7 +564,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): nbrows += 1 last_row_event_id = event_id - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, "event_store_labels", {"last_event_id": last_row_event_id} ) @@ -570,6 +575,6 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): ) if not num_rows: - yield self._end_background_update("event_store_labels") + yield self.db.updates._end_background_update("event_store_labels") return num_rows diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py index ea02497784..03c9c6f8ae 100644 --- a/synapse/storage/data_stores/main/media_repository.py +++ b/synapse/storage/data_stores/main/media_repository.py @@ -12,14 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage._base import SQLBaseStore -class MediaRepositoryBackgroundUpdateStore(BackgroundUpdateStore): +class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): def __init__(self, db_conn, hs): super(MediaRepositoryBackgroundUpdateStore, self).__init__(db_conn, hs) - self.register_background_index_update( + self.db.updates.register_background_index_update( update_name="local_media_repository_url_idx", index_name="local_media_repository_url_idx", table="local_media_repository", diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py index 8f9aa87ceb..1ef143c6d8 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/data_stores/main/registration.py @@ -26,7 +26,6 @@ from twisted.internet.defer import Deferred from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage import background_updates from synapse.storage._base import SQLBaseStore from synapse.types import UserID from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -794,23 +793,21 @@ class RegistrationWorkerStore(SQLBaseStore): ) -class RegistrationBackgroundUpdateStore( - RegistrationWorkerStore, background_updates.BackgroundUpdateStore -): +class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): def __init__(self, db_conn, hs): super(RegistrationBackgroundUpdateStore, self).__init__(db_conn, hs) self.clock = hs.get_clock() self.config = hs.config - self.register_background_index_update( + self.db.updates.register_background_index_update( "access_tokens_device_index", index_name="access_tokens_device_id", table="access_tokens", columns=["user_id", "device_id"], ) - self.register_background_index_update( + self.db.updates.register_background_index_update( "users_creation_ts", index_name="users_creation_ts", table="users", @@ -820,13 +817,13 @@ class RegistrationBackgroundUpdateStore( # we no longer use refresh tokens, but it's possible that some people # might have a background update queued to build this index. Just # clear the background update. - self.register_noop_background_update("refresh_tokens_device_index") + self.db.updates.register_noop_background_update("refresh_tokens_device_index") - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "user_threepids_grandfather", self._bg_user_threepids_grandfather ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "users_set_deactivated_flag", self._background_update_set_deactivated_flag ) @@ -873,7 +870,7 @@ class RegistrationBackgroundUpdateStore( logger.info("Marked %d rows as deactivated", rows_processed_nb) - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]} ) @@ -887,7 +884,7 @@ class RegistrationBackgroundUpdateStore( ) if end: - yield self._end_background_update("users_set_deactivated_flag") + yield self.db.updates._end_background_update("users_set_deactivated_flag") return nb_processed @@ -917,7 +914,7 @@ class RegistrationBackgroundUpdateStore( "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn ) - yield self._end_background_update("user_threepids_grandfather") + yield self.db.updates._end_background_update("user_threepids_grandfather") return 1 diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index a26ed47afc..da42dae243 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -28,7 +28,6 @@ from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore -from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.data_stores.main.search import SearchStore from synapse.types import ThirdPartyInstanceID from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -361,13 +360,13 @@ class RoomWorkerStore(SQLBaseStore): defer.returnValue(row) -class RoomBackgroundUpdateStore(BackgroundUpdateStore): +class RoomBackgroundUpdateStore(SQLBaseStore): def __init__(self, db_conn, hs): super(RoomBackgroundUpdateStore, self).__init__(db_conn, hs) self.config = hs.config - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "insert_room_retention", self._background_insert_retention, ) @@ -421,7 +420,7 @@ class RoomBackgroundUpdateStore(BackgroundUpdateStore): logger.info("Inserted %d rows into room_retention", len(rows)) - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]} ) @@ -435,7 +434,7 @@ class RoomBackgroundUpdateStore(BackgroundUpdateStore): ) if end: - yield self._end_background_update("insert_room_retention") + yield self.db.updates._end_background_update("insert_room_retention") defer.returnValue(batch_size) diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 7f4d02b25b..929f6b0d39 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -26,8 +26,11 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import LoggingTransaction, make_in_list_sql_clause -from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage._base import ( + LoggingTransaction, + SQLBaseStore, + make_in_list_sql_clause, +) from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.engines import Sqlite3Engine from synapse.storage.roommember import ( @@ -831,17 +834,17 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) -class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): +class RoomMemberBackgroundUpdateStore(SQLBaseStore): def __init__(self, db_conn, hs): super(RoomMemberBackgroundUpdateStore, self).__init__(db_conn, hs) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, self._background_current_state_membership, ) - self.register_background_index_update( + self.db.updates.register_background_index_update( "room_membership_forgotten_idx", index_name="room_memberships_user_room_forgotten", table="room_memberships", @@ -909,7 +912,7 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): "max_stream_id_exclusive": min_stream_id, } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress ) @@ -920,7 +923,9 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): ) if not result: - yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME) + yield self.db.updates._end_background_update( + _MEMBERSHIP_PROFILE_UPDATE_NAME + ) return result @@ -959,7 +964,7 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): last_processed_room = next_room - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, {"last_processed_room": last_processed_room}, @@ -978,7 +983,9 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): ) if finished: - yield self._end_background_update(_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME) + yield self.db.updates._end_background_update( + _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME + ) return row_count diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py index 55a604850e..ffa1817e64 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -24,8 +24,7 @@ from canonicaljson import json from twisted.internet import defer from synapse.api.errors import SynapseError -from synapse.storage._base import make_in_list_sql_clause -from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.engines import PostgresEngine, Sqlite3Engine logger = logging.getLogger(__name__) @@ -36,7 +35,7 @@ SearchEntry = namedtuple( ) -class SearchBackgroundUpdateStore(BackgroundUpdateStore): +class SearchBackgroundUpdateStore(SQLBaseStore): EVENT_SEARCH_UPDATE_NAME = "event_search" EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" @@ -49,10 +48,10 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): if not hs.config.enable_search: return - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_reindex_search_order ) @@ -61,9 +60,11 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): # a GIN index. However, it's possible that some people might still have # the background update queued, so we register a handler to clear the # background update. - self.register_noop_background_update(self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME) + self.db.updates.register_noop_background_update( + self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME + ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search ) @@ -153,7 +154,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): "rows_inserted": rows_inserted + len(event_search_rows), } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, self.EVENT_SEARCH_UPDATE_NAME, progress ) @@ -164,7 +165,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): ) if not result: - yield self._end_background_update(self.EVENT_SEARCH_UPDATE_NAME) + yield self.db.updates._end_background_update(self.EVENT_SEARCH_UPDATE_NAME) return result @@ -208,7 +209,9 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): if isinstance(self.database_engine, PostgresEngine): yield self.db.runWithConnection(create_index) - yield self._end_background_update(self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME) + yield self.db.updates._end_background_update( + self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME + ) return 1 @defer.inlineCallbacks @@ -244,7 +247,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): yield self.db.runInteraction( self.EVENT_SEARCH_ORDER_UPDATE_NAME, - self._background_update_progress_txn, + self.db.updates._background_update_progress_txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, pg, ) @@ -274,7 +277,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): "have_added_indexes": True, } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, progress ) @@ -285,7 +288,9 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): ) if not finished: - yield self._end_background_update(self.EVENT_SEARCH_ORDER_UPDATE_NAME) + yield self.db.updates._end_background_update( + self.EVENT_SEARCH_ORDER_UPDATE_NAME + ) return num_rows diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 851e81d6b3..7d5a9f8128 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -27,7 +27,6 @@ from synapse.api.errors import NotFoundError from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.storage._base import SQLBaseStore -from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.engines import PostgresEngine from synapse.storage.state import StateFilter @@ -1023,9 +1022,7 @@ class StateGroupWorkerStore( return set(row["state_group"] for row in rows) -class StateBackgroundUpdateStore( - StateGroupBackgroundUpdateStore, BackgroundUpdateStore -): +class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" @@ -1034,21 +1031,21 @@ class StateBackgroundUpdateStore( def __init__(self, db_conn, hs): super(StateBackgroundUpdateStore, self).__init__(db_conn, hs) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, self._background_deduplicate_state, ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state ) - self.register_background_index_update( + self.db.updates.register_background_index_update( self.CURRENT_STATE_INDEX_UPDATE_NAME, index_name="current_state_events_member_index", table="current_state_events", columns=["state_key"], where_clause="type='m.room.member'", ) - self.register_background_index_update( + self.db.updates.register_background_index_update( self.EVENT_STATE_GROUP_INDEX_UPDATE_NAME, index_name="event_to_state_groups_sg_index", table="event_to_state_groups", @@ -1181,7 +1178,7 @@ class StateBackgroundUpdateStore( "max_group": max_group, } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress ) @@ -1192,7 +1189,7 @@ class StateBackgroundUpdateStore( ) if finished: - yield self._end_background_update( + yield self.db.updates._end_background_update( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME ) @@ -1224,7 +1221,7 @@ class StateBackgroundUpdateStore( yield self.db.runWithConnection(reindex_txn) - yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) + yield self.db.updates._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) return 1 diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py index 974ffc15bd..6b91988c2a 100644 --- a/synapse/storage/data_stores/main/stats.py +++ b/synapse/storage/data_stores/main/stats.py @@ -68,17 +68,17 @@ class StatsStore(StateDeltasStore): self.stats_delta_processing_lock = DeferredLock() - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_stats_process_rooms", self._populate_stats_process_rooms ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_stats_process_users", self._populate_stats_process_users ) # we no longer need to perform clean-up, but we will give ourselves # the potential to reintroduce it in the future – so documentation # will still encourage the use of this no-op handler. - self.register_noop_background_update("populate_stats_cleanup") - self.register_noop_background_update("populate_stats_prepare") + self.db.updates.register_noop_background_update("populate_stats_cleanup") + self.db.updates.register_noop_background_update("populate_stats_prepare") def quantise_stats_time(self, ts): """ @@ -102,7 +102,7 @@ class StatsStore(StateDeltasStore): This is a background update which regenerates statistics for users. """ if not self.stats_enabled: - yield self._end_background_update("populate_stats_process_users") + yield self.db.updates._end_background_update("populate_stats_process_users") return 1 last_user_id = progress.get("last_user_id", "") @@ -123,7 +123,7 @@ class StatsStore(StateDeltasStore): # No more rooms -- complete the transaction. if not users_to_work_on: - yield self._end_background_update("populate_stats_process_users") + yield self.db.updates._end_background_update("populate_stats_process_users") return 1 for user_id in users_to_work_on: @@ -132,7 +132,7 @@ class StatsStore(StateDeltasStore): yield self.db.runInteraction( "populate_stats_process_users", - self._background_update_progress_txn, + self.db.updates._background_update_progress_txn, "populate_stats_process_users", progress, ) @@ -145,7 +145,7 @@ class StatsStore(StateDeltasStore): This is a background update which regenerates statistics for rooms. """ if not self.stats_enabled: - yield self._end_background_update("populate_stats_process_rooms") + yield self.db.updates._end_background_update("populate_stats_process_rooms") return 1 last_room_id = progress.get("last_room_id", "") @@ -166,7 +166,7 @@ class StatsStore(StateDeltasStore): # No more rooms -- complete the transaction. if not rooms_to_work_on: - yield self._end_background_update("populate_stats_process_rooms") + yield self.db.updates._end_background_update("populate_stats_process_rooms") return 1 for room_id in rooms_to_work_on: @@ -175,7 +175,7 @@ class StatsStore(StateDeltasStore): yield self.db.runInteraction( "_populate_stats_process_rooms", - self._background_update_progress_txn, + self.db.updates._background_update_progress_txn, "populate_stats_process_rooms", progress, ) diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/data_stores/main/user_directory.py index 7118bd62f3..62ffb34b29 100644 --- a/synapse/storage/data_stores/main/user_directory.py +++ b/synapse/storage/data_stores/main/user_directory.py @@ -19,7 +19,6 @@ import re from twisted.internet import defer from synapse.api.constants import EventTypes, JoinRules -from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.data_stores.main.state import StateFilter from synapse.storage.data_stores.main.state_deltas import StateDeltasStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine @@ -32,7 +31,7 @@ logger = logging.getLogger(__name__) TEMP_TABLE = "_temp_populate_user_directory" -class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore): +class UserDirectoryBackgroundUpdateStore(StateDeltasStore): # How many records do we calculate before sending it to # add_users_who_share_private_rooms? @@ -43,19 +42,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore self.server_name = hs.hostname - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_user_directory_createtables", self._populate_user_directory_createtables, ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_user_directory_process_rooms", self._populate_user_directory_process_rooms, ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_user_directory_process_users", self._populate_user_directory_process_users, ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_user_directory_cleanup", self._populate_user_directory_cleanup ) @@ -108,7 +107,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore ) yield self.db.simple_insert(TEMP_TABLE + "_position", {"position": new_pos}) - yield self._end_background_update("populate_user_directory_createtables") + yield self.db.updates._end_background_update( + "populate_user_directory_createtables" + ) return 1 @defer.inlineCallbacks @@ -130,7 +131,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore "populate_user_directory_cleanup", _delete_staging_area ) - yield self._end_background_update("populate_user_directory_cleanup") + yield self.db.updates._end_background_update("populate_user_directory_cleanup") return 1 @defer.inlineCallbacks @@ -176,7 +177,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore # No more rooms -- complete the transaction. if not rooms_to_work_on: - yield self._end_background_update("populate_user_directory_process_rooms") + yield self.db.updates._end_background_update( + "populate_user_directory_process_rooms" + ) return 1 logger.info( @@ -248,7 +251,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore progress["remaining"] -= 1 yield self.db.runInteraction( "populate_user_directory", - self._background_update_progress_txn, + self.db.updates._background_update_progress_txn, "populate_user_directory_process_rooms", progress, ) @@ -267,7 +270,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore If search_all_users is enabled, add all of the users to the user directory. """ if not self.hs.config.user_directory_search_all_users: - yield self._end_background_update("populate_user_directory_process_users") + yield self.db.updates._end_background_update( + "populate_user_directory_process_users" + ) return 1 def _get_next_batch(txn): @@ -297,7 +302,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore # No more users -- complete the transaction. if not users_to_work_on: - yield self._end_background_update("populate_user_directory_process_users") + yield self.db.updates._end_background_update( + "populate_user_directory_process_users" + ) return 1 logger.info( @@ -317,7 +324,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore progress["remaining"] -= 1 yield self.db.runInteraction( "populate_user_directory", - self._background_update_progress_txn, + self.db.updates._background_update_progress_txn, "populate_user_directory_process_users", progress, ) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index ac64d80806..be36c1b829 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -30,6 +30,7 @@ from twisted.internet import defer from synapse.api.errors import StoreError from synapse.logging.context import LoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.util.stringutils import exception_to_unicode @@ -223,6 +224,8 @@ class Database(object): self._clock = hs.get_clock() self._db_pool = hs.get_db_pool() + self.updates = BackgroundUpdater(hs, self) + self._previous_txn_total_time = 0 self._current_txn_total_time = 0 self._previous_loop_ts = 0 diff --git a/synmark/__init__.py b/synmark/__init__.py index 570eb818d9..afe4fad8cb 100644 --- a/synmark/__init__.py +++ b/synmark/__init__.py @@ -47,9 +47,9 @@ async def make_homeserver(reactor, config=None): stor = hs.get_datastore() # Run the database background updates. - if hasattr(stor, "do_next_background_update"): - while not await stor.has_completed_background_updates(): - await stor.do_next_background_update(1) + if hasattr(stor.db.updates, "do_next_background_update"): + while not await stor.db.updates.has_completed_background_updates(): + await stor.db.updates.do_next_background_update(1) def cleanup(): for i in cleanup_tasks: diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 7f7962c3dd..d9d312f0fb 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -42,7 +42,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): Add the background updates we need to run. """ # Ugh, have to reset this flag - self.store._all_done = False + self.store.db.updates._all_done = False self.get_success( self.store.db.simple_insert( @@ -108,8 +108,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): # Do the initial population of the stats via the background update self._add_background_updates() - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) def test_initial_room(self): """ @@ -141,8 +145,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): # Do the initial population of the user directory via the background update self._add_background_updates() - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) r = self.get_success(self.get_all_room_state()) @@ -178,7 +186,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): # the position that the deltas should begin at, once they take over. self.hs.config.stats_enabled = True self.handler.stats_enabled = True - self.store._all_done = False + self.store.db.updates._all_done = False self.get_success( self.store.db.simple_update_one( table="stats_incremental_position", @@ -194,8 +202,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) # Now, before the table is actually ingested, add some more events. self.helper.invite(room=room_1, src=u1, targ=u2, tok=u1_token) @@ -221,9 +233,13 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) - self.store._all_done = False - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + self.store.db.updates._all_done = False + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) self.reactor.advance(86401) @@ -653,7 +669,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): # preparation stage of the initial background update # Ugh, have to reset this flag - self.store._all_done = False + self.store.db.updates._all_done = False self.get_success( self.store.db.simple_delete( @@ -673,7 +689,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): # now do the background updates - self.store._all_done = False + self.store.db.updates._all_done = False self.get_success( self.store.db.simple_insert( "background_updates", @@ -705,8 +721,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) r1stats_complete = self._get_current_stats("room", r1) u1stats_complete = self._get_current_stats("user", u1) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index bc9d441541..26071059d2 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -181,7 +181,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): Add the background updates we need to run. """ # Ugh, have to reset this flag - self.store._all_done = False + self.store.db.updates._all_done = False self.get_success( self.store.db.simple_insert( @@ -255,8 +255,12 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # Do the initial population of the user directory via the background update self._add_background_updates() - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) shares_private = self.get_users_who_share_private_rooms() public_users = self.get_users_in_public_rooms() @@ -290,8 +294,12 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # Do the initial population of the user directory via the background update self._add_background_updates() - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) shares_private = self.get_users_who_share_private_rooms() public_users = self.get_users_in_public_rooms() diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index e360297df9..aec76f4ab1 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -15,7 +15,7 @@ class BackgroundUpdateTestCase(unittest.TestCase): self.update_handler = Mock() - yield self.store.register_background_update_handler( + yield self.store.db.updates.register_background_update_handler( "test_update", self.update_handler ) @@ -23,7 +23,7 @@ class BackgroundUpdateTestCase(unittest.TestCase): # (perhaps we should run them as part of the test HS setup, since we # run all of the other schema setup stuff there?) while True: - res = yield self.store.do_next_background_update(1000) + res = yield self.store.db.updates.do_next_background_update(1000) if res is None: break @@ -39,7 +39,7 @@ class BackgroundUpdateTestCase(unittest.TestCase): progress = {"my_key": progress["my_key"] + 1} yield self.store.db.runInteraction( "update_progress", - self.store._background_update_progress_txn, + self.store.db.updates._background_update_progress_txn, "test_update", progress, ) @@ -47,29 +47,37 @@ class BackgroundUpdateTestCase(unittest.TestCase): self.update_handler.side_effect = update - yield self.store.start_background_update("test_update", {"my_key": 1}) + yield self.store.db.updates.start_background_update( + "test_update", {"my_key": 1} + ) self.update_handler.reset_mock() - result = yield self.store.do_next_background_update(duration_ms * desired_count) + result = yield self.store.db.updates.do_next_background_update( + duration_ms * desired_count + ) self.assertIsNotNone(result) self.update_handler.assert_called_once_with( - {"my_key": 1}, self.store.DEFAULT_BACKGROUND_BATCH_SIZE + {"my_key": 1}, self.store.db.updates.DEFAULT_BACKGROUND_BATCH_SIZE ) # second step: complete the update @defer.inlineCallbacks def update(progress, count): - yield self.store._end_background_update("test_update") + yield self.store.db.updates._end_background_update("test_update") return count self.update_handler.side_effect = update self.update_handler.reset_mock() - result = yield self.store.do_next_background_update(duration_ms * desired_count) + result = yield self.store.db.updates.do_next_background_update( + duration_ms * desired_count + ) self.assertIsNotNone(result) self.update_handler.assert_called_once_with({"my_key": 2}, desired_count) # third step: we don't expect to be called any more self.update_handler.reset_mock() - result = yield self.store.do_next_background_update(duration_ms * desired_count) + result = yield self.store.db.updates.do_next_background_update( + duration_ms * desired_count + ) self.assertIsNone(result) self.assertFalse(self.update_handler.called) diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index e454bbff29..029ac26454 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -46,7 +46,9 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): """Re run the background update to clean up the extremities. """ # Make sure we don't clash with in progress updates. - self.assertTrue(self.store._all_done, "Background updates are still ongoing") + self.assertTrue( + self.store.db.updates._all_done, "Background updates are still ongoing" + ) schema_path = os.path.join( prepare_database.dir_path, @@ -68,10 +70,14 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): ) # Ugh, have to reset this flag - self.store._all_done = False + self.store.db.updates._all_done = False - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) def test_soft_failed_extremities_handled_correctly(self): """Test that extremities are correctly calculated in the presence of diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index c4f838907c..fc279340d4 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -202,8 +202,12 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): def test_devices_last_seen_bg_update(self): # First make sure we have completed all updates. - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) # Insert a user IP user_id = "@user:id" @@ -256,11 +260,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) # ... and tell the DataStore that it hasn't finished all updates yet - self.store._all_done = False + self.store.db.updates._all_done = False # Now let's actually drive the updates to completion - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) # We should now get the correct result again result = self.get_success( @@ -281,8 +289,12 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): def test_old_user_ips_pruned(self): # First make sure we have completed all updates. - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) # Insert a user IP user_id = "@user:id" diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 5f957680a2..7840f63fe3 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -122,8 +122,12 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): def test_can_rerun_update(self): # First make sure we have completed all updates. - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) # Now let's create a room, which will insert a membership user = UserID("alice", "test") @@ -143,8 +147,12 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): ) # ... and tell the DataStore that it hasn't finished all updates yet - self.store._all_done = False + self.store.db.updates._all_done = False # Now let's actually drive the updates to completion - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) diff --git a/tests/unittest.py b/tests/unittest.py index fc856a574a..68d245ec9f 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -401,10 +401,12 @@ class HomeserverTestCase(TestCase): hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) stor = hs.get_datastore() - # Run the database background updates. - if hasattr(stor, "do_next_background_update"): - while not self.get_success(stor.has_completed_background_updates()): - self.get_success(stor.do_next_background_update(1)) + # Run the database background updates, when running against "master". + if hs.__class__.__name__ == "TestHomeServer": + while not self.get_success( + stor.db.updates.has_completed_background_updates() + ): + self.get_success(stor.db.updates.do_next_background_update(1)) return hs -- cgit 1.5.1 From 4ca3ef10b9a8d15cf351d67d574088d944c2e3b1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 5 Dec 2019 15:53:10 +0000 Subject: Fixup tests --- tests/rest/client/v1/test_presence.py | 3 +++ tests/rest/client/v1/test_profile.py | 10 +++++++++- tests/utils.py | 4 +++- 3 files changed, 15 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 66c2b68707..0fdff79aa7 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -15,6 +15,8 @@ from mock import Mock +from twisted.internet import defer + from synapse.rest.client.v1 import presence from synapse.types import UserID @@ -36,6 +38,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): ) hs.presence_handler = Mock() + hs.presence_handler.set_state.return_value = defer.succeed(None) return hs diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 140d8b3772..12c5e95cb5 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -52,6 +52,14 @@ class MockHandlerProfileTestCase(unittest.TestCase): ] ) + self.mock_handler.get_displayname.return_value = defer.succeed(Mock()) + self.mock_handler.set_displayname.return_value = defer.succeed(Mock()) + self.mock_handler.get_avatar_url.return_value = defer.succeed(Mock()) + self.mock_handler.set_avatar_url.return_value = defer.succeed(Mock()) + self.mock_handler.check_profile_query_allowed.return_value = defer.succeed( + Mock() + ) + hs = yield setup_test_homeserver( self.addCleanup, "test", @@ -63,7 +71,7 @@ class MockHandlerProfileTestCase(unittest.TestCase): ) def _get_user_by_req(request=None, allow_guest=False): - return synapse.types.create_requester(myid) + return defer.succeed(synapse.types.create_requester(myid)) hs.get_auth().get_user_by_req = _get_user_by_req diff --git a/tests/utils.py b/tests/utils.py index de2ac1ed33..c57da59191 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -461,7 +461,9 @@ class MockHttpResource(HttpServer): try: args = [urlparse.unquote(u) for u in matcher.groups()] - (code, response) = yield func(mock_request, *args) + (code, response) = yield defer.ensureDeferred( + func(mock_request, *args) + ) return code, response except CodeMessageException as e: return (e.code, cs_error(e.msg, code=e.errcode)) -- cgit 1.5.1 From 8437e2383ed2dffacca5395851023eeacb33d7ba Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 5 Dec 2019 17:58:25 +0000 Subject: Port SyncHandler to async/await --- synapse/handlers/events.py | 30 +++--- synapse/handlers/sync.py | 251 +++++++++++++++++++++----------------------- synapse/notifier.py | 29 +++-- synapse/util/metrics.py | 23 ++-- tests/handlers/test_sync.py | 33 +++--- tests/unittest.py | 7 +- 6 files changed, 182 insertions(+), 191 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 45fe13c62f..ec18a42a68 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -16,8 +16,6 @@ import logging import random -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, SynapseError from synapse.events import EventBase @@ -50,9 +48,8 @@ class EventStreamHandler(BaseHandler): self._server_notices_sender = hs.get_server_notices_sender() self._event_serializer = hs.get_event_client_serializer() - @defer.inlineCallbacks @log_function - def get_stream( + async def get_stream( self, auth_user_id, pagin_config, @@ -69,17 +66,17 @@ class EventStreamHandler(BaseHandler): """ if room_id: - blocked = yield self.store.is_room_blocked(room_id) + blocked = await self.store.is_room_blocked(room_id) if blocked: raise SynapseError(403, "This room has been blocked on this server") # send any outstanding server notices to the user. - yield self._server_notices_sender.on_user_syncing(auth_user_id) + await self._server_notices_sender.on_user_syncing(auth_user_id) auth_user = UserID.from_string(auth_user_id) presence_handler = self.hs.get_presence_handler() - context = yield presence_handler.user_syncing( + context = await presence_handler.user_syncing( auth_user_id, affect_presence=affect_presence ) with context: @@ -91,7 +88,7 @@ class EventStreamHandler(BaseHandler): # thundering herds on restart. timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1)) - events, tokens = yield self.notifier.get_events_for( + events, tokens = await self.notifier.get_events_for( auth_user, pagin_config, timeout, @@ -112,14 +109,14 @@ class EventStreamHandler(BaseHandler): # Send down presence. if event.state_key == auth_user_id: # Send down presence for everyone in the room. - users = yield self.state.get_current_users_in_room( + users = await self.state.get_current_users_in_room( event.room_id ) - states = yield presence_handler.get_states(users, as_event=True) + states = await presence_handler.get_states(users, as_event=True) to_add.extend(states) else: - ev = yield presence_handler.get_state( + ev = await presence_handler.get_state( UserID.from_string(event.state_key), as_event=True ) to_add.append(ev) @@ -128,7 +125,7 @@ class EventStreamHandler(BaseHandler): time_now = self.clock.time_msec() - chunks = yield self._event_serializer.serialize_events( + chunks = await self._event_serializer.serialize_events( events, time_now, as_client_event=as_client_event, @@ -151,8 +148,7 @@ class EventHandler(BaseHandler): super(EventHandler, self).__init__(hs) self.storage = hs.get_storage() - @defer.inlineCallbacks - def get_event(self, user, room_id, event_id): + async def get_event(self, user, room_id, event_id): """Retrieve a single specified event. Args: @@ -167,15 +163,15 @@ class EventHandler(BaseHandler): AuthError if the user does not have the rights to inspect this event. """ - event = yield self.store.get_event(event_id, check_room_id=room_id) + event = await self.store.get_event(event_id, check_room_id=room_id) if not event: return None - users = yield self.store.get_users_in_room(event.room_id) + users = await self.store.get_users_in_room(event.room_id) is_peeking = user.to_string() not in users - filtered = yield filter_events_for_client( + filtered = await filter_events_for_client( self.storage, user.to_string(), [event], is_peeking=is_peeking ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index b536d410e5..12751fd8c0 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -22,8 +22,6 @@ from six import iteritems, itervalues from prometheus_client import Counter -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.logging.context import LoggingContext from synapse.push.clientformat import format_push_rules_for_user @@ -241,8 +239,7 @@ class SyncHandler(object): expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, ) - @defer.inlineCallbacks - def wait_for_sync_for_user( + async def wait_for_sync_for_user( self, sync_config, since_token=None, timeout=0, full_state=False ): """Get the sync for a client if we have new data for it now. Otherwise @@ -255,9 +252,9 @@ class SyncHandler(object): # not been exceeded (if not part of the group by this point, almost certain # auth_blocking will occur) user_id = sync_config.user.to_string() - yield self.auth.check_auth_blocking(user_id) + await self.auth.check_auth_blocking(user_id) - res = yield self.response_cache.wrap( + res = await self.response_cache.wrap( sync_config.request_key, self._wait_for_sync_for_user, sync_config, @@ -267,8 +264,9 @@ class SyncHandler(object): ) return res - @defer.inlineCallbacks - def _wait_for_sync_for_user(self, sync_config, since_token, timeout, full_state): + async def _wait_for_sync_for_user( + self, sync_config, since_token, timeout, full_state + ): if since_token is None: sync_type = "initial_sync" elif full_state: @@ -283,7 +281,7 @@ class SyncHandler(object): if timeout == 0 or since_token is None or full_state: # we are going to return immediately, so don't bother calling # notifier.wait_for_events. - result = yield self.current_sync_for_user( + result = await self.current_sync_for_user( sync_config, since_token, full_state=full_state ) else: @@ -291,7 +289,7 @@ class SyncHandler(object): def current_sync_callback(before_token, after_token): return self.current_sync_for_user(sync_config, since_token) - result = yield self.notifier.wait_for_events( + result = await self.notifier.wait_for_events( sync_config.user.to_string(), timeout, current_sync_callback, @@ -314,15 +312,13 @@ class SyncHandler(object): """ return self.generate_sync_result(sync_config, since_token, full_state) - @defer.inlineCallbacks - def push_rules_for_user(self, user): + async def push_rules_for_user(self, user): user_id = user.to_string() - rules = yield self.store.get_push_rules_for_user(user_id) + rules = await self.store.get_push_rules_for_user(user_id) rules = format_push_rules_for_user(user, rules) return rules - @defer.inlineCallbacks - def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None): + async def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None): """Get the ephemeral events for each room the user is in Args: sync_result_builder(SyncResultBuilder) @@ -343,7 +339,7 @@ class SyncHandler(object): room_ids = sync_result_builder.joined_room_ids typing_source = self.event_sources.sources["typing"] - typing, typing_key = yield typing_source.get_new_events( + typing, typing_key = typing_source.get_new_events( user=sync_config.user, from_key=typing_key, limit=sync_config.filter_collection.ephemeral_limit(), @@ -365,7 +361,7 @@ class SyncHandler(object): receipt_key = since_token.receipt_key if since_token else "0" receipt_source = self.event_sources.sources["receipt"] - receipts, receipt_key = yield receipt_source.get_new_events( + receipts, receipt_key = await receipt_source.get_new_events( user=sync_config.user, from_key=receipt_key, limit=sync_config.filter_collection.ephemeral_limit(), @@ -382,8 +378,7 @@ class SyncHandler(object): return now_token, ephemeral_by_room - @defer.inlineCallbacks - def _load_filtered_recents( + async def _load_filtered_recents( self, room_id, sync_config, @@ -415,10 +410,10 @@ class SyncHandler(object): # ensure that we always include current state in the timeline current_state_ids = frozenset() if any(e.is_state() for e in recents): - current_state_ids = yield self.state.get_current_state_ids(room_id) + current_state_ids = await self.state.get_current_state_ids(room_id) current_state_ids = frozenset(itervalues(current_state_ids)) - recents = yield filter_events_for_client( + recents = await filter_events_for_client( self.storage, sync_config.user.to_string(), recents, @@ -449,14 +444,14 @@ class SyncHandler(object): # Otherwise, we want to return the last N events in the room # in toplogical ordering. if since_key: - events, end_key = yield self.store.get_room_events_stream_for_room( + events, end_key = await self.store.get_room_events_stream_for_room( room_id, limit=load_limit + 1, from_key=since_key, to_key=end_key, ) else: - events, end_key = yield self.store.get_recent_events_for_room( + events, end_key = await self.store.get_recent_events_for_room( room_id, limit=load_limit + 1, end_token=end_key ) loaded_recents = sync_config.filter_collection.filter_room_timeline( @@ -468,10 +463,10 @@ class SyncHandler(object): # ensure that we always include current state in the timeline current_state_ids = frozenset() if any(e.is_state() for e in loaded_recents): - current_state_ids = yield self.state.get_current_state_ids(room_id) + current_state_ids = await self.state.get_current_state_ids(room_id) current_state_ids = frozenset(itervalues(current_state_ids)) - loaded_recents = yield filter_events_for_client( + loaded_recents = await filter_events_for_client( self.storage, sync_config.user.to_string(), loaded_recents, @@ -498,8 +493,7 @@ class SyncHandler(object): limited=limited or newly_joined_room, ) - @defer.inlineCallbacks - def get_state_after_event(self, event, state_filter=StateFilter.all()): + async def get_state_after_event(self, event, state_filter=StateFilter.all()): """ Get the room state after the given event @@ -511,7 +505,7 @@ class SyncHandler(object): Returns: A Deferred map from ((type, state_key)->Event) """ - state_ids = yield self.state_store.get_state_ids_for_event( + state_ids = await self.state_store.get_state_ids_for_event( event.event_id, state_filter=state_filter ) if event.is_state(): @@ -519,8 +513,9 @@ class SyncHandler(object): state_ids[(event.type, event.state_key)] = event.event_id return state_ids - @defer.inlineCallbacks - def get_state_at(self, room_id, stream_position, state_filter=StateFilter.all()): + async def get_state_at( + self, room_id, stream_position, state_filter=StateFilter.all() + ): """ Get the room state at a particular stream position Args: @@ -536,13 +531,13 @@ class SyncHandler(object): # get_recent_events_for_room operates by topo ordering. This therefore # does not reliably give you the state at the given stream position. # (https://github.com/matrix-org/synapse/issues/3305) - last_events, _ = yield self.store.get_recent_events_for_room( + last_events, _ = await self.store.get_recent_events_for_room( room_id, end_token=stream_position.room_key, limit=1 ) if last_events: last_event = last_events[-1] - state = yield self.get_state_after_event( + state = await self.get_state_after_event( last_event, state_filter=state_filter ) @@ -551,8 +546,7 @@ class SyncHandler(object): state = {} return state - @defer.inlineCallbacks - def compute_summary(self, room_id, sync_config, batch, state, now_token): + async def compute_summary(self, room_id, sync_config, batch, state, now_token): """ Works out a room summary block for this room, summarising the number of joined members in the room, and providing the 'hero' members if the room has no name so clients can consistently name rooms. Also adds @@ -574,7 +568,7 @@ class SyncHandler(object): # FIXME: we could/should get this from room_stats when matthew/stats lands # FIXME: this promulgates https://github.com/matrix-org/synapse/issues/3305 - last_events, _ = yield self.store.get_recent_event_ids_for_room( + last_events, _ = await self.store.get_recent_event_ids_for_room( room_id, end_token=now_token.room_key, limit=1 ) @@ -582,7 +576,7 @@ class SyncHandler(object): return None last_event = last_events[-1] - state_ids = yield self.state_store.get_state_ids_for_event( + state_ids = await self.state_store.get_state_ids_for_event( last_event.event_id, state_filter=StateFilter.from_types( [(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")] @@ -590,7 +584,7 @@ class SyncHandler(object): ) # this is heavily cached, thus: fast. - details = yield self.store.get_room_summary(room_id) + details = await self.store.get_room_summary(room_id) name_id = state_ids.get((EventTypes.Name, "")) canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, "")) @@ -608,12 +602,12 @@ class SyncHandler(object): # calculating heroes. Empty strings are falsey, so we check # for the "name" value and default to an empty string. if name_id: - name = yield self.store.get_event(name_id, allow_none=True) + name = await self.store.get_event(name_id, allow_none=True) if name and name.content.get("name"): return summary if canonical_alias_id: - canonical_alias = yield self.store.get_event( + canonical_alias = await self.store.get_event( canonical_alias_id, allow_none=True ) if canonical_alias and canonical_alias.content.get("alias"): @@ -678,7 +672,7 @@ class SyncHandler(object): ) ] - missing_hero_state = yield self.store.get_events(missing_hero_event_ids) + missing_hero_state = await self.store.get_events(missing_hero_event_ids) missing_hero_state = missing_hero_state.values() for s in missing_hero_state: @@ -697,8 +691,7 @@ class SyncHandler(object): logger.debug("found LruCache for %r", cache_key) return cache - @defer.inlineCallbacks - def compute_state_delta( + async def compute_state_delta( self, room_id, batch, sync_config, since_token, now_token, full_state ): """ Works out the difference in state between the start of the timeline @@ -759,16 +752,16 @@ class SyncHandler(object): if full_state: if batch: - current_state_ids = yield self.state_store.get_state_ids_for_event( + current_state_ids = await self.state_store.get_state_ids_for_event( batch.events[-1].event_id, state_filter=state_filter ) - state_ids = yield self.state_store.get_state_ids_for_event( + state_ids = await self.state_store.get_state_ids_for_event( batch.events[0].event_id, state_filter=state_filter ) else: - current_state_ids = yield self.get_state_at( + current_state_ids = await self.get_state_at( room_id, stream_position=now_token, state_filter=state_filter ) @@ -783,13 +776,13 @@ class SyncHandler(object): ) elif batch.limited: if batch: - state_at_timeline_start = yield self.state_store.get_state_ids_for_event( + state_at_timeline_start = await self.state_store.get_state_ids_for_event( batch.events[0].event_id, state_filter=state_filter ) else: # We can get here if the user has ignored the senders of all # the recent events. - state_at_timeline_start = yield self.get_state_at( + state_at_timeline_start = await self.get_state_at( room_id, stream_position=now_token, state_filter=state_filter ) @@ -807,19 +800,19 @@ class SyncHandler(object): # about them). state_filter = StateFilter.all() - state_at_previous_sync = yield self.get_state_at( + state_at_previous_sync = await self.get_state_at( room_id, stream_position=since_token, state_filter=state_filter ) if batch: - current_state_ids = yield self.state_store.get_state_ids_for_event( + current_state_ids = await self.state_store.get_state_ids_for_event( batch.events[-1].event_id, state_filter=state_filter ) else: # Its not clear how we get here, but empirically we do # (#5407). Logging has been added elsewhere to try and # figure out where this state comes from. - current_state_ids = yield self.get_state_at( + current_state_ids = await self.get_state_at( room_id, stream_position=now_token, state_filter=state_filter ) @@ -843,7 +836,7 @@ class SyncHandler(object): # So we fish out all the member events corresponding to the # timeline here, and then dedupe any redundant ones below. - state_ids = yield self.state_store.get_state_ids_for_event( + state_ids = await self.state_store.get_state_ids_for_event( batch.events[0].event_id, # we only want members! state_filter=StateFilter.from_types( @@ -883,7 +876,7 @@ class SyncHandler(object): state = {} if state_ids: - state = yield self.store.get_events(list(state_ids.values())) + state = await self.store.get_events(list(state_ids.values())) return { (e.type, e.state_key): e @@ -892,10 +885,9 @@ class SyncHandler(object): ) } - @defer.inlineCallbacks - def unread_notifs_for_room_id(self, room_id, sync_config): + async def unread_notifs_for_room_id(self, room_id, sync_config): with Measure(self.clock, "unread_notifs_for_room_id"): - last_unread_event_id = yield self.store.get_last_receipt_event_id_for_user( + last_unread_event_id = await self.store.get_last_receipt_event_id_for_user( user_id=sync_config.user.to_string(), room_id=room_id, receipt_type="m.read", @@ -903,7 +895,7 @@ class SyncHandler(object): notifs = [] if last_unread_event_id: - notifs = yield self.store.get_unread_event_push_actions_by_room_for_user( + notifs = await self.store.get_unread_event_push_actions_by_room_for_user( room_id, sync_config.user.to_string(), last_unread_event_id ) return notifs @@ -912,8 +904,9 @@ class SyncHandler(object): # count is whatever it was last time. return None - @defer.inlineCallbacks - def generate_sync_result(self, sync_config, since_token=None, full_state=False): + async def generate_sync_result( + self, sync_config, since_token=None, full_state=False + ): """Generates a sync result. Args: @@ -928,7 +921,7 @@ class SyncHandler(object): # this is due to some of the underlying streams not supporting the ability # to query up to a given point. # Always use the `now_token` in `SyncResultBuilder` - now_token = yield self.event_sources.get_current_token() + now_token = await self.event_sources.get_current_token() logger.info( "Calculating sync response for %r between %s and %s", @@ -944,10 +937,9 @@ class SyncHandler(object): # See https://github.com/matrix-org/matrix-doc/issues/1144 raise NotImplementedError() else: - joined_room_ids = yield self.get_rooms_for_user_at( + joined_room_ids = await self.get_rooms_for_user_at( user_id, now_token.room_stream_id ) - sync_result_builder = SyncResultBuilder( sync_config, full_state, @@ -956,11 +948,11 @@ class SyncHandler(object): joined_room_ids=joined_room_ids, ) - account_data_by_room = yield self._generate_sync_entry_for_account_data( + account_data_by_room = await self._generate_sync_entry_for_account_data( sync_result_builder ) - res = yield self._generate_sync_entry_for_rooms( + res = await self._generate_sync_entry_for_rooms( sync_result_builder, account_data_by_room ) newly_joined_rooms, newly_joined_or_invited_users, _, _ = res @@ -970,13 +962,13 @@ class SyncHandler(object): since_token is None and sync_config.filter_collection.blocks_all_presence() ) if self.hs_config.use_presence and not block_all_presence_data: - yield self._generate_sync_entry_for_presence( + await self._generate_sync_entry_for_presence( sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users ) - yield self._generate_sync_entry_for_to_device(sync_result_builder) + await self._generate_sync_entry_for_to_device(sync_result_builder) - device_lists = yield self._generate_sync_entry_for_device_list( + device_lists = await self._generate_sync_entry_for_device_list( sync_result_builder, newly_joined_rooms=newly_joined_rooms, newly_joined_or_invited_users=newly_joined_or_invited_users, @@ -987,11 +979,11 @@ class SyncHandler(object): device_id = sync_config.device_id one_time_key_counts = {} if device_id: - one_time_key_counts = yield self.store.count_e2e_one_time_keys( + one_time_key_counts = await self.store.count_e2e_one_time_keys( user_id, device_id ) - yield self._generate_sync_entry_for_groups(sync_result_builder) + await self._generate_sync_entry_for_groups(sync_result_builder) # debug for https://github.com/matrix-org/synapse/issues/4422 for joined_room in sync_result_builder.joined: @@ -1015,18 +1007,17 @@ class SyncHandler(object): ) @measure_func("_generate_sync_entry_for_groups") - @defer.inlineCallbacks - def _generate_sync_entry_for_groups(self, sync_result_builder): + async def _generate_sync_entry_for_groups(self, sync_result_builder): user_id = sync_result_builder.sync_config.user.to_string() since_token = sync_result_builder.since_token now_token = sync_result_builder.now_token if since_token and since_token.groups_key: - results = yield self.store.get_groups_changes_for_user( + results = self.store.get_groups_changes_for_user( user_id, since_token.groups_key, now_token.groups_key ) else: - results = yield self.store.get_all_groups_for_user( + results = await self.store.get_all_groups_for_user( user_id, now_token.groups_key ) @@ -1059,8 +1050,7 @@ class SyncHandler(object): ) @measure_func("_generate_sync_entry_for_device_list") - @defer.inlineCallbacks - def _generate_sync_entry_for_device_list( + async def _generate_sync_entry_for_device_list( self, sync_result_builder, newly_joined_rooms, @@ -1108,32 +1098,32 @@ class SyncHandler(object): # room with by looking at all users that have left a room plus users # that were in a room we've left. - users_who_share_room = yield self.store.get_users_who_share_room_with_user( + users_who_share_room = await self.store.get_users_who_share_room_with_user( user_id ) # Step 1a, check for changes in devices of users we share a room with - users_that_have_changed = yield self.store.get_users_whose_devices_changed( + users_that_have_changed = await self.store.get_users_whose_devices_changed( since_token.device_list_key, users_who_share_room ) # Step 1b, check for newly joined rooms for room_id in newly_joined_rooms: - joined_users = yield self.state.get_current_users_in_room(room_id) + joined_users = await self.state.get_current_users_in_room(room_id) newly_joined_or_invited_users.update(joined_users) # TODO: Check that these users are actually new, i.e. either they # weren't in the previous sync *or* they left and rejoined. users_that_have_changed.update(newly_joined_or_invited_users) - user_signatures_changed = yield self.store.get_users_whose_signatures_changed( + user_signatures_changed = await self.store.get_users_whose_signatures_changed( user_id, since_token.device_list_key ) users_that_have_changed.update(user_signatures_changed) # Now find users that we no longer track for room_id in newly_left_rooms: - left_users = yield self.state.get_current_users_in_room(room_id) + left_users = await self.state.get_current_users_in_room(room_id) newly_left_users.update(left_users) # Remove any users that we still share a room with. @@ -1143,8 +1133,7 @@ class SyncHandler(object): else: return DeviceLists(changed=[], left=[]) - @defer.inlineCallbacks - def _generate_sync_entry_for_to_device(self, sync_result_builder): + async def _generate_sync_entry_for_to_device(self, sync_result_builder): """Generates the portion of the sync response. Populates `sync_result_builder` with the result. @@ -1165,14 +1154,14 @@ class SyncHandler(object): # We only delete messages when a new message comes in, but that's # fine so long as we delete them at some point. - deleted = yield self.store.delete_messages_for_device( + deleted = await self.store.delete_messages_for_device( user_id, device_id, since_stream_id ) logger.debug( "Deleted %d to-device messages up to %d", deleted, since_stream_id ) - messages, stream_id = yield self.store.get_new_messages_for_device( + messages, stream_id = await self.store.get_new_messages_for_device( user_id, device_id, since_stream_id, now_token.to_device_key ) @@ -1190,8 +1179,7 @@ class SyncHandler(object): else: sync_result_builder.to_device = [] - @defer.inlineCallbacks - def _generate_sync_entry_for_account_data(self, sync_result_builder): + async def _generate_sync_entry_for_account_data(self, sync_result_builder): """Generates the account data portion of the sync response. Populates `sync_result_builder` with the result. @@ -1209,25 +1197,25 @@ class SyncHandler(object): ( account_data, account_data_by_room, - ) = yield self.store.get_updated_account_data_for_user( + ) = self.store.get_updated_account_data_for_user( user_id, since_token.account_data_key ) - push_rules_changed = yield self.store.have_push_rules_changed_for_user( + push_rules_changed = await self.store.have_push_rules_changed_for_user( user_id, int(since_token.push_rules_key) ) if push_rules_changed: - account_data["m.push_rules"] = yield self.push_rules_for_user( + account_data["m.push_rules"] = await self.push_rules_for_user( sync_config.user ) else: ( account_data, account_data_by_room, - ) = yield self.store.get_account_data_for_user(sync_config.user.to_string()) + ) = await self.store.get_account_data_for_user(sync_config.user.to_string()) - account_data["m.push_rules"] = yield self.push_rules_for_user( + account_data["m.push_rules"] = await self.push_rules_for_user( sync_config.user ) @@ -1242,8 +1230,7 @@ class SyncHandler(object): return account_data_by_room - @defer.inlineCallbacks - def _generate_sync_entry_for_presence( + async def _generate_sync_entry_for_presence( self, sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users ): """Generates the presence portion of the sync response. Populates the @@ -1271,7 +1258,7 @@ class SyncHandler(object): presence_key = None include_offline = False - presence, presence_key = yield presence_source.get_new_events( + presence, presence_key = await presence_source.get_new_events( user=user, from_key=presence_key, is_guest=sync_config.is_guest, @@ -1283,12 +1270,12 @@ class SyncHandler(object): extra_users_ids = set(newly_joined_or_invited_users) for room_id in newly_joined_rooms: - users = yield self.state.get_current_users_in_room(room_id) + users = await self.state.get_current_users_in_room(room_id) extra_users_ids.update(users) extra_users_ids.discard(user.to_string()) if extra_users_ids: - states = yield self.presence_handler.get_states(extra_users_ids) + states = await self.presence_handler.get_states(extra_users_ids) presence.extend(states) # Deduplicate the presence entries so that there's at most one per user @@ -1298,8 +1285,9 @@ class SyncHandler(object): sync_result_builder.presence = presence - @defer.inlineCallbacks - def _generate_sync_entry_for_rooms(self, sync_result_builder, account_data_by_room): + async def _generate_sync_entry_for_rooms( + self, sync_result_builder, account_data_by_room + ): """Generates the rooms portion of the sync response. Populates the `sync_result_builder` with the result. @@ -1321,7 +1309,7 @@ class SyncHandler(object): if block_all_room_ephemeral: ephemeral_by_room = {} else: - now_token, ephemeral_by_room = yield self.ephemeral_by_room( + now_token, ephemeral_by_room = await self.ephemeral_by_room( sync_result_builder, now_token=sync_result_builder.now_token, since_token=sync_result_builder.since_token, @@ -1333,16 +1321,16 @@ class SyncHandler(object): since_token = sync_result_builder.since_token if not sync_result_builder.full_state: if since_token and not ephemeral_by_room and not account_data_by_room: - have_changed = yield self._have_rooms_changed(sync_result_builder) + have_changed = await self._have_rooms_changed(sync_result_builder) if not have_changed: - tags_by_room = yield self.store.get_updated_tags( + tags_by_room = await self.store.get_updated_tags( user_id, since_token.account_data_key ) if not tags_by_room: logger.debug("no-oping sync") return [], [], [], [] - ignored_account_data = yield self.store.get_global_account_data_by_type_for_user( + ignored_account_data = await self.store.get_global_account_data_by_type_for_user( "m.ignored_user_list", user_id=user_id ) @@ -1352,18 +1340,18 @@ class SyncHandler(object): ignored_users = frozenset() if since_token: - res = yield self._get_rooms_changed(sync_result_builder, ignored_users) + res = await self._get_rooms_changed(sync_result_builder, ignored_users) room_entries, invited, newly_joined_rooms, newly_left_rooms = res - tags_by_room = yield self.store.get_updated_tags( + tags_by_room = await self.store.get_updated_tags( user_id, since_token.account_data_key ) else: - res = yield self._get_all_rooms(sync_result_builder, ignored_users) + res = await self._get_all_rooms(sync_result_builder, ignored_users) room_entries, invited, newly_joined_rooms = res newly_left_rooms = [] - tags_by_room = yield self.store.get_tags_for_user(user_id) + tags_by_room = await self.store.get_tags_for_user(user_id) def handle_room_entries(room_entry): return self._generate_room_entry( @@ -1376,7 +1364,7 @@ class SyncHandler(object): always_include=sync_result_builder.full_state, ) - yield concurrently_execute(handle_room_entries, room_entries, 10) + await concurrently_execute(handle_room_entries, room_entries, 10) sync_result_builder.invited.extend(invited) @@ -1410,8 +1398,7 @@ class SyncHandler(object): newly_left_users, ) - @defer.inlineCallbacks - def _have_rooms_changed(self, sync_result_builder): + async def _have_rooms_changed(self, sync_result_builder): """Returns whether there may be any new events that should be sent down the sync. Returns True if there are. """ @@ -1422,7 +1409,7 @@ class SyncHandler(object): assert since_token # Get a list of membership change events that have happened. - rooms_changed = yield self.store.get_membership_changes_for_user( + rooms_changed = await self.store.get_membership_changes_for_user( user_id, since_token.room_key, now_token.room_key ) @@ -1435,8 +1422,7 @@ class SyncHandler(object): return True return False - @defer.inlineCallbacks - def _get_rooms_changed(self, sync_result_builder, ignored_users): + async def _get_rooms_changed(self, sync_result_builder, ignored_users): """Gets the the changes that have happened since the last sync. Args: @@ -1461,7 +1447,7 @@ class SyncHandler(object): assert since_token # Get a list of membership change events that have happened. - rooms_changed = yield self.store.get_membership_changes_for_user( + rooms_changed = await self.store.get_membership_changes_for_user( user_id, since_token.room_key, now_token.room_key ) @@ -1499,11 +1485,11 @@ class SyncHandler(object): continue if room_id in sync_result_builder.joined_room_ids or has_join: - old_state_ids = yield self.get_state_at(room_id, since_token) + old_state_ids = await self.get_state_at(room_id, since_token) old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None) old_mem_ev = None if old_mem_ev_id: - old_mem_ev = yield self.store.get_event( + old_mem_ev = await self.store.get_event( old_mem_ev_id, allow_none=True ) @@ -1536,13 +1522,13 @@ class SyncHandler(object): newly_left_rooms.append(room_id) else: if not old_state_ids: - old_state_ids = yield self.get_state_at(room_id, since_token) + old_state_ids = await self.get_state_at(room_id, since_token) old_mem_ev_id = old_state_ids.get( (EventTypes.Member, user_id), None ) old_mem_ev = None if old_mem_ev_id: - old_mem_ev = yield self.store.get_event( + old_mem_ev = await self.store.get_event( old_mem_ev_id, allow_none=True ) if old_mem_ev and old_mem_ev.membership == Membership.JOIN: @@ -1566,7 +1552,7 @@ class SyncHandler(object): if leave_events: leave_event = leave_events[-1] - leave_stream_token = yield self.store.get_stream_token_for_event( + leave_stream_token = await self.store.get_stream_token_for_event( leave_event.event_id ) leave_token = since_token.copy_and_replace( @@ -1603,7 +1589,7 @@ class SyncHandler(object): timeline_limit = sync_config.filter_collection.timeline_limit() # Get all events for rooms we're currently joined to. - room_to_events = yield self.store.get_room_events_stream_for_rooms( + room_to_events = await self.store.get_room_events_stream_for_rooms( room_ids=sync_result_builder.joined_room_ids, from_key=since_token.room_key, to_key=now_token.room_key, @@ -1652,8 +1638,7 @@ class SyncHandler(object): return room_entries, invited, newly_joined_rooms, newly_left_rooms - @defer.inlineCallbacks - def _get_all_rooms(self, sync_result_builder, ignored_users): + async def _get_all_rooms(self, sync_result_builder, ignored_users): """Returns entries for all rooms for the user. Args: @@ -1677,7 +1662,7 @@ class SyncHandler(object): Membership.BAN, ) - room_list = yield self.store.get_rooms_for_user_where_membership_is( + room_list = await self.store.get_rooms_for_user_where_membership_is( user_id=user_id, membership_list=membership_list ) @@ -1700,7 +1685,7 @@ class SyncHandler(object): elif event.membership == Membership.INVITE: if event.sender in ignored_users: continue - invite = yield self.store.get_event(event.event_id) + invite = await self.store.get_event(event.event_id) invited.append(InvitedSyncResult(room_id=event.room_id, invite=invite)) elif event.membership in (Membership.LEAVE, Membership.BAN): # Always send down rooms we were banned or kicked from. @@ -1726,8 +1711,7 @@ class SyncHandler(object): return room_entries, invited, [] - @defer.inlineCallbacks - def _generate_room_entry( + async def _generate_room_entry( self, sync_result_builder, ignored_users, @@ -1769,7 +1753,7 @@ class SyncHandler(object): since_token = room_builder.since_token upto_token = room_builder.upto_token - batch = yield self._load_filtered_recents( + batch = await self._load_filtered_recents( room_id, sync_config, now_token=upto_token, @@ -1796,7 +1780,7 @@ class SyncHandler(object): # tag was added by synapse e.g. for server notice rooms. if full_state: user_id = sync_result_builder.sync_config.user.to_string() - tags = yield self.store.get_tags_for_room(user_id, room_id) + tags = await self.store.get_tags_for_room(user_id, room_id) # If there aren't any tags, don't send the empty tags list down # sync @@ -1821,7 +1805,7 @@ class SyncHandler(object): ): return - state = yield self.compute_state_delta( + state = await self.compute_state_delta( room_id, batch, sync_config, since_token, now_token, full_state=full_state ) @@ -1844,7 +1828,7 @@ class SyncHandler(object): ) or since_token is None ): - summary = yield self.compute_summary( + summary = await self.compute_summary( room_id, sync_config, batch, state, now_token ) @@ -1861,7 +1845,7 @@ class SyncHandler(object): ) if room_sync or always_include: - notifs = yield self.unread_notifs_for_room_id(room_id, sync_config) + notifs = await self.unread_notifs_for_room_id(room_id, sync_config) if notifs is not None: unread_notifications["notification_count"] = notifs["notify_count"] @@ -1887,8 +1871,7 @@ class SyncHandler(object): else: raise Exception("Unrecognized rtype: %r", room_builder.rtype) - @defer.inlineCallbacks - def get_rooms_for_user_at(self, user_id, stream_ordering): + async def get_rooms_for_user_at(self, user_id, stream_ordering): """Get set of joined rooms for a user at the given stream ordering. The stream ordering *must* be recent, otherwise this may throw an @@ -1903,7 +1886,7 @@ class SyncHandler(object): Deferred[frozenset[str]]: Set of room_ids the user is in at given stream_ordering. """ - joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering(user_id) + joined_rooms = await self.store.get_rooms_for_user_with_stream_ordering(user_id) joined_room_ids = set() @@ -1921,10 +1904,10 @@ class SyncHandler(object): logger.info("User joined room after current token: %s", room_id) - extrems = yield self.store.get_forward_extremeties_for_room( + extrems = await self.store.get_forward_extremeties_for_room( room_id, stream_ordering ) - users_in_room = yield self.state.get_current_users_in_room(room_id, extrems) + users_in_room = await self.state.get_current_users_in_room(room_id, extrems) if user_id in users_in_room: joined_room_ids.add(room_id) diff --git a/synapse/notifier.py b/synapse/notifier.py index af161a81d7..5f5f765bea 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -304,8 +304,7 @@ class Notifier(object): without waking up any of the normal user event streams""" self.notify_replication() - @defer.inlineCallbacks - def wait_for_events( + async def wait_for_events( self, user_id, timeout, callback, room_ids=None, from_token=StreamToken.START ): """Wait until the callback returns a non empty response or the @@ -313,9 +312,9 @@ class Notifier(object): """ user_stream = self.user_to_user_stream.get(user_id) if user_stream is None: - current_token = yield self.event_sources.get_current_token() + current_token = await self.event_sources.get_current_token() if room_ids is None: - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) user_stream = _NotifierUserStream( user_id=user_id, rooms=room_ids, @@ -344,11 +343,11 @@ class Notifier(object): self.hs.get_reactor(), ) with PreserveLoggingContext(): - yield listener.deferred + await listener.deferred current_token = user_stream.current_token - result = yield callback(prev_token, current_token) + result = await callback(prev_token, current_token) if result: break @@ -364,12 +363,11 @@ class Notifier(object): # This happened if there was no timeout or if the timeout had # already expired. current_token = user_stream.current_token - result = yield callback(prev_token, current_token) + result = await callback(prev_token, current_token) return result - @defer.inlineCallbacks - def get_events_for( + async def get_events_for( self, user, pagination_config, @@ -391,15 +389,14 @@ class Notifier(object): """ from_token = pagination_config.from_token if not from_token: - from_token = yield self.event_sources.get_current_token() + from_token = await self.event_sources.get_current_token() limit = pagination_config.limit - room_ids, is_joined = yield self._get_room_ids(user, explicit_room_id) + room_ids, is_joined = await self._get_room_ids(user, explicit_room_id) is_peeking = not is_joined - @defer.inlineCallbacks - def check_for_updates(before_token, after_token): + async def check_for_updates(before_token, after_token): if not after_token.is_after(before_token): return EventStreamResult([], (from_token, from_token)) @@ -415,7 +412,7 @@ class Notifier(object): if only_keys and name not in only_keys: continue - new_events, new_key = yield source.get_new_events( + new_events, new_key = await source.get_new_events( user=user, from_key=getattr(from_token, keyname), limit=limit, @@ -425,7 +422,7 @@ class Notifier(object): ) if name == "room": - new_events = yield filter_events_for_client( + new_events = await filter_events_for_client( self.storage, user.to_string(), new_events, @@ -461,7 +458,7 @@ class Notifier(object): user_id_for_stream, ) - result = yield self.wait_for_events( + result = await self.wait_for_events( user_id_for_stream, timeout, check_for_updates, diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 3286804322..63ddaaba87 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import logging from functools import wraps @@ -64,12 +65,22 @@ def measure_func(name=None): def wrapper(func): block_name = func.__name__ if name is None else name - @wraps(func) - @defer.inlineCallbacks - def measured_func(self, *args, **kwargs): - with Measure(self.clock, block_name): - r = yield func(self, *args, **kwargs) - return r + if inspect.iscoroutinefunction(func): + + @wraps(func) + async def measured_func(self, *args, **kwargs): + with Measure(self.clock, block_name): + r = await func(self, *args, **kwargs) + return r + + else: + + @wraps(func) + @defer.inlineCallbacks + def measured_func(self, *args, **kwargs): + with Measure(self.clock, block_name): + r = yield func(self, *args, **kwargs) + return r return measured_func diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 31f54bbd7d..758ee071a5 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -12,54 +12,53 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer from synapse.api.errors import Codes, ResourceLimitError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION -from synapse.handlers.sync import SyncConfig, SyncHandler +from synapse.handlers.sync import SyncConfig from synapse.types import UserID import tests.unittest import tests.utils -from tests.utils import setup_test_homeserver -class SyncTestCase(tests.unittest.TestCase): +class SyncTestCase(tests.unittest.HomeserverTestCase): """ Tests Sync Handler. """ - @defer.inlineCallbacks - def setUp(self): - self.hs = yield setup_test_homeserver(self.addCleanup) - self.sync_handler = SyncHandler(self.hs) + def prepare(self, reactor, clock, hs): + self.hs = hs + self.sync_handler = self.hs.get_sync_handler() self.store = self.hs.get_datastore() - @defer.inlineCallbacks def test_wait_for_sync_for_user_auth_blocking(self): user_id1 = "@user1:server" user_id2 = "@user2:server" sync_config = self._generate_sync_config(user_id1) + self.reactor.advance(100) # So we get not 0 time self.hs.config.limit_usage_by_mau = True self.hs.config.max_mau_value = 1 # Check that the happy case does not throw errors - yield self.store.upsert_monthly_active_user(user_id1) - yield self.sync_handler.wait_for_sync_for_user(sync_config) + self.get_success(self.store.upsert_monthly_active_user(user_id1)) + self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config)) # Test that global lock works self.hs.config.hs_disabled = True - with self.assertRaises(ResourceLimitError) as e: - yield self.sync_handler.wait_for_sync_for_user(sync_config) - self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + e = self.get_failure( + self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError + ) + self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.hs.config.hs_disabled = False sync_config = self._generate_sync_config(user_id2) - with self.assertRaises(ResourceLimitError) as e: - yield self.sync_handler.wait_for_sync_for_user(sync_config) - self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + e = self.get_failure( + self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError + ) + self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) def _generate_sync_config(self, user_id): return SyncConfig( diff --git a/tests/unittest.py b/tests/unittest.py index 295573bc46..a1bdd963e6 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -18,6 +18,7 @@ import gc import hashlib import hmac +import inspect import logging import time @@ -25,7 +26,7 @@ from mock import Mock from canonicaljson import json -from twisted.internet.defer import Deferred, succeed +from twisted.internet.defer import Deferred, ensureDeferred, succeed from twisted.python.threadpool import ThreadPool from twisted.trial import unittest @@ -415,6 +416,8 @@ class HomeserverTestCase(TestCase): self.reactor.pump([by] * 100) def get_success(self, d, by=0.0): + if inspect.isawaitable(d): + d = ensureDeferred(d) if not isinstance(d, Deferred): return d self.pump(by=by) @@ -424,6 +427,8 @@ class HomeserverTestCase(TestCase): """ Run a Deferred and get a Failure from it. The failure must be of the type `exc`. """ + if inspect.isawaitable(d): + d = ensureDeferred(d) if not isinstance(d, Deferred): return d self.pump() -- cgit 1.5.1 From b3a4e35ca84a29fe4ccdfb1125ed098c68405d6c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 6 Dec 2019 10:14:59 +0000 Subject: Fixup functions to consistently return deferreds --- synapse/handlers/sync.py | 6 +++--- synapse/handlers/typing.py | 2 +- synapse/storage/data_stores/main/account_data.py | 2 +- synapse/storage/data_stores/main/group_server.py | 4 ++-- tests/handlers/test_typing.py | 24 ++++++++++++++++++------ tests/rest/client/v1/test_typing.py | 4 +++- 6 files changed, 28 insertions(+), 14 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 12751fd8c0..2d3b8ba73c 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -339,7 +339,7 @@ class SyncHandler(object): room_ids = sync_result_builder.joined_room_ids typing_source = self.event_sources.sources["typing"] - typing, typing_key = typing_source.get_new_events( + typing, typing_key = await typing_source.get_new_events( user=sync_config.user, from_key=typing_key, limit=sync_config.filter_collection.ephemeral_limit(), @@ -1013,7 +1013,7 @@ class SyncHandler(object): now_token = sync_result_builder.now_token if since_token and since_token.groups_key: - results = self.store.get_groups_changes_for_user( + results = await self.store.get_groups_changes_for_user( user_id, since_token.groups_key, now_token.groups_key ) else: @@ -1197,7 +1197,7 @@ class SyncHandler(object): ( account_data, account_data_by_room, - ) = self.store.get_updated_account_data_for_user( + ) = await self.store.get_updated_account_data_for_user( user_id, since_token.account_data_key ) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 856337b7e2..6f78454322 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -313,7 +313,7 @@ class TypingNotificationEventSource(object): events.append(self._make_event_for(room_id)) - return events, handler._latest_room_serial + return defer.succeed((events, handler._latest_room_serial)) def get_current_key(self): return self.get_typing_handler()._latest_room_serial diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py index b0d22faf3f..ed97b3ffe5 100644 --- a/synapse/storage/data_stores/main/account_data.py +++ b/synapse/storage/data_stores/main/account_data.py @@ -250,7 +250,7 @@ class AccountDataWorkerStore(SQLBaseStore): user_id, int(stream_id) ) if not changed: - return {}, {} + return defer.succeed(({}, {})) return self.runInteraction( "get_updated_account_data_for_user", get_updated_account_data_for_user_txn diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py index 9e1d12bcb7..d29155a3b5 100644 --- a/synapse/storage/data_stores/main/group_server.py +++ b/synapse/storage/data_stores/main/group_server.py @@ -1109,7 +1109,7 @@ class GroupServerStore(SQLBaseStore): user_id, from_token ) if not has_changed: - return [] + return defer.succeed([]) def _get_groups_changes_for_user_txn(txn): sql = """ @@ -1139,7 +1139,7 @@ class GroupServerStore(SQLBaseStore): from_token ) if not has_changed: - return [] + return defer.succeed([]) def _get_all_groups_changes_txn(txn): sql = """ diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index f6d8660285..92b8726093 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -163,7 +163,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + ) self.assertEquals( events[0], [ @@ -227,7 +229,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + ) self.assertEquals( events[0], [ @@ -279,7 +283,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + ) self.assertEquals( events[0], [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], @@ -300,7 +306,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + ) self.assertEquals( events[0], [ @@ -317,7 +325,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_has_calls([call("typing_key", 2, rooms=[ROOM_ID])]) self.assertEquals(self.event_source.get_current_key(), 2) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1) + ) self.assertEquals( events[0], [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], @@ -335,7 +345,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 3) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + ) self.assertEquals( events[0], [ diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 30fb77bac8..4bc3aaf02d 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -109,7 +109,9 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code) self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events(from_key=0, room_ids=[self.room_id]) + events = self.get_success( + self.event_source.get_new_events(from_key=0, room_ids=[self.room_id]) + ) self.assertEquals( events[0], [ -- cgit 1.5.1 From 9a4fb457cf5918c85068ea249cd2d58b3e2e3cfc Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 6 Dec 2019 13:08:40 +0000 Subject: Change DataStores to accept 'database' param. --- synapse/app/federation_sender.py | 5 +++-- synapse/app/user_dir.py | 5 +++-- synapse/replication/slave/storage/_base.py | 5 +++-- synapse/replication/slave/storage/account_data.py | 5 +++-- synapse/replication/slave/storage/client_ips.py | 5 +++-- synapse/replication/slave/storage/deviceinbox.py | 5 +++-- synapse/replication/slave/storage/devices.py | 5 +++-- synapse/replication/slave/storage/events.py | 5 +++-- synapse/replication/slave/storage/filtering.py | 5 +++-- synapse/replication/slave/storage/groups.py | 5 +++-- synapse/replication/slave/storage/presence.py | 5 +++-- synapse/replication/slave/storage/push_rule.py | 5 +++-- synapse/replication/slave/storage/pushers.py | 5 +++-- synapse/replication/slave/storage/receipts.py | 5 +++-- synapse/replication/slave/storage/room.py | 5 +++-- synapse/storage/_base.py | 2 +- synapse/storage/data_stores/main/__init__.py | 5 +++-- synapse/storage/data_stores/main/account_data.py | 9 +++++---- synapse/storage/data_stores/main/appservice.py | 5 +++-- synapse/storage/data_stores/main/client_ips.py | 9 +++++---- synapse/storage/data_stores/main/deviceinbox.py | 9 +++++---- synapse/storage/data_stores/main/devices.py | 9 +++++---- synapse/storage/data_stores/main/event_federation.py | 5 +++-- synapse/storage/data_stores/main/event_push_actions.py | 9 +++++---- synapse/storage/data_stores/main/events.py | 5 +++-- synapse/storage/data_stores/main/events_bg_updates.py | 5 +++-- synapse/storage/data_stores/main/events_worker.py | 5 +++-- synapse/storage/data_stores/main/media_repository.py | 11 +++++++---- synapse/storage/data_stores/main/monthly_active_users.py | 7 ++++--- synapse/storage/data_stores/main/push_rule.py | 5 +++-- synapse/storage/data_stores/main/receipts.py | 9 +++++---- synapse/storage/data_stores/main/registration.py | 13 +++++++------ synapse/storage/data_stores/main/room.py | 9 +++++---- synapse/storage/data_stores/main/roommember.py | 13 +++++++------ synapse/storage/data_stores/main/search.py | 9 +++++---- synapse/storage/data_stores/main/state.py | 13 +++++++------ synapse/storage/data_stores/main/stats.py | 5 +++-- synapse/storage/data_stores/main/stream.py | 5 +++-- synapse/storage/data_stores/main/transactions.py | 5 +++-- synapse/storage/data_stores/main/user_directory.py | 9 +++++---- tests/storage/test_appservice.py | 5 +++-- 41 files changed, 156 insertions(+), 114 deletions(-) (limited to 'tests') diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 448e45e00f..f24920a7d6 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -40,6 +40,7 @@ from synapse.replication.slave.storage.transactions import SlavedTransactionStor from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.streams._base import ReceiptsStream from synapse.server import HomeServer +from synapse.storage.database import Database from synapse.storage.engines import create_engine from synapse.types import ReadReceipt from synapse.util.async_helpers import Linearizer @@ -59,8 +60,8 @@ class FederationSenderSlaveStore( SlavedDeviceStore, SlavedPresenceStore, ): - def __init__(self, db_conn, hs): - super(FederationSenderSlaveStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(FederationSenderSlaveStore, self).__init__(database, db_conn, hs) # We pull out the current federation stream position now so that we # always have a known value for the federation position in memory so diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index b6d4481725..c01fb34a9b 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -43,6 +43,7 @@ from synapse.replication.tcp.streams.events import ( from synapse.rest.client.v2_alpha import user_directory from synapse.server import HomeServer from synapse.storage.data_stores.main.user_directory import UserDirectoryStore +from synapse.storage.database import Database from synapse.storage.engines import create_engine from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.httpresourcetree import create_resource_tree @@ -60,8 +61,8 @@ class UserDirectorySlaveStore( UserDirectoryStore, BaseSlavedStore, ): - def __init__(self, db_conn, hs): - super(UserDirectorySlaveStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(UserDirectorySlaveStore, self).__init__(database, db_conn, hs) events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict( diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 6ece1d6745..b91a528245 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -20,6 +20,7 @@ import six from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from ._slaved_id_tracker import SlavedIdTracker @@ -35,8 +36,8 @@ def __func__(inp): class BaseSlavedStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(BaseSlavedStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(BaseSlavedStore, self).__init__(database, db_conn, hs) if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen = SlavedIdTracker( db_conn, "cache_invalidation_stream", "stream_id" diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index bc2f6a12ae..ebe94909cb 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -18,15 +18,16 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore from synapse.storage.data_stores.main.tags import TagsWorkerStore +from synapse.storage.database import Database class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self._account_data_id_gen = SlavedIdTracker( db_conn, "account_data_max_stream_id", "stream_id" ) - super(SlavedAccountDataStore, self).__init__(db_conn, hs) + super(SlavedAccountDataStore, self).__init__(database, db_conn, hs) def get_max_account_data_stream_id(self): return self._account_data_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index b4f58cea19..fbf996e33a 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -14,6 +14,7 @@ # limitations under the License. from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY +from synapse.storage.database import Database from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.descriptors import Cache @@ -21,8 +22,8 @@ from ._base import BaseSlavedStore class SlavedClientIpStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedClientIpStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedClientIpStore, self).__init__(database, db_conn, hs) self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index 9fb6c5c6ff..0c237c6e0f 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -16,13 +16,14 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore +from synapse.storage.database import Database from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedDeviceInboxStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs) self._device_inbox_id_gen = SlavedIdTracker( db_conn, "device_max_stream_id", "stream_id" ) diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index de50748c30..dc625e0d7a 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -18,12 +18,13 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream from synapse.storage.data_stores.main.devices import DeviceWorkerStore from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore +from synapse.storage.database import Database from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedDeviceStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedDeviceStore, self).__init__(database, db_conn, hs) self.hs = hs diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index d0a0eaf75b..29f35b9915 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -31,6 +31,7 @@ from synapse.storage.data_stores.main.signatures import SignatureWorkerStore from synapse.storage.data_stores.main.state import StateGroupWorkerStore from synapse.storage.data_stores.main.stream import StreamWorkerStore from synapse.storage.data_stores.main.user_erasure_store import UserErasureWorkerStore +from synapse.storage.database import Database from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker @@ -59,13 +60,13 @@ class SlavedEventStore( RelationsWorkerStore, BaseSlavedStore, ): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering") self._backfill_id_gen = SlavedIdTracker( db_conn, "events", "stream_ordering", step=-1 ) - super(SlavedEventStore, self).__init__(db_conn, hs) + super(SlavedEventStore, self).__init__(database, db_conn, hs) # Cached functions can't be accessed through a class instance so we need # to reach inside the __dict__ to extract them. diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py index 5c84ebd125..bcb0688954 100644 --- a/synapse/replication/slave/storage/filtering.py +++ b/synapse/replication/slave/storage/filtering.py @@ -14,13 +14,14 @@ # limitations under the License. from synapse.storage.data_stores.main.filtering import FilteringStore +from synapse.storage.database import Database from ._base import BaseSlavedStore class SlavedFilteringStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedFilteringStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedFilteringStore, self).__init__(database, db_conn, hs) # Filters are immutable so this cache doesn't need to be expired get_user_filter = FilteringStore.__dict__["get_user_filter"] diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index 28a46edd28..69a4ae42f9 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -14,6 +14,7 @@ # limitations under the License. from synapse.storage import DataStore +from synapse.storage.database import Database from synapse.util.caches.stream_change_cache import StreamChangeCache from ._base import BaseSlavedStore, __func__ @@ -21,8 +22,8 @@ from ._slaved_id_tracker import SlavedIdTracker class SlavedGroupServerStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedGroupServerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedGroupServerStore, self).__init__(database, db_conn, hs) self.hs = hs diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index 747ced0c84..f552e7c972 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -15,6 +15,7 @@ from synapse.storage import DataStore from synapse.storage.data_stores.main.presence import PresenceStore +from synapse.storage.database import Database from synapse.util.caches.stream_change_cache import StreamChangeCache from ._base import BaseSlavedStore, __func__ @@ -22,8 +23,8 @@ from ._slaved_id_tracker import SlavedIdTracker class SlavedPresenceStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedPresenceStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedPresenceStore, self).__init__(database, db_conn, hs) self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id") self._presence_on_startup = self._get_active_presence(db_conn) diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 3655f05e54..eebd5a1fb6 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -15,17 +15,18 @@ # limitations under the License. from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore +from synapse.storage.database import Database from ._slaved_id_tracker import SlavedIdTracker from .events import SlavedEventStore class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id" ) - super(SlavedPushRuleStore, self).__init__(db_conn, hs) + super(SlavedPushRuleStore, self).__init__(database, db_conn, hs) def get_push_rules_stream_token(self): return ( diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index b4331d0799..f22c2d44a3 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -15,14 +15,15 @@ # limitations under the License. from synapse.storage.data_stores.main.pusher import PusherWorkerStore +from synapse.storage.database import Database from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedPusherStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedPusherStore, self).__init__(database, db_conn, hs) self._pushers_id_gen = SlavedIdTracker( db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] ) diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index 43d823c601..d40dc6e1f5 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -15,6 +15,7 @@ # limitations under the License. from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore +from synapse.storage.database import Database from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker @@ -29,14 +30,14 @@ from ._slaved_id_tracker import SlavedIdTracker class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): # We instantiate this first as the ReceiptsWorkerStore constructor # needs to be able to call get_max_receipt_stream_id self._receipts_id_gen = SlavedIdTracker( db_conn, "receipts_linearized", "stream_id" ) - super(SlavedReceiptsStore, self).__init__(db_conn, hs) + super(SlavedReceiptsStore, self).__init__(database, db_conn, hs) def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index d9ad386b28..3a20f45316 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -14,14 +14,15 @@ # limitations under the License. from synapse.storage.data_stores.main.room import RoomWorkerStore +from synapse.storage.database import Database from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker class RoomStore(RoomWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): - super(RoomStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RoomStore, self).__init__(database, db_conn, hs) self._public_room_id_gen = SlavedIdTracker( db_conn, "public_room_list_stream", "stream_id" ) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index b7e27d4e97..f9e7f9a71e 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -37,7 +37,7 @@ class SQLBaseStore(object): per data store (and not one per physical database). """ - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self.hs = hs self._clock = hs.get_clock() self.database_engine = hs.database_engine diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index 6adb8adb04..7f5fd81bcf 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -20,6 +20,7 @@ import logging import time from synapse.api.constants import PresenceState +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import ( ChainedIdGenerator, @@ -111,7 +112,7 @@ class DataStore( RelationsStore, CacheInvalidationStore, ): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self.hs = hs self._clock = hs.get_clock() self.database_engine = hs.database_engine @@ -169,7 +170,7 @@ class DataStore( else: self._cache_id_gen = None - super(DataStore, self).__init__(db_conn, hs) + super(DataStore, self).__init__(database, db_conn, hs) self._presence_on_startup = self._get_active_presence(db_conn) diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py index a96fe9485c..44d20c19bf 100644 --- a/synapse/storage/data_stores/main/account_data.py +++ b/synapse/storage/data_stores/main/account_data.py @@ -22,6 +22,7 @@ from canonicaljson import json from twisted.internet import defer from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.storage.util.id_generators import StreamIdGenerator from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -38,13 +39,13 @@ class AccountDataWorkerStore(SQLBaseStore): # the abstract methods being implemented. __metaclass__ = abc.ABCMeta - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max ) - super(AccountDataWorkerStore, self).__init__(db_conn, hs) + super(AccountDataWorkerStore, self).__init__(database, db_conn, hs) @abc.abstractmethod def get_max_account_data_stream_id(self): @@ -270,12 +271,12 @@ class AccountDataWorkerStore(SQLBaseStore): class AccountDataStore(AccountDataWorkerStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self._account_data_id_gen = StreamIdGenerator( db_conn, "account_data_max_stream_id", "stream_id" ) - super(AccountDataStore, self).__init__(db_conn, hs) + super(AccountDataStore, self).__init__(database, db_conn, hs) def get_max_account_data_stream_id(self): """Get the current max stream id for the private user data stream diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py index 6b2e12719c..b2f39649fd 100644 --- a/synapse/storage/data_stores/main/appservice.py +++ b/synapse/storage/data_stores/main/appservice.py @@ -24,6 +24,7 @@ from synapse.appservice import AppServiceTransaction from synapse.config.appservice import load_appservices from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.database import Database logger = logging.getLogger(__name__) @@ -48,13 +49,13 @@ def _make_exclusive_regex(services_cache): class ApplicationServiceWorkerStore(SQLBaseStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self.services_cache = load_appservices( hs.hostname, hs.config.app_service_config_files ) self.exclusive_user_regex = _make_exclusive_regex(self.services_cache) - super(ApplicationServiceWorkerStore, self).__init__(db_conn, hs) + super(ApplicationServiceWorkerStore, self).__init__(database, db_conn, hs) def get_app_services(self): return self.services_cache diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index 7b470a58f1..320c5b0f07 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -21,6 +21,7 @@ from twisted.internet import defer from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.descriptors import Cache @@ -33,8 +34,8 @@ LAST_SEEN_GRANULARITY = 120 * 1000 class ClientIpBackgroundUpdateStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(ClientIpBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(ClientIpBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.db.updates.register_background_index_update( "user_ips_device_index", @@ -363,13 +364,13 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): class ClientIpStore(ClientIpBackgroundUpdateStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR ) - super(ClientIpStore, self).__init__(db_conn, hs) + super(ClientIpStore, self).__init__(database, db_conn, hs) self.user_ips_max_age = hs.config.user_ips_max_age diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py index 3c9f09301a..85cfa16850 100644 --- a/synapse/storage/data_stores/main/deviceinbox.py +++ b/synapse/storage/data_stores/main/deviceinbox.py @@ -21,6 +21,7 @@ from twisted.internet import defer from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.database import Database from synapse.util.caches.expiringcache import ExpiringCache logger = logging.getLogger(__name__) @@ -210,8 +211,8 @@ class DeviceInboxWorkerStore(SQLBaseStore): class DeviceInboxBackgroundUpdateStore(SQLBaseStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" - def __init__(self, db_conn, hs): - super(DeviceInboxBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.db.updates.register_background_index_update( "device_inbox_stream_index", @@ -241,8 +242,8 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" - def __init__(self, db_conn, hs): - super(DeviceInboxStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(DeviceInboxStore, self).__init__(database, db_conn, hs) # Map of (user_id, device_id) to the last stream_id that has been # deleted up to. This is so that we can no op deletions. diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 91ddaf137e..9a828231c4 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -31,6 +31,7 @@ from synapse.logging.opentracing import ( ) from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause +from synapse.storage.database import Database from synapse.types import get_verify_key_from_cross_signing_key from synapse.util import batch_iter from synapse.util.caches.descriptors import ( @@ -642,8 +643,8 @@ class DeviceWorkerStore(SQLBaseStore): class DeviceBackgroundUpdateStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(DeviceBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.db.updates.register_background_index_update( "device_lists_stream_idx", @@ -692,8 +693,8 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(DeviceStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(DeviceStore, self).__init__(database, db_conn, hs) # Map of (user_id, device_id) -> bool. If there is an entry that implies # the device exists. diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 31d2e8eb28..1f517e8fad 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -28,6 +28,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.signatures import SignatureWorkerStore +from synapse.storage.database import Database from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -491,8 +492,8 @@ class EventFederationStore(EventFederationWorkerStore): EVENT_AUTH_STATE_ONLY = "event_auth_state_only" - def __init__(self, db_conn, hs): - super(EventFederationStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventFederationStore, self).__init__(database, db_conn, hs) self.db.updates.register_background_update_handler( self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index eec054cd48..9988a6d3fc 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -24,6 +24,7 @@ from twisted.internet import defer from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import LoggingTransaction, SQLBaseStore +from synapse.storage.database import Database from synapse.util.caches.descriptors import cachedInlineCallbacks logger = logging.getLogger(__name__) @@ -68,8 +69,8 @@ def _deserialize_action(actions, is_highlight): class EventPushActionsWorkerStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(EventPushActionsWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventPushActionsWorkerStore, self).__init__(database, db_conn, hs) # These get correctly set by _find_stream_orderings_for_times_txn self.stream_ordering_month_ago = None @@ -611,8 +612,8 @@ class EventPushActionsWorkerStore(SQLBaseStore): class EventPushActionsStore(EventPushActionsWorkerStore): EPA_HIGHLIGHT_INDEX = "epa_highlight_index" - def __init__(self, db_conn, hs): - super(EventPushActionsStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventPushActionsStore, self).__init__(database, db_conn, hs) self.db.updates.register_background_index_update( self.EPA_HIGHLIGHT_INDEX, diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index d644c82784..da1529f6ea 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -41,6 +41,7 @@ from synapse.storage._base import make_in_list_sql_clause from synapse.storage.data_stores.main.event_federation import EventFederationStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.state import StateGroupWorkerStore +from synapse.storage.database import Database from synapse.types import RoomStreamToken, get_domain_from_id from synapse.util import batch_iter from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -95,8 +96,8 @@ def _retry_on_integrity_error(func): class EventsStore( StateGroupWorkerStore, EventFederationStore, EventsWorkerStore, ): - def __init__(self, db_conn, hs): - super(EventsStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventsStore, self).__init__(database, db_conn, hs) # Collect metrics on the number of forward extremities that exist. # Counter of number of extremities to count diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py index cb1fc30c31..efee17b929 100644 --- a/synapse/storage/data_stores/main/events_bg_updates.py +++ b/synapse/storage/data_stores/main/events_bg_updates.py @@ -23,6 +23,7 @@ from twisted.internet import defer from synapse.api.constants import EventContentFields from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.database import Database logger = logging.getLogger(__name__) @@ -33,8 +34,8 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" - def __init__(self, db_conn, hs): - super(EventsBackgroundUpdatesStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventsBackgroundUpdatesStore, self).__init__(database, db_conn, hs) self.db.updates.register_background_update_handler( self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index e041fc5eac..9ee117ce0f 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -33,6 +33,7 @@ from synapse.events.utils import prune_event from synapse.logging.context import LoggingContext, PreserveLoggingContext from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.database import Database from synapse.types import get_domain_from_id from synapse.util import batch_iter from synapse.util.caches.descriptors import Cache @@ -55,8 +56,8 @@ _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) class EventsWorkerStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(EventsWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventsWorkerStore, self).__init__(database, db_conn, hs) self._get_event_cache = Cache( "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py index 03c9c6f8ae..80ca36dedf 100644 --- a/synapse/storage/data_stores/main/media_repository.py +++ b/synapse/storage/data_stores/main/media_repository.py @@ -13,11 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(MediaRepositoryBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(MediaRepositoryBackgroundUpdateStore, self).__init__( + database, db_conn, hs + ) self.db.updates.register_background_index_update( update_name="local_media_repository_url_idx", @@ -31,8 +34,8 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): """Persistence for attachments and avatars""" - def __init__(self, db_conn, hs): - super(MediaRepositoryStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(MediaRepositoryStore, self).__init__(database, db_conn, hs) def get_local_media(self, media_id): """Get the metadata for a local piece of media diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py index 34bf3a1880..27158534cb 100644 --- a/synapse/storage/data_stores/main/monthly_active_users.py +++ b/synapse/storage/data_stores/main/monthly_active_users.py @@ -17,6 +17,7 @@ import logging from twisted.internet import defer from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -27,13 +28,13 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000 class MonthlyActiveUsersStore(SQLBaseStore): - def __init__(self, dbconn, hs): - super(MonthlyActiveUsersStore, self).__init__(None, hs) + def __init__(self, database: Database, db_conn, hs): + super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs) self._clock = hs.get_clock() self.hs = hs # Do not add more reserved users than the total allowable number self.db.new_transaction( - dbconn, + db_conn, "initialise_mau_threepids", [], [], diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index de682cc63a..5ba13aa973 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -27,6 +27,7 @@ from synapse.storage.data_stores.main.appservice import ApplicationServiceWorker from synapse.storage.data_stores.main.pusher import PusherWorkerStore from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore +from synapse.storage.database import Database from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -72,8 +73,8 @@ class PushRulesWorkerStore( # the abstract methods being implemented. __metaclass__ = abc.ABCMeta - def __init__(self, db_conn, hs): - super(PushRulesWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(PushRulesWorkerStore, self).__init__(database, db_conn, hs) push_rules_prefill, push_rules_id = self.db.get_cache_dict( db_conn, diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py index ac2d45bd5c..96e54d145e 100644 --- a/synapse/storage/data_stores/main/receipts.py +++ b/synapse/storage/data_stores/main/receipts.py @@ -22,6 +22,7 @@ from canonicaljson import json from twisted.internet import defer from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.database import Database from synapse.storage.util.id_generators import StreamIdGenerator from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -38,8 +39,8 @@ class ReceiptsWorkerStore(SQLBaseStore): # the abstract methods being implemented. __metaclass__ = abc.ABCMeta - def __init__(self, db_conn, hs): - super(ReceiptsWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs) self._receipts_stream_cache = StreamChangeCache( "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() @@ -315,14 +316,14 @@ class ReceiptsWorkerStore(SQLBaseStore): class ReceiptsStore(ReceiptsWorkerStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): # We instantiate this first as the ReceiptsWorkerStore constructor # needs to be able to call get_max_receipt_stream_id self._receipts_id_gen = StreamIdGenerator( db_conn, "receipts_linearized", "stream_id" ) - super(ReceiptsStore, self).__init__(db_conn, hs) + super(ReceiptsStore, self).__init__(database, db_conn, hs) def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_current_token() diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py index 1ef143c6d8..5e8ecac0ea 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/data_stores/main/registration.py @@ -27,6 +27,7 @@ from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.types import UserID from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -36,8 +37,8 @@ logger = logging.getLogger(__name__) class RegistrationWorkerStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(RegistrationWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RegistrationWorkerStore, self).__init__(database, db_conn, hs) self.config = hs.config self.clock = hs.get_clock() @@ -794,8 +795,8 @@ class RegistrationWorkerStore(SQLBaseStore): class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): - def __init__(self, db_conn, hs): - super(RegistrationBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.clock = hs.get_clock() self.config = hs.config @@ -920,8 +921,8 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): class RegistrationStore(RegistrationBackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(RegistrationStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RegistrationStore, self).__init__(database, db_conn, hs) self._account_validity = hs.config.account_validity diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index da42dae243..0148be20d3 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -29,6 +29,7 @@ from synapse.api.constants import EventTypes from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.search import SearchStore +from synapse.storage.database import Database from synapse.types import ThirdPartyInstanceID from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -361,8 +362,8 @@ class RoomWorkerStore(SQLBaseStore): class RoomBackgroundUpdateStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(RoomBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.config = hs.config @@ -440,8 +441,8 @@ class RoomBackgroundUpdateStore(SQLBaseStore): class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): - def __init__(self, db_conn, hs): - super(RoomStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RoomStore, self).__init__(database, db_conn, hs) self.config = hs.config diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 929f6b0d39..92e3b9c512 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -32,6 +32,7 @@ from synapse.storage._base import ( make_in_list_sql_clause, ) from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.database import Database from synapse.storage.engines import Sqlite3Engine from synapse.storage.roommember import ( GetRoomsForUserWithStreamOrdering, @@ -54,8 +55,8 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" class RoomMemberWorkerStore(EventsWorkerStore): - def __init__(self, db_conn, hs): - super(RoomMemberWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RoomMemberWorkerStore, self).__init__(database, db_conn, hs) # Is the current_state_events.membership up to date? Or is the # background update still running? @@ -835,8 +836,8 @@ class RoomMemberWorkerStore(EventsWorkerStore): class RoomMemberBackgroundUpdateStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(RoomMemberBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RoomMemberBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.db.updates.register_background_update_handler( _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile ) @@ -991,8 +992,8 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(RoomMemberStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RoomMemberStore, self).__init__(database, db_conn, hs) def _store_room_members_txn(self, txn, events, backfilled): """Store a room member in the database. diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py index ffa1817e64..4eec2fae5e 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -25,6 +25,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine, Sqlite3Engine logger = logging.getLogger(__name__) @@ -42,8 +43,8 @@ class SearchBackgroundUpdateStore(SQLBaseStore): EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist" EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" - def __init__(self, db_conn, hs): - super(SearchBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SearchBackgroundUpdateStore, self).__init__(database, db_conn, hs) if not hs.config.enable_search: return @@ -342,8 +343,8 @@ class SearchBackgroundUpdateStore(SQLBaseStore): class SearchStore(SearchBackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(SearchStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SearchStore, self).__init__(database, db_conn, hs) def store_event_search_txn(self, txn, event, key, value): """Add event to the search table diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 7d5a9f8128..9ef7b48c74 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -28,6 +28,7 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from synapse.storage.state import StateFilter from synapse.util.caches import get_cache_factor_for, intern_string @@ -213,8 +214,8 @@ class StateGroupWorkerStore( STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" - def __init__(self, db_conn, hs): - super(StateGroupWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(StateGroupWorkerStore, self).__init__(database, db_conn, hs) # Originally the state store used a single DictionaryCache to cache the # event IDs for the state types in a given state group to avoid hammering @@ -1029,8 +1030,8 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" - def __init__(self, db_conn, hs): - super(StateBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.db.updates.register_background_update_handler( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, self._background_deduplicate_state, @@ -1245,8 +1246,8 @@ class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore): * `state_groups_state`: Maps state group to state events. """ - def __init__(self, db_conn, hs): - super(StateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(StateStore, self).__init__(database, db_conn, hs) def _store_event_state_mappings_txn( self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]] diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py index 40579bf965..7bc186e9a1 100644 --- a/synapse/storage/data_stores/main/stats.py +++ b/synapse/storage/data_stores/main/stats.py @@ -22,6 +22,7 @@ from twisted.internet.defer import DeferredLock from synapse.api.constants import EventTypes, Membership from synapse.storage.data_stores.main.state_deltas import StateDeltasStore +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from synapse.util.caches.descriptors import cached @@ -58,8 +59,8 @@ TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")} class StatsStore(StateDeltasStore): - def __init__(self, db_conn, hs): - super(StatsStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(StatsStore, self).__init__(database, db_conn, hs) self.server_name = hs.hostname self.clock = self.hs.get_clock() diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py index 2ff8c57109..140da8dad6 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -47,6 +47,7 @@ from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from synapse.types import RoomStreamToken from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -251,8 +252,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): __metaclass__ = abc.ABCMeta - def __init__(self, db_conn, hs): - super(StreamWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(StreamWorkerStore, self).__init__(database, db_conn, hs) events_max = self.get_room_max_stream_ordering() event_cache_prefill, min_event_val = self.db.get_cache_dict( diff --git a/synapse/storage/data_stores/main/transactions.py b/synapse/storage/data_stores/main/transactions.py index c0d155a43c..5b07c2fbc0 100644 --- a/synapse/storage/data_stores/main/transactions.py +++ b/synapse/storage/data_stores/main/transactions.py @@ -24,6 +24,7 @@ from twisted.internet import defer from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import Database from synapse.util.caches.expiringcache import ExpiringCache # py2 sqlite has buffer hardcoded as only binary type, so we must use it, @@ -52,8 +53,8 @@ class TransactionStore(SQLBaseStore): """A collection of queries for handling PDUs. """ - def __init__(self, db_conn, hs): - super(TransactionStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(TransactionStore, self).__init__(database, db_conn, hs) self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000) diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/data_stores/main/user_directory.py index 62ffb34b29..90c180ec6d 100644 --- a/synapse/storage/data_stores/main/user_directory.py +++ b/synapse/storage/data_stores/main/user_directory.py @@ -21,6 +21,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, JoinRules from synapse.storage.data_stores.main.state import StateFilter from synapse.storage.data_stores.main.state_deltas import StateDeltasStore +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.types import get_domain_from_id, get_localpart_from_id from synapse.util.caches.descriptors import cached @@ -37,8 +38,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): # add_users_who_share_private_rooms? SHARE_PRIVATE_WORKING_SET = 500 - def __init__(self, db_conn, hs): - super(UserDirectoryBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(UserDirectoryBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.server_name = hs.hostname @@ -549,8 +550,8 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): # add_users_who_share_private_rooms? SHARE_PRIVATE_WORKING_SET = 500 - def __init__(self, db_conn, hs): - super(UserDirectoryStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(UserDirectoryStore, self).__init__(database, db_conn, hs) def remove_from_user_dir(self, user_id): def _remove_from_user_dir_txn(txn): diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index dfeea24599..1679112d82 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -28,6 +28,7 @@ from synapse.storage.data_stores.main.appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore, ) +from synapse.storage.database import Database from tests import unittest from tests.utils import setup_test_homeserver @@ -382,8 +383,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): # required for ApplicationServiceTransactionStoreTestCase tests class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore): - def __init__(self, db_conn, hs): - super(TestTransactionStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(TestTransactionStore, self).__init__(database, db_conn, hs) class ApplicationServiceStoreConfigTestCase(unittest.TestCase): -- cgit 1.5.1 From 852f80d8a697c9d556b1b2641a2e4c1797cbbb46 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 6 Dec 2019 16:02:50 +0000 Subject: Fixup tests --- tests/replication/slave/storage/_base.py | 5 ++++- tests/storage/test_appservice.py | 12 +++++++----- tests/storage/test_base.py | 3 ++- tests/storage/test_profile.py | 3 +-- tests/storage/test_user_directory.py | 4 +--- tests/test_federation.py | 16 +++++----------- 6 files changed, 20 insertions(+), 23 deletions(-) (limited to 'tests') diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index e7472e3a93..3dae83c543 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -20,6 +20,7 @@ from synapse.replication.tcp.client import ( ReplicationClientHandler, ) from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory +from synapse.storage.database import Database from tests import unittest from tests.server import FakeTransport @@ -42,7 +43,9 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): self.master_store = self.hs.get_datastore() self.storage = hs.get_storage() - self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs) + self.slaved_store = self.STORE_TYPE( + Database(hs), self.hs.get_db_conn(), self.hs + ) self.event_id = 0 server_factory = ReplicationStreamProtocolFactory(self.hs) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 1679112d82..2e521e9ab7 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -55,7 +55,8 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob") self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob") # must be done after inserts - self.store = ApplicationServiceStore(hs.get_db_conn(), hs) + database = Database(hs) + self.store = ApplicationServiceStore(database, hs.get_db_conn(), hs) def tearDown(self): # TODO: suboptimal that we need to create files for tests! @@ -124,7 +125,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.as_yaml_files = [] - self.store = TestTransactionStore(hs.get_db_conn(), hs) + database = Database(hs) + self.store = TestTransactionStore(database, hs.get_db_conn(), hs) def _add_service(self, url, as_token, id): as_yaml = dict( @@ -417,7 +419,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs.config.event_cache_size = 1 hs.config.password_providers = [] - ApplicationServiceStore(hs.get_db_conn(), hs) + ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs) @defer.inlineCallbacks def test_duplicate_ids(self): @@ -433,7 +435,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: - ApplicationServiceStore(hs.get_db_conn(), hs) + ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs) e = cm.exception self.assertIn(f1, str(e)) @@ -454,7 +456,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: - ApplicationServiceStore(hs.get_db_conn(), hs) + ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs) e = cm.exception self.assertIn(f1, str(e)) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 7915d48a9e..537cfe9f64 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -21,6 +21,7 @@ from mock import Mock from twisted.internet import defer from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.storage.engines import create_engine from tests import unittest @@ -59,7 +60,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): "test", db_pool=self.db_pool, config=config, database_engine=fake_engine ) - self.datastore = SQLBaseStore(None, hs) + self.datastore = SQLBaseStore(Database(hs), None, hs) @defer.inlineCallbacks def test_insert_1col(self): diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 24c7fe16c3..9b6f7211ae 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -16,7 +16,6 @@ from twisted.internet import defer -from synapse.storage.data_stores.main.profile import ProfileStore from synapse.types import UserID from tests import unittest @@ -28,7 +27,7 @@ class ProfileStoreTestCase(unittest.TestCase): def setUp(self): hs = yield setup_test_homeserver(self.addCleanup) - self.store = ProfileStore(hs.get_db_conn(), hs) + self.store = hs.get_datastore() self.u_frank = UserID.from_string("@frank:test") diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 7eea57c0e2..6a545d2eb0 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -15,8 +15,6 @@ from twisted.internet import defer -from synapse.storage.data_stores.main.user_directory import UserDirectoryStore - from tests import unittest from tests.utils import setup_test_homeserver @@ -29,7 +27,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): self.hs = yield setup_test_homeserver(self.addCleanup) - self.store = UserDirectoryStore(self.hs.get_db_conn(), self.hs) + self.store = self.hs.get_datastore() # alice and bob are both in !room_id. bobby is not but shares # a homeserver with alice. diff --git a/tests/test_federation.py b/tests/test_federation.py index 7d82b58466..ad165d7295 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -33,6 +33,8 @@ class MessageAcceptTests(unittest.TestCase): self.reactor.advance(0.1) self.room_id = self.successResultOf(room)["room_id"] + self.store = self.homeserver.get_datastore() + # Figure out what the most recent event is most_recent = self.successResultOf( maybeDeferred( @@ -77,10 +79,7 @@ class MessageAcceptTests(unittest.TestCase): # Make sure we actually joined the room self.assertEqual( self.successResultOf( - maybeDeferred( - self.homeserver.get_datastore().get_latest_event_ids_in_room, - self.room_id, - ) + maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) )[0], "$join:test.serv", ) @@ -100,10 +99,7 @@ class MessageAcceptTests(unittest.TestCase): # Figure out what the most recent event is most_recent = self.successResultOf( - maybeDeferred( - self.homeserver.get_datastore().get_latest_event_ids_in_room, - self.room_id, - ) + maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) )[0] # Now lie about an event @@ -141,7 +137,5 @@ class MessageAcceptTests(unittest.TestCase): ) # Make sure the invalid event isn't there - extrem = maybeDeferred( - self.homeserver.get_datastore().get_latest_event_ids_in_room, self.room_id - ) + extrem = maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv") -- cgit 1.5.1 From f166a8d1f52f1fda1a64d8cd50f17275d63c8843 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 9 Dec 2019 13:42:49 +0000 Subject: Remove SnapshotCache in favour of ResponseCache --- synapse/handlers/initial_sync.py | 19 +++---- synapse/util/caches/snapshot_cache.py | 94 ----------------------------------- tests/util/test_snapshot_cache.py | 63 ----------------------- 3 files changed, 8 insertions(+), 168 deletions(-) delete mode 100644 synapse/util/caches/snapshot_cache.py delete mode 100644 tests/util/test_snapshot_cache.py (limited to 'tests') diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 81dce96f4b..73c110a92b 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -26,7 +26,7 @@ from synapse.streams.config import PaginationConfig from synapse.types import StreamToken, UserID from synapse.util import unwrapFirstError from synapse.util.async_helpers import concurrently_execute -from synapse.util.caches.snapshot_cache import SnapshotCache +from synapse.util.caches.response_cache import ResponseCache from synapse.visibility import filter_events_for_client from ._base import BaseHandler @@ -41,7 +41,7 @@ class InitialSyncHandler(BaseHandler): self.state = hs.get_state_handler() self.clock = hs.get_clock() self.validator = EventValidator() - self.snapshot_cache = SnapshotCache() + self.snapshot_cache = ResponseCache(hs, "initial_sync_cache") self._event_serializer = hs.get_event_client_serializer() self.storage = hs.get_storage() self.state_store = self.storage.state @@ -79,17 +79,14 @@ class InitialSyncHandler(BaseHandler): as_client_event, include_archived, ) - now_ms = self.clock.time_msec() - result = self.snapshot_cache.get(now_ms, key) - if result is not None: - return result - return self.snapshot_cache.set( - now_ms, + return self.snapshot_cache.wrap( key, - self._snapshot_all_rooms( - user_id, pagin_config, as_client_event, include_archived - ), + self._snapshot_all_rooms, + user_id, + pagin_config, + as_client_event, + include_archived, ) @defer.inlineCallbacks diff --git a/synapse/util/caches/snapshot_cache.py b/synapse/util/caches/snapshot_cache.py deleted file mode 100644 index 8318db8d2c..0000000000 --- a/synapse/util/caches/snapshot_cache.py +++ /dev/null @@ -1,94 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.util.async_helpers import ObservableDeferred - - -class SnapshotCache(object): - """Cache for snapshots like the response of /initialSync. - The response of initialSync only has to be a recent snapshot of the - server state. It shouldn't matter to clients if it is a few minutes out - of date. - - This caches a deferred response. Until the deferred completes it will be - returned from the cache. This means that if the client retries the request - while the response is still being computed, that original response will be - used rather than trying to compute a new response. - - Once the deferred completes it will removed from the cache after 5 minutes. - We delay removing it from the cache because a client retrying its request - could race with us finishing computing the response. - - Rather than tracking precisely how long something has been in the cache we - keep two generations of completed responses. Every 5 minutes discard the - old generation, move the new generation to the old generation, and set the - new generation to be empty. This means that a result will be in the cache - somewhere between 5 and 10 minutes. - """ - - DURATION_MS = 5 * 60 * 1000 # Cache results for 5 minutes. - - def __init__(self): - self.pending_result_cache = {} # Request that haven't finished yet. - self.prev_result_cache = {} # The older requests that have finished. - self.next_result_cache = {} # The newer requests that have finished. - self.time_last_rotated_ms = 0 - - def rotate(self, time_now_ms): - # Rotate once if the cache duration has passed since the last rotation. - if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS: - self.prev_result_cache = self.next_result_cache - self.next_result_cache = {} - self.time_last_rotated_ms += self.DURATION_MS - - # Rotate again if the cache duration has passed twice since the last - # rotation. - if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS: - self.prev_result_cache = self.next_result_cache - self.next_result_cache = {} - self.time_last_rotated_ms = time_now_ms - - def get(self, time_now_ms, key): - self.rotate(time_now_ms) - # This cache is intended to deduplicate requests, so we expect it to be - # missed most of the time. So we just lookup the key in all of the - # dictionaries rather than trying to short circuit the lookup if the - # key is found. - result = self.prev_result_cache.get(key) - result = self.next_result_cache.get(key, result) - result = self.pending_result_cache.get(key, result) - if result is not None: - return result.observe() - else: - return None - - def set(self, time_now_ms, key, deferred): - self.rotate(time_now_ms) - - result = ObservableDeferred(deferred) - - self.pending_result_cache[key] = result - - def shuffle_along(r): - # When the deferred completes we shuffle it along to the first - # generation of the result cache. So that it will eventually - # expire from the rotation of that cache. - self.next_result_cache[key] = result - self.pending_result_cache.pop(key, None) - return r - - result.addBoth(shuffle_along) - - return result.observe() diff --git a/tests/util/test_snapshot_cache.py b/tests/util/test_snapshot_cache.py deleted file mode 100644 index 1a44f72425..0000000000 --- a/tests/util/test_snapshot_cache.py +++ /dev/null @@ -1,63 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from twisted.internet.defer import Deferred - -from synapse.util.caches.snapshot_cache import SnapshotCache - -from .. import unittest - - -class SnapshotCacheTestCase(unittest.TestCase): - def setUp(self): - self.cache = SnapshotCache() - self.cache.DURATION_MS = 1 - - def test_get_set(self): - # Check that getting a missing key returns None - self.assertEquals(self.cache.get(0, "key"), None) - - # Check that setting a key with a deferred returns - # a deferred that resolves when the initial deferred does - d = Deferred() - set_result = self.cache.set(0, "key", d) - self.assertIsNotNone(set_result) - self.assertFalse(set_result.called) - - # Check that getting the key before the deferred has resolved - # returns a deferred that resolves when the initial deferred does. - get_result_at_10 = self.cache.get(10, "key") - self.assertIsNotNone(get_result_at_10) - self.assertFalse(get_result_at_10.called) - - # Check that the returned deferreds resolve when the initial deferred - # does. - d.callback("v") - self.assertTrue(set_result.called) - self.assertTrue(get_result_at_10.called) - - # Check that getting the key after the deferred has resolved - # before the cache expires returns a resolved deferred. - get_result_at_11 = self.cache.get(11, "key") - self.assertIsNotNone(get_result_at_11) - if isinstance(get_result_at_11, Deferred): - # The cache may return the actual result rather than a deferred - self.assertTrue(get_result_at_11.called) - - # Check that getting the key after the deferred has resolved - # after the cache expires returns None - get_result_at_12 = self.cache.get(12, "key") - self.assertIsNone(get_result_at_12) -- cgit 1.5.1 From adfdd82b21ae296ed77453b2f51d55414890f162 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Mon, 9 Dec 2019 13:59:27 +0000 Subject: Back out perf regression from get_cross_signing_keys_from_cache. (#6494) Back out cross-signing code added in Synapse 1.5.0, which caused a performance regression. --- changelog.d/6494.bugfix | 1 + synapse/handlers/e2e_keys.py | 38 ++++++++------------------------------ sytest-blacklist | 3 +++ tests/handlers/test_e2e_keys.py | 8 ++++++++ 4 files changed, 20 insertions(+), 30 deletions(-) create mode 100644 changelog.d/6494.bugfix (limited to 'tests') diff --git a/changelog.d/6494.bugfix b/changelog.d/6494.bugfix new file mode 100644 index 0000000000..78726d5d7f --- /dev/null +++ b/changelog.d/6494.bugfix @@ -0,0 +1 @@ +Back out cross-signing code added in Synapse 1.5.0, which caused a performance regression. diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 28c12753c1..57a10daefd 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -264,7 +264,6 @@ class E2eKeysHandler(object): return ret - @defer.inlineCallbacks def get_cross_signing_keys_from_cache(self, query, from_user_id): """Get cross-signing keys for users from the database @@ -284,35 +283,14 @@ class E2eKeysHandler(object): self_signing_keys = {} user_signing_keys = {} - for user_id in query: - # XXX: consider changing the store functions to allow querying - # multiple users simultaneously. - key = yield self.store.get_e2e_cross_signing_key( - user_id, "master", from_user_id - ) - if key: - master_keys[user_id] = key - - key = yield self.store.get_e2e_cross_signing_key( - user_id, "self_signing", from_user_id - ) - if key: - self_signing_keys[user_id] = key - - # users can see other users' master and self-signing keys, but can - # only see their own user-signing keys - if from_user_id == user_id: - key = yield self.store.get_e2e_cross_signing_key( - user_id, "user_signing", from_user_id - ) - if key: - user_signing_keys[user_id] = key - - return { - "master_keys": master_keys, - "self_signing_keys": self_signing_keys, - "user_signing_keys": user_signing_keys, - } + # Currently a stub, implementation coming in https://github.com/matrix-org/synapse/pull/6486 + return defer.succeed( + { + "master_keys": master_keys, + "self_signing_keys": self_signing_keys, + "user_signing_keys": user_signing_keys, + } + ) @trace @defer.inlineCallbacks diff --git a/sytest-blacklist b/sytest-blacklist index 411cce0692..79b2d4402a 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -33,3 +33,6 @@ New federated private chats get full presence information (SYN-115) # Blacklisted due to https://github.com/matrix-org/matrix-doc/pull/2314 removing # this requirement from the spec Inbound federation of state requires event_id as a mandatory paramater + +# Blacklisted until https://github.com/matrix-org/synapse/pull/6486 lands +Can upload self-signing keys diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 854eb6c024..fdfa2cbbc4 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -183,6 +183,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase): ) self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]}) + test_replace_master_key.skip = ( + "Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486" + ) + @defer.inlineCallbacks def test_reupload_signatures(self): """re-uploading a signature should not fail""" @@ -503,3 +507,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): ], other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey], ) + + test_upload_signatures.skip = ( + "Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486" + ) -- cgit 1.5.1 From 9a2223d4c8c2a46f7ab072597bdba2614bc3e6ea Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 10 Dec 2019 11:22:12 +0000 Subject: Fix make_deferred_yieldable to work with coroutines --- synapse/logging/context.py | 9 ++++++++- tests/util/test_logcontext.py | 24 ++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 2c1fb9ddac..9f484ce59e 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -23,6 +23,7 @@ them. See doc/log_contexts.rst for details on how this works. """ +import inspect import logging import threading import types @@ -612,7 +613,8 @@ def run_in_background(f, *args, **kwargs): def make_deferred_yieldable(deferred): - """Given a deferred, make it follow the Synapse logcontext rules: + """Given a deferred (or coroutine), make it follow the Synapse logcontext + rules: If the deferred has completed (or is not actually a Deferred), essentially does nothing (just returns another completed deferred with the @@ -624,6 +626,11 @@ def make_deferred_yieldable(deferred): (This is more-or-less the opposite operation to run_in_background.) """ + if inspect.isawaitable(deferred): + # If we're given a coroutine we need to convert it to a deferred so that + # we can attach callbacks (and not immediately return). + deferred = defer.ensureDeferred(deferred) + if not isinstance(deferred, defer.Deferred): return deferred diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index 8b8455c8b7..281b32c4b8 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -179,6 +179,30 @@ class LoggingContextTestCase(unittest.TestCase): nested_context = nested_logging_context(suffix="bar") self.assertEqual(nested_context.request, "foo-bar") + @defer.inlineCallbacks + def test_make_deferred_yieldable_with_await(self): + # an async function which retuns an incomplete coroutine, but doesn't + # follow the synapse rules. + + async def blocking_function(): + d = defer.Deferred() + reactor.callLater(0, d.callback, None) + await d + + sentinel_context = LoggingContext.current_context() + + with LoggingContext() as context_one: + context_one.request = "one" + + d1 = make_deferred_yieldable(blocking_function()) + # make sure that the context was reset by make_deferred_yieldable + self.assertIs(LoggingContext.current_context(), sentinel_context) + + yield d1 + + # now it should be restored + self._check_test_key("one") + # a function which returns a deferred which has been "called", but # which had a function which returned another incomplete deferred on -- cgit 1.5.1 From bc5cb8bfe8c5b0b551de9a97912cd00caeaf6b48 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 6 Dec 2019 17:42:18 +0000 Subject: Remove database config parsing from apps. --- synapse/app/admin_cmd.py | 5 ----- synapse/app/appservice.py | 5 ----- synapse/app/client_reader.py | 5 ----- synapse/app/event_creator.py | 5 ----- synapse/app/federation_reader.py | 5 ----- synapse/app/federation_sender.py | 5 ----- synapse/app/frontend_proxy.py | 5 ----- synapse/app/homeserver.py | 7 +------ synapse/app/media_repository.py | 5 ----- synapse/app/pusher.py | 5 ----- synapse/app/synchrotron.py | 5 ----- synapse/app/user_dir.py | 5 ----- synapse/server.py | 10 +++++++++- tests/test_types.py | 19 ++++++++----------- tests/utils.py | 2 -- 15 files changed, 18 insertions(+), 75 deletions(-) (limited to 'tests') diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 04751a6a5e..51a909419f 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -45,7 +45,6 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto from synapse.replication.slave.storage.room import RoomStore from synapse.replication.tcp.client import ReplicationClientHandler from synapse.server import HomeServer -from synapse.storage.engines import create_engine from synapse.util.logcontext import LoggingContext from synapse.util.versionstring import get_version_string @@ -229,14 +228,10 @@ def start(config_options): synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = AdminCmdServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index 02b900f382..e82e0f11e3 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -34,7 +34,6 @@ from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.tcp.client import ReplicationClientHandler from synapse.server import HomeServer -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -143,8 +142,6 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - if config.notify_appservices: sys.stderr.write( "\nThe appservices must be disabled in the main synapse process" @@ -159,10 +156,8 @@ def start(config_options): ps = AppserviceServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ps, config, use_worker_options=True) diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index dadb487d5f..3edfe19567 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -62,7 +62,6 @@ from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet from synapse.rest.client.v2_alpha.register import RegisterRestServlet from synapse.rest.client.versions import VersionsRestServlet from synapse.server import HomeServer -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -181,14 +180,10 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = ClientReaderServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py index d110599a35..d0ddbe38fc 100644 --- a/synapse/app/event_creator.py +++ b/synapse/app/event_creator.py @@ -57,7 +57,6 @@ from synapse.rest.client.v1.room import ( ) from synapse.server import HomeServer from synapse.storage.data_stores.main.user_directory import UserDirectoryStore -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -180,14 +179,10 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = EventCreatorServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index 418c086254..311523e0ed 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -46,7 +46,6 @@ from synapse.replication.slave.storage.transactions import SlavedTransactionStor from synapse.replication.tcp.client import ReplicationClientHandler from synapse.rest.key.v2 import KeyApiV2Resource from synapse.server import HomeServer -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -162,14 +161,10 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = FederationReaderServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index f24920a7d6..83c436229c 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -41,7 +41,6 @@ from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.streams._base import ReceiptsStream from synapse.server import HomeServer from synapse.storage.database import Database -from synapse.storage.engines import create_engine from synapse.types import ReadReceipt from synapse.util.async_helpers import Linearizer from synapse.util.httpresourcetree import create_resource_tree @@ -174,8 +173,6 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - if config.send_federation: sys.stderr.write( "\nThe send_federation must be disabled in the main synapse process" @@ -190,10 +187,8 @@ def start(config_options): ss = FederationSenderServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index e647459d0e..30e435eead 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -39,7 +39,6 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto from synapse.replication.tcp.client import ReplicationClientHandler from synapse.rest.client.v2_alpha._base import client_patterns from synapse.server import HomeServer -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -234,14 +233,10 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = FrontendProxyServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index df65d0a989..79341784c5 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -69,7 +69,7 @@ from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.well_known import WellKnownResource from synapse.server import HomeServer from synapse.storage import DataStore -from synapse.storage.engines import IncorrectDatabaseSetup, create_engine +from synapse.storage.engines import IncorrectDatabaseSetup from synapse.storage.prepare_database import UpgradeDatabaseException from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.httpresourcetree import create_resource_tree @@ -328,15 +328,10 @@ def setup(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection - hs = SynapseHomeServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) synapse.config.logger.setup_logging(hs, config, use_worker_options=False) diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index 2c6dd3ef02..4c80f257e2 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -40,7 +40,6 @@ from synapse.rest.admin import register_servlets_for_media_repo from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.server import HomeServer from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -157,14 +156,10 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = MediaRepositoryServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index 01a5ffc363..e157cbf64b 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -36,7 +36,6 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.tcp.client import ReplicationClientHandler from synapse.server import HomeServer from synapse.storage import DataStore -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -198,14 +197,10 @@ def start(config_options): # Force the pushers to start since they will be disabled in the main config config.start_pushers = True - database_engine = create_engine(config.database_config) - ps = PusherServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ps, config, use_worker_options=True) diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 288ee64b42..dd2132e608 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -55,7 +55,6 @@ from synapse.rest.client.v1.room import RoomInitialSyncRestServlet from synapse.rest.client.v2_alpha import sync from synapse.server import HomeServer from synapse.storage.data_stores.main.presence import UserPresenceState -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.stringutils import random_string @@ -437,14 +436,10 @@ def start(config_options): synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = SynchrotronServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, application_service_handler=SynchrotronApplicationService(), ) diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index c01fb34a9b..1257098f92 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -44,7 +44,6 @@ from synapse.rest.client.v2_alpha import user_directory from synapse.server import HomeServer from synapse.storage.data_stores.main.user_directory import UserDirectoryStore from synapse.storage.database import Database -from synapse.storage.engines import create_engine from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole @@ -200,8 +199,6 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - if config.update_user_directory: sys.stderr.write( "\nThe update_user_directory must be disabled in the main synapse process" @@ -216,10 +213,8 @@ def start(config_options): ss = UserDirectoryServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/server.py b/synapse/server.py index 2db3dab221..a8b5459ff3 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -97,6 +97,7 @@ from synapse.server_notices.worker_server_notices_sender import ( ) from synapse.state import StateHandler, StateResolutionHandler from synapse.storage import DataStores, Storage +from synapse.storage.engines import create_engine from synapse.streams.events import EventSources from synapse.util import Clock from synapse.util.distributor import Distributor @@ -209,7 +210,7 @@ class HomeServer(object): # instantiated during setup() for future return by get_datastore() DATASTORE_CLASS = abc.abstractproperty() - def __init__(self, hostname, reactor=None, **kwargs): + def __init__(self, hostname, config, reactor=None, **kwargs): """ Args: hostname : The hostname for the server. @@ -219,6 +220,7 @@ class HomeServer(object): self._reactor = reactor self.hostname = hostname + self.config = config self._building = {} self._listening_services = [] self.start_time = None @@ -229,6 +231,12 @@ class HomeServer(object): self.admin_redaction_ratelimiter = Ratelimiter() self.registration_ratelimiter = Ratelimiter() + self.database_engine = create_engine(config.database_config) + config.database_config.setdefault("args", {})[ + "cp_openfun" + ] = self.database_engine.on_new_connection + self.db_config = config.database_config + self.datastores = None # Other kwargs are explicit dependencies diff --git a/tests/test_types.py b/tests/test_types.py index 9ab5f829b0..8d97c751ea 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -17,18 +17,15 @@ from synapse.api.errors import SynapseError from synapse.types import GroupID, RoomAlias, UserID, map_username_to_mxid_localpart from tests import unittest -from tests.utils import TestHomeServer -mock_homeserver = TestHomeServer(hostname="my.domain") - -class UserIDTestCase(unittest.TestCase): +class UserIDTestCase(unittest.HomeserverTestCase): def test_parse(self): - user = UserID.from_string("@1234abcd:my.domain") + user = UserID.from_string("@1234abcd:test") self.assertEquals("1234abcd", user.localpart) - self.assertEquals("my.domain", user.domain) - self.assertEquals(True, mock_homeserver.is_mine(user)) + self.assertEquals("test", user.domain) + self.assertEquals(True, self.hs.is_mine(user)) def test_pase_empty(self): with self.assertRaises(SynapseError): @@ -48,13 +45,13 @@ class UserIDTestCase(unittest.TestCase): self.assertTrue(userA != userB) -class RoomAliasTestCase(unittest.TestCase): +class RoomAliasTestCase(unittest.HomeserverTestCase): def test_parse(self): - room = RoomAlias.from_string("#channel:my.domain") + room = RoomAlias.from_string("#channel:test") self.assertEquals("channel", room.localpart) - self.assertEquals("my.domain", room.domain) - self.assertEquals(True, mock_homeserver.is_mine(room)) + self.assertEquals("test", room.domain) + self.assertEquals(True, self.hs.is_mine(room)) def test_build(self): room = RoomAlias("channel", "my.domain") diff --git a/tests/utils.py b/tests/utils.py index c57da59191..585f305b9a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -260,9 +260,7 @@ def setup_test_homeserver( hs = homeserverToUse( name, config=config, - db_config=config.database_config, version_string="Synapse/tests", - database_engine=db_engine, tls_server_context_factory=Mock(), tls_client_options_factory=Mock(), reactor=reactor, -- cgit 1.5.1 From 257ef2c727abbb289c460b930e8bac18ab80e053 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 10 Dec 2019 11:13:15 +0000 Subject: Port handlers.account_validity to async/await. --- synapse/handlers/account_data.py | 2 +- synapse/handlers/account_validity.py | 86 ++++++++++++++--------------- tests/rest/client/v2_alpha/test_register.py | 3 +- 3 files changed, 42 insertions(+), 49 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index fe44d62a1c..20ec1ca01b 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -25,7 +25,7 @@ class AccountDataEventSource(object): user_id = user.to_string() last_stream_id = from_key - current_stream_id = await self.store.get_max_account_data_stream_id() + current_stream_id = self.store.get_max_account_data_stream_id() results = [] tags = await self.store.get_updated_tags(user_id, last_stream_id) diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index d04e0fe576..829f52eca1 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -18,8 +18,7 @@ import email.utils import logging from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText - -from twisted.internet import defer +from typing import List from synapse.api.errors import StoreError from synapse.logging.context import make_deferred_yieldable @@ -78,42 +77,39 @@ class AccountValidityHandler(object): # run as a background process to make sure that the database transactions # have a logcontext to report to return run_as_background_process( - "send_renewals", self.send_renewal_emails + "send_renewals", self._send_renewal_emails ) self.clock.looping_call(send_emails, 30 * 60 * 1000) - @defer.inlineCallbacks - def send_renewal_emails(self): + async def _send_renewal_emails(self): """Gets the list of users whose account is expiring in the amount of time configured in the ``renew_at`` parameter from the ``account_validity`` configuration, and sends renewal emails to all of these users as long as they have an email 3PID attached to their account. """ - expiring_users = yield self.store.get_users_expiring_soon() + expiring_users = await self.store.get_users_expiring_soon() if expiring_users: for user in expiring_users: - yield self._send_renewal_email( + await self._send_renewal_email( user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"] ) - @defer.inlineCallbacks - def send_renewal_email_to_user(self, user_id): - expiration_ts = yield self.store.get_expiration_ts_for_user(user_id) - yield self._send_renewal_email(user_id, expiration_ts) + async def send_renewal_email_to_user(self, user_id: str): + expiration_ts = await self.store.get_expiration_ts_for_user(user_id) + await self._send_renewal_email(user_id, expiration_ts) - @defer.inlineCallbacks - def _send_renewal_email(self, user_id, expiration_ts): + async def _send_renewal_email(self, user_id: str, expiration_ts: int): """Sends out a renewal email to every email address attached to the given user with a unique link allowing them to renew their account. Args: - user_id (str): ID of the user to send email(s) to. - expiration_ts (int): Timestamp in milliseconds for the expiration date of + user_id: ID of the user to send email(s) to. + expiration_ts: Timestamp in milliseconds for the expiration date of this user's account (used in the email templates). """ - addresses = yield self._get_email_addresses_for_user(user_id) + addresses = await self._get_email_addresses_for_user(user_id) # Stop right here if the user doesn't have at least one email address. # In this case, they will have to ask their server admin to renew their @@ -125,7 +121,7 @@ class AccountValidityHandler(object): return try: - user_display_name = yield self.store.get_profile_displayname( + user_display_name = await self.store.get_profile_displayname( UserID.from_string(user_id).localpart ) if user_display_name is None: @@ -133,7 +129,7 @@ class AccountValidityHandler(object): except StoreError: user_display_name = user_id - renewal_token = yield self._get_renewal_token(user_id) + renewal_token = await self._get_renewal_token(user_id) url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % ( self.hs.config.public_baseurl, renewal_token, @@ -165,7 +161,7 @@ class AccountValidityHandler(object): logger.info("Sending renewal email to %s", address) - yield make_deferred_yieldable( + await make_deferred_yieldable( self.sendmail( self.hs.config.email_smtp_host, self._raw_from, @@ -180,19 +176,18 @@ class AccountValidityHandler(object): ) ) - yield self.store.set_renewal_mail_status(user_id=user_id, email_sent=True) + await self.store.set_renewal_mail_status(user_id=user_id, email_sent=True) - @defer.inlineCallbacks - def _get_email_addresses_for_user(self, user_id): + async def _get_email_addresses_for_user(self, user_id: str) -> List[str]: """Retrieve the list of email addresses attached to a user's account. Args: - user_id (str): ID of the user to lookup email addresses for. + user_id: ID of the user to lookup email addresses for. Returns: - defer.Deferred[list[str]]: Email addresses for this account. + Email addresses for this account. """ - threepids = yield self.store.user_get_threepids(user_id) + threepids = await self.store.user_get_threepids(user_id) addresses = [] for threepid in threepids: @@ -201,16 +196,15 @@ class AccountValidityHandler(object): return addresses - @defer.inlineCallbacks - def _get_renewal_token(self, user_id): + async def _get_renewal_token(self, user_id: str) -> str: """Generates a 32-byte long random string that will be inserted into the user's renewal email's unique link, then saves it into the database. Args: - user_id (str): ID of the user to generate a string for. + user_id: ID of the user to generate a string for. Returns: - defer.Deferred[str]: The generated string. + The generated string. Raises: StoreError(500): Couldn't generate a unique string after 5 attempts. @@ -219,52 +213,52 @@ class AccountValidityHandler(object): while attempts < 5: try: renewal_token = stringutils.random_string(32) - yield self.store.set_renewal_token_for_user(user_id, renewal_token) + await self.store.set_renewal_token_for_user(user_id, renewal_token) return renewal_token except StoreError: attempts += 1 raise StoreError(500, "Couldn't generate a unique string as refresh string.") - @defer.inlineCallbacks - def renew_account(self, renewal_token): + async def renew_account(self, renewal_token: str) -> bool: """Renews the account attached to a given renewal token by pushing back the expiration date by the current validity period in the server's configuration. Args: - renewal_token (str): Token sent with the renewal request. + renewal_token: Token sent with the renewal request. Returns: - bool: Whether the provided token is valid. + Whether the provided token is valid. """ try: - user_id = yield self.store.get_user_from_renewal_token(renewal_token) + user_id = await self.store.get_user_from_renewal_token(renewal_token) except StoreError: - defer.returnValue(False) + return False logger.debug("Renewing an account for user %s", user_id) - yield self.renew_account_for_user(user_id) + await self.renew_account_for_user(user_id) - defer.returnValue(True) + return True - @defer.inlineCallbacks - def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False): + async def renew_account_for_user( + self, user_id: str, expiration_ts: int = None, email_sent: bool = False + ) -> int: """Renews the account attached to a given user by pushing back the expiration date by the current validity period in the server's configuration. Args: - renewal_token (str): Token sent with the renewal request. - expiration_ts (int): New expiration date. Defaults to now + validity period. - email_sent (bool): Whether an email has been sent for this validity period. + renewal_token: Token sent with the renewal request. + expiration_ts: New expiration date. Defaults to now + validity period. + email_sen: Whether an email has been sent for this validity period. Defaults to False. Returns: - defer.Deferred[int]: New expiration date for this account, as a timestamp - in milliseconds since epoch. + New expiration date for this account, as a timestamp in + milliseconds since epoch. """ if expiration_ts is None: expiration_ts = self.clock.time_msec() + self._account_validity.period - yield self.store.set_account_validity_for_user( + await self.store.set_account_validity_for_user( user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent ) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index c0d0d2b44e..d0c997e385 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -391,9 +391,8 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # Email config. self.email_attempts = [] - def sendmail(*args, **kwargs): + async def sendmail(*args, **kwargs): self.email_attempts.append((args, kwargs)) - return config["email"] = { "enable_notifs": True, -- cgit 1.5.1 From 40eda849338b6e47a5804b4cf7000e9d2417c4d8 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 10 Dec 2019 16:22:29 +0000 Subject: Fix race which caused deleted devices to reappear (#6514) Stop the `update_client_ips` background job from recreating deleted devices. --- changelog.d/6514.bugfix | 1 + synapse/storage/data_stores/main/client_ips.py | 8 +++-- tests/storage/test_client_ips.py | 49 +++++++++++++++----------- 3 files changed, 35 insertions(+), 23 deletions(-) create mode 100644 changelog.d/6514.bugfix (limited to 'tests') diff --git a/changelog.d/6514.bugfix b/changelog.d/6514.bugfix new file mode 100644 index 0000000000..6dc1985c24 --- /dev/null +++ b/changelog.d/6514.bugfix @@ -0,0 +1 @@ +Fix race which occasionally caused deleted devices to reappear. diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index 320c5b0f07..add3037b69 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -451,16 +451,18 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): # Technically an access token might not be associated with # a device so we need to check. if device_id: - self.db.simple_upsert_txn( + # this is always an update rather than an upsert: the row should + # already exist, and if it doesn't, that may be because it has been + # deleted, and we don't want to re-create it. + self.db.simple_update_txn( txn, table="devices", keyvalues={"user_id": user_id, "device_id": device_id}, - values={ + updatevalues={ "user_agent": user_agent, "last_seen": last_seen, "ip": ip, }, - lock=False, ) except Exception as e: # Failed to upsert, log and continue diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index fc279340d4..bf674dd184 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -37,9 +37,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.reactor.advance(12345678) user_id = "@user:id" + device_id = "MY_DEVICE" + + # Insert a user IP + self.get_success(self.store.store_device(user_id, device_id, "display name",)) self.get_success( self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + user_id, "access_token", "ip", "user_agent", device_id ) ) @@ -47,14 +51,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.reactor.advance(10) result = self.get_success( - self.store.get_last_client_ip_by_device(user_id, "device_id") + self.store.get_last_client_ip_by_device(user_id, device_id) ) - r = result[(user_id, "device_id")] + r = result[(user_id, device_id)] self.assertDictContainsSubset( { "user_id": user_id, - "device_id": "device_id", + "device_id": device_id, "ip": "ip", "user_agent": "user_agent", "last_seen": 12345678000, @@ -209,14 +213,16 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.store.db.updates.do_next_background_update(100), by=0.1 ) - # Insert a user IP user_id = "@user:id" + device_id = "MY_DEVICE" + + # Insert a user IP + self.get_success(self.store.store_device(user_id, device_id, "display name",)) self.get_success( self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + user_id, "access_token", "ip", "user_agent", device_id ) ) - # Force persisting to disk self.reactor.advance(200) @@ -224,7 +230,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.get_success( self.store.db.simple_update( table="devices", - keyvalues={"user_id": user_id, "device_id": "device_id"}, + keyvalues={"user_id": user_id, "device_id": device_id}, updatevalues={"last_seen": None, "ip": None, "user_agent": None}, desc="test_devices_last_seen_bg_update", ) @@ -232,14 +238,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # We should now get nulls when querying result = self.get_success( - self.store.get_last_client_ip_by_device(user_id, "device_id") + self.store.get_last_client_ip_by_device(user_id, device_id) ) - r = result[(user_id, "device_id")] + r = result[(user_id, device_id)] self.assertDictContainsSubset( { "user_id": user_id, - "device_id": "device_id", + "device_id": device_id, "ip": None, "user_agent": None, "last_seen": None, @@ -272,14 +278,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # We should now get the correct result again result = self.get_success( - self.store.get_last_client_ip_by_device(user_id, "device_id") + self.store.get_last_client_ip_by_device(user_id, device_id) ) - r = result[(user_id, "device_id")] + r = result[(user_id, device_id)] self.assertDictContainsSubset( { "user_id": user_id, - "device_id": "device_id", + "device_id": device_id, "ip": "ip", "user_agent": "user_agent", "last_seen": 0, @@ -296,11 +302,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.store.db.updates.do_next_background_update(100), by=0.1 ) - # Insert a user IP user_id = "@user:id" + device_id = "MY_DEVICE" + + # Insert a user IP + self.get_success(self.store.store_device(user_id, device_id, "display name",)) self.get_success( self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + user_id, "access_token", "ip", "user_agent", device_id ) ) @@ -324,7 +333,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): "access_token": "access_token", "ip": "ip", "user_agent": "user_agent", - "device_id": "device_id", + "device_id": device_id, "last_seen": 0, } ], @@ -347,14 +356,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # But we should still get the correct values for the device result = self.get_success( - self.store.get_last_client_ip_by_device(user_id, "device_id") + self.store.get_last_client_ip_by_device(user_id, device_id) ) - r = result[(user_id, "device_id")] + r = result[(user_id, device_id)] self.assertDictContainsSubset( { "user_id": user_id, - "device_id": "device_id", + "device_id": device_id, "ip": "ip", "user_agent": "user_agent", "last_seen": 0, -- cgit 1.5.1 From e77237b9350a7054657b1a641883a09c6b5d44f3 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 10 Dec 2019 17:01:37 +0000 Subject: convert to async: FederationHandler.on_receive_pdu and associated functions: * on_receive_pdu * handle_queued_pdus * get_missing_events_for_pdu --- synapse/handlers/federation.py | 49 +++++++++++++++++++----------------------- tests/test_federation.py | 14 +++++++----- 2 files changed, 31 insertions(+), 32 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index e54d509b62..01d9c5120e 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -165,8 +165,7 @@ class FederationHandler(BaseHandler): self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages - @defer.inlineCallbacks - def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False): + async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: """ Process a PDU received via a federation /send/ transaction, or via backfill of missing prev_events @@ -176,8 +175,6 @@ class FederationHandler(BaseHandler): pdu (FrozenEvent): received PDU sent_to_us_directly (bool): True if this event was pushed to us; False if we pulled it as the result of a missing prev_event. - - Returns (Deferred): completes with None """ room_id = pdu.room_id @@ -186,7 +183,7 @@ class FederationHandler(BaseHandler): logger.info("handling received PDU: %s", pdu) # We reprocess pdus when we have seen them only as outliers - existing = yield self.store.get_event( + existing = await self.store.get_event( event_id, allow_none=True, allow_rejected=True ) @@ -230,7 +227,7 @@ class FederationHandler(BaseHandler): # # Note that if we were never in the room then we would have already # dropped the event, since we wouldn't know the room version. - is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name) + is_in_room = await self.auth.check_host_in_room(room_id, self.server_name) if not is_in_room: logger.info( "[%s %s] Ignoring PDU from %s as we're not in the room", @@ -246,12 +243,12 @@ class FederationHandler(BaseHandler): # Get missing pdus if necessary. if not pdu.internal_metadata.is_outlier(): # We only backfill backwards to the min depth. - min_depth = yield self.get_min_depth_for_context(pdu.room_id) + min_depth = await self.get_min_depth_for_context(pdu.room_id) logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth) prevs = set(pdu.prev_event_ids()) - seen = yield self.store.have_seen_events(prevs) + seen = await self.store.have_seen_events(prevs) if min_depth and pdu.depth < min_depth: # This is so that we don't notify the user about this @@ -271,7 +268,7 @@ class FederationHandler(BaseHandler): len(missing_prevs), shortstr(missing_prevs), ) - with (yield self._room_pdu_linearizer.queue(pdu.room_id)): + with (await self._room_pdu_linearizer.queue(pdu.room_id)): logger.info( "[%s %s] Acquired room lock to fetch %d missing prev_events", room_id, @@ -280,7 +277,7 @@ class FederationHandler(BaseHandler): ) try: - yield self._get_missing_events_for_pdu( + await self._get_missing_events_for_pdu( origin, pdu, prevs, min_depth ) except Exception as e: @@ -291,7 +288,7 @@ class FederationHandler(BaseHandler): # Update the set of things we've seen after trying to # fetch the missing stuff - seen = yield self.store.have_seen_events(prevs) + seen = await self.store.have_seen_events(prevs) if not prevs - seen: logger.info( @@ -355,7 +352,7 @@ class FederationHandler(BaseHandler): event_map = {event_id: pdu} try: # Get the state of the events we know about - ours = yield self.state_store.get_state_groups_ids(room_id, seen) + ours = await self.state_store.get_state_groups_ids(room_id, seen) # state_maps is a list of mappings from (type, state_key) to event_id state_maps = list( @@ -372,7 +369,7 @@ class FederationHandler(BaseHandler): "Requesting state at missing prev_event %s", event_id, ) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) with nested_logging_context(p): # note that if any of the missing prevs share missing state or @@ -381,11 +378,11 @@ class FederationHandler(BaseHandler): ( remote_state, got_auth_chain, - ) = yield self._get_state_for_room(origin, room_id, p) + ) = await self._get_state_for_room(origin, room_id, p) # we want the state *after* p; _get_state_for_room returns the # state *before* p. - remote_event = yield self.federation_client.get_pdu( + remote_event = await self.federation_client.get_pdu( [origin], p, room_version, outlier=True ) @@ -410,7 +407,7 @@ class FederationHandler(BaseHandler): for x in remote_state: event_map[x.event_id] = x - state_map = yield resolve_events_with_store( + state_map = await resolve_events_with_store( room_version, state_maps, event_map, @@ -422,7 +419,7 @@ class FederationHandler(BaseHandler): # First though we need to fetch all the events that are in # state_map, so we can build up the state below. - evs = yield self.store.get_events( + evs = await self.store.get_events( list(state_map.values()), get_prev_content=False, redact_behaviour=EventRedactBehaviour.AS_IS, @@ -446,12 +443,11 @@ class FederationHandler(BaseHandler): affected=event_id, ) - yield self._process_received_pdu( + await self._process_received_pdu( origin, pdu, state=state, auth_chain=auth_chain ) - @defer.inlineCallbacks - def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): + async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): """ Args: origin (str): Origin of the pdu. Will be called to get the missing events @@ -463,12 +459,12 @@ class FederationHandler(BaseHandler): room_id = pdu.room_id event_id = pdu.event_id - seen = yield self.store.have_seen_events(prevs) + seen = await self.store.have_seen_events(prevs) if not prevs - seen: return - latest = yield self.store.get_latest_event_ids_in_room(room_id) + latest = await self.store.get_latest_event_ids_in_room(room_id) # We add the prev events that we have seen to the latest # list to ensure the remote server doesn't give them to us @@ -532,7 +528,7 @@ class FederationHandler(BaseHandler): # All that said: Let's try increasing the timout to 60s and see what happens. try: - missing_events = yield self.federation_client.get_missing_events( + missing_events = await self.federation_client.get_missing_events( origin, room_id, earliest_events_ids=list(latest), @@ -571,7 +567,7 @@ class FederationHandler(BaseHandler): ) with nested_logging_context(ev.event_id): try: - yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False) + await self.on_receive_pdu(origin, ev, sent_to_us_directly=False) except FederationError as e: if e.code == 403: logger.warning( @@ -1328,8 +1324,7 @@ class FederationHandler(BaseHandler): return True - @defer.inlineCallbacks - def _handle_queued_pdus(self, room_queue): + async def _handle_queued_pdus(self, room_queue): """Process PDUs which got queued up while we were busy send_joining. Args: @@ -1345,7 +1340,7 @@ class FederationHandler(BaseHandler): p.room_id, ) with nested_logging_context(p.event_id): - yield self.on_receive_pdu(origin, p, sent_to_us_directly=True) + await self.on_receive_pdu(origin, p, sent_to_us_directly=True) except Exception as e: logger.warning( "Error handling queued PDU %s from %s: %s", p.event_id, origin, e diff --git a/tests/test_federation.py b/tests/test_federation.py index ad165d7295..68684460c6 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -1,6 +1,6 @@ from mock import Mock -from twisted.internet.defer import maybeDeferred, succeed +from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed from synapse.events import FrozenEvent from synapse.logging.context import LoggingContext @@ -70,8 +70,10 @@ class MessageAcceptTests(unittest.TestCase): ) # Send the join, it should return None (which is not an error) - d = self.handler.on_receive_pdu( - "test.serv", join_event, sent_to_us_directly=True + d = ensureDeferred( + self.handler.on_receive_pdu( + "test.serv", join_event, sent_to_us_directly=True + ) ) self.reactor.advance(1) self.assertEqual(self.successResultOf(d), None) @@ -119,8 +121,10 @@ class MessageAcceptTests(unittest.TestCase): ) with LoggingContext(request="lying_event"): - d = self.handler.on_receive_pdu( - "test.serv", lying_event, sent_to_us_directly=True + d = ensureDeferred( + self.handler.on_receive_pdu( + "test.serv", lying_event, sent_to_us_directly=True + ) ) # Step the reactor, so the database fetches come back -- cgit 1.5.1 From cb2db179945f567410b565f29725dff28449f013 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Thu, 12 Dec 2019 12:03:28 -0500 Subject: look up cross-signing keys from the DB in bulk (#6486) --- changelog.d/6486.bugfix | 1 + synapse/handlers/e2e_keys.py | 35 +++- .../storage/data_stores/main/end_to_end_keys.py | 217 ++++++++++++++++++++- synapse/util/caches/descriptors.py | 2 +- tests/handlers/test_e2e_keys.py | 8 - 5 files changed, 242 insertions(+), 21 deletions(-) create mode 100644 changelog.d/6486.bugfix (limited to 'tests') diff --git a/changelog.d/6486.bugfix b/changelog.d/6486.bugfix new file mode 100644 index 0000000000..b98c5a9ae5 --- /dev/null +++ b/changelog.d/6486.bugfix @@ -0,0 +1 @@ +Improve performance of looking up cross-signing keys. diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 57a10daefd..2d889364d4 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -264,6 +264,7 @@ class E2eKeysHandler(object): return ret + @defer.inlineCallbacks def get_cross_signing_keys_from_cache(self, query, from_user_id): """Get cross-signing keys for users from the database @@ -283,14 +284,32 @@ class E2eKeysHandler(object): self_signing_keys = {} user_signing_keys = {} - # Currently a stub, implementation coming in https://github.com/matrix-org/synapse/pull/6486 - return defer.succeed( - { - "master_keys": master_keys, - "self_signing_keys": self_signing_keys, - "user_signing_keys": user_signing_keys, - } - ) + user_ids = list(query) + + keys = yield self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id) + + for user_id, user_info in keys.items(): + if user_info is None: + continue + if "master" in user_info: + master_keys[user_id] = user_info["master"] + if "self_signing" in user_info: + self_signing_keys[user_id] = user_info["self_signing"] + + if ( + from_user_id in keys + and keys[from_user_id] is not None + and "user_signing" in keys[from_user_id] + ): + # users can see other users' master and self-signing keys, but can + # only see their own user-signing keys + user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"] + + return { + "master_keys": master_keys, + "self_signing_keys": self_signing_keys, + "user_signing_keys": user_signing_keys, + } @trace @defer.inlineCallbacks diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py index 38cd0ca9b8..e551606f9d 100644 --- a/synapse/storage/data_stores/main/end_to_end_keys.py +++ b/synapse/storage/data_stores/main/end_to_end_keys.py @@ -14,15 +14,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, List + from six import iteritems from canonicaljson import encode_canonical_json, json +from twisted.enterprise.adbapi import Connection from twisted.internet import defer from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import cached, cachedList class EndToEndKeyWorkerStore(SQLBaseStore): @@ -271,7 +274,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): Args: txn (twisted.enterprise.adbapi.Connection): db connection user_id (str): the user whose key is being requested - key_type (str): the type of key that is being set: either 'master' + key_type (str): the type of key that is being requested: either 'master' for a master key, 'self_signing' for a self-signing key, or 'user_signing' for a user-signing key from_user_id (str): if specified, signatures made by this user on @@ -316,8 +319,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore): """Returns a user's cross-signing key. Args: - user_id (str): the user whose self-signing key is being requested - key_type (str): the type of cross-signing key to get + user_id (str): the user whose key is being requested + key_type (str): the type of key that is being requested: either 'master' + for a master key, 'self_signing' for a self-signing key, or + 'user_signing' for a user-signing key from_user_id (str): if specified, signatures made by this user on the self-signing key will be included in the result @@ -332,6 +337,206 @@ class EndToEndKeyWorkerStore(SQLBaseStore): from_user_id, ) + @cached(num_args=1) + def _get_bare_e2e_cross_signing_keys(self, user_id): + """Dummy function. Only used to make a cache for + _get_bare_e2e_cross_signing_keys_bulk. + """ + raise NotImplementedError() + + @cachedList( + cached_method_name="_get_bare_e2e_cross_signing_keys", + list_name="user_ids", + num_args=1, + ) + def _get_bare_e2e_cross_signing_keys_bulk( + self, user_ids: List[str] + ) -> Dict[str, Dict[str, dict]]: + """Returns the cross-signing keys for a set of users. The output of this + function should be passed to _get_e2e_cross_signing_signatures_txn if + the signatures for the calling user need to be fetched. + + Args: + user_ids (list[str]): the users whose keys are being requested + + Returns: + dict[str, dict[str, dict]]: mapping from user ID to key type to key + data. If a user's cross-signing keys were not found, either + their user ID will not be in the dict, or their user ID will map + to None. + + """ + return self.db.runInteraction( + "get_bare_e2e_cross_signing_keys_bulk", + self._get_bare_e2e_cross_signing_keys_bulk_txn, + user_ids, + ) + + def _get_bare_e2e_cross_signing_keys_bulk_txn( + self, txn: Connection, user_ids: List[str], + ) -> Dict[str, Dict[str, dict]]: + """Returns the cross-signing keys for a set of users. The output of this + function should be passed to _get_e2e_cross_signing_signatures_txn if + the signatures for the calling user need to be fetched. + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + user_ids (list[str]): the users whose keys are being requested + + Returns: + dict[str, dict[str, dict]]: mapping from user ID to key type to key + data. If a user's cross-signing keys were not found, their user + ID will not be in the dict. + + """ + result = {} + + batch_size = 100 + chunks = [ + user_ids[i : i + batch_size] for i in range(0, len(user_ids), batch_size) + ] + for user_chunk in chunks: + sql = """ + SELECT k.user_id, k.keytype, k.keydata, k.stream_id + FROM e2e_cross_signing_keys k + INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id + FROM e2e_cross_signing_keys + GROUP BY user_id, keytype) s + USING (user_id, stream_id, keytype) + WHERE k.user_id IN (%s) + """ % ( + ",".join("?" for u in user_chunk), + ) + query_params = [] + query_params.extend(user_chunk) + + txn.execute(sql, query_params) + rows = self.db.cursor_to_dict(txn) + + for row in rows: + user_id = row["user_id"] + key_type = row["keytype"] + key = json.loads(row["keydata"]) + user_info = result.setdefault(user_id, {}) + user_info[key_type] = key + + return result + + def _get_e2e_cross_signing_signatures_txn( + self, txn: Connection, keys: Dict[str, Dict[str, dict]], from_user_id: str, + ) -> Dict[str, Dict[str, dict]]: + """Returns the cross-signing signatures made by a user on a set of keys. + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + keys (dict[str, dict[str, dict]]): a map of user ID to key type to + key data. This dict will be modified to add signatures. + from_user_id (str): fetch the signatures made by this user + + Returns: + dict[str, dict[str, dict]]: mapping from user ID to key type to key + data. The return value will be the same as the keys argument, + with the modifications included. + """ + + # find out what cross-signing keys (a.k.a. devices) we need to get + # signatures for. This is a map of (user_id, device_id) to key type + # (device_id is the key's public part). + devices = {} + + for user_id, user_info in keys.items(): + if user_info is None: + continue + for key_type, key in user_info.items(): + device_id = None + for k in key["keys"].values(): + device_id = k + devices[(user_id, device_id)] = key_type + + device_list = list(devices) + + # split into batches + batch_size = 100 + chunks = [ + device_list[i : i + batch_size] + for i in range(0, len(device_list), batch_size) + ] + for user_chunk in chunks: + sql = """ + SELECT target_user_id, target_device_id, key_id, signature + FROM e2e_cross_signing_signatures + WHERE user_id = ? + AND (%s) + """ % ( + " OR ".join( + "(target_user_id = ? AND target_device_id = ?)" for d in devices + ) + ) + query_params = [from_user_id] + for item in devices: + # item is a (user_id, device_id) tuple + query_params.extend(item) + + txn.execute(sql, query_params) + rows = self.db.cursor_to_dict(txn) + + # and add the signatures to the appropriate keys + for row in rows: + key_id = row["key_id"] + target_user_id = row["target_user_id"] + target_device_id = row["target_device_id"] + key_type = devices[(target_user_id, target_device_id)] + # We need to copy everything, because the result may have come + # from the cache. dict.copy only does a shallow copy, so we + # need to recursively copy the dicts that will be modified. + user_info = keys[target_user_id] = keys[target_user_id].copy() + target_user_key = user_info[key_type] = user_info[key_type].copy() + if "signatures" in target_user_key: + signatures = target_user_key["signatures"] = target_user_key[ + "signatures" + ].copy() + if from_user_id in signatures: + user_sigs = signatures[from_user_id] = signatures[from_user_id] + user_sigs[key_id] = row["signature"] + else: + signatures[from_user_id] = {key_id: row["signature"]} + else: + target_user_key["signatures"] = { + from_user_id: {key_id: row["signature"]} + } + + return keys + + @defer.inlineCallbacks + def get_e2e_cross_signing_keys_bulk( + self, user_ids: List[str], from_user_id: str = None + ) -> defer.Deferred: + """Returns the cross-signing keys for a set of users. + + Args: + user_ids (list[str]): the users whose keys are being requested + from_user_id (str): if specified, signatures made by this user on + the self-signing keys will be included in the result + + Returns: + Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to + key data. If a user's cross-signing keys were not found, either + their user ID will not be in the dict, or their user ID will map + to None. + """ + + result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids) + + if from_user_id: + result = yield self.db.runInteraction( + "get_e2e_cross_signing_signatures", + self._get_e2e_cross_signing_signatures_txn, + result, + from_user_id, + ) + + return result + def get_all_user_signature_changes_for_remotes(self, from_key, to_key): """Return a list of changes from the user signature stream to notify remotes. Note that the user signature stream represents when a user signs their @@ -520,6 +725,10 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): }, ) + self._invalidate_cache_and_stream( + txn, self._get_bare_e2e_cross_signing_keys, (user_id,) + ) + def set_e2e_cross_signing_key(self, user_id, key_type, key): """Set a user's cross-signing key. diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 84f5ae22c3..2e8f6543e5 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -271,7 +271,7 @@ class _CacheDescriptorBase(object): else: self.function_to_call = orig - arg_spec = inspect.getargspec(orig) + arg_spec = inspect.getfullargspec(orig) all_args = arg_spec.args if "cache_context" in all_args: diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index fdfa2cbbc4..854eb6c024 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -183,10 +183,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase): ) self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]}) - test_replace_master_key.skip = ( - "Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486" - ) - @defer.inlineCallbacks def test_reupload_signatures(self): """re-uploading a signature should not fail""" @@ -507,7 +503,3 @@ class E2eKeysHandlerTestCase(unittest.TestCase): ], other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey], ) - - test_upload_signatures.skip = ( - "Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486" - ) -- cgit 1.5.1 From 1da15f05f5c9c1e47c9fd1323caff869c2e55aa3 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 13 Dec 2019 12:55:32 +0000 Subject: sanity-checking for events used in state res (#6531) When we perform state resolution, check that all of the events involved are in the right room. --- changelog.d/6531.misc | 1 + synapse/handlers/federation.py | 1 + synapse/state/__init__.py | 32 ++++++++----- synapse/state/v1.py | 34 +++++++++++--- synapse/state/v2.py | 100 ++++++++++++++++++++++++++++++----------- tests/state/test_v2.py | 3 ++ 6 files changed, 128 insertions(+), 43 deletions(-) create mode 100644 changelog.d/6531.misc (limited to 'tests') diff --git a/changelog.d/6531.misc b/changelog.d/6531.misc new file mode 100644 index 0000000000..598efb79fc --- /dev/null +++ b/changelog.d/6531.misc @@ -0,0 +1 @@ +Improve sanity-checking when receiving events over federation. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 2ea69c5468..1d39a9a4f5 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -396,6 +396,7 @@ class FederationHandler(BaseHandler): event_map[x.event_id] = x state_map = await resolve_events_with_store( + room_id, room_version, state_maps, event_map, diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 3e6d62eef1..5accc071ab 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -16,7 +16,7 @@ import logging from collections import namedtuple -from typing import Iterable, Optional +from typing import Dict, Iterable, List, Optional, Tuple from six import iteritems, itervalues @@ -417,6 +417,7 @@ class StateHandler(object): with Measure(self.clock, "state._resolve_events"): new_state = yield resolve_events_with_store( + event.room_id, room_version, state_set_ids, event_map=state_map, @@ -462,7 +463,7 @@ class StateResolutionHandler(object): not be called for a single state group Args: - room_id (str): room we are resolving for (used for logging) + room_id (str): room we are resolving for (used for logging and sanity checks) room_version (str): version of the room state_groups_ids (dict[int, dict[(str, str), str]]): map from state group id to the state in that state group @@ -518,6 +519,7 @@ class StateResolutionHandler(object): logger.info("Resolving conflicted state for %r", room_id) with Measure(self.clock, "state._resolve_events"): new_state = yield resolve_events_with_store( + room_id, room_version, list(itervalues(state_groups_ids)), event_map=event_map, @@ -589,36 +591,44 @@ def _make_state_cache_entry(new_state, state_groups_ids): ) -def resolve_events_with_store(room_version, state_sets, event_map, state_res_store): +def resolve_events_with_store( + room_id: str, + room_version: str, + state_sets: List[Dict[Tuple[str, str], str]], + event_map: Optional[Dict[str, EventBase]], + state_res_store: "StateResolutionStore", +): """ Args: - room_version(str): Version of the room + room_id: the room we are working in + + room_version: Version of the room - state_sets(list): List of dicts of (type, state_key) -> event_id, + state_sets: List of dicts of (type, state_key) -> event_id, which are the different state groups to resolve. - event_map(dict[str,FrozenEvent]|None): + event_map: a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be used as a starting point fof finding the state we need; any missing events will be requested via state_map_factory. - If None, all events will be fetched via state_map_factory. + If None, all events will be fetched via state_res_store. - state_res_store (StateResolutionStore) + state_res_store: a place to fetch events from - Returns + Returns: Deferred[dict[(str, str), str]]: a map from (type, state_key) to event_id. """ v = KNOWN_ROOM_VERSIONS[room_version] if v.state_res == StateResolutionVersions.V1: return v1.resolve_events_with_store( - state_sets, event_map, state_res_store.get_events + room_id, state_sets, event_map, state_res_store.get_events ) else: return v2.resolve_events_with_store( - room_version, state_sets, event_map, state_res_store + room_id, room_version, state_sets, event_map, state_res_store ) diff --git a/synapse/state/v1.py b/synapse/state/v1.py index a2f92d9ff9..b2f9865f39 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -15,6 +15,7 @@ import hashlib import logging +from typing import Callable, Dict, List, Optional, Tuple from six import iteritems, iterkeys, itervalues @@ -24,6 +25,7 @@ from synapse import event_auth from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersions +from synapse.events import EventBase logger = logging.getLogger(__name__) @@ -32,13 +34,20 @@ POWER_KEY = (EventTypes.PowerLevels, "") @defer.inlineCallbacks -def resolve_events_with_store(state_sets, event_map, state_map_factory): +def resolve_events_with_store( + room_id: str, + state_sets: List[Dict[Tuple[str, str], str]], + event_map: Optional[Dict[str, EventBase]], + state_map_factory: Callable, +): """ Args: - state_sets(list): List of dicts of (type, state_key) -> event_id, + room_id: the room we are working in + + state_sets: List of dicts of (type, state_key) -> event_id, which are the different state groups to resolve. - event_map(dict[str,FrozenEvent]|None): + event_map: a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be used as a starting point fof finding the state we need; any missing @@ -46,11 +55,11 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory): If None, all events will be fetched via state_map_factory. - state_map_factory(func): will be called + state_map_factory: will be called with a list of event_ids that are needed, and should return with a Deferred of dict of event_id to event. - Returns + Returns: Deferred[dict[(str, str), str]]: a map from (type, state_key) to event_id. """ @@ -76,6 +85,14 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory): if event_map is not None: state_map.update(event_map) + # everything in the state map should be in the right room + for event in state_map.values(): + if event.room_id != room_id: + raise Exception( + "Attempting to state-resolve for room %s with event %s which is in %s" + % (room_id, event.event_id, event.room_id,) + ) + # get the ids of the auth events which allow us to authenticate the # conflicted state, picking only from the unconflicting state. # @@ -95,6 +112,13 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory): ) state_map_new = yield state_map_factory(new_needed_events) + for event in state_map_new.values(): + if event.room_id != room_id: + raise Exception( + "Attempting to state-resolve for room %s with event %s which is in %s" + % (room_id, event.event_id, event.room_id,) + ) + state_map.update(state_map_new) return _resolve_with_state( diff --git a/synapse/state/v2.py b/synapse/state/v2.py index b327c86f40..cb77ed5b78 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -16,29 +16,40 @@ import heapq import itertools import logging +from typing import Dict, List, Optional, Tuple from six import iteritems, itervalues from twisted.internet import defer +import synapse.state from synapse import event_auth from synapse.api.constants import EventTypes from synapse.api.errors import AuthError +from synapse.events import EventBase logger = logging.getLogger(__name__) @defer.inlineCallbacks -def resolve_events_with_store(room_version, state_sets, event_map, state_res_store): +def resolve_events_with_store( + room_id: str, + room_version: str, + state_sets: List[Dict[Tuple[str, str], str]], + event_map: Optional[Dict[str, EventBase]], + state_res_store: "synapse.state.StateResolutionStore", +): """Resolves the state using the v2 state resolution algorithm Args: - room_version (str): The room version + room_id: the room we are working in + + room_version: The room version - state_sets(list): List of dicts of (type, state_key) -> event_id, + state_sets: List of dicts of (type, state_key) -> event_id, which are the different state groups to resolve. - event_map(dict[str,FrozenEvent]|None): + event_map: a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be used as a starting point fof finding the state we need; any missing @@ -46,9 +57,9 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto If None, all events will be fetched via state_res_store. - state_res_store (StateResolutionStore) + state_res_store: - Returns + Returns: Deferred[dict[(str, str), str]]: a map from (type, state_key) to event_id. """ @@ -84,6 +95,14 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto ) event_map.update(events) + # everything in the event map should be in the right room + for event in event_map.values(): + if event.room_id != room_id: + raise Exception( + "Attempting to state-resolve for room %s with event %s which is in %s" + % (room_id, event.event_id, event.room_id,) + ) + full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map) logger.debug("%d full_conflicted_set entries", len(full_conflicted_set)) @@ -94,13 +113,14 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto ) sorted_power_events = yield _reverse_topological_power_sort( - power_events, event_map, state_res_store, full_conflicted_set + room_id, power_events, event_map, state_res_store, full_conflicted_set ) logger.debug("sorted %d power events", len(sorted_power_events)) # Now sequentially auth each one resolved_state = yield _iterative_auth_checks( + room_id, room_version, sorted_power_events, unconflicted_state, @@ -121,13 +141,18 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto pl = resolved_state.get((EventTypes.PowerLevels, ""), None) leftover_events = yield _mainline_sort( - leftover_events, pl, event_map, state_res_store + room_id, leftover_events, pl, event_map, state_res_store ) logger.debug("resolving remaining events") resolved_state = yield _iterative_auth_checks( - room_version, leftover_events, resolved_state, event_map, state_res_store + room_id, + room_version, + leftover_events, + resolved_state, + event_map, + state_res_store, ) logger.debug("resolved") @@ -141,11 +166,12 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto @defer.inlineCallbacks -def _get_power_level_for_sender(event_id, event_map, state_res_store): +def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): """Return the power level of the sender of the given event according to their auth events. Args: + room_id (str) event_id (str) event_map (dict[str,FrozenEvent]) state_res_store (StateResolutionStore) @@ -153,11 +179,11 @@ def _get_power_level_for_sender(event_id, event_map, state_res_store): Returns: Deferred[int] """ - event = yield _get_event(event_id, event_map, state_res_store) + event = yield _get_event(room_id, event_id, event_map, state_res_store) pl = None for aid in event.auth_event_ids(): - aev = yield _get_event(aid, event_map, state_res_store) + aev = yield _get_event(room_id, aid, event_map, state_res_store) if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): pl = aev break @@ -165,7 +191,7 @@ def _get_power_level_for_sender(event_id, event_map, state_res_store): if pl is None: # Couldn't find power level. Check if they're the creator of the room for aid in event.auth_event_ids(): - aev = yield _get_event(aid, event_map, state_res_store) + aev = yield _get_event(room_id, aid, event_map, state_res_store) if (aev.type, aev.state_key) == (EventTypes.Create, ""): if aev.content.get("creator") == event.sender: return 100 @@ -279,7 +305,7 @@ def _is_power_event(event): @defer.inlineCallbacks def _add_event_and_auth_chain_to_graph( - graph, event_id, event_map, state_res_store, auth_diff + graph, room_id, event_id, event_map, state_res_store, auth_diff ): """Helper function for _reverse_topological_power_sort that add the event and its auth chain (that is in the auth diff) to the graph @@ -287,6 +313,7 @@ def _add_event_and_auth_chain_to_graph( Args: graph (dict[str, set[str]]): A map from event ID to the events auth event IDs + room_id (str): the room we are working in event_id (str): Event to add to the graph event_map (dict[str,FrozenEvent]) state_res_store (StateResolutionStore) @@ -298,7 +325,7 @@ def _add_event_and_auth_chain_to_graph( eid = state.pop() graph.setdefault(eid, set()) - event = yield _get_event(eid, event_map, state_res_store) + event = yield _get_event(room_id, eid, event_map, state_res_store) for aid in event.auth_event_ids(): if aid in auth_diff: if aid not in graph: @@ -308,11 +335,14 @@ def _add_event_and_auth_chain_to_graph( @defer.inlineCallbacks -def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_diff): +def _reverse_topological_power_sort( + room_id, event_ids, event_map, state_res_store, auth_diff +): """Returns a list of the event_ids sorted by reverse topological ordering, and then by power level and origin_server_ts Args: + room_id (str): the room we are working in event_ids (list[str]): The events to sort event_map (dict[str,FrozenEvent]) state_res_store (StateResolutionStore) @@ -325,12 +355,14 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_ graph = {} for event_id in event_ids: yield _add_event_and_auth_chain_to_graph( - graph, event_id, event_map, state_res_store, auth_diff + graph, room_id, event_id, event_map, state_res_store, auth_diff ) event_to_pl = {} for event_id in graph: - pl = yield _get_power_level_for_sender(event_id, event_map, state_res_store) + pl = yield _get_power_level_for_sender( + room_id, event_id, event_map, state_res_store + ) event_to_pl[event_id] = pl def _get_power_order(event_id): @@ -348,12 +380,13 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_ @defer.inlineCallbacks def _iterative_auth_checks( - room_version, event_ids, base_state, event_map, state_res_store + room_id, room_version, event_ids, base_state, event_map, state_res_store ): """Sequentially apply auth checks to each event in given list, updating the state as it goes along. Args: + room_id (str) room_version (str) event_ids (list[str]): Ordered list of events to apply auth checks to base_state (dict[tuple[str, str], str]): The set of state to start with @@ -370,7 +403,7 @@ def _iterative_auth_checks( auth_events = {} for aid in event.auth_event_ids(): - ev = yield _get_event(aid, event_map, state_res_store) + ev = yield _get_event(room_id, aid, event_map, state_res_store) if ev.rejected_reason is None: auth_events[(ev.type, ev.state_key)] = ev @@ -378,7 +411,7 @@ def _iterative_auth_checks( for key in event_auth.auth_types_for_event(event): if key in resolved_state: ev_id = resolved_state[key] - ev = yield _get_event(ev_id, event_map, state_res_store) + ev = yield _get_event(room_id, ev_id, event_map, state_res_store) if ev.rejected_reason is None: auth_events[key] = event_map[ev_id] @@ -400,11 +433,14 @@ def _iterative_auth_checks( @defer.inlineCallbacks -def _mainline_sort(event_ids, resolved_power_event_id, event_map, state_res_store): +def _mainline_sort( + room_id, event_ids, resolved_power_event_id, event_map, state_res_store +): """Returns a sorted list of event_ids sorted by mainline ordering based on the given event resolved_power_event_id Args: + room_id (str): room we're working in event_ids (list[str]): Events to sort resolved_power_event_id (str): The final resolved power level event ID event_map (dict[str,FrozenEvent]) @@ -417,11 +453,11 @@ def _mainline_sort(event_ids, resolved_power_event_id, event_map, state_res_stor pl = resolved_power_event_id while pl: mainline.append(pl) - pl_ev = yield _get_event(pl, event_map, state_res_store) + pl_ev = yield _get_event(room_id, pl, event_map, state_res_store) auth_events = pl_ev.auth_event_ids() pl = None for aid in auth_events: - ev = yield _get_event(aid, event_map, state_res_store) + ev = yield _get_event(room_id, aid, event_map, state_res_store) if (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""): pl = aid break @@ -457,6 +493,8 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor Deferred[int] """ + room_id = event.room_id + # We do an iterative search, replacing `event with the power level in its # auth events (if any) while event: @@ -468,7 +506,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor event = None for aid in auth_events: - aev = yield _get_event(aid, event_map, state_res_store) + aev = yield _get_event(room_id, aid, event_map, state_res_store) if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): event = aev break @@ -478,11 +516,12 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor @defer.inlineCallbacks -def _get_event(event_id, event_map, state_res_store): +def _get_event(room_id, event_id, event_map, state_res_store): """Helper function to look up event in event_map, falling back to looking it up in the store Args: + room_id (str) event_id (str) event_map (dict[str,FrozenEvent]) state_res_store (StateResolutionStore) @@ -493,7 +532,14 @@ def _get_event(event_id, event_map, state_res_store): if event_id not in event_map: events = yield state_res_store.get_events([event_id], allow_rejected=True) event_map.update(events) - return event_map[event_id] + event = event_map[event_id] + assert event is not None + if event.room_id != room_id: + raise Exception( + "In state res for room %s, event %s is in %s" + % (room_id, event_id, event.room_id) + ) + return event def lexicographical_topological_sort(graph, key): diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index 8d3845c870..0f341d3ac3 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -58,6 +58,7 @@ class FakeEvent(object): self.type = type self.state_key = state_key self.content = content + self.room_id = ROOM_ID def to_event(self, auth_events, prev_events): """Given the auth_events and prev_events, convert to a Frozen Event @@ -418,6 +419,7 @@ class StateTestCase(unittest.TestCase): state_before = dict(state_at_event[prev_events[0]]) else: state_d = resolve_events_with_store( + ROOM_ID, RoomVersions.V2.identifier, [state_at_event[n] for n in prev_events], event_map=event_map, @@ -565,6 +567,7 @@ class SimpleParamStateTestCase(unittest.TestCase): # Test that we correctly handle passing `None` as the event_map state_d = resolve_events_with_store( + ROOM_ID, RoomVersions.V2.identifier, [self.state_at_bob, self.state_at_charlie], event_map=None, -- cgit 1.5.1 From 83895316d4e18d4a52c43524942d98f864bac6f9 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 13 Dec 2019 12:55:32 +0000 Subject: sanity-checking for events used in state res (#6531) When we perform state resolution, check that all of the events involved are in the right room. --- changelog.d/6531.misc | 1 + synapse/handlers/federation.py | 1 + synapse/state/__init__.py | 32 ++++++++----- synapse/state/v1.py | 34 +++++++++++--- synapse/state/v2.py | 100 ++++++++++++++++++++++++++++++----------- tests/state/test_v2.py | 3 ++ 6 files changed, 128 insertions(+), 43 deletions(-) create mode 100644 changelog.d/6531.misc (limited to 'tests') diff --git a/changelog.d/6531.misc b/changelog.d/6531.misc new file mode 100644 index 0000000000..598efb79fc --- /dev/null +++ b/changelog.d/6531.misc @@ -0,0 +1 @@ +Improve sanity-checking when receiving events over federation. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index ebeffbb768..fd3f5ced55 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -397,6 +397,7 @@ class FederationHandler(BaseHandler): event_map[x.event_id] = x state_map = yield resolve_events_with_store( + room_id, room_version, state_maps, event_map, diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 139beef8ed..0e75e94c6f 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -16,7 +16,7 @@ import logging from collections import namedtuple -from typing import Iterable, Optional +from typing import Dict, Iterable, List, Optional, Tuple from six import iteritems, itervalues @@ -416,6 +416,7 @@ class StateHandler(object): with Measure(self.clock, "state._resolve_events"): new_state = yield resolve_events_with_store( + event.room_id, room_version, state_set_ids, event_map=state_map, @@ -461,7 +462,7 @@ class StateResolutionHandler(object): not be called for a single state group Args: - room_id (str): room we are resolving for (used for logging) + room_id (str): room we are resolving for (used for logging and sanity checks) room_version (str): version of the room state_groups_ids (dict[int, dict[(str, str), str]]): map from state group id to the state in that state group @@ -517,6 +518,7 @@ class StateResolutionHandler(object): logger.info("Resolving conflicted state for %r", room_id) with Measure(self.clock, "state._resolve_events"): new_state = yield resolve_events_with_store( + room_id, room_version, list(itervalues(state_groups_ids)), event_map=event_map, @@ -588,36 +590,44 @@ def _make_state_cache_entry(new_state, state_groups_ids): ) -def resolve_events_with_store(room_version, state_sets, event_map, state_res_store): +def resolve_events_with_store( + room_id: str, + room_version: str, + state_sets: List[Dict[Tuple[str, str], str]], + event_map: Optional[Dict[str, EventBase]], + state_res_store: "StateResolutionStore", +): """ Args: - room_version(str): Version of the room + room_id: the room we are working in + + room_version: Version of the room - state_sets(list): List of dicts of (type, state_key) -> event_id, + state_sets: List of dicts of (type, state_key) -> event_id, which are the different state groups to resolve. - event_map(dict[str,FrozenEvent]|None): + event_map: a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be used as a starting point fof finding the state we need; any missing events will be requested via state_map_factory. - If None, all events will be fetched via state_map_factory. + If None, all events will be fetched via state_res_store. - state_res_store (StateResolutionStore) + state_res_store: a place to fetch events from - Returns + Returns: Deferred[dict[(str, str), str]]: a map from (type, state_key) to event_id. """ v = KNOWN_ROOM_VERSIONS[room_version] if v.state_res == StateResolutionVersions.V1: return v1.resolve_events_with_store( - state_sets, event_map, state_res_store.get_events + room_id, state_sets, event_map, state_res_store.get_events ) else: return v2.resolve_events_with_store( - room_version, state_sets, event_map, state_res_store + room_id, room_version, state_sets, event_map, state_res_store ) diff --git a/synapse/state/v1.py b/synapse/state/v1.py index a2f92d9ff9..b2f9865f39 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -15,6 +15,7 @@ import hashlib import logging +from typing import Callable, Dict, List, Optional, Tuple from six import iteritems, iterkeys, itervalues @@ -24,6 +25,7 @@ from synapse import event_auth from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersions +from synapse.events import EventBase logger = logging.getLogger(__name__) @@ -32,13 +34,20 @@ POWER_KEY = (EventTypes.PowerLevels, "") @defer.inlineCallbacks -def resolve_events_with_store(state_sets, event_map, state_map_factory): +def resolve_events_with_store( + room_id: str, + state_sets: List[Dict[Tuple[str, str], str]], + event_map: Optional[Dict[str, EventBase]], + state_map_factory: Callable, +): """ Args: - state_sets(list): List of dicts of (type, state_key) -> event_id, + room_id: the room we are working in + + state_sets: List of dicts of (type, state_key) -> event_id, which are the different state groups to resolve. - event_map(dict[str,FrozenEvent]|None): + event_map: a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be used as a starting point fof finding the state we need; any missing @@ -46,11 +55,11 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory): If None, all events will be fetched via state_map_factory. - state_map_factory(func): will be called + state_map_factory: will be called with a list of event_ids that are needed, and should return with a Deferred of dict of event_id to event. - Returns + Returns: Deferred[dict[(str, str), str]]: a map from (type, state_key) to event_id. """ @@ -76,6 +85,14 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory): if event_map is not None: state_map.update(event_map) + # everything in the state map should be in the right room + for event in state_map.values(): + if event.room_id != room_id: + raise Exception( + "Attempting to state-resolve for room %s with event %s which is in %s" + % (room_id, event.event_id, event.room_id,) + ) + # get the ids of the auth events which allow us to authenticate the # conflicted state, picking only from the unconflicting state. # @@ -95,6 +112,13 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory): ) state_map_new = yield state_map_factory(new_needed_events) + for event in state_map_new.values(): + if event.room_id != room_id: + raise Exception( + "Attempting to state-resolve for room %s with event %s which is in %s" + % (room_id, event.event_id, event.room_id,) + ) + state_map.update(state_map_new) return _resolve_with_state( diff --git a/synapse/state/v2.py b/synapse/state/v2.py index b327c86f40..cb77ed5b78 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -16,29 +16,40 @@ import heapq import itertools import logging +from typing import Dict, List, Optional, Tuple from six import iteritems, itervalues from twisted.internet import defer +import synapse.state from synapse import event_auth from synapse.api.constants import EventTypes from synapse.api.errors import AuthError +from synapse.events import EventBase logger = logging.getLogger(__name__) @defer.inlineCallbacks -def resolve_events_with_store(room_version, state_sets, event_map, state_res_store): +def resolve_events_with_store( + room_id: str, + room_version: str, + state_sets: List[Dict[Tuple[str, str], str]], + event_map: Optional[Dict[str, EventBase]], + state_res_store: "synapse.state.StateResolutionStore", +): """Resolves the state using the v2 state resolution algorithm Args: - room_version (str): The room version + room_id: the room we are working in + + room_version: The room version - state_sets(list): List of dicts of (type, state_key) -> event_id, + state_sets: List of dicts of (type, state_key) -> event_id, which are the different state groups to resolve. - event_map(dict[str,FrozenEvent]|None): + event_map: a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be used as a starting point fof finding the state we need; any missing @@ -46,9 +57,9 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto If None, all events will be fetched via state_res_store. - state_res_store (StateResolutionStore) + state_res_store: - Returns + Returns: Deferred[dict[(str, str), str]]: a map from (type, state_key) to event_id. """ @@ -84,6 +95,14 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto ) event_map.update(events) + # everything in the event map should be in the right room + for event in event_map.values(): + if event.room_id != room_id: + raise Exception( + "Attempting to state-resolve for room %s with event %s which is in %s" + % (room_id, event.event_id, event.room_id,) + ) + full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map) logger.debug("%d full_conflicted_set entries", len(full_conflicted_set)) @@ -94,13 +113,14 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto ) sorted_power_events = yield _reverse_topological_power_sort( - power_events, event_map, state_res_store, full_conflicted_set + room_id, power_events, event_map, state_res_store, full_conflicted_set ) logger.debug("sorted %d power events", len(sorted_power_events)) # Now sequentially auth each one resolved_state = yield _iterative_auth_checks( + room_id, room_version, sorted_power_events, unconflicted_state, @@ -121,13 +141,18 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto pl = resolved_state.get((EventTypes.PowerLevels, ""), None) leftover_events = yield _mainline_sort( - leftover_events, pl, event_map, state_res_store + room_id, leftover_events, pl, event_map, state_res_store ) logger.debug("resolving remaining events") resolved_state = yield _iterative_auth_checks( - room_version, leftover_events, resolved_state, event_map, state_res_store + room_id, + room_version, + leftover_events, + resolved_state, + event_map, + state_res_store, ) logger.debug("resolved") @@ -141,11 +166,12 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto @defer.inlineCallbacks -def _get_power_level_for_sender(event_id, event_map, state_res_store): +def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): """Return the power level of the sender of the given event according to their auth events. Args: + room_id (str) event_id (str) event_map (dict[str,FrozenEvent]) state_res_store (StateResolutionStore) @@ -153,11 +179,11 @@ def _get_power_level_for_sender(event_id, event_map, state_res_store): Returns: Deferred[int] """ - event = yield _get_event(event_id, event_map, state_res_store) + event = yield _get_event(room_id, event_id, event_map, state_res_store) pl = None for aid in event.auth_event_ids(): - aev = yield _get_event(aid, event_map, state_res_store) + aev = yield _get_event(room_id, aid, event_map, state_res_store) if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): pl = aev break @@ -165,7 +191,7 @@ def _get_power_level_for_sender(event_id, event_map, state_res_store): if pl is None: # Couldn't find power level. Check if they're the creator of the room for aid in event.auth_event_ids(): - aev = yield _get_event(aid, event_map, state_res_store) + aev = yield _get_event(room_id, aid, event_map, state_res_store) if (aev.type, aev.state_key) == (EventTypes.Create, ""): if aev.content.get("creator") == event.sender: return 100 @@ -279,7 +305,7 @@ def _is_power_event(event): @defer.inlineCallbacks def _add_event_and_auth_chain_to_graph( - graph, event_id, event_map, state_res_store, auth_diff + graph, room_id, event_id, event_map, state_res_store, auth_diff ): """Helper function for _reverse_topological_power_sort that add the event and its auth chain (that is in the auth diff) to the graph @@ -287,6 +313,7 @@ def _add_event_and_auth_chain_to_graph( Args: graph (dict[str, set[str]]): A map from event ID to the events auth event IDs + room_id (str): the room we are working in event_id (str): Event to add to the graph event_map (dict[str,FrozenEvent]) state_res_store (StateResolutionStore) @@ -298,7 +325,7 @@ def _add_event_and_auth_chain_to_graph( eid = state.pop() graph.setdefault(eid, set()) - event = yield _get_event(eid, event_map, state_res_store) + event = yield _get_event(room_id, eid, event_map, state_res_store) for aid in event.auth_event_ids(): if aid in auth_diff: if aid not in graph: @@ -308,11 +335,14 @@ def _add_event_and_auth_chain_to_graph( @defer.inlineCallbacks -def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_diff): +def _reverse_topological_power_sort( + room_id, event_ids, event_map, state_res_store, auth_diff +): """Returns a list of the event_ids sorted by reverse topological ordering, and then by power level and origin_server_ts Args: + room_id (str): the room we are working in event_ids (list[str]): The events to sort event_map (dict[str,FrozenEvent]) state_res_store (StateResolutionStore) @@ -325,12 +355,14 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_ graph = {} for event_id in event_ids: yield _add_event_and_auth_chain_to_graph( - graph, event_id, event_map, state_res_store, auth_diff + graph, room_id, event_id, event_map, state_res_store, auth_diff ) event_to_pl = {} for event_id in graph: - pl = yield _get_power_level_for_sender(event_id, event_map, state_res_store) + pl = yield _get_power_level_for_sender( + room_id, event_id, event_map, state_res_store + ) event_to_pl[event_id] = pl def _get_power_order(event_id): @@ -348,12 +380,13 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_ @defer.inlineCallbacks def _iterative_auth_checks( - room_version, event_ids, base_state, event_map, state_res_store + room_id, room_version, event_ids, base_state, event_map, state_res_store ): """Sequentially apply auth checks to each event in given list, updating the state as it goes along. Args: + room_id (str) room_version (str) event_ids (list[str]): Ordered list of events to apply auth checks to base_state (dict[tuple[str, str], str]): The set of state to start with @@ -370,7 +403,7 @@ def _iterative_auth_checks( auth_events = {} for aid in event.auth_event_ids(): - ev = yield _get_event(aid, event_map, state_res_store) + ev = yield _get_event(room_id, aid, event_map, state_res_store) if ev.rejected_reason is None: auth_events[(ev.type, ev.state_key)] = ev @@ -378,7 +411,7 @@ def _iterative_auth_checks( for key in event_auth.auth_types_for_event(event): if key in resolved_state: ev_id = resolved_state[key] - ev = yield _get_event(ev_id, event_map, state_res_store) + ev = yield _get_event(room_id, ev_id, event_map, state_res_store) if ev.rejected_reason is None: auth_events[key] = event_map[ev_id] @@ -400,11 +433,14 @@ def _iterative_auth_checks( @defer.inlineCallbacks -def _mainline_sort(event_ids, resolved_power_event_id, event_map, state_res_store): +def _mainline_sort( + room_id, event_ids, resolved_power_event_id, event_map, state_res_store +): """Returns a sorted list of event_ids sorted by mainline ordering based on the given event resolved_power_event_id Args: + room_id (str): room we're working in event_ids (list[str]): Events to sort resolved_power_event_id (str): The final resolved power level event ID event_map (dict[str,FrozenEvent]) @@ -417,11 +453,11 @@ def _mainline_sort(event_ids, resolved_power_event_id, event_map, state_res_stor pl = resolved_power_event_id while pl: mainline.append(pl) - pl_ev = yield _get_event(pl, event_map, state_res_store) + pl_ev = yield _get_event(room_id, pl, event_map, state_res_store) auth_events = pl_ev.auth_event_ids() pl = None for aid in auth_events: - ev = yield _get_event(aid, event_map, state_res_store) + ev = yield _get_event(room_id, aid, event_map, state_res_store) if (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""): pl = aid break @@ -457,6 +493,8 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor Deferred[int] """ + room_id = event.room_id + # We do an iterative search, replacing `event with the power level in its # auth events (if any) while event: @@ -468,7 +506,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor event = None for aid in auth_events: - aev = yield _get_event(aid, event_map, state_res_store) + aev = yield _get_event(room_id, aid, event_map, state_res_store) if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): event = aev break @@ -478,11 +516,12 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor @defer.inlineCallbacks -def _get_event(event_id, event_map, state_res_store): +def _get_event(room_id, event_id, event_map, state_res_store): """Helper function to look up event in event_map, falling back to looking it up in the store Args: + room_id (str) event_id (str) event_map (dict[str,FrozenEvent]) state_res_store (StateResolutionStore) @@ -493,7 +532,14 @@ def _get_event(event_id, event_map, state_res_store): if event_id not in event_map: events = yield state_res_store.get_events([event_id], allow_rejected=True) event_map.update(events) - return event_map[event_id] + event = event_map[event_id] + assert event is not None + if event.room_id != room_id: + raise Exception( + "In state res for room %s, event %s is in %s" + % (room_id, event_id, event.room_id) + ) + return event def lexicographical_topological_sort(graph, key): diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index 8d3845c870..0f341d3ac3 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -58,6 +58,7 @@ class FakeEvent(object): self.type = type self.state_key = state_key self.content = content + self.room_id = ROOM_ID def to_event(self, auth_events, prev_events): """Given the auth_events and prev_events, convert to a Frozen Event @@ -418,6 +419,7 @@ class StateTestCase(unittest.TestCase): state_before = dict(state_at_event[prev_events[0]]) else: state_d = resolve_events_with_store( + ROOM_ID, RoomVersions.V2.identifier, [state_at_event[n] for n in prev_events], event_map=event_map, @@ -565,6 +567,7 @@ class SimpleParamStateTestCase(unittest.TestCase): # Test that we correctly handle passing `None` as the event_map state_d = resolve_events_with_store( + ROOM_ID, RoomVersions.V2.identifier, [self.state_at_bob, self.state_at_charlie], event_map=None, -- cgit 1.5.1 From 596dd9914dad1933ded1426bdec1e2b1e6874e39 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 16 Dec 2019 14:53:21 +0000 Subject: Add test case --- tests/rest/client/v1/test_rooms.py | 133 +++++++++++++++++++++++++++++++++++++ 1 file changed, 133 insertions(+) (limited to 'tests') diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 1ca7fa742f..9cb505f316 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -29,6 +29,7 @@ import synapse.rest.admin from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.handlers.pagination import PurgeStatus from synapse.rest.client.v1 import login, profile, room +from synapse.rest.client.v2_alpha import account from synapse.util.stringutils import random_string from tests import unittest @@ -1597,3 +1598,135 @@ class LabelsTestCase(unittest.HomeserverTestCase): ) return event_id + + +class ContextTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + account.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.hs = self.setup_test_homeserver() + + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("user", "password") + self.tok = self.login("user", "password") + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + self.other_user_id = self.register_user("user2", "password") + self.other_tok = self.login("user2", "password") + + self.helper.invite(self.room_id, self.user_id, self.other_user_id, tok=self.tok) + self.helper.join(self.room_id, self.other_user_id, tok=self.other_tok) + + def test_erased_sender(self): + """Test that an erasure request results in the requester's events being hidden + from any new member of the room. + """ + + # Send a bunch of events in the room. + + self.helper.send(self.room_id, "message 1", tok=self.tok) + self.helper.send(self.room_id, "message 2", tok=self.tok) + event_id = self.helper.send(self.room_id, "message 3", tok=self.tok)["event_id"] + self.helper.send(self.room_id, "message 4", tok=self.tok) + self.helper.send(self.room_id, "message 5", tok=self.tok) + + # Check that we can still see the messages before the erasure request. + + request, channel = self.make_request( + "GET", + '/rooms/%s/context/%s?filter={"types":["m.room.message"]}' + % (self.room_id, event_id), + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual(len(events_before), 2, events_before) + self.assertEqual( + events_before[0].get("content", {}).get("body"), + "message 2", + events_before[0], + ) + self.assertEqual( + events_before[1].get("content", {}).get("body"), + "message 1", + events_before[1], + ) + + self.assertEqual( + channel.json_body["event"].get("content", {}).get("body"), + "message 3", + channel.json_body["event"], + ) + + events_after = channel.json_body["events_after"] + + self.assertEqual(len(events_after), 2, events_after) + self.assertEqual( + events_after[0].get("content", {}).get("body"), + "message 4", + events_after[0], + ) + self.assertEqual( + events_after[1].get("content", {}).get("body"), + "message 5", + events_after[1], + ) + + # Deactivate the first account and erase the user's data. + + deactivate_account_handler = self.hs.get_deactivate_account_handler() + self.get_success( + deactivate_account_handler.deactivate_account(self.user_id, erase_data=True) + ) + + # Invite another user in the room. This is needed because messages will be + # pruned only if the user wasn't a member of the room when the messages were + # sent. + + invited_user_id = self.register_user("user3", "password") + invited_tok = self.login("user3", "password") + + self.helper.invite( + self.room_id, self.other_user_id, invited_user_id, tok=self.other_tok + ) + self.helper.join(self.room_id, invited_user_id, tok=invited_tok) + + # Check that a user that joined the room after the erasure request can't see + # the messages anymore. + + request, channel = self.make_request( + 'GET', + '/rooms/%s/context/%s?filter={"types":["m.room.message"]}' + % (self.room_id, event_id), + access_token=invited_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual(len(events_before), 2, events_before) + self.assertDictEqual(events_before[0].get("content"), {}, events_before[0]) + self.assertDictEqual(events_before[1].get("content"), {}, events_before[1]) + + self.assertDictEqual( + channel.json_body["event"].get("content"), {}, channel.json_body["event"] + ) + + events_after = channel.json_body["events_after"] + + self.assertEqual(len(events_after), 2, events_after) + self.assertDictEqual(events_after[0].get("content"), {}, events_after[0]) + self.assertEqual(events_after[1].get("content"), {}, events_after[1]) + -- cgit 1.5.1 From a29420f9f449da72d1d38bcab4cedc182e9f2ba0 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 16 Dec 2019 14:55:50 +0000 Subject: Lint --- tests/rest/client/v1/test_rooms.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 9cb505f316..dca0fef97a 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1706,7 +1706,7 @@ class ContextTestCase(unittest.HomeserverTestCase): # the messages anymore. request, channel = self.make_request( - 'GET', + "GET", '/rooms/%s/context/%s?filter={"types":["m.room.message"]}' % (self.room_id, event_id), access_token=invited_tok, @@ -1729,4 +1729,3 @@ class ContextTestCase(unittest.HomeserverTestCase): self.assertEqual(len(events_after), 2, events_after) self.assertDictEqual(events_after[0].get("content"), {}, events_after[0]) self.assertEqual(events_after[1].get("content"), {}, events_after[1]) - -- cgit 1.5.1 From a82006954912ed96b0d47db43db44e76e5b052d6 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 16 Dec 2019 16:00:18 +0000 Subject: Incorporate review --- synapse/handlers/room.py | 2 +- tests/rest/client/v1/test_rooms.py | 5 ----- 2 files changed, 1 insertion(+), 6 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 7f979e5812..60b8bbc7a5 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -941,7 +941,7 @@ class RoomContextHandler(object): if event_filter: state_events = event_filter.filter(state_events) - results["state"] = state_events + results["state"] = yield filter_evts(state_events) # We use a dummy token here as we only care about the room portion of # the token, which we replace. diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index dca0fef97a..e3af280ba6 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1609,11 +1609,6 @@ class ContextTestCase(unittest.HomeserverTestCase): account.register_servlets, ] - def make_homeserver(self, reactor, clock): - self.hs = self.setup_test_homeserver() - - return self.hs - def prepare(self, reactor, clock, homeserver): self.user_id = self.register_user("user", "password") self.tok = self.login("user", "password") -- cgit 1.5.1 From bfb95654c97a8d3aa164eff96ecc13755c1c326d Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Mon, 16 Dec 2019 16:11:55 +0000 Subject: Add option to allow profile queries without sharing a room (#6523) --- changelog.d/6523.feature | 1 + docs/sample_config.yaml | 7 +++++++ synapse/config/server.py | 13 +++++++++++++ synapse/handlers/profile.py | 6 +++++- tests/rest/client/v1/test_profile.py | 2 ++ 5 files changed, 28 insertions(+), 1 deletion(-) create mode 100644 changelog.d/6523.feature (limited to 'tests') diff --git a/changelog.d/6523.feature b/changelog.d/6523.feature new file mode 100644 index 0000000000..798fa143df --- /dev/null +++ b/changelog.d/6523.feature @@ -0,0 +1 @@ +Add option `limit_profile_requests_to_users_who_share_rooms` to prevent requirement of a local user sharing a room with another user to query their profile information. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 4d44e631d1..1787248f53 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -54,6 +54,13 @@ pid_file: DATADIR/homeserver.pid # #require_auth_for_profile_requests: true +# Uncomment to require a user to share a room with another user in order +# to retrieve their profile information. Only checked on Client-Server +# requests. Profile requests from other servers should be checked by the +# requesting server. Defaults to 'false'. +# +#limit_profile_requests_to_users_who_share_rooms: true + # If set to 'true', removes the need for authentication to access the server's # public rooms directory through the client API, meaning that anyone can # query the room directory. Defaults to 'false'. diff --git a/synapse/config/server.py b/synapse/config/server.py index 50af858c76..38f6ff9edc 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -102,6 +102,12 @@ class ServerConfig(Config): "require_auth_for_profile_requests", False ) + # Whether to require sharing a room with a user to retrieve their + # profile data + self.limit_profile_requests_to_users_who_share_rooms = config.get( + "limit_profile_requests_to_users_who_share_rooms", False, + ) + if "restrict_public_rooms_to_local_users" in config and ( "allow_public_rooms_without_auth" in config or "allow_public_rooms_over_federation" in config @@ -621,6 +627,13 @@ class ServerConfig(Config): # #require_auth_for_profile_requests: true + # Uncomment to require a user to share a room with another user in order + # to retrieve their profile information. Only checked on Client-Server + # requests. Profile requests from other servers should be checked by the + # requesting server. Defaults to 'false'. + # + #limit_profile_requests_to_users_who_share_rooms: true + # If set to 'true', removes the need for authentication to access the server's # public rooms directory through the client API, meaning that anyone can # query the room directory. Defaults to 'false'. diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 1e5a4613c9..f9579d69ee 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -295,12 +295,16 @@ class BaseProfileHandler(BaseHandler): be found to be in any room the server is in, and therefore the query is denied. """ + # Implementation of MSC1301: don't allow looking up profiles if the # requester isn't in the same room as the target. We expect requester to # be None when this function is called outside of a profile query, e.g. # when building a membership event. In this case, we must allow the # lookup. - if not self.hs.config.require_auth_for_profile_requests or not requester: + if ( + not self.hs.config.limit_profile_requests_to_users_who_share_rooms + or not requester + ): return # Always allow the user to query their own profile. diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 12c5e95cb5..8df58b4a63 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -237,6 +237,7 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): config = self.default_config() config["require_auth_for_profile_requests"] = True + config["limit_profile_requests_to_users_who_share_rooms"] = True self.hs = self.setup_test_homeserver(config=config) return self.hs @@ -309,6 +310,7 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() config["require_auth_for_profile_requests"] = True + config["limit_profile_requests_to_users_who_share_rooms"] = True self.hs = self.setup_test_homeserver(config=config) return self.hs -- cgit 1.5.1 From 2284eb3a533a2df04784df08da28e67d6588a5ea Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 18 Dec 2019 10:45:12 +0000 Subject: Add database config class (#6513) This encapsulates config for a given database and is the way to get new connections. --- changelog.d/6513.misc | 1 + scripts-dev/update_database | 9 +-- scripts/synapse_port_db | 58 ++++++++----------- synapse/config/database.py | 78 ++++++++++++++++++++------ synapse/handlers/presence.py | 2 +- synapse/server.py | 41 ++------------ synapse/storage/_base.py | 2 +- synapse/storage/data_stores/__init__.py | 40 ++++++++++--- synapse/storage/data_stores/main/client_ips.py | 2 +- synapse/storage/database.py | 45 ++++++++++++++- synapse/storage/engines/sqlite.py | 16 +++++- synapse/storage/prepare_database.py | 7 +-- tests/handlers/test_typing.py | 39 ++++++------- tests/replication/slave/storage/_base.py | 6 +- tests/server.py | 55 +++++++++--------- tests/storage/test_appservice.py | 37 ++++++++---- tests/storage/test_base.py | 14 +++-- tests/storage/test_registration.py | 1 - tests/utils.py | 43 +++++--------- 19 files changed, 287 insertions(+), 209 deletions(-) create mode 100644 changelog.d/6513.misc (limited to 'tests') diff --git a/changelog.d/6513.misc b/changelog.d/6513.misc new file mode 100644 index 0000000000..36700f5657 --- /dev/null +++ b/changelog.d/6513.misc @@ -0,0 +1 @@ +Remove all assumptions of there being a single phyiscal DB apart from the `synapse.config`. diff --git a/scripts-dev/update_database b/scripts-dev/update_database index 23017c21f8..1d62f0403a 100755 --- a/scripts-dev/update_database +++ b/scripts-dev/update_database @@ -26,7 +26,6 @@ from synapse.config.homeserver import HomeServerConfig from synapse.metrics.background_process_metrics import run_as_background_process from synapse.server import HomeServer from synapse.storage import DataStore -from synapse.storage.prepare_database import prepare_database logger = logging.getLogger("update_database") @@ -77,12 +76,8 @@ if __name__ == "__main__": # Instantiate and initialise the homeserver object. hs = MockHomeserver(config) - db_conn = hs.get_db_conn() - # Update the database to the latest schema. - prepare_database(db_conn, hs.database_engine, config=config) - db_conn.commit() - - # setup instantiates the store within the homeserver object. + # Setup instantiates the store within the homeserver object and updates the + # DB. hs.setup() store = hs.get_datastore() diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index e393a9b2f7..5b5368988c 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -30,6 +30,7 @@ import yaml from twisted.enterprise import adbapi from twisted.internet import defer, reactor +from synapse.config.database import DatabaseConnectionConfig from synapse.config.homeserver import HomeServerConfig from synapse.logging.context import PreserveLoggingContext from synapse.storage._base import LoggingTransaction @@ -55,7 +56,7 @@ from synapse.storage.data_stores.main.stats import StatsStore from synapse.storage.data_stores.main.user_directory import ( UserDirectoryBackgroundUpdateStore, ) -from synapse.storage.database import Database +from synapse.storage.database import Database, make_conn from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database from synapse.util import Clock @@ -165,23 +166,17 @@ class Store( class MockHomeserver: - def __init__(self, config, database_engine, db_conn, db_pool): - self.database_engine = database_engine - self.db_conn = db_conn - self.db_pool = db_pool + def __init__(self, config): self.clock = Clock(reactor) self.config = config self.hostname = config.server_name - def get_db_conn(self): - return self.db_conn - - def get_db_pool(self): - return self.db_pool - def get_clock(self): return self.clock + def get_reactor(self): + return reactor + class Porter(object): def __init__(self, **kwargs): @@ -445,45 +440,36 @@ class Porter(object): else: return - def setup_db(self, db_config, database_engine): - db_conn = database_engine.module.connect( - **{ - k: v - for k, v in db_config.get("args", {}).items() - if not k.startswith("cp_") - } - ) - - prepare_database(db_conn, database_engine, config=None) + def setup_db(self, db_config: DatabaseConnectionConfig, engine): + db_conn = make_conn(db_config, engine) + prepare_database(db_conn, engine, config=None) db_conn.commit() return db_conn @defer.inlineCallbacks - def build_db_store(self, config): + def build_db_store(self, db_config: DatabaseConnectionConfig): """Builds and returns a database store using the provided configuration. Args: - config: The database configuration, i.e. a dict following the structure of - the "database" section of Synapse's configuration file. + config: The database configuration Returns: The built Store object. """ - engine = create_engine(config) - - self.progress.set_state("Preparing %s" % config["name"]) - conn = self.setup_db(config, engine) + self.progress.set_state("Preparing %s" % db_config.config["name"]) - db_pool = adbapi.ConnectionPool(config["name"], **config["args"]) + engine = create_engine(db_config.config) + conn = self.setup_db(db_config, engine) - hs = MockHomeserver(self.hs_config, engine, conn, db_pool) + hs = MockHomeserver(self.hs_config) - store = Store(Database(hs), conn, hs) + store = Store(Database(hs, db_config, engine), conn, hs) yield store.db.runInteraction( - "%s_engine.check_database" % config["name"], engine.check_database, + "%s_engine.check_database" % db_config.config["name"], + engine.check_database, ) return store @@ -509,7 +495,11 @@ class Porter(object): @defer.inlineCallbacks def run(self): try: - self.sqlite_store = yield self.build_db_store(self.sqlite_config) + self.sqlite_store = yield self.build_db_store( + DatabaseConnectionConfig( + "master", self.sqlite_config, data_stores=["main"] + ) + ) # Check if all background updates are done, abort if not. updates_complete = ( @@ -524,7 +514,7 @@ class Porter(object): defer.returnValue(None) self.postgres_store = yield self.build_db_store( - self.hs_config.database_config + self.hs_config.get_single_database() ) yield self.run_background_updates_on_postgres() diff --git a/synapse/config/database.py b/synapse/config/database.py index 0e2509f0b1..5f2f3c7cfd 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -12,12 +12,43 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import os from textwrap import indent +from typing import List import yaml -from ._base import Config +from synapse.config._base import Config, ConfigError + +logger = logging.getLogger(__name__) + + +class DatabaseConnectionConfig: + """Contains the connection config for a particular database. + + Args: + name: A label for the database, used for logging. + db_config: The config for a particular database, as per `database` + section of main config. Has two fields: `name` for database + module name, and `args` for the args to give to the database + connector. + data_stores: The list of data stores that should be provisioned on the + database. + """ + + def __init__(self, name: str, db_config: dict, data_stores: List[str]): + if db_config["name"] not in ("sqlite3", "psycopg2"): + raise ConfigError("Unsupported database type %r" % (db_config["name"],)) + + if db_config["name"] == "sqlite3": + db_config.setdefault("args", {}).update( + {"cp_min": 1, "cp_max": 1, "check_same_thread": False} + ) + + self.name = name + self.config = db_config + self.data_stores = data_stores class DatabaseConfig(Config): @@ -26,20 +57,14 @@ class DatabaseConfig(Config): def read_config(self, config, **kwargs): self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K")) - self.database_config = config.get("database") + database_config = config.get("database") - if self.database_config is None: - self.database_config = {"name": "sqlite3", "args": {}} + if database_config is None: + database_config = {"name": "sqlite3", "args": {}} - name = self.database_config.get("name", None) - if name == "psycopg2": - pass - elif name == "sqlite3": - self.database_config.setdefault("args", {}).update( - {"cp_min": 1, "cp_max": 1, "check_same_thread": False} - ) - else: - raise RuntimeError("Unsupported database type '%s'" % (name,)) + self.databases = [ + DatabaseConnectionConfig("master", database_config, data_stores=["main"]) + ] self.set_databasepath(config.get("database_path")) @@ -76,11 +101,24 @@ class DatabaseConfig(Config): self.set_databasepath(args.database_path) def set_databasepath(self, database_path): + if database_path is None: + return + if database_path != ":memory:": database_path = self.abspath(database_path) - if self.database_config.get("name", None) == "sqlite3": - if database_path is not None: - self.database_config["args"]["database"] = database_path + + # We only support setting a database path if we have a single sqlite3 + # database. + if len(self.databases) != 1: + raise ConfigError("Cannot specify 'database_path' with multiple databases") + + database = self.get_single_database() + if database.config["name"] != "sqlite3": + # We don't raise here as we haven't done so before for this case. + logger.warn("Ignoring 'database_path' for non-sqlite3 database") + return + + database.config["args"]["database"] = database_path @staticmethod def add_arguments(parser): @@ -91,3 +129,11 @@ class DatabaseConfig(Config): metavar="SQLITE_DATABASE_PATH", help="The path to a sqlite database to use.", ) + + def get_single_database(self) -> DatabaseConnectionConfig: + """Returns the database if there is only one, useful for e.g. tests + """ + if len(self.databases) != 1: + raise Exception("More than one database exists") + + return self.databases[0] diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index eda15bc623..240c4add12 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -230,7 +230,7 @@ class PresenceHandler(object): is some spurious presence changes that will self-correct. """ # If the DB pool has already terminated, don't try updating - if not self.hs.get_db_pool().running: + if not self.store.database.is_running(): return logger.info( diff --git a/synapse/server.py b/synapse/server.py index 5021068ce0..7926867b77 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -25,7 +25,6 @@ import abc import logging import os -from twisted.enterprise import adbapi from twisted.mail.smtp import sendmail from twisted.web.client import BrowserLikePolicyForHTTPS @@ -98,7 +97,6 @@ from synapse.server_notices.worker_server_notices_sender import ( ) from synapse.state import StateHandler, StateResolutionHandler from synapse.storage import DataStores, Storage -from synapse.storage.engines import create_engine from synapse.streams.events import EventSources from synapse.util import Clock from synapse.util.distributor import Distributor @@ -134,7 +132,6 @@ class HomeServer(object): DEPENDENCIES = [ "http_client", - "db_pool", "federation_client", "federation_server", "handlers", @@ -233,12 +230,6 @@ class HomeServer(object): self.admin_redaction_ratelimiter = Ratelimiter() self.registration_ratelimiter = Ratelimiter() - self.database_engine = create_engine(config.database_config) - config.database_config.setdefault("args", {})[ - "cp_openfun" - ] = self.database_engine.on_new_connection - self.db_config = config.database_config - self.datastores = None # Other kwargs are explicit dependencies @@ -247,10 +238,8 @@ class HomeServer(object): def setup(self): logger.info("Setting up.") - with self.get_db_conn() as conn: - self.datastores = DataStores(self.DATASTORE_CLASS, conn, self) - conn.commit() self.start_time = int(self.get_clock().time()) + self.datastores = DataStores(self.DATASTORE_CLASS, self) logger.info("Finished setting up.") def setup_master(self): @@ -284,6 +273,9 @@ class HomeServer(object): def get_datastore(self): return self.datastores.main + def get_datastores(self): + return self.datastores + def get_config(self): return self.config @@ -433,31 +425,6 @@ class HomeServer(object): ) return MatrixFederationHttpClient(self, tls_client_options_factory) - def build_db_pool(self): - name = self.db_config["name"] - - return adbapi.ConnectionPool( - name, cp_reactor=self.get_reactor(), **self.db_config.get("args", {}) - ) - - def get_db_conn(self, run_new_connection=True): - """Makes a new connection to the database, skipping the db pool - - Returns: - Connection: a connection object implementing the PEP-249 spec - """ - # Any param beginning with cp_ is a parameter for adbapi, and should - # not be passed to the database engine. - db_params = { - k: v - for k, v in self.db_config.get("args", {}).items() - if not k.startswith("cp_") - } - db_conn = self.database_engine.module.connect(**db_params) - if run_new_connection: - self.database_engine.on_new_connection(db_conn) - return db_conn - def build_media_repository_resource(self): # build the media repo resource. This indirects through the HomeServer # to ensure that we only have a single instance of diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index b7637b5dc0..88546ad614 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -40,7 +40,7 @@ class SQLBaseStore(object): def __init__(self, database: Database, db_conn, hs): self.hs = hs self._clock = hs.get_clock() - self.database_engine = hs.database_engine + self.database_engine = database.engine self.db = database self.rand = random.SystemRandom() diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py index cafedd5c0d..0983e059c0 100644 --- a/synapse/storage/data_stores/__init__.py +++ b/synapse/storage/data_stores/__init__.py @@ -13,24 +13,50 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.database import Database +import logging + +from synapse.storage.database import Database, make_conn +from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database +logger = logging.getLogger(__name__) + class DataStores(object): """The various data stores. These are low level interfaces to physical databases. + + Attributes: + main (DataStore) """ - def __init__(self, main_store_class, db_conn, hs): + def __init__(self, main_store_class, hs): # Note we pass in the main store class here as workers use a different main # store. - database = Database(hs) - # Check that db is correctly configured. - database.engine.check_database(db_conn.cursor()) + self.databases = [] + + for database_config in hs.config.database.databases: + db_name = database_config.name + engine = create_engine(database_config.config) + + with make_conn(database_config, engine) as db_conn: + logger.info("Preparing database %r...", db_name) + + engine.check_database(db_conn.cursor()) + prepare_database( + db_conn, engine, hs.config, data_stores=database_config.data_stores, + ) + + database = Database(hs, database_config, engine) + + if "main" in database_config.data_stores: + logger.info("Starting 'main' data store") + self.main = main_store_class(database, db_conn, hs) + + db_conn.commit() - prepare_database(db_conn, database.engine, config=hs.config) + self.databases.append(database) - self.main = main_store_class(database, db_conn, hs) + logger.info("Database %r prepared", db_name) diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index add3037b69..13f4c9c72e 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -412,7 +412,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): def _update_client_ips_batch(self): # If the DB pool has already terminated, don't try updating - if not self.hs.get_db_pool().running: + if not self.db.is_running(): return to_update = self._batch_row_update diff --git a/synapse/storage/database.py b/synapse/storage/database.py index ec19ae1d9d..1003dd84a5 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -24,9 +24,11 @@ from six.moves import intern, range from prometheus_client import Histogram +from twisted.enterprise import adbapi from twisted.internet import defer from synapse.api.errors import StoreError +from synapse.config.database import DatabaseConnectionConfig from synapse.logging.context import LoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater @@ -74,6 +76,37 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = { } +def make_pool( + reactor, db_config: DatabaseConnectionConfig, engine +) -> adbapi.ConnectionPool: + """Get the connection pool for the database. + """ + + return adbapi.ConnectionPool( + db_config.config["name"], + cp_reactor=reactor, + cp_openfun=engine.on_new_connection, + **db_config.config.get("args", {}) + ) + + +def make_conn(db_config: DatabaseConnectionConfig, engine): + """Make a new connection to the database and return it. + + Returns: + Connection + """ + + db_params = { + k: v + for k, v in db_config.config.get("args", {}).items() + if not k.startswith("cp_") + } + db_conn = engine.module.connect(**db_params) + engine.on_new_connection(db_conn) + return db_conn + + class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging and metrics to the .execute() @@ -218,10 +251,11 @@ class Database(object): _TXN_ID = 0 - def __init__(self, hs): + def __init__(self, hs, database_config: DatabaseConnectionConfig, engine): self.hs = hs self._clock = hs.get_clock() - self._db_pool = hs.get_db_pool() + self._database_config = database_config + self._db_pool = make_pool(hs.get_reactor(), database_config, engine) self.updates = BackgroundUpdater(hs, self) @@ -234,7 +268,7 @@ class Database(object): # to watch it self._txn_perf_counters = PerformanceCounters() - self.engine = hs.database_engine + self.engine = engine # A set of tables that are not safe to use native upserts in. self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys()) @@ -255,6 +289,11 @@ class Database(object): self._check_safe_to_upsert, ) + def is_running(self): + """Is the database pool currently running + """ + return self._db_pool.running + @defer.inlineCallbacks def _check_safe_to_upsert(self): """ diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index ddad17dc5a..df039a072d 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -16,8 +16,6 @@ import struct import threading -from synapse.storage.prepare_database import prepare_database - class Sqlite3Engine(object): single_threaded = True @@ -25,6 +23,9 @@ class Sqlite3Engine(object): def __init__(self, database_module, database_config): self.module = database_module + database = database_config.get("args", {}).get("database") + self._is_in_memory = database in (None, ":memory:",) + # The current max state_group, or None if we haven't looked # in the DB yet. self._current_state_group_id = None @@ -59,7 +60,16 @@ class Sqlite3Engine(object): return sql def on_new_connection(self, db_conn): - prepare_database(db_conn, self, config=None) + + # We need to import here to avoid an import loop. + from synapse.storage.prepare_database import prepare_database + + if self._is_in_memory: + # In memory databases need to be rebuilt each time. Ideally we'd + # reuse the same connection as we do when starting up, but that + # would involve using adbapi before we have started the reactor. + prepare_database(db_conn, self, config=None) + db_conn.create_function("rank", 1, _rank) def is_deadlock(self, error): diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 731e1c9d9c..b4194b44ee 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -41,7 +41,7 @@ class UpgradeDatabaseException(PrepareDatabaseException): pass -def prepare_database(db_conn, database_engine, config): +def prepare_database(db_conn, database_engine, config, data_stores=["main"]): """Prepares a database for usage. Will either create all necessary tables or upgrade from an older schema version. @@ -54,11 +54,10 @@ def prepare_database(db_conn, database_engine, config): config (synapse.config.homeserver.HomeServerConfig|None): application config, or None if we are connecting to an existing database which we expect to be configured already + data_stores (list[str]): The name of the data stores that will be used + with this database. Defaults to all data stores. """ - # For now we only have the one datastore. - data_stores = ["main"] - try: cur = db_conn.cursor() version_info = _get_or_create_schema_state(cur, database_engine) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 92b8726093..596ddc6970 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -64,28 +64,29 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): mock_federation_client = Mock(spec=["put_json"]) mock_federation_client.put_json.return_value = defer.succeed((200, "OK")) + datastores = Mock() + datastores.main = Mock( + spec=[ + # Bits that Federation needs + "prep_send_transaction", + "delivered_txn", + "get_received_txn_response", + "set_received_txn_response", + "get_destination_retry_timings", + "get_devices_by_remote", + # Bits that user_directory needs + "get_user_directory_stream_pos", + "get_current_state_deltas", + "get_device_updates_by_remote", + ] + ) + hs = self.setup_test_homeserver( - datastore=( - Mock( - spec=[ - # Bits that Federation needs - "prep_send_transaction", - "delivered_txn", - "get_received_txn_response", - "set_received_txn_response", - "get_destination_retry_timings", - "get_device_updates_by_remote", - # Bits that user_directory needs - "get_user_directory_stream_pos", - "get_current_state_deltas", - ] - ) - ), - notifier=Mock(), - http_client=mock_federation_client, - keyring=mock_keyring, + notifier=Mock(), http_client=mock_federation_client, keyring=mock_keyring ) + hs.datastores = datastores + return hs def prepare(self, reactor, clock, hs): diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 3dae83c543..2a1e7c7166 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -20,7 +20,7 @@ from synapse.replication.tcp.client import ( ReplicationClientHandler, ) from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory -from synapse.storage.database import Database +from synapse.storage.database import make_conn from tests import unittest from tests.server import FakeTransport @@ -41,10 +41,12 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): + db_config = hs.config.database.get_single_database() self.master_store = self.hs.get_datastore() self.storage = hs.get_storage() + database = hs.get_datastores().databases[0] self.slaved_store = self.STORE_TYPE( - Database(hs), self.hs.get_db_conn(), self.hs + database, make_conn(db_config, database.engine), self.hs ) self.event_id = 0 diff --git a/tests/server.py b/tests/server.py index 2b7cf4242e..a554dfdd57 100644 --- a/tests/server.py +++ b/tests/server.py @@ -302,41 +302,42 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): Set up a synchronous test server, driven by the reactor used by the homeserver. """ - d = _sth(cleanup_func, *args, **kwargs).result + server = _sth(cleanup_func, *args, **kwargs) - if isinstance(d, Failure): - d.raiseException() + database = server.config.database.get_single_database() # Make the thread pool synchronous. - clock = d.get_clock() - pool = d.get_db_pool() - - def runWithConnection(func, *args, **kwargs): - return threads.deferToThreadPool( - pool._reactor, - pool.threadpool, - pool._runWithConnection, - func, - *args, - **kwargs - ) - - def runInteraction(interaction, *args, **kwargs): - return threads.deferToThreadPool( - pool._reactor, - pool.threadpool, - pool._runInteraction, - interaction, - *args, - **kwargs - ) + clock = server.get_clock() + + for database in server.get_datastores().databases: + pool = database._db_pool + + def runWithConnection(func, *args, **kwargs): + return threads.deferToThreadPool( + pool._reactor, + pool.threadpool, + pool._runWithConnection, + func, + *args, + **kwargs + ) + + def runInteraction(interaction, *args, **kwargs): + return threads.deferToThreadPool( + pool._reactor, + pool.threadpool, + pool._runInteraction, + interaction, + *args, + **kwargs + ) - if pool: pool.runWithConnection = runWithConnection pool.runInteraction = runInteraction pool.threadpool = ThreadPool(clock._reactor) pool.running = True - return d + + return server def get_clock(): diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 2e521e9ab7..fd52512696 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -28,7 +28,7 @@ from synapse.storage.data_stores.main.appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore, ) -from synapse.storage.database import Database +from synapse.storage.database import Database, make_conn from tests import unittest from tests.utils import setup_test_homeserver @@ -55,8 +55,10 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob") self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob") # must be done after inserts - database = Database(hs) - self.store = ApplicationServiceStore(database, hs.get_db_conn(), hs) + database = hs.get_datastores().databases[0] + self.store = ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) def tearDown(self): # TODO: suboptimal that we need to create files for tests! @@ -111,9 +113,6 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): hs.config.event_cache_size = 1 hs.config.password_providers = [] - self.db_pool = hs.get_db_pool() - self.engine = hs.database_engine - self.as_list = [ {"token": "token1", "url": "https://matrix-as.org", "id": "id_1"}, {"token": "alpha_tok", "url": "https://alpha.com", "id": "id_alpha"}, @@ -125,8 +124,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.as_yaml_files = [] - database = Database(hs) - self.store = TestTransactionStore(database, hs.get_db_conn(), hs) + # We assume there is only one database in these tests + database = hs.get_datastores().databases[0] + self.db_pool = database._db_pool + self.engine = database.engine + + db_config = hs.config.get_single_database() + self.store = TestTransactionStore( + database, make_conn(db_config, self.engine), hs + ) def _add_service(self, url, as_token, id): as_yaml = dict( @@ -419,7 +425,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs.config.event_cache_size = 1 hs.config.password_providers = [] - ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs) + database = hs.get_datastores().databases[0] + ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) @defer.inlineCallbacks def test_duplicate_ids(self): @@ -435,7 +444,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: - ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs) + database = hs.get_datastores().databases[0] + ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) e = cm.exception self.assertIn(f1, str(e)) @@ -456,7 +468,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: - ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs) + database = hs.get_datastores().databases[0] + ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) e = cm.exception self.assertIn(f1, str(e)) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 537cfe9f64..cdee0a9e60 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -52,15 +52,17 @@ class SQLBaseStoreTestCase(unittest.TestCase): config = Mock() config._disable_native_upserts = True config.event_cache_size = 1 - config.database_config = {"name": "sqlite3"} - engine = create_engine(config.database_config) + hs = TestHomeServer("test", config=config) + + sqlite_config = {"name": "sqlite3"} + engine = create_engine(sqlite_config) fake_engine = Mock(wraps=engine) fake_engine.can_native_upsert = False - hs = TestHomeServer( - "test", db_pool=self.db_pool, config=config, database_engine=fake_engine - ) - self.datastore = SQLBaseStore(Database(hs), None, hs) + db = Database(Mock(), Mock(config=sqlite_config), fake_engine) + db._db_pool = self.db_pool + + self.datastore = SQLBaseStore(db, None, hs) @defer.inlineCallbacks def test_insert_1col(self): diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 4578cc3b60..ed5786865a 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -26,7 +26,6 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): hs = yield setup_test_homeserver(self.addCleanup) - self.db_pool = hs.get_db_pool() self.store = hs.get_datastore() diff --git a/tests/utils.py b/tests/utils.py index 585f305b9a..9f5bf40b4b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -30,6 +30,7 @@ from twisted.internet import defer, reactor from synapse.api.constants import EventTypes from synapse.api.errors import CodeMessageException, cs_error from synapse.api.room_versions import RoomVersions +from synapse.config.database import DatabaseConnectionConfig from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.federation.transport import server as federation_server @@ -177,7 +178,6 @@ class TestHomeServer(HomeServer): DATASTORE_CLASS = DataStore -@defer.inlineCallbacks def setup_test_homeserver( cleanup_func, name="test", @@ -214,7 +214,7 @@ def setup_test_homeserver( if USE_POSTGRES_FOR_TESTS: test_db = "synapse_test_%s" % uuid.uuid4().hex - config.database_config = { + database_config = { "name": "psycopg2", "args": { "database": test_db, @@ -226,12 +226,15 @@ def setup_test_homeserver( }, } else: - config.database_config = { + database_config = { "name": "sqlite3", "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1}, } - db_engine = create_engine(config.database_config) + database = DatabaseConnectionConfig("master", database_config, ["main"]) + config.database.databases = [database] + + db_engine = create_engine(database.config) # Create the database before we actually try and connect to it, based off # the template database we generate in setupdb() @@ -251,11 +254,6 @@ def setup_test_homeserver( cur.close() db_conn.close() - # we need to configure the connection pool to run the on_new_connection - # function, so that we can test code that uses custom sqlite functions - # (like rank). - config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection - if datastore is None: hs = homeserverToUse( name, @@ -267,21 +265,19 @@ def setup_test_homeserver( **kargs ) - # Prepare the DB on SQLite -- PostgreSQL is a copy of an already up to - # date db - if not isinstance(db_engine, PostgresEngine): - db_conn = hs.get_db_conn() - yield prepare_database(db_conn, db_engine, config) - db_conn.commit() - db_conn.close() + hs.setup() + if homeserverToUse.__name__ == "TestHomeServer": + hs.setup_master() + + if isinstance(db_engine, PostgresEngine): + database = hs.get_datastores().databases[0] - else: # We need to do cleanup on PostgreSQL def cleanup(): import psycopg2 # Close all the db pools - hs.get_db_pool().close() + database._db_pool.close() dropped = False @@ -320,23 +316,12 @@ def setup_test_homeserver( # Register the cleanup hook cleanup_func(cleanup) - hs.setup() - if homeserverToUse.__name__ == "TestHomeServer": - hs.setup_master() else: - # If we have been given an explicit datastore we probably want to mock - # out the DataStores somehow too. This all feels a bit wrong, but then - # mocking the stores feels wrong too. - datastores = Mock(datastore=datastore) - hs = homeserverToUse( name, - db_pool=None, datastore=datastore, - datastores=datastores, config=config, version_string="Synapse/tests", - database_engine=db_engine, tls_server_context_factory=Mock(), tls_client_options_factory=Mock(), reactor=reactor, -- cgit 1.5.1 From d6752ce5da38d35857fe324800d76a86ee1e64f1 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 18 Dec 2019 14:26:58 +0000 Subject: Clean up startup for the pusher (#6558) * Remove redundant python2 support code `str.decode()` doesn't exist on python3, so presumably this code was doing nothing * Filter out pushers with corrupt data When we get a row with unparsable json, drop the row, rather than returning a row with null `data`, which will then cause an explosion later on. * Improve logging when we can't start a pusher Log the ID to help us understand the problem * Make email pusher setup more robust We know we'll have a `data` member, since that comes from the database. What we *don't* know is if that is a dict, and if that has a `brand` member, and if that member is a string. --- changelog.d/6558.misc | 1 + synapse/push/pusher.py | 12 ++++++----- synapse/push/pusherpool.py | 10 +++++---- synapse/rest/client/v1/pusher.py | 33 +++++++++++++++--------------- synapse/storage/data_stores/main/pusher.py | 25 ++++++++-------------- tests/push/test_email.py | 3 +++ tests/push/test_http.py | 4 ++++ 7 files changed, 45 insertions(+), 43 deletions(-) create mode 100644 changelog.d/6558.misc (limited to 'tests') diff --git a/changelog.d/6558.misc b/changelog.d/6558.misc new file mode 100644 index 0000000000..a7572f1a85 --- /dev/null +++ b/changelog.d/6558.misc @@ -0,0 +1 @@ +Clean up logs from the push notifier at startup. \ No newline at end of file diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py index f277aeb131..8ad0bf5936 100644 --- a/synapse/push/pusher.py +++ b/synapse/push/pusher.py @@ -80,9 +80,11 @@ class PusherFactory(object): return EmailPusher(self.hs, pusherdict, mailer) def _app_name_from_pusherdict(self, pusherdict): - if "data" in pusherdict and "brand" in pusherdict["data"]: - app_name = pusherdict["data"]["brand"] - else: - app_name = self.config.email_app_name + data = pusherdict["data"] - return app_name + if isinstance(data, dict): + brand = data.get("brand") + if isinstance(brand, str): + return brand + + return self.config.email_app_name diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 0f6992202d..b9dca5bc63 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -232,7 +232,6 @@ class PusherPool: Deferred """ pushers = yield self.store.get_all_pushers() - logger.info("Starting %d pushers", len(pushers)) # Stagger starting up the pushers so we don't completely drown the # process on start up. @@ -245,7 +244,7 @@ class PusherPool: """Start the given pusher Args: - pusherdict (dict): + pusherdict (dict): dict with the values pulled from the db table Returns: Deferred[EmailPusher|HttpPusher] @@ -254,7 +253,8 @@ class PusherPool: p = self.pusher_factory.create_pusher(pusherdict) except PusherConfigException as e: logger.warning( - "Pusher incorrectly configured user=%s, appid=%s, pushkey=%s: %s", + "Pusher incorrectly configured id=%i, user=%s, appid=%s, pushkey=%s: %s", + pusherdict["id"], pusherdict.get("user_name"), pusherdict.get("app_id"), pusherdict.get("pushkey"), @@ -262,7 +262,9 @@ class PusherPool: ) return except Exception: - logger.exception("Couldn't start a pusher: caught Exception") + logger.exception( + "Couldn't start pusher id %i: caught Exception", pusherdict["id"], + ) return if not p: diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 0791866f55..6f6b7aed6e 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -28,6 +28,17 @@ from synapse.rest.client.v2_alpha._base import client_patterns logger = logging.getLogger(__name__) +ALLOWED_KEYS = { + "app_display_name", + "app_id", + "data", + "device_display_name", + "kind", + "lang", + "profile_tag", + "pushkey", +} + class PushersRestServlet(RestServlet): PATTERNS = client_patterns("/pushers$", v1=True) @@ -43,23 +54,11 @@ class PushersRestServlet(RestServlet): pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) - allowed_keys = [ - "app_display_name", - "app_id", - "data", - "device_display_name", - "kind", - "lang", - "profile_tag", - "pushkey", - ] - - for p in pushers: - for k, v in list(p.items()): - if k not in allowed_keys: - del p[k] - - return 200, {"pushers": pushers} + filtered_pushers = list( + {k: v for k, v in p.items() if k in ALLOWED_KEYS} for p in pushers + ) + + return 200, {"pushers": filtered_pushers} def on_OPTIONS(self, _): return 200, {} diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py index f07309ef09..6b03233262 100644 --- a/synapse/storage/data_stores/main/pusher.py +++ b/synapse/storage/data_stores/main/pusher.py @@ -15,8 +15,7 @@ # limitations under the License. import logging - -import six +from typing import Iterable, Iterator from canonicaljson import encode_canonical_json, json @@ -27,21 +26,16 @@ from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList logger = logging.getLogger(__name__) -if six.PY2: - db_binary_type = six.moves.builtins.buffer -else: - db_binary_type = memoryview - class PusherWorkerStore(SQLBaseStore): - def _decode_pushers_rows(self, rows): + def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[dict]: + """JSON-decode the data in the rows returned from the `pushers` table + + Drops any rows whose data cannot be decoded + """ for r in rows: dataJson = r["data"] - r["data"] = None try: - if isinstance(dataJson, db_binary_type): - dataJson = str(dataJson).decode("UTF8") - r["data"] = json.loads(dataJson) except Exception as e: logger.warning( @@ -50,12 +44,9 @@ class PusherWorkerStore(SQLBaseStore): dataJson, e.args[0], ) - pass - - if isinstance(r["pushkey"], db_binary_type): - r["pushkey"] = str(r["pushkey"]).decode("UTF8") + continue - return rows + yield r @defer.inlineCallbacks def user_has_pusher(self, user_id): diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 358b593cd4..80187406bc 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -165,6 +165,7 @@ class EmailPusherTests(HomeserverTestCase): pushers = self.get_success( self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id)) ) + pushers = list(pushers) self.assertEqual(len(pushers), 1) last_stream_ordering = pushers[0]["last_stream_ordering"] @@ -175,6 +176,7 @@ class EmailPusherTests(HomeserverTestCase): pushers = self.get_success( self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id)) ) + pushers = list(pushers) self.assertEqual(len(pushers), 1) self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"]) @@ -192,5 +194,6 @@ class EmailPusherTests(HomeserverTestCase): pushers = self.get_success( self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id)) ) + pushers = list(pushers) self.assertEqual(len(pushers), 1) self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) diff --git a/tests/push/test_http.py b/tests/push/test_http.py index af2327fb66..fe3441f081 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -104,6 +104,7 @@ class HTTPPusherTests(HomeserverTestCase): pushers = self.get_success( self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) ) + pushers = list(pushers) self.assertEqual(len(pushers), 1) last_stream_ordering = pushers[0]["last_stream_ordering"] @@ -114,6 +115,7 @@ class HTTPPusherTests(HomeserverTestCase): pushers = self.get_success( self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) ) + pushers = list(pushers) self.assertEqual(len(pushers), 1) self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"]) @@ -132,6 +134,7 @@ class HTTPPusherTests(HomeserverTestCase): pushers = self.get_success( self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) ) + pushers = list(pushers) self.assertEqual(len(pushers), 1) self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) last_stream_ordering = pushers[0]["last_stream_ordering"] @@ -151,5 +154,6 @@ class HTTPPusherTests(HomeserverTestCase): pushers = self.get_success( self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) ) + pushers = list(pushers) self.assertEqual(len(pushers), 1) self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) -- cgit 1.5.1 From fa780e9721c940479a72eed9877ccad4fef78160 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 20 Dec 2019 10:32:02 +0000 Subject: Change EventContext to use the Storage class (#6564) --- changelog.d/6564.misc | 1 + synapse/api/auth.py | 2 +- synapse/events/snapshot.py | 36 +++++++++++++++----------- synapse/events/third_party_rules.py | 2 +- synapse/handlers/_base.py | 2 +- synapse/handlers/federation.py | 14 +++++----- synapse/handlers/message.py | 10 +++---- synapse/handlers/room.py | 2 +- synapse/handlers/room_member.py | 4 +-- synapse/push/bulk_push_rule_evaluator.py | 4 +-- synapse/replication/http/federation.py | 5 +++- synapse/replication/http/send_event.py | 3 ++- synapse/storage/data_stores/main/push_rule.py | 2 +- synapse/storage/data_stores/main/roommember.py | 2 +- tests/test_state.py | 28 ++++++++++---------- 15 files changed, 64 insertions(+), 53 deletions(-) create mode 100644 changelog.d/6564.misc (limited to 'tests') diff --git a/changelog.d/6564.misc b/changelog.d/6564.misc new file mode 100644 index 0000000000..f644f5868b --- /dev/null +++ b/changelog.d/6564.misc @@ -0,0 +1 @@ +Change `EventContext` to use the `Storage` class, in preparation for moving state database queries to a separate data store. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 9fd52a8c77..abbc7079a3 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -79,7 +79,7 @@ class Auth(object): @defer.inlineCallbacks def check_from_context(self, room_version, event, context, do_sig_check=True): - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() auth_events_ids = yield self.compute_auth_events( event, prev_state_ids, for_verification=True ) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 64e898f40c..a44baea365 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -149,7 +149,7 @@ class EventContext: # the prev_state_ids, so if we're a state event we include the event # id that we replaced in the state. if event.is_state(): - prev_state_ids = yield self.get_prev_state_ids(store) + prev_state_ids = yield self.get_prev_state_ids() prev_state_id = prev_state_ids.get((event.type, event.state_key)) else: prev_state_id = None @@ -167,12 +167,13 @@ class EventContext: } @staticmethod - def deserialize(store, input): + def deserialize(storage, input): """Converts a dict that was produced by `serialize` back into a EventContext. Args: - store (DataStore): Used to convert AS ID to AS object + storage (Storage): Used to convert AS ID to AS object and fetch + state. input (dict): A dict produced by `serialize` Returns: @@ -181,6 +182,7 @@ class EventContext: context = _AsyncEventContextImpl( # We use the state_group and prev_state_id stuff to pull the # current_state_ids out of the DB and construct prev_state_ids. + storage=storage, prev_state_id=input["prev_state_id"], event_type=input["event_type"], event_state_key=input["event_state_key"], @@ -193,7 +195,7 @@ class EventContext: app_service_id = input["app_service_id"] if app_service_id: - context.app_service = store.get_app_service_by_id(app_service_id) + context.app_service = storage.main.get_app_service_by_id(app_service_id) return context @@ -216,7 +218,7 @@ class EventContext: return self._state_group @defer.inlineCallbacks - def get_current_state_ids(self, store): + def get_current_state_ids(self): """ Gets the room state map, including this event - ie, the state in ``state_group`` @@ -234,11 +236,11 @@ class EventContext: if self.rejected: raise RuntimeError("Attempt to access state_ids of rejected event") - yield self._ensure_fetched(store) + yield self._ensure_fetched() return self._current_state_ids @defer.inlineCallbacks - def get_prev_state_ids(self, store): + def get_prev_state_ids(self): """ Gets the room state map, excluding this event. @@ -250,7 +252,7 @@ class EventContext: Maps a (type, state_key) to the event ID of the state event matching this tuple. """ - yield self._ensure_fetched(store) + yield self._ensure_fetched() return self._prev_state_ids def get_cached_current_state_ids(self): @@ -270,7 +272,7 @@ class EventContext: return self._current_state_ids - def _ensure_fetched(self, store): + def _ensure_fetched(self): return defer.succeed(None) @@ -282,6 +284,8 @@ class _AsyncEventContextImpl(EventContext): Attributes: + _storage (Storage) + _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have been calculated. None if we haven't started calculating yet @@ -295,28 +299,30 @@ class _AsyncEventContextImpl(EventContext): that was replaced. """ + # This needs to have a default as we're inheriting + _storage = attr.ib(default=None) _prev_state_id = attr.ib(default=None) _event_type = attr.ib(default=None) _event_state_key = attr.ib(default=None) _fetching_state_deferred = attr.ib(default=None) - def _ensure_fetched(self, store): + def _ensure_fetched(self): if not self._fetching_state_deferred: - self._fetching_state_deferred = run_in_background( - self._fill_out_state, store - ) + self._fetching_state_deferred = run_in_background(self._fill_out_state) return make_deferred_yieldable(self._fetching_state_deferred) @defer.inlineCallbacks - def _fill_out_state(self, store): + def _fill_out_state(self): """Called to populate the _current_state_ids and _prev_state_ids attributes by loading from the database. """ if self.state_group is None: return - self._current_state_ids = yield store.get_state_ids_for_group(self.state_group) + self._current_state_ids = yield self._storage.state.get_state_ids_for_group( + self.state_group + ) if self._prev_state_id and self._event_state_key is not None: self._prev_state_ids = dict(self._current_state_ids) diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 714a9b1579..86f7e5f8aa 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -53,7 +53,7 @@ class ThirdPartyEventRules(object): if self.third_party_rules is None: return True - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() # Retrieve the state events from the database. state_events = {} diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index d15c6282fb..51413d910e 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -134,7 +134,7 @@ class BaseHandler(object): guest_access = event.content.get("guest_access", "forbidden") if guest_access != "can_join": if context: - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() current_state = yield self.store.get_events( list(current_state_ids.values()) ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 60bb00fc6a..05ae40dde7 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -718,7 +718,7 @@ class FederationHandler(BaseHandler): # changing their profile info. newly_joined = True - prev_state_ids = await context.get_prev_state_ids(self.store) + prev_state_ids = await context.get_prev_state_ids() prev_state_id = prev_state_ids.get((event.type, event.state_key)) if prev_state_id: @@ -1418,7 +1418,7 @@ class FederationHandler(BaseHandler): user = UserID.from_string(event.state_key) yield self.user_joined_room(user, event.room_id) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() state_ids = list(prev_state_ids.values()) auth_chain = yield self.store.get_auth_chain(state_ids) @@ -1927,7 +1927,7 @@ class FederationHandler(BaseHandler): context = yield self.state_handler.compute_event_context(event, old_state=state) if not auth_events: - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() auth_events_ids = yield self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) @@ -2336,12 +2336,12 @@ class FederationHandler(BaseHandler): k: a.event_id for k, a in iteritems(auth_events) if k != event_key } - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() current_state_ids = dict(current_state_ids) current_state_ids.update(state_updates) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() prev_state_ids = dict(prev_state_ids) prev_state_ids.update({k: a.event_id for k, a in iteritems(auth_events)}) @@ -2625,7 +2625,7 @@ class FederationHandler(BaseHandler): event.content["third_party_invite"]["signed"]["token"], ) original_invite = None - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() original_invite_id = prev_state_ids.get(key) if original_invite_id: original_invite = yield self.store.get_event( @@ -2673,7 +2673,7 @@ class FederationHandler(BaseHandler): signed = event.content["third_party_invite"]["signed"] token = signed["token"] - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token)) invite_event = None diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index bf9add7fe2..4ad752205f 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -515,7 +515,7 @@ class EventCreationHandler(object): # federation as well as those created locally. As of room v3, aliases events # can be created by users that are not in the room, therefore we have to # tolerate them in event_auth.check(). - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender)) prev_event = ( yield self.store.get_event(prev_event_id, allow_none=True) @@ -665,7 +665,7 @@ class EventCreationHandler(object): If so, returns the version of the event in context. Otherwise, returns None. """ - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() prev_event_id = prev_state_ids.get((event.type, event.state_key)) if not prev_event_id: return @@ -914,7 +914,7 @@ class EventCreationHandler(object): def is_inviter_member_event(e): return e.type == EventTypes.Member and e.sender == event.sender - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() state_to_include_ids = [ e_id @@ -967,7 +967,7 @@ class EventCreationHandler(object): if original_event.room_id != event.room_id: raise SynapseError(400, "Cannot redact event from a different room") - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() auth_events_ids = yield self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) @@ -989,7 +989,7 @@ class EventCreationHandler(object): event.internal_metadata.recheck_redaction = False if event.type == EventTypes.Create: - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() if prev_state_ids: raise AuthError(403, "Changing the room create event is forbidden") diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index d3a1a7b4a6..89c9118b26 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -184,7 +184,7 @@ class RoomCreationHandler(BaseHandler): requester, tombstone_event, tombstone_context ) - old_room_state = yield tombstone_context.get_current_state_ids(self.store) + old_room_state = yield tombstone_context.get_current_state_ids() # update any aliases yield self._move_aliases_to_new_room( diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 7b7270fc61..44c5e3239c 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -193,7 +193,7 @@ class RoomMemberHandler(object): requester, event, context, extra_users=[target], ratelimit=ratelimit ) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) @@ -601,7 +601,7 @@ class RoomMemberHandler(object): if prev_event is not None: return - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() if event.membership == Membership.JOIN: if requester.is_guest: guest_can_join = yield self._can_guest_join(prev_state_ids) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 7881780760..7d9f5a38d9 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -116,7 +116,7 @@ class BulkPushRuleEvaluator(object): @defer.inlineCallbacks def _get_power_levels_and_sender_level(self, event, context): - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() pl_event_id = prev_state_ids.get(POWER_KEY) if pl_event_id: # fastpath: if there's a power level event, that's all we need, and @@ -304,7 +304,7 @@ class RulesForRoom(object): push_rules_delta_state_cache_metric.inc_hits() else: - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() push_rules_delta_state_cache_metric.inc_misses() push_rules_state_size_counter.inc(len(current_state_ids)) diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 9af4e7e173..49a3251372 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -51,6 +51,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): super(ReplicationFederationSendEventsRestServlet, self).__init__(hs) self.store = hs.get_datastore() + self.storage = hs.get_storage() self.clock = hs.get_clock() self.federation_handler = hs.get_handlers().federation_handler @@ -100,7 +101,9 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): EventType = event_type_from_format_version(format_ver) event = EventType(event_dict, internal_metadata, rejected_reason) - context = EventContext.deserialize(self.store, event_payload["context"]) + context = EventContext.deserialize( + self.storage, event_payload["context"] + ) event_and_contexts.append((event, context)) diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index 9bafd60b14..84b92f16ad 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -54,6 +54,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): self.event_creation_handler = hs.get_event_creation_handler() self.store = hs.get_datastore() + self.storage = hs.get_storage() self.clock = hs.get_clock() @staticmethod @@ -100,7 +101,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): event = EventType(event_dict, internal_metadata, rejected_reason) requester = Requester.deserialize(self.store, content["requester"]) - context = EventContext.deserialize(self.store, content["context"]) + context = EventContext.deserialize(self.storage, content["context"]) ratelimit = content["ratelimit"] extra_users = [UserID.from_string(u) for u in content["extra_users"]] diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index 5ba13aa973..e2673ae073 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -244,7 +244,7 @@ class PushRulesWorkerStore( # To do this we set the state_group to a new object as object() != object() state_group = object() - current_state_ids = yield context.get_current_state_ids(self) + current_state_ids = yield context.get_current_state_ids() result = yield self._bulk_get_push_rules_for_room( event.room_id, state_group, current_state_ids, event=event ) diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 92e3b9c512..70ff5751b6 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -477,7 +477,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # To do this we set the state_group to a new object as object() != object() state_group = object() - current_state_ids = yield context.get_current_state_ids(self) + current_state_ids = yield context.get_current_state_ids() result = yield self._get_joined_users_from_context( event.room_id, state_group, current_state_ids, event=event, context=context ) diff --git a/tests/test_state.py b/tests/test_state.py index 176535947a..e0aae06be4 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -209,7 +209,7 @@ class StateTestCase(unittest.TestCase): ctx_c = context_store["C"] ctx_d = context_store["D"] - prev_state_ids = yield ctx_d.get_prev_state_ids(self.store) + prev_state_ids = yield ctx_d.get_prev_state_ids() self.assertEqual(2, len(prev_state_ids)) self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event) @@ -253,7 +253,7 @@ class StateTestCase(unittest.TestCase): ctx_c = context_store["C"] ctx_d = context_store["D"] - prev_state_ids = yield ctx_d.get_prev_state_ids(self.store) + prev_state_ids = yield ctx_d.get_prev_state_ids() self.assertSetEqual( {"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()} ) @@ -312,7 +312,7 @@ class StateTestCase(unittest.TestCase): ctx_c = context_store["C"] ctx_e = context_store["E"] - prev_state_ids = yield ctx_e.get_prev_state_ids(self.store) + prev_state_ids = yield ctx_e.get_prev_state_ids() self.assertSetEqual( {"START", "A", "B", "C"}, {e for e in prev_state_ids.values()} ) @@ -387,7 +387,7 @@ class StateTestCase(unittest.TestCase): ctx_b = context_store["B"] ctx_d = context_store["D"] - prev_state_ids = yield ctx_d.get_prev_state_ids(self.store) + prev_state_ids = yield ctx_d.get_prev_state_ids() self.assertSetEqual( {"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()} ) @@ -419,10 +419,10 @@ class StateTestCase(unittest.TestCase): context = yield self.state.compute_event_context(event, old_state=old_state) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() self.assertCountEqual( (e.event_id for e in old_state), current_state_ids.values() ) @@ -442,10 +442,10 @@ class StateTestCase(unittest.TestCase): context = yield self.state.compute_event_context(event, old_state=old_state) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() self.assertCountEqual( (e.event_id for e in old_state + [event]), current_state_ids.values() ) @@ -479,7 +479,7 @@ class StateTestCase(unittest.TestCase): context = yield self.state.compute_event_context(event) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() self.assertEqual( set([e.event_id for e in old_state]), set(current_state_ids.values()) @@ -511,7 +511,7 @@ class StateTestCase(unittest.TestCase): context = yield self.state.compute_event_context(event) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() self.assertEqual( set([e.event_id for e in old_state]), set(prev_state_ids.values()) @@ -552,7 +552,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() self.assertEqual(len(current_state_ids), 6) @@ -594,7 +594,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() self.assertEqual(len(current_state_ids), 6) @@ -649,7 +649,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")]) @@ -677,7 +677,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")]) -- cgit 1.5.1 From 75d8f26ac85efd3816d454927f40b6e4c3032df1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 20 Dec 2019 10:48:24 +0000 Subject: Split state groups into a separate data store (#6296) --- changelog.d/6245.misc | 1 + scripts/synapse_port_db | 8 +- synapse/config/database.py | 10 +- synapse/storage/data_stores/__init__.py | 5 + synapse/storage/data_stores/main/events.py | 157 ---- .../main/schema/delta/23/drop_state_index.sql | 16 - .../main/schema/delta/30/state_stream.sql | 33 - .../main/schema/delta/32/remove_indices.sql | 1 - .../main/schema/delta/35/add_state_index.sql | 17 - .../data_stores/main/schema/delta/35/state.sql | 22 - .../main/schema/delta/35/state_dedupe.sql | 17 - .../main/schema/delta/47/state_group_seq.py | 34 - .../main/schema/full_schemas/54/full.sql.postgres | 52 -- .../main/schema/full_schemas/54/full.sql.sqlite | 6 - synapse/storage/data_stores/main/state.py | 937 +-------------------- synapse/storage/data_stores/state/__init__.py | 16 + synapse/storage/data_stores/state/bg_updates.py | 374 ++++++++ .../state/schema/delta/23/drop_state_index.sql | 16 + .../state/schema/delta/30/state_stream.sql | 33 + .../state/schema/delta/32/remove_state_indices.sql | 19 + .../state/schema/delta/35/add_state_index.sql | 17 + .../data_stores/state/schema/delta/35/state.sql | 22 + .../state/schema/delta/35/state_dedupe.sql | 17 + .../state/schema/delta/47/state_group_seq.py | 34 + .../state/schema/delta/56/state_group_room_idx.sql | 17 + .../state/schema/full_schemas/54/full.sql | 37 + .../schema/full_schemas/54/sequence.sql.postgres | 21 + synapse/storage/data_stores/state/store.py | 640 ++++++++++++++ synapse/storage/persist_events.py | 2 +- synapse/storage/prepare_database.py | 2 +- synapse/storage/purge_events.py | 4 +- synapse/storage/state.py | 14 +- tests/storage/test_state.py | 2 +- tests/utils.py | 2 +- 34 files changed, 1298 insertions(+), 1307 deletions(-) create mode 100644 changelog.d/6245.misc delete mode 100644 synapse/storage/data_stores/main/schema/delta/23/drop_state_index.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/30/state_stream.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/35/add_state_index.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/35/state.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/35/state_dedupe.sql delete mode 100644 synapse/storage/data_stores/main/schema/delta/47/state_group_seq.py create mode 100644 synapse/storage/data_stores/state/__init__.py create mode 100644 synapse/storage/data_stores/state/bg_updates.py create mode 100644 synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql create mode 100644 synapse/storage/data_stores/state/schema/delta/30/state_stream.sql create mode 100644 synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql create mode 100644 synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql create mode 100644 synapse/storage/data_stores/state/schema/delta/35/state.sql create mode 100644 synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql create mode 100644 synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py create mode 100644 synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql create mode 100644 synapse/storage/data_stores/state/schema/full_schemas/54/full.sql create mode 100644 synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres create mode 100644 synapse/storage/data_stores/state/store.py (limited to 'tests') diff --git a/changelog.d/6245.misc b/changelog.d/6245.misc new file mode 100644 index 0000000000..a3e6b8296e --- /dev/null +++ b/changelog.d/6245.misc @@ -0,0 +1 @@ +Split out state storage into separate data store. diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 5b5368988c..eb927f2094 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -51,11 +51,12 @@ from synapse.storage.data_stores.main.registration import ( from synapse.storage.data_stores.main.room import RoomBackgroundUpdateStore from synapse.storage.data_stores.main.roommember import RoomMemberBackgroundUpdateStore from synapse.storage.data_stores.main.search import SearchBackgroundUpdateStore -from synapse.storage.data_stores.main.state import StateBackgroundUpdateStore +from synapse.storage.data_stores.main.state import MainStateBackgroundUpdateStore from synapse.storage.data_stores.main.stats import StatsStore from synapse.storage.data_stores.main.user_directory import ( UserDirectoryBackgroundUpdateStore, ) +from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore from synapse.storage.database import Database, make_conn from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database @@ -138,6 +139,7 @@ class Store( RoomMemberBackgroundUpdateStore, SearchBackgroundUpdateStore, StateBackgroundUpdateStore, + MainStateBackgroundUpdateStore, UserDirectoryBackgroundUpdateStore, StatsStore, ): @@ -496,9 +498,7 @@ class Porter(object): def run(self): try: self.sqlite_store = yield self.build_db_store( - DatabaseConnectionConfig( - "master", self.sqlite_config, data_stores=["main"] - ) + DatabaseConnectionConfig("master-sqlite", self.sqlite_config) ) # Check if all background updates are done, abort if not. diff --git a/synapse/config/database.py b/synapse/config/database.py index 5f2f3c7cfd..134824789c 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -34,10 +34,12 @@ class DatabaseConnectionConfig: module name, and `args` for the args to give to the database connector. data_stores: The list of data stores that should be provisioned on the - database. + database. Defaults to all data stores. """ - def __init__(self, name: str, db_config: dict, data_stores: List[str]): + def __init__( + self, name: str, db_config: dict, data_stores: List[str] = ["main", "state"] + ): if db_config["name"] not in ("sqlite3", "psycopg2"): raise ConfigError("Unsupported database type %r" % (db_config["name"],)) @@ -62,9 +64,7 @@ class DatabaseConfig(Config): if database_config is None: database_config = {"name": "sqlite3", "args": {}} - self.databases = [ - DatabaseConnectionConfig("master", database_config, data_stores=["main"]) - ] + self.databases = [DatabaseConnectionConfig("master", database_config)] self.set_databasepath(config.get("database_path")) diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py index 0983e059c0..d20df5f076 100644 --- a/synapse/storage/data_stores/__init__.py +++ b/synapse/storage/data_stores/__init__.py @@ -15,6 +15,7 @@ import logging +from synapse.storage.data_stores.state import StateGroupDataStore from synapse.storage.database import Database, make_conn from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database @@ -55,6 +56,10 @@ class DataStores(object): logger.info("Starting 'main' data store") self.main = main_store_class(database, db_conn, hs) + if "state" in database_config.data_stores: + logger.info("Starting 'state' data store") + self.state = StateGroupDataStore(database, db_conn, hs) + db_conn.commit() self.databases.append(database) diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 998bba1aad..58f35d7f56 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -1757,163 +1757,6 @@ class EventsStore( return state_groups - def purge_unreferenced_state_groups( - self, room_id: str, state_groups_to_delete - ) -> defer.Deferred: - """Deletes no longer referenced state groups and de-deltas any state - groups that reference them. - - Args: - room_id: The room the state groups belong to (must all be in the - same room). - state_groups_to_delete (Collection[int]): Set of all state groups - to delete. - """ - - return self.db.runInteraction( - "purge_unreferenced_state_groups", - self._purge_unreferenced_state_groups, - room_id, - state_groups_to_delete, - ) - - def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete): - logger.info( - "[purge] found %i state groups to delete", len(state_groups_to_delete) - ) - - rows = self.db.simple_select_many_txn( - txn, - table="state_group_edges", - column="prev_state_group", - iterable=state_groups_to_delete, - keyvalues={}, - retcols=("state_group",), - ) - - remaining_state_groups = set( - row["state_group"] - for row in rows - if row["state_group"] not in state_groups_to_delete - ) - - logger.info( - "[purge] de-delta-ing %i remaining state groups", - len(remaining_state_groups), - ) - - # Now we turn the state groups that reference to-be-deleted state - # groups to non delta versions. - for sg in remaining_state_groups: - logger.info("[purge] de-delta-ing remaining state group %s", sg) - curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) - curr_state = curr_state[sg] - - self.db.simple_delete_txn( - txn, table="state_groups_state", keyvalues={"state_group": sg} - ) - - self.db.simple_delete_txn( - txn, table="state_group_edges", keyvalues={"state_group": sg} - ) - - self.db.simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": sg, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in iteritems(curr_state) - ], - ) - - logger.info("[purge] removing redundant state groups") - txn.executemany( - "DELETE FROM state_groups_state WHERE state_group = ?", - ((sg,) for sg in state_groups_to_delete), - ) - txn.executemany( - "DELETE FROM state_groups WHERE id = ?", - ((sg,) for sg in state_groups_to_delete), - ) - - @defer.inlineCallbacks - def get_previous_state_groups(self, state_groups): - """Fetch the previous groups of the given state groups. - - Args: - state_groups (Iterable[int]) - - Returns: - Deferred[dict[int, int]]: mapping from state group to previous - state group. - """ - - rows = yield self.db.simple_select_many_batch( - table="state_group_edges", - column="prev_state_group", - iterable=state_groups, - keyvalues={}, - retcols=("prev_state_group", "state_group"), - desc="get_previous_state_groups", - ) - - return {row["state_group"]: row["prev_state_group"] for row in rows} - - def purge_room_state(self, room_id, state_groups_to_delete): - """Deletes all record of a room from state tables - - Args: - room_id (str): - state_groups_to_delete (list[int]): State groups to delete - """ - - return self.db.runInteraction( - "purge_room_state", - self._purge_room_state_txn, - room_id, - state_groups_to_delete, - ) - - def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete): - # first we have to delete the state groups states - logger.info("[purge] removing %s from state_groups_state", room_id) - - self.db.simple_delete_many_txn( - txn, - table="state_groups_state", - column="state_group", - iterable=state_groups_to_delete, - keyvalues={}, - ) - - # ... and the state group edges - logger.info("[purge] removing %s from state_group_edges", room_id) - - self.db.simple_delete_many_txn( - txn, - table="state_group_edges", - column="state_group", - iterable=state_groups_to_delete, - keyvalues={}, - ) - - # ... and the state groups - logger.info("[purge] removing %s from state_groups", room_id) - - self.db.simple_delete_many_txn( - txn, - table="state_groups", - column="id", - iterable=state_groups_to_delete, - keyvalues={}, - ) - async def is_event_after(self, event_id1, event_id2): """Returns True if event_id1 is after event_id2 in the stream """ diff --git a/synapse/storage/data_stores/main/schema/delta/23/drop_state_index.sql b/synapse/storage/data_stores/main/schema/delta/23/drop_state_index.sql deleted file mode 100644 index ae09fa0065..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/23/drop_state_index.sql +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2015, 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -DROP INDEX IF EXISTS state_groups_state_tuple; diff --git a/synapse/storage/data_stores/main/schema/delta/30/state_stream.sql b/synapse/storage/data_stores/main/schema/delta/30/state_stream.sql deleted file mode 100644 index e85699e82e..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/30/state_stream.sql +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -/* We used to create a table called current_state_resets, but this is no - * longer used and is removed in delta 54. - */ - -/* The outlier events that have aquired a state group typically through - * backfill. This is tracked separately to the events table, as assigning a - * state group change the position of the existing event in the stream - * ordering. - * However since a stream_ordering is assigned in persist_event for the - * (event, state) pair, we can use that stream_ordering to identify when - * the new state was assigned for the event. - */ -CREATE TABLE IF NOT EXISTS ex_outlier_stream( - event_stream_ordering BIGINT PRIMARY KEY NOT NULL, - event_id TEXT NOT NULL, - state_group BIGINT NOT NULL -); diff --git a/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql b/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql index 4219cdd06a..2de50d408c 100644 --- a/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql +++ b/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql @@ -20,7 +20,6 @@ DROP INDEX IF EXISTS events_room_id; -- Prefix of events_room_stream DROP INDEX IF EXISTS events_order; -- Prefix of events_order_topo_stream_room DROP INDEX IF EXISTS events_topological_ordering; -- Prefix of events_order_topo_stream_room DROP INDEX IF EXISTS events_stream_ordering; -- Duplicate of PRIMARY KEY -DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY DROP INDEX IF EXISTS event_to_state_groups_id; -- Duplicate of PRIMARY KEY DROP INDEX IF EXISTS event_push_actions_room_id_event_id_user_id_profile_tag; -- Duplicate of UNIQUE CONSTRAINT diff --git a/synapse/storage/data_stores/main/schema/delta/35/add_state_index.sql b/synapse/storage/data_stores/main/schema/delta/35/add_state_index.sql deleted file mode 100644 index 33980d02f0..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/35/add_state_index.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -INSERT into background_updates (update_name, progress_json, depends_on) - VALUES ('state_group_state_type_index', '{}', 'state_group_state_deduplication'); diff --git a/synapse/storage/data_stores/main/schema/delta/35/state.sql b/synapse/storage/data_stores/main/schema/delta/35/state.sql deleted file mode 100644 index 0f1fa68a89..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/35/state.sql +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE state_group_edges( - state_group BIGINT NOT NULL, - prev_state_group BIGINT NOT NULL -); - -CREATE INDEX state_group_edges_idx ON state_group_edges(state_group); -CREATE INDEX state_group_edges_prev_idx ON state_group_edges(prev_state_group); diff --git a/synapse/storage/data_stores/main/schema/delta/35/state_dedupe.sql b/synapse/storage/data_stores/main/schema/delta/35/state_dedupe.sql deleted file mode 100644 index 97e5067ef4..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/35/state_dedupe.sql +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright 2016 OpenMarket Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -INSERT into background_updates (update_name, progress_json) - VALUES ('state_group_state_deduplication', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/47/state_group_seq.py b/synapse/storage/data_stores/main/schema/delta/47/state_group_seq.py deleted file mode 100644 index 9fd1ccf6f7..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/47/state_group_seq.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.storage.engines import PostgresEngine - - -def run_create(cur, database_engine, *args, **kwargs): - if isinstance(database_engine, PostgresEngine): - # if we already have some state groups, we want to start making new - # ones with a higher id. - cur.execute("SELECT max(id) FROM state_groups") - row = cur.fetchone() - - if row[0] is None: - start_val = 1 - else: - start_val = row[0] + 1 - - cur.execute("CREATE SEQUENCE state_group_id_seq START WITH %s", (start_val,)) - - -def run_upgrade(*args, **kwargs): - pass diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres index 4ad2929f32..889a9a0ce4 100644 --- a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres +++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres @@ -975,40 +975,6 @@ CREATE TABLE state_events ( -CREATE TABLE state_group_edges ( - state_group bigint NOT NULL, - prev_state_group bigint NOT NULL -); - - - -CREATE SEQUENCE state_group_id_seq - START WITH 1 - INCREMENT BY 1 - NO MINVALUE - NO MAXVALUE - CACHE 1; - - - -CREATE TABLE state_groups ( - id bigint NOT NULL, - room_id text NOT NULL, - event_id text NOT NULL -); - - - -CREATE TABLE state_groups_state ( - state_group bigint NOT NULL, - room_id text NOT NULL, - type text NOT NULL, - state_key text NOT NULL, - event_id text NOT NULL -); - - - CREATE TABLE stats_stream_pos ( lock character(1) DEFAULT 'X'::bpchar NOT NULL, stream_id bigint, @@ -1482,12 +1448,6 @@ ALTER TABLE ONLY state_events ADD CONSTRAINT state_events_event_id_key UNIQUE (event_id); - -ALTER TABLE ONLY state_groups - ADD CONSTRAINT state_groups_pkey PRIMARY KEY (id); - - - ALTER TABLE ONLY stats_stream_pos ADD CONSTRAINT stats_stream_pos_lock_key UNIQUE (lock); @@ -1928,18 +1888,6 @@ CREATE UNIQUE INDEX room_stats_room_ts ON room_stats USING btree (room_id, ts); -CREATE INDEX state_group_edges_idx ON state_group_edges USING btree (state_group); - - - -CREATE INDEX state_group_edges_prev_idx ON state_group_edges USING btree (prev_state_group); - - - -CREATE INDEX state_groups_state_type_idx ON state_groups_state USING btree (state_group, type, state_key); - - - CREATE INDEX stream_ordering_to_exterm_idx ON stream_ordering_to_exterm USING btree (stream_ordering); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite index bad33291e7..a0411ede7e 100644 --- a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite +++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite @@ -42,8 +42,6 @@ CREATE INDEX ev_edges_id ON event_edges(event_id); CREATE INDEX ev_edges_prev_id ON event_edges(prev_event_id); CREATE TABLE room_depth( room_id TEXT NOT NULL, min_depth INTEGER NOT NULL, UNIQUE (room_id) ); CREATE INDEX room_depth_room ON room_depth(room_id); -CREATE TABLE state_groups( id BIGINT PRIMARY KEY, room_id TEXT NOT NULL, event_id TEXT NOT NULL ); -CREATE TABLE state_groups_state( state_group BIGINT NOT NULL, room_id TEXT NOT NULL, type TEXT NOT NULL, state_key TEXT NOT NULL, event_id TEXT NOT NULL ); CREATE TABLE event_to_state_groups( event_id TEXT NOT NULL, state_group BIGINT NOT NULL, UNIQUE (event_id) ); CREATE TABLE local_media_repository ( media_id TEXT, media_type TEXT, media_length INTEGER, created_ts BIGINT, upload_name TEXT, user_id TEXT, quarantined_by TEXT, url_cache TEXT, last_access_ts BIGINT, UNIQUE (media_id) ); CREATE TABLE local_media_repository_thumbnails ( media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_type TEXT, thumbnail_method TEXT, thumbnail_length INTEGER, UNIQUE ( media_id, thumbnail_width, thumbnail_height, thumbnail_type ) ); @@ -120,9 +118,6 @@ CREATE TABLE device_max_stream_id ( stream_id BIGINT NOT NULL ); CREATE TABLE public_room_list_stream ( stream_id BIGINT NOT NULL, room_id TEXT NOT NULL, visibility BOOLEAN NOT NULL , appservice_id TEXT, network_id TEXT); CREATE INDEX public_room_list_stream_idx on public_room_list_stream( stream_id ); CREATE INDEX public_room_list_stream_rm_idx on public_room_list_stream( room_id, stream_id ); -CREATE TABLE state_group_edges( state_group BIGINT NOT NULL, prev_state_group BIGINT NOT NULL ); -CREATE INDEX state_group_edges_idx ON state_group_edges(state_group); -CREATE INDEX state_group_edges_prev_idx ON state_group_edges(prev_state_group); CREATE TABLE stream_ordering_to_exterm ( stream_ordering BIGINT NOT NULL, room_id TEXT NOT NULL, event_id TEXT NOT NULL ); CREATE INDEX stream_ordering_to_exterm_idx on stream_ordering_to_exterm( stream_ordering ); CREATE INDEX stream_ordering_to_exterm_rm_idx on stream_ordering_to_exterm( room_id, stream_ordering ); @@ -254,6 +249,5 @@ CREATE INDEX user_ips_last_seen_only ON user_ips (last_seen); CREATE INDEX users_creation_ts ON users (creation_ts); CREATE INDEX event_to_state_groups_sg_index ON event_to_state_groups (state_group); CREATE UNIQUE INDEX device_lists_remote_cache_unique_id ON device_lists_remote_cache (user_id, device_id); -CREATE INDEX state_groups_state_type_idx ON state_groups_state(state_group, type, state_key); CREATE UNIQUE INDEX device_lists_remote_extremeties_unique_idx ON device_lists_remote_extremeties (user_id); CREATE UNIQUE INDEX user_ips_user_token_ip_unique_index ON user_ips (user_id, access_token, ip); diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index dcc6b43cdf..0dc39f139c 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -17,8 +17,7 @@ import logging from collections import namedtuple from typing import Iterable, Tuple -from six import iteritems, itervalues -from six.moves import range +from six import iteritems from twisted.internet import defer @@ -29,11 +28,9 @@ from synapse.events.snapshot import EventContext from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.database import Database -from synapse.storage.engines import PostgresEngine from synapse.storage.state import StateFilter -from synapse.util.caches import get_cache_factor_for, intern_string +from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList -from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.stringutils import to_ascii logger = logging.getLogger(__name__) @@ -55,207 +52,14 @@ class _GetStateGroupDelta( return len(self.delta_ids) if self.delta_ids else 0 -class StateGroupBackgroundUpdateStore(SQLBaseStore): - """Defines functions related to state groups needed to run the state backgroud - updates. - """ - - def _count_state_group_hops_txn(self, txn, state_group): - """Given a state group, count how many hops there are in the tree. - - This is used to ensure the delta chains don't get too long. - """ - if isinstance(self.database_engine, PostgresEngine): - sql = """ - WITH RECURSIVE state(state_group) AS ( - VALUES(?::bigint) - UNION ALL - SELECT prev_state_group FROM state_group_edges e, state s - WHERE s.state_group = e.state_group - ) - SELECT count(*) FROM state; - """ - - txn.execute(sql, (state_group,)) - row = txn.fetchone() - if row and row[0]: - return row[0] - else: - return 0 - else: - # We don't use WITH RECURSIVE on sqlite3 as there are distributions - # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) - next_group = state_group - count = 0 - - while next_group: - next_group = self.db.simple_select_one_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": next_group}, - retcol="prev_state_group", - allow_none=True, - ) - if next_group: - count += 1 - - return count - - def _get_state_groups_from_groups_txn( - self, txn, groups, state_filter=StateFilter.all() - ): - results = {group: {} for group in groups} - - where_clause, where_args = state_filter.make_sql_filter_clause() - - # Unless the filter clause is empty, we're going to append it after an - # existing where clause - if where_clause: - where_clause = " AND (%s)" % (where_clause,) - - if isinstance(self.database_engine, PostgresEngine): - # Temporarily disable sequential scans in this transaction. This is - # a temporary hack until we can add the right indices in - txn.execute("SET LOCAL enable_seqscan=off") - - # The below query walks the state_group tree so that the "state" - # table includes all state_groups in the tree. It then joins - # against `state_groups_state` to fetch the latest state. - # It assumes that previous state groups are always numerically - # lesser. - # The PARTITION is used to get the event_id in the greatest state - # group for the given type, state_key. - # This may return multiple rows per (type, state_key), but last_value - # should be the same. - sql = """ - WITH RECURSIVE state(state_group) AS ( - VALUES(?::bigint) - UNION ALL - SELECT prev_state_group FROM state_group_edges e, state s - WHERE s.state_group = e.state_group - ) - SELECT DISTINCT type, state_key, last_value(event_id) OVER ( - PARTITION BY type, state_key ORDER BY state_group ASC - ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING - ) AS event_id FROM state_groups_state - WHERE state_group IN ( - SELECT state_group FROM state - ) - """ - - for group in groups: - args = [group] - args.extend(where_args) - - txn.execute(sql + where_clause, args) - for row in txn: - typ, state_key, event_id = row - key = (typ, state_key) - results[group][key] = event_id - else: - max_entries_returned = state_filter.max_entries_returned() - - # We don't use WITH RECURSIVE on sqlite3 as there are distributions - # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) - for group in groups: - next_group = group - - while next_group: - # We did this before by getting the list of group ids, and - # then passing that list to sqlite to get latest event for - # each (type, state_key). However, that was terribly slow - # without the right indices (which we can't add until - # after we finish deduping state, which requires this func) - args = [next_group] - args.extend(where_args) - - txn.execute( - "SELECT type, state_key, event_id FROM state_groups_state" - " WHERE state_group = ? " + where_clause, - args, - ) - results[group].update( - ((typ, state_key), event_id) - for typ, state_key, event_id in txn - if (typ, state_key) not in results[group] - ) - - # If the number of entries in the (type,state_key)->event_id dict - # matches the number of (type,state_keys) types we were searching - # for, then we must have found them all, so no need to go walk - # further down the tree... UNLESS our types filter contained - # wildcards (i.e. Nones) in which case we have to do an exhaustive - # search - if ( - max_entries_returned is not None - and len(results[group]) == max_entries_returned - ): - break - - next_group = self.db.simple_select_one_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": next_group}, - retcol="prev_state_group", - allow_none=True, - ) - - return results - - # this inherits from EventsWorkerStore because it calls self.get_events -class StateGroupWorkerStore( - EventsWorkerStore, StateGroupBackgroundUpdateStore, SQLBaseStore -): +class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """The parts of StateGroupStore that can be called from workers. """ - STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" - STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" - CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" - def __init__(self, database: Database, db_conn, hs): super(StateGroupWorkerStore, self).__init__(database, db_conn, hs) - # Originally the state store used a single DictionaryCache to cache the - # event IDs for the state types in a given state group to avoid hammering - # on the state_group* tables. - # - # The point of using a DictionaryCache is that it can cache a subset - # of the state events for a given state group (i.e. a subset of the keys for a - # given dict which is an entry in the cache for a given state group ID). - # - # However, this poses problems when performing complicated queries - # on the store - for instance: "give me all the state for this group, but - # limit members to this subset of users", as DictionaryCache's API isn't - # rich enough to say "please cache any of these fields, apart from this subset". - # This is problematic when lazy loading members, which requires this behaviour, - # as without it the cache has no choice but to speculatively load all - # state events for the group, which negates the efficiency being sought. - # - # Rather than overcomplicating DictionaryCache's API, we instead split the - # state_group_cache into two halves - one for tracking non-member events, - # and the other for tracking member_events. This means that lazy loading - # queries can be made in a cache-friendly manner by querying both caches - # separately and then merging the result. So for the example above, you - # would query the members cache for a specific subset of state keys - # (which DictionaryCache will handle efficiently and fine) and the non-members - # cache for all state (which DictionaryCache will similarly handle fine) - # and then just merge the results together. - # - # We size the non-members cache to be smaller than the members cache as the - # vast majority of state in Matrix (today) is member events. - - self._state_group_cache = DictionaryCache( - "*stateGroupCache*", - # TODO: this hasn't been tuned yet - 50000 * get_cache_factor_for("stateGroupCache"), - ) - self._state_group_members_cache = DictionaryCache( - "*stateGroupMembersCache*", - 500000 * get_cache_factor_for("stateGroupMembersCache"), - ) - @defer.inlineCallbacks def get_room_version(self, room_id): """Get the room_version of a given room @@ -431,229 +235,6 @@ class StateGroupWorkerStore( return event.content.get("canonical_alias") - @cached(max_entries=10000, iterable=True) - def get_state_group_delta(self, state_group): - """Given a state group try to return a previous group and a delta between - the old and the new. - - Returns: - (prev_group, delta_ids), where both may be None. - """ - - def _get_state_group_delta_txn(txn): - prev_group = self.db.simple_select_one_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": state_group}, - retcol="prev_state_group", - allow_none=True, - ) - - if not prev_group: - return _GetStateGroupDelta(None, None) - - delta_ids = self.db.simple_select_list_txn( - txn, - table="state_groups_state", - keyvalues={"state_group": state_group}, - retcols=("type", "state_key", "event_id"), - ) - - return _GetStateGroupDelta( - prev_group, - {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids}, - ) - - return self.db.runInteraction( - "get_state_group_delta", _get_state_group_delta_txn - ) - - @defer.inlineCallbacks - def get_state_groups_ids(self, _room_id, event_ids): - """Get the event IDs of all the state for the state groups for the given events - - Args: - _room_id (str): id of the room for these events - event_ids (iterable[str]): ids of the events - - Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) - """ - if not event_ids: - return {} - - event_to_groups = yield self._get_state_group_for_events(event_ids) - - groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups) - - return group_to_state - - @defer.inlineCallbacks - def get_state_ids_for_group(self, state_group): - """Get the event IDs of all the state in the given state group - - Args: - state_group (int) - - Returns: - Deferred[dict]: Resolves to a map of (type, state_key) -> event_id - """ - group_to_state = yield self._get_state_for_groups((state_group,)) - - return group_to_state[state_group] - - @defer.inlineCallbacks - def get_state_groups(self, room_id, event_ids): - """ Get the state groups for the given list of event_ids - - Returns: - Deferred[dict[int, list[EventBase]]]: - dict of state_group_id -> list of state events. - """ - if not event_ids: - return {} - - group_to_ids = yield self.get_state_groups_ids(room_id, event_ids) - - state_event_map = yield self.get_events( - [ - ev_id - for group_ids in itervalues(group_to_ids) - for ev_id in itervalues(group_ids) - ], - get_prev_content=False, - ) - - return { - group: [ - state_event_map[v] - for v in itervalues(event_id_map) - if v in state_event_map - ] - for group, event_id_map in iteritems(group_to_ids) - } - - @defer.inlineCallbacks - def _get_state_groups_from_groups(self, groups, state_filter): - """Returns the state groups for a given set of groups, filtering on - types of state events. - - Args: - groups(list[int]): list of state group IDs to query - state_filter (StateFilter): The state filter used to fetch state - from the database. - Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) - """ - results = {} - - chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] - for chunk in chunks: - res = yield self.db.runInteraction( - "_get_state_groups_from_groups", - self._get_state_groups_from_groups_txn, - chunk, - state_filter, - ) - results.update(res) - - return results - - @defer.inlineCallbacks - def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): - """Given a list of event_ids and type tuples, return a list of state - dicts for each event. - - Args: - event_ids (list[string]) - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - deferred: A dict of (event_id) -> (type, state_key) -> [state_events] - """ - event_to_groups = yield self._get_state_group_for_events(event_ids) - - groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups, state_filter) - - state_event_map = yield self.get_events( - [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)], - get_prev_content=False, - ) - - event_to_state = { - event_id: { - k: state_event_map[v] - for k, v in iteritems(group_to_state[group]) - if v in state_event_map - } - for event_id, group in iteritems(event_to_groups) - } - - return {event: event_to_state[event] for event in event_ids} - - @defer.inlineCallbacks - def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()): - """ - Get the state dicts corresponding to a list of events, containing the event_ids - of the state events (as opposed to the events themselves) - - Args: - event_ids(list(str)): events whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - A deferred dict from event_id -> (type, state_key) -> event_id - """ - event_to_groups = yield self._get_state_group_for_events(event_ids) - - groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups, state_filter) - - event_to_state = { - event_id: group_to_state[group] - for event_id, group in iteritems(event_to_groups) - } - - return {event: event_to_state[event] for event in event_ids} - - @defer.inlineCallbacks - def get_state_for_event(self, event_id, state_filter=StateFilter.all()): - """ - Get the state dict corresponding to a particular event - - Args: - event_id(str): event whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - A deferred dict from (type, state_key) -> state_event - """ - state_map = yield self.get_state_for_events([event_id], state_filter) - return state_map[event_id] - - @defer.inlineCallbacks - def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()): - """ - Get the state dict corresponding to a particular event - - Args: - event_id(str): event whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - A deferred dict from (type, state_key) -> state_event - """ - state_map = yield self.get_state_ids_for_events([event_id], state_filter) - return state_map[event_id] - @cached(max_entries=50000) def _get_state_group_for_event(self, event_id): return self.db.simple_select_one_onecol( @@ -684,329 +265,6 @@ class StateGroupWorkerStore( return {row["event_id"]: row["state_group"] for row in rows} - def _get_state_for_group_using_cache(self, cache, group, state_filter): - """Checks if group is in cache. See `_get_state_for_groups` - - Args: - cache(DictionaryCache): the state group cache to use - group(int): The state group to lookup - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns 2-tuple (`state_dict`, `got_all`). - `got_all` is a bool indicating if we successfully retrieved all - requests state from the cache, if False we need to query the DB for the - missing state. - """ - is_all, known_absent, state_dict_ids = cache.get(group) - - if is_all or state_filter.is_full(): - # Either we have everything or want everything, either way - # `is_all` tells us whether we've gotten everything. - return state_filter.filter_state(state_dict_ids), is_all - - # tracks whether any of our requested types are missing from the cache - missing_types = False - - if state_filter.has_wildcards(): - # We don't know if we fetched all the state keys for the types in - # the filter that are wildcards, so we have to assume that we may - # have missed some. - missing_types = True - else: - # There aren't any wild cards, so `concrete_types()` returns the - # complete list of event types we're wanting. - for key in state_filter.concrete_types(): - if key not in state_dict_ids and key not in known_absent: - missing_types = True - break - - return state_filter.filter_state(state_dict_ids), not missing_types - - @defer.inlineCallbacks - def _get_state_for_groups(self, groups, state_filter=StateFilter.all()): - """Gets the state at each of a list of state groups, optionally - filtering by type/state_key - - Args: - groups (iterable[int]): list of state groups for which we want - to get the state. - state_filter (StateFilter): The state filter used to fetch state - from the database. - Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) - """ - - member_filter, non_member_filter = state_filter.get_member_split() - - # Now we look them up in the member and non-member caches - ( - non_member_state, - incomplete_groups_nm, - ) = yield self._get_state_for_groups_using_cache( - groups, self._state_group_cache, state_filter=non_member_filter - ) - - ( - member_state, - incomplete_groups_m, - ) = yield self._get_state_for_groups_using_cache( - groups, self._state_group_members_cache, state_filter=member_filter - ) - - state = dict(non_member_state) - for group in groups: - state[group].update(member_state[group]) - - # Now fetch any missing groups from the database - - incomplete_groups = incomplete_groups_m | incomplete_groups_nm - - if not incomplete_groups: - return state - - cache_sequence_nm = self._state_group_cache.sequence - cache_sequence_m = self._state_group_members_cache.sequence - - # Help the cache hit ratio by expanding the filter a bit - db_state_filter = state_filter.return_expanded() - - group_to_state_dict = yield self._get_state_groups_from_groups( - list(incomplete_groups), state_filter=db_state_filter - ) - - # Now lets update the caches - self._insert_into_cache( - group_to_state_dict, - db_state_filter, - cache_seq_num_members=cache_sequence_m, - cache_seq_num_non_members=cache_sequence_nm, - ) - - # And finally update the result dict, by filtering out any extra - # stuff we pulled out of the database. - for group, group_state_dict in iteritems(group_to_state_dict): - # We just replace any existing entries, as we will have loaded - # everything we need from the database anyway. - state[group] = state_filter.filter_state(group_state_dict) - - return state - - def _get_state_for_groups_using_cache(self, groups, cache, state_filter): - """Gets the state at each of a list of state groups, optionally - filtering by type/state_key, querying from a specific cache. - - Args: - groups (iterable[int]): list of state groups for which we want - to get the state. - cache (DictionaryCache): the cache of group ids to state dicts which - we will pass through - either the normal state cache or the specific - members state cache. - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of - dict of state_group_id -> (dict of (type, state_key) -> event id) - of entries in the cache, and the state group ids either missing - from the cache or incomplete. - """ - results = {} - incomplete_groups = set() - for group in set(groups): - state_dict_ids, got_all = self._get_state_for_group_using_cache( - cache, group, state_filter - ) - results[group] = state_dict_ids - - if not got_all: - incomplete_groups.add(group) - - return results, incomplete_groups - - def _insert_into_cache( - self, - group_to_state_dict, - state_filter, - cache_seq_num_members, - cache_seq_num_non_members, - ): - """Inserts results from querying the database into the relevant cache. - - Args: - group_to_state_dict (dict): The new entries pulled from database. - Map from state group to state dict - state_filter (StateFilter): The state filter used to fetch state - from the database. - cache_seq_num_members (int): Sequence number of member cache since - last lookup in cache - cache_seq_num_non_members (int): Sequence number of member cache since - last lookup in cache - """ - - # We need to work out which types we've fetched from the DB for the - # member vs non-member caches. This should be as accurate as possible, - # but can be an underestimate (e.g. when we have wild cards) - - member_filter, non_member_filter = state_filter.get_member_split() - if member_filter.is_full(): - # We fetched all member events - member_types = None - else: - # `concrete_types()` will only return a subset when there are wild - # cards in the filter, but that's fine. - member_types = member_filter.concrete_types() - - if non_member_filter.is_full(): - # We fetched all non member events - non_member_types = None - else: - non_member_types = non_member_filter.concrete_types() - - for group, group_state_dict in iteritems(group_to_state_dict): - state_dict_members = {} - state_dict_non_members = {} - - for k, v in iteritems(group_state_dict): - if k[0] == EventTypes.Member: - state_dict_members[k] = v - else: - state_dict_non_members[k] = v - - self._state_group_members_cache.update( - cache_seq_num_members, - key=group, - value=state_dict_members, - fetched_keys=member_types, - ) - - self._state_group_cache.update( - cache_seq_num_non_members, - key=group, - value=state_dict_non_members, - fetched_keys=non_member_types, - ) - - def store_state_group( - self, event_id, room_id, prev_group, delta_ids, current_state_ids - ): - """Store a new set of state, returning a newly assigned state group. - - Args: - event_id (str): The event ID for which the state was calculated - room_id (str) - prev_group (int|None): A previous state group for the room, optional. - delta_ids (dict|None): The delta between state at `prev_group` and - `current_state_ids`, if `prev_group` was given. Same format as - `current_state_ids`. - current_state_ids (dict): The state to store. Map of (type, state_key) - to event_id. - - Returns: - Deferred[int]: The state group ID - """ - - def _store_state_group_txn(txn): - if current_state_ids is None: - # AFAIK, this can never happen - raise Exception("current_state_ids cannot be None") - - state_group = self.database_engine.get_next_state_group_id(txn) - - self.db.simple_insert_txn( - txn, - table="state_groups", - values={"id": state_group, "room_id": room_id, "event_id": event_id}, - ) - - # We persist as a delta if we can, while also ensuring the chain - # of deltas isn't tooo long, as otherwise read performance degrades. - if prev_group: - is_in_db = self.db.simple_select_one_onecol_txn( - txn, - table="state_groups", - keyvalues={"id": prev_group}, - retcol="id", - allow_none=True, - ) - if not is_in_db: - raise Exception( - "Trying to persist state with unpersisted prev_group: %r" - % (prev_group,) - ) - - potential_hops = self._count_state_group_hops_txn(txn, prev_group) - if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: - self.db.simple_insert_txn( - txn, - table="state_group_edges", - values={"state_group": state_group, "prev_state_group": prev_group}, - ) - - self.db.simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": state_group, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in iteritems(delta_ids) - ], - ) - else: - self.db.simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": state_group, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in iteritems(current_state_ids) - ], - ) - - # Prefill the state group caches with this group. - # It's fine to use the sequence like this as the state group map - # is immutable. (If the map wasn't immutable then this prefill could - # race with another update) - - current_member_state_ids = { - s: ev - for (s, ev) in iteritems(current_state_ids) - if s[0] == EventTypes.Member - } - txn.call_after( - self._state_group_members_cache.update, - self._state_group_members_cache.sequence, - key=state_group, - value=dict(current_member_state_ids), - ) - - current_non_member_state_ids = { - s: ev - for (s, ev) in iteritems(current_state_ids) - if s[0] != EventTypes.Member - } - txn.call_after( - self._state_group_cache.update, - self._state_group_cache.sequence, - key=state_group, - value=dict(current_non_member_state_ids), - ) - - return state_group - - return self.db.runInteraction("store_state_group", _store_state_group_txn) - @defer.inlineCallbacks def get_referenced_state_groups(self, state_groups): """Check if the state groups are referenced by events. @@ -1031,22 +289,14 @@ class StateGroupWorkerStore( return set(row["state_group"] for row in rows) -class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): +class MainStateBackgroundUpdateStore(SQLBaseStore): - STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" - STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" def __init__(self, database: Database, db_conn, hs): - super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs) - self.db.updates.register_background_update_handler( - self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, - self._background_deduplicate_state, - ) - self.db.updates.register_background_update_handler( - self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state - ) + super(MainStateBackgroundUpdateStore, self).__init__(database, db_conn, hs) + self.db.updates.register_background_index_update( self.CURRENT_STATE_INDEX_UPDATE_NAME, index_name="current_state_events_member_index", @@ -1061,181 +311,8 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): columns=["state_group"], ) - @defer.inlineCallbacks - def _background_deduplicate_state(self, progress, batch_size): - """This background update will slowly deduplicate state by reencoding - them as deltas. - """ - last_state_group = progress.get("last_state_group", 0) - rows_inserted = progress.get("rows_inserted", 0) - max_group = progress.get("max_group", None) - - BATCH_SIZE_SCALE_FACTOR = 100 - batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR)) - - if max_group is None: - rows = yield self.db.execute( - "_background_deduplicate_state", - None, - "SELECT coalesce(max(id), 0) FROM state_groups", - ) - max_group = rows[0][0] - - def reindex_txn(txn): - new_last_state_group = last_state_group - for count in range(batch_size): - txn.execute( - "SELECT id, room_id FROM state_groups" - " WHERE ? < id AND id <= ?" - " ORDER BY id ASC" - " LIMIT 1", - (new_last_state_group, max_group), - ) - row = txn.fetchone() - if row: - state_group, room_id = row - - if not row or not state_group: - return True, count - - txn.execute( - "SELECT state_group FROM state_group_edges" - " WHERE state_group = ?", - (state_group,), - ) - - # If we reach a point where we've already started inserting - # edges we should stop. - if txn.fetchall(): - return True, count - - txn.execute( - "SELECT coalesce(max(id), 0) FROM state_groups" - " WHERE id < ? AND room_id = ?", - (state_group, room_id), - ) - (prev_group,) = txn.fetchone() - new_last_state_group = state_group - - if prev_group: - potential_hops = self._count_state_group_hops_txn(txn, prev_group) - if potential_hops >= MAX_STATE_DELTA_HOPS: - # We want to ensure chains are at most this long,# - # otherwise read performance degrades. - continue - - prev_state = self._get_state_groups_from_groups_txn( - txn, [prev_group] - ) - prev_state = prev_state[prev_group] - - curr_state = self._get_state_groups_from_groups_txn( - txn, [state_group] - ) - curr_state = curr_state[state_group] - - if not set(prev_state.keys()) - set(curr_state.keys()): - # We can only do a delta if the current has a strict super set - # of keys - - delta_state = { - key: value - for key, value in iteritems(curr_state) - if prev_state.get(key, None) != value - } - - self.db.simple_delete_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": state_group}, - ) - - self.db.simple_insert_txn( - txn, - table="state_group_edges", - values={ - "state_group": state_group, - "prev_state_group": prev_group, - }, - ) - - self.db.simple_delete_txn( - txn, - table="state_groups_state", - keyvalues={"state_group": state_group}, - ) - - self.db.simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": state_group, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in iteritems(delta_state) - ], - ) - - progress = { - "last_state_group": state_group, - "rows_inserted": rows_inserted + batch_size, - "max_group": max_group, - } - - self.db.updates._background_update_progress_txn( - txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress - ) - - return False, batch_size - - finished, result = yield self.db.runInteraction( - self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn - ) - - if finished: - yield self.db.updates._end_background_update( - self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME - ) - - return result * BATCH_SIZE_SCALE_FACTOR - - @defer.inlineCallbacks - def _background_index_state(self, progress, batch_size): - def reindex_txn(conn): - conn.rollback() - if isinstance(self.database_engine, PostgresEngine): - # postgres insists on autocommit for the index - conn.set_session(autocommit=True) - try: - txn = conn.cursor() - txn.execute( - "CREATE INDEX CONCURRENTLY state_groups_state_type_idx" - " ON state_groups_state(state_group, type, state_key)" - ) - txn.execute("DROP INDEX IF EXISTS state_groups_state_id") - finally: - conn.set_session(autocommit=False) - else: - txn = conn.cursor() - txn.execute( - "CREATE INDEX state_groups_state_type_idx" - " ON state_groups_state(state_group, type, state_key)" - ) - txn.execute("DROP INDEX IF EXISTS state_groups_state_id") - - yield self.db.runWithConnection(reindex_txn) - - yield self.db.updates._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) - - return 1 - - -class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore): +class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore): """ Keeps track of the state at a given event. This is done by the concept of `state groups`. Every event is a assigned diff --git a/synapse/storage/data_stores/state/__init__.py b/synapse/storage/data_stores/state/__init__.py new file mode 100644 index 0000000000..86e09f6229 --- /dev/null +++ b/synapse/storage/data_stores/state/__init__.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.data_stores.state.store import StateGroupDataStore # noqa: F401 diff --git a/synapse/storage/data_stores/state/bg_updates.py b/synapse/storage/data_stores/state/bg_updates.py new file mode 100644 index 0000000000..e8edaf9f7b --- /dev/null +++ b/synapse/storage/data_stores/state/bg_updates.py @@ -0,0 +1,374 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from six import iteritems + +from twisted.internet import defer + +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database +from synapse.storage.engines import PostgresEngine +from synapse.storage.state import StateFilter + +logger = logging.getLogger(__name__) + + +MAX_STATE_DELTA_HOPS = 100 + + +class StateGroupBackgroundUpdateStore(SQLBaseStore): + """Defines functions related to state groups needed to run the state backgroud + updates. + """ + + def _count_state_group_hops_txn(self, txn, state_group): + """Given a state group, count how many hops there are in the tree. + + This is used to ensure the delta chains don't get too long. + """ + if isinstance(self.database_engine, PostgresEngine): + sql = """ + WITH RECURSIVE state(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, state s + WHERE s.state_group = e.state_group + ) + SELECT count(*) FROM state; + """ + + txn.execute(sql, (state_group,)) + row = txn.fetchone() + if row and row[0]: + return row[0] + else: + return 0 + else: + # We don't use WITH RECURSIVE on sqlite3 as there are distributions + # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + next_group = state_group + count = 0 + + while next_group: + next_group = self.db.simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": next_group}, + retcol="prev_state_group", + allow_none=True, + ) + if next_group: + count += 1 + + return count + + def _get_state_groups_from_groups_txn( + self, txn, groups, state_filter=StateFilter.all() + ): + results = {group: {} for group in groups} + + where_clause, where_args = state_filter.make_sql_filter_clause() + + # Unless the filter clause is empty, we're going to append it after an + # existing where clause + if where_clause: + where_clause = " AND (%s)" % (where_clause,) + + if isinstance(self.database_engine, PostgresEngine): + # Temporarily disable sequential scans in this transaction. This is + # a temporary hack until we can add the right indices in + txn.execute("SET LOCAL enable_seqscan=off") + + # The below query walks the state_group tree so that the "state" + # table includes all state_groups in the tree. It then joins + # against `state_groups_state` to fetch the latest state. + # It assumes that previous state groups are always numerically + # lesser. + # The PARTITION is used to get the event_id in the greatest state + # group for the given type, state_key. + # This may return multiple rows per (type, state_key), but last_value + # should be the same. + sql = """ + WITH RECURSIVE state(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, state s + WHERE s.state_group = e.state_group + ) + SELECT DISTINCT type, state_key, last_value(event_id) OVER ( + PARTITION BY type, state_key ORDER BY state_group ASC + ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + ) AS event_id FROM state_groups_state + WHERE state_group IN ( + SELECT state_group FROM state + ) + """ + + for group in groups: + args = [group] + args.extend(where_args) + + txn.execute(sql + where_clause, args) + for row in txn: + typ, state_key, event_id = row + key = (typ, state_key) + results[group][key] = event_id + else: + max_entries_returned = state_filter.max_entries_returned() + + # We don't use WITH RECURSIVE on sqlite3 as there are distributions + # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + for group in groups: + next_group = group + + while next_group: + # We did this before by getting the list of group ids, and + # then passing that list to sqlite to get latest event for + # each (type, state_key). However, that was terribly slow + # without the right indices (which we can't add until + # after we finish deduping state, which requires this func) + args = [next_group] + args.extend(where_args) + + txn.execute( + "SELECT type, state_key, event_id FROM state_groups_state" + " WHERE state_group = ? " + where_clause, + args, + ) + results[group].update( + ((typ, state_key), event_id) + for typ, state_key, event_id in txn + if (typ, state_key) not in results[group] + ) + + # If the number of entries in the (type,state_key)->event_id dict + # matches the number of (type,state_keys) types we were searching + # for, then we must have found them all, so no need to go walk + # further down the tree... UNLESS our types filter contained + # wildcards (i.e. Nones) in which case we have to do an exhaustive + # search + if ( + max_entries_returned is not None + and len(results[group]) == max_entries_returned + ): + break + + next_group = self.db.simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": next_group}, + retcol="prev_state_group", + allow_none=True, + ) + + return results + + +class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): + + STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" + STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" + STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx" + + def __init__(self, database: Database, db_conn, hs): + super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs) + self.db.updates.register_background_update_handler( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, + self._background_deduplicate_state, + ) + self.db.updates.register_background_update_handler( + self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state + ) + self.db.updates.register_background_index_update( + self.STATE_GROUPS_ROOM_INDEX_UPDATE_NAME, + index_name="state_groups_room_id_idx", + table="state_groups", + columns=["room_id"], + ) + + @defer.inlineCallbacks + def _background_deduplicate_state(self, progress, batch_size): + """This background update will slowly deduplicate state by reencoding + them as deltas. + """ + last_state_group = progress.get("last_state_group", 0) + rows_inserted = progress.get("rows_inserted", 0) + max_group = progress.get("max_group", None) + + BATCH_SIZE_SCALE_FACTOR = 100 + + batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR)) + + if max_group is None: + rows = yield self.db.execute( + "_background_deduplicate_state", + None, + "SELECT coalesce(max(id), 0) FROM state_groups", + ) + max_group = rows[0][0] + + def reindex_txn(txn): + new_last_state_group = last_state_group + for count in range(batch_size): + txn.execute( + "SELECT id, room_id FROM state_groups" + " WHERE ? < id AND id <= ?" + " ORDER BY id ASC" + " LIMIT 1", + (new_last_state_group, max_group), + ) + row = txn.fetchone() + if row: + state_group, room_id = row + + if not row or not state_group: + return True, count + + txn.execute( + "SELECT state_group FROM state_group_edges" + " WHERE state_group = ?", + (state_group,), + ) + + # If we reach a point where we've already started inserting + # edges we should stop. + if txn.fetchall(): + return True, count + + txn.execute( + "SELECT coalesce(max(id), 0) FROM state_groups" + " WHERE id < ? AND room_id = ?", + (state_group, room_id), + ) + (prev_group,) = txn.fetchone() + new_last_state_group = state_group + + if prev_group: + potential_hops = self._count_state_group_hops_txn(txn, prev_group) + if potential_hops >= MAX_STATE_DELTA_HOPS: + # We want to ensure chains are at most this long,# + # otherwise read performance degrades. + continue + + prev_state = self._get_state_groups_from_groups_txn( + txn, [prev_group] + ) + prev_state = prev_state[prev_group] + + curr_state = self._get_state_groups_from_groups_txn( + txn, [state_group] + ) + curr_state = curr_state[state_group] + + if not set(prev_state.keys()) - set(curr_state.keys()): + # We can only do a delta if the current has a strict super set + # of keys + + delta_state = { + key: value + for key, value in iteritems(curr_state) + if prev_state.get(key, None) != value + } + + self.db.simple_delete_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": state_group}, + ) + + self.db.simple_insert_txn( + txn, + table="state_group_edges", + values={ + "state_group": state_group, + "prev_state_group": prev_group, + }, + ) + + self.db.simple_delete_txn( + txn, + table="state_groups_state", + keyvalues={"state_group": state_group}, + ) + + self.db.simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(delta_state) + ], + ) + + progress = { + "last_state_group": state_group, + "rows_inserted": rows_inserted + batch_size, + "max_group": max_group, + } + + self.db.updates._background_update_progress_txn( + txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress + ) + + return False, batch_size + + finished, result = yield self.db.runInteraction( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn + ) + + if finished: + yield self.db.updates._end_background_update( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME + ) + + return result * BATCH_SIZE_SCALE_FACTOR + + @defer.inlineCallbacks + def _background_index_state(self, progress, batch_size): + def reindex_txn(conn): + conn.rollback() + if isinstance(self.database_engine, PostgresEngine): + # postgres insists on autocommit for the index + conn.set_session(autocommit=True) + try: + txn = conn.cursor() + txn.execute( + "CREATE INDEX CONCURRENTLY state_groups_state_type_idx" + " ON state_groups_state(state_group, type, state_key)" + ) + txn.execute("DROP INDEX IF EXISTS state_groups_state_id") + finally: + conn.set_session(autocommit=False) + else: + txn = conn.cursor() + txn.execute( + "CREATE INDEX state_groups_state_type_idx" + " ON state_groups_state(state_group, type, state_key)" + ) + txn.execute("DROP INDEX IF EXISTS state_groups_state_id") + + yield self.db.runWithConnection(reindex_txn) + + yield self.db.updates._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) + + return 1 diff --git a/synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql b/synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql new file mode 100644 index 0000000000..ae09fa0065 --- /dev/null +++ b/synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql @@ -0,0 +1,16 @@ +/* Copyright 2015, 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +DROP INDEX IF EXISTS state_groups_state_tuple; diff --git a/synapse/storage/data_stores/state/schema/delta/30/state_stream.sql b/synapse/storage/data_stores/state/schema/delta/30/state_stream.sql new file mode 100644 index 0000000000..e85699e82e --- /dev/null +++ b/synapse/storage/data_stores/state/schema/delta/30/state_stream.sql @@ -0,0 +1,33 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +/* We used to create a table called current_state_resets, but this is no + * longer used and is removed in delta 54. + */ + +/* The outlier events that have aquired a state group typically through + * backfill. This is tracked separately to the events table, as assigning a + * state group change the position of the existing event in the stream + * ordering. + * However since a stream_ordering is assigned in persist_event for the + * (event, state) pair, we can use that stream_ordering to identify when + * the new state was assigned for the event. + */ +CREATE TABLE IF NOT EXISTS ex_outlier_stream( + event_stream_ordering BIGINT PRIMARY KEY NOT NULL, + event_id TEXT NOT NULL, + state_group BIGINT NOT NULL +); diff --git a/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql b/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql new file mode 100644 index 0000000000..1450313bfa --- /dev/null +++ b/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql @@ -0,0 +1,19 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- The following indices are redundant, other indices are equivalent or +-- supersets +DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY diff --git a/synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql b/synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql new file mode 100644 index 0000000000..33980d02f0 --- /dev/null +++ b/synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT into background_updates (update_name, progress_json, depends_on) + VALUES ('state_group_state_type_index', '{}', 'state_group_state_deduplication'); diff --git a/synapse/storage/data_stores/state/schema/delta/35/state.sql b/synapse/storage/data_stores/state/schema/delta/35/state.sql new file mode 100644 index 0000000000..0f1fa68a89 --- /dev/null +++ b/synapse/storage/data_stores/state/schema/delta/35/state.sql @@ -0,0 +1,22 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE state_group_edges( + state_group BIGINT NOT NULL, + prev_state_group BIGINT NOT NULL +); + +CREATE INDEX state_group_edges_idx ON state_group_edges(state_group); +CREATE INDEX state_group_edges_prev_idx ON state_group_edges(prev_state_group); diff --git a/synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql b/synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql new file mode 100644 index 0000000000..97e5067ef4 --- /dev/null +++ b/synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT into background_updates (update_name, progress_json) + VALUES ('state_group_state_deduplication', '{}'); diff --git a/synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py b/synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py new file mode 100644 index 0000000000..9fd1ccf6f7 --- /dev/null +++ b/synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py @@ -0,0 +1,34 @@ +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.engines import PostgresEngine + + +def run_create(cur, database_engine, *args, **kwargs): + if isinstance(database_engine, PostgresEngine): + # if we already have some state groups, we want to start making new + # ones with a higher id. + cur.execute("SELECT max(id) FROM state_groups") + row = cur.fetchone() + + if row[0] is None: + start_val = 1 + else: + start_val = row[0] + 1 + + cur.execute("CREATE SEQUENCE state_group_id_seq START WITH %s", (start_val,)) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql b/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql new file mode 100644 index 0000000000..7916ef18b2 --- /dev/null +++ b/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql @@ -0,0 +1,17 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('state_groups_room_id_idx', '{}'); diff --git a/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql b/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql new file mode 100644 index 0000000000..35f97d6b3d --- /dev/null +++ b/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql @@ -0,0 +1,37 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE state_groups ( + id BIGINT PRIMARY KEY, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL +); + +CREATE TABLE state_groups_state ( + state_group BIGINT NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL, + event_id TEXT NOT NULL +); + +CREATE TABLE state_group_edges ( + state_group BIGINT NOT NULL, + prev_state_group BIGINT NOT NULL +); + +CREATE INDEX state_group_edges_idx ON state_group_edges (state_group); +CREATE INDEX state_group_edges_prev_idx ON state_group_edges (prev_state_group); +CREATE INDEX state_groups_state_type_idx ON state_groups_state (state_group, type, state_key); diff --git a/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres b/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres new file mode 100644 index 0000000000..fcd926c9fb --- /dev/null +++ b/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres @@ -0,0 +1,21 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE SEQUENCE state_group_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1; diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py new file mode 100644 index 0000000000..d53695f238 --- /dev/null +++ b/synapse/storage/data_stores/state/store.py @@ -0,0 +1,640 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from collections import namedtuple + +from six import iteritems +from six.moves import range + +from twisted.internet import defer + +from synapse.api.constants import EventTypes +from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore +from synapse.storage.database import Database +from synapse.storage.state import StateFilter +from synapse.util.caches import get_cache_factor_for +from synapse.util.caches.descriptors import cached +from synapse.util.caches.dictionary_cache import DictionaryCache + +logger = logging.getLogger(__name__) + + +MAX_STATE_DELTA_HOPS = 100 + + +class _GetStateGroupDelta( + namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids")) +): + """Return type of get_state_group_delta that implements __len__, which lets + us use the itrable flag when caching + """ + + __slots__ = [] + + def __len__(self): + return len(self.delta_ids) if self.delta_ids else 0 + + +class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): + """A data store for fetching/storing state groups. + """ + + def __init__(self, database: Database, db_conn, hs): + super(StateGroupDataStore, self).__init__(database, db_conn, hs) + + # Originally the state store used a single DictionaryCache to cache the + # event IDs for the state types in a given state group to avoid hammering + # on the state_group* tables. + # + # The point of using a DictionaryCache is that it can cache a subset + # of the state events for a given state group (i.e. a subset of the keys for a + # given dict which is an entry in the cache for a given state group ID). + # + # However, this poses problems when performing complicated queries + # on the store - for instance: "give me all the state for this group, but + # limit members to this subset of users", as DictionaryCache's API isn't + # rich enough to say "please cache any of these fields, apart from this subset". + # This is problematic when lazy loading members, which requires this behaviour, + # as without it the cache has no choice but to speculatively load all + # state events for the group, which negates the efficiency being sought. + # + # Rather than overcomplicating DictionaryCache's API, we instead split the + # state_group_cache into two halves - one for tracking non-member events, + # and the other for tracking member_events. This means that lazy loading + # queries can be made in a cache-friendly manner by querying both caches + # separately and then merging the result. So for the example above, you + # would query the members cache for a specific subset of state keys + # (which DictionaryCache will handle efficiently and fine) and the non-members + # cache for all state (which DictionaryCache will similarly handle fine) + # and then just merge the results together. + # + # We size the non-members cache to be smaller than the members cache as the + # vast majority of state in Matrix (today) is member events. + + self._state_group_cache = DictionaryCache( + "*stateGroupCache*", + # TODO: this hasn't been tuned yet + 50000 * get_cache_factor_for("stateGroupCache"), + ) + self._state_group_members_cache = DictionaryCache( + "*stateGroupMembersCache*", + 500000 * get_cache_factor_for("stateGroupMembersCache"), + ) + + @cached(max_entries=10000, iterable=True) + def get_state_group_delta(self, state_group): + """Given a state group try to return a previous group and a delta between + the old and the new. + + Returns: + (prev_group, delta_ids), where both may be None. + """ + + def _get_state_group_delta_txn(txn): + prev_group = self.db.simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": state_group}, + retcol="prev_state_group", + allow_none=True, + ) + + if not prev_group: + return _GetStateGroupDelta(None, None) + + delta_ids = self.db.simple_select_list_txn( + txn, + table="state_groups_state", + keyvalues={"state_group": state_group}, + retcols=("type", "state_key", "event_id"), + ) + + return _GetStateGroupDelta( + prev_group, + {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids}, + ) + + return self.db.runInteraction( + "get_state_group_delta", _get_state_group_delta_txn + ) + + @defer.inlineCallbacks + def _get_state_groups_from_groups(self, groups, state_filter): + """Returns the state groups for a given set of groups, filtering on + types of state events. + + Args: + groups(list[int]): list of state group IDs to query + state_filter (StateFilter): The state filter used to fetch state + from the database. + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) + """ + results = {} + + chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] + for chunk in chunks: + res = yield self.db.runInteraction( + "_get_state_groups_from_groups", + self._get_state_groups_from_groups_txn, + chunk, + state_filter, + ) + results.update(res) + + return results + + def _get_state_for_group_using_cache(self, cache, group, state_filter): + """Checks if group is in cache. See `_get_state_for_groups` + + Args: + cache(DictionaryCache): the state group cache to use + group(int): The state group to lookup + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns 2-tuple (`state_dict`, `got_all`). + `got_all` is a bool indicating if we successfully retrieved all + requests state from the cache, if False we need to query the DB for the + missing state. + """ + is_all, known_absent, state_dict_ids = cache.get(group) + + if is_all or state_filter.is_full(): + # Either we have everything or want everything, either way + # `is_all` tells us whether we've gotten everything. + return state_filter.filter_state(state_dict_ids), is_all + + # tracks whether any of our requested types are missing from the cache + missing_types = False + + if state_filter.has_wildcards(): + # We don't know if we fetched all the state keys for the types in + # the filter that are wildcards, so we have to assume that we may + # have missed some. + missing_types = True + else: + # There aren't any wild cards, so `concrete_types()` returns the + # complete list of event types we're wanting. + for key in state_filter.concrete_types(): + if key not in state_dict_ids and key not in known_absent: + missing_types = True + break + + return state_filter.filter_state(state_dict_ids), not missing_types + + @defer.inlineCallbacks + def _get_state_for_groups(self, groups, state_filter=StateFilter.all()): + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key + + Args: + groups (iterable[int]): list of state groups for which we want + to get the state. + state_filter (StateFilter): The state filter used to fetch state + from the database. + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) + """ + + member_filter, non_member_filter = state_filter.get_member_split() + + # Now we look them up in the member and non-member caches + ( + non_member_state, + incomplete_groups_nm, + ) = yield self._get_state_for_groups_using_cache( + groups, self._state_group_cache, state_filter=non_member_filter + ) + + ( + member_state, + incomplete_groups_m, + ) = yield self._get_state_for_groups_using_cache( + groups, self._state_group_members_cache, state_filter=member_filter + ) + + state = dict(non_member_state) + for group in groups: + state[group].update(member_state[group]) + + # Now fetch any missing groups from the database + + incomplete_groups = incomplete_groups_m | incomplete_groups_nm + + if not incomplete_groups: + return state + + cache_sequence_nm = self._state_group_cache.sequence + cache_sequence_m = self._state_group_members_cache.sequence + + # Help the cache hit ratio by expanding the filter a bit + db_state_filter = state_filter.return_expanded() + + group_to_state_dict = yield self._get_state_groups_from_groups( + list(incomplete_groups), state_filter=db_state_filter + ) + + # Now lets update the caches + self._insert_into_cache( + group_to_state_dict, + db_state_filter, + cache_seq_num_members=cache_sequence_m, + cache_seq_num_non_members=cache_sequence_nm, + ) + + # And finally update the result dict, by filtering out any extra + # stuff we pulled out of the database. + for group, group_state_dict in iteritems(group_to_state_dict): + # We just replace any existing entries, as we will have loaded + # everything we need from the database anyway. + state[group] = state_filter.filter_state(group_state_dict) + + return state + + def _get_state_for_groups_using_cache(self, groups, cache, state_filter): + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key, querying from a specific cache. + + Args: + groups (iterable[int]): list of state groups for which we want + to get the state. + cache (DictionaryCache): the cache of group ids to state dicts which + we will pass through - either the normal state cache or the specific + members state cache. + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns: + tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of + dict of state_group_id -> (dict of (type, state_key) -> event id) + of entries in the cache, and the state group ids either missing + from the cache or incomplete. + """ + results = {} + incomplete_groups = set() + for group in set(groups): + state_dict_ids, got_all = self._get_state_for_group_using_cache( + cache, group, state_filter + ) + results[group] = state_dict_ids + + if not got_all: + incomplete_groups.add(group) + + return results, incomplete_groups + + def _insert_into_cache( + self, + group_to_state_dict, + state_filter, + cache_seq_num_members, + cache_seq_num_non_members, + ): + """Inserts results from querying the database into the relevant cache. + + Args: + group_to_state_dict (dict): The new entries pulled from database. + Map from state group to state dict + state_filter (StateFilter): The state filter used to fetch state + from the database. + cache_seq_num_members (int): Sequence number of member cache since + last lookup in cache + cache_seq_num_non_members (int): Sequence number of member cache since + last lookup in cache + """ + + # We need to work out which types we've fetched from the DB for the + # member vs non-member caches. This should be as accurate as possible, + # but can be an underestimate (e.g. when we have wild cards) + + member_filter, non_member_filter = state_filter.get_member_split() + if member_filter.is_full(): + # We fetched all member events + member_types = None + else: + # `concrete_types()` will only return a subset when there are wild + # cards in the filter, but that's fine. + member_types = member_filter.concrete_types() + + if non_member_filter.is_full(): + # We fetched all non member events + non_member_types = None + else: + non_member_types = non_member_filter.concrete_types() + + for group, group_state_dict in iteritems(group_to_state_dict): + state_dict_members = {} + state_dict_non_members = {} + + for k, v in iteritems(group_state_dict): + if k[0] == EventTypes.Member: + state_dict_members[k] = v + else: + state_dict_non_members[k] = v + + self._state_group_members_cache.update( + cache_seq_num_members, + key=group, + value=state_dict_members, + fetched_keys=member_types, + ) + + self._state_group_cache.update( + cache_seq_num_non_members, + key=group, + value=state_dict_non_members, + fetched_keys=non_member_types, + ) + + def store_state_group( + self, event_id, room_id, prev_group, delta_ids, current_state_ids + ): + """Store a new set of state, returning a newly assigned state group. + + Args: + event_id (str): The event ID for which the state was calculated + room_id (str) + prev_group (int|None): A previous state group for the room, optional. + delta_ids (dict|None): The delta between state at `prev_group` and + `current_state_ids`, if `prev_group` was given. Same format as + `current_state_ids`. + current_state_ids (dict): The state to store. Map of (type, state_key) + to event_id. + + Returns: + Deferred[int]: The state group ID + """ + + def _store_state_group_txn(txn): + if current_state_ids is None: + # AFAIK, this can never happen + raise Exception("current_state_ids cannot be None") + + state_group = self.database_engine.get_next_state_group_id(txn) + + self.db.simple_insert_txn( + txn, + table="state_groups", + values={"id": state_group, "room_id": room_id, "event_id": event_id}, + ) + + # We persist as a delta if we can, while also ensuring the chain + # of deltas isn't tooo long, as otherwise read performance degrades. + if prev_group: + is_in_db = self.db.simple_select_one_onecol_txn( + txn, + table="state_groups", + keyvalues={"id": prev_group}, + retcol="id", + allow_none=True, + ) + if not is_in_db: + raise Exception( + "Trying to persist state with unpersisted prev_group: %r" + % (prev_group,) + ) + + potential_hops = self._count_state_group_hops_txn(txn, prev_group) + if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: + self.db.simple_insert_txn( + txn, + table="state_group_edges", + values={"state_group": state_group, "prev_state_group": prev_group}, + ) + + self.db.simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(delta_ids) + ], + ) + else: + self.db.simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(current_state_ids) + ], + ) + + # Prefill the state group caches with this group. + # It's fine to use the sequence like this as the state group map + # is immutable. (If the map wasn't immutable then this prefill could + # race with another update) + + current_member_state_ids = { + s: ev + for (s, ev) in iteritems(current_state_ids) + if s[0] == EventTypes.Member + } + txn.call_after( + self._state_group_members_cache.update, + self._state_group_members_cache.sequence, + key=state_group, + value=dict(current_member_state_ids), + ) + + current_non_member_state_ids = { + s: ev + for (s, ev) in iteritems(current_state_ids) + if s[0] != EventTypes.Member + } + txn.call_after( + self._state_group_cache.update, + self._state_group_cache.sequence, + key=state_group, + value=dict(current_non_member_state_ids), + ) + + return state_group + + return self.db.runInteraction("store_state_group", _store_state_group_txn) + + def purge_unreferenced_state_groups( + self, room_id: str, state_groups_to_delete + ) -> defer.Deferred: + """Deletes no longer referenced state groups and de-deltas any state + groups that reference them. + + Args: + room_id: The room the state groups belong to (must all be in the + same room). + state_groups_to_delete (Collection[int]): Set of all state groups + to delete. + """ + + return self.db.runInteraction( + "purge_unreferenced_state_groups", + self._purge_unreferenced_state_groups, + room_id, + state_groups_to_delete, + ) + + def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete): + logger.info( + "[purge] found %i state groups to delete", len(state_groups_to_delete) + ) + + rows = self.db.simple_select_many_txn( + txn, + table="state_group_edges", + column="prev_state_group", + iterable=state_groups_to_delete, + keyvalues={}, + retcols=("state_group",), + ) + + remaining_state_groups = set( + row["state_group"] + for row in rows + if row["state_group"] not in state_groups_to_delete + ) + + logger.info( + "[purge] de-delta-ing %i remaining state groups", + len(remaining_state_groups), + ) + + # Now we turn the state groups that reference to-be-deleted state + # groups to non delta versions. + for sg in remaining_state_groups: + logger.info("[purge] de-delta-ing remaining state group %s", sg) + curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) + curr_state = curr_state[sg] + + self.db.simple_delete_txn( + txn, table="state_groups_state", keyvalues={"state_group": sg} + ) + + self.db.simple_delete_txn( + txn, table="state_group_edges", keyvalues={"state_group": sg} + ) + + self.db.simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": sg, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(curr_state) + ], + ) + + logger.info("[purge] removing redundant state groups") + txn.executemany( + "DELETE FROM state_groups_state WHERE state_group = ?", + ((sg,) for sg in state_groups_to_delete), + ) + txn.executemany( + "DELETE FROM state_groups WHERE id = ?", + ((sg,) for sg in state_groups_to_delete), + ) + + @defer.inlineCallbacks + def get_previous_state_groups(self, state_groups): + """Fetch the previous groups of the given state groups. + + Args: + state_groups (Iterable[int]) + + Returns: + Deferred[dict[int, int]]: mapping from state group to previous + state group. + """ + + rows = yield self.db.simple_select_many_batch( + table="state_group_edges", + column="prev_state_group", + iterable=state_groups, + keyvalues={}, + retcols=("prev_state_group", "state_group"), + desc="get_previous_state_groups", + ) + + return {row["state_group"]: row["prev_state_group"] for row in rows} + + def purge_room_state(self, room_id, state_groups_to_delete): + """Deletes all record of a room from state tables + + Args: + room_id (str): + state_groups_to_delete (list[int]): State groups to delete + """ + + return self.db.runInteraction( + "purge_room_state", + self._purge_room_state_txn, + room_id, + state_groups_to_delete, + ) + + def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete): + # first we have to delete the state groups states + logger.info("[purge] removing %s from state_groups_state", room_id) + + self.db.simple_delete_many_txn( + txn, + table="state_groups_state", + column="state_group", + iterable=state_groups_to_delete, + keyvalues={}, + ) + + # ... and the state group edges + logger.info("[purge] removing %s from state_group_edges", room_id) + + self.db.simple_delete_many_txn( + txn, + table="state_group_edges", + column="state_group", + iterable=state_groups_to_delete, + keyvalues={}, + ) + + # ... and the state groups + logger.info("[purge] removing %s from state_groups", room_id) + + self.db.simple_delete_many_txn( + txn, + table="state_groups", + column="id", + iterable=state_groups_to_delete, + keyvalues={}, + ) diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index fa03ca9ff7..1ed44925fc 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -183,7 +183,7 @@ class EventsPersistenceStorage(object): # so we use separate variables here even though they point to the same # store for now. self.main_store = stores.main - self.state_store = stores.main + self.state_store = stores.state self._clock = hs.get_clock() self.is_mine_id = hs.is_mine_id diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 403848ad03..e70026b80a 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -42,7 +42,7 @@ class UpgradeDatabaseException(PrepareDatabaseException): pass -def prepare_database(db_conn, database_engine, config, data_stores=["main"]): +def prepare_database(db_conn, database_engine, config, data_stores=["main", "state"]): """Prepares a database for usage. Will either create all necessary tables or upgrade from an older schema version. diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py index a368182034..d6a7bd7834 100644 --- a/synapse/storage/purge_events.py +++ b/synapse/storage/purge_events.py @@ -58,7 +58,7 @@ class PurgeEventsStorage(object): sg_to_delete = yield self._find_unreferenced_groups(state_groups) - yield self.stores.main.purge_unreferenced_state_groups(room_id, sg_to_delete) + yield self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete) @defer.inlineCallbacks def _find_unreferenced_groups(self, state_groups): @@ -102,7 +102,7 @@ class PurgeEventsStorage(object): # groups that are referenced. current_search -= referenced - edges = yield self.stores.main.get_previous_state_groups(current_search) + edges = yield self.stores.state.get_previous_state_groups(current_search) prevs = set(edges.values()) # We don't bother re-handling groups we've already seen diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 3735846899..cbeb586014 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -342,7 +342,7 @@ class StateGroupStorage(object): (prev_group, delta_ids) """ - return self.stores.main.get_state_group_delta(state_group) + return self.stores.state.get_state_group_delta(state_group) @defer.inlineCallbacks def get_state_groups_ids(self, _room_id, event_ids): @@ -362,7 +362,7 @@ class StateGroupStorage(object): event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) groups = set(itervalues(event_to_groups)) - group_to_state = yield self.stores.main._get_state_for_groups(groups) + group_to_state = yield self.stores.state._get_state_for_groups(groups) return group_to_state @@ -423,7 +423,7 @@ class StateGroupStorage(object): dict of state_group_id -> (dict of (type, state_key) -> event id) """ - return self.stores.main._get_state_groups_from_groups(groups, state_filter) + return self.stores.state._get_state_groups_from_groups(groups, state_filter) @defer.inlineCallbacks def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): @@ -439,7 +439,7 @@ class StateGroupStorage(object): event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) groups = set(itervalues(event_to_groups)) - group_to_state = yield self.stores.main._get_state_for_groups( + group_to_state = yield self.stores.state._get_state_for_groups( groups, state_filter ) @@ -476,7 +476,7 @@ class StateGroupStorage(object): event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) groups = set(itervalues(event_to_groups)) - group_to_state = yield self.stores.main._get_state_for_groups( + group_to_state = yield self.stores.state._get_state_for_groups( groups, state_filter ) @@ -532,7 +532,7 @@ class StateGroupStorage(object): Deferred[dict[int, dict[tuple[str, str], str]]]: dict of state_group_id -> (dict of (type, state_key) -> event id) """ - return self.stores.main._get_state_for_groups(groups, state_filter) + return self.stores.state._get_state_for_groups(groups, state_filter) def store_state_group( self, event_id, room_id, prev_group, delta_ids, current_state_ids @@ -552,6 +552,6 @@ class StateGroupStorage(object): Returns: Deferred[int]: The state group ID """ - return self.stores.main.store_state_group( + return self.stores.state.store_state_group( event_id, room_id, prev_group, delta_ids, current_state_ids ) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 43200654f1..d6ecf102f8 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -35,7 +35,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store = hs.get_datastore() self.storage = hs.get_storage() - self.state_datastore = self.store + self.state_datastore = self.storage.state.stores.state self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() diff --git a/tests/utils.py b/tests/utils.py index 9f5bf40b4b..e2e9cafd79 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -231,7 +231,7 @@ def setup_test_homeserver( "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1}, } - database = DatabaseConnectionConfig("master", database_config, ["main"]) + database = DatabaseConnectionConfig("master", database_config) config.database.databases = [database] db_engine = create_engine(database.config) -- cgit 1.5.1 From b6b57ecb4e845490fc26a537ff57df8cae1587b9 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 3 Jan 2020 14:19:48 +0000 Subject: Kill off redundant SynapseRequestFactory (#6619) We already get the Site via the Channel, so there's no need for a dedicated RequestFactory: we can just use the right constructor. --- changelog.d/6619.misc | 1 + synapse/http/site.py | 18 +++--------------- tests/server.py | 6 ++++-- 3 files changed, 8 insertions(+), 17 deletions(-) create mode 100644 changelog.d/6619.misc (limited to 'tests') diff --git a/changelog.d/6619.misc b/changelog.d/6619.misc new file mode 100644 index 0000000000..b608133219 --- /dev/null +++ b/changelog.d/6619.misc @@ -0,0 +1 @@ +Simplify http handling by removing redundant SynapseRequestFactory. diff --git a/synapse/http/site.py b/synapse/http/site.py index ff8184a3d0..9f2d035fa0 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -47,9 +47,9 @@ class SynapseRequest(Request): logcontext(LoggingContext) : the log context for this request """ - def __init__(self, site, channel, *args, **kw): + def __init__(self, channel, *args, **kw): Request.__init__(self, channel, *args, **kw) - self.site = site + self.site = channel.site self._channel = channel # this is used by the tests self.authenticated_entity = None self.start_time = 0 @@ -331,18 +331,6 @@ class XForwardedForRequest(SynapseRequest): ) -class SynapseRequestFactory(object): - def __init__(self, site, x_forwarded_for): - self.site = site - self.x_forwarded_for = x_forwarded_for - - def __call__(self, *args, **kwargs): - if self.x_forwarded_for: - return XForwardedForRequest(self.site, *args, **kwargs) - else: - return SynapseRequest(self.site, *args, **kwargs) - - class SynapseSite(Site): """ Subclass of a twisted http Site that does access logging with python's @@ -364,7 +352,7 @@ class SynapseSite(Site): self.site_tag = site_tag proxied = config.get("x_forwarded", False) - self.requestFactory = SynapseRequestFactory(self, proxied) + self.requestFactory = XForwardedForRequest if proxied else SynapseRequest self.access_logger = logging.getLogger(logger_name) self.server_version_string = server_version_string.encode("ascii") diff --git a/tests/server.py b/tests/server.py index a554dfdd57..1644710aa0 100644 --- a/tests/server.py +++ b/tests/server.py @@ -20,6 +20,7 @@ from twisted.python.failure import Failure from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock from twisted.web.http import unquote from twisted.web.http_headers import Headers +from twisted.web.server import Site from synapse.http.site import SynapseRequest from synapse.util import Clock @@ -42,6 +43,7 @@ class FakeChannel(object): wire). """ + site = attr.ib(type=Site) _reactor = attr.ib() result = attr.ib(default=attr.Factory(dict)) _producer = None @@ -176,9 +178,9 @@ def make_request( content = content.encode("utf8") site = FakeSite() - channel = FakeChannel(reactor) + channel = FakeChannel(site, reactor) - req = request(site, channel) + req = request(channel) req.process = lambda: b"" req.content = BytesIO(content) req.postpath = list(map(unquote, path[1:].split(b"/"))) -- cgit 1.5.1 From 18674eebb1fa5d7445952d7e201afe33bd040523 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 6 Jan 2020 12:28:58 +0000 Subject: Workaround for error when fetching notary's own key (#6620) * Kill off redundant SynapseRequestFactory We already get the Site via the Channel, so there's no need for a dedicated RequestFactory: we can just use the right constructor. * Workaround for error when fetching notary's own key As a notary server, when we return our own keys, include all of our signing keys in verify_keys. This is a workaround for #6596. --- changelog.d/6620.misc | 1 + synapse/rest/key/v2/remote_key_resource.py | 30 ++++-- tests/rest/key/v2/test_remote_key_resource.py | 130 ++++++++++++++++++++++++++ tests/unittest.py | 11 ++- 4 files changed, 163 insertions(+), 9 deletions(-) create mode 100644 changelog.d/6620.misc create mode 100644 tests/rest/key/v2/test_remote_key_resource.py (limited to 'tests') diff --git a/changelog.d/6620.misc b/changelog.d/6620.misc new file mode 100644 index 0000000000..8bfb78fb20 --- /dev/null +++ b/changelog.d/6620.misc @@ -0,0 +1 @@ +Add a workaround for synapse raising exceptions when fetching the notary's own key from the notary. diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index e7fc3f0431..bf5e0eb844 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -15,6 +15,7 @@ import logging from canonicaljson import encode_canonical_json, json +from signedjson.key import encode_verify_key_base64 from signedjson.sign import sign_json from twisted.internet import defer @@ -216,15 +217,28 @@ class RemoteKey(DirectServeResource): if cache_misses and query_remote_on_cache_miss: yield self.fetcher.get_keys(cache_misses) yield self.query_keys(request, query, query_remote_on_cache_miss=False) - else: - signed_keys = [] - for key_json in json_results: - key_json = json.loads(key_json) + return + + signed_keys = [] + for key_json in json_results: + key_json = json.loads(key_json) + + # backwards-compatibility hack for #6596: if the requested key belongs + # to us, make sure that all of the signing keys appear in the + # "verify_keys" section. + if key_json["server_name"] == self.config.server_name: + verify_keys = key_json["verify_keys"] for signing_key in self.config.key_server_signing_keys: - key_json = sign_json(key_json, self.config.server_name, signing_key) + key_id = "%s:%s" % (signing_key.alg, signing_key.version) + verify_keys[key_id] = { + "key": encode_verify_key_base64(signing_key.verify_key) + } + + for signing_key in self.config.key_server_signing_keys: + key_json = sign_json(key_json, self.config.server_name, signing_key) - signed_keys.append(key_json) + signed_keys.append(key_json) - results = {"server_keys": signed_keys} + results = {"server_keys": signed_keys} - respond_with_json_bytes(request, 200, encode_canonical_json(results)) + respond_with_json_bytes(request, 200, encode_canonical_json(results)) diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py new file mode 100644 index 0000000000..d8246b4e78 --- /dev/null +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import urllib.parse +from io import BytesIO + +from mock import Mock + +import signedjson.key +from nacl.signing import SigningKey +from signedjson.sign import sign_json + +from twisted.web.resource import NoResource + +from synapse.http.site import SynapseRequest +from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.util.httpresourcetree import create_resource_tree + +from tests import unittest +from tests.server import FakeChannel, wait_until_result + + +class RemoteKeyResourceTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): + self.http_client = Mock() + return self.setup_test_homeserver(http_client=self.http_client) + + def create_test_json_resource(self): + return create_resource_tree( + {"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource() + ) + + def expect_outgoing_key_request( + self, server_name: str, signing_key: SigningKey + ) -> None: + """ + Tell the mock http client to expect an outgoing GET request for the given key + """ + + def get_json(destination, path, ignore_backoff=False, **kwargs): + self.assertTrue(ignore_backoff) + self.assertEqual(destination, server_name) + key_id = "%s:%s" % (signing_key.alg, signing_key.version) + self.assertEqual( + path, "/_matrix/key/v2/server/%s" % (urllib.parse.quote(key_id),) + ) + + response = { + "server_name": server_name, + "old_verify_keys": {}, + "valid_until_ts": 200 * 1000, + "verify_keys": { + key_id: { + "key": signedjson.key.encode_verify_key_base64( + signing_key.verify_key + ) + } + }, + } + sign_json(response, server_name, signing_key) + return response + + self.http_client.get_json.side_effect = get_json + + def make_notary_request(self, server_name: str, key_id: str) -> dict: + """Send a GET request to the test server requesting the given key. + + Checks that the response is a 200 and returns the decoded json body. + """ + channel = FakeChannel(self.site, self.reactor) + req = SynapseRequest(channel) + req.content = BytesIO(b"") + req.requestReceived( + b"GET", + b"/_matrix/key/v2/query/%s/%s" + % (server_name.encode("utf-8"), key_id.encode("utf-8")), + b"1.1", + ) + wait_until_result(self.reactor, req) + self.assertEqual(channel.code, 200) + resp = channel.json_body + return resp + + def test_get_key(self): + """Fetch a remote key""" + SERVER_NAME = "remote.server" + testkey = signedjson.key.generate_signing_key("ver1") + self.expect_outgoing_key_request(SERVER_NAME, testkey) + + resp = self.make_notary_request(SERVER_NAME, "ed25519:ver1") + keys = resp["server_keys"] + self.assertEqual(len(keys), 1) + + self.assertIn("ed25519:ver1", keys[0]["verify_keys"]) + self.assertEqual(len(keys[0]["verify_keys"]), 1) + + # it should be signed by both the origin server and the notary + self.assertIn(SERVER_NAME, keys[0]["signatures"]) + self.assertIn(self.hs.hostname, keys[0]["signatures"]) + + def test_get_own_key(self): + """Fetch our own key""" + testkey = signedjson.key.generate_signing_key("ver1") + self.expect_outgoing_key_request(self.hs.hostname, testkey) + + resp = self.make_notary_request(self.hs.hostname, "ed25519:ver1") + keys = resp["server_keys"] + self.assertEqual(len(keys), 1) + + # it should be signed by both itself, and the notary signing key + sigs = keys[0]["signatures"] + self.assertEqual(len(sigs), 1) + self.assertIn(self.hs.hostname, sigs) + oursigs = sigs[self.hs.hostname] + self.assertEqual(len(oursigs), 2) + + # and both keys should be present in the verify_keys section + self.assertIn("ed25519:ver1", keys[0]["verify_keys"]) + self.assertIn("ed25519:a_lPym", keys[0]["verify_keys"]) diff --git a/tests/unittest.py b/tests/unittest.py index b30b7d1718..cbda237278 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -36,7 +36,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.config.ratelimiting import FederationRateLimitConfig from synapse.federation.transport import server as federation_server from synapse.http.server import JsonResource -from synapse.http.site import SynapseRequest +from synapse.http.site import SynapseRequest, SynapseSite from synapse.logging.context import LoggingContext from synapse.server import HomeServer from synapse.types import Requester, UserID, create_requester @@ -210,6 +210,15 @@ class HomeserverTestCase(TestCase): # Register the resources self.resource = self.create_test_json_resource() + # create a site to wrap the resource. + self.site = SynapseSite( + logger_name="synapse.access.http.fake", + site_tag="test", + config={}, + resource=self.resource, + server_version_string="1", + ) + from tests.rest.client.v1.utils import RestHelper self.helper = RestHelper(self.hs, self.resource, getattr(self, "user_id", None)) -- cgit 1.5.1 From 4b36b482e0cc1a63db27534c4ea5d9608cdb6a79 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 6 Jan 2020 12:33:56 +0000 Subject: Fix exception when fetching notary server's old keys (#6625) Lift the restriction that *all* the keys used for signing v2 key responses be present in verify_keys. Fixes #6596. --- changelog.d/6625.bugfix | 1 + synapse/crypto/keyring.py | 13 ++-- tests/crypto/test_keyring.py | 139 +++++++++++++++++++++++++++++-------------- 3 files changed, 103 insertions(+), 50 deletions(-) create mode 100644 changelog.d/6625.bugfix (limited to 'tests') diff --git a/changelog.d/6625.bugfix b/changelog.d/6625.bugfix new file mode 100644 index 0000000000..a8dc5587dc --- /dev/null +++ b/changelog.d/6625.bugfix @@ -0,0 +1 @@ +Fix exception when fetching the `matrix.org:ed25519:auto` key. diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 7cfad192e8..6fe5a6a26a 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -511,17 +511,18 @@ class BaseV2KeyFetcher(object): server_name = response_json["server_name"] verified = False for key_id in response_json["signatures"].get(server_name, {}): - # each of the keys used for the signature must be present in the response - # json. key = verify_keys.get(key_id) if not key: - raise KeyLookupError( - "Key response is signed by key id %s:%s but that key is not " - "present in the response" % (server_name, key_id) - ) + # the key may not be present in verify_keys if: + # * we got the key from the notary server, and: + # * the key belongs to the notary server, and: + # * the notary server is using a different key to sign notary + # responses. + continue verify_signed_json(response_json, server_name, key.verify_key) verified = True + break if not verified: raise KeyLookupError( diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 8efd39c7f7..34d5895f18 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -19,6 +19,7 @@ from mock import Mock import canonicaljson import signedjson.key import signedjson.sign +from nacl.signing import SigningKey from signedjson.key import encode_verify_key_base64, get_verify_key from twisted.internet import defer @@ -412,34 +413,37 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): handlers=None, http_client=self.http_client, config=config ) - def test_get_keys_from_perspectives(self): - # arbitrarily advance the clock a bit - self.reactor.advance(100) - - fetcher = PerspectivesKeyFetcher(self.hs) - - SERVER_NAME = "server2" - testkey = signedjson.key.generate_signing_key("ver1") - testverifykey = signedjson.key.get_verify_key(testkey) - testverifykey_id = "ed25519:ver1" - VALID_UNTIL_TS = 200 * 1000 + def build_perspectives_response( + self, server_name: str, signing_key: SigningKey, valid_until_ts: int, + ) -> dict: + """ + Build a valid perspectives server response to a request for the given key + """ + verify_key = signedjson.key.get_verify_key(signing_key) + verifykey_id = "%s:%s" % (verify_key.alg, verify_key.version) - # valid response response = { - "server_name": SERVER_NAME, + "server_name": server_name, "old_verify_keys": {}, - "valid_until_ts": VALID_UNTIL_TS, + "valid_until_ts": valid_until_ts, "verify_keys": { - testverifykey_id: { - "key": signedjson.key.encode_verify_key_base64(testverifykey) + verifykey_id: { + "key": signedjson.key.encode_verify_key_base64(verify_key) } }, } - # the response must be signed by both the origin server and the perspectives # server. - signedjson.sign.sign_json(response, SERVER_NAME, testkey) + signedjson.sign.sign_json(response, server_name, signing_key) self.mock_perspective_server.sign_response(response) + return response + + def expect_outgoing_key_query( + self, expected_server_name: str, expected_key_id: str, response: dict + ) -> None: + """ + Tell the mock http client to expect a perspectives-server key query + """ def post_json(destination, path, data, **kwargs): self.assertEqual(destination, self.mock_perspective_server.server_name) @@ -447,11 +451,79 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): # check that the request is for the expected key q = data["server_keys"] - self.assertEqual(list(q[SERVER_NAME].keys()), ["key1"]) + self.assertEqual(list(q[expected_server_name].keys()), [expected_key_id]) return {"server_keys": [response]} self.http_client.post_json.side_effect = post_json + def test_get_keys_from_perspectives(self): + # arbitrarily advance the clock a bit + self.reactor.advance(100) + + fetcher = PerspectivesKeyFetcher(self.hs) + + SERVER_NAME = "server2" + testkey = signedjson.key.generate_signing_key("ver1") + testverifykey = signedjson.key.get_verify_key(testkey) + testverifykey_id = "ed25519:ver1" + VALID_UNTIL_TS = 200 * 1000 + + response = self.build_perspectives_response( + SERVER_NAME, testkey, VALID_UNTIL_TS, + ) + + self.expect_outgoing_key_query(SERVER_NAME, "key1", response) + + keys_to_fetch = {SERVER_NAME: {"key1": 0}} + keys = self.get_success(fetcher.get_keys(keys_to_fetch)) + self.assertIn(SERVER_NAME, keys) + k = keys[SERVER_NAME][testverifykey_id] + self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) + self.assertEqual(k.verify_key, testverifykey) + self.assertEqual(k.verify_key.alg, "ed25519") + self.assertEqual(k.verify_key.version, "ver1") + + # check that the perspectives store is correctly updated + lookup_triplet = (SERVER_NAME, testverifykey_id, None) + key_json = self.get_success( + self.hs.get_datastore().get_server_keys_json([lookup_triplet]) + ) + res = key_json[lookup_triplet] + self.assertEqual(len(res), 1) + res = res[0] + self.assertEqual(res["key_id"], testverifykey_id) + self.assertEqual(res["from_server"], self.mock_perspective_server.server_name) + self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000) + self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS) + + self.assertEqual( + bytes(res["key_json"]), canonicaljson.encode_canonical_json(response) + ) + + def test_get_perspectives_own_key(self): + """Check that we can get the perspectives server's own keys + + This is slightly complicated by the fact that the perspectives server may + use different keys for signing notary responses. + """ + + # arbitrarily advance the clock a bit + self.reactor.advance(100) + + fetcher = PerspectivesKeyFetcher(self.hs) + + SERVER_NAME = self.mock_perspective_server.server_name + testkey = signedjson.key.generate_signing_key("ver1") + testverifykey = signedjson.key.get_verify_key(testkey) + testverifykey_id = "ed25519:ver1" + VALID_UNTIL_TS = 200 * 1000 + + response = self.build_perspectives_response( + SERVER_NAME, testkey, VALID_UNTIL_TS + ) + + self.expect_outgoing_key_query(SERVER_NAME, "key1", response) + keys_to_fetch = {SERVER_NAME: {"key1": 0}} keys = self.get_success(fetcher.get_keys(keys_to_fetch)) self.assertIn(SERVER_NAME, keys) @@ -490,35 +562,14 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): VALID_UNTIL_TS = 200 * 1000 def build_response(): - # valid response - response = { - "server_name": SERVER_NAME, - "old_verify_keys": {}, - "valid_until_ts": VALID_UNTIL_TS, - "verify_keys": { - testverifykey_id: { - "key": signedjson.key.encode_verify_key_base64(testverifykey) - } - }, - } - - # the response must be signed by both the origin server and the perspectives - # server. - signedjson.sign.sign_json(response, SERVER_NAME, testkey) - self.mock_perspective_server.sign_response(response) - return response + return self.build_perspectives_response( + SERVER_NAME, testkey, VALID_UNTIL_TS + ) def get_key_from_perspectives(response): fetcher = PerspectivesKeyFetcher(self.hs) keys_to_fetch = {SERVER_NAME: {"key1": 0}} - - def post_json(destination, path, data, **kwargs): - self.assertEqual(destination, self.mock_perspective_server.server_name) - self.assertEqual(path, "/_matrix/key/v2/query") - return {"server_keys": [response]} - - self.http_client.post_json.side_effect = post_json - + self.expect_outgoing_key_query(SERVER_NAME, "key1", response) return self.get_success(fetcher.get_keys(keys_to_fetch)) # start with a valid response so we can check we are testing the right thing -- cgit 1.5.1 From 5a047816434e2ce2df8b80eb63a49c17dc3085fb Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 3 Jan 2020 15:31:09 +0000 Subject: rename get_prev_events_for_room to get_prev_events_and_hashes_for_room ... to make way for a new method which just returns the event ids --- synapse/handlers/message.py | 6 ++++-- synapse/handlers/room_member.py | 4 +++- synapse/storage/data_stores/main/event_federation.py | 5 +++-- tests/storage/test_event_federation.py | 4 ++-- 4 files changed, 12 insertions(+), 7 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 4ad752205f..2695975a16 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -740,7 +740,7 @@ class EventCreationHandler(object): % (len(prev_events_and_hashes),) ) else: - prev_events_and_hashes = yield self.store.get_prev_events_for_room( + prev_events_and_hashes = yield self.store.get_prev_events_and_hashes_for_room( builder.room_id ) @@ -1042,7 +1042,9 @@ class EventCreationHandler(object): # For each room we need to find a joined member we can use to send # the dummy event with. - prev_events_and_hashes = yield self.store.get_prev_events_for_room(room_id) + prev_events_and_hashes = yield self.store.get_prev_events_and_hashes_for_room( + room_id + ) latest_event_ids = (event_id for (event_id, _, _) in prev_events_and_hashes) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 44c5e3239c..91bb34cd55 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -370,7 +370,9 @@ class RoomMemberHandler(object): if block_invite: raise SynapseError(403, "Invites have been disabled on this server") - prev_events_and_hashes = yield self.store.get_prev_events_for_room(room_id) + prev_events_and_hashes = yield self.store.get_prev_events_and_hashes_for_room( + room_id + ) latest_event_ids = (event_id for (event_id, _, _) in prev_events_and_hashes) current_state_ids = yield self.state_handler.get_current_state_ids( diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 1f517e8fad..266fc9715f 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -149,9 +149,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas ) @defer.inlineCallbacks - def get_prev_events_for_room(self, room_id): + def get_prev_events_and_hashes_for_room(self, room_id): """ - Gets a subset of the current forward extremities in the given room. + Gets a subset of the current forward extremities in the given room, + along with their depths and hashes. Limits the result to 10 extremities, so that we can avoid creating events which refer to hundreds of prev_events. diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index eadfb90a22..3a68bf3274 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -26,7 +26,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): self.store = hs.get_datastore() @defer.inlineCallbacks - def test_get_prev_events_for_room(self): + def test_get_prev_events_and_hashes_for_room(self): room_id = "@ROOM:local" # add a bunch of events and hashes to act as forward extremities @@ -64,7 +64,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): yield self.store.db.runInteraction("insert", insert_event, i) # this should get the last five and five others - r = yield self.store.get_prev_events_for_room(room_id) + r = yield self.store.get_prev_events_and_hashes_for_room(room_id) self.assertEqual(10, len(r)) for i in range(0, 5): el = r[i] -- cgit 1.5.1 From 3bef62488e5cff4dfb33454f2f2e18cc928f319b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 3 Jan 2020 16:19:55 +0000 Subject: Remove unused hashes and depths from create_event params --- synapse/handlers/message.py | 21 +++++---------------- synapse/handlers/room_member.py | 8 +++++++- tests/unittest.py | 6 +----- 3 files changed, 13 insertions(+), 22 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 5415b0c9ee..8ea3aca2f4 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -422,7 +422,7 @@ class EventCreationHandler(object): event_dict, token_id=None, txn_id=None, - prev_events_and_hashes=None, + prev_event_ids: Optional[Collection[str]] = None, require_consent=True, ): """ @@ -439,10 +439,9 @@ class EventCreationHandler(object): token_id (str) txn_id (str) - prev_events_and_hashes (list[(str, dict[str, str], int)]|None): + prev_event_ids: the forward extremities to use as the prev_events for the - new event. For each event, a tuple of (event_id, hashes, depth) - where *hashes* is a map from algorithm to hash. + new event. If None, they will be requested from the database. @@ -497,12 +496,6 @@ class EventCreationHandler(object): if txn_id is not None: builder.internal_metadata.txn_id = txn_id - prev_event_ids = ( - None - if prev_events_and_hashes is None - else [event_id for event_id, _, _ in prev_events_and_hashes] - ) - event, context = yield self.create_new_client_event( builder=builder, requester=requester, prev_event_ids=prev_event_ids, ) @@ -1038,11 +1031,7 @@ class EventCreationHandler(object): # For each room we need to find a joined member we can use to send # the dummy event with. - prev_events_and_hashes = yield self.store.get_prev_events_and_hashes_for_room( - room_id - ) - - latest_event_ids = (event_id for (event_id, _, _) in prev_events_and_hashes) + latest_event_ids = yield self.store.get_prev_events_for_room(room_id) members = yield self.state.get_current_users_in_room( room_id, latest_event_ids=latest_event_ids @@ -1061,7 +1050,7 @@ class EventCreationHandler(object): "room_id": room_id, "sender": user_id, }, - prev_events_and_hashes=prev_events_and_hashes, + prev_event_ids=latest_event_ids, ) event.internal_metadata.proactively_send = False diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 91bb34cd55..d550ba8ab4 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -164,6 +164,12 @@ class RoomMemberHandler(object): if requester.is_guest: content["kind"] = "guest" + prev_event_ids = ( + None + if prev_events_and_hashes is None + else [event_id for event_id, _, _ in prev_events_and_hashes] + ) + event, context = yield self.event_creation_handler.create_event( requester, { @@ -177,7 +183,7 @@ class RoomMemberHandler(object): }, token_id=requester.access_token_id, txn_id=txn_id, - prev_events_and_hashes=prev_events_and_hashes, + prev_event_ids=prev_event_ids, require_consent=require_consent, ) diff --git a/tests/unittest.py b/tests/unittest.py index b30b7d1718..07b50c0ccd 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -522,10 +522,6 @@ class HomeserverTestCase(TestCase): secrets = self.hs.get_secrets() requester = Requester(user, None, False, None, None) - prev_events_and_hashes = None - if prev_event_ids: - prev_events_and_hashes = [[p, {}, 0] for p in prev_event_ids] - event, context = self.get_success( event_creator.create_event( requester, @@ -535,7 +531,7 @@ class HomeserverTestCase(TestCase): "sender": user.to_string(), "content": {"body": secrets.token_hex(), "msgtype": "m.text"}, }, - prev_events_and_hashes=prev_events_and_hashes, + prev_event_ids=prev_event_ids, ) ) -- cgit 1.5.1 From dc41fbf0dda981df117d8cf1938e023a38836cda Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 3 Jan 2020 16:30:51 +0000 Subject: Remove unused get_prev_events_and_hashes_for_room --- .../storage/data_stores/main/event_federation.py | 30 ---------------------- tests/storage/test_event_federation.py | 19 +++++--------- 2 files changed, 6 insertions(+), 43 deletions(-) (limited to 'tests') diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 88e6489576..32e76621a7 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -14,7 +14,6 @@ # limitations under the License. import itertools import logging -import random from six.moves import range from six.moves.queue import Empty, PriorityQueue @@ -148,35 +147,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas retcol="event_id", ) - @defer.inlineCallbacks - def get_prev_events_and_hashes_for_room(self, room_id): - """ - Gets a subset of the current forward extremities in the given room, - along with their depths and hashes. - - Limits the result to 10 extremities, so that we can avoid creating - events which refer to hundreds of prev_events. - - Args: - room_id (str): room_id - - Returns: - Deferred[list[(str, dict[str, str], int)]] - for each event, a tuple of (event_id, hashes, depth) - where *hashes* is a map from algorithm to hash. - """ - res = yield self.get_latest_event_ids_and_hashes_in_room(room_id) - if len(res) > 10: - # Sort by reverse depth, so we point to the most recent. - res.sort(key=lambda a: -a[2]) - - # we use half of the limit for the actual most recent events, and - # the other half to randomly point to some of the older events, to - # make sure that we don't completely ignore the older events. - res = res[0:5] + random.sample(res[5:], 5) - - return res - def get_prev_events_for_room(self, room_id: str): """ Gets a subset of the current forward extremities in the given room. diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 3a68bf3274..a331517f4d 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -26,7 +26,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): self.store = hs.get_datastore() @defer.inlineCallbacks - def test_get_prev_events_and_hashes_for_room(self): + def test_get_prev_events_for_room(self): room_id = "@ROOM:local" # add a bunch of events and hashes to act as forward extremities @@ -60,21 +60,14 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): (event_id, bytearray(b"ffff")), ) - for i in range(0, 11): + for i in range(0, 20): yield self.store.db.runInteraction("insert", insert_event, i) - # this should get the last five and five others - r = yield self.store.get_prev_events_and_hashes_for_room(room_id) + # this should get the last ten + r = yield self.store.get_prev_events_for_room(room_id) self.assertEqual(10, len(r)) - for i in range(0, 5): - el = r[i] - depth = el[2] - self.assertEqual(10 - i, depth) - - for i in range(5, 5): - el = r[i] - depth = el[2] - self.assertLessEqual(5, depth) + for i in range(0, 10): + self.assertEqual("$event_%i:local" % (19 - i), r[i]) @defer.inlineCallbacks def test_get_rooms_with_many_extremities(self): -- cgit 1.5.1 From d20c3465441cd64ba3a1e84ee399bbadc0997bdf Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 7 Jan 2020 14:09:07 +0000 Subject: port BackgroundUpdateTestCase to HomeserverTestCase (#6653) --- changelog.d/6653.misc | 1 + tests/storage/test_background_update.py | 72 +++++++++++++++++---------------- 2 files changed, 38 insertions(+), 35 deletions(-) create mode 100644 changelog.d/6653.misc (limited to 'tests') diff --git a/changelog.d/6653.misc b/changelog.d/6653.misc new file mode 100644 index 0000000000..fbe7c0e7db --- /dev/null +++ b/changelog.d/6653.misc @@ -0,0 +1 @@ +Port core background update routines to async/await. diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index aec76f4ab1..ae14fb407d 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -2,44 +2,37 @@ from mock import Mock from twisted.internet import defer +from synapse.storage.background_updates import BackgroundUpdater + from tests import unittest -from tests.utils import setup_test_homeserver -class BackgroundUpdateTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield setup_test_homeserver(self.addCleanup) - self.store = hs.get_datastore() - self.clock = hs.get_clock() +class BackgroundUpdateTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, homeserver): + self.updates = self.hs.get_datastore().db.updates # type: BackgroundUpdater + # the base test class should have run the real bg updates for us + self.assertTrue(self.updates.has_completed_background_updates()) self.update_handler = Mock() - - yield self.store.db.updates.register_background_update_handler( + self.updates.register_background_update_handler( "test_update", self.update_handler ) - # run the real background updates, to get them out the way - # (perhaps we should run them as part of the test HS setup, since we - # run all of the other schema setup stuff there?) - while True: - res = yield self.store.db.updates.do_next_background_update(1000) - if res is None: - break - - @defer.inlineCallbacks def test_do_background_update(self): - desired_count = 1000 + # the time we claim each update takes duration_ms = 42 + # the target runtime for each bg update + target_background_update_duration_ms = 50000 + # first step: make a bit of progress @defer.inlineCallbacks def update(progress, count): - self.clock.advance_time_msec(count * duration_ms) + yield self.clock.sleep((count * duration_ms) / 1000) progress = {"my_key": progress["my_key"] + 1} - yield self.store.db.runInteraction( + yield self.hs.get_datastore().db.runInteraction( "update_progress", - self.store.db.updates._background_update_progress_txn, + self.updates._background_update_progress_txn, "test_update", progress, ) @@ -47,37 +40,46 @@ class BackgroundUpdateTestCase(unittest.TestCase): self.update_handler.side_effect = update - yield self.store.db.updates.start_background_update( - "test_update", {"my_key": 1} + self.get_success( + self.updates.start_background_update("test_update", {"my_key": 1}) ) - self.update_handler.reset_mock() - result = yield self.store.db.updates.do_next_background_update( - duration_ms * desired_count + res = self.get_success( + self.updates.do_next_background_update( + target_background_update_duration_ms + ), + by=0.1, ) - self.assertIsNotNone(result) + self.assertIsNotNone(res) + + # on the first call, we should get run with the default background update size self.update_handler.assert_called_once_with( - {"my_key": 1}, self.store.db.updates.DEFAULT_BACKGROUND_BATCH_SIZE + {"my_key": 1}, self.updates.DEFAULT_BACKGROUND_BATCH_SIZE ) # second step: complete the update + # we should now get run with a much bigger number of items to update @defer.inlineCallbacks def update(progress, count): - yield self.store.db.updates._end_background_update("test_update") + self.assertEqual(progress, {"my_key": 2}) + self.assertAlmostEqual( + count, target_background_update_duration_ms / duration_ms, places=0, + ) + yield self.updates._end_background_update("test_update") return count self.update_handler.side_effect = update self.update_handler.reset_mock() - result = yield self.store.db.updates.do_next_background_update( - duration_ms * desired_count + result = self.get_success( + self.updates.do_next_background_update(target_background_update_duration_ms) ) self.assertIsNotNone(result) - self.update_handler.assert_called_once_with({"my_key": 2}, desired_count) + self.update_handler.assert_called_once() # third step: we don't expect to be called any more self.update_handler.reset_mock() - result = yield self.store.db.updates.do_next_background_update( - duration_ms * desired_count + result = self.get_success( + self.updates.do_next_background_update(target_background_update_duration_ms) ) self.assertIsNone(result) self.assertFalse(self.update_handler.called) -- cgit 1.5.1 From 573fee759cbd76fca93bf90783cd013a11b9b4e5 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 8 Jan 2020 13:24:10 +0000 Subject: Back out ill-advised notary server hackery (#6657) This was ill-advised. We can't modify verify_keys here, because the response object has already been signed by the requested key. Furthermore, it's somewhat unnecessary because existing versions of Synapse (which get upset that the notary key isn't present in verify_keys) will fall back to a direct fetch via `/key/v2/server`. Also: more tests for fetching keys via perspectives: it would be nice if we actually tested when our fetcher can't talk to our notary impl. --- changelog.d/6657.bugfix | 1 + synapse/rest/key/v2/remote_key_resource.py | 30 ++---- tests/rest/key/__init__.py | 0 tests/rest/key/v2/__init__.py | 0 tests/rest/key/v2/test_remote_key_resource.py | 135 +++++++++++++++++++++++++- 5 files changed, 140 insertions(+), 26 deletions(-) create mode 100644 changelog.d/6657.bugfix create mode 100644 tests/rest/key/__init__.py create mode 100644 tests/rest/key/v2/__init__.py (limited to 'tests') diff --git a/changelog.d/6657.bugfix b/changelog.d/6657.bugfix new file mode 100644 index 0000000000..94e51a9896 --- /dev/null +++ b/changelog.d/6657.bugfix @@ -0,0 +1 @@ +Fix incorrect signing of responses from the key server implementation. \ No newline at end of file diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index bf5e0eb844..e7fc3f0431 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -15,7 +15,6 @@ import logging from canonicaljson import encode_canonical_json, json -from signedjson.key import encode_verify_key_base64 from signedjson.sign import sign_json from twisted.internet import defer @@ -217,28 +216,15 @@ class RemoteKey(DirectServeResource): if cache_misses and query_remote_on_cache_miss: yield self.fetcher.get_keys(cache_misses) yield self.query_keys(request, query, query_remote_on_cache_miss=False) - return - - signed_keys = [] - for key_json in json_results: - key_json = json.loads(key_json) - - # backwards-compatibility hack for #6596: if the requested key belongs - # to us, make sure that all of the signing keys appear in the - # "verify_keys" section. - if key_json["server_name"] == self.config.server_name: - verify_keys = key_json["verify_keys"] + else: + signed_keys = [] + for key_json in json_results: + key_json = json.loads(key_json) for signing_key in self.config.key_server_signing_keys: - key_id = "%s:%s" % (signing_key.alg, signing_key.version) - verify_keys[key_id] = { - "key": encode_verify_key_base64(signing_key.verify_key) - } - - for signing_key in self.config.key_server_signing_keys: - key_json = sign_json(key_json, self.config.server_name, signing_key) + key_json = sign_json(key_json, self.config.server_name, signing_key) - signed_keys.append(key_json) + signed_keys.append(key_json) - results = {"server_keys": signed_keys} + results = {"server_keys": signed_keys} - respond_with_json_bytes(request, 200, encode_canonical_json(results)) + respond_with_json_bytes(request, 200, encode_canonical_json(results)) diff --git a/tests/rest/key/__init__.py b/tests/rest/key/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/rest/key/v2/__init__.py b/tests/rest/key/v2/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py index d8246b4e78..6776a56cad 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -13,25 +13,30 @@ # See the License for the specific language governing permissions and # limitations under the License. import urllib.parse -from io import BytesIO +from io import BytesIO, StringIO from mock import Mock import signedjson.key +from canonicaljson import encode_canonical_json from nacl.signing import SigningKey from signedjson.sign import sign_json from twisted.web.resource import NoResource +from synapse.crypto.keyring import PerspectivesKeyFetcher from synapse.http.site import SynapseRequest from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.storage.keys import FetchKeyResult from synapse.util.httpresourcetree import create_resource_tree +from synapse.util.stringutils import random_string from tests import unittest from tests.server import FakeChannel, wait_until_result +from tests.utils import default_config -class RemoteKeyResourceTestCase(unittest.HomeserverTestCase): +class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.http_client = Mock() return self.setup_test_homeserver(http_client=self.http_client) @@ -73,6 +78,8 @@ class RemoteKeyResourceTestCase(unittest.HomeserverTestCase): self.http_client.get_json.side_effect = get_json + +class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase): def make_notary_request(self, server_name: str, key_id: str) -> dict: """Send a GET request to the test server requesting the given key. @@ -125,6 +132,126 @@ class RemoteKeyResourceTestCase(unittest.HomeserverTestCase): oursigs = sigs[self.hs.hostname] self.assertEqual(len(oursigs), 2) - # and both keys should be present in the verify_keys section + # the requested key should be present in the verify_keys section self.assertIn("ed25519:ver1", keys[0]["verify_keys"]) - self.assertIn("ed25519:a_lPym", keys[0]["verify_keys"]) + + +class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): + """End-to-end tests of the perspectives fetch case + + The idea here is to actually wire up a PerspectivesKeyFetcher to the notary + endpoint, to check that the two implementations are compatible. + """ + + def default_config(self, *args, **kwargs): + config = super().default_config(*args, **kwargs) + + # replace the signing key with our own + self.hs_signing_key = signedjson.key.generate_signing_key("kssk") + strm = StringIO() + signedjson.key.write_signing_keys(strm, [self.hs_signing_key]) + config["signing_key"] = strm.getvalue() + + return config + + def prepare(self, reactor, clock, homeserver): + # make a second homeserver, configured to use the first one as a key notary + self.http_client2 = Mock() + config = default_config(name="keyclient") + config["trusted_key_servers"] = [ + { + "server_name": self.hs.hostname, + "verify_keys": { + "ed25519:%s" + % ( + self.hs_signing_key.version, + ): signedjson.key.encode_verify_key_base64( + self.hs_signing_key.verify_key + ) + }, + } + ] + self.hs2 = self.setup_test_homeserver( + http_client=self.http_client2, config=config + ) + + # wire up outbound POST /key/v2/query requests from hs2 so that they + # will be forwarded to hs1 + def post_json(destination, path, data): + self.assertEqual(destination, self.hs.hostname) + self.assertEqual( + path, "/_matrix/key/v2/query", + ) + + channel = FakeChannel(self.site, self.reactor) + req = SynapseRequest(channel) + req.content = BytesIO(encode_canonical_json(data)) + + req.requestReceived( + b"POST", path.encode("utf-8"), b"1.1", + ) + wait_until_result(self.reactor, req) + self.assertEqual(channel.code, 200) + resp = channel.json_body + return resp + + self.http_client2.post_json.side_effect = post_json + + def test_get_key(self): + """Fetch a key belonging to a random server""" + # make up a key to be fetched. + testkey = signedjson.key.generate_signing_key("abc") + + # we expect hs1 to make a regular key request to the target server + self.expect_outgoing_key_request("targetserver", testkey) + keyid = "ed25519:%s" % (testkey.version,) + + fetcher = PerspectivesKeyFetcher(self.hs2) + d = fetcher.get_keys({"targetserver": {keyid: 1000}}) + res = self.get_success(d) + self.assertIn("targetserver", res) + keyres = res["targetserver"][keyid] + assert isinstance(keyres, FetchKeyResult) + self.assertEqual( + signedjson.key.encode_verify_key_base64(keyres.verify_key), + signedjson.key.encode_verify_key_base64(testkey.verify_key), + ) + + def test_get_notary_key(self): + """Fetch a key belonging to the notary server""" + # make up a key to be fetched. We randomise the keyid to try to get it to + # appear before the key server signing key sometimes (otherwise we bail out + # before fetching its signature) + testkey = signedjson.key.generate_signing_key(random_string(5)) + + # we expect hs1 to make a regular key request to itself + self.expect_outgoing_key_request(self.hs.hostname, testkey) + keyid = "ed25519:%s" % (testkey.version,) + + fetcher = PerspectivesKeyFetcher(self.hs2) + d = fetcher.get_keys({self.hs.hostname: {keyid: 1000}}) + res = self.get_success(d) + self.assertIn(self.hs.hostname, res) + keyres = res[self.hs.hostname][keyid] + assert isinstance(keyres, FetchKeyResult) + self.assertEqual( + signedjson.key.encode_verify_key_base64(keyres.verify_key), + signedjson.key.encode_verify_key_base64(testkey.verify_key), + ) + + def test_get_notary_keyserver_key(self): + """Fetch the notary's keyserver key""" + # we expect hs1 to make a regular key request to itself + self.expect_outgoing_key_request(self.hs.hostname, self.hs_signing_key) + keyid = "ed25519:%s" % (self.hs_signing_key.version,) + + fetcher = PerspectivesKeyFetcher(self.hs2) + d = fetcher.get_keys({self.hs.hostname: {keyid: 1000}}) + res = self.get_success(d) + self.assertIn(self.hs.hostname, res) + keyres = res[self.hs.hostname][keyid] + assert isinstance(keyres, FetchKeyResult) + self.assertEqual( + signedjson.key.encode_verify_key_base64(keyres.verify_key), + signedjson.key.encode_verify_key_base64(self.hs_signing_key.verify_key), + ) -- cgit 1.5.1 From 7caaa29daab7bde331dcec0bb760a8fe5870f18e Mon Sep 17 00:00:00 2001 From: Manuel Stahl <37705355+awesome-manuel@users.noreply.github.com> Date: Wed, 8 Jan 2020 14:26:40 +0100 Subject: Fix GET request on /_synapse/admin/v2/users endpoint (#6563) Fixes #6552 --- changelog.d/6563.bugfix | 1 + synapse/storage/data_stores/main/__init__.py | 4 +-- tests/rest/admin/test_admin.py | 41 ++++++++++++++++++++++++++++ 3 files changed, 44 insertions(+), 2 deletions(-) create mode 100644 changelog.d/6563.bugfix (limited to 'tests') diff --git a/changelog.d/6563.bugfix b/changelog.d/6563.bugfix new file mode 100644 index 0000000000..3325fb1dcf --- /dev/null +++ b/changelog.d/6563.bugfix @@ -0,0 +1 @@ +Fix GET request on /_synapse/admin/v2/users endpoint. Contributed by Awesome Technologies Innovationslabor GmbH. \ No newline at end of file diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index c577c0df5f..2700cca822 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -526,9 +526,9 @@ class DataStore( attr_filter = {} if not guests: - attr_filter["is_guest"] = False + attr_filter["is_guest"] = 0 if not deactivated: - attr_filter["deactivated"] = False + attr_filter["deactivated"] = 0 return self.db.simple_select_list_paginate( desc="get_users_paginate", diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 0ed2594381..325bd6a608 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -341,6 +341,47 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): self.assertEqual("Invalid user type", channel.json_body["error"]) +class UsersListTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + url = "/_synapse/admin/v2/users" + + def prepare(self, reactor, clock, hs): + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.register_user("user1", "pass1", admin=False) + self.register_user("user2", "pass2", admin=False) + + def test_no_auth(self): + """ + Try to list users without authentication. + """ + request, channel = self.make_request("GET", self.url, b"{}") + self.render(request) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("M_MISSING_TOKEN", channel.json_body["errcode"]) + + def test_all_users(self): + """ + List all users, including deactivated users. + """ + request, channel = self.make_request( + "GET", + self.url + "?deactivated=true", + b"{}", + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(3, len(channel.json_body["users"])) + + class ShutdownRoomTestCase(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, -- cgit 1.5.1 From d2906fe6667d3384f37ef03ca87172d643d49587 Mon Sep 17 00:00:00 2001 From: Manuel Stahl <37705355+awesome-manuel@users.noreply.github.com> Date: Thu, 9 Jan 2020 14:31:00 +0100 Subject: Allow admin users to create or modify users without a shared secret (#6495) Signed-off-by: Manuel Stahl --- changelog.d/5742.feature | 1 + docs/admin_api/user_admin_api.rst | 33 +- synapse/handlers/admin.py | 9 + synapse/rest/admin/__init__.py | 2 + synapse/rest/admin/users.py | 142 +++++++ synapse/storage/data_stores/main/registration.py | 2 + tests/rest/admin/test_admin.py | 338 ---------------- tests/rest/admin/test_user.py | 465 +++++++++++++++++++++++ tests/storage/test_registration.py | 2 + 9 files changed, 655 insertions(+), 339 deletions(-) create mode 100644 changelog.d/5742.feature create mode 100644 tests/rest/admin/test_user.py (limited to 'tests') diff --git a/changelog.d/5742.feature b/changelog.d/5742.feature new file mode 100644 index 0000000000..de10302275 --- /dev/null +++ b/changelog.d/5742.feature @@ -0,0 +1 @@ +Allow admin to create or modify a user. Contributed by Awesome Technologies Innovationslabor GmbH. diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst index b451dc5014..0b3d09d694 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst @@ -1,3 +1,33 @@ +Create or modify Account +======================== + +This API allows an administrator to create or modify a user account with a +specific ``user_id``. + +This api is:: + + PUT /_synapse/admin/v2/users/ + +with a body of: + +.. code:: json + + { + "password": "user_password", + "displayname": "User", + "avatar_url": "", + "admin": false, + "deactivated": false + } + +including an ``access_token`` of a server admin. + +The parameter ``displayname`` is optional and defaults to ``user_id``. +The parameter ``avatar_url`` is optional. +The parameter ``admin`` is optional and defaults to 'false'. +The parameter ``deactivated`` is optional and defaults to 'false'. +If the user already exists then optional parameters default to the current value. + List Accounts ============= @@ -50,7 +80,8 @@ This API returns information about a specific user account. The api is:: - GET /_synapse/admin/v1/whois/ + GET /_synapse/admin/v1/whois/ (deprecated) + GET /_synapse/admin/v2/users/ including an ``access_token`` of a server admin. diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 1a4ba12385..76d18a8ba8 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -51,6 +51,15 @@ class AdminHandler(BaseHandler): return ret + async def get_user(self, user): + """Function to get user details""" + ret = await self.store.get_user_by_id(user.to_string()) + if ret: + profile = await self.store.get_profileinfo(user.localpart) + ret["displayname"] = profile.display_name + ret["avatar_url"] = profile.avatar_url + return ret + async def get_users(self): """Function to retrieve a list of users in users table. diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index c122c449f4..a10b4a9b72 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -38,6 +38,7 @@ from synapse.rest.admin.users import ( SearchUsersRestServlet, UserAdminServlet, UserRegisterServlet, + UserRestServletV2, UsersRestServlet, UsersRestServletV2, WhoisRestServlet, @@ -191,6 +192,7 @@ def register_servlets(hs, http_server): SendServerNoticeServlet(hs).register(http_server) VersionServlet(hs).register(http_server) UserAdminServlet(hs).register(http_server) + UserRestServletV2(hs).register(http_server) UsersRestServletV2(hs).register(http_server) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 1937879dbe..574cb90c74 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -102,6 +102,148 @@ class UsersRestServletV2(RestServlet): return 200, ret +class UserRestServletV2(RestServlet): + PATTERNS = (re.compile("^/_synapse/admin/v2/users/(?P@[^/]+)$"),) + + """Get request to list user details. + This needs user to have administrator access in Synapse. + + GET /_synapse/admin/v2/users/ + + returns: + 200 OK with user details if success otherwise an error. + + Put request to allow an administrator to add or modify a user. + This needs user to have administrator access in Synapse. + We use PUT instead of POST since we already know the id of the user + object to create. POST could be used to create guests. + + PUT /_synapse/admin/v2/users/ + { + "password": "secret", + "displayname": "User" + } + + returns: + 201 OK with new user object if user was created or + 200 OK with modified user object if user was modified + otherwise an error. + """ + + def __init__(self, hs): + self.hs = hs + self.auth = hs.get_auth() + self.admin_handler = hs.get_handlers().admin_handler + self.profile_handler = hs.get_profile_handler() + self.set_password_handler = hs.get_set_password_handler() + self.deactivate_account_handler = hs.get_deactivate_account_handler() + self.registration_handler = hs.get_registration_handler() + + async def on_GET(self, request, user_id): + await assert_requester_is_admin(self.auth, request) + + target_user = UserID.from_string(user_id) + if not self.hs.is_mine(target_user): + raise SynapseError(400, "Can only lookup local users") + + ret = await self.admin_handler.get_user(target_user) + + return 200, ret + + async def on_PUT(self, request, user_id): + await assert_requester_is_admin(self.auth, request) + + target_user = UserID.from_string(user_id) + body = parse_json_object_from_request(request) + + if not self.hs.is_mine(target_user): + raise SynapseError(400, "This endpoint can only be used with local users") + + user = await self.admin_handler.get_user(target_user) + + if user: # modify user + requester = await self.auth.get_user_by_req(request) + + if "displayname" in body: + await self.profile_handler.set_displayname( + target_user, requester, body["displayname"], True + ) + + if "avatar_url" in body: + await self.profile_handler.set_avatar_url( + target_user, requester, body["avatar_url"], True + ) + + if "admin" in body: + set_admin_to = bool(body["admin"]) + if set_admin_to != user["admin"]: + auth_user = requester.user + if target_user == auth_user and not set_admin_to: + raise SynapseError(400, "You may not demote yourself.") + + await self.admin_handler.set_user_server_admin( + target_user, set_admin_to + ) + + if "password" in body: + if ( + not isinstance(body["password"], text_type) + or len(body["password"]) > 512 + ): + raise SynapseError(400, "Invalid password") + else: + new_password = body["password"] + await self._set_password_handler.set_password( + target_user, new_password, requester + ) + + if "deactivated" in body: + deactivate = bool(body["deactivated"]) + if deactivate and not user["deactivated"]: + result = await self.deactivate_account_handler.deactivate_account( + target_user.to_string(), False + ) + if not result: + raise SynapseError(500, "Could not deactivate user") + + user = await self.admin_handler.get_user(target_user) + return 200, user + + else: # create user + if "password" not in body: + raise SynapseError( + 400, "password must be specified", errcode=Codes.BAD_JSON + ) + elif ( + not isinstance(body["password"], text_type) + or len(body["password"]) > 512 + ): + raise SynapseError(400, "Invalid password") + + admin = body.get("admin", None) + user_type = body.get("user_type", None) + displayname = body.get("displayname", None) + + if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES: + raise SynapseError(400, "Invalid user type") + + user_id = await self.registration_handler.register_user( + localpart=target_user.localpart, + password=body["password"], + admin=bool(admin), + default_display_name=displayname, + user_type=user_type, + ) + if "avatar_url" in body: + await self.profile_handler.set_avatar_url( + user_id, requester, body["avatar_url"], True + ) + + ret = await self.admin_handler.get_user(target_user) + + return 201, ret + + class UserRegisterServlet(RestServlet): """ Attributes: diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py index 5e8ecac0ea..cb4b2b39a0 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/data_stores/main/registration.py @@ -52,11 +52,13 @@ class RegistrationWorkerStore(SQLBaseStore): "name", "password_hash", "is_guest", + "admin", "consent_version", "consent_server_notice_sent", "appservice_id", "creation_ts", "user_type", + "deactivated", ], allow_none=True, desc="get_user_by_id", diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 325bd6a608..6ceb483aa8 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -13,14 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import hashlib -import hmac import json from mock import Mock import synapse.rest.admin -from synapse.api.constants import UserTypes from synapse.http.server import JsonResource from synapse.rest.admin import VersionServlet from synapse.rest.client.v1 import events, login, room @@ -47,341 +44,6 @@ class VersionTestCase(unittest.HomeserverTestCase): ) -class UserRegisterTestCase(unittest.HomeserverTestCase): - - servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource] - - def make_homeserver(self, reactor, clock): - - self.url = "/_matrix/client/r0/admin/register" - - self.registration_handler = Mock() - self.identity_handler = Mock() - self.login_handler = Mock() - self.device_handler = Mock() - self.device_handler.check_device_registered = Mock(return_value="FAKE") - - self.datastore = Mock(return_value=Mock()) - self.datastore.get_current_state_deltas = Mock(return_value=(0, [])) - - self.secrets = Mock() - - self.hs = self.setup_test_homeserver() - - self.hs.config.registration_shared_secret = "shared" - - self.hs.get_media_repository = Mock() - self.hs.get_deactivate_account_handler = Mock() - - return self.hs - - def test_disabled(self): - """ - If there is no shared secret, registration through this method will be - prevented. - """ - self.hs.config.registration_shared_secret = None - - request, channel = self.make_request("POST", self.url, b"{}") - self.render(request) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual( - "Shared secret registration is not enabled", channel.json_body["error"] - ) - - def test_get_nonce(self): - """ - Calling GET on the endpoint will return a randomised nonce, using the - homeserver's secrets provider. - """ - secrets = Mock() - secrets.token_hex = Mock(return_value="abcd") - - self.hs.get_secrets = Mock(return_value=secrets) - - request, channel = self.make_request("GET", self.url) - self.render(request) - - self.assertEqual(channel.json_body, {"nonce": "abcd"}) - - def test_expired_nonce(self): - """ - Calling GET on the endpoint will return a randomised nonce, which will - only last for SALT_TIMEOUT (60s). - """ - request, channel = self.make_request("GET", self.url) - self.render(request) - nonce = channel.json_body["nonce"] - - # 59 seconds - self.reactor.advance(59) - - body = json.dumps({"nonce": nonce}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("username must be specified", channel.json_body["error"]) - - # 61 seconds - self.reactor.advance(2) - - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("unrecognised nonce", channel.json_body["error"]) - - def test_register_incorrect_nonce(self): - """ - Only the provided nonce can be used, as it's checked in the MAC. - """ - request, channel = self.make_request("GET", self.url) - self.render(request) - nonce = channel.json_body["nonce"] - - want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) - want_mac.update(b"notthenonce\x00bob\x00abc123\x00admin") - want_mac = want_mac.hexdigest() - - body = json.dumps( - { - "nonce": nonce, - "username": "bob", - "password": "abc123", - "admin": True, - "mac": want_mac, - } - ) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) - - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("HMAC incorrect", channel.json_body["error"]) - - def test_register_correct_nonce(self): - """ - When the correct nonce is provided, and the right key is provided, the - user is registered. - """ - request, channel = self.make_request("GET", self.url) - self.render(request) - nonce = channel.json_body["nonce"] - - want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) - want_mac.update( - nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin\x00support" - ) - want_mac = want_mac.hexdigest() - - body = json.dumps( - { - "nonce": nonce, - "username": "bob", - "password": "abc123", - "admin": True, - "user_type": UserTypes.SUPPORT, - "mac": want_mac, - } - ) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("@bob:test", channel.json_body["user_id"]) - - def test_nonce_reuse(self): - """ - A valid unrecognised nonce. - """ - request, channel = self.make_request("GET", self.url) - self.render(request) - nonce = channel.json_body["nonce"] - - want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) - want_mac.update(nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin") - want_mac = want_mac.hexdigest() - - body = json.dumps( - { - "nonce": nonce, - "username": "bob", - "password": "abc123", - "admin": True, - "mac": want_mac, - } - ) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("@bob:test", channel.json_body["user_id"]) - - # Now, try and reuse it - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("unrecognised nonce", channel.json_body["error"]) - - def test_missing_parts(self): - """ - Synapse will complain if you don't give nonce, username, password, and - mac. Admin and user_types are optional. Additional checks are done for length - and type. - """ - - def nonce(): - request, channel = self.make_request("GET", self.url) - self.render(request) - return channel.json_body["nonce"] - - # - # Nonce check - # - - # Must be present - body = json.dumps({}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("nonce must be specified", channel.json_body["error"]) - - # - # Username checks - # - - # Must be present - body = json.dumps({"nonce": nonce()}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("username must be specified", channel.json_body["error"]) - - # Must be a string - body = json.dumps({"nonce": nonce(), "username": 1234}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("Invalid username", channel.json_body["error"]) - - # Must not have null bytes - body = json.dumps({"nonce": nonce(), "username": "abcd\u0000"}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("Invalid username", channel.json_body["error"]) - - # Must not have null bytes - body = json.dumps({"nonce": nonce(), "username": "a" * 1000}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("Invalid username", channel.json_body["error"]) - - # - # Password checks - # - - # Must be present - body = json.dumps({"nonce": nonce(), "username": "a"}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("password must be specified", channel.json_body["error"]) - - # Must be a string - body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("Invalid password", channel.json_body["error"]) - - # Must not have null bytes - body = json.dumps({"nonce": nonce(), "username": "a", "password": "abcd\u0000"}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("Invalid password", channel.json_body["error"]) - - # Super long - body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("Invalid password", channel.json_body["error"]) - - # - # user_type check - # - - # Invalid user_type - body = json.dumps( - { - "nonce": nonce(), - "username": "a", - "password": "1234", - "user_type": "invalid", - } - ) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("Invalid user type", channel.json_body["error"]) - - -class UsersListTestCase(unittest.HomeserverTestCase): - - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - ] - url = "/_synapse/admin/v2/users" - - def prepare(self, reactor, clock, hs): - self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") - - self.register_user("user1", "pass1", admin=False) - self.register_user("user2", "pass2", admin=False) - - def test_no_auth(self): - """ - Try to list users without authentication. - """ - request, channel = self.make_request("GET", self.url, b"{}") - self.render(request) - - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("M_MISSING_TOKEN", channel.json_body["errcode"]) - - def test_all_users(self): - """ - List all users, including deactivated users. - """ - request, channel = self.make_request( - "GET", - self.url + "?deactivated=true", - b"{}", - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(3, len(channel.json_body["users"])) - - class ShutdownRoomTestCase(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py new file mode 100644 index 0000000000..7352d609e6 --- /dev/null +++ b/tests/rest/admin/test_user.py @@ -0,0 +1,465 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +import hmac +import json + +from mock import Mock + +import synapse.rest.admin +from synapse.api.constants import UserTypes +from synapse.rest.client.v1 import login + +from tests import unittest + + +class UserRegisterTestCase(unittest.HomeserverTestCase): + + servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource] + + def make_homeserver(self, reactor, clock): + + self.url = "/_matrix/client/r0/admin/register" + + self.registration_handler = Mock() + self.identity_handler = Mock() + self.login_handler = Mock() + self.device_handler = Mock() + self.device_handler.check_device_registered = Mock(return_value="FAKE") + + self.datastore = Mock(return_value=Mock()) + self.datastore.get_current_state_deltas = Mock(return_value=(0, [])) + + self.secrets = Mock() + + self.hs = self.setup_test_homeserver() + + self.hs.config.registration_shared_secret = "shared" + + self.hs.get_media_repository = Mock() + self.hs.get_deactivate_account_handler = Mock() + + return self.hs + + def test_disabled(self): + """ + If there is no shared secret, registration through this method will be + prevented. + """ + self.hs.config.registration_shared_secret = None + + request, channel = self.make_request("POST", self.url, b"{}") + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + "Shared secret registration is not enabled", channel.json_body["error"] + ) + + def test_get_nonce(self): + """ + Calling GET on the endpoint will return a randomised nonce, using the + homeserver's secrets provider. + """ + secrets = Mock() + secrets.token_hex = Mock(return_value="abcd") + + self.hs.get_secrets = Mock(return_value=secrets) + + request, channel = self.make_request("GET", self.url) + self.render(request) + + self.assertEqual(channel.json_body, {"nonce": "abcd"}) + + def test_expired_nonce(self): + """ + Calling GET on the endpoint will return a randomised nonce, which will + only last for SALT_TIMEOUT (60s). + """ + request, channel = self.make_request("GET", self.url) + self.render(request) + nonce = channel.json_body["nonce"] + + # 59 seconds + self.reactor.advance(59) + + body = json.dumps({"nonce": nonce}) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("username must be specified", channel.json_body["error"]) + + # 61 seconds + self.reactor.advance(2) + + request, channel = self.make_request("POST", self.url, body.encode("utf8")) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("unrecognised nonce", channel.json_body["error"]) + + def test_register_incorrect_nonce(self): + """ + Only the provided nonce can be used, as it's checked in the MAC. + """ + request, channel = self.make_request("GET", self.url) + self.render(request) + nonce = channel.json_body["nonce"] + + want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) + want_mac.update(b"notthenonce\x00bob\x00abc123\x00admin") + want_mac = want_mac.hexdigest() + + body = json.dumps( + { + "nonce": nonce, + "username": "bob", + "password": "abc123", + "admin": True, + "mac": want_mac, + } + ) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("HMAC incorrect", channel.json_body["error"]) + + def test_register_correct_nonce(self): + """ + When the correct nonce is provided, and the right key is provided, the + user is registered. + """ + request, channel = self.make_request("GET", self.url) + self.render(request) + nonce = channel.json_body["nonce"] + + want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) + want_mac.update( + nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin\x00support" + ) + want_mac = want_mac.hexdigest() + + body = json.dumps( + { + "nonce": nonce, + "username": "bob", + "password": "abc123", + "admin": True, + "user_type": UserTypes.SUPPORT, + "mac": want_mac, + } + ) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["user_id"]) + + def test_nonce_reuse(self): + """ + A valid unrecognised nonce. + """ + request, channel = self.make_request("GET", self.url) + self.render(request) + nonce = channel.json_body["nonce"] + + want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) + want_mac.update(nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin") + want_mac = want_mac.hexdigest() + + body = json.dumps( + { + "nonce": nonce, + "username": "bob", + "password": "abc123", + "admin": True, + "mac": want_mac, + } + ) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["user_id"]) + + # Now, try and reuse it + request, channel = self.make_request("POST", self.url, body.encode("utf8")) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("unrecognised nonce", channel.json_body["error"]) + + def test_missing_parts(self): + """ + Synapse will complain if you don't give nonce, username, password, and + mac. Admin and user_types are optional. Additional checks are done for length + and type. + """ + + def nonce(): + request, channel = self.make_request("GET", self.url) + self.render(request) + return channel.json_body["nonce"] + + # + # Nonce check + # + + # Must be present + body = json.dumps({}) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("nonce must be specified", channel.json_body["error"]) + + # + # Username checks + # + + # Must be present + body = json.dumps({"nonce": nonce()}) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("username must be specified", channel.json_body["error"]) + + # Must be a string + body = json.dumps({"nonce": nonce(), "username": 1234}) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("Invalid username", channel.json_body["error"]) + + # Must not have null bytes + body = json.dumps({"nonce": nonce(), "username": "abcd\u0000"}) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("Invalid username", channel.json_body["error"]) + + # Must not have null bytes + body = json.dumps({"nonce": nonce(), "username": "a" * 1000}) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("Invalid username", channel.json_body["error"]) + + # + # Password checks + # + + # Must be present + body = json.dumps({"nonce": nonce(), "username": "a"}) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("password must be specified", channel.json_body["error"]) + + # Must be a string + body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234}) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("Invalid password", channel.json_body["error"]) + + # Must not have null bytes + body = json.dumps({"nonce": nonce(), "username": "a", "password": "abcd\u0000"}) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("Invalid password", channel.json_body["error"]) + + # Super long + body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000}) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("Invalid password", channel.json_body["error"]) + + # + # user_type check + # + + # Invalid user_type + body = json.dumps( + { + "nonce": nonce(), + "username": "a", + "password": "1234", + "user_type": "invalid", + } + ) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("Invalid user type", channel.json_body["error"]) + + +class UsersListTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + url = "/_synapse/admin/v2/users" + + def prepare(self, reactor, clock, hs): + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.register_user("user1", "pass1", admin=False) + self.register_user("user2", "pass2", admin=False) + + def test_no_auth(self): + """ + Try to list users without authentication. + """ + request, channel = self.make_request("GET", self.url, b"{}") + self.render(request) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("M_MISSING_TOKEN", channel.json_body["errcode"]) + + def test_all_users(self): + """ + List all users, including deactivated users. + """ + request, channel = self.make_request( + "GET", + self.url + "?deactivated=true", + b"{}", + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(3, len(channel.json_body["users"])) + + +class UserRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.url = "/_synapse/admin/v2/users/@bob:test" + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_token = self.login("user", "pass") + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + self.hs.config.registration_shared_secret = None + + request, channel = self.make_request( + "GET", self.url, access_token=self.other_user_token, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("You are not a server admin", channel.json_body["error"]) + + request, channel = self.make_request( + "PUT", self.url, access_token=self.other_user_token, content=b"{}", + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("You are not a server admin", channel.json_body["error"]) + + def test_requester_is_admin(self): + """ + If the user is a server admin, a new user is created. + """ + self.hs.config.registration_shared_secret = None + + body = json.dumps({"password": "abc123", "admin": True}) + + # Create user + request, channel = self.make_request( + "PUT", + self.url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual("bob", channel.json_body["displayname"]) + + # Get user + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual("bob", channel.json_body["displayname"]) + self.assertEqual(1, channel.json_body["admin"]) + self.assertEqual(0, channel.json_body["is_guest"]) + self.assertEqual(0, channel.json_body["deactivated"]) + + # Modify user + body = json.dumps({"displayname": "foobar", "deactivated": True}) + + request, channel = self.make_request( + "PUT", + self.url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual("foobar", channel.json_body["displayname"]) + self.assertEqual(True, channel.json_body["deactivated"]) + + # Get user + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual("foobar", channel.json_body["displayname"]) + self.assertEqual(1, channel.json_body["admin"]) + self.assertEqual(0, channel.json_body["is_guest"]) + self.assertEqual(1, channel.json_body["deactivated"]) diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index ed5786865a..71a40a0a49 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -43,12 +43,14 @@ class RegistrationStoreTestCase(unittest.TestCase): # TODO(paul): Surely this field should be 'user_id', not 'name' "name": self.user_id, "password_hash": self.pwhash, + "admin": 0, "is_guest": 0, "consent_version": None, "consent_server_notice_sent": None, "appservice_id": None, "creation_ts": 1000, "user_type": None, + "deactivated": 0, }, (yield self.store.get_user_by_id(self.user_id)), ) -- cgit 1.5.1 From 326c893d24e0f03b62bd9d1136a335f329bb8528 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 13 Jan 2020 12:48:22 +0000 Subject: Kill off RegistrationError (#6691) This is pretty pointless. Let's just use SynapseError. --- changelog.d/6691.misc | 1 + synapse/api/errors.py | 6 ------ synapse/handlers/register.py | 12 +++--------- tests/handlers/test_register.py | 2 -- 4 files changed, 4 insertions(+), 17 deletions(-) create mode 100644 changelog.d/6691.misc (limited to 'tests') diff --git a/changelog.d/6691.misc b/changelog.d/6691.misc new file mode 100644 index 0000000000..104e9ce648 --- /dev/null +++ b/changelog.d/6691.misc @@ -0,0 +1 @@ +Remove redundant RegistrationError class. diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 5853a54c95..9e9844b47c 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -158,12 +158,6 @@ class UserDeactivatedError(SynapseError): ) -class RegistrationError(SynapseError): - """An error raised when a registration event fails.""" - - pass - - class FederationDeniedError(SynapseError): """An error raised when the server tries to federate with a server which is not on its federation whitelist. diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 885da82985..7ffc194f0c 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -20,13 +20,7 @@ from twisted.internet import defer from synapse import types from synapse.api.constants import MAX_USERID_LENGTH, LoginType -from synapse.api.errors import ( - AuthError, - Codes, - ConsentNotGivenError, - RegistrationError, - SynapseError, -) +from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError from synapse.config.server import is_threepid_reserved from synapse.http.servlet import assert_params_in_dict from synapse.replication.http.login import RegisterDeviceReplicationServlet @@ -165,7 +159,7 @@ class RegistrationHandler(BaseHandler): Returns: Deferred[str]: user_id Raises: - RegistrationError if there was a problem registering. + SynapseError if there was a problem registering. """ yield self.check_registration_ratelimit(address) @@ -182,7 +176,7 @@ class RegistrationHandler(BaseHandler): if not was_guest: try: int(localpart) - raise RegistrationError( + raise SynapseError( 400, "Numeric user IDs are reserved for guest users." ) except ValueError: diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 1e9ba3a201..e2915eb7b1 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -269,8 +269,6 @@ class RegistrationTestCase(unittest.HomeserverTestCase): one will be randomly generated. Returns: A tuple of (user_id, access_token). - Raises: - RegistrationError if there was a problem registering. """ if localpart is None: raise SynapseError(400, "Request must include user id") -- cgit 1.5.1 From 1177d3f3a33bd3ae1eef46fba360d319598359ad Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 13 Jan 2020 18:10:43 +0000 Subject: Quarantine media by ID or user ID (#6681) --- changelog.d/6681.feature | 1 + docs/admin_api/media_admin_api.md | 76 ++++++- docs/workers.md | 4 +- synapse/rest/admin/media.py | 68 +++++- synapse/storage/data_stores/main/room.py | 116 ++++++++++- tests/rest/admin/test_admin.py | 341 +++++++++++++++++++++++++++++++ tests/rest/client/v1/utils.py | 37 ++++ 7 files changed, 632 insertions(+), 11 deletions(-) create mode 100644 changelog.d/6681.feature (limited to 'tests') diff --git a/changelog.d/6681.feature b/changelog.d/6681.feature new file mode 100644 index 0000000000..5cf19a4e0e --- /dev/null +++ b/changelog.d/6681.feature @@ -0,0 +1 @@ +Add new quarantine media admin APIs to quarantine by media ID or by user who uploaded the media. diff --git a/docs/admin_api/media_admin_api.md b/docs/admin_api/media_admin_api.md index 8b3666d5f5..46ba7a1a71 100644 --- a/docs/admin_api/media_admin_api.md +++ b/docs/admin_api/media_admin_api.md @@ -22,19 +22,81 @@ It returns a JSON body like the following: } ``` -# Quarantine media in a room +# Quarantine media -This API 'quarantines' all the media in a room. +Quarantining media means that it is marked as inaccessible by users. It applies +to any local media, and any locally-cached copies of remote media. -The API is: +The media file itself (and any thumbnails) is not deleted from the server. + +## Quarantining media by ID + +This API quarantines a single piece of local or remote media. + +Request: ``` -POST /_synapse/admin/v1/quarantine_media/ +POST /_synapse/admin/v1/media/quarantine// {} ``` -Quarantining media means that it is marked as inaccessible by users. It applies -to any local media, and any locally-cached copies of remote media. +Where `server_name` is in the form of `example.org`, and `media_id` is in the +form of `abcdefg12345...`. + +Response: + +``` +{} +``` + +## Quarantining media in a room + +This API quarantines all local and remote media in a room. + +Request: + +``` +POST /_synapse/admin/v1/room//media/quarantine + +{} +``` + +Where `room_id` is in the form of `!roomid12345:example.org`. + +Response: + +``` +{ + "num_quarantined": 10 # The number of media items successfully quarantined +} +``` + +Note that there is a legacy endpoint, `POST +/_synapse/admin/v1/quarantine_media/`, that operates the same. +However, it is deprecated and may be removed in a future release. + +## Quarantining all media of a user + +This API quarantines all *local* media that a *local* user has uploaded. That is to say, if +you would like to quarantine media uploaded by a user on a remote homeserver, you should +instead use one of the other APIs. + +Request: + +``` +POST /_synapse/admin/v1/user//media/quarantine + +{} +``` + +Where `user_id` is in the form of `@bob:example.org`. + +Response: + +``` +{ + "num_quarantined": 10 # The number of media items successfully quarantined +} +``` -The media file itself (and any thumbnails) is not deleted from the server. diff --git a/docs/workers.md b/docs/workers.md index f4283aeb05..0ab269fd96 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -202,7 +202,9 @@ Handles the media repository. It can handle all endpoints starting with: ... and the following regular expressions matching media-specific administration APIs: ^/_synapse/admin/v1/purge_media_cache$ - ^/_synapse/admin/v1/room/.*/media$ + ^/_synapse/admin/v1/room/.*/media.*$ + ^/_synapse/admin/v1/user/.*/media.*$ + ^/_synapse/admin/v1/media/.*$ ^/_synapse/admin/v1/quarantine_media/.*$ You should also set `enable_media_repo: False` in the shared configuration diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index fa833e54cf..3a445d6eed 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -32,16 +32,24 @@ class QuarantineMediaInRoom(RestServlet): this server. """ - PATTERNS = historical_admin_path_patterns("/quarantine_media/(?P[^/]+)") + PATTERNS = ( + historical_admin_path_patterns("/room/(?P[^/]+)/media/quarantine") + + + # This path kept around for legacy reasons + historical_admin_path_patterns("/quarantine_media/(?P![^/]+)") + ) def __init__(self, hs): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_POST(self, request, room_id): + async def on_POST(self, request, room_id: str): requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) + logging.info("Quarantining room: %s", room_id) + + # Quarantine all media in this room num_quarantined = await self.store.quarantine_media_ids_in_room( room_id, requester.user.to_string() ) @@ -49,6 +57,60 @@ class QuarantineMediaInRoom(RestServlet): return 200, {"num_quarantined": num_quarantined} +class QuarantineMediaByUser(RestServlet): + """Quarantines all local media by a given user so that no one can download it via + this server. + """ + + PATTERNS = historical_admin_path_patterns( + "/user/(?P[^/]+)/media/quarantine" + ) + + def __init__(self, hs): + self.store = hs.get_datastore() + self.auth = hs.get_auth() + + async def on_POST(self, request, user_id: str): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + logging.info("Quarantining local media by user: %s", user_id) + + # Quarantine all media this user has uploaded + num_quarantined = await self.store.quarantine_media_ids_by_user( + user_id, requester.user.to_string() + ) + + return 200, {"num_quarantined": num_quarantined} + + +class QuarantineMediaByID(RestServlet): + """Quarantines local or remote media by a given ID so that no one can download + it via this server. + """ + + PATTERNS = historical_admin_path_patterns( + "/media/quarantine/(?P[^/]+)/(?P[^/]+)" + ) + + def __init__(self, hs): + self.store = hs.get_datastore() + self.auth = hs.get_auth() + + async def on_POST(self, request, server_name: str, media_id: str): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + logging.info("Quarantining local media by ID: %s/%s", server_name, media_id) + + # Quarantine this media id + await self.store.quarantine_media_by_id( + server_name, media_id, requester.user.to_string() + ) + + return 200, {} + + class ListMediaInRoom(RestServlet): """Lists all of the media in a given room. """ @@ -94,4 +156,6 @@ def register_servlets_for_media_repo(hs, http_server): """ PurgeMediaCacheRestServlet(hs).register(http_server) QuarantineMediaInRoom(hs).register(http_server) + QuarantineMediaByID(hs).register(http_server) + QuarantineMediaByUser(hs).register(http_server) ListMediaInRoom(hs).register(http_server) diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index 8636d75030..49bab62be3 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -18,7 +18,7 @@ import collections import logging import re from abc import abstractmethod -from typing import Optional, Tuple +from typing import List, Optional, Tuple from six import integer_types @@ -399,6 +399,8 @@ class RoomWorkerStore(SQLBaseStore): the associated media """ + logger.info("Quarantining media in room: %s", room_id) + def _quarantine_media_in_room_txn(txn): local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) total_media_quarantined = 0 @@ -494,6 +496,118 @@ class RoomWorkerStore(SQLBaseStore): return local_media_mxcs, remote_media_mxcs + def quarantine_media_by_id( + self, server_name: str, media_id: str, quarantined_by: str, + ): + """quarantines a single local or remote media id + + Args: + server_name: The name of the server that holds this media + media_id: The ID of the media to be quarantined + quarantined_by: The user ID that initiated the quarantine request + """ + logger.info("Quarantining media: %s/%s", server_name, media_id) + is_local = server_name == self.config.server_name + + def _quarantine_media_by_id_txn(txn): + local_mxcs = [media_id] if is_local else [] + remote_mxcs = [(server_name, media_id)] if not is_local else [] + + return self._quarantine_media_txn( + txn, local_mxcs, remote_mxcs, quarantined_by + ) + + return self.db.runInteraction( + "quarantine_media_by_user", _quarantine_media_by_id_txn + ) + + def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str): + """quarantines all local media associated with a single user + + Args: + user_id: The ID of the user to quarantine media of + quarantined_by: The ID of the user who made the quarantine request + """ + + def _quarantine_media_by_user_txn(txn): + local_media_ids = self._get_media_ids_by_user_txn(txn, user_id) + return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by) + + return self.db.runInteraction( + "quarantine_media_by_user", _quarantine_media_by_user_txn + ) + + def _get_media_ids_by_user_txn(self, txn, user_id: str, filter_quarantined=True): + """Retrieves local media IDs by a given user + + Args: + txn (cursor) + user_id: The ID of the user to retrieve media IDs of + + Returns: + The local and remote media as a lists of tuples where the key is + the hostname and the value is the media ID. + """ + # Local media + sql = """ + SELECT media_id + FROM local_media_repository + WHERE user_id = ? + """ + if filter_quarantined: + sql += "AND quarantined_by IS NULL" + txn.execute(sql, (user_id,)) + + local_media_ids = [row[0] for row in txn] + + # TODO: Figure out all remote media a user has referenced in a message + + return local_media_ids + + def _quarantine_media_txn( + self, + txn, + local_mxcs: List[str], + remote_mxcs: List[Tuple[str, str]], + quarantined_by: str, + ) -> int: + """Quarantine local and remote media items + + Args: + txn (cursor) + local_mxcs: A list of local mxc URLs + remote_mxcs: A list of (remote server, media id) tuples representing + remote mxc URLs + quarantined_by: The ID of the user who initiated the quarantine request + Returns: + The total number of media items quarantined + """ + total_media_quarantined = 0 + + # Update all the tables to set the quarantined_by flag + txn.executemany( + """ + UPDATE local_media_repository + SET quarantined_by = ? + WHERE media_id = ? + """, + ((quarantined_by, media_id) for media_id in local_mxcs), + ) + + txn.executemany( + """ + UPDATE remote_media_cache + SET quarantined_by = ? + WHERE media_origin = ? AND media_id = ? + """, + ((quarantined_by, origin, media_id) for origin, media_id in remote_mxcs), + ) + + total_media_quarantined += len(local_mxcs) + total_media_quarantined += len(remote_mxcs) + + return total_media_quarantined + class RoomBackgroundUpdateStore(SQLBaseStore): REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 6ceb483aa8..7a7e898843 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -14,11 +14,17 @@ # limitations under the License. import json +import os +import urllib.parse +from binascii import unhexlify from mock import Mock +from twisted.internet.defer import Deferred + import synapse.rest.admin from synapse.http.server import JsonResource +from synapse.logging.context import make_deferred_yieldable from synapse.rest.admin import VersionServlet from synapse.rest.client.v1 import events, login, room from synapse.rest.client.v2_alpha import groups @@ -346,3 +352,338 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase): self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) test_purge_room.skip = "Disabled because it's currently broken" + + +class QuarantineMediaTestCase(unittest.HomeserverTestCase): + """Test /quarantine_media admin API. + """ + + servlets = [ + synapse.rest.admin.register_servlets, + synapse.rest.admin.register_servlets_for_media_repo, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.hs = hs + + # Allow for uploading and downloading to/from the media repo + self.media_repo = hs.get_media_repository_resource() + self.download_resource = self.media_repo.children[b"download"] + self.upload_resource = self.media_repo.children[b"upload"] + self.image_data = unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000a49444154789c63000100000500010d" + b"0a2db40000000049454e44ae426082" + ) + + def make_homeserver(self, reactor, clock): + + self.fetches = [] + + def get_file(destination, path, output_stream, args=None, max_size=None): + """ + Returns tuple[int,dict,str,int] of file length, response headers, + absolute URI, and response code. + """ + + def write_to(r): + data, response = r + output_stream.write(data) + return response + + d = Deferred() + d.addCallback(write_to) + self.fetches.append((d, destination, path, args)) + return make_deferred_yieldable(d) + + client = Mock() + client.get_file = get_file + + self.storage_path = self.mktemp() + self.media_store_path = self.mktemp() + os.mkdir(self.storage_path) + os.mkdir(self.media_store_path) + + config = self.default_config() + config["media_store_path"] = self.media_store_path + config["thumbnail_requirements"] = {} + config["max_image_pixels"] = 2000000 + + provider_config = { + "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend", + "store_local": True, + "store_synchronous": False, + "store_remote": True, + "config": {"directory": self.storage_path}, + } + config["media_storage_providers"] = [provider_config] + + hs = self.setup_test_homeserver(config=config, http_client=client) + + return hs + + def test_quarantine_media_requires_admin(self): + self.register_user("nonadmin", "pass", admin=False) + non_admin_user_tok = self.login("nonadmin", "pass") + + # Attempt quarantine media APIs as non-admin + url = "/_synapse/admin/v1/media/quarantine/example.org/abcde12345" + request, channel = self.make_request( + "POST", url.encode("ascii"), access_token=non_admin_user_tok, + ) + self.render(request) + + # Expect a forbidden error + self.assertEqual( + 403, + int(channel.result["code"]), + msg="Expected forbidden on quarantining media as a non-admin", + ) + + # And the roomID/userID endpoint + url = "/_synapse/admin/v1/room/!room%3Aexample.com/media/quarantine" + request, channel = self.make_request( + "POST", url.encode("ascii"), access_token=non_admin_user_tok, + ) + self.render(request) + + # Expect a forbidden error + self.assertEqual( + 403, + int(channel.result["code"]), + msg="Expected forbidden on quarantining media as a non-admin", + ) + + def test_quarantine_media_by_id(self): + self.register_user("id_admin", "pass", admin=True) + admin_user_tok = self.login("id_admin", "pass") + + self.register_user("id_nonadmin", "pass", admin=False) + non_admin_user_tok = self.login("id_nonadmin", "pass") + + # Upload some media into the room + response = self.helper.upload_media( + self.upload_resource, self.image_data, tok=admin_user_tok + ) + + # Extract media ID from the response + server_name_and_media_id = response["content_uri"][ + 6: + ] # Cut off the 'mxc://' bit + server_name, media_id = server_name_and_media_id.split("/") + + # Attempt to access the media + request, channel = self.make_request( + "GET", + server_name_and_media_id, + shorthand=False, + access_token=non_admin_user_tok, + ) + request.render(self.download_resource) + self.pump(1.0) + + # Should be successful + self.assertEqual(200, int(channel.code), msg=channel.result["body"]) + + # Quarantine the media + url = "/_synapse/admin/v1/media/quarantine/%s/%s" % ( + urllib.parse.quote(server_name), + urllib.parse.quote(media_id), + ) + request, channel = self.make_request("POST", url, access_token=admin_user_tok,) + self.render(request) + self.pump(1.0) + self.assertEqual(200, int(channel.code), msg=channel.result["body"]) + + # Attempt to access the media + request, channel = self.make_request( + "GET", + server_name_and_media_id, + shorthand=False, + access_token=admin_user_tok, + ) + request.render(self.download_resource) + self.pump(1.0) + + # Should be quarantined + self.assertEqual( + 404, + int(channel.code), + msg=( + "Expected to receive a 404 on accessing quarantined media: %s" + % server_name_and_media_id + ), + ) + + def test_quarantine_all_media_in_room(self): + self.register_user("room_admin", "pass", admin=True) + admin_user_tok = self.login("room_admin", "pass") + + non_admin_user = self.register_user("room_nonadmin", "pass", admin=False) + non_admin_user_tok = self.login("room_nonadmin", "pass") + + room_id = self.helper.create_room_as(non_admin_user, tok=admin_user_tok) + self.helper.join(room_id, non_admin_user, tok=non_admin_user_tok) + + # Upload some media + response_1 = self.helper.upload_media( + self.upload_resource, self.image_data, tok=non_admin_user_tok + ) + response_2 = self.helper.upload_media( + self.upload_resource, self.image_data, tok=non_admin_user_tok + ) + + # Extract mxcs + mxc_1 = response_1["content_uri"] + mxc_2 = response_2["content_uri"] + + # Send it into the room + self.helper.send_event( + room_id, + "m.room.message", + content={"body": "image-1", "msgtype": "m.image", "url": mxc_1}, + txn_id="111", + tok=non_admin_user_tok, + ) + self.helper.send_event( + room_id, + "m.room.message", + content={"body": "image-2", "msgtype": "m.image", "url": mxc_2}, + txn_id="222", + tok=non_admin_user_tok, + ) + + # Quarantine all media in the room + url = "/_synapse/admin/v1/room/%s/media/quarantine" % urllib.parse.quote( + room_id + ) + request, channel = self.make_request("POST", url, access_token=admin_user_tok,) + self.render(request) + self.pump(1.0) + self.assertEqual(200, int(channel.code), msg=channel.result["body"]) + self.assertEqual( + json.loads(channel.result["body"].decode("utf-8")), + {"num_quarantined": 2}, + "Expected 2 quarantined items", + ) + + # Convert mxc URLs to server/media_id strings + server_and_media_id_1 = mxc_1[6:] + server_and_media_id_2 = mxc_2[6:] + + # Test that we cannot download any of the media anymore + request, channel = self.make_request( + "GET", + server_and_media_id_1, + shorthand=False, + access_token=non_admin_user_tok, + ) + request.render(self.download_resource) + self.pump(1.0) + + # Should be quarantined + self.assertEqual( + 404, + int(channel.code), + msg=( + "Expected to receive a 404 on accessing quarantined media: %s" + % server_and_media_id_1 + ), + ) + + request, channel = self.make_request( + "GET", + server_and_media_id_2, + shorthand=False, + access_token=non_admin_user_tok, + ) + request.render(self.download_resource) + self.pump(1.0) + + # Should be quarantined + self.assertEqual( + 404, + int(channel.code), + msg=( + "Expected to receive a 404 on accessing quarantined media: %s" + % server_and_media_id_2 + ), + ) + + def test_quarantine_all_media_by_user(self): + self.register_user("user_admin", "pass", admin=True) + admin_user_tok = self.login("user_admin", "pass") + + non_admin_user = self.register_user("user_nonadmin", "pass", admin=False) + non_admin_user_tok = self.login("user_nonadmin", "pass") + + # Upload some media + response_1 = self.helper.upload_media( + self.upload_resource, self.image_data, tok=non_admin_user_tok + ) + response_2 = self.helper.upload_media( + self.upload_resource, self.image_data, tok=non_admin_user_tok + ) + + # Extract media IDs + server_and_media_id_1 = response_1["content_uri"][6:] + server_and_media_id_2 = response_2["content_uri"][6:] + + # Quarantine all media by this user + url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote( + non_admin_user + ) + request, channel = self.make_request( + "POST", url.encode("ascii"), access_token=admin_user_tok, + ) + self.render(request) + self.pump(1.0) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + json.loads(channel.result["body"].decode("utf-8")), + {"num_quarantined": 2}, + "Expected 2 quarantined items", + ) + + # Attempt to access each piece of media + request, channel = self.make_request( + "GET", + server_and_media_id_1, + shorthand=False, + access_token=non_admin_user_tok, + ) + request.render(self.download_resource) + self.pump(1.0) + + # Should be quarantined + self.assertEqual( + 404, + int(channel.code), + msg=( + "Expected to receive a 404 on accessing quarantined media: %s" + % server_and_media_id_1, + ), + ) + + # Attempt to access each piece of media + request, channel = self.make_request( + "GET", + server_and_media_id_2, + shorthand=False, + access_token=non_admin_user_tok, + ) + request.render(self.download_resource) + self.pump(1.0) + + # Should be quarantined + self.assertEqual( + 404, + int(channel.code), + msg=( + "Expected to receive a 404 on accessing quarantined media: %s" + % server_and_media_id_2 + ), + ) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index e7417b3d14..873d5ef99c 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -21,6 +21,8 @@ import time import attr +from twisted.web.resource import Resource + from synapse.api.constants import Membership from tests.server import make_request, render @@ -160,3 +162,38 @@ class RestHelper(object): ) return channel.json_body + + def upload_media( + self, + resource: Resource, + image_data: bytes, + tok: str, + filename: str = "test.png", + expect_code: int = 200, + ) -> dict: + """Upload a piece of test media to the media repo + Args: + resource: The resource that will handle the upload request + image_data: The image data to upload + tok: The user token to use during the upload + filename: The filename of the media to be uploaded + expect_code: The return code to expect from attempting to upload the media + """ + image_length = len(image_data) + path = "/_matrix/media/r0/upload?filename=%s" % (filename,) + request, channel = make_request( + self.hs.get_reactor(), "POST", path, content=image_data, access_token=tok + ) + request.requestHeaders.addRawHeader( + b"Content-Length", str(image_length).encode("UTF-8") + ) + request.render(resource) + self.hs.get_reactor().pump([100]) + + assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( + expect_code, + int(channel.result["code"]), + channel.result["body"], + ) + + return channel.json_body -- cgit 1.5.1 From 28c98e51ffa166bd717646b0b34228e59f253485 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 15 Jan 2020 14:59:33 +0000 Subject: Add `local_current_membership` table (#6655) Currently we rely on `current_state_events` to figure out what rooms a user was in and their last membership event in there. However, if the server leaves the room then the table may be cleaned up and that information is lost. So lets add a table that separately holds that information. --- changelog.d/6655.misc | 1 + scripts/synapse_port_db | 2 +- synapse/handlers/admin.py | 2 +- synapse/handlers/deactivate_account.py | 2 +- synapse/handlers/initial_sync.py | 2 +- synapse/handlers/room_member.py | 2 +- synapse/handlers/search.py | 2 +- synapse/handlers/sync.py | 2 +- synapse/push/push_tools.py | 2 +- synapse/replication/slave/storage/events.py | 2 +- synapse/server_notices/server_notices_manager.py | 2 +- synapse/storage/data_stores/main/events.py | 30 ++++ synapse/storage/data_stores/main/roommember.py | 189 ++++++++++++--------- .../schema/delta/57/local_current_membership.py | 97 +++++++++++ synapse/storage/prepare_database.py | 2 +- tests/handlers/test_sync.py | 4 +- tests/replication/slave/storage/test_events.py | 4 +- tests/rest/client/v2_alpha/test_account.py | 12 +- tests/rest/client/v2_alpha/test_sync.py | 9 - tests/storage/test_roommember.py | 2 +- 20 files changed, 263 insertions(+), 107 deletions(-) create mode 100644 changelog.d/6655.misc create mode 100644 synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py (limited to 'tests') diff --git a/changelog.d/6655.misc b/changelog.d/6655.misc new file mode 100644 index 0000000000..01e78bc84e --- /dev/null +++ b/changelog.d/6655.misc @@ -0,0 +1 @@ +Add `local_current_membership` table for tracking local user membership state in rooms. diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index f135c8bc54..5e69104b97 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -470,7 +470,7 @@ class Porter(object): engine.check_database( db_conn, allow_outdated_version=allow_outdated_version ) - prepare_database(db_conn, engine, config=None) + prepare_database(db_conn, engine, config=self.hs_config) store = Store(Database(hs, db_config, engine), db_conn, hs) db_conn.commit() diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 76d18a8ba8..a9407553b4 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -134,7 +134,7 @@ class AdminHandler(BaseHandler): The returned value is that returned by `writer.finished()`. """ # Get all rooms the user is in or has been in - rooms = await self.store.get_rooms_for_user_where_membership_is( + rooms = await self.store.get_rooms_for_local_user_where_membership_is( user_id, membership_list=( Membership.JOIN, diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 4426967f88..2afb390a92 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -140,7 +140,7 @@ class DeactivateAccountHandler(BaseHandler): user_id (str): The user ID to reject pending invites for. """ user = UserID.from_string(user_id) - pending_invites = await self.store.get_invited_rooms_for_user(user_id) + pending_invites = await self.store.get_invited_rooms_for_local_user(user_id) for room in pending_invites: try: diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 44ec3e66ae..2e6755f19c 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -101,7 +101,7 @@ class InitialSyncHandler(BaseHandler): if include_archived: memberships.append(Membership.LEAVE) - room_list = await self.store.get_rooms_for_user_where_membership_is( + room_list = await self.store.get_rooms_for_local_user_where_membership_is( user_id=user_id, membership_list=memberships ) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 03bb52ccfb..15e8aa5249 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -690,7 +690,7 @@ class RoomMemberHandler(object): @defer.inlineCallbacks def _get_inviter(self, user_id, room_id): - invite = yield self.store.get_invite_for_user_in_room( + invite = yield self.store.get_invite_for_local_user_in_room( user_id=user_id, room_id=room_id ) if invite: diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index ef750d1497..110097eab9 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -179,7 +179,7 @@ class SearchHandler(BaseHandler): search_filter = Filter(filter_dict) # TODO: Search through left rooms too - rooms = yield self.store.get_rooms_for_user_where_membership_is( + rooms = yield self.store.get_rooms_for_local_user_where_membership_is( user.to_string(), membership_list=[Membership.JOIN], # membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban], diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 2d3b8ba73c..cd95f85e3f 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1662,7 +1662,7 @@ class SyncHandler(object): Membership.BAN, ) - room_list = await self.store.get_rooms_for_user_where_membership_is( + room_list = await self.store.get_rooms_for_local_user_where_membership_is( user_id=user_id, membership_list=membership_list ) diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index de5c101a58..5dae4648c0 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -21,7 +21,7 @@ from synapse.storage import Storage @defer.inlineCallbacks def get_badge_count(store, user_id): - invites = yield store.get_invited_rooms_for_user(user_id) + invites = yield store.get_invited_rooms_for_local_user(user_id) joins = yield store.get_rooms_for_user(user_id) my_receipts_by_room = yield store.get_receipts_for_user(user_id, "m.read") diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 29f35b9915..3aa6cb8b96 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -152,7 +152,7 @@ class SlavedEventStore( if etype == EventTypes.Member: self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) - self.get_invited_rooms_for_user.invalidate((state_key,)) + self.get_invited_rooms_for_local_user.invalidate((state_key,)) if relates_to: self.get_relations_for_event.invalidate_many((relates_to,)) diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index 2dac90578c..f7432c8d2f 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -105,7 +105,7 @@ class ServerNoticesManager(object): assert self._is_mine_id(user_id), "Cannot send server notices to remote users" - rooms = yield self._store.get_rooms_for_user_where_membership_is( + rooms = yield self._store.get_rooms_for_local_user_where_membership_is( user_id, [Membership.INVITE, Membership.JOIN] ) system_mxid = self._config.server_notices_mxid diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 58f35d7f56..e9fe63037b 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -128,6 +128,7 @@ class EventsStore( hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000) self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages + self.is_mine_id = hs.is_mine_id @defer.inlineCallbacks def _read_forward_extremities(self): @@ -547,6 +548,34 @@ class EventsStore( ], ) + # Note: Do we really want to delete rows here (that we do not + # subsequently reinsert below)? While technically correct it means + # we have no record of the fact the user *was* a member of the + # room but got, say, state reset out of it. + if to_delete or to_insert: + txn.executemany( + "DELETE FROM local_current_membership" + " WHERE room_id = ? AND user_id = ?", + ( + (room_id, state_key) + for etype, state_key in itertools.chain(to_delete, to_insert) + if etype == EventTypes.Member and self.is_mine_id(state_key) + ), + ) + + if to_insert: + txn.executemany( + """INSERT INTO local_current_membership + (room_id, user_id, event_id, membership) + VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?)) + """, + [ + (room_id, key[1], ev_id, ev_id) + for key, ev_id in to_insert.items() + if key[0] == EventTypes.Member and self.is_mine_id(key[1]) + ], + ) + txn.call_after( self._curr_state_delta_stream_cache.entity_has_changed, room_id, @@ -1724,6 +1753,7 @@ class EventsStore( "local_invites", "room_account_data", "room_tags", + "local_current_membership", ): logger.info("[purge] removing %s from %s", room_id, table) txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,)) diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 70ff5751b6..9acef7c950 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -297,19 +297,22 @@ class RoomMemberWorkerStore(EventsWorkerStore): return {row[0]: row[1] for row in txn} @cached() - def get_invited_rooms_for_user(self, user_id): - """ Get all the rooms the user is invited to + def get_invited_rooms_for_local_user(self, user_id): + """ Get all the rooms the *local* user is invited to + Args: user_id (str): The user ID. Returns: A deferred list of RoomsForUser. """ - return self.get_rooms_for_user_where_membership_is(user_id, [Membership.INVITE]) + return self.get_rooms_for_local_user_where_membership_is( + user_id, [Membership.INVITE] + ) @defer.inlineCallbacks - def get_invite_for_user_in_room(self, user_id, room_id): - """Gets the invite for the given user and room + def get_invite_for_local_user_in_room(self, user_id, room_id): + """Gets the invite for the given *local* user and room Args: user_id (str) @@ -319,15 +322,15 @@ class RoomMemberWorkerStore(EventsWorkerStore): Deferred: Resolves to either a RoomsForUser or None if no invite was found. """ - invites = yield self.get_invited_rooms_for_user(user_id) + invites = yield self.get_invited_rooms_for_local_user(user_id) for invite in invites: if invite.room_id == room_id: return invite return None @defer.inlineCallbacks - def get_rooms_for_user_where_membership_is(self, user_id, membership_list): - """ Get all the rooms for this user where the membership for this user + def get_rooms_for_local_user_where_membership_is(self, user_id, membership_list): + """ Get all the rooms for this *local* user where the membership for this user matches one in the membership list. Filters out forgotten rooms. @@ -344,8 +347,8 @@ class RoomMemberWorkerStore(EventsWorkerStore): return defer.succeed(None) rooms = yield self.db.runInteraction( - "get_rooms_for_user_where_membership_is", - self._get_rooms_for_user_where_membership_is_txn, + "get_rooms_for_local_user_where_membership_is", + self._get_rooms_for_local_user_where_membership_is_txn, user_id, membership_list, ) @@ -354,76 +357,42 @@ class RoomMemberWorkerStore(EventsWorkerStore): forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id) return [room for room in rooms if room.room_id not in forgotten_rooms] - def _get_rooms_for_user_where_membership_is_txn( + def _get_rooms_for_local_user_where_membership_is_txn( self, txn, user_id, membership_list ): + # Paranoia check. + if not self.hs.is_mine_id(user_id): + raise Exception( + "Cannot call 'get_rooms_for_local_user_where_membership_is' on non-local user %r" + % (user_id,), + ) - do_invite = Membership.INVITE in membership_list - membership_list = [m for m in membership_list if m != Membership.INVITE] - - results = [] - if membership_list: - if self._current_state_events_membership_up_to_date: - clause, args = make_in_list_sql_clause( - self.database_engine, "c.membership", membership_list - ) - sql = """ - SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering - FROM current_state_events AS c - INNER JOIN events AS e USING (room_id, event_id) - WHERE - c.type = 'm.room.member' - AND state_key = ? - AND %s - """ % ( - clause, - ) - else: - clause, args = make_in_list_sql_clause( - self.database_engine, "m.membership", membership_list - ) - sql = """ - SELECT room_id, e.sender, m.membership, event_id, e.stream_ordering - FROM current_state_events AS c - INNER JOIN room_memberships AS m USING (room_id, event_id) - INNER JOIN events AS e USING (room_id, event_id) - WHERE - c.type = 'm.room.member' - AND state_key = ? - AND %s - """ % ( - clause, - ) - - txn.execute(sql, (user_id, *args)) - results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)] + clause, args = make_in_list_sql_clause( + self.database_engine, "c.membership", membership_list + ) - if do_invite: - sql = ( - "SELECT i.room_id, inviter, i.event_id, e.stream_ordering" - " FROM local_invites as i" - " INNER JOIN events as e USING (event_id)" - " WHERE invitee = ? AND locally_rejected is NULL" - " AND replaced_by is NULL" - ) + sql = """ + SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering + FROM local_current_membership AS c + INNER JOIN events AS e USING (room_id, event_id) + WHERE + user_id = ? + AND %s + """ % ( + clause, + ) - txn.execute(sql, (user_id,)) - results.extend( - RoomsForUser( - room_id=r["room_id"], - sender=r["inviter"], - event_id=r["event_id"], - stream_ordering=r["stream_ordering"], - membership=Membership.INVITE, - ) - for r in self.db.cursor_to_dict(txn) - ) + txn.execute(sql, (user_id, *args)) + results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)] return results - @cachedInlineCallbacks(max_entries=500000, iterable=True) + @cached(max_entries=500000, iterable=True) def get_rooms_for_user_with_stream_ordering(self, user_id): - """Returns a set of room_ids the user is currently joined to + """Returns a set of room_ids the user is currently joined to. + + If a remote user only returns rooms this server is currently + participating in. Args: user_id (str) @@ -433,17 +402,49 @@ class RoomMemberWorkerStore(EventsWorkerStore): the rooms the user is in currently, along with the stream ordering of the most recent join for that user and room. """ - rooms = yield self.get_rooms_for_user_where_membership_is( - user_id, membership_list=[Membership.JOIN] - ) - return frozenset( - GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering) - for r in rooms + return self.db.runInteraction( + "get_rooms_for_user_with_stream_ordering", + self._get_rooms_for_user_with_stream_ordering_txn, + user_id, ) + def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id): + # We use `current_state_events` here and not `local_current_membership` + # as a) this gets called with remote users and b) this only gets called + # for rooms the server is participating in. + if self._current_state_events_membership_up_to_date: + sql = """ + SELECT room_id, e.stream_ordering + FROM current_state_events AS c + INNER JOIN events AS e USING (room_id, event_id) + WHERE + c.type = 'm.room.member' + AND state_key = ? + AND c.membership = ? + """ + else: + sql = """ + SELECT room_id, e.stream_ordering + FROM current_state_events AS c + INNER JOIN room_memberships AS m USING (room_id, event_id) + INNER JOIN events AS e USING (room_id, event_id) + WHERE + c.type = 'm.room.member' + AND state_key = ? + AND m.membership = ? + """ + + txn.execute(sql, (user_id, Membership.JOIN)) + results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn) + + return results + @defer.inlineCallbacks def get_rooms_for_user(self, user_id, on_invalidate=None): - """Returns a set of room_ids the user is currently joined to + """Returns a set of room_ids the user is currently joined to. + + If a remote user only returns rooms this server is currently + participating in. """ rooms = yield self.get_rooms_for_user_with_stream_ordering( user_id, on_invalidate=on_invalidate @@ -1022,7 +1023,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): event.internal_metadata.stream_ordering, ) txn.call_after( - self.get_invited_rooms_for_user.invalidate, (event.state_key,) + self.get_invited_rooms_for_local_user.invalidate, (event.state_key,) ) # We update the local_invites table only if the event is "current", @@ -1064,6 +1065,27 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): ), ) + # We also update the `local_current_membership` table with + # latest invite info. This will usually get updated by the + # `current_state_events` handling, unless its an outlier. + if event.internal_metadata.is_outlier(): + # This should only happen for out of band memberships, so + # we add a paranoia check. + assert event.internal_metadata.is_out_of_band_membership() + + self.db.simple_upsert_txn( + txn, + table="local_current_membership", + keyvalues={ + "room_id": event.room_id, + "user_id": event.state_key, + }, + values={ + "event_id": event.event_id, + "membership": event.membership, + }, + ) + @defer.inlineCallbacks def locally_reject_invite(self, user_id, room_id): sql = ( @@ -1075,6 +1097,15 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): def f(txn, stream_ordering): txn.execute(sql, (stream_ordering, True, room_id, user_id)) + # We also clear this entry from `local_current_membership`. + # Ideally we'd point to a leave event, but we don't have one, so + # nevermind. + self.db.simple_delete_txn( + txn, + table="local_current_membership", + keyvalues={"room_id": room_id, "user_id": user_id}, + ) + with self._stream_id_gen.get_next() as stream_ordering: yield self.db.runInteraction("locally_reject_invite", f, stream_ordering) diff --git a/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py b/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py new file mode 100644 index 0000000000..601c236c4a --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# We create a new table called `local_current_membership` that stores the latest +# membership state of local users in rooms, which helps track leaves/bans/etc +# even if the server has left the room (and so has deleted the room from +# `current_state_events`). This will also include outstanding invites for local +# users for rooms the server isn't in. +# +# If the server isn't and hasn't been in the room then it will only include +# outsstanding invites, and not e.g. pre-emptive bans of local users. +# +# If the server later rejoins a room `local_current_membership` can simply be +# replaced with the new current state of the room (which results in the +# equivalent behaviour as if the server had remained in the room). + + +def run_upgrade(cur, database_engine, config, *args, **kwargs): + # We need to do the insert in `run_upgrade` section as we don't have access + # to `config` in `run_create`. + + # This upgrade may take a bit of time for large servers (e.g. one minute for + # matrix.org) but means we avoid a lots of book keeping required to do it as + # a background update. + + # We check if the `current_state_events.membership` is up to date by + # checking if the relevant background update has finished. If it has + # finished we can avoid doing a join against `room_memberships`, which + # speesd things up. + cur.execute( + """SELECT 1 FROM background_updates + WHERE update_name = 'current_state_events_membership' + """ + ) + current_state_membership_up_to_date = not bool(cur.fetchone()) + + # Cheekily drop and recreate indices, as that is faster. + cur.execute("DROP INDEX local_current_membership_idx") + cur.execute("DROP INDEX local_current_membership_room_idx") + + if current_state_membership_up_to_date: + sql = """ + INSERT INTO local_current_membership (room_id, user_id, event_id, membership) + SELECT c.room_id, state_key AS user_id, event_id, c.membership + FROM current_state_events AS c + WHERE type = 'm.room.member' AND c.membership IS NOT NULL AND state_key like '%' || ? + """ + else: + # We can't rely on the membership column, so we need to join against + # `room_memberships`. + sql = """ + INSERT INTO local_current_membership (room_id, user_id, event_id, membership) + SELECT c.room_id, state_key AS user_id, event_id, r.membership + FROM current_state_events AS c + INNER JOIN room_memberships AS r USING (event_id) + WHERE type = 'm.room.member' and state_key like '%' || ? + """ + cur.execute(sql, (config.server_name,)) + + cur.execute( + "CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)" + ) + cur.execute( + "CREATE INDEX local_current_membership_room_idx ON local_current_membership(room_id)" + ) + + +def run_create(cur, database_engine, *args, **kwargs): + cur.execute( + """ + CREATE TABLE local_current_membership ( + room_id TEXT NOT NULL, + user_id TEXT NOT NULL, + event_id TEXT NOT NULL, + membership TEXT NOT NULL + )""" + ) + + cur.execute( + "CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)" + ) + cur.execute( + "CREATE INDEX local_current_membership_room_idx ON local_current_membership(room_id)" + ) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index e70026b80a..e86984cd50 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 56 +SCHEMA_VERSION = 57 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 758ee071a5..4cbe9784ed 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -32,8 +32,8 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): def test_wait_for_sync_for_user_auth_blocking(self): - user_id1 = "@user1:server" - user_id2 = "@user2:server" + user_id1 = "@user1:test" + user_id2 = "@user2:test" sync_config = self._generate_sync_config(user_id1) self.reactor.advance(100) # So we get not 0 time diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index b68e9fe082..b1b037006d 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -115,13 +115,13 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): def test_invites(self): self.persist(type="m.room.create", key="", creator=USER_ID) - self.check("get_invited_rooms_for_user", [USER_ID_2], []) + self.check("get_invited_rooms_for_local_user", [USER_ID_2], []) event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite") self.replicate() self.check( - "get_invited_rooms_for_user", + "get_invited_rooms_for_local_user", [USER_ID_2], [ RoomsForUser( diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 0f51895b81..c3facc00eb 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -285,7 +285,9 @@ class DeactivateTestCase(unittest.HomeserverTestCase): ) # Make sure the invite is here. - pending_invites = self.get_success(store.get_invited_rooms_for_user(invitee_id)) + pending_invites = self.get_success( + store.get_invited_rooms_for_local_user(invitee_id) + ) self.assertEqual(len(pending_invites), 1, pending_invites) self.assertEqual(pending_invites[0].room_id, room_id, pending_invites) @@ -293,12 +295,16 @@ class DeactivateTestCase(unittest.HomeserverTestCase): self.deactivate(invitee_id, invitee_tok) # Check that the invite isn't there anymore. - pending_invites = self.get_success(store.get_invited_rooms_for_user(invitee_id)) + pending_invites = self.get_success( + store.get_invited_rooms_for_local_user(invitee_id) + ) self.assertEqual(len(pending_invites), 0, pending_invites) # Check that the membership of @invitee:test in the room is now "leave". memberships = self.get_success( - store.get_rooms_for_user_where_membership_is(invitee_id, [Membership.LEAVE]) + store.get_rooms_for_local_user_where_membership_is( + invitee_id, [Membership.LEAVE] + ) ) self.assertEqual(len(memberships), 1, memberships) self.assertEqual(memberships[0].room_id, room_id, memberships) diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 661c1f88b9..9c13a13786 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -15,8 +15,6 @@ # limitations under the License. import json -from mock import Mock - import synapse.rest.admin from synapse.api.constants import EventContentFields, EventTypes from synapse.rest.client.v1 import login, room @@ -36,13 +34,6 @@ class FilterTestCase(unittest.HomeserverTestCase): sync.register_servlets, ] - def make_homeserver(self, reactor, clock): - - hs = self.setup_test_homeserver( - "red", http_client=None, federation_client=Mock() - ) - return hs - def test_sync_argless(self): request, channel = self.make_request("GET", "/sync") self.render(request) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 7840f63fe3..00df0ea68e 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -57,7 +57,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) rooms_for_user = self.get_success( - self.store.get_rooms_for_user_where_membership_is( + self.store.get_rooms_for_local_user_where_membership_is( self.u_alice, [Membership.JOIN] ) ) -- cgit 1.5.1 From 8f5d7302acb7f6d15ba7051df7fd7fda7375a29e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 15 Jan 2020 15:58:55 +0000 Subject: Implement RedirectException (#6687) Allow REST endpoint implemnentations to raise a RedirectException, which will redirect the user's browser to a given location. --- changelog.d/6687.misc | 1 + synapse/api/errors.py | 27 ++++++++++++++++- synapse/http/server.py | 13 ++++++--- tests/test_server.py | 79 ++++++++++++++++++++++++++++++++++++++++++++++++-- 4 files changed, 113 insertions(+), 7 deletions(-) create mode 100644 changelog.d/6687.misc (limited to 'tests') diff --git a/changelog.d/6687.misc b/changelog.d/6687.misc new file mode 100644 index 0000000000..deb0454602 --- /dev/null +++ b/changelog.d/6687.misc @@ -0,0 +1 @@ +Allow REST endpoint implementations to raise a RedirectException, which will redirect the user's browser to a given location. diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 9e9844b47c..1c9456e583 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -17,13 +17,15 @@ """Contains exceptions and error codes.""" import logging -from typing import Dict +from typing import Dict, List from six import iteritems from six.moves import http_client from canonicaljson import json +from twisted.web import http + logger = logging.getLogger(__name__) @@ -80,6 +82,29 @@ class CodeMessageException(RuntimeError): self.msg = msg +class RedirectException(CodeMessageException): + """A pseudo-error indicating that we want to redirect the client to a different + location + + Attributes: + cookies: a list of set-cookies values to add to the response. For example: + b"sessionId=a3fWa; Expires=Wed, 21 Oct 2015 07:28:00 GMT" + """ + + def __init__(self, location: bytes, http_code: int = http.FOUND): + """ + + Args: + location: the URI to redirect to + http_code: the HTTP response code + """ + msg = "Redirect to %s" % (location.decode("utf-8"),) + super().__init__(code=http_code, msg=msg) + self.location = location + + self.cookies = [] # type: List[bytes] + + class SynapseError(CodeMessageException): """A base exception type for matrix errors which have an errcode and error message (as well as an HTTP status code). diff --git a/synapse/http/server.py b/synapse/http/server.py index 943d12c907..04bc2385a2 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -14,8 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import cgi import collections +import html import http.client import logging import types @@ -36,6 +36,7 @@ import synapse.metrics from synapse.api.errors import ( CodeMessageException, Codes, + RedirectException, SynapseError, UnrecognizedRequestError, ) @@ -153,14 +154,18 @@ def _return_html_error(f, request): Args: f (twisted.python.failure.Failure): - request (twisted.web.iweb.IRequest): + request (twisted.web.server.Request): """ if f.check(CodeMessageException): cme = f.value code = cme.code msg = cme.msg - if isinstance(cme, SynapseError): + if isinstance(cme, RedirectException): + logger.info("%s redirect to %s", request, cme.location) + request.setHeader(b"location", cme.location) + request.cookies.extend(cme.cookies) + elif isinstance(cme, SynapseError): logger.info("%s SynapseError: %s - %s", request, code, msg) else: logger.error( @@ -178,7 +183,7 @@ def _return_html_error(f, request): exc_info=(f.type, f.value, f.getTracebackObject()), ) - body = HTML_ERROR_TEMPLATE.format(code=code, msg=cgi.escape(msg)).encode("utf-8") + body = HTML_ERROR_TEMPLATE.format(code=code, msg=html.escape(msg)).encode("utf-8") request.setResponseCode(code) request.setHeader(b"Content-Type", b"text/html; charset=utf-8") request.setHeader(b"Content-Length", b"%i" % (len(body),)) diff --git a/tests/test_server.py b/tests/test_server.py index 98fef21d55..0d57eed268 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -23,8 +23,12 @@ from twisted.test.proto_helpers import AccumulatingProtocol from twisted.web.resource import Resource from twisted.web.server import NOT_DONE_YET -from synapse.api.errors import Codes, SynapseError -from synapse.http.server import JsonResource +from synapse.api.errors import Codes, RedirectException, SynapseError +from synapse.http.server import ( + DirectServeResource, + JsonResource, + wrap_html_request_handler, +) from synapse.http.site import SynapseSite, logger from synapse.logging.context import make_deferred_yieldable from synapse.util import Clock @@ -164,6 +168,77 @@ class JsonResourceTests(unittest.TestCase): self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") +class WrapHtmlRequestHandlerTests(unittest.TestCase): + class TestResource(DirectServeResource): + callback = None + + @wrap_html_request_handler + async def _async_render_GET(self, request): + return await self.callback(request) + + def setUp(self): + self.reactor = ThreadedMemoryReactorClock() + + def test_good_response(self): + def callback(request): + request.write(b"response") + request.finish() + + res = WrapHtmlRequestHandlerTests.TestResource() + res.callback = callback + + request, channel = make_request(self.reactor, b"GET", b"/path") + render(request, res, self.reactor) + + self.assertEqual(channel.result["code"], b"200") + body = channel.result["body"] + self.assertEqual(body, b"response") + + def test_redirect_exception(self): + """ + If the callback raises a RedirectException, it is turned into a 30x + with the right location. + """ + + def callback(request, **kwargs): + raise RedirectException(b"/look/an/eagle", 301) + + res = WrapHtmlRequestHandlerTests.TestResource() + res.callback = callback + + request, channel = make_request(self.reactor, b"GET", b"/path") + render(request, res, self.reactor) + + self.assertEqual(channel.result["code"], b"301") + headers = channel.result["headers"] + location_headers = [v for k, v in headers if k == b"Location"] + self.assertEqual(location_headers, [b"/look/an/eagle"]) + + def test_redirect_exception_with_cookie(self): + """ + If the callback raises a RedirectException which sets a cookie, that is + returned too + """ + + def callback(request, **kwargs): + e = RedirectException(b"/no/over/there", 304) + e.cookies.append(b"session=yespls") + raise e + + res = WrapHtmlRequestHandlerTests.TestResource() + res.callback = callback + + request, channel = make_request(self.reactor, b"GET", b"/path") + render(request, res, self.reactor) + + self.assertEqual(channel.result["code"], b"304") + headers = channel.result["headers"] + location_headers = [v for k, v in headers if k == b"Location"] + self.assertEqual(location_headers, [b"/no/over/there"]) + cookies_headers = [v for k, v in headers if k == b"Set-Cookie"] + self.assertEqual(cookies_headers, [b"session=yespls"]) + + class SiteTestCase(unittest.HomeserverTestCase): def test_lose_connection(self): """ -- cgit 1.5.1 From 19a1aac48cc83fe41287a97bb0a96280a0e8c565 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 15 Jan 2020 18:13:47 +0000 Subject: Fix purge_room admin API (#6711) --- changelog.d/6711.bugfix | 1 + synapse/storage/purge_events.py | 2 +- tests/rest/admin/test_admin.py | 4 +--- 3 files changed, 3 insertions(+), 4 deletions(-) create mode 100644 changelog.d/6711.bugfix (limited to 'tests') diff --git a/changelog.d/6711.bugfix b/changelog.d/6711.bugfix new file mode 100644 index 0000000000..c70506bd88 --- /dev/null +++ b/changelog.d/6711.bugfix @@ -0,0 +1 @@ +Fix `purge_room` admin API. diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py index d6a7bd7834..fdc0abf5cf 100644 --- a/synapse/storage/purge_events.py +++ b/synapse/storage/purge_events.py @@ -34,7 +34,7 @@ class PurgeEventsStorage(object): """ state_groups_to_delete = yield self.stores.main.purge_room(room_id) - yield self.stores.main.purge_room_state(room_id, state_groups_to_delete) + yield self.stores.state.purge_room_state(room_id, state_groups_to_delete) @defer.inlineCallbacks def purge_history(self, room_id, token, delete_local_events): diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 7a7e898843..f3b4a31e21 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -337,7 +337,7 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase): "local_invites", "room_account_data", "room_tags", - "state_groups", + # "state_groups", # Current impl leaves orphaned state groups around. "state_groups_state", ): count = self.get_success( @@ -351,8 +351,6 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase): self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) - test_purge_room.skip = "Disabled because it's currently broken" - class QuarantineMediaTestCase(unittest.HomeserverTestCase): """Test /quarantine_media admin API. -- cgit 1.5.1 From 48c3a96886de64f3141ad68b8163cd2fc0c197ff Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 16 Jan 2020 09:16:12 +0000 Subject: Port synapse.replication.tcp to async/await (#6666) * Port synapse.replication.tcp to async/await * Newsfile * Correctly document type of on_ functions as async * Don't be overenthusiastic with the asyncing.... --- changelog.d/6666.misc | 1 + synapse/app/admin_cmd.py | 3 +- synapse/app/appservice.py | 5 +-- synapse/app/federation_sender.py | 5 +-- synapse/app/pusher.py | 5 +-- synapse/app/synchrotron.py | 5 +-- synapse/app/user_dir.py | 5 +-- synapse/federation/send_queue.py | 4 +- synapse/handlers/typing.py | 2 +- synapse/replication/tcp/client.py | 11 ++--- synapse/replication/tcp/protocol.py | 72 ++++++++++++++----------------- synapse/replication/tcp/resource.py | 31 ++++++------- synapse/replication/tcp/streams/_base.py | 25 +++++------ synapse/replication/tcp/streams/events.py | 9 ++-- tests/replication/tcp/streams/_base.py | 2 +- 15 files changed, 80 insertions(+), 105 deletions(-) create mode 100644 changelog.d/6666.misc (limited to 'tests') diff --git a/changelog.d/6666.misc b/changelog.d/6666.misc new file mode 100644 index 0000000000..e79c23d2d2 --- /dev/null +++ b/changelog.d/6666.misc @@ -0,0 +1 @@ +Port `synapse.replication.tcp` to async/await. diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 8e36bc57d3..1c7c6ec0c8 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -84,8 +84,7 @@ class AdminCmdServer(HomeServer): class AdminCmdReplicationHandler(ReplicationClientHandler): - @defer.inlineCallbacks - def on_rdata(self, stream_name, token, rows): + async def on_rdata(self, stream_name, token, rows): pass def get_streams_to_replicate(self): diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index e82e0f11e3..2217d4a4fb 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -115,9 +115,8 @@ class ASReplicationHandler(ReplicationClientHandler): super(ASReplicationHandler, self).__init__(hs.get_datastore()) self.appservice_handler = hs.get_application_service_handler() - @defer.inlineCallbacks - def on_rdata(self, stream_name, token, rows): - yield super(ASReplicationHandler, self).on_rdata(stream_name, token, rows) + async def on_rdata(self, stream_name, token, rows): + await super(ASReplicationHandler, self).on_rdata(stream_name, token, rows) if stream_name == "events": max_stream_id = self.store.get_room_max_stream_ordering() diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 83c436229c..a57cf991ac 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -145,9 +145,8 @@ class FederationSenderReplicationHandler(ReplicationClientHandler): super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore()) self.send_handler = FederationSenderHandler(hs, self) - @defer.inlineCallbacks - def on_rdata(self, stream_name, token, rows): - yield super(FederationSenderReplicationHandler, self).on_rdata( + async def on_rdata(self, stream_name, token, rows): + await super(FederationSenderReplicationHandler, self).on_rdata( stream_name, token, rows ) self.send_handler.process_replication_rows(stream_name, token, rows) diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index 09e639040a..e46b6ac598 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -141,9 +141,8 @@ class PusherReplicationHandler(ReplicationClientHandler): self.pusher_pool = hs.get_pusherpool() - @defer.inlineCallbacks - def on_rdata(self, stream_name, token, rows): - yield super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows) + async def on_rdata(self, stream_name, token, rows): + await super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows) run_in_background(self.poke_pushers, stream_name, token, rows) @defer.inlineCallbacks diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 03031ee34d..3218da07bd 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -358,9 +358,8 @@ class SyncReplicationHandler(ReplicationClientHandler): self.presence_handler = hs.get_presence_handler() self.notifier = hs.get_notifier() - @defer.inlineCallbacks - def on_rdata(self, stream_name, token, rows): - yield super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows) + async def on_rdata(self, stream_name, token, rows): + await super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows) run_in_background(self.process_and_notify, stream_name, token, rows) def get_streams_to_replicate(self): diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index 1257098f92..ba536d6f04 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -172,9 +172,8 @@ class UserDirectoryReplicationHandler(ReplicationClientHandler): super(UserDirectoryReplicationHandler, self).__init__(hs.get_datastore()) self.user_directory = hs.get_user_directory_handler() - @defer.inlineCallbacks - def on_rdata(self, stream_name, token, rows): - yield super(UserDirectoryReplicationHandler, self).on_rdata( + async def on_rdata(self, stream_name, token, rows): + await super(UserDirectoryReplicationHandler, self).on_rdata( stream_name, token, rows ) if stream_name == EventsStream.NAME: diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index ced4925a98..174f6e42be 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -259,7 +259,9 @@ class FederationRemoteSendQueue(object): def federation_ack(self, token): self._clear_queue_before_pos(token) - def get_replication_rows(self, from_token, to_token, limit, federation_ack=None): + async def get_replication_rows( + self, from_token, to_token, limit, federation_ack=None + ): """Get rows to be sent over federation between the two tokens Args: diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index b635c339ed..d5ca9cb07b 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -257,7 +257,7 @@ class TypingHandler(object): "typing_key", self._latest_room_serial, rooms=[member.room_id] ) - def get_all_typing_updates(self, last_id, current_id): + async def get_all_typing_updates(self, last_id, current_id): if last_id == current_id: return [] diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index aa7fd90e26..52a0aefe68 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -110,7 +110,7 @@ class ReplicationClientHandler(AbstractReplicationClientHandler): port = hs.config.worker_replication_port hs.get_reactor().connectTCP(host, port, self.factory) - def on_rdata(self, stream_name, token, rows): + async def on_rdata(self, stream_name, token, rows): """Called to handle a batch of replication data with a given stream token. By default this just pokes the slave store. Can be overridden in subclasses to @@ -121,20 +121,17 @@ class ReplicationClientHandler(AbstractReplicationClientHandler): token (int): stream token for this batch of rows rows (list): a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. - - Returns: - Deferred|None """ logger.debug("Received rdata %s -> %s", stream_name, token) - return self.store.process_replication_rows(stream_name, token, rows) + self.store.process_replication_rows(stream_name, token, rows) - def on_position(self, stream_name, token): + async def on_position(self, stream_name, token): """Called when we get new position data. By default this just pokes the slave store. Can be overriden in subclasses to handle more. """ - return self.store.process_replication_rows(stream_name, token, []) + self.store.process_replication_rows(stream_name, token, []) def on_sync(self, data): """When we received a SYNC we wake up any deferreds that were waiting diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index db0353c996..5f4bdf84d2 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -81,12 +81,11 @@ from synapse.replication.tcp.commands import ( SyncCommand, UserSyncCommand, ) +from synapse.replication.tcp.streams import STREAMS_MAP from synapse.types import Collection from synapse.util import Clock from synapse.util.stringutils import random_string -from .streams import STREAMS_MAP - connection_close_counter = Counter( "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"] ) @@ -241,19 +240,16 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd ) - def handle_command(self, cmd): + async def handle_command(self, cmd: Command): """Handle a command we have received over the replication stream. - By default delegates to on_ + By default delegates to on_, which should return an awaitable. Args: - cmd (synapse.replication.tcp.commands.Command): received command - - Returns: - Deferred + cmd: received command """ handler = getattr(self, "on_%s" % (cmd.NAME,)) - return handler(cmd) + await handler(cmd) def close(self): logger.warning("[%s] Closing connection", self.id()) @@ -326,10 +322,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): for cmd in pending: self.send_command(cmd) - def on_PING(self, line): + async def on_PING(self, line): self.received_ping = True - def on_ERROR(self, cmd): + async def on_ERROR(self, cmd): logger.error("[%s] Remote reported error: %r", self.id(), cmd.data) def pauseProducing(self): @@ -429,16 +425,16 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): BaseReplicationStreamProtocol.connectionMade(self) self.streamer.new_connection(self) - def on_NAME(self, cmd): + async def on_NAME(self, cmd): logger.info("[%s] Renamed to %r", self.id(), cmd.data) self.name = cmd.data - def on_USER_SYNC(self, cmd): - return self.streamer.on_user_sync( + async def on_USER_SYNC(self, cmd): + await self.streamer.on_user_sync( self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms ) - def on_REPLICATE(self, cmd): + async def on_REPLICATE(self, cmd): stream_name = cmd.stream_name token = cmd.token @@ -449,23 +445,23 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): for stream in iterkeys(self.streamer.streams_by_name) ] - return make_deferred_yieldable( + await make_deferred_yieldable( defer.gatherResults(deferreds, consumeErrors=True) ) else: - return self.subscribe_to_stream(stream_name, token) + await self.subscribe_to_stream(stream_name, token) - def on_FEDERATION_ACK(self, cmd): - return self.streamer.federation_ack(cmd.token) + async def on_FEDERATION_ACK(self, cmd): + self.streamer.federation_ack(cmd.token) - def on_REMOVE_PUSHER(self, cmd): - return self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id) + async def on_REMOVE_PUSHER(self, cmd): + await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id) - def on_INVALIDATE_CACHE(self, cmd): - return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) + async def on_INVALIDATE_CACHE(self, cmd): + self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) - def on_USER_IP(self, cmd): - return self.streamer.on_user_ip( + async def on_USER_IP(self, cmd): + self.streamer.on_user_ip( cmd.user_id, cmd.access_token, cmd.ip, @@ -474,8 +470,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): cmd.last_seen, ) - @defer.inlineCallbacks - def subscribe_to_stream(self, stream_name, token): + async def subscribe_to_stream(self, stream_name, token): """Subscribe the remote to a stream. This invloves checking if they've missed anything and sending those @@ -487,7 +482,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): try: # Get missing updates - updates, current_token = yield self.streamer.get_stream_updates( + updates, current_token = await self.streamer.get_stream_updates( stream_name, token ) @@ -572,7 +567,7 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta): """ @abc.abstractmethod - def on_rdata(self, stream_name, token, rows): + async def on_rdata(self, stream_name, token, rows): """Called to handle a batch of replication data with a given stream token. Args: @@ -580,14 +575,11 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta): token (int): stream token for this batch of rows rows (list): a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. - - Returns: - Deferred|None """ raise NotImplementedError() @abc.abstractmethod - def on_position(self, stream_name, token): + async def on_position(self, stream_name, token): """Called when we get new position data.""" raise NotImplementedError() @@ -676,12 +668,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): if not self.streams_connecting: self.handler.finished_connecting() - def on_SERVER(self, cmd): + async def on_SERVER(self, cmd): if cmd.data != self.server_name: logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) self.send_error("Wrong remote") - def on_RDATA(self, cmd): + async def on_RDATA(self, cmd): stream_name = cmd.stream_name inbound_rdata_count.labels(stream_name).inc() @@ -701,19 +693,19 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): # Check if this is the last of a batch of updates rows = self.pending_batches.pop(stream_name, []) rows.append(row) - return self.handler.on_rdata(stream_name, cmd.token, rows) + await self.handler.on_rdata(stream_name, cmd.token, rows) - def on_POSITION(self, cmd): + async def on_POSITION(self, cmd): # When we get a `POSITION` command it means we've finished getting # missing updates for the given stream, and are now up to date. self.streams_connecting.discard(cmd.stream_name) if not self.streams_connecting: self.handler.finished_connecting() - return self.handler.on_position(cmd.stream_name, cmd.token) + await self.handler.on_position(cmd.stream_name, cmd.token) - def on_SYNC(self, cmd): - return self.handler.on_sync(cmd.data) + async def on_SYNC(self, cmd): + self.handler.on_sync(cmd.data) def replicate(self, stream_name, token): """Send the subscription request to the server diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index cbfdaf5773..b1752e88cd 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -23,7 +23,6 @@ from six import itervalues from prometheus_client import Counter -from twisted.internet import defer from twisted.internet.protocol import Factory from synapse.metrics import LaterGauge @@ -155,8 +154,7 @@ class ReplicationStreamer(object): run_as_background_process("replication_notifier", self._run_notifier_loop) - @defer.inlineCallbacks - def _run_notifier_loop(self): + async def _run_notifier_loop(self): self.is_looping = True try: @@ -185,7 +183,7 @@ class ReplicationStreamer(object): continue if self._replication_torture_level: - yield self.clock.sleep( + await self.clock.sleep( self._replication_torture_level / 1000.0 ) @@ -196,7 +194,7 @@ class ReplicationStreamer(object): stream.upto_token, ) try: - updates, current_token = yield stream.get_updates() + updates, current_token = await stream.get_updates() except Exception: logger.info("Failed to handle stream %s", stream.NAME) raise @@ -233,7 +231,7 @@ class ReplicationStreamer(object): self.is_looping = False @measure_func("repl.get_stream_updates") - def get_stream_updates(self, stream_name, token): + async def get_stream_updates(self, stream_name, token): """For a given stream get all updates since token. This is called when a client first subscribes to a stream. """ @@ -241,7 +239,7 @@ class ReplicationStreamer(object): if not stream: raise Exception("unknown stream %s", stream_name) - return stream.get_updates_since(token) + return await stream.get_updates_since(token) @measure_func("repl.federation_ack") def federation_ack(self, token): @@ -252,22 +250,20 @@ class ReplicationStreamer(object): self.federation_sender.federation_ack(token) @measure_func("repl.on_user_sync") - @defer.inlineCallbacks - def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms): + async def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms): """A client has started/stopped syncing on a worker. """ user_sync_counter.inc() - yield self.presence_handler.update_external_syncs_row( + await self.presence_handler.update_external_syncs_row( conn_id, user_id, is_syncing, last_sync_ms ) @measure_func("repl.on_remove_pusher") - @defer.inlineCallbacks - def on_remove_pusher(self, app_id, push_key, user_id): + async def on_remove_pusher(self, app_id, push_key, user_id): """A client has asked us to remove a pusher """ remove_pusher_counter.inc() - yield self.store.delete_pusher_by_app_id_pushkey_user_id( + await self.store.delete_pusher_by_app_id_pushkey_user_id( app_id=app_id, pushkey=push_key, user_id=user_id ) @@ -281,15 +277,16 @@ class ReplicationStreamer(object): getattr(self.store, cache_func).invalidate(tuple(keys)) @measure_func("repl.on_user_ip") - @defer.inlineCallbacks - def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen): + async def on_user_ip( + self, user_id, access_token, ip, user_agent, device_id, last_seen + ): """The client saw a user request """ user_ip_cache_counter.inc() - yield self.store.insert_client_ip( + await self.store.insert_client_ip( user_id, access_token, ip, user_agent, device_id, last_seen ) - yield self._server_notices_sender.on_user_ip(user_id) + await self._server_notices_sender.on_user_ip(user_id) def send_sync_to_all_connections(self, data): """Sends a SYNC command to all clients. diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 4ab0334fc1..e03e77199b 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -19,8 +19,6 @@ import logging from collections import namedtuple from typing import Any -from twisted.internet import defer - logger = logging.getLogger(__name__) @@ -144,8 +142,7 @@ class Stream(object): self.upto_token = self.current_token() self.last_token = self.upto_token - @defer.inlineCallbacks - def get_updates(self): + async def get_updates(self): """Gets all updates since the last time this function was called (or since the stream was constructed if it hadn't been called before), until the `upto_token` @@ -156,13 +153,12 @@ class Stream(object): list of ``(token, row)`` entries. ``row`` will be json-serialised and sent over the replication steam. """ - updates, current_token = yield self.get_updates_since(self.last_token) + updates, current_token = await self.get_updates_since(self.last_token) self.last_token = current_token return updates, current_token - @defer.inlineCallbacks - def get_updates_since(self, from_token): + async def get_updates_since(self, from_token): """Like get_updates except allows specifying from when we should stream updates @@ -182,15 +178,16 @@ class Stream(object): if from_token == current_token: return [], current_token + logger.info("get_updates_since: %s", self.__class__) if self._LIMITED: - rows = yield self.update_function( + rows = await self.update_function( from_token, current_token, limit=MAX_EVENTS_BEHIND + 1 ) # never turn more than MAX_EVENTS_BEHIND + 1 into updates. rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1) else: - rows = yield self.update_function(from_token, current_token) + rows = await self.update_function(from_token, current_token) updates = [(row[0], row[1:]) for row in rows] @@ -295,9 +292,8 @@ class PushRulesStream(Stream): push_rules_token, _ = self.store.get_push_rules_stream_token() return push_rules_token - @defer.inlineCallbacks - def update_function(self, from_token, to_token, limit): - rows = yield self.store.get_all_push_rule_updates(from_token, to_token, limit) + async def update_function(self, from_token, to_token, limit): + rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit) return [(row[0], row[2]) for row in rows] @@ -413,9 +409,8 @@ class AccountDataStream(Stream): super(AccountDataStream, self).__init__(hs) - @defer.inlineCallbacks - def update_function(self, from_token, to_token, limit): - global_results, room_results = yield self.store.get_all_updated_account_data( + async def update_function(self, from_token, to_token, limit): + global_results, room_results = await self.store.get_all_updated_account_data( from_token, from_token, to_token, limit ) diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 0843e5aa90..b3afabb8cd 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -19,8 +19,6 @@ from typing import Tuple, Type import attr -from twisted.internet import defer - from ._base import Stream @@ -122,16 +120,15 @@ class EventsStream(Stream): super(EventsStream, self).__init__(hs) - @defer.inlineCallbacks - def update_function(self, from_token, current_token, limit=None): - event_rows = yield self._store.get_all_new_forward_event_rows( + async def update_function(self, from_token, current_token, limit=None): + event_rows = await self._store.get_all_new_forward_event_rows( from_token, current_token, limit ) event_updates = ( (row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows ) - state_rows = yield self._store.get_all_updated_current_state_deltas( + state_rows = await self._store.get_all_updated_current_state_deltas( from_token, current_token, limit ) state_updates = ( diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index 1d14e77255..e96ad4ca4e 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -73,6 +73,6 @@ class TestReplicationClientHandler(object): def finished_connecting(self): pass - def on_rdata(self, stream_name, token, rows): + async def on_rdata(self, stream_name, token, rows): for r in rows: self.received_rdata_rows.append((stream_name, token, r)) -- cgit 1.5.1 From acc7820574426cf27673d941b1b0362272113351 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 16 Jan 2020 22:26:34 +0000 Subject: Log saml assertions rather than the whole response ... since the whole response is huge. We even need to break up the assertions, since kibana otherwise truncates them. --- synapse/handlers/saml_handler.py | 13 ++++++++++- synapse/util/iterutils.py | 13 +++++++++++ tests/util/test_itertools.py | 47 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 72 insertions(+), 1 deletion(-) create mode 100644 tests/util/test_itertools.py (limited to 'tests') diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 107f97032b..32638671c9 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -32,6 +32,7 @@ from synapse.types import ( mxid_localpart_allowed_characters, ) from synapse.util.async_helpers import Linearizer +from synapse.util.iterutils import chunk_seq logger = logging.getLogger(__name__) @@ -132,7 +133,17 @@ class SamlHandler: logger.warning("SAML2 response was not signed") raise SynapseError(400, "SAML2 response was not signed") - logger.info("SAML2 response: %s", saml2_auth.origxml) + logger.debug("SAML2 response: %s", saml2_auth.origxml) + for assertion in saml2_auth.assertions: + # kibana limits the length of a log field, whereas this is all rather + # useful, so split it up. + count = 0 + for part in chunk_seq(str(assertion), 10000): + logger.info( + "SAML2 assertion: %s%s", "(%i)..." % (count,) if count else "", part + ) + count += 1 + logger.info("SAML2 mapped attributes: %s", saml2_auth.ava) try: diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py index c10016fbc5..06faeebe7f 100644 --- a/synapse/util/iterutils.py +++ b/synapse/util/iterutils.py @@ -33,3 +33,16 @@ def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T]]: sourceiter = iter(iterable) # call islice until it returns an empty tuple return iter(lambda: tuple(islice(sourceiter, size)), ()) + + +ISeq = TypeVar("ISeq", bound=Sequence, covariant=True) + + +def chunk_seq(iseq: ISeq, maxlen: int) -> Iterable[ISeq]: + """Split the given sequence into chunks of the given size + + The last chunk may be shorter than the given size. + + If the input is empty, no chunks are returned. + """ + return (iseq[i : i + maxlen] for i in range(0, len(iseq), maxlen)) diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py new file mode 100644 index 0000000000..0ab0a91483 --- /dev/null +++ b/tests/util/test_itertools.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.util.iterutils import chunk_seq + +from tests.unittest import TestCase + + +class ChunkSeqTests(TestCase): + def test_short_seq(self): + parts = chunk_seq("123", 8) + + self.assertEqual( + list(parts), ["123"], + ) + + def test_long_seq(self): + parts = chunk_seq("abcdefghijklmnop", 8) + + self.assertEqual( + list(parts), ["abcdefgh", "ijklmnop"], + ) + + def test_uneven_parts(self): + parts = chunk_seq("abcdefghijklmnop", 5) + + self.assertEqual( + list(parts), ["abcde", "fghij", "klmno", "p"], + ) + + def test_empty_input(self): + parts = chunk_seq([], 5) + + self.assertEqual( + list(parts), [], + ) -- cgit 1.5.1 From ceecedc68ba1af25b0ee60c5cf927fd1fd245b9f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 20 Jan 2020 17:23:59 +0000 Subject: Fix changing password via user admin API. (#6730) --- changelog.d/6730.bugfix | 1 + synapse/rest/admin/users.py | 4 ++-- tests/rest/admin/test_user.py | 13 +++++++++++++ 3 files changed, 16 insertions(+), 2 deletions(-) create mode 100644 changelog.d/6730.bugfix (limited to 'tests') diff --git a/changelog.d/6730.bugfix b/changelog.d/6730.bugfix new file mode 100644 index 0000000000..beb444ca66 --- /dev/null +++ b/changelog.d/6730.bugfix @@ -0,0 +1 @@ +Fix changing password via user admin API. diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 574cb90c74..c178c960c5 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -193,8 +193,8 @@ class UserRestServletV2(RestServlet): raise SynapseError(400, "Invalid password") else: new_password = body["password"] - await self._set_password_handler.set_password( - target_user, new_password, requester + await self.set_password_handler.set_password( + target_user.to_string(), new_password, requester ) if "deactivated" in body: diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 7352d609e6..8f09f51c61 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -435,6 +435,19 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(0, channel.json_body["is_guest"]) self.assertEqual(0, channel.json_body["deactivated"]) + # Change password + body = json.dumps({"password": "hahaha"}) + + request, channel = self.make_request( + "PUT", + self.url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + # Modify user body = json.dumps({"displayname": "foobar", "deactivated": True}) -- cgit 1.5.1 From 74b74462f1c8b2db9b0995cbf64d879cbfce0dc4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 20 Jan 2020 17:38:09 +0000 Subject: Fix `/events/:event_id` deprecated API. (#6731) --- changelog.d/6731.bugfix | 1 + synapse/rest/client/v1/events.py | 2 +- tests/rest/client/v1/test_events.py | 27 +++++++++++++++++++++++++++ tests/unittest.py | 2 +- 4 files changed, 30 insertions(+), 2 deletions(-) create mode 100644 changelog.d/6731.bugfix (limited to 'tests') diff --git a/changelog.d/6731.bugfix b/changelog.d/6731.bugfix new file mode 100644 index 0000000000..21f6e15cbd --- /dev/null +++ b/changelog.d/6731.bugfix @@ -0,0 +1 @@ +Fix `/events/:event_id` deprecated API. diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 4beb617733..25effd0261 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -70,7 +70,6 @@ class EventStreamRestServlet(RestServlet): return 200, {} -# TODO: Unit test gets, with and without auth, with different kinds of events. class EventRestServlet(RestServlet): PATTERNS = client_patterns("/events/(?P[^/]*)$", v1=True) @@ -78,6 +77,7 @@ class EventRestServlet(RestServlet): super(EventRestServlet, self).__init__() self.clock = hs.get_clock() self.event_handler = hs.get_event_handler() + self.auth = hs.get_auth() self._event_serializer = hs.get_event_client_serializer() async def on_GET(self, request, event_id): diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index f340b7e851..ffb2de1505 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -134,3 +134,30 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): # someone else set topic, expect 6 (join,send,topic,join,send,topic) pass + + +class GetEventsTestCase(unittest.HomeserverTestCase): + servlets = [ + events.register_servlets, + room.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + ] + + def prepare(self, hs, reactor, clock): + + # register an account + self.user_id = self.register_user("sid1", "pass") + self.token = self.login(self.user_id, "pass") + + self.room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + def test_get_event_via_events(self): + resp = self.helper.send(self.room_id, tok=self.token) + event_id = resp["event_id"] + + request, channel = self.make_request( + "GET", "/events/" + event_id, access_token=self.token, + ) + self.render(request) + self.assertEquals(channel.code, 200, msg=channel.result) diff --git a/tests/unittest.py b/tests/unittest.py index ddcd4becfe..b56e249386 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -463,7 +463,7 @@ class HomeserverTestCase(TestCase): # Create the user request, channel = self.make_request("GET", "/_matrix/client/r0/admin/register") self.render(request) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, 200, msg=channel.result) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) -- cgit 1.5.1 From aa9b00fb2f9a7718d67fb11621a83035492ed9fb Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 22 Jan 2020 11:05:50 +0000 Subject: Fix and add test to deprecated quarantine media admin api (#6756) --- changelog.d/6756.feature | 1 + synapse/rest/admin/media.py | 2 +- tests/rest/admin/test_admin.py | 15 +++++++++++---- 3 files changed, 13 insertions(+), 5 deletions(-) create mode 100644 changelog.d/6756.feature (limited to 'tests') diff --git a/changelog.d/6756.feature b/changelog.d/6756.feature new file mode 100644 index 0000000000..6328c868f2 --- /dev/null +++ b/changelog.d/6756.feature @@ -0,0 +1 @@ +Add new quarantine media admin APIs to quarantine by media ID or by user who uploaded the media. \ No newline at end of file diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 3a445d6eed..ee75095c0e 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -36,7 +36,7 @@ class QuarantineMediaInRoom(RestServlet): historical_admin_path_patterns("/room/(?P[^/]+)/media/quarantine") + # This path kept around for legacy reasons - historical_admin_path_patterns("/quarantine_media/(?P![^/]+)") + historical_admin_path_patterns("/quarantine_media/(?P[^/]+)") ) def __init__(self, hs): diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index f3b4a31e21..af4d604e50 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -516,7 +516,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): ), ) - def test_quarantine_all_media_in_room(self): + def test_quarantine_all_media_in_room(self, override_url_template=None): self.register_user("room_admin", "pass", admin=True) admin_user_tok = self.login("room_admin", "pass") @@ -555,9 +555,12 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): ) # Quarantine all media in the room - url = "/_synapse/admin/v1/room/%s/media/quarantine" % urllib.parse.quote( - room_id - ) + if override_url_template: + url = override_url_template % urllib.parse.quote(room_id) + else: + url = "/_synapse/admin/v1/room/%s/media/quarantine" % urllib.parse.quote( + room_id + ) request, channel = self.make_request("POST", url, access_token=admin_user_tok,) self.render(request) self.pump(1.0) @@ -611,6 +614,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): ), ) + def test_quaraantine_all_media_in_room_deprecated_api_path(self): + # Perform the above test with the deprecated API path + self.test_quarantine_all_media_in_room("/_synapse/admin/v1/quarantine_media/%s") + def test_quarantine_all_media_by_user(self): self.register_user("user_admin", "pass", admin=True) admin_user_tok = self.login("user_admin", "pass") -- cgit 1.5.1 From 67aa18e8dc81d6d5c0645fc330e049ed7d2e786e Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 22 Jan 2020 12:28:07 +0000 Subject: Add tests for thumbnailing --- tests/rest/media/v1/test_media_storage.py | 48 +++++++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index bc662b61db..d80a7daed3 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -17,7 +17,7 @@ import os import shutil import tempfile -from binascii import unhexlify +from binascii import unhexlify, hexlify from mock import Mock from six.moves.urllib import parse @@ -149,6 +149,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): self.media_repo = hs.get_media_repository_resource() self.download_resource = self.media_repo.children[b"download"] + self.thumbnail_resource = self.media_repo.children[b"thumbnail"] # smol png self.end_content = unhexlify( @@ -157,10 +158,12 @@ class MediaRepoTests(unittest.HomeserverTestCase): b"0a2db40000000049454e44ae426082" ) + self.media_id = "example.com/12345" + def _req(self, content_disposition): request, channel = self.make_request( - "GET", "example.com/12345", shorthand=False + "GET", self.media_id, shorthand=False ) request.render(self.download_resource) self.pump() @@ -170,7 +173,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): self.assertEqual(len(self.fetches), 1) self.assertEqual(self.fetches[0][1], "example.com") self.assertEqual( - self.fetches[0][2], "/_matrix/media/v1/download/example.com/12345" + self.fetches[0][2], "/_matrix/media/v1/download/" + self.media_id ) self.assertEqual(self.fetches[0][3], {"allow_remote": "false"}) @@ -229,3 +232,42 @@ class MediaRepoTests(unittest.HomeserverTestCase): headers = channel.headers self.assertEqual(headers.getRawHeaders(b"Content-Type"), [b"image/png"]) self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None) + + def test_thumbnail_crop(self): + expected_body = unhexlify( + b"89504e470d0a1a0a0000000d4948445200000020000000200806" + b"000000737a7af40000001a49444154789cedc101010000008220" + b"ffaf6e484001000000ef0610200001194334ee0000000049454e" + b"44ae426082" + ) + + self._test_thumbnail("crop", expected_body) + + def test_thumbnail_scale(self): + expected_body = unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000d49444154789c636060606000000005" + b"0001a5f645400000000049454e44ae426082" + ) + + self._test_thumbnail("scale", expected_body) + + def _test_thumbnail(self, method, expected_body): + params = "?width=32&height=32&method=" + method + request, channel = self.make_request( + "GET", self.media_id + params, shorthand=False + ) + request.render(self.thumbnail_resource) + self.pump() + + headers = { + b"Content-Length": [b"%d" % (len(self.end_content))], + b"Content-Type": [b"image/png"], + } + self.fetches[0][0].callback( + (self.end_content, (len(self.end_content), headers)) + ) + self.pump() + + self.assertEqual(channel.code, 200) + self.assertEqual(channel.result["body"], expected_body, channel.result["body"]) -- cgit 1.5.1 From d9a8728b1183c416a89a79856f3b5185386600b2 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 22 Jan 2020 12:30:49 +0000 Subject: Remove unused import --- tests/rest/media/v1/test_media_storage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index d80a7daed3..6345cc7637 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -17,7 +17,7 @@ import os import shutil import tempfile -from binascii import unhexlify, hexlify +from binascii import unhexlify from mock import Mock from six.moves.urllib import parse -- cgit 1.5.1 From 6ae0c8db3335faa9b5f0e4407f7a4a3713c84062 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 22 Jan 2020 12:38:18 +0000 Subject: Lint + changelog --- changelog.d/6764.misc | 1 + tests/rest/media/v1/test_media_storage.py | 4 +--- 2 files changed, 2 insertions(+), 3 deletions(-) create mode 100644 changelog.d/6764.misc (limited to 'tests') diff --git a/changelog.d/6764.misc b/changelog.d/6764.misc new file mode 100644 index 0000000000..8edd767405 --- /dev/null +++ b/changelog.d/6764.misc @@ -0,0 +1 @@ +Fixup `synapse.rest` to pass mypy. diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 6345cc7637..1809ceb839 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -162,9 +162,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): def _req(self, content_disposition): - request, channel = self.make_request( - "GET", self.media_id, shorthand=False - ) + request, channel = self.make_request("GET", self.media_id, shorthand=False) request.render(self.download_resource) self.pump() -- cgit 1.5.1 From 90a28fb475a29daa9e7a9ee7204f6f76cc8af441 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 22 Jan 2020 13:36:43 +0000 Subject: Admin API to list, filter and sort rooms (#6720) --- changelog.d/6720.feature | 1 + docs/admin_api/rooms.md | 173 ++++++++++++++ synapse/rest/admin/__init__.py | 3 +- synapse/rest/admin/_base.py | 15 ++ synapse/rest/admin/rooms.py | 82 +++++++ synapse/rest/client/v2_alpha/_base.py | 2 +- synapse/storage/data_stores/main/room.py | 125 +++++++++- tests/rest/admin/test_admin.py | 393 ++++++++++++++++++++++++++++++- 8 files changed, 787 insertions(+), 7 deletions(-) create mode 100644 changelog.d/6720.feature create mode 100644 docs/admin_api/rooms.md (limited to 'tests') diff --git a/changelog.d/6720.feature b/changelog.d/6720.feature new file mode 100644 index 0000000000..dfc1b74d62 --- /dev/null +++ b/changelog.d/6720.feature @@ -0,0 +1 @@ +Add a new admin API to list and filter rooms on the server. \ No newline at end of file diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md new file mode 100644 index 0000000000..082721ea95 --- /dev/null +++ b/docs/admin_api/rooms.md @@ -0,0 +1,173 @@ +# List Room API + +The List Room admin API allows server admins to get a list of rooms on their +server. There are various parameters available that allow for filtering and +sorting the returned list. This API supports pagination. + +## Parameters + +The following query parameters are available: + +* `from` - Offset in the returned list. Defaults to `0`. +* `limit` - Maximum amount of rooms to return. Defaults to `100`. +* `order_by` - The method in which to sort the returned list of rooms. Valid values are: + - `alphabetical` - Rooms are ordered alphabetically by room name. This is the default. + - `size` - Rooms are ordered by the number of members. Largest to smallest. +* `dir` - Direction of room order. Either `f` for forwards or `b` for backwards. Setting + this value to `b` will reverse the above sort order. Defaults to `f`. +* `search_term` - Filter rooms by their room name. Search term can be contained in any + part of the room name. Defaults to no filtering. + +The following fields are possible in the JSON response body: + +* `rooms` - An array of objects, each containing information about a room. + - Room objects contain the following fields: + - `room_id` - The ID of the room. + - `name` - The name of the room. + - `canonical_alias` - The canonical (main) alias address of the room. + - `joined_members` - How many users are currently in the room. +* `offset` - The current pagination offset in rooms. This parameter should be + used instead of `next_token` for room offset as `next_token` is + not intended to be parsed. +* `total_rooms` - The total number of rooms this query can return. Using this + and `offset`, you have enough information to know the current + progression through the list. +* `next_batch` - If this field is present, we know that there are potentially + more rooms on the server that did not all fit into this response. + We can use `next_batch` to get the "next page" of results. To do + so, simply repeat your request, setting the `from` parameter to + the value of `next_batch`. +* `prev_batch` - If this field is present, it is possible to paginate backwards. + Use `prev_batch` for the `from` value in the next request to + get the "previous page" of results. + +## Usage + +A standard request with no filtering: + +``` +GET /_synapse/admin/rooms + +{} +``` + +Response: + +``` +{ + "rooms": [ + { + "room_id": "!OGEhHVWSdvArJzumhm:matrix.org", + "name": "Matrix HQ", + "canonical_alias": "#matrix:matrix.org", + "joined_members": 8326 + }, + ... (8 hidden items) ... + { + "room_id": "!xYvNcQPhnkrdUmYczI:matrix.org", + "name": "This Week In Matrix (TWIM)", + "canonical_alias": "#twim:matrix.org", + "joined_members": 314 + } + ], + "offset": 0, + "total_rooms": 10 +} +``` + +Filtering by room name: + +``` +GET /_synapse/admin/rooms?search_term=TWIM + +{} +``` + +Response: + +``` +{ + "rooms": [ + { + "room_id": "!xYvNcQPhnkrdUmYczI:matrix.org", + "name": "This Week In Matrix (TWIM)", + "canonical_alias": "#twim:matrix.org", + "joined_members": 314 + } + ], + "offset": 0, + "total_rooms": 1 +} +``` + +Paginating through a list of rooms: + +``` +GET /_synapse/admin/rooms?order_by=size + +{} +``` + +Response: + +``` +{ + "rooms": [ + { + "room_id": "!OGEhHVWSdvArJzumhm:matrix.org", + "name": "Matrix HQ", + "canonical_alias": "#matrix:matrix.org", + "joined_members": 8326 + }, + ... (98 hidden items) ... + { + "room_id": "!xYvNcQPhnkrdUmYczI:matrix.org", + "name": "This Week In Matrix (TWIM)", + "canonical_alias": "#twim:matrix.org", + "joined_members": 314 + } + ], + "offset": 0, + "total_rooms": 150 + "next_token": 100 +} +``` + +The presence of the `next_token` parameter tells us that there are more rooms +than returned in this request, and we need to make another request to get them. +To get the next batch of room results, we repeat our request, setting the `from` +parameter to the value of `next_token`. + +``` +GET /_synapse/admin/rooms?order_by=size&from=100 + +{} +``` + +Response: + +``` +{ + "rooms": [ + { + "room_id": "!mscvqgqpHYjBGDxNym:matrix.org", + "name": "Music Theory", + "canonical_alias": "#musictheory:matrix.org", + "joined_members": 127 + }, + ... (48 hidden items) ... + { + "room_id": "!twcBhHVdZlQWuuxBhN:termina.org.uk", + "name": "weechat-matrix", + "canonical_alias": "#weechat-matrix:termina.org.uk", + "joined_members": 137 + } + ], + "offset": 100, + "prev_batch": 0, + "total_rooms": 150 +} +``` + +Once the `next_token` parameter is no longer present, we know we've reached the +end of the list. diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 2932fe2123..42cc2b062a 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -29,7 +29,7 @@ from synapse.rest.admin._base import ( from synapse.rest.admin.groups import DeleteGroupAdminRestServlet from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet -from synapse.rest.admin.rooms import ShutdownRoomRestServlet +from synapse.rest.admin.rooms import ListRoomRestServlet, ShutdownRoomRestServlet from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet from synapse.rest.admin.users import ( AccountValidityRenewServlet, @@ -188,6 +188,7 @@ def register_servlets(hs, http_server): Register all the admin servlets. """ register_servlets_for_client_rest_resource(hs, http_server) + ListRoomRestServlet(hs).register(http_server) PurgeRoomServlet(hs).register(http_server) SendServerNoticeServlet(hs).register(http_server) VersionServlet(hs).register(http_server) diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py index afd0647205..459482eb6d 100644 --- a/synapse/rest/admin/_base.py +++ b/synapse/rest/admin/_base.py @@ -40,6 +40,21 @@ def historical_admin_path_patterns(path_regex): ) +def admin_patterns(path_regex: str): + """Returns the list of patterns for an admin endpoint + + Args: + path_regex: The regex string to match. This should NOT have a ^ + as this will be prefixed. + + Returns: + A list of regex patterns. + """ + admin_prefix = "^/_synapse/admin/v1" + patterns = [re.compile(admin_prefix + path_regex)] + return patterns + + async def assert_requester_is_admin(auth, request): """Verify that the requester is an admin user diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index f7cc5e9be9..f9b8c0a4f0 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -15,15 +15,20 @@ import logging from synapse.api.constants import Membership +from synapse.api.errors import Codes, SynapseError from synapse.http.servlet import ( RestServlet, assert_params_in_dict, + parse_integer, parse_json_object_from_request, + parse_string, ) from synapse.rest.admin._base import ( + admin_patterns, assert_user_is_admin, historical_admin_path_patterns, ) +from synapse.storage.data_stores.main.room import RoomSortOrder from synapse.types import create_requester from synapse.util.async_helpers import maybe_awaitable @@ -155,3 +160,80 @@ class ShutdownRoomRestServlet(RestServlet): "new_room_id": new_room_id, }, ) + + +class ListRoomRestServlet(RestServlet): + """ + List all rooms that are known to the homeserver. Results are returned + in a dictionary containing room information. Supports pagination. + """ + + PATTERNS = admin_patterns("/rooms") + + def __init__(self, hs): + self.store = hs.get_datastore() + self.auth = hs.get_auth() + self.admin_handler = hs.get_handlers().admin_handler + + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + # Extract query parameters + start = parse_integer(request, "from", default=0) + limit = parse_integer(request, "limit", default=100) + order_by = parse_string(request, "order_by", default="alphabetical") + if order_by not in ( + RoomSortOrder.ALPHABETICAL.value, + RoomSortOrder.SIZE.value, + ): + raise SynapseError( + 400, + "Unknown value for order_by: %s" % (order_by,), + errcode=Codes.INVALID_PARAM, + ) + + search_term = parse_string(request, "search_term") + if search_term == "": + raise SynapseError( + 400, + "search_term cannot be an empty string", + errcode=Codes.INVALID_PARAM, + ) + + direction = parse_string(request, "dir", default="f") + if direction not in ("f", "b"): + raise SynapseError( + 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM + ) + + reverse_order = True if direction == "b" else False + + # Return list of rooms according to parameters + rooms, total_rooms = await self.store.get_rooms_paginate( + start, limit, order_by, reverse_order, search_term + ) + response = { + # next_token should be opaque, so return a value the client can parse + "offset": start, + "rooms": rooms, + "total_rooms": total_rooms, + } + + # Are there more rooms to paginate through after this? + if (start + limit) < total_rooms: + # There are. Calculate where the query should start from next time + # to get the next part of the list + response["next_batch"] = start + limit + + # Is it possible to paginate backwards? Check if we currently have an + # offset + if start > 0: + if start > limit: + # Going back one iteration won't take us to the start. + # Calculate new offset + response["prev_batch"] = start - limit + else: + response["prev_batch"] = 0 + + return 200, response diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index 2a3f4dd58f..bc11b4dda4 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -32,7 +32,7 @@ def client_patterns(path_regex, releases=(0,), unstable=True, v1=False): Args: path_regex (str): The regex string to match. This should NOT have a ^ - as this will be prefixed. + as this will be prefixed. Returns: SRE_Pattern """ diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index 49bab62be3..d968803ad2 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -18,7 +18,8 @@ import collections import logging import re from abc import abstractmethod -from typing import List, Optional, Tuple +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple from six import integer_types @@ -46,6 +47,18 @@ RatelimitOverride = collections.namedtuple( ) +class RoomSortOrder(Enum): + """ + Enum to define the sorting method used when returning rooms with get_rooms_paginate + + ALPHABETICAL = sort rooms alphabetically by name + SIZE = sort rooms by membership size, highest to lowest + """ + + ALPHABETICAL = "alphabetical" + SIZE = "size" + + class RoomWorkerStore(SQLBaseStore): def __init__(self, database: Database, db_conn, hs): super(RoomWorkerStore, self).__init__(database, db_conn, hs) @@ -281,6 +294,116 @@ class RoomWorkerStore(SQLBaseStore): desc="is_room_blocked", ) + async def get_rooms_paginate( + self, + start: int, + limit: int, + order_by: RoomSortOrder, + reverse_order: bool, + search_term: Optional[str], + ) -> Tuple[List[Dict[str, Any]], int]: + """Function to retrieve a paginated list of rooms as json. + + Args: + start: offset in the list + limit: maximum amount of rooms to retrieve + order_by: the sort order of the returned list + reverse_order: whether to reverse the room list + search_term: a string to filter room names by + Returns: + A list of room dicts and an integer representing the total number of + rooms that exist given this query + """ + # Filter room names by a string + where_statement = "" + if search_term: + where_statement = "WHERE state.name LIKE ?" + + # Our postgres db driver converts ? -> %s in SQL strings as that's the + # placeholder for postgres. + # HOWEVER, if you put a % into your SQL then everything goes wibbly. + # To get around this, we're going to surround search_term with %'s + # before giving it to the database in python instead + search_term = "%" + search_term + "%" + + # Set ordering + if RoomSortOrder(order_by) == RoomSortOrder.SIZE: + order_by_column = "curr.joined_members" + order_by_asc = False + elif RoomSortOrder(order_by) == RoomSortOrder.ALPHABETICAL: + # Sort alphabetically + order_by_column = "state.name" + order_by_asc = True + else: + raise StoreError( + 500, "Incorrect value for order_by provided: %s" % order_by + ) + + # Whether to return the list in reverse order + if reverse_order: + # Flip the boolean + order_by_asc = not order_by_asc + + # Create one query for getting the limited number of events that the user asked + # for, and another query for getting the total number of events that could be + # returned. Thus allowing us to see if there are more events to paginate through + info_sql = """ + SELECT state.room_id, state.name, state.canonical_alias, curr.joined_members + FROM room_stats_state state + INNER JOIN room_stats_current curr USING (room_id) + %s + ORDER BY %s %s + LIMIT ? + OFFSET ? + """ % ( + where_statement, + order_by_column, + "ASC" if order_by_asc else "DESC", + ) + + # Use a nested SELECT statement as SQL can't count(*) with an OFFSET + count_sql = """ + SELECT count(*) FROM ( + SELECT room_id FROM room_stats_state state + %s + ) AS get_room_ids + """ % ( + where_statement, + ) + + def _get_rooms_paginate_txn(txn): + # Execute the data query + sql_values = (limit, start) + if search_term: + # Add the search term into the WHERE clause + sql_values = (search_term,) + sql_values + txn.execute(info_sql, sql_values) + + # Refactor room query data into a structured dictionary + rooms = [] + for room in txn: + rooms.append( + { + "room_id": room[0], + "name": room[1], + "canonical_alias": room[2], + "joined_members": room[3], + } + ) + + # Execute the count query + + # Add the search term into the WHERE clause if present + sql_values = (search_term,) if search_term else () + txn.execute(count_sql, sql_values) + + room_count = txn.fetchone() + return rooms, room_count[0] + + return await self.db.runInteraction( + "get_rooms_paginate", _get_rooms_paginate_txn, + ) + @cachedInlineCallbacks(max_entries=10000) def get_ratelimit_for_user(self, user_id): """Check if there are any overrides for ratelimiting for the given diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index af4d604e50..0342aed416 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -17,6 +17,7 @@ import json import os import urllib.parse from binascii import unhexlify +from typing import List, Optional from mock import Mock @@ -26,7 +27,7 @@ import synapse.rest.admin from synapse.http.server import JsonResource from synapse.logging.context import make_deferred_yieldable from synapse.rest.admin import VersionServlet -from synapse.rest.client.v1 import events, login, room +from synapse.rest.client.v1 import directory, events, login, room from synapse.rest.client.v2_alpha import groups from tests import unittest @@ -468,9 +469,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): ) # Extract media ID from the response - server_name_and_media_id = response["content_uri"][ - 6: - ] # Cut off the 'mxc://' bit + server_name_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' server_name, media_id = server_name_and_media_id.split("/") # Attempt to access the media @@ -692,3 +691,389 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): % server_and_media_id_2 ), ) + + +class RoomTestCase(unittest.HomeserverTestCase): + """Test /room admin API. + """ + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + directory.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + # Create user + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + def test_list_rooms(self): + """Test that we can list rooms""" + # Create 3 test rooms + total_rooms = 3 + room_ids = [] + for x in range(total_rooms): + room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + room_ids.append(room_id) + + # Request the list of rooms + url = "/_synapse/admin/v1/rooms" + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + + # Check request completed successfully + self.assertEqual(200, int(channel.code), msg=channel.json_body) + + # Check that response json body contains a "rooms" key + self.assertTrue( + "rooms" in channel.json_body, + msg="Response body does not " "contain a 'rooms' key", + ) + + # Check that 3 rooms were returned + self.assertEqual(3, len(channel.json_body["rooms"]), msg=channel.json_body) + + # Check their room_ids match + returned_room_ids = [room["room_id"] for room in channel.json_body["rooms"]] + self.assertEqual(room_ids, returned_room_ids) + + # Check that all fields are available + for r in channel.json_body["rooms"]: + self.assertIn("name", r) + self.assertIn("canonical_alias", r) + self.assertIn("joined_members", r) + + # Check that the correct number of total rooms was returned + self.assertEqual(channel.json_body["total_rooms"], total_rooms) + + # Check that the offset is correct + # Should be 0 as we aren't paginating + self.assertEqual(channel.json_body["offset"], 0) + + # Check that the prev_batch parameter is not present + self.assertNotIn("prev_batch", channel.json_body) + + # We shouldn't receive a next token here as there's no further rooms to show + self.assertNotIn("next_batch", channel.json_body) + + def test_list_rooms_pagination(self): + """Test that we can get a full list of rooms through pagination""" + # Create 5 test rooms + total_rooms = 5 + room_ids = [] + for x in range(total_rooms): + room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + room_ids.append(room_id) + + # Set the name of the rooms so we get a consistent returned ordering + for idx, room_id in enumerate(room_ids): + self.helper.send_state( + room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok, + ) + + # Request the list of rooms + returned_room_ids = [] + start = 0 + limit = 2 + + run_count = 0 + should_repeat = True + while should_repeat: + run_count += 1 + + url = "/_synapse/admin/v1/rooms?from=%d&limit=%d&order_by=%s" % ( + start, + limit, + "alphabetical", + ) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual( + 200, int(channel.result["code"]), msg=channel.result["body"] + ) + + self.assertTrue("rooms" in channel.json_body) + for r in channel.json_body["rooms"]: + returned_room_ids.append(r["room_id"]) + + # Check that the correct number of total rooms was returned + self.assertEqual(channel.json_body["total_rooms"], total_rooms) + + # Check that the offset is correct + # We're only getting 2 rooms each page, so should be 2 * last run_count + self.assertEqual(channel.json_body["offset"], 2 * (run_count - 1)) + + if run_count > 1: + # Check the value of prev_batch is correct + self.assertEqual(channel.json_body["prev_batch"], 2 * (run_count - 2)) + + if "next_batch" not in channel.json_body: + # We have reached the end of the list + should_repeat = False + else: + # Make another query with an updated start value + start = channel.json_body["next_batch"] + + # We should've queried the endpoint 3 times + self.assertEqual( + run_count, + 3, + msg="Should've queried 3 times for 5 rooms with limit 2 per query", + ) + + # Check that we received all of the room ids + self.assertEqual(room_ids, returned_room_ids) + + url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + def test_correct_room_attributes(self): + """Test the correct attributes for a room are returned""" + # Create a test room + room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + test_alias = "#test:test" + test_room_name = "something" + + # Have another user join the room + user_2 = self.register_user("user4", "pass") + user_tok_2 = self.login("user4", "pass") + self.helper.join(room_id, user_2, tok=user_tok_2) + + # Create a new alias to this room + url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),) + request, channel = self.make_request( + "PUT", + url.encode("ascii"), + {"room_id": room_id}, + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Set this new alias as the canonical alias for this room + self.helper.send_state( + room_id, + "m.room.aliases", + {"aliases": [test_alias]}, + tok=self.admin_user_tok, + state_key="test", + ) + self.helper.send_state( + room_id, + "m.room.canonical_alias", + {"alias": test_alias}, + tok=self.admin_user_tok, + ) + + # Set a name for the room + self.helper.send_state( + room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok, + ) + + # Request the list of rooms + url = "/_synapse/admin/v1/rooms" + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Check that rooms were returned + self.assertTrue("rooms" in channel.json_body) + rooms = channel.json_body["rooms"] + + # Check that only one room was returned + self.assertEqual(len(rooms), 1) + + # And that the value of the total_rooms key was correct + self.assertEqual(channel.json_body["total_rooms"], 1) + + # Check that the offset is correct + # We're not paginating, so should be 0 + self.assertEqual(channel.json_body["offset"], 0) + + # Check that there is no `prev_batch` + self.assertNotIn("prev_batch", channel.json_body) + + # Check that there is no `next_batch` + self.assertNotIn("next_batch", channel.json_body) + + # Check that all provided attributes are set + r = rooms[0] + self.assertEqual(room_id, r["room_id"]) + self.assertEqual(test_room_name, r["name"]) + self.assertEqual(test_alias, r["canonical_alias"]) + + def test_room_list_sort_order(self): + """Test room list sort ordering. alphabetical versus number of members, + reversing the order, etc. + """ + # Create 3 test rooms + room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_3 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C + self.helper.send_state( + room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok, + ) + self.helper.send_state( + room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok, + ) + self.helper.send_state( + room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok, + ) + + # Set room member size in the reverse order. room 1 -> 1 member, 2 -> 2, 3 -> 3 + user_1 = self.register_user("bob1", "pass") + user_1_tok = self.login("bob1", "pass") + self.helper.join(room_id_2, user_1, tok=user_1_tok) + + user_2 = self.register_user("bob2", "pass") + user_2_tok = self.login("bob2", "pass") + self.helper.join(room_id_3, user_2, tok=user_2_tok) + + user_3 = self.register_user("bob3", "pass") + user_3_tok = self.login("bob3", "pass") + self.helper.join(room_id_3, user_3, tok=user_3_tok) + + def _order_test( + order_type: str, expected_room_list: List[str], reverse: bool = False, + ): + """Request the list of rooms in a certain order. Assert that order is what + we expect + + Args: + order_type: The type of ordering to give the server + expected_room_list: The list of room_ids in the order we expect to get + back from the server + """ + # Request the list of rooms in the given order + url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,) + if reverse: + url += "&dir=b" + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # Check that rooms were returned + self.assertTrue("rooms" in channel.json_body) + rooms = channel.json_body["rooms"] + + # Check for the correct total_rooms value + self.assertEqual(channel.json_body["total_rooms"], 3) + + # Check that the offset is correct + # We're not paginating, so should be 0 + self.assertEqual(channel.json_body["offset"], 0) + + # Check that there is no `prev_batch` + self.assertNotIn("prev_batch", channel.json_body) + + # Check that there is no `next_batch` + self.assertNotIn("next_batch", channel.json_body) + + # Check that rooms were returned in alphabetical order + returned_order = [r["room_id"] for r in rooms] + self.assertListEqual(expected_room_list, returned_order) # order is checked + + # Test different sort orders, with forward and reverse directions + _order_test("alphabetical", [room_id_1, room_id_2, room_id_3]) + _order_test("alphabetical", [room_id_3, room_id_2, room_id_1], reverse=True) + + _order_test("size", [room_id_3, room_id_2, room_id_1]) + _order_test("size", [room_id_1, room_id_2, room_id_3], reverse=True) + + def test_search_term(self): + """Test that searching for a room works correctly""" + # Create two test rooms + room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + room_name_1 = "something" + room_name_2 = "else" + + # Set the name for each room + self.helper.send_state( + room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok, + ) + self.helper.send_state( + room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok, + ) + + def _search_test( + expected_room_id: Optional[str], + search_term: str, + expected_http_code: int = 200, + ): + """Search for a room and check that the returned room's id is a match + + Args: + expected_room_id: The room_id expected to be returned by the API. Set + to None to expect zero results for the search + search_term: The term to search for room names with + expected_http_code: The expected http code for the request + """ + url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) + + if expected_http_code != 200: + return + + # Check that rooms were returned + self.assertTrue("rooms" in channel.json_body) + rooms = channel.json_body["rooms"] + + # Check that the expected number of rooms were returned + expected_room_count = 1 if expected_room_id else 0 + self.assertEqual(len(rooms), expected_room_count) + self.assertEqual(channel.json_body["total_rooms"], expected_room_count) + + # Check that the offset is correct + # We're not paginating, so should be 0 + self.assertEqual(channel.json_body["offset"], 0) + + # Check that there is no `prev_batch` + self.assertNotIn("prev_batch", channel.json_body) + + # Check that there is no `next_batch` + self.assertNotIn("next_batch", channel.json_body) + + if expected_room_id: + # Check that the first returned room id is correct + r = rooms[0] + self.assertEqual(expected_room_id, r["room_id"]) + + # Perform search tests + _search_test(room_id_1, "something") + _search_test(room_id_1, "thing") + + _search_test(room_id_2, "else") + _search_test(room_id_2, "se") + + _search_test(None, "foo") + _search_test(None, "bar") + _search_test(None, "", expected_http_code=400) -- cgit 1.5.1 From fa4d609e20318821e2ffbeb35bfddbc86be81be0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 23 Jan 2020 15:19:03 +0000 Subject: Make 'event.redacts' never raise. (#6771) There are quite a few places that we assume that a redaction event has a corresponding `redacts` key, which is not always the case. So lets cheekily make it so that event.redacts just returns None instead. --- changelog.d/6771.bugfix | 1 + synapse/events/__init__.py | 28 +++++++++++++++--- synapse/storage/data_stores/main/events.py | 2 +- synapse/storage/data_stores/main/events_worker.py | 2 +- tests/storage/test_redaction.py | 35 +++++++++++++++++++++++ 5 files changed, 62 insertions(+), 6 deletions(-) create mode 100644 changelog.d/6771.bugfix (limited to 'tests') diff --git a/changelog.d/6771.bugfix b/changelog.d/6771.bugfix new file mode 100644 index 0000000000..623ba24acb --- /dev/null +++ b/changelog.d/6771.bugfix @@ -0,0 +1 @@ +Fix persisting redaction events that have been redacted (or otherwise don't have a redacts key). diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 88ed6d764f..72c09327f4 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -116,16 +116,32 @@ class _EventInternalMetadata(object): return getattr(self, "redacted", False) -def _event_dict_property(key): +_SENTINEL = object() + + +def _event_dict_property(key, default=_SENTINEL): + """Creates a new property for the given key that delegates access to + `self._event_dict`. + + The default is used if the key is missing from the `_event_dict`, if given, + otherwise an AttributeError will be raised. + + Note: If a default is given then `hasattr` will always return true. + """ + # We want to be able to use hasattr with the event dict properties. # However, (on python3) hasattr expects AttributeError to be raised. Hence, # we need to transform the KeyError into an AttributeError - def getter(self): + + def getter_raises(self): try: return self._event_dict[key] except KeyError: raise AttributeError(key) + def getter_default(self): + return self._event_dict.get(key, default) + def setter(self, v): try: self._event_dict[key] = v @@ -138,7 +154,11 @@ def _event_dict_property(key): except KeyError: raise AttributeError(key) - return property(getter, setter, delete) + if default is _SENTINEL: + # No default given, so use the getter that raises + return property(getter_raises, setter, delete) + else: + return property(getter_default, setter, delete) class EventBase(object): @@ -165,7 +185,7 @@ class EventBase(object): origin = _event_dict_property("origin") origin_server_ts = _event_dict_property("origin_server_ts") prev_events = _event_dict_property("prev_events") - redacts = _event_dict_property("redacts") + redacts = _event_dict_property("redacts", None) room_id = _event_dict_property("room_id") sender = _event_dict_property("sender") user_id = _event_dict_property("sender") diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 596daf8909..ce553566a5 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -951,7 +951,7 @@ class EventsStore( elif event.type == EventTypes.Message: # Insert into the event_search table. self._store_room_message_txn(txn, event) - elif event.type == EventTypes.Redaction: + elif event.type == EventTypes.Redaction and event.redacts is not None: # Insert into the redactions table. self._store_redaction(txn, event) elif event.type == EventTypes.Retention: diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 3b93e0597a..7251e819f5 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -287,7 +287,7 @@ class EventsWorkerStore(SQLBaseStore): # we have to recheck auth now. if not allow_rejected and entry.event.type == EventTypes.Redaction: - if not hasattr(entry.event, "redacts"): + if entry.event.redacts is None: # A redacted redaction doesn't have a `redacts` key, in # which case lets just withhold the event. # diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index dc45173355..feb1c07cb2 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -398,3 +398,38 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.get_success( self.store.get_event(first_redact_event.event_id, allow_none=True) ) + + def test_store_redacted_redaction(self): + """Tests that we can store a redacted redaction. + """ + + self.get_success( + self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) + ) + + builder = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.Redaction, + "sender": self.u_alice.to_string(), + "room_id": self.room1.to_string(), + "content": {"reason": "foo"}, + }, + ) + + redaction_event, context = self.get_success( + self.event_creation_handler.create_new_client_event(builder) + ) + + self.get_success( + self.storage.persistence.persist_event(redaction_event, context) + ) + + # Now lets jump to the future where we have censored the redaction event + # in the DB. + self.reactor.advance(60 * 60 * 24 * 31) + + # We just want to check that fetching the event doesn't raise an exception. + self.get_success( + self.store.get_event(redaction_event.event_id, allow_none=True) + ) -- cgit 1.5.1 From 9f7aaf90b5ef76416852f35201a851d45eccc0a1 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 24 Jan 2020 14:28:40 +0000 Subject: Validate client_secret parameter (#6767) --- changelog.d/6767.bugfix | 1 + synapse/handlers/identity.py | 4 ++- synapse/rest/client/v2_alpha/account.py | 23 ++++++++++---- synapse/rest/client/v2_alpha/register.py | 3 ++ synapse/util/stringutils.py | 17 +++++++++++ tests/util/test_stringutils.py | 51 ++++++++++++++++++++++++++++++++ 6 files changed, 93 insertions(+), 6 deletions(-) create mode 100644 changelog.d/6767.bugfix create mode 100644 tests/util/test_stringutils.py (limited to 'tests') diff --git a/changelog.d/6767.bugfix b/changelog.d/6767.bugfix new file mode 100644 index 0000000000..63c7c63315 --- /dev/null +++ b/changelog.d/6767.bugfix @@ -0,0 +1 @@ +Validate `client_secret` parameter using the regex provided by the Client-Server API, temporarily allowing `:` characters for older clients. The `:` character will be removed in a future release. diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 000fbf090f..23f07832e7 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -38,7 +38,7 @@ from synapse.api.errors import ( from synapse.config.emailconfig import ThreepidBehaviour from synapse.http.client import SimpleHttpClient from synapse.util.hash import sha256_and_url_safe_base64 -from synapse.util.stringutils import random_string +from synapse.util.stringutils import assert_valid_client_secret, random_string from ._base import BaseHandler @@ -84,6 +84,8 @@ class IdentityHandler(BaseHandler): raise SynapseError( 400, "Missing param client_secret in creds", errcode=Codes.MISSING_PARAM ) + assert_valid_client_secret(client_secret) + session_id = creds.get("sid") if not session_id: raise SynapseError( diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index fc240f5cf8..dc837d6c75 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -30,6 +30,7 @@ from synapse.http.servlet import ( ) from synapse.push.mailer import Mailer, load_jinja2_templates from synapse.util.msisdn import phone_number_to_msisdn +from synapse.util.stringutils import assert_valid_client_secret from synapse.util.threepids import check_3pid_allowed from ._base import client_patterns, interactive_auth_handler @@ -81,6 +82,8 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # Extract params from body client_secret = body["client_secret"] + assert_valid_client_secret(client_secret) + email = body["email"] send_attempt = body["send_attempt"] next_link = body.get("next_link") # Optional param @@ -166,8 +169,9 @@ class PasswordResetSubmitTokenServlet(RestServlet): ) sid = parse_string(request, "sid", required=True) - client_secret = parse_string(request, "client_secret", required=True) token = parse_string(request, "token", required=True) + client_secret = parse_string(request, "client_secret", required=True) + assert_valid_client_secret(client_secret) # Attempt to validate a 3PID session try: @@ -353,6 +357,8 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): body = parse_json_object_from_request(request) assert_params_in_dict(body, ["client_secret", "email", "send_attempt"]) client_secret = body["client_secret"] + assert_valid_client_secret(client_secret) + email = body["email"] send_attempt = body["send_attempt"] next_link = body.get("next_link") # Optional param @@ -413,6 +419,8 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): body, ["client_secret", "country", "phone_number", "send_attempt"] ) client_secret = body["client_secret"] + assert_valid_client_secret(client_secret) + country = body["country"] phone_number = body["phone_number"] send_attempt = body["send_attempt"] @@ -493,8 +501,9 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): ) sid = parse_string(request, "sid", required=True) - client_secret = parse_string(request, "client_secret", required=True) token = parse_string(request, "token", required=True) + client_secret = parse_string(request, "client_secret", required=True) + assert_valid_client_secret(client_secret) # Attempt to validate a 3PID session try: @@ -559,6 +568,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet): body = parse_json_object_from_request(request) assert_params_in_dict(body, ["client_secret", "sid", "token"]) + assert_valid_client_secret(body["client_secret"]) # Proxy submit_token request to msisdn threepid delegate response = await self.identity_handler.proxy_msisdn_submit_token( @@ -600,8 +610,9 @@ class ThreepidRestServlet(RestServlet): ) assert_params_in_dict(threepid_creds, ["client_secret", "sid"]) - client_secret = threepid_creds["client_secret"] sid = threepid_creds["sid"] + client_secret = threepid_creds["client_secret"] + assert_valid_client_secret(client_secret) validation_session = await self.identity_handler.validate_threepid_session( client_secret, sid @@ -637,8 +648,9 @@ class ThreepidAddRestServlet(RestServlet): body = parse_json_object_from_request(request) assert_params_in_dict(body, ["client_secret", "sid"]) - client_secret = body["client_secret"] sid = body["sid"] + client_secret = body["client_secret"] + assert_valid_client_secret(client_secret) await self.auth_handler.validate_user_via_ui_auth( requester, body, self.hs.get_ip_from_request(request) @@ -676,8 +688,9 @@ class ThreepidBindRestServlet(RestServlet): assert_params_in_dict(body, ["id_server", "sid", "client_secret"]) id_server = body["id_server"] sid = body["sid"] - client_secret = body["client_secret"] id_access_token = body.get("id_access_token") # optional + client_secret = body["client_secret"] + assert_valid_client_secret(client_secret) requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 1bda9aec7e..a09189b1b4 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -49,6 +49,7 @@ from synapse.http.servlet import ( from synapse.push.mailer import load_jinja2_templates from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.ratelimitutils import FederationRateLimiter +from synapse.util.stringutils import assert_valid_client_secret from synapse.util.threepids import check_3pid_allowed from ._base import client_patterns, interactive_auth_handler @@ -116,6 +117,8 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): # Extract params from body client_secret = body["client_secret"] + assert_valid_client_secret(client_secret) + email = body["email"] send_attempt = body["send_attempt"] next_link = body.get("next_link") # Optional param diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 982c6d81ca..2c0dcb5208 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,14 +15,22 @@ # limitations under the License. import random +import re import string import six from six import PY2, PY3 from six.moves import range +from synapse.api.errors import Codes, SynapseError + _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" +# https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken +# Note: The : character is allowed here for older clients, but will be removed in a +# future release. Context: https://github.com/matrix-org/synapse/issues/6766 +client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-\:]+$") + # random_string and random_string_with_symbols are used for a range of things, # some cryptographically important, some less so. We use SystemRandom to make sure # we get cryptographically-secure randoms. @@ -109,3 +118,11 @@ def exception_to_unicode(e): return msg.decode("utf-8", errors="replace") else: return msg + + +def assert_valid_client_secret(client_secret): + """Validate that a given string matches the client_secret regex defined by the spec""" + if client_secret_regex.match(client_secret) is None: + raise SynapseError( + 400, "Invalid client_secret parameter", errcode=Codes.INVALID_PARAM + ) diff --git a/tests/util/test_stringutils.py b/tests/util/test_stringutils.py new file mode 100644 index 0000000000..4f4da29a98 --- /dev/null +++ b/tests/util/test_stringutils.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.api.errors import SynapseError +from synapse.util.stringutils import assert_valid_client_secret + +from .. import unittest + + +class StringUtilsTestCase(unittest.TestCase): + def test_client_secret_regex(self): + """Ensure that client_secret does not contain illegal characters""" + good = [ + "abcde12345", + "ABCabc123", + "_--something==_", + "...--==-18913", + "8Dj2odd-e9asd.cd==_--ddas-secret-", + # We temporarily allow : characters: https://github.com/matrix-org/synapse/issues/6766 + # To be removed in a future release + "SECRET:1234567890", + ] + + bad = [ + "--+-/secret", + "\\dx--dsa288", + "", + "AAS//", + "asdj**", + ">X> Date: Mon, 27 Jan 2020 14:30:57 +0000 Subject: Add `rooms.room_version` column (#6729) This is so that we don't have to rely on pulling it out from `current_state_events` table. --- changelog.d/6729.misc | 1 + synapse/federation/federation_client.py | 50 ++++++++---- synapse/handlers/federation.py | 65 +++++++++++---- synapse/handlers/room.py | 52 +++++++----- .../client/v2_alpha/room_upgrade_rest_servlet.py | 3 +- synapse/storage/data_stores/main/room.py | 94 ++++++++++++++++++++-- .../main/schema/delta/57/rooms_version_column.sql | 24 ++++++ synapse/storage/data_stores/main/state.py | 34 +++++--- tests/storage/test_room.py | 7 +- tests/storage/test_state.py | 5 +- tests/utils.py | 8 ++ 11 files changed, 270 insertions(+), 73 deletions(-) create mode 100644 changelog.d/6729.misc create mode 100644 synapse/storage/data_stores/main/schema/delta/57/rooms_version_column.sql (limited to 'tests') diff --git a/changelog.d/6729.misc b/changelog.d/6729.misc new file mode 100644 index 0000000000..5537355bea --- /dev/null +++ b/changelog.d/6729.misc @@ -0,0 +1 @@ +Record room versions in the `rooms` table. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index af652a7659..d57e8ca7a2 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -17,6 +17,7 @@ import copy import itertools import logging +from typing import Dict, Iterable from prometheus_client import Counter @@ -29,6 +30,7 @@ from synapse.api.errors import ( FederationDeniedError, HttpResponseException, SynapseError, + UnsupportedRoomVersionError, ) from synapse.api.room_versions import ( KNOWN_ROOM_VERSIONS, @@ -385,6 +387,8 @@ class FederationClient(FederationBase): return res except InvalidResponseError as e: logger.warning("Failed to %s via %s: %s", description, destination, e) + except UnsupportedRoomVersionError: + raise except HttpResponseException as e: if not 500 <= e.code < 600: raise e.to_synapse_error() @@ -404,7 +408,13 @@ class FederationClient(FederationBase): raise SynapseError(502, "Failed to %s via any server" % (description,)) def make_membership_event( - self, destinations, room_id, user_id, membership, content, params + self, + destinations: Iterable[str], + room_id: str, + user_id: str, + membership: str, + content: dict, + params: Dict[str, str], ): """ Creates an m.room.member event, with context, without participating in the room. @@ -417,21 +427,23 @@ class FederationClient(FederationBase): Note that this does not append any events to any graphs. Args: - destinations (Iterable[str]): Candidate homeservers which are probably + destinations: Candidate homeservers which are probably participating in the room. - room_id (str): The room in which the event will happen. - user_id (str): The user whose membership is being evented. - membership (str): The "membership" property of the event. Must be - one of "join" or "leave". - content (dict): Any additional data to put into the content field - of the event. - params (dict[str, str|Iterable[str]]): Query parameters to include in the - request. + room_id: The room in which the event will happen. + user_id: The user whose membership is being evented. + membership: The "membership" property of the event. Must be one of + "join" or "leave". + content: Any additional data to put into the content field of the + event. + params: Query parameters to include in the request. Return: - Deferred[tuple[str, FrozenEvent, int]]: resolves to a tuple of - `(origin, event, event_format)` where origin is the remote - homeserver which generated the event, and event_format is one of - `synapse.api.room_versions.EventFormatVersions`. + Deferred[Tuple[str, FrozenEvent, RoomVersion]]: resolves to a tuple of + `(origin, event, room_version)` where origin is the remote + homeserver which generated the event, and room_version is the + version of the room. + + Fails with a `UnsupportedRoomVersionError` if remote responds with + a room version we don't understand. Fails with a ``SynapseError`` if the chosen remote server returns a 300/400 code. @@ -453,8 +465,12 @@ class FederationClient(FederationBase): # Note: If not supplied, the room version may be either v1 or v2, # however either way the event format version will be v1. - room_version = ret.get("room_version", RoomVersions.V1.identifier) - event_format = room_version_to_event_format(room_version) + room_version_id = ret.get("room_version", RoomVersions.V1.identifier) + room_version = KNOWN_ROOM_VERSIONS.get(room_version_id) + if not room_version: + raise UnsupportedRoomVersionError() + + event_format = room_version_to_event_format(room_version_id) pdu_dict = ret.get("event", None) if not isinstance(pdu_dict, dict): @@ -478,7 +494,7 @@ class FederationClient(FederationBase): event_dict=pdu_dict, ) - return (destination, ev, event_format) + return (destination, ev, room_version) return self._try_destination_list( "make_" + membership, destinations, send_request diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index d4f9a792fc..f824ee79a0 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -44,10 +44,10 @@ from synapse.api.errors import ( StoreError, SynapseError, ) -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion, RoomVersions from synapse.crypto.event_signing import compute_event_signature from synapse.event_auth import auth_types_for_event -from synapse.events import EventBase +from synapse.events import EventBase, room_version_to_event_format from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator from synapse.logging.context import ( @@ -703,8 +703,20 @@ class FederationHandler(BaseHandler): if not room: try: + prev_state_ids = await context.get_prev_state_ids() + create_event = await self.store.get_event( + prev_state_ids[(EventTypes.Create, "")] + ) + + room_version_id = create_event.content.get( + "room_version", RoomVersions.V1.identifier + ) + await self.store.store_room( - room_id=room_id, room_creator_user_id="", is_public=False + room_id=room_id, + room_creator_user_id="", + is_public=False, + room_version=KNOWN_ROOM_VERSIONS[room_version_id], ) except StoreError: logger.exception("Failed to store room.") @@ -1186,7 +1198,7 @@ class FederationHandler(BaseHandler): """ logger.debug("Joining %s to %s", joinee, room_id) - origin, event, event_format_version = yield self._make_and_verify_event( + origin, event, room_version = yield self._make_and_verify_event( target_hosts, room_id, joinee, @@ -1214,6 +1226,8 @@ class FederationHandler(BaseHandler): target_hosts.insert(0, origin) except ValueError: pass + + event_format_version = room_version_to_event_format(room_version.identifier) ret = yield self.federation_client.send_join( target_hosts, event, event_format_version ) @@ -1234,13 +1248,18 @@ class FederationHandler(BaseHandler): try: yield self.store.store_room( - room_id=room_id, room_creator_user_id="", is_public=False + room_id=room_id, + room_creator_user_id="", + is_public=False, + room_version=room_version, ) except Exception: # FIXME pass - yield self._persist_auth_tree(origin, auth_chain, state, event) + yield self._persist_auth_tree( + origin, auth_chain, state, event, room_version + ) # Check whether this room is the result of an upgrade of a room we already know # about. If so, migrate over user information @@ -1486,7 +1505,7 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def do_remotely_reject_invite(self, target_hosts, room_id, user_id, content): - origin, event, event_format_version = yield self._make_and_verify_event( + origin, event, room_version = yield self._make_and_verify_event( target_hosts, room_id, user_id, "leave", content=content ) # Mark as outlier as we don't have any state for this event; we're not @@ -1513,7 +1532,11 @@ class FederationHandler(BaseHandler): def _make_and_verify_event( self, target_hosts, room_id, user_id, membership, content={}, params=None ): - origin, event, format_ver = yield self.federation_client.make_membership_event( + ( + origin, + event, + room_version, + ) = yield self.federation_client.make_membership_event( target_hosts, room_id, user_id, membership, content, params=params ) @@ -1525,7 +1548,7 @@ class FederationHandler(BaseHandler): assert event.user_id == user_id assert event.state_key == user_id assert event.room_id == room_id - return origin, event, format_ver + return origin, event, room_version @defer.inlineCallbacks @log_function @@ -1810,7 +1833,14 @@ class FederationHandler(BaseHandler): ) @defer.inlineCallbacks - def _persist_auth_tree(self, origin, auth_events, state, event): + def _persist_auth_tree( + self, + origin: str, + auth_events: List[EventBase], + state: List[EventBase], + event: EventBase, + room_version: RoomVersion, + ): """Checks the auth chain is valid (and passes auth checks) for the state and event. Then persists the auth chain and state atomically. Persists the event separately. Notifies about the persisted events @@ -1819,10 +1849,12 @@ class FederationHandler(BaseHandler): Will attempt to fetch missing auth events. Args: - origin (str): Where the events came from - auth_events (list) - state (list) - event (Event) + origin: Where the events came from + auth_events + state + event + room_version: The room version we expect this room to have, and + will raise if it doesn't match the version in the create event. Returns: Deferred @@ -1848,10 +1880,13 @@ class FederationHandler(BaseHandler): # invalid, and it would fail auth checks anyway. raise SynapseError(400, "No create event in state") - room_version = create_event.content.get( + room_version_id = create_event.content.get( "room_version", RoomVersions.V1.identifier ) + if room_version.identifier != room_version_id: + raise SynapseError(400, "Room version mismatch") + missing_auth_events = set() for e in itertools.chain(auth_events, state, [event]): for e_id in e.auth_event_ids(): diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 9f50196ea7..a9490782b7 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -29,7 +29,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.http.endpoint import parse_and_validate_server_name from synapse.storage.state import StateFilter from synapse.types import ( @@ -100,13 +100,15 @@ class RoomCreationHandler(BaseHandler): self.third_party_event_rules = hs.get_third_party_event_rules() @defer.inlineCallbacks - def upgrade_room(self, requester, old_room_id, new_version): + def upgrade_room( + self, requester: Requester, old_room_id: str, new_version: RoomVersion + ): """Replace a room with a new room with a different version Args: - requester (synapse.types.Requester): the user requesting the upgrade - old_room_id (unicode): the id of the room to be replaced - new_version (unicode): the new room version to use + requester: the user requesting the upgrade + old_room_id: the id of the room to be replaced + new_version: the new room version to use Returns: Deferred[unicode]: the new room id @@ -151,7 +153,7 @@ class RoomCreationHandler(BaseHandler): if r is None: raise NotFoundError("Unknown room id %s" % (old_room_id,)) new_room_id = yield self._generate_room_id( - creator_id=user_id, is_public=r["is_public"] + creator_id=user_id, is_public=r["is_public"], room_version=new_version, ) logger.info("Creating new room %s to replace %s", new_room_id, old_room_id) @@ -299,18 +301,22 @@ class RoomCreationHandler(BaseHandler): @defer.inlineCallbacks def clone_existing_room( - self, requester, old_room_id, new_room_id, new_room_version, tombstone_event_id + self, + requester: Requester, + old_room_id: str, + new_room_id: str, + new_room_version: RoomVersion, + tombstone_event_id: str, ): """Populate a new room based on an old room Args: - requester (synapse.types.Requester): the user requesting the upgrade - old_room_id (unicode): the id of the room to be replaced - new_room_id (unicode): the id to give the new room (should already have been + requester: the user requesting the upgrade + old_room_id : the id of the room to be replaced + new_room_id: the id to give the new room (should already have been created with _gemerate_room_id()) - new_room_version (unicode): the new room version to use - tombstone_event_id (unicode|str): the ID of the tombstone event in the old - room. + new_room_version: the new room version to use + tombstone_event_id: the ID of the tombstone event in the old room. Returns: Deferred """ @@ -320,7 +326,7 @@ class RoomCreationHandler(BaseHandler): raise SynapseError(403, "You are not permitted to create rooms") creation_content = { - "room_version": new_room_version, + "room_version": new_room_version.identifier, "predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id}, } @@ -577,14 +583,15 @@ class RoomCreationHandler(BaseHandler): if ratelimit: yield self.ratelimit(requester) - room_version = config.get( + room_version_id = config.get( "room_version", self.config.default_room_version.identifier ) - if not isinstance(room_version, string_types): + if not isinstance(room_version_id, string_types): raise SynapseError(400, "room_version must be a string", Codes.BAD_JSON) - if room_version not in KNOWN_ROOM_VERSIONS: + room_version = KNOWN_ROOM_VERSIONS.get(room_version_id) + if room_version is None: raise SynapseError( 400, "Your homeserver does not support this room version", @@ -631,7 +638,9 @@ class RoomCreationHandler(BaseHandler): visibility = config.get("visibility", None) is_public = visibility == "public" - room_id = yield self._generate_room_id(creator_id=user_id, is_public=is_public) + room_id = yield self._generate_room_id( + creator_id=user_id, is_public=is_public, room_version=room_version, + ) directory_handler = self.hs.get_handlers().directory_handler if room_alias: @@ -660,7 +669,7 @@ class RoomCreationHandler(BaseHandler): creation_content = config.get("creation_content", {}) # override any attempt to set room versions via the creation_content - creation_content["room_version"] = room_version + creation_content["room_version"] = room_version.identifier yield self._send_events_for_new_room( requester, @@ -849,7 +858,9 @@ class RoomCreationHandler(BaseHandler): yield send(etype=etype, state_key=state_key, content=content) @defer.inlineCallbacks - def _generate_room_id(self, creator_id, is_public): + def _generate_room_id( + self, creator_id: str, is_public: str, room_version: RoomVersion, + ): # autogen room IDs and try to create it. We may clash, so just # try a few times till one goes through, giving up eventually. attempts = 0 @@ -863,6 +874,7 @@ class RoomCreationHandler(BaseHandler): room_id=gen_room_id, room_creator_user_id=creator_id, is_public=is_public, + room_version=room_version, ) return gen_room_id except StoreError: diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py index ca97330797..f357015a70 100644 --- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py +++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py @@ -64,7 +64,8 @@ class RoomUpgradeRestServlet(RestServlet): assert_params_in_dict(content, ("new_version",)) new_version = content["new_version"] - if new_version not in KNOWN_ROOM_VERSIONS: + new_version = KNOWN_ROOM_VERSIONS.get(content["new_version"]) + if new_version is None: raise SynapseError( 400, "Your homeserver does not support this room version", diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index d968803ad2..9a17e336ba 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -29,9 +29,10 @@ from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.api.errors import StoreError +from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.search import SearchStore -from synapse.storage.database import Database +from synapse.storage.database import Database, LoggingTransaction from synapse.types import ThirdPartyInstanceID from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -734,6 +735,7 @@ class RoomWorkerStore(SQLBaseStore): class RoomBackgroundUpdateStore(SQLBaseStore): REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" + ADD_ROOMS_ROOM_VERSION_COLUMN = "add_rooms_room_version_column" def __init__(self, database: Database, db_conn, hs): super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs) @@ -749,6 +751,11 @@ class RoomBackgroundUpdateStore(SQLBaseStore): self._remove_tombstoned_rooms_from_directory, ) + self.db.updates.register_background_update_handler( + self.ADD_ROOMS_ROOM_VERSION_COLUMN, + self._background_add_rooms_room_version_column, + ) + @defer.inlineCallbacks def _background_insert_retention(self, progress, batch_size): """Retrieves a list of all rooms within a range and inserts an entry for each of @@ -817,6 +824,73 @@ class RoomBackgroundUpdateStore(SQLBaseStore): defer.returnValue(batch_size) + async def _background_add_rooms_room_version_column( + self, progress: dict, batch_size: int + ): + """Background update to go and add room version inforamtion to `rooms` + table from `current_state_events` table. + """ + + last_room_id = progress.get("room_id", "") + + def _background_add_rooms_room_version_column_txn(txn: LoggingTransaction): + sql = """ + SELECT room_id, json FROM current_state_events + INNER JOIN event_json USING (room_id, event_id) + WHERE room_id > ? AND type = 'm.room.create' AND state_key = '' + ORDER BY room_id + LIMIT ? + """ + + txn.execute(sql, (last_room_id, batch_size)) + + updates = [] + for room_id, event_json in txn: + event_dict = json.loads(event_json) + room_version_id = event_dict.get("content", {}).get( + "room_version", RoomVersions.V1.identifier + ) + + creator = event_dict.get("content").get("creator") + + updates.append((room_id, creator, room_version_id)) + + if not updates: + return True + + new_last_room_id = "" + for room_id, creator, room_version_id in updates: + # We upsert here just in case we don't already have a row, + # mainly for paranoia as much badness would happen if we don't + # insert the row and then try and get the room version for the + # room. + self.db.simple_upsert_txn( + txn, + table="rooms", + keyvalues={"room_id": room_id}, + values={"room_version": room_version_id}, + insertion_values={"is_public": False, "creator": creator}, + ) + new_last_room_id = room_id + + self.db.updates._background_update_progress_txn( + txn, self.ADD_ROOMS_ROOM_VERSION_COLUMN, {"room_id": new_last_room_id} + ) + + return False + + end = await self.db.runInteraction( + "_background_add_rooms_room_version_column", + _background_add_rooms_room_version_column_txn, + ) + + if end: + await self.db.updates._end_background_update( + self.ADD_ROOMS_ROOM_VERSION_COLUMN + ) + + return batch_size + async def _remove_tombstoned_rooms_from_directory( self, progress, batch_size ) -> int: @@ -881,14 +955,21 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): self.config = hs.config @defer.inlineCallbacks - def store_room(self, room_id, room_creator_user_id, is_public): + def store_room( + self, + room_id: str, + room_creator_user_id: str, + is_public: bool, + room_version: RoomVersion, + ): """Stores a room. Args: - room_id (str): The desired room ID, can be None. - room_creator_user_id (str): The user ID of the room creator. - is_public (bool): True to indicate that this room should appear in - public room lists. + room_id: The desired room ID, can be None. + room_creator_user_id: The user ID of the room creator. + is_public: True to indicate that this room should appear in + public room lists. + room_version: The version of the room Raises: StoreError if the room could not be stored. """ @@ -902,6 +983,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): "room_id": room_id, "creator": room_creator_user_id, "is_public": is_public, + "room_version": room_version.identifier, }, ) if is_public: diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column.sql b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column.sql new file mode 100644 index 0000000000..352a66f5b0 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column.sql @@ -0,0 +1,24 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- We want to start storing the room version independently of +-- `current_state_events` so that we can delete stale entries from it without +-- losing the information. +ALTER TABLE rooms ADD COLUMN room_version TEXT; + + +INSERT into background_updates (update_name, progress_json) + VALUES ('add_rooms_room_version_column', '{}'); diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 33bebd1c48..bd7b0276f1 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -60,24 +60,34 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): def __init__(self, database: Database, db_conn, hs): super(StateGroupWorkerStore, self).__init__(database, db_conn, hs) - @defer.inlineCallbacks - def get_room_version(self, room_id): + @cached(max_entries=10000) + async def get_room_version(self, room_id: str) -> str: """Get the room_version of a given room - Args: - room_id (str) - - Returns: - Deferred[str] - Raises: - NotFoundError if the room is unknown + NotFoundError: if the room is unknown """ - # for now we do this by looking at the create event. We may want to cache this - # more intelligently in future. + + # First we try looking up room version from the database, but for old + # rooms we might not have added the room version to it yet so we fall + # back to previous behaviour and look in current state events. + + # We really should have an entry in the rooms table for every room we + # care about, but let's be a bit paranoid (at least while the background + # update is happening) to avoid breaking existing rooms. + version = await self.db.simple_select_one_onecol( + table="rooms", + keyvalues={"room_id": room_id}, + retcol="room_version", + desc="get_room_version", + allow_none=True, + ) + + if version is not None: + return version # Retrieve the room's create event - create_event = yield self.get_create_event_for_room(room_id) + create_event = await self.get_create_event_for_room(room_id) return create_event.content.get("room_version", "1") @defer.inlineCallbacks diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 3ddaa151fe..086adeb8fd 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -17,6 +17,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes +from synapse.api.room_versions import RoomVersions from synapse.types import RoomAlias, RoomID, UserID from tests import unittest @@ -40,6 +41,7 @@ class RoomStoreTestCase(unittest.TestCase): self.room.to_string(), room_creator_user_id=self.u_creator.to_string(), is_public=True, + room_version=RoomVersions.V1, ) @defer.inlineCallbacks @@ -68,7 +70,10 @@ class RoomEventsStoreTestCase(unittest.TestCase): self.room = RoomID.from_string("!abcde:test") yield self.store.store_room( - self.room.to_string(), room_creator_user_id="@creator:text", is_public=True + self.room.to_string(), + room_creator_user_id="@creator:text", + is_public=True, + room_version=RoomVersions.V1, ) @defer.inlineCallbacks diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index d6ecf102f8..04d58fbf24 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -45,7 +45,10 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room = RoomID.from_string("!abc123:test") yield self.store.store_room( - self.room.to_string(), room_creator_user_id="@creator:text", is_public=True + self.room.to_string(), + room_creator_user_id="@creator:text", + is_public=True, + room_version=RoomVersions.V1, ) @defer.inlineCallbacks diff --git a/tests/utils.py b/tests/utils.py index e2e9cafd79..513f358f4f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -639,9 +639,17 @@ def create_room(hs, room_id, creator_id): """ persistence_store = hs.get_storage().persistence + store = hs.get_datastore() event_builder_factory = hs.get_event_builder_factory() event_creation_handler = hs.get_event_creation_handler() + yield store.store_room( + room_id=room_id, + room_creator_user_id=creator_id, + is_public=False, + room_version=RoomVersions.V1, + ) + builder = event_builder_factory.for_room_version( RoomVersions.V1, { -- cgit 1.5.1 From a8ce7aeb433e08f46306797a1252668c178a7825 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 28 Jan 2020 14:18:29 +0000 Subject: Pass room version object into event_auth.check and check_redaction (#6788) These are easier to work with than the strings and we normally have one around. This fixes `FederationHander._persist_auth_tree` which was passing a RoomVersion object into event_auth.check instead of a string. --- changelog.d/6788.misc | 1 + synapse/api/auth.py | 7 +++++-- synapse/event_auth.py | 34 +++++++++++++++++++++------------- synapse/handlers/federation.py | 18 +++++++++++------- synapse/handlers/message.py | 8 ++++++-- synapse/state/v1.py | 4 ++-- synapse/state/v2.py | 4 +++- tests/test_event_auth.py | 11 ++++------- 8 files changed, 53 insertions(+), 34 deletions(-) create mode 100644 changelog.d/6788.misc (limited to 'tests') diff --git a/changelog.d/6788.misc b/changelog.d/6788.misc new file mode 100644 index 0000000000..5537355bea --- /dev/null +++ b/changelog.d/6788.misc @@ -0,0 +1 @@ +Record room versions in the `rooms` table. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 2cbfab2569..8b1277ad02 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -33,6 +33,7 @@ from synapse.api.errors import ( MissingClientTokenError, ResourceLimitError, ) +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.config.server import is_threepid_reserved from synapse.types import StateMap, UserID from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache @@ -77,15 +78,17 @@ class Auth(object): self._account_validity = hs.config.account_validity @defer.inlineCallbacks - def check_from_context(self, room_version, event, context, do_sig_check=True): + def check_from_context(self, room_version: str, event, context, do_sig_check=True): prev_state_ids = yield context.get_prev_state_ids() auth_events_ids = yield self.compute_auth_events( event, prev_state_ids, for_verification=True ) auth_events = yield self.store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in itervalues(auth_events)} + + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] event_auth.check( - room_version, event, auth_events=auth_events, do_sig_check=do_sig_check + room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check ) @defer.inlineCallbacks diff --git a/synapse/event_auth.py b/synapse/event_auth.py index e3a1ba47a0..016d5678e5 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014 - 2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,17 +24,27 @@ from unpaddedbase64 import decode_base64 from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.errors import AuthError, EventSizeError, SynapseError -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, EventFormatVersions +from synapse.api.room_versions import ( + KNOWN_ROOM_VERSIONS, + EventFormatVersions, + RoomVersion, +) from synapse.types import UserID, get_domain_from_id logger = logging.getLogger(__name__) -def check(room_version, event, auth_events, do_sig_check=True, do_size_check=True): +def check( + room_version_obj: RoomVersion, + event, + auth_events, + do_sig_check=True, + do_size_check=True, +): """ Checks if this event is correctly authed. Args: - room_version (str): the version of the room + room_version_obj: the version of the room event: the event being checked. auth_events (dict: event-key -> event): the existing room state. @@ -97,10 +108,11 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru 403, "Creation event's room_id domain does not match sender's" ) - room_version = event.content.get("room_version", "1") - if room_version not in KNOWN_ROOM_VERSIONS: + room_version_prop = event.content.get("room_version", "1") + if room_version_prop not in KNOWN_ROOM_VERSIONS: raise AuthError( - 403, "room appears to have unsupported version %s" % (room_version,) + 403, + "room appears to have unsupported version %s" % (room_version_prop,), ) # FIXME logger.debug("Allowing! %s", event) @@ -160,7 +172,7 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru _check_power_levels(event, auth_events) if event.type == EventTypes.Redaction: - check_redaction(room_version, event, auth_events) + check_redaction(room_version_obj, event, auth_events) logger.debug("Allowing! %s", event) @@ -386,7 +398,7 @@ def _can_send_event(event, auth_events): return True -def check_redaction(room_version, event, auth_events): +def check_redaction(room_version_obj: RoomVersion, event, auth_events): """Check whether the event sender is allowed to redact the target event. Returns: @@ -406,11 +418,7 @@ def check_redaction(room_version, event, auth_events): if user_level >= redact_level: return False - v = KNOWN_ROOM_VERSIONS.get(room_version) - if not v: - raise RuntimeError("Unrecognized room version %r" % (room_version,)) - - if v.event_format == EventFormatVersions.V1: + if room_version_obj.event_format == EventFormatVersions.V1: redacter_domain = get_domain_from_id(event.event_id) redactee_domain = get_domain_from_id(event.redacts) if redacter_domain == redactee_domain: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index f824ee79a0..180f165a7a 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -47,7 +47,7 @@ from synapse.api.errors import ( from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion, RoomVersions from synapse.crypto.event_signing import compute_event_signature from synapse.event_auth import auth_types_for_event -from synapse.events import EventBase, room_version_to_event_format +from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator from synapse.logging.context import ( @@ -1198,7 +1198,7 @@ class FederationHandler(BaseHandler): """ logger.debug("Joining %s to %s", joinee, room_id) - origin, event, room_version = yield self._make_and_verify_event( + origin, event, room_version_obj = yield self._make_and_verify_event( target_hosts, room_id, joinee, @@ -1227,7 +1227,7 @@ class FederationHandler(BaseHandler): except ValueError: pass - event_format_version = room_version_to_event_format(room_version.identifier) + event_format_version = room_version_obj.event_format ret = yield self.federation_client.send_join( target_hosts, event, event_format_version ) @@ -1251,14 +1251,14 @@ class FederationHandler(BaseHandler): room_id=room_id, room_creator_user_id="", is_public=False, - room_version=room_version, + room_version=room_version_obj, ) except Exception: # FIXME pass yield self._persist_auth_tree( - origin, auth_chain, state, event, room_version + origin, auth_chain, state, event, room_version_obj ) # Check whether this room is the result of an upgrade of a room we already know @@ -2022,6 +2022,7 @@ class FederationHandler(BaseHandler): if do_soft_fail_check: room_version = yield self.store.get_room_version(event.room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] # Calculate the "current state". if state is not None: @@ -2071,7 +2072,9 @@ class FederationHandler(BaseHandler): } try: - event_auth.check(room_version, event, auth_events=current_auth_events) + event_auth.check( + room_version_obj, event, auth_events=current_auth_events + ) except AuthError as e: logger.warning("Soft-failing %r because %s", event, e) event.internal_metadata.soft_failed = True @@ -2155,6 +2158,7 @@ class FederationHandler(BaseHandler): defer.Deferred[EventContext]: updated context object """ room_version = yield self.store.get_room_version(event.room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] try: context = yield self._update_auth_events_and_context_for_auth( @@ -2172,7 +2176,7 @@ class FederationHandler(BaseHandler): ) try: - event_auth.check(room_version, event, auth_events=auth_events) + event_auth.check(room_version_obj, event, auth_events=auth_events) except AuthError as e: logger.warning("Failed auth resolution for %r because %s", event, e) context.rejected = RejectedReason.AUTH_ERROR diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 8ea3aca2f4..9a0f661b9b 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -40,7 +40,7 @@ from synapse.api.errors import ( NotFoundError, SynapseError, ) -from synapse.api.room_versions import RoomVersions +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.api.urls import ConsentURIBuilder from synapse.events.validator import EventValidator from synapse.logging.context import run_in_background @@ -962,9 +962,13 @@ class EventCreationHandler(object): ) auth_events = yield self.store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in auth_events.values()} + room_version = yield self.store.get_room_version(event.room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - if event_auth.check_redaction(room_version, event, auth_events=auth_events): + if event_auth.check_redaction( + room_version_obj, event, auth_events=auth_events + ): # this user doesn't have 'redact' rights, so we need to do some more # checks on the original event. Let's start by checking the original # event exists. diff --git a/synapse/state/v1.py b/synapse/state/v1.py index d6c34ce3b7..24b7c0faef 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -281,7 +281,7 @@ def _resolve_auth_events(events, auth_events): try: # The signatures have already been checked at this point event_auth.check( - RoomVersions.V1.identifier, + RoomVersions.V1, event, auth_events, do_sig_check=False, @@ -299,7 +299,7 @@ def _resolve_normal_events(events, auth_events): try: # The signatures have already been checked at this point event_auth.check( - RoomVersions.V1.identifier, + RoomVersions.V1, event, auth_events, do_sig_check=False, diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 6216fdd204..531018c6a5 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -26,6 +26,7 @@ import synapse.state from synapse import event_auth from synapse.api.constants import EventTypes from synapse.api.errors import AuthError +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.types import StateMap @@ -402,6 +403,7 @@ def _iterative_auth_checks( Deferred[StateMap[str]]: Returns the final updated state """ resolved_state = base_state.copy() + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] for event_id in event_ids: event = event_map[event_id] @@ -430,7 +432,7 @@ def _iterative_auth_checks( try: event_auth.check( - room_version, + room_version_obj, event, auth_events, do_sig_check=False, diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index 8b2741d277..ca20b085a2 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -37,7 +37,7 @@ class EventAuthTestCase(unittest.TestCase): # creator should be able to send state event_auth.check( - RoomVersions.V1.identifier, + RoomVersions.V1, _random_state_event(creator), auth_events, do_sig_check=False, @@ -47,7 +47,7 @@ class EventAuthTestCase(unittest.TestCase): self.assertRaises( AuthError, event_auth.check, - RoomVersions.V1.identifier, + RoomVersions.V1, _random_state_event(joiner), auth_events, do_sig_check=False, @@ -76,7 +76,7 @@ class EventAuthTestCase(unittest.TestCase): self.assertRaises( AuthError, event_auth.check, - RoomVersions.V1.identifier, + RoomVersions.V1, _random_state_event(pleb), auth_events, do_sig_check=False, @@ -84,10 +84,7 @@ class EventAuthTestCase(unittest.TestCase): # king should be able to send state event_auth.check( - RoomVersions.V1.identifier, - _random_state_event(king), - auth_events, - do_sig_check=False, + RoomVersions.V1, _random_state_event(king), auth_events, do_sig_check=False, ) -- cgit 1.5.1 From ee42a5513e020424daa736962fee7fb69bd2373a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 28 Jan 2020 11:02:55 +0000 Subject: Factor out a `copy_power_levels_contents` method I'm going to need another copy (hah!) of this. --- synapse/events/utils.py | 37 ++++++++++++++++++++++++++++++++++++- synapse/handlers/room.py | 23 +++++++++++------------ tests/events/test_utils.py | 45 +++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 90 insertions(+), 15 deletions(-) (limited to 'tests') diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 07d1c5bcf0..be57c6d9be 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -12,8 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import collections import re +from typing import Mapping, Union from six import string_types @@ -422,3 +423,37 @@ class EventClientSerializer(object): return yieldable_gather_results( self.serialize_event, events, time_now=time_now, **kwargs ) + + +def copy_power_levels_contents( + old_power_levels: Mapping[str, Union[int, Mapping[str, int]]] +): + """Copy the content of a power_levels event, unfreezing frozendicts along the way + + Raises: + TypeError if the input does not look like a valid power levels event content + """ + if not isinstance(old_power_levels, collections.Mapping): + raise TypeError("Not a valid power-levels content: %r" % (old_power_levels,)) + + power_levels = {} + for k, v in old_power_levels.items(): + + if isinstance(v, int): + power_levels[k] = v + continue + + if isinstance(v, collections.Mapping): + power_levels[k] = h = {} + for k1, v1 in v.items(): + # we should only have one level of nesting + if not isinstance(v1, int): + raise TypeError( + "Invalid power_levels value for %s.%s: %r" % (k, k1, v) + ) + h[k1] = v1 + continue + + raise TypeError("Invalid power_levels value for %s: %r" % (k, v)) + + return power_levels diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index a9490782b7..532ee22fa4 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -30,6 +30,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion +from synapse.events.utils import copy_power_levels_contents from synapse.http.endpoint import parse_and_validate_server_name from synapse.storage.state import StateFilter from synapse.types import ( @@ -367,6 +368,15 @@ class RoomCreationHandler(BaseHandler): if old_event: initial_state[k] = old_event.content + # deep-copy the power-levels event before we start modifying it + # note that if frozen_dicts are enabled, `power_levels` will be a frozen + # dict so we can't just copy.deepcopy it. + initial_state[ + (EventTypes.PowerLevels, "") + ] = power_levels = copy_power_levels_contents( + initial_state[(EventTypes.PowerLevels, "")] + ) + # Resolve the minimum power level required to send any state event # We will give the upgrading user this power level temporarily (if necessary) such that # they are able to copy all of the state events over, then revert them back to their @@ -375,8 +385,6 @@ class RoomCreationHandler(BaseHandler): # Copy over user power levels now as this will not be possible with >100PL users once # the room has been created - power_levels = initial_state[(EventTypes.PowerLevels, "")] - # Calculate the minimum power level needed to clone the room event_power_levels = power_levels.get("events", {}) state_default = power_levels.get("state_default", 0) @@ -386,16 +394,7 @@ class RoomCreationHandler(BaseHandler): # Raise the requester's power level in the new room if necessary current_power_level = power_levels["users"][user_id] if current_power_level < needed_power_level: - # make sure we copy the event content rather than overwriting it. - # note that if frozen_dicts are enabled, `power_levels` will be a frozen - # dict so we can't just copy.deepcopy it. - - new_power_levels = {k: v for k, v in power_levels.items() if k != "users"} - new_power_levels["users"] = { - k: v for k, v in power_levels.get("users", {}).items() if k != user_id - } - new_power_levels["users"][user_id] = needed_power_level - initial_state[(EventTypes.PowerLevels, "")] = new_power_levels + power_levels["users"][user_id] = needed_power_level yield self._send_events_for_new_room( requester, diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index 9e3d4d0f47..2b13980dfd 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -15,9 +15,14 @@ from synapse.events import FrozenEvent -from synapse.events.utils import prune_event, serialize_event +from synapse.events.utils import ( + copy_power_levels_contents, + prune_event, + serialize_event, +) +from synapse.util.frozenutils import freeze -from .. import unittest +from tests import unittest def MockEvent(**kwargs): @@ -241,3 +246,39 @@ class SerializeEventTestCase(unittest.TestCase): self.serialize( MockEvent(room_id="!foo:bar", content={"foo": "bar"}), ["room_id", 4] ) + + +class CopyPowerLevelsContentTestCase(unittest.TestCase): + def setUp(self) -> None: + self.test_content = { + "ban": 50, + "events": {"m.room.name": 100, "m.room.power_levels": 100}, + "events_default": 0, + "invite": 50, + "kick": 50, + "notifications": {"room": 20}, + "redact": 50, + "state_default": 50, + "users": {"@example:localhost": 100}, + "users_default": 0, + } + + def _test(self, input): + a = copy_power_levels_contents(input) + + self.assertEqual(a["ban"], 50) + self.assertEqual(a["events"]["m.room.name"], 100) + + # make sure that changing the copy changes the copy and not the orig + a["ban"] = 10 + a["events"]["m.room.power_levels"] = 20 + + self.assertEqual(input["ban"], 50) + self.assertEqual(input["events"]["m.room.power_levels"], 100) + + def test_unfrozen(self): + self._test(self.test_content) + + def test_frozen(self): + input = freeze(self.test_content) + self._test(input) -- cgit 1.5.1 From 5a246611e3cbf27cf1dd7e4453adc8040cddd6a2 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 30 Jan 2020 11:25:59 +0000 Subject: Type defintions for use in refactoring for redaction changes (#6803) * Bump signedjson to 1.1 ... so that we can use the type definitions * Fix breakage caused by upgrade to signedjson 1.1 Thanks, @illicitonion... --- changelog.d/6803.misc | 1 + synapse/events/__init__.py | 5 +++-- synapse/python_dependencies.py | 4 +++- synapse/types.py | 7 ++++++- tests/storage/test_keys.py | 15 +++++++++++---- 5 files changed, 24 insertions(+), 8 deletions(-) create mode 100644 changelog.d/6803.misc (limited to 'tests') diff --git a/changelog.d/6803.misc b/changelog.d/6803.misc new file mode 100644 index 0000000000..08aa80bcd9 --- /dev/null +++ b/changelog.d/6803.misc @@ -0,0 +1 @@ +Refactoring work in preparation for changing the event redaction algorithm. diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 72c09327f4..f813fa2fe7 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -23,6 +23,7 @@ from unpaddedbase64 import encode_base64 from synapse.api.errors import UnsupportedRoomVersionError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, EventFormatVersions +from synapse.types import JsonDict from synapse.util.caches import intern_dict from synapse.util.frozenutils import freeze @@ -197,7 +198,7 @@ class EventBase(object): def is_state(self): return hasattr(self, "state_key") and self.state_key is not None - def get_dict(self): + def get_dict(self) -> JsonDict: d = dict(self._event_dict) d.update({"signatures": self.signatures, "unsigned": dict(self.unsigned)}) @@ -209,7 +210,7 @@ class EventBase(object): def get_internal_metadata_dict(self): return self.internal_metadata.get_dict() - def get_pdu_json(self, time_now=None): + def get_pdu_json(self, time_now=None) -> JsonDict: pdu_json = self.get_dict() if time_now is not None and "age_ts" in pdu_json["unsigned"]: diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 5871feaafd..8de8cb2c12 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -1,6 +1,7 @@ # Copyright 2015, 2016 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd # Copyright 2018 New Vector Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -43,7 +44,8 @@ REQUIREMENTS = [ "frozendict>=1", "unpaddedbase64>=1.1.0", "canonicaljson>=1.1.3", - "signedjson>=1.0.0", + # we use the type definitions added in signedjson 1.1. + "signedjson>=1.1.0", "pynacl>=1.2.1", "idna>=2.5", # validating SSL certs for IP addresses requires service_identity 18.1. diff --git a/synapse/types.py b/synapse/types.py index 65e4d8c181..f3cd465735 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -17,7 +17,7 @@ import re import string import sys from collections import namedtuple -from typing import Dict, Tuple, TypeVar +from typing import Any, Dict, Tuple, TypeVar import attr from signedjson.key import decode_verify_key_bytes @@ -43,6 +43,11 @@ T = TypeVar("T") StateMap = Dict[Tuple[str, str], T] +# the type of a JSON-serialisable dict. This could be made stronger, but it will +# do for now. +JsonDict = Dict[str, Any] + + class Requester( namedtuple( "Requester", ["user", "access_token_id", "is_guest", "device_id", "app_service"] diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py index e07ff01201..95f309fbbc 100644 --- a/tests/storage/test_keys.py +++ b/tests/storage/test_keys.py @@ -14,6 +14,7 @@ # limitations under the License. import signedjson.key +import unpaddedbase64 from twisted.internet.defer import Deferred @@ -21,11 +22,17 @@ from synapse.storage.keys import FetchKeyResult import tests.unittest -KEY_1 = signedjson.key.decode_verify_key_base64( - "ed25519", "key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw" + +def decode_verify_key_base64(key_id: str, key_base64: str): + key_bytes = unpaddedbase64.decode_base64(key_base64) + return signedjson.key.decode_verify_key_bytes(key_id, key_bytes) + + +KEY_1 = decode_verify_key_base64( + "ed25519:key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw" ) -KEY_2 = signedjson.key.decode_verify_key_base64( - "ed25519", "key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw" +KEY_2 = decode_verify_key_base64( + "ed25519:key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw" ) -- cgit 1.5.1 From b660327056cdced860d532ab2404a26946da7ef5 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 30 Jan 2020 17:06:38 +0000 Subject: Resync remote device list when detected as stale. (#6786) --- changelog.d/6786.misc | 1 + synapse/handlers/devicemessage.py | 10 ++++++++-- synapse/handlers/federation.py | 18 ++++++++++++++++-- tests/handlers/test_typing.py | 6 +++--- 4 files changed, 28 insertions(+), 7 deletions(-) create mode 100644 changelog.d/6786.misc (limited to 'tests') diff --git a/changelog.d/6786.misc b/changelog.d/6786.misc new file mode 100644 index 0000000000..94c692e53a --- /dev/null +++ b/changelog.d/6786.misc @@ -0,0 +1 @@ +Attempt to resync remote users' devices when detected as stale. diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 5c5fe77be2..05c4b3eec0 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -21,6 +21,7 @@ from canonicaljson import json from twisted.internet import defer from synapse.api.errors import SynapseError +from synapse.logging.context import run_in_background from synapse.logging.opentracing import ( get_active_span_text_map, log_kv, @@ -48,6 +49,8 @@ class DeviceMessageHandler(object): "m.direct_to_device", self.on_direct_to_device_edu ) + self._device_list_updater = hs.get_device_handler().device_list_updater + @defer.inlineCallbacks def on_direct_to_device_edu(self, origin, content): local_messages = {} @@ -134,8 +137,11 @@ class DeviceMessageHandler(object): unknown_devices, ) yield self.store.mark_remote_user_device_cache_as_stale(sender_user_id) - # TODO: Poke something to start trying to refetch user's - # keys. + + # Immediately attempt a resync in the background + run_in_background( + self._device_list_updater.user_device_resync, sender_user_id + ) @defer.inlineCallbacks def send_device_message(self, sender_user_id, message_type, messages): diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index a67020a259..ca484e5458 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -57,6 +57,7 @@ from synapse.logging.context import ( run_in_background, ) from synapse.logging.utils import log_function +from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet from synapse.replication.http.federation import ( ReplicationCleanRoomRestServlet, ReplicationFederationSendEventsRestServlet, @@ -156,6 +157,13 @@ class FederationHandler(BaseHandler): hs ) + if hs.config.worker_app: + self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client( + hs + ) + else: + self._device_list_updater = hs.get_device_handler().device_list_updater + # When joining a room we need to queue any events for that room up self.room_queues = {} self._room_pdu_linearizer = Linearizer("fed_room_pdu") @@ -759,8 +767,14 @@ class FederationHandler(BaseHandler): await self.store.mark_remote_user_device_cache_as_stale( event.sender ) - # TODO: Poke something to start trying to refetch user's - # keys. + + # Immediately attempt a resync in the background + if self.config.worker_app: + return run_in_background(self._user_device_resync, event.sender) + else: + return run_in_background( + self._device_list_updater.user_device_resync, event.sender + ) @log_function async def backfill(self, dest, room_id, limit, extremities): diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 596ddc6970..68b9847bd2 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -81,6 +81,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ] ) + # the tests assume that we are starting at unix time 1000 + reactor.pump((1000,)) + hs = self.setup_test_homeserver( notifier=Mock(), http_client=mock_federation_client, keyring=mock_keyring ) @@ -90,9 +93,6 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): return hs def prepare(self, reactor, clock, hs): - # the tests assume that we are starting at unix time 1000 - reactor.pump((1000,)) - mock_notifier = hs.get_notifier() self.on_new_event = mock_notifier.on_new_event -- cgit 1.5.1 From 184303b8650a90256f84bc9801b749a5b81b6d4b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 30 Jan 2020 17:20:55 +0000 Subject: MSC2260: Block direct sends of m.room.aliases events (#6794) as per MSC2260 --- changelog.d/6794.feature | 1 + synapse/rest/client/v1/room.py | 12 ++++++++++ tests/rest/admin/test_admin.py | 7 ------ tests/rest/client/v1/test_directory.py | 41 +++++++++++++--------------------- 4 files changed, 28 insertions(+), 33 deletions(-) create mode 100644 changelog.d/6794.feature (limited to 'tests') diff --git a/changelog.d/6794.feature b/changelog.d/6794.feature new file mode 100644 index 0000000000..df9e4b77ab --- /dev/null +++ b/changelog.d/6794.feature @@ -0,0 +1 @@ +Implement updated authorization rules for aliases events, from [MSC2260](https://github.com/matrix-org/matrix-doc/pull/2260). diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 5aef8238b8..6f31584c51 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -184,6 +184,12 @@ class RoomStateEventRestServlet(TransactionRestServlet): content = parse_json_object_from_request(request) + if event_type == EventTypes.Aliases: + # MSC2260 + raise SynapseError( + 400, "Cannot send m.room.aliases events via /rooms/{room_id}/state" + ) + event_dict = { "type": event_type, "content": content, @@ -231,6 +237,12 @@ class RoomSendEventRestServlet(TransactionRestServlet): requester = await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) + if event_type == EventTypes.Aliases: + # MSC2260 + raise SynapseError( + 400, "Cannot send m.room.aliases events via /rooms/{room_id}/send" + ) + event_dict = { "type": event_type, "content": content, diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 0342aed416..e5984aaad8 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -868,13 +868,6 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Set this new alias as the canonical alias for this room - self.helper.send_state( - room_id, - "m.room.aliases", - {"aliases": [test_alias]}, - tok=self.admin_user_tok, - state_key="test", - ) self.helper.send_state( room_id, "m.room.canonical_alias", diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/v1/test_directory.py index 633b7dbda0..914cf54927 100644 --- a/tests/rest/client/v1/test_directory.py +++ b/tests/rest/client/v1/test_directory.py @@ -51,26 +51,30 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.user = self.register_user("user", "test") self.user_tok = self.login("user", "test") - def test_state_event_not_in_room(self): - self.ensure_user_left_room() - self.set_alias_via_state_event(403) + def test_cannot_set_alias_via_state_event(self): + self.ensure_user_joined_room() + url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % ( + self.room_id, + self.hs.hostname, + ) + + data = {"aliases": [self.random_alias(5)]} + request_data = json.dumps(data) + + request, channel = self.make_request( + "PUT", url, request_data, access_token=self.user_tok + ) + self.render(request) + self.assertEqual(channel.code, 400, channel.result) def test_directory_endpoint_not_in_room(self): self.ensure_user_left_room() self.set_alias_via_directory(403) - def test_state_event_in_room_too_long(self): - self.ensure_user_joined_room() - self.set_alias_via_state_event(400, alias_length=256) - def test_directory_in_room_too_long(self): self.ensure_user_joined_room() self.set_alias_via_directory(400, alias_length=256) - def test_state_event_in_room(self): - self.ensure_user_joined_room() - self.set_alias_via_state_event(200) - def test_directory_in_room(self): self.ensure_user_joined_room() self.set_alias_via_directory(200) @@ -102,21 +106,6 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEqual(channel.code, 200, channel.result) - def set_alias_via_state_event(self, expected_code, alias_length=5): - url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % ( - self.room_id, - self.hs.hostname, - ) - - data = {"aliases": [self.random_alias(alias_length)]} - request_data = json.dumps(data) - - request, channel = self.make_request( - "PUT", url, request_data, access_token=self.user_tok - ) - self.render(request) - self.assertEqual(channel.code, expected_code, channel.result) - def set_alias_via_directory(self, expected_code, alias_length=5): url = "/_matrix/client/r0/directory/room/%s" % self.random_alias(alias_length) data = {"room_id": self.room_id} -- cgit 1.5.1 From ef6bdafb29abe14cb40f6b83a46e70d82cd3e041 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 29 Jan 2020 17:55:48 +0000 Subject: Store the room version in EventBuilder --- synapse/events/builder.py | 12 +++++++----- tests/handlers/test_presence.py | 4 ++-- 2 files changed, 9 insertions(+), 7 deletions(-) (limited to 'tests') diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 3997751337..291fb38a26 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -23,6 +23,7 @@ from synapse.api.room_versions import ( KNOWN_EVENT_FORMAT_VERSIONS, KNOWN_ROOM_VERSIONS, EventFormatVersions, + RoomVersion, ) from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.types import EventID @@ -40,7 +41,7 @@ class EventBuilder(object): content/unsigned/internal_metadata fields are still mutable) Attributes: - format_version (int): Event format version + room_version: Version of the target room room_id (str) type (str) sender (str) @@ -63,7 +64,7 @@ class EventBuilder(object): _hostname = attr.ib() _signing_key = attr.ib() - format_version = attr.ib() + room_version = attr.ib(type=RoomVersion) room_id = attr.ib() type = attr.ib() @@ -108,7 +109,8 @@ class EventBuilder(object): ) auth_ids = yield self._auth.compute_auth_events(self, state_ids) - if self.format_version == EventFormatVersions.V1: + format_version = self.room_version.event_format + if format_version == EventFormatVersions.V1: auth_events = yield self._store.add_event_hashes(auth_ids) prev_events = yield self._store.add_event_hashes(prev_event_ids) else: @@ -148,7 +150,7 @@ class EventBuilder(object): clock=self._clock, hostname=self._hostname, signing_key=self._signing_key, - format_version=self.format_version, + format_version=format_version, event_dict=event_dict, internal_metadata_dict=self.internal_metadata.get_dict(), ) @@ -201,7 +203,7 @@ class EventBuilderFactory(object): clock=self.clock, hostname=self.hostname, signing_key=self.signing_key, - format_version=room_version.event_format, + room_version=room_version, type=key_values["type"], state_key=key_values.get("state_key"), room_id=key_values["room_id"], diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index d4293b4312..69914428e2 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -19,7 +19,7 @@ from mock import Mock, call from signedjson.key import generate_signing_key from synapse.api.constants import EventTypes, Membership, PresenceState -from synapse.events import room_version_to_event_format +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events.builder import EventBuilder from synapse.handlers.presence import ( EXTERNAL_PROCESS_EXPIRY, @@ -597,7 +597,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): clock=self.clock, hostname=hostname, signing_key=self.random_signing_key, - format_version=room_version_to_event_format(room_version), + room_version=KNOWN_ROOM_VERSIONS[room_version], room_id=room_id, type=EventTypes.Member, sender=user_id, -- cgit 1.5.1 From 2a81393a4b905c8bd4c31da04a8b4407462948b9 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 29 Jan 2020 17:40:33 +0000 Subject: Pass room_version into add_hashes_and_signatures --- synapse/crypto/event_signing.py | 20 +++++++++++++------- synapse/events/builder.py | 2 +- tests/crypto/test_event_signing.py | 9 +++++++-- 3 files changed, 21 insertions(+), 10 deletions(-) (limited to 'tests') diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py index e65bd61d97..1f2bccf700 100644 --- a/synapse/crypto/event_signing.py +++ b/synapse/crypto/event_signing.py @@ -20,10 +20,13 @@ import logging from canonicaljson import encode_canonical_json from signedjson.sign import sign_json +from signedjson.types import SigningKey from unpaddedbase64 import decode_base64, encode_base64 from synapse.api.errors import Codes, SynapseError +from synapse.api.room_versions import RoomVersion from synapse.events.utils import prune_event, prune_event_dict +from synapse.types import JsonDict logger = logging.getLogger(__name__) @@ -137,20 +140,23 @@ def compute_event_signature(event_dict, signature_name, signing_key): def add_hashes_and_signatures( - event_dict, signature_name, signing_key, hash_algorithm=hashlib.sha256 + room_version: RoomVersion, + event_dict: JsonDict, + signature_name: str, + signing_key: SigningKey, ): """Add content hash and sign the event Args: - event_dict (dict): The event to add hashes to and sign - signature_name (str): The name of the entity signing the event + room_version: the version of the room this event is in + + event_dict: The event to add hashes to and sign + signature_name: The name of the entity signing the event (typically the server's hostname). - signing_key (syutil.crypto.SigningKey): The key to sign with - hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use - to hash the event + signing_key: The key to sign with """ - name, digest = compute_content_hash(event_dict, hash_algorithm=hash_algorithm) + name, digest = compute_content_hash(event_dict, hash_algorithm=hashlib.sha256) event_dict.setdefault("hashes", {})[name] = encode_base64(digest) diff --git a/synapse/events/builder.py b/synapse/events/builder.py index a26f4c9044..8d63ad6dc3 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -255,7 +255,7 @@ def create_local_event_from_event_dict( event_dict.setdefault("signatures", {}) - add_hashes_and_signatures(event_dict, hostname, signing_key) + add_hashes_and_signatures(room_version, event_dict, hostname, signing_key) return event_type_from_format_version(format_version)( event_dict, internal_metadata_dict=internal_metadata_dict ) diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py index 126e176004..6143a50ab2 100644 --- a/tests/crypto/test_event_signing.py +++ b/tests/crypto/test_event_signing.py @@ -17,6 +17,7 @@ import nacl.signing from unpaddedbase64 import decode_base64 +from synapse.api.room_versions import RoomVersions from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.events import FrozenEvent @@ -49,7 +50,9 @@ class EventSigningTestCase(unittest.TestCase): "unsigned": {"age_ts": 1000000}, } - add_hashes_and_signatures(event_dict, HOSTNAME, self.signing_key) + add_hashes_and_signatures( + RoomVersions.V1, event_dict, HOSTNAME, self.signing_key + ) event = FrozenEvent(event_dict) @@ -81,7 +84,9 @@ class EventSigningTestCase(unittest.TestCase): "unsigned": {"age_ts": 1000000}, } - add_hashes_and_signatures(event_dict, HOSTNAME, self.signing_key) + add_hashes_and_signatures( + RoomVersions.V1, event_dict, HOSTNAME, self.signing_key + ) event = FrozenEvent(event_dict) -- cgit 1.5.1 From d7bf793cc1e3f5268285286341835ac54753eff6 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 31 Jan 2020 10:06:21 +0000 Subject: s/get_room_version/get_room_version_id/ ... to make way for a forthcoming get_room_version which returns a RoomVersion object. --- synapse/federation/federation_client.py | 10 +++++----- synapse/federation/federation_server.py | 16 ++++++++-------- synapse/handlers/federation.py | 18 +++++++++--------- synapse/handlers/message.py | 8 +++++--- synapse/handlers/pagination.py | 2 +- synapse/handlers/room.py | 2 +- synapse/state/__init__.py | 2 +- synapse/storage/data_stores/main/state.py | 2 +- synapse/storage/persist_events.py | 2 +- tests/handlers/test_presence.py | 2 +- tests/test_state.py | 2 +- tests/unittest.py | 4 +++- 12 files changed, 37 insertions(+), 33 deletions(-) (limited to 'tests') diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index d57e8ca7a2..4ac3d81cba 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -198,7 +198,7 @@ class FederationClient(FederationBase): logger.debug("backfill transaction_data=%r", transaction_data) - room_version = yield self.store.get_room_version(room_id) + room_version = yield self.store.get_room_version_id(room_id) format_ver = room_version_to_event_format(room_version) pdus = [ @@ -336,7 +336,7 @@ class FederationClient(FederationBase): def get_event_auth(self, destination, room_id, event_id): res = yield self.transport_layer.get_event_auth(destination, room_id, event_id) - room_version = yield self.store.get_room_version(room_id) + room_version = yield self.store.get_room_version_id(room_id) format_ver = room_version_to_event_format(room_version) auth_chain = [ @@ -649,7 +649,7 @@ class FederationClient(FederationBase): @defer.inlineCallbacks def send_invite(self, destination, room_id, event_id, pdu): - room_version = yield self.store.get_room_version(room_id) + room_version = yield self.store.get_room_version_id(room_id) content = yield self._do_send_invite(destination, pdu, room_version) @@ -657,7 +657,7 @@ class FederationClient(FederationBase): logger.debug("Got response to send_invite: %s", pdu_dict) - room_version = yield self.store.get_room_version(room_id) + room_version = yield self.store.get_room_version_id(room_id) format_ver = room_version_to_event_format(room_version) pdu = event_from_pdu_json(pdu_dict, format_ver) @@ -859,7 +859,7 @@ class FederationClient(FederationBase): timeout=timeout, ) - room_version = yield self.store.get_room_version(room_id) + room_version = yield self.store.get_room_version_id(room_id) format_ver = room_version_to_event_format(room_version) events = [ diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 9562faa3ee..a4c97ed458 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -234,7 +234,7 @@ class FederationServer(FederationBase): continue try: - room_version = await self.store.get_room_version(room_id) + room_version = await self.store.get_room_version_id(room_id) except NotFoundError: logger.info("Ignoring PDU for unknown room_id: %s", room_id) continue @@ -334,7 +334,7 @@ class FederationServer(FederationBase): ) ) - room_version = await self.store.get_room_version(room_id) + room_version = await self.store.get_room_version_id(room_id) resp["room_version"] = room_version return 200, resp @@ -385,7 +385,7 @@ class FederationServer(FederationBase): origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) - room_version = await self.store.get_room_version(room_id) + room_version = await self.store.get_room_version_id(room_id) if room_version not in supported_versions: logger.warning( "Room version %s not in %s", room_version, supported_versions @@ -417,7 +417,7 @@ class FederationServer(FederationBase): async def on_send_join_request(self, origin, content, room_id): logger.debug("on_send_join_request: content: %s", content) - room_version = await self.store.get_room_version(room_id) + room_version = await self.store.get_room_version_id(room_id) format_ver = room_version_to_event_format(room_version) pdu = event_from_pdu_json(content, format_ver) @@ -440,7 +440,7 @@ class FederationServer(FederationBase): await self.check_server_matches_acl(origin_host, room_id) pdu = await self.handler.on_make_leave_request(origin, room_id, user_id) - room_version = await self.store.get_room_version(room_id) + room_version = await self.store.get_room_version_id(room_id) time_now = self._clock.time_msec() return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} @@ -448,7 +448,7 @@ class FederationServer(FederationBase): async def on_send_leave_request(self, origin, content, room_id): logger.debug("on_send_leave_request: content: %s", content) - room_version = await self.store.get_room_version(room_id) + room_version = await self.store.get_room_version_id(room_id) format_ver = room_version_to_event_format(room_version) pdu = event_from_pdu_json(content, format_ver) @@ -495,7 +495,7 @@ class FederationServer(FederationBase): origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) - room_version = await self.store.get_room_version(room_id) + room_version = await self.store.get_room_version_id(room_id) format_ver = room_version_to_event_format(room_version) auth_chain = [ @@ -664,7 +664,7 @@ class FederationServer(FederationBase): logger.info("Accepting join PDU %s from %s", pdu.event_id, origin) # We've already checked that we know the room version by this point - room_version = await self.store.get_room_version(pdu.room_id) + room_version = await self.store.get_room_version_id(pdu.room_id) # Check signature. try: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 01372f6d47..30c720f093 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -388,7 +388,7 @@ class FederationHandler(BaseHandler): for x in remote_state: event_map[x.event_id] = x - room_version = await self.store.get_room_version(room_id) + room_version = await self.store.get_room_version_id(room_id) state_map = await resolve_events_with_store( room_id, room_version, @@ -1110,7 +1110,7 @@ class FederationHandler(BaseHandler): Logs a warning if we can't find the given event. """ - room_version = await self.store.get_room_version(room_id) + room_version = await self.store.get_room_version_id(room_id) event_infos = [] @@ -1373,7 +1373,7 @@ class FederationHandler(BaseHandler): event_content = {"membership": Membership.JOIN} - room_version = yield self.store.get_room_version(room_id) + room_version = yield self.store.get_room_version_id(room_id) builder = self.event_builder_factory.new( room_version, @@ -1607,7 +1607,7 @@ class FederationHandler(BaseHandler): ) raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) - room_version = yield self.store.get_room_version(room_id) + room_version = yield self.store.get_room_version_id(room_id) builder = self.event_builder_factory.new( room_version, { @@ -2055,7 +2055,7 @@ class FederationHandler(BaseHandler): do_soft_fail_check = False if do_soft_fail_check: - room_version = yield self.store.get_room_version(event.room_id) + room_version = yield self.store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] # Calculate the "current state". @@ -2191,7 +2191,7 @@ class FederationHandler(BaseHandler): Returns: defer.Deferred[EventContext]: updated context object """ - room_version = yield self.store.get_room_version(event.room_id) + room_version = yield self.store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] try: @@ -2363,7 +2363,7 @@ class FederationHandler(BaseHandler): remote_auth_events.update({(d.type, d.state_key): d for d in different_events}) remote_state = remote_auth_events.values() - room_version = yield self.store.get_room_version(event.room_id) + room_version = yield self.store.get_room_version_id(event.room_id) new_state = yield self.state_handler.resolve_events( room_version, (local_state, remote_state), event ) @@ -2587,7 +2587,7 @@ class FederationHandler(BaseHandler): } if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)): - room_version = yield self.store.get_room_version(room_id) + room_version = yield self.store.get_room_version_id(room_id) builder = self.event_builder_factory.new(room_version, event_dict) EventValidator().validate_builder(builder) @@ -2650,7 +2650,7 @@ class FederationHandler(BaseHandler): Returns: Deferred: resolves (to None) """ - room_version = yield self.store.get_room_version(room_id) + room_version = yield self.store.get_room_version_id(room_id) # NB: event_dict has a particular specced format we might need to fudge # if we change event formats too much. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 9a0f661b9b..bdf16c84d3 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -459,7 +459,9 @@ class EventCreationHandler(object): room_version = event_dict["content"]["room_version"] else: try: - room_version = yield self.store.get_room_version(event_dict["room_id"]) + room_version = yield self.store.get_room_version_id( + event_dict["room_id"] + ) except NotFoundError: raise AuthError(403, "Unknown room") @@ -788,7 +790,7 @@ class EventCreationHandler(object): ): room_version = event.content.get("room_version", RoomVersions.V1.identifier) else: - room_version = yield self.store.get_room_version(event.room_id) + room_version = yield self.store.get_room_version_id(event.room_id) event_allowed = yield self.third_party_event_rules.check_event_allowed( event, context @@ -963,7 +965,7 @@ class EventCreationHandler(object): auth_events = yield self.store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in auth_events.values()} - room_version = yield self.store.get_room_version(event.room_id) + room_version = yield self.store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] if event_auth.check_redaction( diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 71d76202c9..caf841a643 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -281,7 +281,7 @@ class PaginationHandler(object): """Purge the given room from the database""" with (await self.pagination_lock.write(room_id)): # check we know about the room - await self.store.get_room_version(room_id) + await self.store.get_room_version_id(room_id) # first check that we have no users in this room joined = await defer.maybeDeferred( diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index a95b45d791..1382399557 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -178,7 +178,7 @@ class RoomCreationHandler(BaseHandler): }, token_id=requester.access_token_id, ) - old_room_version = yield self.store.get_room_version(old_room_id) + old_room_version = yield self.store.get_room_version_id(old_room_id) yield self.auth.check_from_context( old_room_version, tombstone_event, tombstone_context ) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index cacd0c0c2b..fdd6bef6b4 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -394,7 +394,7 @@ class StateHandler(object): delta_ids=delta_ids, ) - room_version = yield self.store.get_room_version(room_id) + room_version = yield self.store.get_room_version_id(room_id) result = yield self._state_resolution_handler.resolve_state_groups( room_id, diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 4167f83c9b..6700942523 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -62,7 +62,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): super(StateGroupWorkerStore, self).__init__(database, db_conn, hs) @cached(max_entries=10000) - async def get_room_version(self, room_id: str) -> str: + async def get_room_version_id(self, room_id: str) -> str: """Get the room_version of a given room Raises: diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 86166fd4c1..af3fd67ab9 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -661,7 +661,7 @@ class EventsPersistenceStorage(object): break if not room_version: - room_version = await self.main_store.get_room_version(room_id) + room_version = await self.main_store.get_room_version_id(room_id) logger.debug("calling resolve_state_groups from preserve_events") res = await self._state_resolution_handler.resolve_state_groups( diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index d4293b4312..e92e090c3c 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -588,7 +588,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): hostname = get_domain_from_id(user_id) - room_version = self.get_success(self.store.get_room_version(room_id)) + room_version = self.get_success(self.store.get_room_version_id(room_id)) builder = EventBuilder( state=self.state, diff --git a/tests/test_state.py b/tests/test_state.py index e0aae06be4..1e4449fa1c 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -119,7 +119,7 @@ class StateGroupStore(object): def register_event_id_state_group(self, event_id, state_group): self._event_to_state_group[event_id] = state_group - def get_room_version(self, room_id): + def get_room_version_id(self, room_id): return RoomVersions.V1.identifier diff --git a/tests/unittest.py b/tests/unittest.py index b56e249386..98bf27d39c 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -589,7 +589,9 @@ class HomeserverTestCase(TestCase): event_builder_factory = self.hs.get_event_builder_factory() event_creation_handler = self.hs.get_event_creation_handler() - room_version = self.get_success(self.hs.get_datastore().get_room_version(room)) + room_version = self.get_success( + self.hs.get_datastore().get_room_version_id(room) + ) builder = event_builder_factory.for_room_version( KNOWN_ROOM_VERSIONS[room_version], -- cgit 1.5.1 From b9391c957572224c3a7c22870102fcbd24dea4e0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 3 Feb 2020 18:05:44 +0000 Subject: Add typing to SyncHandler (#6821) Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> --- changelog.d/6821.misc | 1 + synapse/events/__init__.py | 18 +- synapse/handlers/sync.py | 705 +++++++++++++++++++++------------------- tests/storage/test_redaction.py | 5 +- tox.ini | 1 + 5 files changed, 381 insertions(+), 349 deletions(-) create mode 100644 changelog.d/6821.misc (limited to 'tests') diff --git a/changelog.d/6821.misc b/changelog.d/6821.misc new file mode 100644 index 0000000000..1d5265d5e2 --- /dev/null +++ b/changelog.d/6821.misc @@ -0,0 +1 @@ +Add type hints to `SyncHandler`. diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index f813fa2fe7..92f76703b3 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -189,8 +189,14 @@ class EventBase(object): redacts = _event_dict_property("redacts", None) room_id = _event_dict_property("room_id") sender = _event_dict_property("sender") + state_key = _event_dict_property("state_key") + type = _event_dict_property("type") user_id = _event_dict_property("sender") + @property + def event_id(self) -> str: + raise NotImplementedError() + @property def membership(self): return self.content["membership"] @@ -281,10 +287,7 @@ class FrozenEvent(EventBase): else: frozen_dict = event_dict - self.event_id = event_dict["event_id"] - self.type = event_dict["type"] - if "state_key" in event_dict: - self.state_key = event_dict["state_key"] + self._event_id = event_dict["event_id"] super(FrozenEvent, self).__init__( frozen_dict, @@ -294,6 +297,10 @@ class FrozenEvent(EventBase): rejected_reason=rejected_reason, ) + @property + def event_id(self) -> str: + return self._event_id + def __str__(self): return self.__repr__() @@ -332,9 +339,6 @@ class FrozenEventV2(EventBase): frozen_dict = event_dict self._event_id = None - self.type = event_dict["type"] - if "state_key" in event_dict: - self.state_key = event_dict["state_key"] super(FrozenEventV2, self).__init__( frozen_dict, diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index cd95f85e3f..5f060241b4 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -14,20 +14,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -import collections import itertools import logging +from typing import Any, Dict, FrozenSet, List, Optional, Set, Tuple from six import iteritems, itervalues +import attr from prometheus_client import Counter from synapse.api.constants import EventTypes, Membership +from synapse.api.filtering import FilterCollection +from synapse.events import EventBase from synapse.logging.context import LoggingContext from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.roommember import MemberSummary from synapse.storage.state import StateFilter -from synapse.types import RoomStreamToken +from synapse.types import ( + Collection, + JsonDict, + RoomStreamToken, + StateMap, + StreamToken, + UserID, +) from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.lrucache import LruCache @@ -62,17 +72,22 @@ LAZY_LOADED_MEMBERS_CACHE_MAX_AGE = 30 * 60 * 1000 LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE = 100 -SyncConfig = collections.namedtuple( - "SyncConfig", ["user", "filter_collection", "is_guest", "request_key", "device_id"] -) +@attr.s(slots=True, frozen=True) +class SyncConfig: + user = attr.ib(type=UserID) + filter_collection = attr.ib(type=FilterCollection) + is_guest = attr.ib(type=bool) + request_key = attr.ib(type=Tuple[Any, ...]) + device_id = attr.ib(type=str) -class TimelineBatch( - collections.namedtuple("TimelineBatch", ["prev_batch", "events", "limited"]) -): - __slots__ = [] +@attr.s(slots=True, frozen=True) +class TimelineBatch: + prev_batch = attr.ib(type=StreamToken) + events = attr.ib(type=List[EventBase]) + limited = attr.ib(bool) - def __nonzero__(self): + def __nonzero__(self) -> bool: """Make the result appear empty if there are no updates. This is used to tell if room needs to be part of the sync result. """ @@ -81,23 +96,17 @@ class TimelineBatch( __bool__ = __nonzero__ # python3 -class JoinedSyncResult( - collections.namedtuple( - "JoinedSyncResult", - [ - "room_id", # str - "timeline", # TimelineBatch - "state", # dict[(str, str), FrozenEvent] - "ephemeral", - "account_data", - "unread_notifications", - "summary", - ], - ) -): - __slots__ = [] - - def __nonzero__(self): +@attr.s(slots=True, frozen=True) +class JoinedSyncResult: + room_id = attr.ib(type=str) + timeline = attr.ib(type=TimelineBatch) + state = attr.ib(type=StateMap[EventBase]) + ephemeral = attr.ib(type=List[JsonDict]) + account_data = attr.ib(type=List[JsonDict]) + unread_notifications = attr.ib(type=JsonDict) + summary = attr.ib(type=Optional[JsonDict]) + + def __nonzero__(self) -> bool: """Make the result appear empty if there are no updates. This is used to tell if room needs to be part of the sync result. """ @@ -113,20 +122,14 @@ class JoinedSyncResult( __bool__ = __nonzero__ # python3 -class ArchivedSyncResult( - collections.namedtuple( - "ArchivedSyncResult", - [ - "room_id", # str - "timeline", # TimelineBatch - "state", # dict[(str, str), FrozenEvent] - "account_data", - ], - ) -): - __slots__ = [] - - def __nonzero__(self): +@attr.s(slots=True, frozen=True) +class ArchivedSyncResult: + room_id = attr.ib(type=str) + timeline = attr.ib(type=TimelineBatch) + state = attr.ib(type=StateMap[EventBase]) + account_data = attr.ib(type=List[JsonDict]) + + def __nonzero__(self) -> bool: """Make the result appear empty if there are no updates. This is used to tell if room needs to be part of the sync result. """ @@ -135,70 +138,88 @@ class ArchivedSyncResult( __bool__ = __nonzero__ # python3 -class InvitedSyncResult( - collections.namedtuple( - "InvitedSyncResult", - ["room_id", "invite"], # str # FrozenEvent: the invite event - ) -): - __slots__ = [] +@attr.s(slots=True, frozen=True) +class InvitedSyncResult: + room_id = attr.ib(type=str) + invite = attr.ib(type=EventBase) - def __nonzero__(self): + def __nonzero__(self) -> bool: """Invited rooms should always be reported to the client""" return True __bool__ = __nonzero__ # python3 -class GroupsSyncResult( - collections.namedtuple("GroupsSyncResult", ["join", "invite", "leave"]) -): - __slots__ = [] +@attr.s(slots=True, frozen=True) +class GroupsSyncResult: + join = attr.ib(type=JsonDict) + invite = attr.ib(type=JsonDict) + leave = attr.ib(type=JsonDict) - def __nonzero__(self): + def __nonzero__(self) -> bool: return bool(self.join or self.invite or self.leave) __bool__ = __nonzero__ # python3 -class DeviceLists( - collections.namedtuple( - "DeviceLists", - [ - "changed", # list of user_ids whose devices may have changed - "left", # list of user_ids whose devices we no longer track - ], - ) -): - __slots__ = [] +@attr.s(slots=True, frozen=True) +class DeviceLists: + """ + Attributes: + changed: List of user_ids whose devices may have changed + left: List of user_ids whose devices we no longer track + """ + + changed = attr.ib(type=Collection[str]) + left = attr.ib(type=Collection[str]) - def __nonzero__(self): + def __nonzero__(self) -> bool: return bool(self.changed or self.left) __bool__ = __nonzero__ # python3 -class SyncResult( - collections.namedtuple( - "SyncResult", - [ - "next_batch", # Token for the next sync - "presence", # List of presence events for the user. - "account_data", # List of account_data events for the user. - "joined", # JoinedSyncResult for each joined room. - "invited", # InvitedSyncResult for each invited room. - "archived", # ArchivedSyncResult for each archived room. - "to_device", # List of direct messages for the device. - "device_lists", # List of user_ids whose devices have changed - "device_one_time_keys_count", # Dict of algorithm to count for one time keys - # for this device - "groups", - ], - ) -): - __slots__ = [] - - def __nonzero__(self): +@attr.s +class _RoomChanges: + """The set of room entries to include in the sync, plus the set of joined + and left room IDs since last sync. + """ + + room_entries = attr.ib(type=List["RoomSyncResultBuilder"]) + invited = attr.ib(type=List[InvitedSyncResult]) + newly_joined_rooms = attr.ib(type=List[str]) + newly_left_rooms = attr.ib(type=List[str]) + + +@attr.s(slots=True, frozen=True) +class SyncResult: + """ + Attributes: + next_batch: Token for the next sync + presence: List of presence events for the user. + account_data: List of account_data events for the user. + joined: JoinedSyncResult for each joined room. + invited: InvitedSyncResult for each invited room. + archived: ArchivedSyncResult for each archived room. + to_device: List of direct messages for the device. + device_lists: List of user_ids whose devices have changed + device_one_time_keys_count: Dict of algorithm to count for one time keys + for this device + groups: Group updates, if any + """ + + next_batch = attr.ib(type=StreamToken) + presence = attr.ib(type=List[JsonDict]) + account_data = attr.ib(type=List[JsonDict]) + joined = attr.ib(type=List[JoinedSyncResult]) + invited = attr.ib(type=List[InvitedSyncResult]) + archived = attr.ib(type=List[ArchivedSyncResult]) + to_device = attr.ib(type=List[JsonDict]) + device_lists = attr.ib(type=DeviceLists) + device_one_time_keys_count = attr.ib(type=JsonDict) + groups = attr.ib(type=Optional[GroupsSyncResult]) + + def __nonzero__(self) -> bool: """Make the result appear empty if there are no updates. This is used to tell if the notifier needs to wait for more events when polling for events. @@ -240,13 +261,15 @@ class SyncHandler(object): ) async def wait_for_sync_for_user( - self, sync_config, since_token=None, timeout=0, full_state=False - ): + self, + sync_config: SyncConfig, + since_token: Optional[StreamToken] = None, + timeout: int = 0, + full_state: bool = False, + ) -> SyncResult: """Get the sync for a client if we have new data for it now. Otherwise wait for new data to arrive on the server. If the timeout expires, then return an empty sync result. - Returns: - Deferred[SyncResult] """ # If the user is not part of the mau group, then check that limits have # not been exceeded (if not part of the group by this point, almost certain @@ -265,8 +288,12 @@ class SyncHandler(object): return res async def _wait_for_sync_for_user( - self, sync_config, since_token, timeout, full_state - ): + self, + sync_config: SyncConfig, + since_token: Optional[StreamToken] = None, + timeout: int = 0, + full_state: bool = False, + ) -> SyncResult: if since_token is None: sync_type = "initial_sync" elif full_state: @@ -305,25 +332,33 @@ class SyncHandler(object): return result - def current_sync_for_user(self, sync_config, since_token=None, full_state=False): + async def current_sync_for_user( + self, + sync_config: SyncConfig, + since_token: Optional[StreamToken] = None, + full_state: bool = False, + ) -> SyncResult: """Get the sync for client needed to match what the server has now. - Returns: - A Deferred SyncResult. """ - return self.generate_sync_result(sync_config, since_token, full_state) + return await self.generate_sync_result(sync_config, since_token, full_state) - async def push_rules_for_user(self, user): + async def push_rules_for_user(self, user: UserID) -> JsonDict: user_id = user.to_string() rules = await self.store.get_push_rules_for_user(user_id) rules = format_push_rules_for_user(user, rules) return rules - async def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None): + async def ephemeral_by_room( + self, + sync_result_builder: "SyncResultBuilder", + now_token: StreamToken, + since_token: Optional[StreamToken] = None, + ) -> Tuple[StreamToken, Dict[str, List[JsonDict]]]: """Get the ephemeral events for each room the user is in Args: - sync_result_builder(SyncResultBuilder) - now_token (StreamToken): Where the server is currently up to. - since_token (StreamToken): Where the server was when the client + sync_result_builder + now_token: Where the server is currently up to. + since_token: Where the server was when the client last synced. Returns: A tuple of the now StreamToken, updated to reflect the which typing @@ -348,7 +383,7 @@ class SyncHandler(object): ) now_token = now_token.copy_and_replace("typing_key", typing_key) - ephemeral_by_room = {} + ephemeral_by_room = {} # type: JsonDict for event in typing: # we want to exclude the room_id from the event, but modifying the @@ -380,13 +415,13 @@ class SyncHandler(object): async def _load_filtered_recents( self, - room_id, - sync_config, - now_token, - since_token=None, - recents=None, - newly_joined_room=False, - ): + room_id: str, + sync_config: SyncConfig, + now_token: StreamToken, + since_token: Optional[StreamToken] = None, + potential_recents: Optional[List[EventBase]] = None, + newly_joined_room: bool = False, + ) -> TimelineBatch: """ Returns: a Deferred TimelineBatch @@ -397,21 +432,29 @@ class SyncHandler(object): sync_config.filter_collection.blocks_all_room_timeline() ) - if recents is None or newly_joined_room or timeline_limit < len(recents): + if ( + potential_recents is None + or newly_joined_room + or timeline_limit < len(potential_recents) + ): limited = True else: limited = False - if recents: - recents = sync_config.filter_collection.filter_room_timeline(recents) + if potential_recents: + recents = sync_config.filter_collection.filter_room_timeline( + potential_recents + ) # We check if there are any state events, if there are then we pass # all current state events to the filter_events function. This is to # ensure that we always include current state in the timeline - current_state_ids = frozenset() + current_state_ids = frozenset() # type: FrozenSet[str] if any(e.is_state() for e in recents): - current_state_ids = await self.state.get_current_state_ids(room_id) - current_state_ids = frozenset(itervalues(current_state_ids)) + current_state_ids_map = await self.state.get_current_state_ids( + room_id + ) + current_state_ids = frozenset(itervalues(current_state_ids_map)) recents = await filter_events_for_client( self.storage, @@ -463,8 +506,10 @@ class SyncHandler(object): # ensure that we always include current state in the timeline current_state_ids = frozenset() if any(e.is_state() for e in loaded_recents): - current_state_ids = await self.state.get_current_state_ids(room_id) - current_state_ids = frozenset(itervalues(current_state_ids)) + current_state_ids_map = await self.state.get_current_state_ids( + room_id + ) + current_state_ids = frozenset(itervalues(current_state_ids_map)) loaded_recents = await filter_events_for_client( self.storage, @@ -493,17 +538,15 @@ class SyncHandler(object): limited=limited or newly_joined_room, ) - async def get_state_after_event(self, event, state_filter=StateFilter.all()): + async def get_state_after_event( + self, event: EventBase, state_filter: StateFilter = StateFilter.all() + ) -> StateMap[str]: """ Get the room state after the given event Args: - event(synapse.events.EventBase): event of interest - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - A Deferred map from ((type, state_key)->Event) + event: event of interest + state_filter: The state filter used to fetch state from the database. """ state_ids = await self.state_store.get_state_ids_for_event( event.event_id, state_filter=state_filter @@ -514,18 +557,17 @@ class SyncHandler(object): return state_ids async def get_state_at( - self, room_id, stream_position, state_filter=StateFilter.all() - ): + self, + room_id: str, + stream_position: StreamToken, + state_filter: StateFilter = StateFilter.all(), + ) -> StateMap[str]: """ Get the room state at a particular stream position Args: - room_id(str): room for which to get state - stream_position(StreamToken): point at which to get state - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - A Deferred map from ((type, state_key)->Event) + room_id: room for which to get state + stream_position: point at which to get state + state_filter: The state filter used to fetch state from the database. """ # FIXME this claims to get the state at a stream position, but # get_recent_events_for_room operates by topo ordering. This therefore @@ -546,23 +588,25 @@ class SyncHandler(object): state = {} return state - async def compute_summary(self, room_id, sync_config, batch, state, now_token): + async def compute_summary( + self, + room_id: str, + sync_config: SyncConfig, + batch: TimelineBatch, + state: StateMap[EventBase], + now_token: StreamToken, + ) -> Optional[JsonDict]: """ Works out a room summary block for this room, summarising the number of joined members in the room, and providing the 'hero' members if the room has no name so clients can consistently name rooms. Also adds state events to 'state' if needed to describe the heroes. - Args: - room_id(str): - sync_config(synapse.handlers.sync.SyncConfig): - batch(synapse.handlers.sync.TimelineBatch): The timeline batch for - the room that will be sent to the user. - state(dict): dict of (type, state_key) -> Event as returned by - compute_state_delta - now_token(str): Token of the end of the current batch. - - Returns: - A deferred dict describing the room summary + Args + room_id + sync_config + batch: The timeline batch for the room that will be sent to the user. + state: State as returned by compute_state_delta + now_token: Token of the end of the current batch. """ # FIXME: we could/should get this from room_stats when matthew/stats lands @@ -681,7 +725,7 @@ class SyncHandler(object): return summary - def get_lazy_loaded_members_cache(self, cache_key): + def get_lazy_loaded_members_cache(self, cache_key: Tuple[str, str]) -> LruCache: cache = self.lazy_loaded_members_cache.get(cache_key) if cache is None: logger.debug("creating LruCache for %r", cache_key) @@ -692,23 +736,24 @@ class SyncHandler(object): return cache async def compute_state_delta( - self, room_id, batch, sync_config, since_token, now_token, full_state - ): + self, + room_id: str, + batch: TimelineBatch, + sync_config: SyncConfig, + since_token: Optional[StreamToken], + now_token: StreamToken, + full_state: bool, + ) -> StateMap[EventBase]: """ Works out the difference in state between the start of the timeline and the previous sync. Args: - room_id(str): - batch(synapse.handlers.sync.TimelineBatch): The timeline batch for - the room that will be sent to the user. - sync_config(synapse.handlers.sync.SyncConfig): - since_token(str|None): Token of the end of the previous batch. May - be None. - now_token(str): Token of the end of the current batch. - full_state(bool): Whether to force returning the full state. - - Returns: - A deferred dict of (type, state_key) -> Event + room_id: + batch: The timeline batch for the room that will be sent to the user. + sync_config: + since_token: Token of the end of the previous batch. May be None. + now_token: Token of the end of the current batch. + full_state: Whether to force returning the full state. """ # TODO(mjark) Check if the state events were received by the server # after the previous sync, since we need to include those state @@ -800,6 +845,10 @@ class SyncHandler(object): # about them). state_filter = StateFilter.all() + # If this is an initial sync then full_state should be set, and + # that case is handled above. We assert here to ensure that this + # is indeed the case. + assert since_token is not None state_at_previous_sync = await self.get_state_at( room_id, stream_position=since_token, state_filter=state_filter ) @@ -874,7 +923,7 @@ class SyncHandler(object): if t[0] == EventTypes.Member: cache.set(t[1], event_id) - state = {} + state = {} # type: Dict[str, EventBase] if state_ids: state = await self.store.get_events(list(state_ids.values())) @@ -885,7 +934,9 @@ class SyncHandler(object): ) } - async def unread_notifs_for_room_id(self, room_id, sync_config): + async def unread_notifs_for_room_id( + self, room_id: str, sync_config: SyncConfig + ) -> Optional[Dict[str, str]]: with Measure(self.clock, "unread_notifs_for_room_id"): last_unread_event_id = await self.store.get_last_receipt_event_id_for_user( user_id=sync_config.user.to_string(), @@ -893,7 +944,6 @@ class SyncHandler(object): receipt_type="m.read", ) - notifs = [] if last_unread_event_id: notifs = await self.store.get_unread_event_push_actions_by_room_for_user( room_id, sync_config.user.to_string(), last_unread_event_id @@ -905,17 +955,12 @@ class SyncHandler(object): return None async def generate_sync_result( - self, sync_config, since_token=None, full_state=False - ): + self, + sync_config: SyncConfig, + since_token: Optional[StreamToken] = None, + full_state: bool = False, + ) -> SyncResult: """Generates a sync result. - - Args: - sync_config (SyncConfig) - since_token (StreamToken) - full_state (bool) - - Returns: - Deferred(SyncResult) """ # NB: The now_token gets changed by some of the generate_sync_* methods, # this is due to some of the underlying streams not supporting the ability @@ -977,7 +1022,7 @@ class SyncHandler(object): ) device_id = sync_config.device_id - one_time_key_counts = {} + one_time_key_counts = {} # type: JsonDict if device_id: one_time_key_counts = await self.store.count_e2e_one_time_keys( user_id, device_id @@ -1007,7 +1052,9 @@ class SyncHandler(object): ) @measure_func("_generate_sync_entry_for_groups") - async def _generate_sync_entry_for_groups(self, sync_result_builder): + async def _generate_sync_entry_for_groups( + self, sync_result_builder: "SyncResultBuilder" + ) -> None: user_id = sync_result_builder.sync_config.user.to_string() since_token = sync_result_builder.since_token now_token = sync_result_builder.now_token @@ -1052,27 +1099,22 @@ class SyncHandler(object): @measure_func("_generate_sync_entry_for_device_list") async def _generate_sync_entry_for_device_list( self, - sync_result_builder, - newly_joined_rooms, - newly_joined_or_invited_users, - newly_left_rooms, - newly_left_users, - ): + sync_result_builder: "SyncResultBuilder", + newly_joined_rooms: Set[str], + newly_joined_or_invited_users: Set[str], + newly_left_rooms: Set[str], + newly_left_users: Set[str], + ) -> DeviceLists: """Generate the DeviceLists section of sync Args: - sync_result_builder (SyncResultBuilder) - newly_joined_rooms (set[str]): Set of rooms user has joined since - previous sync - newly_joined_or_invited_users (set[str]): Set of users that have - joined or been invited to a room since previous sync. - newly_left_rooms (set[str]): Set of rooms user has left since + sync_result_builder + newly_joined_rooms: Set of rooms user has joined since previous sync + newly_joined_or_invited_users: Set of users that have joined or + been invited to a room since previous sync. + newly_left_rooms: Set of rooms user has left since previous sync + newly_left_users: Set of users that have left a room we're in since previous sync - newly_left_users (set[str]): Set of users that have left a room - we're in since previous sync - - Returns: - Deferred[DeviceLists] """ user_id = sync_result_builder.sync_config.user.to_string() @@ -1133,15 +1175,11 @@ class SyncHandler(object): else: return DeviceLists(changed=[], left=[]) - async def _generate_sync_entry_for_to_device(self, sync_result_builder): + async def _generate_sync_entry_for_to_device( + self, sync_result_builder: "SyncResultBuilder" + ) -> None: """Generates the portion of the sync response. Populates `sync_result_builder` with the result. - - Args: - sync_result_builder(SyncResultBuilder) - - Returns: - Deferred(dict): A dictionary containing the per room account data. """ user_id = sync_result_builder.sync_config.user.to_string() device_id = sync_result_builder.sync_config.device_id @@ -1179,15 +1217,17 @@ class SyncHandler(object): else: sync_result_builder.to_device = [] - async def _generate_sync_entry_for_account_data(self, sync_result_builder): + async def _generate_sync_entry_for_account_data( + self, sync_result_builder: "SyncResultBuilder" + ) -> Dict[str, Dict[str, JsonDict]]: """Generates the account data portion of the sync response. Populates `sync_result_builder` with the result. Args: - sync_result_builder(SyncResultBuilder) + sync_result_builder Returns: - Deferred(dict): A dictionary containing the per room account data. + A dictionary containing the per room account data. """ sync_config = sync_result_builder.sync_config user_id = sync_result_builder.sync_config.user.to_string() @@ -1231,18 +1271,21 @@ class SyncHandler(object): return account_data_by_room async def _generate_sync_entry_for_presence( - self, sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users - ): + self, + sync_result_builder: "SyncResultBuilder", + newly_joined_rooms: Set[str], + newly_joined_or_invited_users: Set[str], + ) -> None: """Generates the presence portion of the sync response. Populates the `sync_result_builder` with the result. Args: - sync_result_builder(SyncResultBuilder) - newly_joined_rooms(list): List of rooms that the user has joined - since the last sync (or empty if an initial sync) - newly_joined_or_invited_users(list): List of users that have joined - or been invited to rooms since the last sync (or empty if an initial - sync) + sync_result_builder + newly_joined_rooms: Set of rooms that the user has joined since + the last sync (or empty if an initial sync) + newly_joined_or_invited_users: Set of users that have joined or + been invited to rooms since the last sync (or empty if an + initial sync) """ now_token = sync_result_builder.now_token sync_config = sync_result_builder.sync_config @@ -1286,17 +1329,19 @@ class SyncHandler(object): sync_result_builder.presence = presence async def _generate_sync_entry_for_rooms( - self, sync_result_builder, account_data_by_room - ): + self, + sync_result_builder: "SyncResultBuilder", + account_data_by_room: Dict[str, Dict[str, JsonDict]], + ) -> Tuple[Set[str], Set[str], Set[str], Set[str]]: """Generates the rooms portion of the sync response. Populates the `sync_result_builder` with the result. Args: - sync_result_builder(SyncResultBuilder) - account_data_by_room(dict): Dictionary of per room account data + sync_result_builder + account_data_by_room: Dictionary of per room account data Returns: - Deferred(tuple): Returns a 4-tuple of + Returns a 4-tuple of `(newly_joined_rooms, newly_joined_or_invited_users, newly_left_rooms, newly_left_users)` """ @@ -1307,7 +1352,7 @@ class SyncHandler(object): ) if block_all_room_ephemeral: - ephemeral_by_room = {} + ephemeral_by_room = {} # type: Dict[str, List[JsonDict]] else: now_token, ephemeral_by_room = await self.ephemeral_by_room( sync_result_builder, @@ -1328,7 +1373,7 @@ class SyncHandler(object): ) if not tags_by_room: logger.debug("no-oping sync") - return [], [], [], [] + return set(), set(), set(), set() ignored_account_data = await self.store.get_global_account_data_by_type_for_user( "m.ignored_user_list", user_id=user_id @@ -1340,19 +1385,22 @@ class SyncHandler(object): ignored_users = frozenset() if since_token: - res = await self._get_rooms_changed(sync_result_builder, ignored_users) - room_entries, invited, newly_joined_rooms, newly_left_rooms = res - + room_changes = await self._get_rooms_changed( + sync_result_builder, ignored_users + ) tags_by_room = await self.store.get_updated_tags( user_id, since_token.account_data_key ) else: - res = await self._get_all_rooms(sync_result_builder, ignored_users) - room_entries, invited, newly_joined_rooms = res - newly_left_rooms = [] + room_changes = await self._get_all_rooms(sync_result_builder, ignored_users) tags_by_room = await self.store.get_tags_for_user(user_id) + room_entries = room_changes.room_entries + invited = room_changes.invited + newly_joined_rooms = room_changes.newly_joined_rooms + newly_left_rooms = room_changes.newly_left_rooms + def handle_room_entries(room_entry): return self._generate_room_entry( sync_result_builder, @@ -1392,13 +1440,15 @@ class SyncHandler(object): newly_left_users -= newly_joined_or_invited_users return ( - newly_joined_rooms, + set(newly_joined_rooms), newly_joined_or_invited_users, - newly_left_rooms, + set(newly_left_rooms), newly_left_users, ) - async def _have_rooms_changed(self, sync_result_builder): + async def _have_rooms_changed( + self, sync_result_builder: "SyncResultBuilder" + ) -> bool: """Returns whether there may be any new events that should be sent down the sync. Returns True if there are. """ @@ -1422,22 +1472,10 @@ class SyncHandler(object): return True return False - async def _get_rooms_changed(self, sync_result_builder, ignored_users): + async def _get_rooms_changed( + self, sync_result_builder: "SyncResultBuilder", ignored_users: Set[str] + ) -> _RoomChanges: """Gets the the changes that have happened since the last sync. - - Args: - sync_result_builder(SyncResultBuilder) - ignored_users(set(str)): Set of users ignored by user. - - Returns: - Deferred(tuple): Returns a tuple of the form: - `(room_entries, invited_rooms, newly_joined_rooms, newly_left_rooms)` - - where: - room_entries is a list [RoomSyncResultBuilder] - invited_rooms is a list [InvitedSyncResult] - newly_joined_rooms is a list[str] of room ids - newly_left_rooms is a list[str] of room ids """ user_id = sync_result_builder.sync_config.user.to_string() since_token = sync_result_builder.since_token @@ -1451,7 +1489,7 @@ class SyncHandler(object): user_id, since_token.room_key, now_token.room_key ) - mem_change_events_by_room_id = {} + mem_change_events_by_room_id = {} # type: Dict[str, List[EventBase]] for event in rooms_changed: mem_change_events_by_room_id.setdefault(event.room_id, []).append(event) @@ -1570,7 +1608,7 @@ class SyncHandler(object): # This is all screaming out for a refactor, as the logic here is # subtle and the moving parts numerous. if leave_event.internal_metadata.is_out_of_band_membership(): - batch_events = [leave_event] + batch_events = [leave_event] # type: Optional[List[EventBase]] else: batch_events = None @@ -1636,18 +1674,17 @@ class SyncHandler(object): ) room_entries.append(entry) - return room_entries, invited, newly_joined_rooms, newly_left_rooms + return _RoomChanges(room_entries, invited, newly_joined_rooms, newly_left_rooms) - async def _get_all_rooms(self, sync_result_builder, ignored_users): + async def _get_all_rooms( + self, sync_result_builder: "SyncResultBuilder", ignored_users: Set[str] + ) -> _RoomChanges: """Returns entries for all rooms for the user. Args: - sync_result_builder(SyncResultBuilder) - ignored_users(set(str)): Set of users ignored by user. + sync_result_builder + ignored_users: Set of users ignored by user. - Returns: - Deferred(tuple): Returns a tuple of the form: - `([RoomSyncResultBuilder], [InvitedSyncResult], [])` """ user_id = sync_result_builder.sync_config.user.to_string() @@ -1709,30 +1746,30 @@ class SyncHandler(object): ) ) - return room_entries, invited, [] + return _RoomChanges(room_entries, invited, [], []) async def _generate_room_entry( self, - sync_result_builder, - ignored_users, - room_builder, - ephemeral, - tags, - account_data, - always_include=False, + sync_result_builder: "SyncResultBuilder", + ignored_users: Set[str], + room_builder: "RoomSyncResultBuilder", + ephemeral: List[JsonDict], + tags: Optional[List[JsonDict]], + account_data: Dict[str, JsonDict], + always_include: bool = False, ): """Populates the `joined` and `archived` section of `sync_result_builder` based on the `room_builder`. Args: - sync_result_builder(SyncResultBuilder) - ignored_users(set(str)): Set of users ignored by user. - room_builder(RoomSyncResultBuilder) - ephemeral(list): List of new ephemeral events for room - tags(list): List of *all* tags for room, or None if there has been + sync_result_builder + ignored_users: Set of users ignored by user. + room_builder + ephemeral: List of new ephemeral events for room + tags: List of *all* tags for room, or None if there has been no change. - account_data(list): List of new account data for room - always_include(bool): Always include this room in the sync response, + account_data: List of new account data for room + always_include: Always include this room in the sync response, even if empty. """ newly_joined = room_builder.newly_joined @@ -1758,7 +1795,7 @@ class SyncHandler(object): sync_config, now_token=upto_token, since_token=since_token, - recents=events, + potential_recents=events, newly_joined_room=newly_joined, ) @@ -1809,7 +1846,7 @@ class SyncHandler(object): room_id, batch, sync_config, since_token, now_token, full_state=full_state ) - summary = {} + summary = {} # type: Optional[JsonDict] # we include a summary in room responses when we're lazy loading # members (as the client otherwise doesn't have enough info to form @@ -1833,7 +1870,7 @@ class SyncHandler(object): ) if room_builder.rtype == "joined": - unread_notifications = {} + unread_notifications = {} # type: Dict[str, str] room_sync = JoinedSyncResult( room_id=room_id, timeline=batch, @@ -1860,18 +1897,20 @@ class SyncHandler(object): % (room_id, user_id, len(state)) ) elif room_builder.rtype == "archived": - room_sync = ArchivedSyncResult( + archived_room_sync = ArchivedSyncResult( room_id=room_id, timeline=batch, state=state, account_data=account_data_events, ) - if room_sync or always_include: - sync_result_builder.archived.append(room_sync) + if archived_room_sync or always_include: + sync_result_builder.archived.append(archived_room_sync) else: raise Exception("Unrecognized rtype: %r", room_builder.rtype) - async def get_rooms_for_user_at(self, user_id, stream_ordering): + async def get_rooms_for_user_at( + self, user_id: str, stream_ordering: int + ) -> FrozenSet[str]: """Get set of joined rooms for a user at the given stream ordering. The stream ordering *must* be recent, otherwise this may throw an @@ -1879,12 +1918,11 @@ class SyncHandler(object): current token, which should be perfectly fine). Args: - user_id (str) - stream_ordering (int) + user_id + stream_ordering ReturnValue: - Deferred[frozenset[str]]: Set of room_ids the user is in at given - stream_ordering. + Set of room_ids the user is in at given stream_ordering. """ joined_rooms = await self.store.get_rooms_for_user_with_stream_ordering(user_id) @@ -1911,11 +1949,10 @@ class SyncHandler(object): if user_id in users_in_room: joined_room_ids.add(room_id) - joined_room_ids = frozenset(joined_room_ids) - return joined_room_ids + return frozenset(joined_room_ids) -def _action_has_highlight(actions): +def _action_has_highlight(actions: List[JsonDict]) -> bool: for action in actions: try: if action.get("set_tweak", None) == "highlight": @@ -1927,22 +1964,23 @@ def _action_has_highlight(actions): def _calculate_state( - timeline_contains, timeline_start, previous, current, lazy_load_members -): + timeline_contains: StateMap[str], + timeline_start: StateMap[str], + previous: StateMap[str], + current: StateMap[str], + lazy_load_members: bool, +) -> StateMap[str]: """Works out what state to include in a sync response. Args: - timeline_contains (dict): state in the timeline - timeline_start (dict): state at the start of the timeline - previous (dict): state at the end of the previous sync (or empty dict + timeline_contains: state in the timeline + timeline_start: state at the start of the timeline + previous: state at the end of the previous sync (or empty dict if this is an initial sync) - current (dict): state at the end of the timeline - lazy_load_members (bool): whether to return members from timeline_start + current: state at the end of the timeline + lazy_load_members: whether to return members from timeline_start or not. assumes that timeline_start has already been filtered to include only the members the client needs to know about. - - Returns: - dict """ event_id_to_key = { e: key @@ -1979,15 +2017,16 @@ def _calculate_state( return {event_id_to_key[e]: e for e in state_ids} -class SyncResultBuilder(object): +@attr.s +class SyncResultBuilder: """Used to help build up a new SyncResult for a user Attributes: - sync_config (SyncConfig) - full_state (bool) - since_token (StreamToken) - now_token (StreamToken) - joined_room_ids (list[str]) + sync_config + full_state: The full_state flag as specified by user + since_token: The token supplied by user, or None. + now_token: The token to sync up to. + joined_room_ids: List of rooms the user is joined to # The following mirror the fields in a sync response presence (list) @@ -1995,61 +2034,45 @@ class SyncResultBuilder(object): joined (list[JoinedSyncResult]) invited (list[InvitedSyncResult]) archived (list[ArchivedSyncResult]) - device (list) groups (GroupsSyncResult|None) to_device (list) """ - def __init__( - self, sync_config, full_state, since_token, now_token, joined_room_ids - ): - """ - Args: - sync_config (SyncConfig) - full_state (bool): The full_state flag as specified by user - since_token (StreamToken): The token supplied by user, or None. - now_token (StreamToken): The token to sync up to. - joined_room_ids (list[str]): List of rooms the user is joined to - """ - self.sync_config = sync_config - self.full_state = full_state - self.since_token = since_token - self.now_token = now_token - self.joined_room_ids = joined_room_ids - - self.presence = [] - self.account_data = [] - self.joined = [] - self.invited = [] - self.archived = [] - self.device = [] - self.groups = None - self.to_device = [] + sync_config = attr.ib(type=SyncConfig) + full_state = attr.ib(type=bool) + since_token = attr.ib(type=Optional[StreamToken]) + now_token = attr.ib(type=StreamToken) + joined_room_ids = attr.ib(type=FrozenSet[str]) + + presence = attr.ib(type=List[JsonDict], default=attr.Factory(list)) + account_data = attr.ib(type=List[JsonDict], default=attr.Factory(list)) + joined = attr.ib(type=List[JoinedSyncResult], default=attr.Factory(list)) + invited = attr.ib(type=List[InvitedSyncResult], default=attr.Factory(list)) + archived = attr.ib(type=List[ArchivedSyncResult], default=attr.Factory(list)) + groups = attr.ib(type=Optional[GroupsSyncResult], default=None) + to_device = attr.ib(type=List[JsonDict], default=attr.Factory(list)) +@attr.s class RoomSyncResultBuilder(object): """Stores information needed to create either a `JoinedSyncResult` or `ArchivedSyncResult`. + + Attributes: + room_id + rtype: One of `"joined"` or `"archived"` + events: List of events to include in the room (more events may be added + when generating result). + newly_joined: If the user has newly joined the room + full_state: Whether the full state should be sent in result + since_token: Earliest point to return events from, or None + upto_token: Latest point to return events from. """ - def __init__( - self, room_id, rtype, events, newly_joined, full_state, since_token, upto_token - ): - """ - Args: - room_id(str) - rtype(str): One of `"joined"` or `"archived"` - events(list[FrozenEvent]): List of events to include in the room - (more events may be added when generating result). - newly_joined(bool): If the user has newly joined the room - full_state(bool): Whether the full state should be sent in result - since_token(StreamToken): Earliest point to return events from, or None - upto_token(StreamToken): Latest point to return events from. - """ - self.room_id = room_id - self.rtype = rtype - self.events = events - self.newly_joined = newly_joined - self.full_state = full_state - self.since_token = since_token - self.upto_token = upto_token + room_id = attr.ib(type=str) + rtype = attr.ib(type=str) + events = attr.ib(type=Optional[List[EventBase]]) + newly_joined = attr.ib(type=bool) + full_state = attr.ib(type=bool) + since_token = attr.ib(type=Optional[StreamToken]) + upto_token = attr.ib(type=StreamToken) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index feb1c07cb2..b9ee6ec1ec 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -238,8 +238,11 @@ class RedactionTestCase(unittest.HomeserverTestCase): @defer.inlineCallbacks def build(self, prev_event_ids): built_event = yield self._base_builder.build(prev_event_ids) - built_event.event_id = self._event_id + + built_event._event_id = self._event_id built_event._event_dict["event_id"] = self._event_id + assert built_event.event_id == self._event_id + return built_event @property diff --git a/tox.ini b/tox.ini index 88ef12bebd..ef22368cf1 100644 --- a/tox.ini +++ b/tox.ini @@ -180,6 +180,7 @@ commands = mypy \ synapse/api \ synapse/config/ \ synapse/federation/transport \ + synapse/handlers/sync.py \ synapse/handlers/ui_auth \ synapse/logging/ \ synapse/module_api \ -- cgit 1.5.1 From 928edef9793bf10fa6156a42c4babbfaaaa17f88 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 31 Jan 2020 16:50:13 +0000 Subject: Pass room_version into `event_from_pdu_json` It's called from all over the shop, so this one's a bit messy. --- changelog.d/6856.misc | 1 + synapse/federation/federation_base.py | 28 ++++++++++++---------- synapse/federation/federation_client.py | 35 +++++++++++++--------------- synapse/federation/federation_server.py | 41 +++++++++++---------------------- tests/handlers/test_federation.py | 6 +++-- 5 files changed, 51 insertions(+), 60 deletions(-) create mode 100644 changelog.d/6856.misc (limited to 'tests') diff --git a/changelog.d/6856.misc b/changelog.d/6856.misc new file mode 100644 index 0000000000..08aa80bcd9 --- /dev/null +++ b/changelog.d/6856.misc @@ -0,0 +1 @@ +Refactoring work in preparation for changing the event redaction algorithm. diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 0e22183280..ebe8b8e9fe 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,9 +23,13 @@ from twisted.internet.defer import DeferredList from synapse.api.constants import MAX_DEPTH, EventTypes, Membership from synapse.api.errors import Codes, SynapseError -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, EventFormatVersions +from synapse.api.room_versions import ( + KNOWN_ROOM_VERSIONS, + EventFormatVersions, + RoomVersion, +) from synapse.crypto.event_signing import check_event_content_hash -from synapse.events import event_type_from_format_version +from synapse.events import EventBase, event_type_from_format_version from synapse.events.utils import prune_event from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import ( @@ -33,7 +38,7 @@ from synapse.logging.context import ( make_deferred_yieldable, preserve_fn, ) -from synapse.types import get_domain_from_id +from synapse.types import JsonDict, get_domain_from_id from synapse.util import unwrapFirstError logger = logging.getLogger(__name__) @@ -342,16 +347,15 @@ def _is_invite_via_3pid(event): ) -def event_from_pdu_json(pdu_json, event_format_version, outlier=False): - """Construct a FrozenEvent from an event json received over federation +def event_from_pdu_json( + pdu_json: JsonDict, room_version: RoomVersion, outlier: bool = False +) -> EventBase: + """Construct an EventBase from an event json received over federation Args: - pdu_json (object): pdu as received over federation - event_format_version (int): The event format version - outlier (bool): True to mark this event as an outlier - - Returns: - FrozenEvent + pdu_json: pdu as received over federation + room_version: The version of the room this event belongs to + outlier: True to mark this event as an outlier Raises: SynapseError: if the pdu is missing required fields or is otherwise @@ -370,7 +374,7 @@ def event_from_pdu_json(pdu_json, event_format_version, outlier=False): elif depth > MAX_DEPTH: raise SynapseError(400, "Depth too large", Codes.BAD_JSON) - event = event_type_from_format_version(event_format_version)(pdu_json) + event = event_type_from_format_version(room_version.event_format)(pdu_json) event.internal_metadata.outlier = outlier diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 5fb4bd414c..4870e39652 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -49,7 +49,7 @@ from synapse.api.room_versions import ( RoomVersion, RoomVersions, ) -from synapse.events import EventBase, builder, room_version_to_event_format +from synapse.events import EventBase, builder from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.logging.context import make_deferred_yieldable from synapse.logging.utils import log_function @@ -209,18 +209,18 @@ class FederationClient(FederationBase): logger.debug("backfill transaction_data=%r", transaction_data) - room_version = await self.store.get_room_version_id(room_id) - format_ver = room_version_to_event_format(room_version) + room_version = await self.store.get_room_version(room_id) pdus = [ - event_from_pdu_json(p, format_ver, outlier=False) + event_from_pdu_json(p, room_version, outlier=False) for p in transaction_data["pdus"] ] # FIXME: We should handle signature failures more gracefully. pdus[:] = await make_deferred_yieldable( defer.gatherResults( - self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True + self._check_sigs_and_hashes(room_version.identifier, pdus), + consumeErrors=True, ).addErrback(unwrapFirstError) ) @@ -262,8 +262,6 @@ class FederationClient(FederationBase): pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {}) - format_ver = room_version.event_format - signed_pdu = None for destination in destinations: now = self._clock.time_msec() @@ -284,7 +282,7 @@ class FederationClient(FederationBase): ) pdu_list = [ - event_from_pdu_json(p, format_ver, outlier=outlier) + event_from_pdu_json(p, room_version, outlier=outlier) for p in transaction_data["pdus"] ] @@ -350,15 +348,15 @@ class FederationClient(FederationBase): async def get_event_auth(self, destination, room_id, event_id): res = await self.transport_layer.get_event_auth(destination, room_id, event_id) - room_version = await self.store.get_room_version_id(room_id) - format_ver = room_version_to_event_format(room_version) + room_version = await self.store.get_room_version(room_id) auth_chain = [ - event_from_pdu_json(p, format_ver, outlier=True) for p in res["auth_chain"] + event_from_pdu_json(p, room_version, outlier=True) + for p in res["auth_chain"] ] signed_auth = await self._check_sigs_and_hash_and_fetch( - destination, auth_chain, outlier=True, room_version=room_version + destination, auth_chain, outlier=True, room_version=room_version.identifier ) signed_auth.sort(key=lambda e: e.depth) @@ -547,12 +545,12 @@ class FederationClient(FederationBase): logger.debug("Got content: %s", content) state = [ - event_from_pdu_json(p, room_version.event_format, outlier=True) + event_from_pdu_json(p, room_version, outlier=True) for p in content.get("state", []) ] auth_chain = [ - event_from_pdu_json(p, room_version.event_format, outlier=True) + event_from_pdu_json(p, room_version, outlier=True) for p in content.get("auth_chain", []) ] @@ -677,7 +675,7 @@ class FederationClient(FederationBase): logger.debug("Got response to send_invite: %s", pdu_dict) - pdu = event_from_pdu_json(pdu_dict, room_version.event_format) + pdu = event_from_pdu_json(pdu_dict, room_version) # Check signatures are correct. pdu = await self._check_sigs_and_hash(room_version.identifier, pdu) @@ -865,15 +863,14 @@ class FederationClient(FederationBase): timeout=timeout, ) - room_version = await self.store.get_room_version_id(room_id) - format_ver = room_version_to_event_format(room_version) + room_version = await self.store.get_room_version(room_id) events = [ - event_from_pdu_json(e, format_ver) for e in content.get("events", []) + event_from_pdu_json(e, room_version) for e in content.get("events", []) ] signed_events = await self._check_sigs_and_hash_and_fetch( - destination, events, outlier=False, room_version=room_version + destination, events, outlier=False, room_version=room_version.identifier ) except HttpResponseException as e: if not e.code == 400: diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 8e3933b6c5..2489832a11 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -38,7 +38,6 @@ from synapse.api.errors import ( UnsupportedRoomVersionError, ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.events import room_version_to_event_format from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.persistence import TransactionActions from synapse.federation.units import Edu, Transaction @@ -234,24 +233,17 @@ class FederationServer(FederationBase): continue try: - room_version = await self.store.get_room_version_id(room_id) + room_version = await self.store.get_room_version(room_id) except NotFoundError: logger.info("Ignoring PDU for unknown room_id: %s", room_id) continue - - try: - format_ver = room_version_to_event_format(room_version) - except UnsupportedRoomVersionError: + except UnsupportedRoomVersionError as e: # this can happen if support for a given room version is withdrawn, # so that we still get events for said room. - logger.info( - "Ignoring PDU for room %s with unknown version %s", - room_id, - room_version, - ) + logger.info("Ignoring PDU: %s", e) continue - event = event_from_pdu_json(p, format_ver) + event = event_from_pdu_json(p, room_version) pdus_by_room.setdefault(room_id, []).append(event) pdu_results = {} @@ -407,9 +399,7 @@ class FederationServer(FederationBase): Codes.UNSUPPORTED_ROOM_VERSION, ) - format_ver = room_version.event_format - - pdu = event_from_pdu_json(content, format_ver) + pdu = event_from_pdu_json(content, room_version) origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, pdu.room_id) pdu = await self._check_sigs_and_hash(room_version.identifier, pdu) @@ -420,16 +410,15 @@ class FederationServer(FederationBase): async def on_send_join_request(self, origin, content, room_id): logger.debug("on_send_join_request: content: %s", content) - room_version = await self.store.get_room_version_id(room_id) - format_ver = room_version_to_event_format(room_version) - pdu = event_from_pdu_json(content, format_ver) + room_version = await self.store.get_room_version(room_id) + pdu = event_from_pdu_json(content, room_version) origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, pdu.room_id) logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) - pdu = await self._check_sigs_and_hash(room_version, pdu) + pdu = await self._check_sigs_and_hash(room_version.identifier, pdu) res_pdus = await self.handler.on_send_join_request(origin, pdu) time_now = self._clock.time_msec() @@ -451,16 +440,15 @@ class FederationServer(FederationBase): async def on_send_leave_request(self, origin, content, room_id): logger.debug("on_send_leave_request: content: %s", content) - room_version = await self.store.get_room_version_id(room_id) - format_ver = room_version_to_event_format(room_version) - pdu = event_from_pdu_json(content, format_ver) + room_version = await self.store.get_room_version(room_id) + pdu = event_from_pdu_json(content, room_version) origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, pdu.room_id) logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures) - pdu = await self._check_sigs_and_hash(room_version, pdu) + pdu = await self._check_sigs_and_hash(room_version.identifier, pdu) await self.handler.on_send_leave_request(origin, pdu) return {} @@ -498,15 +486,14 @@ class FederationServer(FederationBase): origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) - room_version = await self.store.get_room_version_id(room_id) - format_ver = room_version_to_event_format(room_version) + room_version = await self.store.get_room_version(room_id) auth_chain = [ - event_from_pdu_json(e, format_ver) for e in content["auth_chain"] + event_from_pdu_json(e, room_version) for e in content["auth_chain"] ] signed_auth = await self._check_sigs_and_hash_and_fetch( - origin, auth_chain, outlier=True, room_version=room_version + origin, auth_chain, outlier=True, room_version=room_version.identifier ) ret = await self.handler.on_query_auth( diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index b4d92cf732..132e35651d 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -99,6 +99,7 @@ class FederationTestCase(unittest.HomeserverTestCase): user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test") room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + room_version = self.get_success(self.store.get_room_version(room_id)) # pretend that another server has joined join_event = self._build_and_send_join_event(OTHER_SERVER, OTHER_USER, room_id) @@ -120,7 +121,7 @@ class FederationTestCase(unittest.HomeserverTestCase): "auth_events": [], "origin_server_ts": self.clock.time_msec(), }, - join_event.format_version, + room_version, ) with LoggingContext(request="send_rejected"): @@ -149,6 +150,7 @@ class FederationTestCase(unittest.HomeserverTestCase): user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test") room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + room_version = self.get_success(self.store.get_room_version(room_id)) # pretend that another server has joined join_event = self._build_and_send_join_event(OTHER_SERVER, OTHER_USER, room_id) @@ -171,7 +173,7 @@ class FederationTestCase(unittest.HomeserverTestCase): "auth_events": [], "origin_server_ts": self.clock.time_msec(), }, - join_event.format_version, + room_version, ) with LoggingContext(request="send_rejected"): -- cgit 1.5.1 From 56ca93ef5941b5dfcda368f373a6bcd80d177acd Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 7 Feb 2020 11:29:36 +0100 Subject: Admin api to add an email address (#6789) --- changelog.d/6769.feature | 1 + docs/admin_api/user_admin_api.rst | 11 +++++++++++ synapse/handlers/admin.py | 2 ++ synapse/handlers/auth.py | 8 ++++++++ synapse/rest/admin/users.py | 39 +++++++++++++++++++++++++++++++++++++++ tests/rest/admin/test_user.py | 19 +++++++++++++++++-- 6 files changed, 78 insertions(+), 2 deletions(-) create mode 100644 changelog.d/6769.feature (limited to 'tests') diff --git a/changelog.d/6769.feature b/changelog.d/6769.feature new file mode 100644 index 0000000000..8a60e12907 --- /dev/null +++ b/changelog.d/6769.feature @@ -0,0 +1 @@ +Admin API to add or modify threepids of user accounts. \ No newline at end of file diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst index 0b3d09d694..eb146095de 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst @@ -15,6 +15,16 @@ with a body of: { "password": "user_password", "displayname": "User", + "threepids": [ + { + "medium": "email", + "address": "" + }, + { + "medium": "email", + "address": "" + } + ], "avatar_url": "", "admin": false, "deactivated": false @@ -23,6 +33,7 @@ with a body of: including an ``access_token`` of a server admin. The parameter ``displayname`` is optional and defaults to ``user_id``. +The parameter ``threepids`` is optional. The parameter ``avatar_url`` is optional. The parameter ``admin`` is optional and defaults to 'false'. The parameter ``deactivated`` is optional and defaults to 'false'. diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 9205865231..f3c0aeceb6 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -58,8 +58,10 @@ class AdminHandler(BaseHandler): ret = await self.store.get_user_by_id(user.to_string()) if ret: profile = await self.store.get_profileinfo(user.localpart) + threepids = await self.store.user_get_threepids(user.to_string()) ret["displayname"] = profile.display_name ret["avatar_url"] = profile.avatar_url + ret["threepids"] = threepids return ret async def export_user_data(self, user_id, writer): diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 54a71c49d2..48a88d3c2a 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -816,6 +816,14 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def add_threepid(self, user_id, medium, address, validated_at): + # check if medium has a valid value + if medium not in ["email", "msisdn"]: + raise SynapseError( + code=400, + msg=("'%s' is not a valid value for 'medium'" % (medium,)), + errcode=Codes.INVALID_PARAM, + ) + # 'Canonicalise' email addresses down to lower case. # We've now moving towards the homeserver being the entity that # is responsible for validating threepids used for resetting passwords diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index f1c4434f5c..e75c5f1370 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -136,6 +136,8 @@ class UserRestServletV2(RestServlet): self.hs = hs self.auth = hs.get_auth() self.admin_handler = hs.get_handlers().admin_handler + self.store = hs.get_datastore() + self.auth_handler = hs.get_auth_handler() self.profile_handler = hs.get_profile_handler() self.set_password_handler = hs.get_set_password_handler() self.deactivate_account_handler = hs.get_deactivate_account_handler() @@ -163,6 +165,7 @@ class UserRestServletV2(RestServlet): raise SynapseError(400, "This endpoint can only be used with local users") user = await self.admin_handler.get_user(target_user) + user_id = target_user.to_string() if user: # modify user if "displayname" in body: @@ -170,6 +173,29 @@ class UserRestServletV2(RestServlet): target_user, requester, body["displayname"], True ) + if "threepids" in body: + # check for required parameters for each threepid + for threepid in body["threepids"]: + assert_params_in_dict(threepid, ["medium", "address"]) + + # remove old threepids from user + threepids = await self.store.user_get_threepids(user_id) + for threepid in threepids: + try: + await self.auth_handler.delete_threepid( + user_id, threepid["medium"], threepid["address"], None + ) + except Exception: + logger.exception("Failed to remove threepids") + raise SynapseError(500, "Failed to remove threepids") + + # add new threepids to user + current_time = self.hs.get_clock().time_msec() + for threepid in body["threepids"]: + await self.auth_handler.add_threepid( + user_id, threepid["medium"], threepid["address"], current_time + ) + if "avatar_url" in body: await self.profile_handler.set_avatar_url( target_user, requester, body["avatar_url"], True @@ -221,6 +247,7 @@ class UserRestServletV2(RestServlet): admin = body.get("admin", None) user_type = body.get("user_type", None) displayname = body.get("displayname", None) + threepids = body.get("threepids", None) if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES: raise SynapseError(400, "Invalid user type") @@ -232,6 +259,18 @@ class UserRestServletV2(RestServlet): default_display_name=displayname, user_type=user_type, ) + + if "threepids" in body: + # check for required parameters for each threepid + for threepid in body["threepids"]: + assert_params_in_dict(threepid, ["medium", "address"]) + + current_time = self.hs.get_clock().time_msec() + for threepid in body["threepids"]: + await self.auth_handler.add_threepid( + user_id, threepid["medium"], threepid["address"], current_time + ) + if "avatar_url" in body: await self.profile_handler.set_avatar_url( user_id, requester, body["avatar_url"], True diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 8f09f51c61..3b5169b38d 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -407,7 +407,13 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ self.hs.config.registration_shared_secret = None - body = json.dumps({"password": "abc123", "admin": True}) + body = json.dumps( + { + "password": "abc123", + "admin": True, + "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + } + ) # Create user request, channel = self.make_request( @@ -421,6 +427,8 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("bob", channel.json_body["displayname"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) # Get user request, channel = self.make_request( @@ -449,7 +457,13 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Modify user - body = json.dumps({"displayname": "foobar", "deactivated": True}) + body = json.dumps( + { + "displayname": "foobar", + "deactivated": True, + "threepids": [{"medium": "email", "address": "bob2@bob.bob"}], + } + ) request, channel = self.make_request( "PUT", @@ -463,6 +477,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("foobar", channel.json_body["displayname"]) self.assertEqual(True, channel.json_body["deactivated"]) + # the user is deactivated, the threepid will be deleted # Get user request, channel = self.make_request( -- cgit 1.5.1 From b08b0a22d505b1555f511e3f38935a62930ea25d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 7 Feb 2020 13:56:38 +0000 Subject: Add typing to synapse.federation.sender (#6871) --- changelog.d/6871.misc | 1 + synapse/federation/federation_server.py | 7 +- synapse/federation/sender/__init__.py | 99 +++++++++++----------- synapse/federation/sender/per_destination_queue.py | 88 +++++++++---------- synapse/federation/sender/transaction_manager.py | 16 ++-- synapse/federation/units.py | 23 ++++- synapse/server.pyi | 2 + tests/handlers/test_typing.py | 8 +- tox.ini | 1 + 9 files changed, 138 insertions(+), 107 deletions(-) create mode 100644 changelog.d/6871.misc (limited to 'tests') diff --git a/changelog.d/6871.misc b/changelog.d/6871.misc new file mode 100644 index 0000000000..5161af9983 --- /dev/null +++ b/changelog.d/6871.misc @@ -0,0 +1 @@ +Add typing to `synapse.federation.sender` and port to async/await. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 2489832a11..a6c966a393 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -294,7 +294,12 @@ class FederationServer(FederationBase): async def _process_edu(edu_dict): received_edus_counter.inc() - edu = Edu(**edu_dict) + edu = Edu( + origin=origin, + destination=self.server_name, + edu_type=edu_dict["edu_type"], + content=edu_dict["content"], + ) await self.registry.on_edu(edu.edu_type, origin, edu.content) await concurrently_execute( diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 36c83c3027..233cb33daf 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import Dict, Hashable, Iterable, List, Optional, Set from six import itervalues @@ -23,6 +24,7 @@ from twisted.internet import defer import synapse import synapse.metrics +from synapse.events import EventBase from synapse.federation.sender.per_destination_queue import PerDestinationQueue from synapse.federation.sender.transaction_manager import TransactionManager from synapse.federation.units import Edu @@ -39,6 +41,8 @@ from synapse.metrics import ( events_processed_counter, ) from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.presence import UserPresenceState +from synapse.types import ReadReceipt from synapse.util.metrics import Measure, measure_func logger = logging.getLogger(__name__) @@ -68,7 +72,7 @@ class FederationSender(object): self._transaction_manager = TransactionManager(hs) # map from destination to PerDestinationQueue - self._per_destination_queues = {} # type: dict[str, PerDestinationQueue] + self._per_destination_queues = {} # type: Dict[str, PerDestinationQueue] LaterGauge( "synapse_federation_transaction_queue_pending_destinations", @@ -84,7 +88,7 @@ class FederationSender(object): # Map of user_id -> UserPresenceState for all the pending presence # to be sent out by user_id. Entries here get processed and put in # pending_presence_by_dest - self.pending_presence = {} + self.pending_presence = {} # type: Dict[str, UserPresenceState] LaterGauge( "synapse_federation_transaction_queue_pending_pdus", @@ -116,20 +120,17 @@ class FederationSender(object): # and that there is a pending call to _flush_rrs_for_room in the system. self._queues_awaiting_rr_flush_by_room = ( {} - ) # type: dict[str, set[PerDestinationQueue]] + ) # type: Dict[str, Set[PerDestinationQueue]] self._rr_txn_interval_per_room_ms = ( - 1000.0 / hs.get_config().federation_rr_transactions_per_room_per_second + 1000.0 / hs.config.federation_rr_transactions_per_room_per_second ) - def _get_per_destination_queue(self, destination): + def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue: """Get or create a PerDestinationQueue for the given destination Args: - destination (str): server_name of remote server - - Returns: - PerDestinationQueue + destination: server_name of remote server """ queue = self._per_destination_queues.get(destination) if not queue: @@ -137,7 +138,7 @@ class FederationSender(object): self._per_destination_queues[destination] = queue return queue - def notify_new_events(self, current_id): + def notify_new_events(self, current_id: int) -> None: """This gets called when we have some new events we might want to send out to other servers. """ @@ -151,13 +152,12 @@ class FederationSender(object): "process_event_queue_for_federation", self._process_event_queue_loop ) - @defer.inlineCallbacks - def _process_event_queue_loop(self): + async def _process_event_queue_loop(self) -> None: try: self._is_processing = True while True: - last_token = yield self.store.get_federation_out_pos("events") - next_token, events = yield self.store.get_all_new_events_stream( + last_token = await self.store.get_federation_out_pos("events") + next_token, events = await self.store.get_all_new_events_stream( last_token, self._last_poked_id, limit=100 ) @@ -166,8 +166,7 @@ class FederationSender(object): if not events and next_token >= self._last_poked_id: break - @defer.inlineCallbacks - def handle_event(event): + async def handle_event(event: EventBase) -> None: # Only send events for this server. send_on_behalf_of = event.internal_metadata.get_send_on_behalf_of() is_mine = self.is_mine_id(event.sender) @@ -184,7 +183,7 @@ class FederationSender(object): # Otherwise if the last member on a server in a room is # banned then it won't receive the event because it won't # be in the room after the ban. - destinations = yield self.state.get_hosts_in_room_at_events( + destinations = await self.state.get_hosts_in_room_at_events( event.room_id, event_ids=event.prev_event_ids() ) except Exception: @@ -206,17 +205,16 @@ class FederationSender(object): self._send_pdu(event, destinations) - @defer.inlineCallbacks - def handle_room_events(events): + async def handle_room_events(events: Iterable[EventBase]) -> None: with Measure(self.clock, "handle_room_events"): for event in events: - yield handle_event(event) + await handle_event(event) - events_by_room = {} + events_by_room = {} # type: Dict[str, List[EventBase]] for event in events: events_by_room.setdefault(event.room_id, []).append(event) - yield make_deferred_yieldable( + await make_deferred_yieldable( defer.gatherResults( [ run_in_background(handle_room_events, evs) @@ -226,11 +224,11 @@ class FederationSender(object): ) ) - yield self.store.update_federation_out_pos("events", next_token) + await self.store.update_federation_out_pos("events", next_token) if events: now = self.clock.time_msec() - ts = yield self.store.get_received_ts(events[-1].event_id) + ts = await self.store.get_received_ts(events[-1].event_id) synapse.metrics.event_processing_lag.labels( "federation_sender" @@ -254,7 +252,7 @@ class FederationSender(object): finally: self._is_processing = False - def _send_pdu(self, pdu, destinations): + def _send_pdu(self, pdu: EventBase, destinations: Iterable[str]) -> None: # We loop through all destinations to see whether we already have # a transaction in progress. If we do, stick it in the pending_pdus # table and we'll get back to it later. @@ -276,11 +274,11 @@ class FederationSender(object): self._get_per_destination_queue(destination).send_pdu(pdu, order) @defer.inlineCallbacks - def send_read_receipt(self, receipt): + def send_read_receipt(self, receipt: ReadReceipt): """Send a RR to any other servers in the room Args: - receipt (synapse.types.ReadReceipt): receipt to be sent + receipt: receipt to be sent """ # Some background on the rate-limiting going on here. @@ -343,7 +341,7 @@ class FederationSender(object): else: queue.flush_read_receipts_for_room(room_id) - def _schedule_rr_flush_for_room(self, room_id, n_domains): + def _schedule_rr_flush_for_room(self, room_id: str, n_domains: int) -> None: # that is going to cause approximately len(domains) transactions, so now back # off for that multiplied by RR_TXN_INTERVAL_PER_ROOM backoff_ms = self._rr_txn_interval_per_room_ms * n_domains @@ -352,7 +350,7 @@ class FederationSender(object): self.clock.call_later(backoff_ms, self._flush_rrs_for_room, room_id) self._queues_awaiting_rr_flush_by_room[room_id] = set() - def _flush_rrs_for_room(self, room_id): + def _flush_rrs_for_room(self, room_id: str) -> None: queues = self._queues_awaiting_rr_flush_by_room.pop(room_id) logger.debug("Flushing RRs in %s to %s", room_id, queues) @@ -368,14 +366,11 @@ class FederationSender(object): @preserve_fn # the caller should not yield on this @defer.inlineCallbacks - def send_presence(self, states): + def send_presence(self, states: List[UserPresenceState]): """Send the new presence states to the appropriate destinations. This actually queues up the presence states ready for sending and triggers a background task to process them and send out the transactions. - - Args: - states (list(UserPresenceState)) """ if not self.hs.config.use_presence: # No-op if presence is disabled. @@ -412,11 +407,10 @@ class FederationSender(object): finally: self._processing_pending_presence = False - def send_presence_to_destinations(self, states, destinations): + def send_presence_to_destinations( + self, states: List[UserPresenceState], destinations: List[str] + ) -> None: """Send the given presence states to the given destinations. - - Args: - states (list[UserPresenceState]) destinations (list[str]) """ @@ -431,12 +425,9 @@ class FederationSender(object): @measure_func("txnqueue._process_presence") @defer.inlineCallbacks - def _process_presence_inner(self, states): + def _process_presence_inner(self, states: List[UserPresenceState]): """Given a list of states populate self.pending_presence_by_dest and poke to send a new transaction to each destination - - Args: - states (list(UserPresenceState)) """ hosts_and_states = yield get_interested_remotes(self.store, states, self.state) @@ -446,14 +437,20 @@ class FederationSender(object): continue self._get_per_destination_queue(destination).send_presence(states) - def build_and_send_edu(self, destination, edu_type, content, key=None): + def build_and_send_edu( + self, + destination: str, + edu_type: str, + content: dict, + key: Optional[Hashable] = None, + ): """Construct an Edu object, and queue it for sending Args: - destination (str): name of server to send to - edu_type (str): type of EDU to send - content (dict): content of EDU - key (Any|None): clobbering key for this edu + destination: name of server to send to + edu_type: type of EDU to send + content: content of EDU + key: clobbering key for this edu """ if destination == self.server_name: logger.info("Not sending EDU to ourselves") @@ -468,12 +465,12 @@ class FederationSender(object): self.send_edu(edu, key) - def send_edu(self, edu, key): + def send_edu(self, edu: Edu, key: Optional[Hashable]): """Queue an EDU for sending Args: - edu (Edu): edu to send - key (Any|None): clobbering key for this edu + edu: edu to send + key: clobbering key for this edu """ queue = self._get_per_destination_queue(edu.destination) if key: @@ -481,7 +478,7 @@ class FederationSender(object): else: queue.send_edu(edu) - def send_device_messages(self, destination): + def send_device_messages(self, destination: str): if destination == self.server_name: logger.warning("Not sending device update to ourselves") return @@ -501,5 +498,5 @@ class FederationSender(object): self._get_per_destination_queue(destination).attempt_new_transaction() - def get_current_token(self): + def get_current_token(self) -> int: return 0 diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 5012aaea35..e13cd20ffa 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -15,11 +15,11 @@ # limitations under the License. import datetime import logging +from typing import Dict, Hashable, Iterable, List, Tuple from prometheus_client import Counter -from twisted.internet import defer - +import synapse.server from synapse.api.errors import ( FederationDeniedError, HttpResponseException, @@ -31,7 +31,7 @@ from synapse.handlers.presence import format_user_presence_state from synapse.metrics import sent_transactions_counter from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.presence import UserPresenceState -from synapse.types import StateMap +from synapse.types import ReadReceipt from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter # This is defined in the Matrix spec and enforced by the receiver. @@ -56,13 +56,18 @@ class PerDestinationQueue(object): Manages the per-destination transmission queues. Args: - hs (synapse.HomeServer): - transaction_sender (TransactionManager): - destination (str): the server_name of the destination that we are managing + hs + transaction_sender + destination: the server_name of the destination that we are managing transmission for. """ - def __init__(self, hs, transaction_manager, destination): + def __init__( + self, + hs: "synapse.server.HomeServer", + transaction_manager: "synapse.federation.sender.TransactionManager", + destination: str, + ): self._server_name = hs.hostname self._clock = hs.get_clock() self._store = hs.get_datastore() @@ -72,20 +77,20 @@ class PerDestinationQueue(object): self.transmission_loop_running = False # a list of tuples of (pending pdu, order) - self._pending_pdus = [] # type: list[tuple[EventBase, int]] - self._pending_edus = [] # type: list[Edu] + self._pending_pdus = [] # type: List[Tuple[EventBase, int]] + self._pending_edus = [] # type: List[Edu] # Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered # based on their key (e.g. typing events by room_id) # Map of (edu_type, key) -> Edu - self._pending_edus_keyed = {} # type: StateMap[Edu] + self._pending_edus_keyed = {} # type: Dict[Tuple[str, Hashable], Edu] # Map of user_id -> UserPresenceState of pending presence to be sent to this # destination - self._pending_presence = {} # type: dict[str, UserPresenceState] + self._pending_presence = {} # type: Dict[str, UserPresenceState] # room_id -> receipt_type -> user_id -> receipt_dict - self._pending_rrs = {} + self._pending_rrs = {} # type: Dict[str, Dict[str, Dict[str, dict]]] self._rrs_pending_flush = False # stream_id of last successfully sent to-device message. @@ -95,50 +100,50 @@ class PerDestinationQueue(object): # stream_id of last successfully sent device list update. self._last_device_list_stream_id = 0 - def __str__(self): + def __str__(self) -> str: return "PerDestinationQueue[%s]" % self._destination - def pending_pdu_count(self): + def pending_pdu_count(self) -> int: return len(self._pending_pdus) - def pending_edu_count(self): + def pending_edu_count(self) -> int: return ( len(self._pending_edus) + len(self._pending_presence) + len(self._pending_edus_keyed) ) - def send_pdu(self, pdu, order): + def send_pdu(self, pdu: EventBase, order: int) -> None: """Add a PDU to the queue, and start the transmission loop if neccessary Args: - pdu (EventBase): pdu to send - order (int): + pdu: pdu to send + order """ self._pending_pdus.append((pdu, order)) self.attempt_new_transaction() - def send_presence(self, states): + def send_presence(self, states: Iterable[UserPresenceState]) -> None: """Add presence updates to the queue. Start the transmission loop if neccessary. Args: - states (iterable[UserPresenceState]): presence to send + states: presence to send """ self._pending_presence.update({state.user_id: state for state in states}) self.attempt_new_transaction() - def queue_read_receipt(self, receipt): + def queue_read_receipt(self, receipt: ReadReceipt) -> None: """Add a RR to the list to be sent. Doesn't start the transmission loop yet (see flush_read_receipts_for_room) Args: - receipt (synapse.api.receipt_info.ReceiptInfo): receipt to be queued + receipt: receipt to be queued """ self._pending_rrs.setdefault(receipt.room_id, {}).setdefault( receipt.receipt_type, {} )[receipt.user_id] = {"event_ids": receipt.event_ids, "data": receipt.data} - def flush_read_receipts_for_room(self, room_id): + def flush_read_receipts_for_room(self, room_id: str) -> None: # if we don't have any read-receipts for this room, it may be that we've already # sent them out, so we don't need to flush. if room_id not in self._pending_rrs: @@ -146,15 +151,15 @@ class PerDestinationQueue(object): self._rrs_pending_flush = True self.attempt_new_transaction() - def send_keyed_edu(self, edu, key): + def send_keyed_edu(self, edu: Edu, key: Hashable) -> None: self._pending_edus_keyed[(edu.edu_type, key)] = edu self.attempt_new_transaction() - def send_edu(self, edu): + def send_edu(self, edu) -> None: self._pending_edus.append(edu) self.attempt_new_transaction() - def attempt_new_transaction(self): + def attempt_new_transaction(self) -> None: """Try to start a new transaction to this destination If there is already a transaction in progress to this destination, @@ -177,23 +182,22 @@ class PerDestinationQueue(object): self._transaction_transmission_loop, ) - @defer.inlineCallbacks - def _transaction_transmission_loop(self): - pending_pdus = [] + async def _transaction_transmission_loop(self) -> None: + pending_pdus = [] # type: List[Tuple[EventBase, int]] try: self.transmission_loop_running = True # This will throw if we wouldn't retry. We do this here so we fail # quickly, but we will later check this again in the http client, # hence why we throw the result away. - yield get_retry_limiter(self._destination, self._clock, self._store) + await get_retry_limiter(self._destination, self._clock, self._store) pending_pdus = [] while True: # We have to keep 2 free slots for presence and rr_edus limit = MAX_EDUS_PER_TRANSACTION - 2 - device_update_edus, dev_list_id = yield self._get_device_update_edus( + device_update_edus, dev_list_id = await self._get_device_update_edus( limit ) @@ -202,7 +206,7 @@ class PerDestinationQueue(object): ( to_device_edus, device_stream_id, - ) = yield self._get_to_device_message_edus(limit) + ) = await self._get_to_device_message_edus(limit) pending_edus = device_update_edus + to_device_edus @@ -269,7 +273,7 @@ class PerDestinationQueue(object): # END CRITICAL SECTION - success = yield self._transaction_manager.send_new_transaction( + success = await self._transaction_manager.send_new_transaction( self._destination, pending_pdus, pending_edus ) if success: @@ -280,7 +284,7 @@ class PerDestinationQueue(object): # Remove the acknowledged device messages from the database # Only bother if we actually sent some device messages if to_device_edus: - yield self._store.delete_device_msgs_for_remote( + await self._store.delete_device_msgs_for_remote( self._destination, device_stream_id ) @@ -289,7 +293,7 @@ class PerDestinationQueue(object): logger.info( "Marking as sent %r %r", self._destination, dev_list_id ) - yield self._store.mark_as_sent_devices_by_remote( + await self._store.mark_as_sent_devices_by_remote( self._destination, dev_list_id ) @@ -334,7 +338,7 @@ class PerDestinationQueue(object): # We want to be *very* sure we clear this after we stop processing self.transmission_loop_running = False - def _get_rr_edus(self, force_flush): + def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]: if not self._pending_rrs: return if not force_flush and not self._rrs_pending_flush: @@ -351,17 +355,16 @@ class PerDestinationQueue(object): self._rrs_pending_flush = False yield edu - def _pop_pending_edus(self, limit): + def _pop_pending_edus(self, limit: int) -> List[Edu]: pending_edus = self._pending_edus pending_edus, self._pending_edus = pending_edus[:limit], pending_edus[limit:] return pending_edus - @defer.inlineCallbacks - def _get_device_update_edus(self, limit): + async def _get_device_update_edus(self, limit: int) -> Tuple[List[Edu], int]: last_device_list = self._last_device_list_stream_id # Retrieve list of new device updates to send to the destination - now_stream_id, results = yield self._store.get_device_updates_by_remote( + now_stream_id, results = await self._store.get_device_updates_by_remote( self._destination, last_device_list, limit=limit ) edus = [ @@ -378,11 +381,10 @@ class PerDestinationQueue(object): return (edus, now_stream_id) - @defer.inlineCallbacks - def _get_to_device_message_edus(self, limit): + async def _get_to_device_message_edus(self, limit: int) -> Tuple[List[Edu], int]: last_device_stream_id = self._last_device_stream_id to_device_stream_id = self._store.get_to_device_stream_token() - contents, stream_id = yield self._store.get_new_device_msgs_for_remote( + contents, stream_id = await self._store.get_new_device_msgs_for_remote( self._destination, last_device_stream_id, to_device_stream_id, limit ) edus = [ diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 5fed626d5b..3c2a02a3b3 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -13,14 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import List from canonicaljson import json -from twisted.internet import defer - +import synapse.server from synapse.api.errors import HttpResponseException +from synapse.events import EventBase from synapse.federation.persistence import TransactionActions -from synapse.federation.units import Transaction +from synapse.federation.units import Edu, Transaction from synapse.logging.opentracing import ( extract_text_map, set_tag, @@ -39,7 +40,7 @@ class TransactionManager(object): shared between PerDestinationQueue objects """ - def __init__(self, hs): + def __init__(self, hs: "synapse.server.HomeServer"): self._server_name = hs.hostname self.clock = hs.get_clock() # nb must be called this for @measure_func self._store = hs.get_datastore() @@ -50,8 +51,9 @@ class TransactionManager(object): self._next_txn_id = int(self.clock.time_msec()) @measure_func("_send_new_transaction") - @defer.inlineCallbacks - def send_new_transaction(self, destination, pending_pdus, pending_edus): + async def send_new_transaction( + self, destination: str, pending_pdus: List[EventBase], pending_edus: List[Edu] + ): # Make a transaction-sending opentracing span. This span follows on from # all the edus in that transaction. This needs to be done since there is @@ -127,7 +129,7 @@ class TransactionManager(object): return data try: - response = yield self._transport_layer.send_transaction( + response = await self._transport_layer.send_transaction( transaction, json_data_cb ) code = 200 diff --git a/synapse/federation/units.py b/synapse/federation/units.py index b4d743cde7..6b32e0dcbf 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -19,11 +19,15 @@ server protocol. import logging +import attr + +from synapse.types import JsonDict from synapse.util.jsonobject import JsonEncodedObject logger = logging.getLogger(__name__) +@attr.s(slots=True) class Edu(JsonEncodedObject): """ An Edu represents a piece of data sent from one homeserver to another. @@ -32,11 +36,24 @@ class Edu(JsonEncodedObject): internal ID or previous references graph. """ - valid_keys = ["origin", "destination", "edu_type", "content"] + edu_type = attr.ib(type=str) + content = attr.ib(type=dict) + origin = attr.ib(type=str) + destination = attr.ib(type=str) - required_keys = ["edu_type"] + def get_dict(self) -> JsonDict: + return { + "edu_type": self.edu_type, + "content": self.content, + } - internal_keys = ["origin", "destination"] + def get_internal_dict(self) -> JsonDict: + return { + "edu_type": self.edu_type, + "content": self.content, + "origin": self.origin, + "destination": self.destination, + } def get_context(self): return getattr(self, "content", {}).get("org.matrix.opentracing_context", "{}") diff --git a/synapse/server.pyi b/synapse/server.pyi index 90347ac23e..40eabfe5d9 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -107,3 +107,5 @@ class HomeServer(object): self, ) -> synapse.replication.tcp.client.ReplicationClientHandler: pass + def is_mine_id(self, domain_id: str) -> bool: + pass diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 68b9847bd2..2767b0497a 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -111,7 +111,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): retry_timings_res ) - self.datastore.get_device_updates_by_remote.return_value = (0, []) + self.datastore.get_device_updates_by_remote.return_value = defer.succeed( + (0, []) + ) def get_received_txn_response(*args): return defer.succeed(None) @@ -144,7 +146,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.datastore.get_current_state_deltas.return_value = (0, None) self.datastore.get_to_device_stream_token = lambda: 0 - self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: ([], 0) + self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: defer.succeed( + ([], 0) + ) self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed( None diff --git a/tox.ini b/tox.ini index ef22368cf1..f8229eba88 100644 --- a/tox.ini +++ b/tox.ini @@ -179,6 +179,7 @@ extras = all commands = mypy \ synapse/api \ synapse/config/ \ + synapse/federation/sender \ synapse/federation/transport \ synapse/handlers/sync.py \ synapse/handlers/ui_auth \ -- cgit 1.5.1 From 799001f2c0b31d72b95a252a3808da25987e1ed3 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 7 Feb 2020 15:30:04 +0000 Subject: Add a `make_event_from_dict` method (#6858) ... and use it in places where it's trivial to do so. This will make it easier to pass room versions into the FrozenEvent constructors. --- changelog.d/6858.misc | 1 + synapse/events/__init__.py | 16 ++++++++++++++-- synapse/events/builder.py | 10 +++------- synapse/federation/federation_base.py | 5 ++--- tests/api/test_filtering.py | 4 ++-- tests/crypto/test_event_signing.py | 6 +++--- tests/events/test_utils.py | 9 +++++---- tests/federation/test_federation_server.py | 4 ++-- tests/replication/slave/storage/test_events.py | 12 ++++++++---- tests/state/test_v2.py | 4 ++-- tests/test_event_auth.py | 10 +++++----- tests/test_federation.py | 6 +++--- tests/test_state.py | 4 ++-- 13 files changed, 52 insertions(+), 39 deletions(-) create mode 100644 changelog.d/6858.misc (limited to 'tests') diff --git a/changelog.d/6858.misc b/changelog.d/6858.misc new file mode 100644 index 0000000000..08aa80bcd9 --- /dev/null +++ b/changelog.d/6858.misc @@ -0,0 +1 @@ +Refactoring work in preparation for changing the event redaction algorithm. diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 89d41d82b6..a842661a90 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -16,12 +16,13 @@ import os from distutils.util import strtobool +from typing import Optional, Type import six from unpaddedbase64 import encode_base64 -from synapse.api.room_versions import EventFormatVersions +from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions from synapse.types import JsonDict from synapse.util.caches import intern_dict from synapse.util.frozenutils import freeze @@ -407,7 +408,7 @@ class FrozenEventV3(FrozenEventV2): return self._event_id -def event_type_from_format_version(format_version): +def event_type_from_format_version(format_version: int) -> Type[EventBase]: """Returns the python type to use to construct an Event object for the given event format version. @@ -427,3 +428,14 @@ def event_type_from_format_version(format_version): return FrozenEventV3 else: raise Exception("No event format %r" % (format_version,)) + + +def make_event_from_dict( + event_dict: JsonDict, + room_version: RoomVersion = RoomVersions.V1, + internal_metadata_dict: JsonDict = {}, + rejected_reason: Optional[str] = None, +) -> EventBase: + """Construct an EventBase from the given event dict""" + event_type = event_type_from_format_version(room_version.event_format) + return event_type(event_dict, internal_metadata_dict, rejected_reason) diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 8d63ad6dc3..a0c4a40c27 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -28,11 +28,7 @@ from synapse.api.room_versions import ( RoomVersion, ) from synapse.crypto.event_signing import add_hashes_and_signatures -from synapse.events import ( - EventBase, - _EventInternalMetadata, - event_type_from_format_version, -) +from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict from synapse.types import EventID, JsonDict from synapse.util import Clock from synapse.util.stringutils import random_string @@ -256,8 +252,8 @@ def create_local_event_from_event_dict( event_dict.setdefault("signatures", {}) add_hashes_and_signatures(room_version, event_dict, hostname, signing_key) - return event_type_from_format_version(format_version)( - event_dict, internal_metadata_dict=internal_metadata_dict + return make_event_from_dict( + event_dict, room_version, internal_metadata_dict=internal_metadata_dict ) diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index ebe8b8e9fe..eea64c1c9f 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -29,7 +29,7 @@ from synapse.api.room_versions import ( RoomVersion, ) from synapse.crypto.event_signing import check_event_content_hash -from synapse.events import EventBase, event_type_from_format_version +from synapse.events import EventBase, make_event_from_dict from synapse.events.utils import prune_event from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import ( @@ -374,8 +374,7 @@ def event_from_pdu_json( elif depth > MAX_DEPTH: raise SynapseError(400, "Depth too large", Codes.BAD_JSON) - event = event_type_from_format_version(room_version.event_format)(pdu_json) - + event = make_event_from_dict(pdu_json, room_version) event.internal_metadata.outlier = outlier return event diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 63d8633582..4e67503cf0 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -25,7 +25,7 @@ from twisted.internet import defer from synapse.api.constants import EventContentFields from synapse.api.errors import SynapseError from synapse.api.filtering import Filter -from synapse.events import FrozenEvent +from synapse.events import make_event_from_dict from tests import unittest from tests.utils import DeferredMockCallable, MockHttpResource, setup_test_homeserver @@ -38,7 +38,7 @@ def MockEvent(**kwargs): kwargs["event_id"] = "fake_event_id" if "type" not in kwargs: kwargs["type"] = "fake_type" - return FrozenEvent(kwargs) + return make_event_from_dict(kwargs) class FilteringTestCase(unittest.TestCase): diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py index 6143a50ab2..62f639a18d 100644 --- a/tests/crypto/test_event_signing.py +++ b/tests/crypto/test_event_signing.py @@ -19,7 +19,7 @@ from unpaddedbase64 import decode_base64 from synapse.api.room_versions import RoomVersions from synapse.crypto.event_signing import add_hashes_and_signatures -from synapse.events import FrozenEvent +from synapse.events import make_event_from_dict from tests import unittest @@ -54,7 +54,7 @@ class EventSigningTestCase(unittest.TestCase): RoomVersions.V1, event_dict, HOSTNAME, self.signing_key ) - event = FrozenEvent(event_dict) + event = make_event_from_dict(event_dict) self.assertTrue(hasattr(event, "hashes")) self.assertIn("sha256", event.hashes) @@ -88,7 +88,7 @@ class EventSigningTestCase(unittest.TestCase): RoomVersions.V1, event_dict, HOSTNAME, self.signing_key ) - event = FrozenEvent(event_dict) + event = make_event_from_dict(event_dict) self.assertTrue(hasattr(event, "hashes")) self.assertIn("sha256", event.hashes) diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index 2b13980dfd..45d55b9e94 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from synapse.events import FrozenEvent +from synapse.events import make_event_from_dict from synapse.events.utils import ( copy_power_levels_contents, prune_event, @@ -30,7 +29,7 @@ def MockEvent(**kwargs): kwargs["event_id"] = "fake_event_id" if "type" not in kwargs: kwargs["type"] = "fake_type" - return FrozenEvent(kwargs) + return make_event_from_dict(kwargs) class PruneEventTestCase(unittest.TestCase): @@ -38,7 +37,9 @@ class PruneEventTestCase(unittest.TestCase): `matchdict` when it is redacted. """ def run_test(self, evdict, matchdict): - self.assertEquals(prune_event(FrozenEvent(evdict)).get_dict(), matchdict) + self.assertEquals( + prune_event(make_event_from_dict(evdict)).get_dict(), matchdict + ) def test_minimal(self): self.run_test( diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index 1ec8c40901..e7d8699040 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -15,7 +15,7 @@ # limitations under the License. import logging -from synapse.events import FrozenEvent +from synapse.events import make_event_from_dict from synapse.federation.federation_server import server_matches_acl_event from synapse.rest import admin from synapse.rest.client.v1 import login, room @@ -105,7 +105,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase): def _create_acl_event(content): - return FrozenEvent( + return make_event_from_dict( { "room_id": "!a:b", "event_id": "$a:b", diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index b1b037006d..d31210fbe4 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -15,7 +15,7 @@ import logging from canonicaljson import encode_canonical_json -from synapse.events import FrozenEvent, _EventInternalMetadata +from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict from synapse.events.snapshot import EventContext from synapse.handlers.room import RoomEventSource from synapse.replication.slave.storage.events import SlavedEventStore @@ -90,7 +90,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): msg_dict["content"] = {} msg_dict["unsigned"]["redacted_by"] = redaction.event_id msg_dict["unsigned"]["redacted_because"] = redaction - redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict()) + redacted = make_event_from_dict( + msg_dict, internal_metadata_dict=msg.internal_metadata.get_dict() + ) self.check("get_event", [msg.event_id], redacted) def test_backfilled_redactions(self): @@ -110,7 +112,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): msg_dict["content"] = {} msg_dict["unsigned"]["redacted_by"] = redaction.event_id msg_dict["unsigned"]["redacted_because"] = redaction - redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict()) + redacted = make_event_from_dict( + msg_dict, internal_metadata_dict=msg.internal_metadata.get_dict() + ) self.check("get_event", [msg.event_id], redacted) def test_invites(self): @@ -345,7 +349,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): if redacts is not None: event_dict["redacts"] = redacts - event = FrozenEvent(event_dict, internal_metadata_dict=internal) + event = make_event_from_dict(event_dict, internal_metadata_dict=internal) self.event_id += 1 diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index 0f341d3ac3..5bafad9f19 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -22,7 +22,7 @@ import attr from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.room_versions import RoomVersions from synapse.event_auth import auth_types_for_event -from synapse.events import FrozenEvent +from synapse.events import make_event_from_dict from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store from synapse.types import EventID @@ -89,7 +89,7 @@ class FakeEvent(object): if self.state_key is not None: event_dict["state_key"] = self.state_key - return FrozenEvent(event_dict) + return make_event_from_dict(event_dict) # All graphs start with this set of events diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index ca20b085a2..bfa5d6f510 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -18,7 +18,7 @@ import unittest from synapse import event_auth from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersions -from synapse.events import FrozenEvent +from synapse.events import make_event_from_dict class EventAuthTestCase(unittest.TestCase): @@ -94,7 +94,7 @@ TEST_ROOM_ID = "!test:room" def _create_event(user_id): - return FrozenEvent( + return make_event_from_dict( { "room_id": TEST_ROOM_ID, "event_id": _get_event_id(), @@ -106,7 +106,7 @@ def _create_event(user_id): def _join_event(user_id): - return FrozenEvent( + return make_event_from_dict( { "room_id": TEST_ROOM_ID, "event_id": _get_event_id(), @@ -119,7 +119,7 @@ def _join_event(user_id): def _power_levels_event(sender, content): - return FrozenEvent( + return make_event_from_dict( { "room_id": TEST_ROOM_ID, "event_id": _get_event_id(), @@ -132,7 +132,7 @@ def _power_levels_event(sender, content): def _random_state_event(sender): - return FrozenEvent( + return make_event_from_dict( { "room_id": TEST_ROOM_ID, "event_id": _get_event_id(), diff --git a/tests/test_federation.py b/tests/test_federation.py index 68684460c6..9b5cf562f3 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -2,7 +2,7 @@ from mock import Mock from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed -from synapse.events import FrozenEvent +from synapse.events import make_event_from_dict from synapse.logging.context import LoggingContext from synapse.types import Requester, UserID from synapse.util import Clock @@ -43,7 +43,7 @@ class MessageAcceptTests(unittest.TestCase): ) )[0] - join_event = FrozenEvent( + join_event = make_event_from_dict( { "room_id": self.room_id, "sender": "@baduser:test.serv", @@ -105,7 +105,7 @@ class MessageAcceptTests(unittest.TestCase): )[0] # Now lie about an event - lying_event = FrozenEvent( + lying_event = make_event_from_dict( { "room_id": self.room_id, "sender": "@baduser:test.serv", diff --git a/tests/test_state.py b/tests/test_state.py index 1e4449fa1c..d1578fe581 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -20,7 +20,7 @@ from twisted.internet import defer from synapse.api.auth import Auth from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions -from synapse.events import FrozenEvent +from synapse.events import make_event_from_dict from synapse.events.snapshot import EventContext from synapse.state import StateHandler, StateResolutionHandler @@ -66,7 +66,7 @@ def create_event( d.update(kwargs) - event = FrozenEvent(d) + event = make_event_from_dict(d) return event -- cgit 1.5.1 From a92e703ab9d78aecc062e797f941bb7e206650a5 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 10 Feb 2020 16:35:26 -0500 Subject: Reject device display names that are too long (#6882) * Reject device display names that are too long. Too long is currently defined as 100 characters in length. * Add a regression test for rejecting a too long device display name. --- changelog.d/6882.misc | 1 + synapse/handlers/device.py | 14 +++++++++++++- tests/handlers/test_device.py | 18 ++++++++++++++++++ 3 files changed, 32 insertions(+), 1 deletion(-) create mode 100644 changelog.d/6882.misc (limited to 'tests') diff --git a/changelog.d/6882.misc b/changelog.d/6882.misc new file mode 100644 index 0000000000..e8382e36ae --- /dev/null +++ b/changelog.d/6882.misc @@ -0,0 +1 @@ +Reject device display names over 100 characters in length. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 6d8e48ed39..50cea3f378 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -26,6 +26,7 @@ from synapse.api.errors import ( FederationDeniedError, HttpResponseException, RequestSendFailed, + SynapseError, ) from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.types import RoomStreamToken, get_domain_from_id @@ -39,6 +40,8 @@ from ._base import BaseHandler logger = logging.getLogger(__name__) +MAX_DEVICE_DISPLAY_NAME_LEN = 100 + class DeviceWorkerHandler(BaseHandler): def __init__(self, hs): @@ -404,9 +407,18 @@ class DeviceHandler(DeviceWorkerHandler): defer.Deferred: """ + # Reject a new displayname which is too long. + new_display_name = content.get("display_name") + if new_display_name and len(new_display_name) > MAX_DEVICE_DISPLAY_NAME_LEN: + raise SynapseError( + 400, + "Device display name is too long (max %i)" + % (MAX_DEVICE_DISPLAY_NAME_LEN,), + ) + try: yield self.store.update_device( - user_id, device_id, new_display_name=content.get("display_name") + user_id, device_id, new_display_name=new_display_name ) yield self.notify_device_update(user_id, [device_id]) except errors.StoreError as e: diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index a3aa0a1cf2..62b47f6574 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -160,6 +160,24 @@ class DeviceTestCase(unittest.HomeserverTestCase): res = self.get_success(self.handler.get_device(user1, "abc")) self.assertEqual(res["display_name"], "new display") + def test_update_device_too_long_display_name(self): + """Update a device with a display name that is invalid (too long).""" + self._record_users() + + # Request to update a device display name with a new value that is longer than allowed. + update = { + "display_name": "a" + * (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1) + } + self.get_failure( + self.handler.update_device(user1, "abc", update), + synapse.api.errors.SynapseError, + ) + + # Ensure the display name was not updated. + res = self.get_success(self.handler.get_device(user1, "abc")) + self.assertEqual(res["display_name"], "display 2") + def test_update_unknown_device(self): update = {"display_name": "new_display"} res = self.handler.update_device("user_id", "unknown_device_id", update) -- cgit 1.5.1 From 0295abdcf70548a7ba9d685598233f34c50127b5 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Tue, 11 Feb 2020 16:18:29 +0000 Subject: Dinsic Blacking with black==18.6b2 --- synapse/__init__.py | 1 + synapse/_scripts/register_new_matrix_user.py | 24 +- synapse/api/auth.py | 158 ++-- synapse/api/constants.py | 6 +- synapse/api/errors.py | 118 ++- synapse/api/filtering.py | 189 ++--- synapse/api/ratelimiting.py | 16 +- synapse/api/room_versions.py | 25 +- synapse/api/urls.py | 17 +- synapse/app/__init__.py | 4 +- synapse/app/_base.py | 77 +- synapse/app/appservice.py | 29 +- synapse/app/client_reader.py | 29 +- synapse/app/event_creator.py | 39 +- synapse/app/federation_reader.py | 38 +- synapse/app/federation_sender.py | 46 +- synapse/app/frontend_proxy.py | 69 +- synapse/app/homeserver.py | 160 ++-- synapse/app/media_repository.py | 37 +- synapse/app/pusher.py | 60 +- synapse/app/synchrotron.py | 90 ++- synapse/app/user_dir.py | 45 +- synapse/appservice/__init__.py | 39 +- synapse/appservice/api.py | 66 +- synapse/appservice/scheduler.py | 50 +- synapse/config/_base.py | 12 +- synapse/config/api.py | 22 +- synapse/config/appservice.py | 54 +- synapse/config/captcha.py | 1 - synapse/config/consent_config.py | 27 +- synapse/config/database.py | 31 +- synapse/config/emailconfig.py | 71 +- synapse/config/jwt_config.py | 5 +- synapse/config/key.py | 5 +- synapse/config/logger.py | 70 +- synapse/config/metrics.py | 8 +- synapse/config/password_auth_providers.py | 16 +- synapse/config/registration.py | 47 +- synapse/config/repository.py | 74 +- synapse/config/room_directory.py | 19 +- synapse/config/saml2_config.py | 18 +- synapse/config/server.py | 263 +++--- synapse/config/server_notices_config.py | 17 +- synapse/config/tls.py | 50 +- synapse/config/user_directory.py | 12 +- synapse/config/voip.py | 3 +- synapse/config/workers.py | 16 +- synapse/crypto/event_signing.py | 17 +- synapse/crypto/keyring.py | 19 +- synapse/event_auth.py | 177 ++-- synapse/events/__init__.py | 30 +- synapse/events/builder.py | 37 +- synapse/events/snapshot.py | 49 +- synapse/events/spamcheck.py | 24 +- synapse/events/third_party_rules.py | 5 +- synapse/events/utils.py | 58 +- synapse/events/validator.py | 21 +- synapse/federation/federation_base.py | 78 +- synapse/federation/federation_client.py | 217 +++-- synapse/federation/federation_server.py | 234 +++--- synapse/federation/persistence.py | 15 +- synapse/federation/send_queue.py | 145 ++-- synapse/federation/sender/__init__.py | 65 +- synapse/federation/sender/per_destination_queue.py | 7 +- synapse/federation/sender/transaction_manager.py | 38 +- synapse/federation/transport/client.py | 294 ++++--- synapse/federation/transport/server.py | 237 +++--- synapse/federation/units.py | 33 +- synapse/groups/attestations.py | 35 +- synapse/groups/groups_server.py | 317 +++----- synapse/handlers/_base.py | 10 +- synapse/handlers/account_data.py | 23 +- synapse/handlers/account_validity.py | 75 +- synapse/handlers/acme.py | 14 +- synapse/handlers/admin.py | 23 +- synapse/handlers/appservice.py | 74 +- synapse/handlers/auth.py | 266 +++--- synapse/handlers/deactivate_account.py | 12 +- synapse/handlers/device.py | 116 ++- synapse/handlers/devicemessage.py | 7 +- synapse/handlers/directory.py | 136 ++-- synapse/handlers/e2e_keys.py | 119 ++- synapse/handlers/e2e_room_keys.py | 38 +- synapse/handlers/events.py | 49 +- synapse/handlers/federation.py | 887 +++++++++------------ synapse/handlers/groups_local.py | 95 +-- synapse/handlers/identity.py | 165 ++-- synapse/handlers/initial_sync.py | 170 ++-- synapse/handlers/message.py | 256 +++--- synapse/handlers/pagination.py | 81 +- synapse/handlers/password_policy.py | 16 +- synapse/handlers/presence.py | 311 ++++---- synapse/handlers/profile.py | 112 ++- synapse/handlers/read_marker.py | 9 +- synapse/handlers/receipts.py | 26 +- synapse/handlers/register.py | 267 +++---- synapse/handlers/room.py | 310 +++---- synapse/handlers/room_list.py | 254 +++--- synapse/handlers/room_member.py | 329 +++----- synapse/handlers/room_member_worker.py | 8 +- synapse/handlers/search.py | 129 ++- synapse/handlers/set_password.py | 5 +- synapse/handlers/state_deltas.py | 1 - synapse/handlers/stats.py | 2 +- synapse/handlers/sync.py | 694 ++++++++-------- synapse/handlers/typing.py | 74 +- synapse/http/__init__.py | 9 +- synapse/http/additional_resource.py | 1 + synapse/http/client.py | 28 +- synapse/http/endpoint.py | 20 +- synapse/http/federation/matrix_federation_agent.py | 98 ++- synapse/http/federation/srv_resolver.py | 37 +- synapse/http/matrixfederationclient.py | 242 +++--- synapse/http/server.py | 80 +- synapse/http/servlet.py | 45 +- synapse/http/site.py | 48 +- synapse/metrics/__init__.py | 7 +- synapse/metrics/background_process_metrics.py | 33 +- synapse/module_api/__init__.py | 6 +- synapse/notifier.py | 68 +- synapse/push/action_generator.py | 4 +- synapse/push/baserules.py | 404 ++++------ synapse/push/bulk_push_rule_evaluator.py | 63 +- synapse/push/clientformat.py | 38 +- synapse/push/emailpusher.py | 63 +- synapse/push/httppusher.py | 226 +++--- synapse/push/mailer.py | 343 ++++---- synapse/push/presentable_names.py | 45 +- synapse/push/push_rule_evaluator.py | 57 +- synapse/push/push_tools.py | 8 +- synapse/push/pusher.py | 10 +- synapse/push/pusherpool.py | 97 ++- synapse/push/rulekinds.py | 10 +- synapse/python_dependencies.py | 24 +- synapse/replication/http/_base.py | 22 +- synapse/replication/http/federation.py | 51 +- synapse/replication/http/login.py | 7 +- synapse/replication/http/membership.py | 34 +- synapse/replication/http/register.py | 17 +- synapse/replication/http/send_event.py | 13 +- synapse/replication/slave/storage/_base.py | 2 +- synapse/replication/slave/storage/account_data.py | 17 +- synapse/replication/slave/storage/appservice.py | 5 +- synapse/replication/slave/storage/client_ips.py | 4 +- synapse/replication/slave/storage/deviceinbox.py | 6 +- synapse/replication/slave/storage/devices.py | 14 +- synapse/replication/slave/storage/events.py | 77 +- synapse/replication/slave/storage/groups.py | 9 +- synapse/replication/slave/storage/presence.py | 8 +- synapse/replication/slave/storage/push_rule.py | 6 +- synapse/replication/slave/storage/pushers.py | 4 +- synapse/replication/slave/storage/receipts.py | 1 - synapse/replication/slave/storage/room.py | 4 +- synapse/replication/tcp/client.py | 6 +- synapse/replication/tcp/commands.py | 71 +- synapse/replication/tcp/protocol.py | 76 +- synapse/replication/tcp/resource.py | 54 +- synapse/replication/tcp/streams/_base.py | 160 ++-- synapse/replication/tcp/streams/events.py | 32 +- synapse/replication/tcp/streams/federation.py | 12 +- synapse/rest/__init__.py | 1 + synapse/rest/admin/__init__.py | 153 ++-- synapse/rest/admin/server_notice_servlet.py | 13 +- synapse/rest/client/transactions.py | 3 +- synapse/rest/client/v1/directory.py | 31 +- synapse/rest/client/v1/events.py | 7 +- synapse/rest/client/v1/login.py | 121 ++- synapse/rest/client/v1/logout.py | 3 +- synapse/rest/client/v1/presence.py | 2 +- synapse/rest/client/v1/profile.py | 25 +- synapse/rest/client/v1/push_rule.py | 119 ++- synapse/rest/client/v1/pusher.py | 70 +- synapse/rest/client/v1/room.py | 183 ++--- synapse/rest/client/v1/voip.py | 24 +- synapse/rest/client/v2_alpha/_base.py | 12 +- synapse/rest/client/v2_alpha/account.py | 213 ++--- synapse/rest/client/v2_alpha/account_data.py | 18 +- synapse/rest/client/v2_alpha/account_validity.py | 8 +- synapse/rest/client/v2_alpha/auth.py | 62 +- synapse/rest/client/v2_alpha/devices.py | 19 +- synapse/rest/client/v2_alpha/filter.py | 11 +- synapse/rest/client/v2_alpha/groups.py | 157 ++-- synapse/rest/client/v2_alpha/keys.py | 27 +- synapse/rest/client/v2_alpha/notifications.py | 28 +- synapse/rest/client/v2_alpha/openid.py | 22 +- synapse/rest/client/v2_alpha/read_marker.py | 4 +- synapse/rest/client/v2_alpha/receipts.py | 5 +- synapse/rest/client/v2_alpha/register.py | 181 ++--- synapse/rest/client/v2_alpha/relations.py | 5 +- synapse/rest/client/v2_alpha/report_event.py | 4 +- synapse/rest/client/v2_alpha/room_keys.py | 54 +- .../client/v2_alpha/room_upgrade_rest_servlet.py | 9 +- synapse/rest/client/v2_alpha/sendtodevice.py | 2 +- synapse/rest/client/v2_alpha/sync.py | 129 +-- synapse/rest/client/v2_alpha/tags.py | 14 +- synapse/rest/client/v2_alpha/thirdparty.py | 2 +- synapse/rest/client/v2_alpha/tokenrefresh.py | 1 + synapse/rest/client/v2_alpha/user_directory.py | 27 +- synapse/rest/client/versions.py | 43 +- synapse/rest/consent/consent_resource.py | 28 +- synapse/rest/key/v2/local_key_resource.py | 14 +- synapse/rest/key/v2/remote_key_resource.py | 56 +- synapse/rest/media/v0/content_repository.py | 15 +- synapse/rest/media/v1/_base.py | 57 +- synapse/rest/media/v1/config_resource.py | 4 +- synapse/rest/media/v1/download_resource.py | 8 +- synapse/rest/media/v1/filepath.py | 131 ++- synapse/rest/media/v1/media_repository.py | 215 ++--- synapse/rest/media/v1/media_storage.py | 17 +- synapse/rest/media/v1/preview_url_resource.py | 213 +++-- synapse/rest/media/v1/storage_provider.py | 6 +- synapse/rest/media/v1/thumbnail_resource.py | 120 ++- synapse/rest/media/v1/thumbnailer.py | 15 +- synapse/rest/media/v1/upload_resource.py | 30 +- synapse/rest/saml2/metadata_resource.py | 2 +- synapse/rest/saml2/response_resource.py | 13 +- synapse/rest/well_known.py | 11 +- synapse/rulecheck/domain_rule_checker.py | 33 +- synapse/secrets.py | 3 +- synapse/server.py | 158 ++-- synapse/server.pyi | 50 +- synapse/server_notices/consent_server_notices.py | 24 +- .../resource_limits_server_notices.py | 33 +- synapse/server_notices/server_notices_manager.py | 30 +- synapse/server_notices/server_notices_sender.py | 13 +- .../server_notices/worker_server_notices_sender.py | 1 + synapse/state/__init__.py | 110 +-- synapse/state/v1.py | 56 +- synapse/state/v2.py | 92 +-- synapse/storage/__init__.py | 10 +- synapse/storage/_base.py | 10 +- synapse/storage/background_updates.py | 2 +- synapse/storage/devices.py | 12 +- synapse/storage/e2e_room_keys.py | 28 +- synapse/storage/engines/sqlite.py | 2 +- synapse/storage/event_federation.py | 3 +- synapse/storage/event_push_actions.py | 4 +- synapse/storage/events.py | 12 +- synapse/storage/events_bg_updates.py | 16 +- synapse/storage/events_worker.py | 14 +- synapse/storage/group_server.py | 8 +- synapse/storage/keys.py | 2 +- synapse/storage/media_repository.py | 22 +- synapse/storage/monthly_active_users.py | 6 +- synapse/storage/prepare_database.py | 13 +- synapse/storage/profile.py | 46 +- synapse/storage/push_rule.py | 30 +- synapse/storage/pusher.py | 36 +- synapse/storage/receipts.py | 2 +- synapse/storage/registration.py | 112 ++- synapse/storage/relations.py | 6 +- synapse/storage/room.py | 56 +- synapse/storage/roommember.py | 8 +- synapse/storage/schema/delta/20/pushers.py | 22 +- synapse/storage/schema/delta/30/as_users.py | 17 +- synapse/storage/schema/delta/31/pushers.py | 24 +- synapse/storage/schema/delta/33/remote_media_ts.py | 2 +- synapse/storage/schema/delta/47/state_group_seq.py | 5 +- .../schema/delta/48/group_unique_indexes.py | 14 +- .../schema/delta/50/make_event_content_nullable.py | 12 +- synapse/storage/search.py | 6 +- synapse/storage/stats.py | 34 +- synapse/storage/stream.py | 32 +- synapse/streams/config.py | 27 +- synapse/streams/events.py | 43 +- synapse/third_party_rules/access_rules.py | 53 +- synapse/types.py | 108 +-- synapse/util/__init__.py | 23 +- synapse/util/async_helpers.py | 41 +- synapse/util/caches/__init__.py | 6 +- synapse/util/caches/descriptors.py | 100 +-- synapse/util/caches/dictionary_cache.py | 15 +- synapse/util/caches/expiringcache.py | 18 +- synapse/util/caches/lrucache.py | 14 +- synapse/util/caches/response_cache.py | 20 +- synapse/util/caches/stream_change_cache.py | 13 +- synapse/util/caches/treecache.py | 1 + synapse/util/caches/ttlcache.py | 1 + synapse/util/distributor.py | 29 +- synapse/util/frozenutils.py | 9 +- synapse/util/httpresourcetree.py | 8 +- synapse/util/jsonobject.py | 6 +- synapse/util/logcontext.py | 79 +- synapse/util/logformatter.py | 3 +- synapse/util/logutils.py | 51 +- synapse/util/manhole.py | 18 +- synapse/util/metrics.py | 32 +- synapse/util/module_loader.py | 6 +- synapse/util/msisdn.py | 6 +- synapse/util/ratelimitutils.py | 35 +- synapse/util/stringutils.py | 14 +- synapse/util/threepids.py | 25 +- synapse/util/versionstring.py | 61 +- synapse/util/wheel_timer.py | 4 +- synapse/visibility.py | 75 +- tests/api/test_auth.py | 12 +- tests/config/test_server.py | 8 +- tests/config/test_tls.py | 2 +- tests/crypto/test_event_signing.py | 46 +- tests/events/test_utils.py | 94 +-- tests/federation/test_complexity.py | 2 +- tests/federation/test_federation_sender.py | 48 +- tests/handlers/test_auth.py | 10 +- tests/handlers/test_directory.py | 8 +- tests/handlers/test_e2e_room_keys.py | 18 +- tests/handlers/test_identity.py | 57 +- tests/handlers/test_profile.py | 6 +- tests/handlers/test_register.py | 44 +- tests/handlers/test_stats.py | 5 +- tests/handlers/test_typing.py | 16 +- tests/handlers/test_user_directory.py | 8 +- .../federation/test_matrix_federation_agent.py | 216 ++--- tests/http/federation/test_srv_resolver.py | 2 +- tests/http/test_endpoint.py | 12 +- tests/http/test_fedclient.py | 19 +- tests/push/test_email.py | 6 +- tests/rest/admin/test_admin.py | 92 +-- tests/rest/client/test_consent.py | 10 +- tests/rest/client/test_identity.py | 93 +-- tests/rest/client/test_retention.py | 89 +-- tests/rest/client/test_room_access_rules.py | 206 ++--- tests/rest/client/v1/test_profile.py | 7 +- tests/rest/client/v1/test_rooms.py | 46 +- tests/rest/client/v1/utils.py | 6 +- tests/rest/client/v2_alpha/test_account.py | 31 +- tests/rest/client/v2_alpha/test_capabilities.py | 14 +- tests/rest/client/v2_alpha/test_password_policy.py | 62 +- tests/rest/client/v2_alpha/test_register.py | 52 +- tests/rest/media/v1/test_base.py | 4 +- tests/rest/media/v1/test_media_storage.py | 6 +- tests/rest/media/v1/test_url_preview.py | 50 +- tests/rulecheck/test_domainrulecheck.py | 22 +- tests/server.py | 16 +- .../test_resource_limits_server_notices.py | 6 +- tests/state/test_v2.py | 20 +- tests/storage/test_appservice.py | 4 +- tests/storage/test_client_ips.py | 20 +- tests/storage/test_devices.py | 16 +- tests/storage/test_end_to_end_keys.py | 8 +- tests/storage/test_event_federation.py | 14 +- tests/storage/test_event_metrics.py | 36 +- tests/storage/test_monthly_active_users.py | 22 +- tests/storage/test_profile.py | 9 +- tests/storage/test_state.py | 16 +- tests/test_server.py | 12 +- tests/test_state.py | 6 +- tests/test_types.py | 4 +- tests/test_utils/logging_setup.py | 2 +- tests/test_visibility.py | 2 +- tests/unittest.py | 16 +- tests/util/caches/test_descriptors.py | 60 +- tests/util/caches/test_ttlcache.py | 46 +- tests/utils.py | 18 +- 353 files changed, 8823 insertions(+), 10315 deletions(-) (limited to 'tests') diff --git a/synapse/__init__.py b/synapse/__init__.py index 0c01546789..119359be68 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -28,6 +28,7 @@ try: from twisted.internet import protocol from twisted.internet.protocol import Factory from twisted.names.dns import DNSDatagramProtocol + protocol.Factory.noisy = False Factory.noisy = False DNSDatagramProtocol.noisy = False diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py index 6e93f5a0c6..bdcd915bbe 100644 --- a/synapse/_scripts/register_new_matrix_user.py +++ b/synapse/_scripts/register_new_matrix_user.py @@ -57,18 +57,18 @@ def request_registration( nonce = r.json()["nonce"] - mac = hmac.new(key=shared_secret.encode('utf8'), digestmod=hashlib.sha1) + mac = hmac.new(key=shared_secret.encode("utf8"), digestmod=hashlib.sha1) - mac.update(nonce.encode('utf8')) + mac.update(nonce.encode("utf8")) mac.update(b"\x00") - mac.update(user.encode('utf8')) + mac.update(user.encode("utf8")) mac.update(b"\x00") - mac.update(password.encode('utf8')) + mac.update(password.encode("utf8")) mac.update(b"\x00") mac.update(b"admin" if admin else b"notadmin") if user_type: mac.update(b"\x00") - mac.update(user_type.encode('utf8')) + mac.update(user_type.encode("utf8")) mac = mac.hexdigest() @@ -134,8 +134,9 @@ def register_new_user(user, password, server_location, shared_secret, admin, use else: admin = False - request_registration(user, password, server_location, shared_secret, - bool(admin), user_type) + request_registration( + user, password, server_location, shared_secret, bool(admin), user_type + ) def main(): @@ -189,7 +190,7 @@ def main(): group.add_argument( "-c", "--config", - type=argparse.FileType('r'), + type=argparse.FileType("r"), help="Path to server config file. Used to read in shared secret.", ) @@ -200,7 +201,7 @@ def main(): parser.add_argument( "server_url", default="https://localhost:8448", - nargs='?', + nargs="?", help="URL to use to talk to the home server. Defaults to " " 'https://localhost:8448'.", ) @@ -220,8 +221,9 @@ def main(): if args.admin or args.no_admin: admin = args.admin - register_new_user(args.user, args.password, args.server_url, secret, - admin, args.user_type) + register_new_user( + args.user, args.password, args.server_url, secret, admin, args.user_type + ) if __name__ == "__main__": diff --git a/synapse/api/auth.py b/synapse/api/auth.py index f505f1ac63..bd0c1a1275 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -36,8 +36,11 @@ logger = logging.getLogger(__name__) AuthEventTypes = ( - EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels, - EventTypes.JoinRules, EventTypes.RoomHistoryVisibility, + EventTypes.Create, + EventTypes.Member, + EventTypes.PowerLevels, + EventTypes.JoinRules, + EventTypes.RoomHistoryVisibility, EventTypes.ThirdPartyInvite, ) @@ -54,6 +57,7 @@ class Auth(object): FIXME: This class contains a mix of functions for authenticating users of our client-server API and authenticating events added to room graphs. """ + def __init__(self, hs): self.hs = hs self.clock = hs.get_clock() @@ -70,15 +74,12 @@ class Auth(object): def check_from_context(self, room_version, event, context, do_sig_check=True): prev_state_ids = yield context.get_prev_state_ids(self.store) auth_events_ids = yield self.compute_auth_events( - event, prev_state_ids, for_verification=True, + event, prev_state_ids, for_verification=True ) auth_events = yield self.store.get_events(auth_events_ids) - auth_events = { - (e.type, e.state_key): e for e in itervalues(auth_events) - } + auth_events = {(e.type, e.state_key): e for e in itervalues(auth_events)} self.check( - room_version, event, - auth_events=auth_events, do_sig_check=do_sig_check, + room_version, event, auth_events=auth_events, do_sig_check=do_sig_check ) def check(self, room_version, event, auth_events, do_sig_check=True): @@ -115,15 +116,10 @@ class Auth(object): the room. """ if current_state: - member = current_state.get( - (EventTypes.Member, user_id), - None - ) + member = current_state.get((EventTypes.Member, user_id), None) else: member = yield self.state.get_current_state( - room_id=room_id, - event_type=EventTypes.Member, - state_key=user_id + room_id=room_id, event_type=EventTypes.Member, state_key=user_id ) self._check_joined_room(member, user_id, room_id) @@ -143,23 +139,17 @@ class Auth(object): the room. This will be the leave event if they have left the room. """ member = yield self.state.get_current_state( - room_id=room_id, - event_type=EventTypes.Member, - state_key=user_id + room_id=room_id, event_type=EventTypes.Member, state_key=user_id ) membership = member.membership if member else None if membership not in (Membership.JOIN, Membership.LEAVE): - raise AuthError(403, "User %s not in room %s" % ( - user_id, room_id - )) + raise AuthError(403, "User %s not in room %s" % (user_id, room_id)) if membership == Membership.LEAVE: forgot = yield self.store.did_forget(user_id, room_id) if forgot: - raise AuthError(403, "User %s not in room %s" % ( - user_id, room_id - )) + raise AuthError(403, "User %s not in room %s" % (user_id, room_id)) defer.returnValue(member) @@ -171,9 +161,9 @@ class Auth(object): def _check_joined_room(self, member, user_id, room_id): if not member or member.membership != Membership.JOIN: - raise AuthError(403, "User %s not in room %s (%s)" % ( - user_id, room_id, repr(member) - )) + raise AuthError( + 403, "User %s not in room %s (%s)" % (user_id, room_id, repr(member)) + ) def can_federate(self, event, auth_events): creation_event = auth_events.get((EventTypes.Create, "")) @@ -185,11 +175,7 @@ class Auth(object): @defer.inlineCallbacks def get_user_by_req( - self, - request, - allow_guest=False, - rights="access", - allow_expired=False, + self, request, allow_guest=False, rights="access", allow_expired=False ): """ Get a registered user's ID. @@ -209,9 +195,8 @@ class Auth(object): try: ip_addr = self.hs.get_ip_from_request(request) user_agent = request.requestHeaders.getRawHeaders( - b"User-Agent", - default=[b""] - )[0].decode('ascii', 'surrogateescape') + b"User-Agent", default=[b""] + )[0].decode("ascii", "surrogateescape") access_token = self.get_access_token_from_request( request, self.TOKEN_NOT_FOUND_HTTP_STATUS @@ -244,11 +229,12 @@ class Auth(object): if self._account_validity.enabled and not allow_expired: user_id = user.to_string() expiration_ts = yield self.store.get_expiration_ts_for_user(user_id) - if expiration_ts is not None and self.clock.time_msec() >= expiration_ts: + if ( + expiration_ts is not None + and self.clock.time_msec() >= expiration_ts + ): raise AuthError( - 403, - "User account has expired", - errcode=Codes.EXPIRED_ACCOUNT, + 403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT ) # device_id may not be present if get_user_by_access_token has been @@ -266,18 +252,23 @@ class Auth(object): if is_guest and not allow_guest: raise AuthError( - 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN + 403, + "Guest access not allowed", + errcode=Codes.GUEST_ACCESS_FORBIDDEN, ) request.authenticated_entity = user.to_string() - defer.returnValue(synapse.types.create_requester( - user, token_id, is_guest, device_id, app_service=app_service) + defer.returnValue( + synapse.types.create_requester( + user, token_id, is_guest, device_id, app_service=app_service + ) ) except KeyError: raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", - errcode=Codes.MISSING_TOKEN + self.TOKEN_NOT_FOUND_HTTP_STATUS, + "Missing access token.", + errcode=Codes.MISSING_TOKEN, ) def _get_appservice_user_id(self, request): @@ -288,32 +279,29 @@ class Auth(object): ) if app_service is None: - return(None, None) + return (None, None) if app_service.ip_range_whitelist: ip_address = IPAddress(self.hs.get_ip_from_request(request)) if ip_address not in app_service.ip_range_whitelist: - return(None, None) + return (None, None) if b"user_id" not in request.args: - return(app_service.sender, app_service) + return (app_service.sender, app_service) - user_id = request.args[b"user_id"][0].decode('utf8') + user_id = request.args[b"user_id"][0].decode("utf8") if app_service.sender == user_id: - return(app_service.sender, app_service) + return (app_service.sender, app_service) if not app_service.is_interested_in_user(user_id): - raise AuthError( - 403, - "Application service cannot masquerade as this user." - ) + raise AuthError(403, "Application service cannot masquerade as this user.") # Let ASes manipulate nonexistent users (e.g. to shadow-register them) # if not (yield self.store.get_user_by_id(user_id)): # raise AuthError( # 403, # "Application service has not registered this user" # ) - return(user_id, app_service) + return (user_id, app_service) @defer.inlineCallbacks def get_user_by_access_token(self, token, rights="access"): @@ -370,13 +358,13 @@ class Auth(object): raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unknown user_id %s" % user_id, - errcode=Codes.UNKNOWN_TOKEN + errcode=Codes.UNKNOWN_TOKEN, ) if not stored_user["is_guest"]: raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Guest access token used for regular user", - errcode=Codes.UNKNOWN_TOKEN + errcode=Codes.UNKNOWN_TOKEN, ) ret = { "user": user, @@ -404,8 +392,9 @@ class Auth(object): ) as e: logger.warning("Invalid macaroon in auth: %s %s", type(e), e) raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.", - errcode=Codes.UNKNOWN_TOKEN + self.TOKEN_NOT_FOUND_HTTP_STATUS, + "Invalid macaroon passed.", + errcode=Codes.UNKNOWN_TOKEN, ) def _parse_and_validate_macaroon(self, token, rights="access"): @@ -443,13 +432,13 @@ class Auth(object): guest = True self.validate_macaroon( - macaroon, rights, self.hs.config.expire_access_token, - user_id=user_id, + macaroon, rights, self.hs.config.expire_access_token, user_id=user_id ) except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.", - errcode=Codes.UNKNOWN_TOKEN + self.TOKEN_NOT_FOUND_HTTP_STATUS, + "Invalid macaroon passed.", + errcode=Codes.UNKNOWN_TOKEN, ) if not has_expiry and rights == "access": @@ -474,10 +463,11 @@ class Auth(object): user_prefix = "user_id = " for caveat in macaroon.caveats: if caveat.caveat_id.startswith(user_prefix): - return caveat.caveat_id[len(user_prefix):] + return caveat.caveat_id[len(user_prefix) :] raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon", - errcode=Codes.UNKNOWN_TOKEN + self.TOKEN_NOT_FOUND_HTTP_STATUS, + "No user caveat in macaroon", + errcode=Codes.UNKNOWN_TOKEN, ) def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id): @@ -524,7 +514,7 @@ class Auth(object): prefix = "time < " if not caveat.startswith(prefix): return False - expiry = int(caveat[len(prefix):]) + expiry = int(caveat[len(prefix) :]) now = self.hs.get_clock().time_msec() return now < expiry @@ -574,19 +564,19 @@ class Auth(object): auth_ids = [] - key = (EventTypes.PowerLevels, "", ) + key = (EventTypes.PowerLevels, "") power_level_event_id = current_state_ids.get(key) if power_level_event_id: auth_ids.append(power_level_event_id) - key = (EventTypes.JoinRules, "", ) + key = (EventTypes.JoinRules, "") join_rule_event_id = current_state_ids.get(key) - key = (EventTypes.Member, event.sender, ) + key = (EventTypes.Member, event.sender) member_event_id = current_state_ids.get(key) - key = (EventTypes.Create, "", ) + key = (EventTypes.Create, "") create_event_id = current_state_ids.get(key) if create_event_id: auth_ids.append(create_event_id) @@ -612,7 +602,7 @@ class Auth(object): auth_ids.append(member_event_id) if for_verification: - key = (EventTypes.Member, event.state_key, ) + key = (EventTypes.Member, event.state_key) existing_event_id = current_state_ids.get(key) if existing_event_id: auth_ids.append(existing_event_id) @@ -621,7 +611,7 @@ class Auth(object): if "third_party_invite" in event.content: key = ( EventTypes.ThirdPartyInvite, - event.content["third_party_invite"]["signed"]["token"] + event.content["third_party_invite"]["signed"]["token"], ) third_party_invite_id = current_state_ids.get(key) if third_party_invite_id: @@ -677,7 +667,7 @@ class Auth(object): auth_events[(EventTypes.PowerLevels, "")] = power_level_event send_level = event_auth.get_send_level( - EventTypes.Aliases, "", power_level_event, + EventTypes.Aliases, "", power_level_event ) user_level = event_auth.get_user_power_level(user_id, auth_events) @@ -685,7 +675,7 @@ class Auth(object): raise AuthError( 403, "This server requires you to be a moderator in the room to" - " edit its room list entry" + " edit its room list entry", ) @staticmethod @@ -735,7 +725,7 @@ class Auth(object): ) parts = auth_headers[0].split(b" ") if parts[0] == b"Bearer" and len(parts) == 2: - return parts[1].decode('ascii') + return parts[1].decode("ascii") else: raise AuthError( token_not_found_http_status, @@ -748,10 +738,10 @@ class Auth(object): raise AuthError( token_not_found_http_status, "Missing access token.", - errcode=Codes.MISSING_TOKEN + errcode=Codes.MISSING_TOKEN, ) - return query_params[0].decode('ascii') + return query_params[0].decode("ascii") @defer.inlineCallbacks def check_in_room_or_world_readable(self, room_id, user_id): @@ -778,8 +768,8 @@ class Auth(object): room_id, EventTypes.RoomHistoryVisibility, "" ) if ( - visibility and - visibility.content["history_visibility"] == "world_readable" + visibility + and visibility.content["history_visibility"] == "world_readable" ): defer.returnValue((Membership.JOIN, None)) return @@ -813,10 +803,11 @@ class Auth(object): if self.hs.config.hs_disabled: raise ResourceLimitError( - 403, self.hs.config.hs_disabled_message, + 403, + self.hs.config.hs_disabled_message, errcode=Codes.RESOURCE_LIMIT_EXCEEDED, admin_contact=self.hs.config.admin_contact, - limit_type=self.hs.config.hs_disabled_limit_type + limit_type=self.hs.config.hs_disabled_limit_type, ) if self.hs.config.limit_usage_by_mau is True: assert not (user_id and threepid) @@ -841,8 +832,9 @@ class Auth(object): current_mau = yield self.store.get_monthly_active_count() if current_mau >= self.hs.config.max_mau_value: raise ResourceLimitError( - 403, "Monthly Active User Limit Exceeded", + 403, + "Monthly Active User Limit Exceeded", admin_contact=self.hs.config.admin_contact, errcode=Codes.RESOURCE_LIMIT_EXCEEDED, - limit_type="monthly_active_user" + limit_type="monthly_active_user", ) diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 731c200c8d..74c2087abc 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -18,7 +18,7 @@ """Contains constants from the specification.""" # the "depth" field on events is limited to 2**63 - 1 -MAX_DEPTH = 2**63 - 1 +MAX_DEPTH = 2 ** 63 - 1 # the maximum length for a room alias is 255 characters MAX_ALIAS_LENGTH = 255 @@ -30,6 +30,7 @@ MAX_USERID_LENGTH = 255 class Membership(object): """Represents the membership states of a user in a room.""" + INVITE = u"invite" JOIN = u"join" KNOCK = u"knock" @@ -40,6 +41,7 @@ class Membership(object): class PresenceState(object): """Represents the presence state of a user.""" + OFFLINE = u"offline" UNAVAILABLE = u"unavailable" ONLINE = u"online" @@ -121,6 +123,7 @@ class UserTypes(object): """Allows for user type specific behaviour. With the benefit of hindsight 'admin' and 'guest' users should also be UserTypes. Normal users are type None """ + SUPPORT = "support" ALL_USER_TYPES = (SUPPORT,) @@ -128,6 +131,7 @@ class UserTypes(object): class RelationTypes(object): """The types of relations known to this server. """ + ANNOTATION = "m.annotation" REPLACE = "m.replace" REFERENCE = "m.reference" diff --git a/synapse/api/errors.py b/synapse/api/errors.py index e46bfdfcb9..214a45636c 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -78,6 +78,7 @@ class CodeMessageException(RuntimeError): code (int): HTTP error code msg (str): string describing the error """ + def __init__(self, code, msg): super(CodeMessageException, self).__init__("%d: %s" % (code, msg)) self.code = code @@ -91,6 +92,7 @@ class SynapseError(CodeMessageException): Attributes: errcode (str): Matrix error code e.g 'M_FORBIDDEN' """ + def __init__(self, code, msg, errcode=Codes.UNKNOWN): """Constructs a synapse error. @@ -103,10 +105,7 @@ class SynapseError(CodeMessageException): self.errcode = errcode def error_dict(self): - return cs_error( - self.msg, - self.errcode, - ) + return cs_error(self.msg, self.errcode) class ProxiedRequestError(SynapseError): @@ -115,27 +114,23 @@ class ProxiedRequestError(SynapseError): Attributes: errcode (str): Matrix error code e.g 'M_FORBIDDEN' """ + def __init__(self, code, msg, errcode=Codes.UNKNOWN, additional_fields=None): - super(ProxiedRequestError, self).__init__( - code, msg, errcode - ) + super(ProxiedRequestError, self).__init__(code, msg, errcode) if additional_fields is None: self._additional_fields = {} else: self._additional_fields = dict(additional_fields) def error_dict(self): - return cs_error( - self.msg, - self.errcode, - **self._additional_fields - ) + return cs_error(self.msg, self.errcode, **self._additional_fields) class ConsentNotGivenError(SynapseError): """The error returned to the client when the user has not consented to the privacy policy. """ + def __init__(self, msg, consent_uri): """Constructs a ConsentNotGivenError @@ -144,22 +139,17 @@ class ConsentNotGivenError(SynapseError): consent_url (str): The URL where the user can give their consent """ super(ConsentNotGivenError, self).__init__( - code=http_client.FORBIDDEN, - msg=msg, - errcode=Codes.CONSENT_NOT_GIVEN + code=http_client.FORBIDDEN, msg=msg, errcode=Codes.CONSENT_NOT_GIVEN ) self._consent_uri = consent_uri def error_dict(self): - return cs_error( - self.msg, - self.errcode, - consent_uri=self._consent_uri - ) + return cs_error(self.msg, self.errcode, consent_uri=self._consent_uri) class RegistrationError(SynapseError): """An error raised when a registration event fails.""" + pass @@ -198,15 +188,17 @@ class InteractiveAuthIncompleteError(Exception): result (dict): the server response to the request, which should be passed back to the client """ + def __init__(self, result): super(InteractiveAuthIncompleteError, self).__init__( - "Interactive auth not yet complete", + "Interactive auth not yet complete" ) self.result = result class UnrecognizedRequestError(SynapseError): """An error indicating we don't understand the request you're trying to make""" + def __init__(self, *args, **kwargs): if "errcode" not in kwargs: kwargs["errcode"] = Codes.UNRECOGNIZED @@ -215,21 +207,14 @@ class UnrecognizedRequestError(SynapseError): message = "Unrecognized request" else: message = args[0] - super(UnrecognizedRequestError, self).__init__( - 400, - message, - **kwargs - ) + super(UnrecognizedRequestError, self).__init__(400, message, **kwargs) class NotFoundError(SynapseError): """An error indicating we can't find the thing you asked for""" + def __init__(self, msg="Not found", errcode=Codes.NOT_FOUND): - super(NotFoundError, self).__init__( - 404, - msg, - errcode=errcode - ) + super(NotFoundError, self).__init__(404, msg, errcode=errcode) class AuthError(SynapseError): @@ -246,8 +231,11 @@ class ResourceLimitError(SynapseError): Any error raised when there is a problem with resource usage. For instance, the monthly active user limit for the server has been exceeded """ + def __init__( - self, code, msg, + self, + code, + msg, errcode=Codes.RESOURCE_LIMIT_EXCEEDED, admin_contact=None, limit_type=None, @@ -261,7 +249,7 @@ class ResourceLimitError(SynapseError): self.msg, self.errcode, admin_contact=self.admin_contact, - limit_type=self.limit_type + limit_type=self.limit_type, ) @@ -276,6 +264,7 @@ class EventSizeError(SynapseError): class EventStreamError(SynapseError): """An error raised when there a problem with the event stream.""" + def __init__(self, *args, **kwargs): if "errcode" not in kwargs: kwargs["errcode"] = Codes.BAD_PAGINATION @@ -284,47 +273,53 @@ class EventStreamError(SynapseError): class LoginError(SynapseError): """An error raised when there was a problem logging in.""" + pass class StoreError(SynapseError): """An error raised when there was a problem storing some data.""" + pass class InvalidCaptchaError(SynapseError): - def __init__(self, code=400, msg="Invalid captcha.", error_url=None, - errcode=Codes.CAPTCHA_INVALID): + def __init__( + self, + code=400, + msg="Invalid captcha.", + error_url=None, + errcode=Codes.CAPTCHA_INVALID, + ): super(InvalidCaptchaError, self).__init__(code, msg, errcode) self.error_url = error_url def error_dict(self): - return cs_error( - self.msg, - self.errcode, - error_url=self.error_url, - ) + return cs_error(self.msg, self.errcode, error_url=self.error_url) class LimitExceededError(SynapseError): """A client has sent too many requests and is being throttled. """ - def __init__(self, code=429, msg="Too Many Requests", retry_after_ms=None, - errcode=Codes.LIMIT_EXCEEDED): + + def __init__( + self, + code=429, + msg="Too Many Requests", + retry_after_ms=None, + errcode=Codes.LIMIT_EXCEEDED, + ): super(LimitExceededError, self).__init__(code, msg, errcode) self.retry_after_ms = retry_after_ms def error_dict(self): - return cs_error( - self.msg, - self.errcode, - retry_after_ms=self.retry_after_ms, - ) + return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms) class RoomKeysVersionError(SynapseError): """A client has tried to upload to a non-current version of the room_keys store """ + def __init__(self, current_version): """ Args: @@ -339,6 +334,7 @@ class RoomKeysVersionError(SynapseError): class UnsupportedRoomVersionError(SynapseError): """The client's request to create a room used a room version that the server does not support.""" + def __init__(self): super(UnsupportedRoomVersionError, self).__init__( code=400, @@ -362,22 +358,19 @@ class IncompatibleRoomVersionError(SynapseError): Unlike UnsupportedRoomVersionError, it is specific to the case of the make_join failing. """ + def __init__(self, room_version): super(IncompatibleRoomVersionError, self).__init__( code=400, msg="Your homeserver does not support the features required to " - "join this room", + "join this room", errcode=Codes.INCOMPATIBLE_ROOM_VERSION, ) self._room_version = room_version def error_dict(self): - return cs_error( - self.msg, - self.errcode, - room_version=self._room_version, - ) + return cs_error(self.msg, self.errcode, room_version=self._room_version) class PasswordRefusedError(SynapseError): @@ -389,11 +382,7 @@ class PasswordRefusedError(SynapseError): msg="This password doesn't comply with the server's policy", errcode=Codes.WEAK_PASSWORD, ): - super(PasswordRefusedError, self).__init__( - code=400, - msg=msg, - errcode=errcode, - ) + super(PasswordRefusedError, self).__init__(code=400, msg=msg, errcode=errcode) class RequestSendFailed(RuntimeError): @@ -404,11 +393,11 @@ class RequestSendFailed(RuntimeError): networking (e.g. DNS failures, connection timeouts etc), versus unexpected errors (like programming errors). """ + def __init__(self, inner_exception, can_retry): super(RequestSendFailed, self).__init__( - "Failed to send request: %s: %s" % ( - type(inner_exception).__name__, inner_exception, - ) + "Failed to send request: %s: %s" + % (type(inner_exception).__name__, inner_exception) ) self.inner_exception = inner_exception self.can_retry = can_retry @@ -452,7 +441,7 @@ class FederationError(RuntimeError): self.affected = affected self.source = source - msg = "%s %s: %s" % (level, code, reason,) + msg = "%s %s: %s" % (level, code, reason) super(FederationError, self).__init__(msg) def get_dict(self): @@ -472,6 +461,7 @@ class HttpResponseException(CodeMessageException): Attributes: response (bytes): body of response """ + def __init__(self, code, msg, response): """ @@ -510,7 +500,7 @@ class HttpResponseException(CodeMessageException): if not isinstance(j, dict): j = {} - errcode = j.pop('errcode', Codes.UNKNOWN) - errmsg = j.pop('error', self.msg) + errcode = j.pop("errcode", Codes.UNKNOWN) + errmsg = j.pop("error", self.msg) return ProxiedRequestError(self.code, errmsg, errcode, j) diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 3906475403..9b3daca29b 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -28,117 +28,55 @@ FILTER_SCHEMA = { "additionalProperties": False, "type": "object", "properties": { - "limit": { - "type": "number" - }, - "senders": { - "$ref": "#/definitions/user_id_array" - }, - "not_senders": { - "$ref": "#/definitions/user_id_array" - }, + "limit": {"type": "number"}, + "senders": {"$ref": "#/definitions/user_id_array"}, + "not_senders": {"$ref": "#/definitions/user_id_array"}, # TODO: We don't limit event type values but we probably should... # check types are valid event types - "types": { - "type": "array", - "items": { - "type": "string" - } - }, - "not_types": { - "type": "array", - "items": { - "type": "string" - } - } - } + "types": {"type": "array", "items": {"type": "string"}}, + "not_types": {"type": "array", "items": {"type": "string"}}, + }, } ROOM_FILTER_SCHEMA = { "additionalProperties": False, "type": "object", "properties": { - "not_rooms": { - "$ref": "#/definitions/room_id_array" - }, - "rooms": { - "$ref": "#/definitions/room_id_array" - }, - "ephemeral": { - "$ref": "#/definitions/room_event_filter" - }, - "include_leave": { - "type": "boolean" - }, - "state": { - "$ref": "#/definitions/room_event_filter" - }, - "timeline": { - "$ref": "#/definitions/room_event_filter" - }, - "account_data": { - "$ref": "#/definitions/room_event_filter" - }, - } + "not_rooms": {"$ref": "#/definitions/room_id_array"}, + "rooms": {"$ref": "#/definitions/room_id_array"}, + "ephemeral": {"$ref": "#/definitions/room_event_filter"}, + "include_leave": {"type": "boolean"}, + "state": {"$ref": "#/definitions/room_event_filter"}, + "timeline": {"$ref": "#/definitions/room_event_filter"}, + "account_data": {"$ref": "#/definitions/room_event_filter"}, + }, } ROOM_EVENT_FILTER_SCHEMA = { "additionalProperties": False, "type": "object", "properties": { - "limit": { - "type": "number" - }, - "senders": { - "$ref": "#/definitions/user_id_array" - }, - "not_senders": { - "$ref": "#/definitions/user_id_array" - }, - "types": { - "type": "array", - "items": { - "type": "string" - } - }, - "not_types": { - "type": "array", - "items": { - "type": "string" - } - }, - "rooms": { - "$ref": "#/definitions/room_id_array" - }, - "not_rooms": { - "$ref": "#/definitions/room_id_array" - }, - "contains_url": { - "type": "boolean" - }, - "lazy_load_members": { - "type": "boolean" - }, - "include_redundant_members": { - "type": "boolean" - }, - } + "limit": {"type": "number"}, + "senders": {"$ref": "#/definitions/user_id_array"}, + "not_senders": {"$ref": "#/definitions/user_id_array"}, + "types": {"type": "array", "items": {"type": "string"}}, + "not_types": {"type": "array", "items": {"type": "string"}}, + "rooms": {"$ref": "#/definitions/room_id_array"}, + "not_rooms": {"$ref": "#/definitions/room_id_array"}, + "contains_url": {"type": "boolean"}, + "lazy_load_members": {"type": "boolean"}, + "include_redundant_members": {"type": "boolean"}, + }, } USER_ID_ARRAY_SCHEMA = { "type": "array", - "items": { - "type": "string", - "format": "matrix_user_id" - } + "items": {"type": "string", "format": "matrix_user_id"}, } ROOM_ID_ARRAY_SCHEMA = { "type": "array", - "items": { - "type": "string", - "format": "matrix_room_id" - } + "items": {"type": "string", "format": "matrix_room_id"}, } USER_FILTER_SCHEMA = { @@ -150,22 +88,13 @@ USER_FILTER_SCHEMA = { "user_id_array": USER_ID_ARRAY_SCHEMA, "filter": FILTER_SCHEMA, "room_filter": ROOM_FILTER_SCHEMA, - "room_event_filter": ROOM_EVENT_FILTER_SCHEMA + "room_event_filter": ROOM_EVENT_FILTER_SCHEMA, }, "properties": { - "presence": { - "$ref": "#/definitions/filter" - }, - "account_data": { - "$ref": "#/definitions/filter" - }, - "room": { - "$ref": "#/definitions/room_filter" - }, - "event_format": { - "type": "string", - "enum": ["client", "federation"] - }, + "presence": {"$ref": "#/definitions/filter"}, + "account_data": {"$ref": "#/definitions/filter"}, + "room": {"$ref": "#/definitions/room_filter"}, + "event_format": {"type": "string", "enum": ["client", "federation"]}, "event_fields": { "type": "array", "items": { @@ -177,26 +106,25 @@ USER_FILTER_SCHEMA = { # # Note that because this is a regular expression, we have to escape # each backslash in the pattern. - "pattern": r"^((?!\\\\).)*$" - } - } + "pattern": r"^((?!\\\\).)*$", + }, + }, }, - "additionalProperties": False + "additionalProperties": False, } -@FormatChecker.cls_checks('matrix_room_id') +@FormatChecker.cls_checks("matrix_room_id") def matrix_room_id_validator(room_id_str): return RoomID.from_string(room_id_str) -@FormatChecker.cls_checks('matrix_user_id') +@FormatChecker.cls_checks("matrix_user_id") def matrix_user_id_validator(user_id_str): return UserID.from_string(user_id_str) class Filtering(object): - def __init__(self, hs): super(Filtering, self).__init__() self.store = hs.get_datastore() @@ -228,8 +156,9 @@ class Filtering(object): # individual top-level key e.g. public_user_data. Filters are made of # many definitions. try: - jsonschema.validate(user_filter_json, USER_FILTER_SCHEMA, - format_checker=FormatChecker()) + jsonschema.validate( + user_filter_json, USER_FILTER_SCHEMA, format_checker=FormatChecker() + ) except jsonschema.ValidationError as e: raise SynapseError(400, str(e)) @@ -240,10 +169,9 @@ class FilterCollection(object): room_filter_json = self._filter_json.get("room", {}) - self._room_filter = Filter({ - k: v for k, v in room_filter_json.items() - if k in ("rooms", "not_rooms") - }) + self._room_filter = Filter( + {k: v for k, v in room_filter_json.items() if k in ("rooms", "not_rooms")} + ) self._room_timeline_filter = Filter(room_filter_json.get("timeline", {})) self._room_state_filter = Filter(room_filter_json.get("state", {})) @@ -252,9 +180,7 @@ class FilterCollection(object): self._presence_filter = Filter(filter_json.get("presence", {})) self._account_data = Filter(filter_json.get("account_data", {})) - self.include_leave = filter_json.get("room", {}).get( - "include_leave", False - ) + self.include_leave = filter_json.get("room", {}).get("include_leave", False) self.event_fields = filter_json.get("event_fields", []) self.event_format = filter_json.get("event_format", "client") @@ -299,22 +225,22 @@ class FilterCollection(object): def blocks_all_presence(self): return ( - self._presence_filter.filters_all_types() or - self._presence_filter.filters_all_senders() + self._presence_filter.filters_all_types() + or self._presence_filter.filters_all_senders() ) def blocks_all_room_ephemeral(self): return ( - self._room_ephemeral_filter.filters_all_types() or - self._room_ephemeral_filter.filters_all_senders() or - self._room_ephemeral_filter.filters_all_rooms() + self._room_ephemeral_filter.filters_all_types() + or self._room_ephemeral_filter.filters_all_senders() + or self._room_ephemeral_filter.filters_all_rooms() ) def blocks_all_room_timeline(self): return ( - self._room_timeline_filter.filters_all_types() or - self._room_timeline_filter.filters_all_senders() or - self._room_timeline_filter.filters_all_rooms() + self._room_timeline_filter.filters_all_types() + or self._room_timeline_filter.filters_all_senders() + or self._room_timeline_filter.filters_all_rooms() ) @@ -375,12 +301,7 @@ class Filter(object): # check if there is a string url field in the content for filtering purposes contains_url = isinstance(content.get("url"), text_type) - return self.check_fields( - room_id, - sender, - ev_type, - contains_url, - ) + return self.check_fields(room_id, sender, ev_type, contains_url) def check_fields(self, room_id, sender, event_type, contains_url): """Checks whether the filter matches the given event fields. @@ -391,7 +312,7 @@ class Filter(object): literal_keys = { "rooms": lambda v: room_id == v, "senders": lambda v: sender == v, - "types": lambda v: _matches_wildcard(event_type, v) + "types": lambda v: _matches_wildcard(event_type, v), } for name, match_func in literal_keys.items(): diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 296c4a1c17..09caaa35bd 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -44,7 +44,7 @@ class Ratelimiter(object): """ self.prune_message_counts(time_now_s) message_count, time_start, _ignored = self.message_counts.get( - key, (0., time_now_s, None), + key, (0., time_now_s, None) ) time_delta = time_now_s - time_start sent_count = message_count - time_delta * rate_hz @@ -59,14 +59,10 @@ class Ratelimiter(object): message_count += 1 if update: - self.message_counts[key] = ( - message_count, time_start, rate_hz - ) + self.message_counts[key] = (message_count, time_start, rate_hz) if rate_hz > 0: - time_allowed = ( - time_start + (message_count - burst_count + 1) / rate_hz - ) + time_allowed = time_start + (message_count - burst_count + 1) / rate_hz if time_allowed < time_now_s: time_allowed = time_now_s else: @@ -76,9 +72,7 @@ class Ratelimiter(object): def prune_message_counts(self, time_now_s): for key in list(self.message_counts.keys()): - message_count, time_start, rate_hz = ( - self.message_counts[key] - ) + message_count, time_start, rate_hz = self.message_counts[key] time_delta = time_now_s - time_start if message_count - time_delta * rate_hz > 0: break @@ -92,5 +86,5 @@ class Ratelimiter(object): if not allowed: raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now_s)), + retry_after_ms=int(1000 * (time_allowed - time_now_s)) ) diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index d644803d38..95292b7dec 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -19,9 +19,10 @@ class EventFormatVersions(object): """This is an internal enum for tracking the version of the event format, independently from the room version. """ - V1 = 1 # $id:server event id format - V2 = 2 # MSC1659-style $hash event id format: introduced for room v3 - V3 = 3 # MSC1884-style $hash format: introduced for room v4 + + V1 = 1 # $id:server event id format + V2 = 2 # MSC1659-style $hash event id format: introduced for room v3 + V3 = 3 # MSC1884-style $hash format: introduced for room v4 KNOWN_EVENT_FORMAT_VERSIONS = { @@ -33,8 +34,9 @@ KNOWN_EVENT_FORMAT_VERSIONS = { class StateResolutionVersions(object): """Enum to identify the state resolution algorithms""" - V1 = 1 # room v1 state res - V2 = 2 # MSC1442 state res: room v2 and later + + V1 = 1 # room v1 state res + V2 = 2 # MSC1442 state res: room v2 and later class RoomDisposition(object): @@ -46,10 +48,10 @@ class RoomDisposition(object): class RoomVersion(object): """An object which describes the unique attributes of a room version.""" - identifier = attr.ib() # str; the identifier for this version - disposition = attr.ib() # str; one of the RoomDispositions - event_format = attr.ib() # int; one of the EventFormatVersions - state_res = attr.ib() # int; one of the StateResolutionVersions + identifier = attr.ib() # str; the identifier for this version + disposition = attr.ib() # str; one of the RoomDispositions + event_format = attr.ib() # int; one of the EventFormatVersions + state_res = attr.ib() # int; one of the StateResolutionVersions enforce_key_validity = attr.ib() # bool @@ -92,11 +94,12 @@ class RoomVersions(object): KNOWN_ROOM_VERSIONS = { - v.identifier: v for v in ( + v.identifier: v + for v in ( RoomVersions.V1, RoomVersions.V2, RoomVersions.V3, RoomVersions.V4, RoomVersions.V5, ) -} # type: dict[str, RoomVersion] +} # type: dict[str, RoomVersion] diff --git a/synapse/api/urls.py b/synapse/api/urls.py index e16c386a14..ff1f39e86c 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -42,13 +42,9 @@ class ConsentURIBuilder(object): hs_config (synapse.config.homeserver.HomeServerConfig): """ if hs_config.form_secret is None: - raise ConfigError( - "form_secret not set in config", - ) + raise ConfigError("form_secret not set in config") if hs_config.public_baseurl is None: - raise ConfigError( - "public_baseurl not set in config", - ) + raise ConfigError("public_baseurl not set in config") self._hmac_secret = hs_config.form_secret.encode("utf-8") self._public_baseurl = hs_config.public_baseurl @@ -64,15 +60,10 @@ class ConsentURIBuilder(object): (str) the URI where the user can do consent """ mac = hmac.new( - key=self._hmac_secret, - msg=user_id.encode('ascii'), - digestmod=sha256, + key=self._hmac_secret, msg=user_id.encode("ascii"), digestmod=sha256 ).hexdigest() consent_uri = "%s_matrix/consent?%s" % ( self._public_baseurl, - urlencode({ - "u": user_id, - "h": mac - }), + urlencode({"u": user_id, "h": mac}), ) return consent_uri diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py index f56f5fcc13..d877c77834 100644 --- a/synapse/app/__init__.py +++ b/synapse/app/__init__.py @@ -43,7 +43,7 @@ def check_bind_error(e, address, bind_addresses): address (str): Address on which binding was attempted. bind_addresses (list): Addresses on which the service listens. """ - if address == '0.0.0.0' and '::' in bind_addresses: - logger.warn('Failed to listen on 0.0.0.0, continuing because listening on [::]') + if address == "0.0.0.0" and "::" in bind_addresses: + logger.warn("Failed to listen on 0.0.0.0, continuing because listening on [::]") else: raise e diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 8cc990399f..df4c2d4c97 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -75,14 +75,14 @@ def start_worker_reactor(appname, config): def start_reactor( - appname, - soft_file_limit, - gc_thresholds, - pid_file, - daemonize, - cpu_affinity, - print_pidfile, - logger, + appname, + soft_file_limit, + gc_thresholds, + pid_file, + daemonize, + cpu_affinity, + print_pidfile, + logger, ): """ Run the reactor in the main process @@ -149,10 +149,10 @@ def start_reactor( def quit_with_error(error_string): message_lines = error_string.split("\n") line_length = max([len(l) for l in message_lines if len(l) < 80]) + 2 - sys.stderr.write("*" * line_length + '\n') + sys.stderr.write("*" * line_length + "\n") for line in message_lines: sys.stderr.write(" %s\n" % (line.rstrip(),)) - sys.stderr.write("*" * line_length + '\n') + sys.stderr.write("*" * line_length + "\n") sys.exit(1) @@ -178,14 +178,7 @@ def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50): r = [] for address in bind_addresses: try: - r.append( - reactor.listenTCP( - port, - factory, - backlog, - address - ) - ) + r.append(reactor.listenTCP(port, factory, backlog, address)) except error.CannotListenError as e: check_bind_error(e, address, bind_addresses) @@ -205,13 +198,7 @@ def listen_ssl( for address in bind_addresses: try: r.append( - reactor.listenSSL( - port, - factory, - context_factory, - backlog, - address - ) + reactor.listenSSL(port, factory, context_factory, backlog, address) ) except error.CannotListenError as e: check_bind_error(e, address, bind_addresses) @@ -243,15 +230,13 @@ def refresh_certificate(hs): if isinstance(i.factory, TLSMemoryBIOFactory): addr = i.getHost() logger.info( - "Replacing TLS context factory on [%s]:%i", addr.host, addr.port, + "Replacing TLS context factory on [%s]:%i", addr.host, addr.port ) # We want to replace TLS factories with a new one, with the new # TLS configuration. We do this by reaching in and pulling out # the wrappedFactory, and then re-wrapping it. i.factory = TLSMemoryBIOFactory( - hs.tls_server_context_factory, - False, - i.factory.wrappedFactory + hs.tls_server_context_factory, False, i.factory.wrappedFactory ) logger.info("Context factories updated.") @@ -267,6 +252,7 @@ def start(hs, listeners=None): try: # Set up the SIGHUP machinery. if hasattr(signal, "SIGHUP"): + def handle_sighup(*args, **kwargs): for i in _sighup_callbacks: i(hs) @@ -302,10 +288,8 @@ def setup_sentry(hs): return import sentry_sdk - sentry_sdk.init( - dsn=hs.config.sentry_dsn, - release=get_version_string(synapse), - ) + + sentry_sdk.init(dsn=hs.config.sentry_dsn, release=get_version_string(synapse)) # We set some default tags that give some context to this instance with sentry_sdk.configure_scope() as scope: @@ -326,7 +310,7 @@ def install_dns_limiter(reactor, max_dns_requests_in_flight=100): many DNS queries at once """ new_resolver = _LimitedHostnameResolver( - reactor.nameResolver, max_dns_requests_in_flight, + reactor.nameResolver, max_dns_requests_in_flight ) reactor.installNameResolver(new_resolver) @@ -339,11 +323,17 @@ class _LimitedHostnameResolver(object): def __init__(self, resolver, max_dns_requests_in_flight): self._resolver = resolver self._limiter = Linearizer( - name="dns_client_limiter", max_count=max_dns_requests_in_flight, + name="dns_client_limiter", max_count=max_dns_requests_in_flight ) - def resolveHostName(self, resolutionReceiver, hostName, portNumber=0, - addressTypes=None, transportSemantics='TCP'): + def resolveHostName( + self, + resolutionReceiver, + hostName, + portNumber=0, + addressTypes=None, + transportSemantics="TCP", + ): # We need this function to return `resolutionReceiver` so we do all the # actual logic involving deferreds in a separate function. @@ -363,8 +353,14 @@ class _LimitedHostnameResolver(object): return resolutionReceiver @defer.inlineCallbacks - def _resolve(self, resolutionReceiver, hostName, portNumber=0, - addressTypes=None, transportSemantics='TCP'): + def _resolve( + self, + resolutionReceiver, + hostName, + portNumber=0, + addressTypes=None, + transportSemantics="TCP", + ): with (yield self._limiter.queue(())): # resolveHostName doesn't return a Deferred, so we need to hook into @@ -374,8 +370,7 @@ class _LimitedHostnameResolver(object): receiver = _DeferredResolutionReceiver(resolutionReceiver, deferred) self._resolver.resolveHostName( - receiver, hostName, portNumber, - addressTypes, transportSemantics, + receiver, hostName, portNumber, addressTypes, transportSemantics ) yield deferred diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index 33107f56d1..9120bdb143 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -44,7 +44,9 @@ logger = logging.getLogger("synapse.app.appservice") class AppserviceSlaveStore( - DirectoryStore, SlavedEventStore, SlavedApplicationServiceStore, + DirectoryStore, + SlavedEventStore, + SlavedApplicationServiceStore, SlavedRegistrationStore, ): pass @@ -74,7 +76,7 @@ class AppserviceServer(HomeServer): listener_config, root_resource, self.version_string, - ) + ), ) logger.info("Synapse appservice now listening on port %d", port) @@ -88,18 +90,19 @@ class AppserviceServer(HomeServer): listener["bind_addresses"], listener["port"], manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ) + username="matrix", password="rabbithole", globals={"hs": self} + ), ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn(("Metrics listener configured, but " - "enable_metrics is not True!")) + logger.warn( + ( + "Metrics listener configured, but " + "enable_metrics is not True!" + ) + ) else: - _base.listen_metrics(listener["bind_addresses"], - listener["port"]) + _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -132,9 +135,7 @@ class ASReplicationHandler(ReplicationClientHandler): def start(config_options): try: - config = HomeServerConfig.load_config( - "Synapse appservice", config_options - ) + config = HomeServerConfig.load_config("Synapse appservice", config_options) except ConfigError as e: sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) @@ -173,6 +174,6 @@ def start(config_options): _base.start_worker_reactor("synapse-appservice", config) -if __name__ == '__main__': +if __name__ == "__main__": with LoggingContext("main"): start(sys.argv[1:]) diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index a16e037f32..310cdab2e4 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -118,9 +118,7 @@ class ClientReaderServer(HomeServer): PushRuleRestServlet(self).register(resource) VersionsRestServlet().register(resource) - resources.update({ - "/_matrix/client": resource, - }) + resources.update({"/_matrix/client": resource}) root_resource = create_resource_tree(resources, NoResource()) @@ -133,7 +131,7 @@ class ClientReaderServer(HomeServer): listener_config, root_resource, self.version_string, - ) + ), ) logger.info("Synapse client reader now listening on port %d", port) @@ -147,18 +145,19 @@ class ClientReaderServer(HomeServer): listener["bind_addresses"], listener["port"], manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ) + username="matrix", password="rabbithole", globals={"hs": self} + ), ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn(("Metrics listener configured, but " - "enable_metrics is not True!")) + logger.warn( + ( + "Metrics listener configured, but " + "enable_metrics is not True!" + ) + ) else: - _base.listen_metrics(listener["bind_addresses"], - listener["port"]) + _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -170,9 +169,7 @@ class ClientReaderServer(HomeServer): def start(config_options): try: - config = HomeServerConfig.load_config( - "Synapse client reader", config_options - ) + config = HomeServerConfig.load_config("Synapse client reader", config_options) except ConfigError as e: sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) @@ -199,6 +196,6 @@ def start(config_options): _base.start_worker_reactor("synapse-client-reader", config) -if __name__ == '__main__': +if __name__ == "__main__": with LoggingContext("main"): start(sys.argv[1:]) diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py index b8e5196152..ff522e4499 100644 --- a/synapse/app/event_creator.py +++ b/synapse/app/event_creator.py @@ -109,12 +109,14 @@ class EventCreatorServer(HomeServer): ProfileAvatarURLRestServlet(self).register(resource) ProfileDisplaynameRestServlet(self).register(resource) ProfileRestServlet(self).register(resource) - resources.update({ - "/_matrix/client/r0": resource, - "/_matrix/client/unstable": resource, - "/_matrix/client/v2_alpha": resource, - "/_matrix/client/api/v1": resource, - }) + resources.update( + { + "/_matrix/client/r0": resource, + "/_matrix/client/unstable": resource, + "/_matrix/client/v2_alpha": resource, + "/_matrix/client/api/v1": resource, + } + ) root_resource = create_resource_tree(resources, NoResource()) @@ -127,7 +129,7 @@ class EventCreatorServer(HomeServer): listener_config, root_resource, self.version_string, - ) + ), ) logger.info("Synapse event creator now listening on port %d", port) @@ -141,18 +143,19 @@ class EventCreatorServer(HomeServer): listener["bind_addresses"], listener["port"], manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ) + username="matrix", password="rabbithole", globals={"hs": self} + ), ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn(("Metrics listener configured, but " - "enable_metrics is not True!")) + logger.warn( + ( + "Metrics listener configured, but " + "enable_metrics is not True!" + ) + ) else: - _base.listen_metrics(listener["bind_addresses"], - listener["port"]) + _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -164,9 +167,7 @@ class EventCreatorServer(HomeServer): def start(config_options): try: - config = HomeServerConfig.load_config( - "Synapse event creator", config_options - ) + config = HomeServerConfig.load_config("Synapse event creator", config_options) except ConfigError as e: sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) @@ -198,6 +199,6 @@ def start(config_options): _base.start_worker_reactor("synapse-event-creator", config) -if __name__ == '__main__': +if __name__ == "__main__": with LoggingContext("main"): start(sys.argv[1:]) diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index 7da79dc827..9421420930 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -86,19 +86,18 @@ class FederationReaderServer(HomeServer): if name == "metrics": resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) elif name == "federation": - resources.update({ - FEDERATION_PREFIX: TransportLayerServer(self), - }) + resources.update({FEDERATION_PREFIX: TransportLayerServer(self)}) if name == "openid" and "federation" not in res["names"]: # Only load the openid resource separately if federation resource # is not specified since federation resource includes openid # resource. - resources.update({ - FEDERATION_PREFIX: TransportLayerServer( - self, - servlet_groups=["openid"], - ), - }) + resources.update( + { + FEDERATION_PREFIX: TransportLayerServer( + self, servlet_groups=["openid"] + ) + } + ) if name in ["keys", "federation"]: resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self) @@ -115,7 +114,7 @@ class FederationReaderServer(HomeServer): root_resource, self.version_string, ), - reactor=self.get_reactor() + reactor=self.get_reactor(), ) logger.info("Synapse federation reader now listening on port %d", port) @@ -129,18 +128,19 @@ class FederationReaderServer(HomeServer): listener["bind_addresses"], listener["port"], manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ) + username="matrix", password="rabbithole", globals={"hs": self} + ), ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn(("Metrics listener configured, but " - "enable_metrics is not True!")) + logger.warn( + ( + "Metrics listener configured, but " + "enable_metrics is not True!" + ) + ) else: - _base.listen_metrics(listener["bind_addresses"], - listener["port"]) + _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -181,6 +181,6 @@ def start(config_options): _base.start_worker_reactor("synapse-federation-reader", config) -if __name__ == '__main__': +if __name__ == "__main__": with LoggingContext("main"): start(sys.argv[1:]) diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 1d43f2b075..969be58d0b 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -52,8 +52,13 @@ logger = logging.getLogger("synapse.app.federation_sender") class FederationSenderSlaveStore( - SlavedDeviceInboxStore, SlavedTransactionStore, SlavedReceiptsStore, SlavedEventStore, - SlavedRegistrationStore, SlavedDeviceStore, SlavedPresenceStore, + SlavedDeviceInboxStore, + SlavedTransactionStore, + SlavedReceiptsStore, + SlavedEventStore, + SlavedRegistrationStore, + SlavedDeviceStore, + SlavedPresenceStore, ): def __init__(self, db_conn, hs): super(FederationSenderSlaveStore, self).__init__(db_conn, hs) @@ -65,10 +70,7 @@ class FederationSenderSlaveStore( self.federation_out_pos_startup = self._get_federation_out_pos(db_conn) def _get_federation_out_pos(self, db_conn): - sql = ( - "SELECT stream_id FROM federation_stream_position" - " WHERE type = ?" - ) + sql = "SELECT stream_id FROM federation_stream_position" " WHERE type = ?" sql = self.database_engine.convert_param_style(sql) txn = db_conn.cursor() @@ -103,7 +105,7 @@ class FederationSenderServer(HomeServer): listener_config, root_resource, self.version_string, - ) + ), ) logger.info("Synapse federation_sender now listening on port %d", port) @@ -117,18 +119,19 @@ class FederationSenderServer(HomeServer): listener["bind_addresses"], listener["port"], manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ) + username="matrix", password="rabbithole", globals={"hs": self} + ), ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn(("Metrics listener configured, but " - "enable_metrics is not True!")) + logger.warn( + ( + "Metrics listener configured, but " + "enable_metrics is not True!" + ) + ) else: - _base.listen_metrics(listener["bind_addresses"], - listener["port"]) + _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -151,7 +154,9 @@ class FederationSenderReplicationHandler(ReplicationClientHandler): self.send_handler.process_replication_rows(stream_name, token, rows) def get_streams_to_replicate(self): - args = super(FederationSenderReplicationHandler, self).get_streams_to_replicate() + args = super( + FederationSenderReplicationHandler, self + ).get_streams_to_replicate() args.update(self.send_handler.stream_positions()) return args @@ -203,6 +208,7 @@ class FederationSenderHandler(object): """Processes the replication stream and forwards the appropriate entries to the federation sender. """ + def __init__(self, hs, replication_client): self.store = hs.get_datastore() self._is_mine_id = hs.is_mine_id @@ -241,7 +247,7 @@ class FederationSenderHandler(object): # ... and when new receipts happen elif stream_name == ReceiptsStream.NAME: run_as_background_process( - "process_receipts_for_federation", self._on_new_receipts, rows, + "process_receipts_for_federation", self._on_new_receipts, rows ) @defer.inlineCallbacks @@ -278,12 +284,14 @@ class FederationSenderHandler(object): # We ACK this token over replication so that the master can drop # its in memory queues - self.replication_client.send_federation_ack(self.federation_position) + self.replication_client.send_federation_ack( + self.federation_position + ) self._last_ack = self.federation_position except Exception: logger.exception("Error updating federation stream position") -if __name__ == '__main__': +if __name__ == "__main__": with LoggingContext("main"): start(sys.argv[1:]) diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index 6504da5278..2fd7d57ebf 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -62,14 +62,11 @@ class PresenceStatusStubServlet(RestServlet): # Pass through the auth headers, if any, in case the access token # is there. auth_headers = request.requestHeaders.getRawHeaders("Authorization", []) - headers = { - "Authorization": auth_headers, - } + headers = {"Authorization": auth_headers} try: result = yield self.http_client.get_json( - self.main_uri + request.uri.decode('ascii'), - headers=headers, + self.main_uri + request.uri.decode("ascii"), headers=headers ) except HttpResponseException as e: raise e.to_synapse_error() @@ -105,18 +102,19 @@ class KeyUploadServlet(RestServlet): if device_id is not None: # passing the device_id here is deprecated; however, we allow it # for now for compatibility with older clients. - if (requester.device_id is not None and - device_id != requester.device_id): - logger.warning("Client uploading keys for a different device " - "(logged in as %s, uploading for %s)", - requester.device_id, device_id) + if requester.device_id is not None and device_id != requester.device_id: + logger.warning( + "Client uploading keys for a different device " + "(logged in as %s, uploading for %s)", + requester.device_id, + device_id, + ) else: device_id = requester.device_id if device_id is None: raise SynapseError( - 400, - "To upload keys, you must pass device_id when authenticating" + 400, "To upload keys, you must pass device_id when authenticating" ) if body: @@ -124,13 +122,9 @@ class KeyUploadServlet(RestServlet): # Pass through the auth headers, if any, in case the access token # is there. auth_headers = request.requestHeaders.getRawHeaders(b"Authorization", []) - headers = { - "Authorization": auth_headers, - } + headers = {"Authorization": auth_headers} result = yield self.http_client.post_json_get_json( - self.main_uri + request.uri.decode('ascii'), - body, - headers=headers, + self.main_uri + request.uri.decode("ascii"), body, headers=headers ) defer.returnValue((200, result)) @@ -171,12 +165,14 @@ class FrontendProxyServer(HomeServer): if not self.config.use_presence: PresenceStatusStubServlet(self).register(resource) - resources.update({ - "/_matrix/client/r0": resource, - "/_matrix/client/unstable": resource, - "/_matrix/client/v2_alpha": resource, - "/_matrix/client/api/v1": resource, - }) + resources.update( + { + "/_matrix/client/r0": resource, + "/_matrix/client/unstable": resource, + "/_matrix/client/v2_alpha": resource, + "/_matrix/client/api/v1": resource, + } + ) root_resource = create_resource_tree(resources, NoResource()) @@ -190,7 +186,7 @@ class FrontendProxyServer(HomeServer): root_resource, self.version_string, ), - reactor=self.get_reactor() + reactor=self.get_reactor(), ) logger.info("Synapse client reader now listening on port %d", port) @@ -204,18 +200,19 @@ class FrontendProxyServer(HomeServer): listener["bind_addresses"], listener["port"], manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ) + username="matrix", password="rabbithole", globals={"hs": self} + ), ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn(("Metrics listener configured, but " - "enable_metrics is not True!")) + logger.warn( + ( + "Metrics listener configured, but " + "enable_metrics is not True!" + ) + ) else: - _base.listen_metrics(listener["bind_addresses"], - listener["port"]) + _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -227,9 +224,7 @@ class FrontendProxyServer(HomeServer): def start(config_options): try: - config = HomeServerConfig.load_config( - "Synapse frontend proxy", config_options - ) + config = HomeServerConfig.load_config("Synapse frontend proxy", config_options) except ConfigError as e: sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) @@ -258,6 +253,6 @@ def start(config_options): _base.start_worker_reactor("synapse-frontend-proxy", config) -if __name__ == '__main__': +if __name__ == "__main__": with LoggingContext("main"): start(sys.argv[1:]) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index eaf7234ee3..9113d8729c 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -101,13 +101,12 @@ class SynapseHomeServer(HomeServer): # Skip loading openid resource if federation is defined # since federation resource will include openid continue - resources.update(self._configure_named_resource( - name, res.get("compress", False), - )) + resources.update( + self._configure_named_resource(name, res.get("compress", False)) + ) additional_resources = listener_config.get("additional_resources", {}) - logger.debug("Configuring additional resources: %r", - additional_resources) + logger.debug("Configuring additional resources: %r", additional_resources) module_api = ModuleApi(self, self.get_auth_handler()) for path, resmodule in additional_resources.items(): handler_cls, config = load_module(resmodule) @@ -174,59 +173,67 @@ class SynapseHomeServer(HomeServer): if compress: client_resource = gz_wrap(client_resource) - resources.update({ - "/_matrix/client/api/v1": client_resource, - "/_matrix/client/r0": client_resource, - "/_matrix/client/unstable": client_resource, - "/_matrix/client/v2_alpha": client_resource, - "/_matrix/client/versions": client_resource, - "/.well-known/matrix/client": WellKnownResource(self), - "/_synapse/admin": AdminRestResource(self), - }) + resources.update( + { + "/_matrix/client/api/v1": client_resource, + "/_matrix/client/r0": client_resource, + "/_matrix/client/unstable": client_resource, + "/_matrix/client/v2_alpha": client_resource, + "/_matrix/client/versions": client_resource, + "/.well-known/matrix/client": WellKnownResource(self), + "/_synapse/admin": AdminRestResource(self), + } + ) if self.get_config().saml2_enabled: from synapse.rest.saml2 import SAML2Resource + resources["/_matrix/saml2"] = SAML2Resource(self) if name == "consent": from synapse.rest.consent.consent_resource import ConsentResource + consent_resource = ConsentResource(self) if compress: consent_resource = gz_wrap(consent_resource) - resources.update({ - "/_matrix/consent": consent_resource, - }) + resources.update({"/_matrix/consent": consent_resource}) if name == "federation": - resources.update({ - FEDERATION_PREFIX: TransportLayerServer(self), - }) + resources.update({FEDERATION_PREFIX: TransportLayerServer(self)}) if name == "openid": - resources.update({ - FEDERATION_PREFIX: TransportLayerServer(self, servlet_groups=["openid"]), - }) + resources.update( + { + FEDERATION_PREFIX: TransportLayerServer( + self, servlet_groups=["openid"] + ) + } + ) if name in ["static", "client"]: - resources.update({ - STATIC_PREFIX: File( - os.path.join(os.path.dirname(synapse.__file__), "static") - ), - }) + resources.update( + { + STATIC_PREFIX: File( + os.path.join(os.path.dirname(synapse.__file__), "static") + ) + } + ) if name in ["media", "federation", "client"]: if self.get_config().enable_media_repo: media_repo = self.get_media_repository_resource() - resources.update({ - MEDIA_PREFIX: media_repo, - LEGACY_MEDIA_PREFIX: media_repo, - CONTENT_REPO_PREFIX: ContentRepoResource( - self, self.config.uploads_path - ), - }) + resources.update( + { + MEDIA_PREFIX: media_repo, + LEGACY_MEDIA_PREFIX: media_repo, + CONTENT_REPO_PREFIX: ContentRepoResource( + self, self.config.uploads_path + ), + } + ) elif name == "media": raise ConfigError( - "'media' resource conflicts with enable_media_repo=False", + "'media' resource conflicts with enable_media_repo=False" ) if name in ["keys", "federation"]: @@ -257,18 +264,14 @@ class SynapseHomeServer(HomeServer): for listener in listeners: if listener["type"] == "http": - self._listening_services.extend( - self._listener_http(config, listener) - ) + self._listening_services.extend(self._listener_http(config, listener)) elif listener["type"] == "manhole": listen_tcp( listener["bind_addresses"], listener["port"], manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ) + username="matrix", password="rabbithole", globals={"hs": self} + ), ) elif listener["type"] == "replication": services = listen_tcp( @@ -277,16 +280,17 @@ class SynapseHomeServer(HomeServer): ReplicationStreamProtocolFactory(self), ) for s in services: - reactor.addSystemEventTrigger( - "before", "shutdown", s.stopListening, - ) + reactor.addSystemEventTrigger("before", "shutdown", s.stopListening) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn(("Metrics listener configured, but " - "enable_metrics is not True!")) + logger.warn( + ( + "Metrics listener configured, but " + "enable_metrics is not True!" + ) + ) else: - _base.listen_metrics(listener["bind_addresses"], - listener["port"]) + _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -312,7 +316,7 @@ current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU") max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit") registered_reserved_users_mau_gauge = Gauge( "synapse_admin_mau:registered_reserved_users", - "Registered users with reserved threepids" + "Registered users with reserved threepids", ) @@ -327,8 +331,7 @@ def setup(config_options): """ try: config = HomeServerConfig.load_or_generate_config( - "Synapse Homeserver", - config_options, + "Synapse Homeserver", config_options ) except ConfigError as e: sys.stderr.write("\n" + str(e) + "\n") @@ -339,10 +342,7 @@ def setup(config_options): # generating config files and shouldn't try to continue. sys.exit(0) - synapse.config.logger.setup_logging( - config, - use_worker_options=False - ) + synapse.config.logger.setup_logging(config, use_worker_options=False) events.USE_FROZEN_DICTS = config.use_frozen_dicts @@ -357,7 +357,7 @@ def setup(config_options): database_engine=database_engine, ) - logger.info("Preparing database: %s...", config.database_config['name']) + logger.info("Preparing database: %s...", config.database_config["name"]) try: with hs.get_db_conn(run_new_connection=False) as db_conn: @@ -375,7 +375,7 @@ def setup(config_options): ) sys.exit(1) - logger.info("Database prepared in %s.", config.database_config['name']) + logger.info("Database prepared in %s.", config.database_config["name"]) hs.setup() hs.setup_master() @@ -391,9 +391,7 @@ def setup(config_options): acme = hs.get_acme_handler() # Check how long the certificate is active for. - cert_days_remaining = hs.config.is_disk_cert_valid( - allow_self_signed=False - ) + cert_days_remaining = hs.config.is_disk_cert_valid(allow_self_signed=False) # We want to reprovision if cert_days_remaining is None (meaning no # certificate exists), or the days remaining number it returns @@ -401,8 +399,8 @@ def setup(config_options): provision = False if ( - cert_days_remaining is None or - cert_days_remaining < hs.config.acme_reprovision_threshold + cert_days_remaining is None + or cert_days_remaining < hs.config.acme_reprovision_threshold ): provision = True @@ -433,10 +431,7 @@ def setup(config_options): yield do_acme() # Check if it needs to be reprovisioned every day. - hs.get_clock().looping_call( - reprovision_acme, - 24 * 60 * 60 * 1000 - ) + hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000) _base.start(hs, config.listeners) @@ -463,6 +458,7 @@ class SynapseService(service.Service): A twisted Service class that will start synapse. Used to run synapse via twistd and a .tac. """ + def __init__(self, config): self.config = config @@ -479,6 +475,7 @@ class SynapseService(service.Service): def run(hs): PROFILE_SYNAPSE = False if PROFILE_SYNAPSE: + def profile(func): from cProfile import Profile from threading import current_thread @@ -489,13 +486,14 @@ def run(hs): func(*args, **kargs) profile.disable() ident = current_thread().ident - profile.dump_stats("/tmp/%s.%s.%i.pstat" % ( - hs.hostname, func.__name__, ident - )) + profile.dump_stats( + "/tmp/%s.%s.%i.pstat" % (hs.hostname, func.__name__, ident) + ) return profiled from twisted.python.threadpool import ThreadPool + ThreadPool._worker = profile(ThreadPool._worker) reactor.run = profile(reactor.run) @@ -541,7 +539,9 @@ def run(hs): stats["daily_active_users"] = yield hs.get_datastore().count_daily_users() stats["monthly_active_users"] = yield hs.get_datastore().count_monthly_users() - stats["daily_active_rooms"] = yield hs.get_datastore().count_daily_active_rooms() + stats[ + "daily_active_rooms" + ] = yield hs.get_datastore().count_daily_active_rooms() stats["daily_messages"] = yield hs.get_datastore().count_daily_messages() r30_results = yield hs.get_datastore().count_r30_users() @@ -580,14 +580,11 @@ def run(hs): logger.info("report_stats can use psutil") stats_process.append(process) except (AttributeError): - logger.warning( - "Unable to read memory/cpu stats. Disabling reporting." - ) + logger.warning("Unable to read memory/cpu stats. Disabling reporting.") def generate_user_daily_visit_stats(): return run_as_background_process( - "generate_user_daily_visits", - hs.get_datastore().generate_user_daily_visits, + "generate_user_daily_visits", hs.get_datastore().generate_user_daily_visits ) # Rather than update on per session basis, batch up the requests. @@ -598,9 +595,9 @@ def run(hs): # monthly active user limiting functionality def reap_monthly_active_users(): return run_as_background_process( - "reap_monthly_active_users", - hs.get_datastore().reap_monthly_active_users, + "reap_monthly_active_users", hs.get_datastore().reap_monthly_active_users ) + clock.looping_call(reap_monthly_active_users, 1000 * 60 * 60) reap_monthly_active_users() @@ -618,8 +615,7 @@ def run(hs): def start_generate_monthly_active_users(): return run_as_background_process( - "generate_monthly_active_users", - generate_monthly_active_users, + "generate_monthly_active_users", generate_monthly_active_users ) start_generate_monthly_active_users() @@ -659,5 +655,5 @@ def main(): run(hs) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index d4cc4e9443..cf0e2036c3 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -72,13 +72,15 @@ class MediaRepositoryServer(HomeServer): resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) elif name == "media": media_repo = self.get_media_repository_resource() - resources.update({ - MEDIA_PREFIX: media_repo, - LEGACY_MEDIA_PREFIX: media_repo, - CONTENT_REPO_PREFIX: ContentRepoResource( - self, self.config.uploads_path - ), - }) + resources.update( + { + MEDIA_PREFIX: media_repo, + LEGACY_MEDIA_PREFIX: media_repo, + CONTENT_REPO_PREFIX: ContentRepoResource( + self, self.config.uploads_path + ), + } + ) root_resource = create_resource_tree(resources, NoResource()) @@ -91,7 +93,7 @@ class MediaRepositoryServer(HomeServer): listener_config, root_resource, self.version_string, - ) + ), ) logger.info("Synapse media repository now listening on port %d", port) @@ -105,18 +107,19 @@ class MediaRepositoryServer(HomeServer): listener["bind_addresses"], listener["port"], manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ) + username="matrix", password="rabbithole", globals={"hs": self} + ), ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn(("Metrics listener configured, but " - "enable_metrics is not True!")) + logger.warn( + ( + "Metrics listener configured, but " + "enable_metrics is not True!" + ) + ) else: - _base.listen_metrics(listener["bind_addresses"], - listener["port"]) + _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -164,6 +167,6 @@ def start(config_options): _base.start_worker_reactor("synapse-media-repository", config) -if __name__ == '__main__': +if __name__ == "__main__": with LoggingContext("main"): start(sys.argv[1:]) diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index cbf0d67f51..df29ea5ecb 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -46,36 +46,27 @@ logger = logging.getLogger("synapse.app.pusher") class PusherSlaveStore( - SlavedEventStore, SlavedPusherStore, SlavedReceiptsStore, - SlavedAccountDataStore + SlavedEventStore, SlavedPusherStore, SlavedReceiptsStore, SlavedAccountDataStore ): - update_pusher_last_stream_ordering_and_success = ( - __func__(DataStore.update_pusher_last_stream_ordering_and_success) + update_pusher_last_stream_ordering_and_success = __func__( + DataStore.update_pusher_last_stream_ordering_and_success ) - update_pusher_failing_since = ( - __func__(DataStore.update_pusher_failing_since) - ) + update_pusher_failing_since = __func__(DataStore.update_pusher_failing_since) - update_pusher_last_stream_ordering = ( - __func__(DataStore.update_pusher_last_stream_ordering) + update_pusher_last_stream_ordering = __func__( + DataStore.update_pusher_last_stream_ordering ) - get_throttle_params_by_room = ( - __func__(DataStore.get_throttle_params_by_room) - ) + get_throttle_params_by_room = __func__(DataStore.get_throttle_params_by_room) - set_throttle_params = ( - __func__(DataStore.set_throttle_params) - ) + set_throttle_params = __func__(DataStore.set_throttle_params) - get_time_of_last_push_action_before = ( - __func__(DataStore.get_time_of_last_push_action_before) + get_time_of_last_push_action_before = __func__( + DataStore.get_time_of_last_push_action_before ) - get_profile_displayname = ( - __func__(DataStore.get_profile_displayname) - ) + get_profile_displayname = __func__(DataStore.get_profile_displayname) class PusherServer(HomeServer): @@ -105,7 +96,7 @@ class PusherServer(HomeServer): listener_config, root_resource, self.version_string, - ) + ), ) logger.info("Synapse pusher now listening on port %d", port) @@ -119,18 +110,19 @@ class PusherServer(HomeServer): listener["bind_addresses"], listener["port"], manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ) + username="matrix", password="rabbithole", globals={"hs": self} + ), ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn(("Metrics listener configured, but " - "enable_metrics is not True!")) + logger.warn( + ( + "Metrics listener configured, but " + "enable_metrics is not True!" + ) + ) else: - _base.listen_metrics(listener["bind_addresses"], - listener["port"]) + _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -161,9 +153,7 @@ class PusherReplicationHandler(ReplicationClientHandler): else: yield self.start_pusher(row.user_id, row.app_id, row.pushkey) elif stream_name == "events": - yield self.pusher_pool.on_new_notifications( - token, token, - ) + yield self.pusher_pool.on_new_notifications(token, token) elif stream_name == "receipts": yield self.pusher_pool.on_new_receipts( token, token, set(row.room_id for row in rows) @@ -188,9 +178,7 @@ class PusherReplicationHandler(ReplicationClientHandler): def start(config_options): try: - config = HomeServerConfig.load_config( - "Synapse pusher", config_options - ) + config = HomeServerConfig.load_config("Synapse pusher", config_options) except ConfigError as e: sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) @@ -234,6 +222,6 @@ def start(config_options): _base.start_worker_reactor("synapse-pusher", config) -if __name__ == '__main__': +if __name__ == "__main__": with LoggingContext("main"): ps = start(sys.argv[1:]) diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 5388def28a..858949910d 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -98,10 +98,7 @@ class SynchrotronPresence(object): self.notifier = hs.get_notifier() active_presence = self.store.take_presence_startup_info() - self.user_to_current_state = { - state.user_id: state - for state in active_presence - } + self.user_to_current_state = {state.user_id: state for state in active_presence} # user_id -> last_sync_ms. Lists the users that have stopped syncing # but we haven't notified the master of that yet @@ -196,17 +193,26 @@ class SynchrotronPresence(object): room_ids_to_states, users_to_states = parties self.notifier.on_new_event( - "presence_key", stream_id, rooms=room_ids_to_states.keys(), - users=users_to_states.keys() + "presence_key", + stream_id, + rooms=room_ids_to_states.keys(), + users=users_to_states.keys(), ) @defer.inlineCallbacks def process_replication_rows(self, token, rows): - states = [UserPresenceState( - row.user_id, row.state, row.last_active_ts, - row.last_federation_update_ts, row.last_user_sync_ts, row.status_msg, - row.currently_active - ) for row in rows] + states = [ + UserPresenceState( + row.user_id, + row.state, + row.last_active_ts, + row.last_federation_update_ts, + row.last_user_sync_ts, + row.status_msg, + row.currently_active, + ) + for row in rows + ] for state in states: self.user_to_current_state[state.user_id] = state @@ -217,7 +223,8 @@ class SynchrotronPresence(object): def get_currently_syncing_users(self): if self.hs.config.use_presence: return [ - user_id for user_id, count in iteritems(self.user_to_num_current_syncs) + user_id + for user_id, count in iteritems(self.user_to_num_current_syncs) if count > 0 ] else: @@ -281,12 +288,14 @@ class SynchrotronServer(HomeServer): events.register_servlets(self, resource) InitialSyncRestServlet(self).register(resource) RoomInitialSyncRestServlet(self).register(resource) - resources.update({ - "/_matrix/client/r0": resource, - "/_matrix/client/unstable": resource, - "/_matrix/client/v2_alpha": resource, - "/_matrix/client/api/v1": resource, - }) + resources.update( + { + "/_matrix/client/r0": resource, + "/_matrix/client/unstable": resource, + "/_matrix/client/v2_alpha": resource, + "/_matrix/client/api/v1": resource, + } + ) root_resource = create_resource_tree(resources, NoResource()) @@ -299,7 +308,7 @@ class SynchrotronServer(HomeServer): listener_config, root_resource, self.version_string, - ) + ), ) logger.info("Synapse synchrotron now listening on port %d", port) @@ -313,18 +322,19 @@ class SynchrotronServer(HomeServer): listener["bind_addresses"], listener["port"], manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ) + username="matrix", password="rabbithole", globals={"hs": self} + ), ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn(("Metrics listener configured, but " - "enable_metrics is not True!")) + logger.warn( + ( + "Metrics listener configured, but " + "enable_metrics is not True!" + ) + ) else: - _base.listen_metrics(listener["bind_addresses"], - listener["port"]) + _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -382,40 +392,36 @@ class SyncReplicationHandler(ReplicationClientHandler): ) elif stream_name == "push_rules": self.notifier.on_new_event( - "push_rules_key", token, users=[row.user_id for row in rows], + "push_rules_key", token, users=[row.user_id for row in rows] ) - elif stream_name in ("account_data", "tag_account_data",): + elif stream_name in ("account_data", "tag_account_data"): self.notifier.on_new_event( - "account_data_key", token, users=[row.user_id for row in rows], + "account_data_key", token, users=[row.user_id for row in rows] ) elif stream_name == "receipts": self.notifier.on_new_event( - "receipt_key", token, rooms=[row.room_id for row in rows], + "receipt_key", token, rooms=[row.room_id for row in rows] ) elif stream_name == "typing": self.typing_handler.process_replication_rows(token, rows) self.notifier.on_new_event( - "typing_key", token, rooms=[row.room_id for row in rows], + "typing_key", token, rooms=[row.room_id for row in rows] ) elif stream_name == "to_device": entities = [row.entity for row in rows if row.entity.startswith("@")] if entities: - self.notifier.on_new_event( - "to_device_key", token, users=entities, - ) + self.notifier.on_new_event("to_device_key", token, users=entities) elif stream_name == "device_lists": all_room_ids = set() for row in rows: room_ids = yield self.store.get_rooms_for_user(row.user_id) all_room_ids.update(room_ids) - self.notifier.on_new_event( - "device_list_key", token, rooms=all_room_ids, - ) + self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids) elif stream_name == "presence": yield self.presence_handler.process_replication_rows(token, rows) elif stream_name == "receipts": self.notifier.on_new_event( - "groups_key", token, users=[row.user_id for row in rows], + "groups_key", token, users=[row.user_id for row in rows] ) except Exception: logger.exception("Error processing replication") @@ -423,9 +429,7 @@ class SyncReplicationHandler(ReplicationClientHandler): def start(config_options): try: - config = HomeServerConfig.load_config( - "Synapse synchrotron", config_options - ) + config = HomeServerConfig.load_config("Synapse synchrotron", config_options) except ConfigError as e: sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) @@ -453,6 +457,6 @@ def start(config_options): _base.start_worker_reactor("synapse-synchrotron", config) -if __name__ == '__main__': +if __name__ == "__main__": with LoggingContext("main"): start(sys.argv[1:]) diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index 355f5aa71d..2d9d2e1bbc 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -66,14 +66,16 @@ class UserDirectorySlaveStore( events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict( - db_conn, "current_state_delta_stream", + db_conn, + "current_state_delta_stream", entity_column="room_id", stream_column="stream_id", max_value=events_max, # As we share the stream id with events token limit=1000, ) self._curr_state_delta_stream_cache = StreamChangeCache( - "_curr_state_delta_stream_cache", min_curr_state_delta_id, + "_curr_state_delta_stream_cache", + min_curr_state_delta_id, prefilled_cache=curr_state_delta_prefill, ) @@ -110,12 +112,14 @@ class UserDirectoryServer(HomeServer): elif name == "client": resource = JsonResource(self, canonical_json=False) user_directory.register_servlets(self, resource) - resources.update({ - "/_matrix/client/r0": resource, - "/_matrix/client/unstable": resource, - "/_matrix/client/v2_alpha": resource, - "/_matrix/client/api/v1": resource, - }) + resources.update( + { + "/_matrix/client/r0": resource, + "/_matrix/client/unstable": resource, + "/_matrix/client/v2_alpha": resource, + "/_matrix/client/api/v1": resource, + } + ) root_resource = create_resource_tree(resources, NoResource()) @@ -128,7 +132,7 @@ class UserDirectoryServer(HomeServer): listener_config, root_resource, self.version_string, - ) + ), ) logger.info("Synapse user_dir now listening on port %d", port) @@ -142,18 +146,19 @@ class UserDirectoryServer(HomeServer): listener["bind_addresses"], listener["port"], manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ) + username="matrix", password="rabbithole", globals={"hs": self} + ), ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn(("Metrics listener configured, but " - "enable_metrics is not True!")) + logger.warn( + ( + "Metrics listener configured, but " + "enable_metrics is not True!" + ) + ) else: - _base.listen_metrics(listener["bind_addresses"], - listener["port"]) + _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -186,9 +191,7 @@ class UserDirectoryReplicationHandler(ReplicationClientHandler): def start(config_options): try: - config = HomeServerConfig.load_config( - "Synapse user directory", config_options - ) + config = HomeServerConfig.load_config("Synapse user directory", config_options) except ConfigError as e: sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) @@ -227,6 +230,6 @@ def start(config_options): _base.start_worker_reactor("synapse-user-dir", config) -if __name__ == '__main__': +if __name__ == "__main__": with LoggingContext("main"): start(sys.argv[1:]) diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index c58f83d268..8885d1f131 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -48,9 +48,7 @@ class AppServiceTransaction(object): A Deferred which resolves to True if the transaction was sent. """ return as_api.push_bulk( - service=self.service, - events=self.events, - txn_id=self.id + service=self.service, events=self.events, txn_id=self.id ) def complete(self, store): @@ -64,10 +62,7 @@ class AppServiceTransaction(object): Returns: A Deferred which resolves to True if the transaction was completed. """ - return store.complete_appservice_txn( - service=self.service, - txn_id=self.id - ) + return store.complete_appservice_txn(service=self.service, txn_id=self.id) class ApplicationService(object): @@ -76,6 +71,7 @@ class ApplicationService(object): Provides methods to check if this service is "interested" in events. """ + NS_USERS = "users" NS_ALIASES = "aliases" NS_ROOMS = "rooms" @@ -84,9 +80,19 @@ class ApplicationService(object): # values. NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS] - def __init__(self, token, hostname, url=None, namespaces=None, hs_token=None, - sender=None, id=None, protocols=None, rate_limited=True, - ip_range_whitelist=None): + def __init__( + self, + token, + hostname, + url=None, + namespaces=None, + hs_token=None, + sender=None, + id=None, + protocols=None, + rate_limited=True, + ip_range_whitelist=None, + ): self.token = token self.url = url self.hs_token = hs_token @@ -128,9 +134,7 @@ class ApplicationService(object): if not isinstance(regex_obj, dict): raise ValueError("Expected dict regex for ns '%s'" % ns) if not isinstance(regex_obj.get("exclusive"), bool): - raise ValueError( - "Expected bool for 'exclusive' in ns '%s'" % ns - ) + raise ValueError("Expected bool for 'exclusive' in ns '%s'" % ns) group_id = regex_obj.get("group_id") if group_id: if not isinstance(group_id, str): @@ -153,9 +157,7 @@ class ApplicationService(object): if isinstance(regex, string_types): regex_obj["regex"] = re.compile(regex) # Pre-compile regex else: - raise ValueError( - "Expected string for 'regex' in ns '%s'" % ns - ) + raise ValueError("Expected string for 'regex' in ns '%s'" % ns) return namespaces def _matches_regex(self, test_string, namespace_key): @@ -178,8 +180,9 @@ class ApplicationService(object): if self.is_interested_in_user(event.sender): defer.returnValue(True) # also check m.room.member state key - if (event.type == EventTypes.Member and - self.is_interested_in_user(event.state_key)): + if event.type == EventTypes.Member and self.is_interested_in_user( + event.state_key + ): defer.returnValue(True) if not store: diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 9ccc5a80fc..571881775b 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -32,19 +32,17 @@ logger = logging.getLogger(__name__) sent_transactions_counter = Counter( "synapse_appservice_api_sent_transactions", "Number of /transactions/ requests sent", - ["service"] + ["service"], ) failed_transactions_counter = Counter( "synapse_appservice_api_failed_transactions", "Number of /transactions/ requests that failed to send", - ["service"] + ["service"], ) sent_events_counter = Counter( - "synapse_appservice_api_sent_events", - "Number of events sent to the AS", - ["service"] + "synapse_appservice_api_sent_events", "Number of events sent to the AS", ["service"] ) HOUR_IN_MS = 60 * 60 * 1000 @@ -92,8 +90,9 @@ class ApplicationServiceApi(SimpleHttpClient): super(ApplicationServiceApi, self).__init__(hs) self.clock = hs.get_clock() - self.protocol_meta_cache = ResponseCache(hs, "as_protocol_meta", - timeout_ms=HOUR_IN_MS) + self.protocol_meta_cache = ResponseCache( + hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS + ) @defer.inlineCallbacks def query_user(self, service, user_id): @@ -102,9 +101,7 @@ class ApplicationServiceApi(SimpleHttpClient): uri = service.url + ("/users/%s" % urllib.parse.quote(user_id)) response = None try: - response = yield self.get_json(uri, { - "access_token": service.hs_token - }) + response = yield self.get_json(uri, {"access_token": service.hs_token}) if response is not None: # just an empty json object defer.returnValue(True) except CodeMessageException as e: @@ -123,9 +120,7 @@ class ApplicationServiceApi(SimpleHttpClient): uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias)) response = None try: - response = yield self.get_json(uri, { - "access_token": service.hs_token - }) + response = yield self.get_json(uri, {"access_token": service.hs_token}) if response is not None: # just an empty json object defer.returnValue(True) except CodeMessageException as e: @@ -144,9 +139,7 @@ class ApplicationServiceApi(SimpleHttpClient): elif kind == ThirdPartyEntityKind.LOCATION: required_field = "alias" else: - raise ValueError( - "Unrecognised 'kind' argument %r to query_3pe()", kind - ) + raise ValueError("Unrecognised 'kind' argument %r to query_3pe()", kind) if service.url is None: defer.returnValue([]) @@ -154,14 +147,13 @@ class ApplicationServiceApi(SimpleHttpClient): service.url, APP_SERVICE_PREFIX, kind, - urllib.parse.quote(protocol) + urllib.parse.quote(protocol), ) try: response = yield self.get_json(uri, fields) if not isinstance(response, list): logger.warning( - "query_3pe to %s returned an invalid response %r", - uri, response + "query_3pe to %s returned an invalid response %r", uri, response ) defer.returnValue([]) @@ -171,8 +163,7 @@ class ApplicationServiceApi(SimpleHttpClient): ret.append(r) else: logger.warning( - "query_3pe to %s returned an invalid result %r", - uri, r + "query_3pe to %s returned an invalid result %r", uri, r ) defer.returnValue(ret) @@ -189,27 +180,27 @@ class ApplicationServiceApi(SimpleHttpClient): uri = "%s%s/thirdparty/protocol/%s" % ( service.url, APP_SERVICE_PREFIX, - urllib.parse.quote(protocol) + urllib.parse.quote(protocol), ) try: info = yield self.get_json(uri, {}) if not _is_valid_3pe_metadata(info): - logger.warning("query_3pe_protocol to %s did not return a" - " valid result", uri) + logger.warning( + "query_3pe_protocol to %s did not return a" " valid result", uri + ) defer.returnValue(None) for instance in info.get("instances", []): network_id = instance.get("network_id", None) if network_id is not None: instance["instance_id"] = ThirdPartyInstanceID( - service.id, network_id, + service.id, network_id ).to_string() defer.returnValue(info) except Exception as ex: - logger.warning("query_3pe_protocol to %s threw exception %s", - uri, ex) + logger.warning("query_3pe_protocol to %s threw exception %s", uri, ex) defer.returnValue(None) key = (service.id, protocol) @@ -223,22 +214,19 @@ class ApplicationServiceApi(SimpleHttpClient): events = self._serialize(events) if txn_id is None: - logger.warning("push_bulk: Missing txn ID sending events to %s", - service.url) + logger.warning( + "push_bulk: Missing txn ID sending events to %s", service.url + ) txn_id = str(0) txn_id = str(txn_id) - uri = service.url + ("/transactions/%s" % - urllib.parse.quote(txn_id)) + uri = service.url + ("/transactions/%s" % urllib.parse.quote(txn_id)) try: yield self.put_json( uri=uri, - json_body={ - "events": events - }, - args={ - "access_token": service.hs_token - }) + json_body={"events": events}, + args={"access_token": service.hs_token}, + ) sent_transactions_counter.labels(service.id).inc() sent_events_counter.labels(service.id).inc(len(events)) defer.returnValue(True) @@ -252,6 +240,4 @@ class ApplicationServiceApi(SimpleHttpClient): def _serialize(self, events): time_now = self.clock.time_msec() - return [ - serialize_event(e, time_now, as_client_event=True) for e in events - ] + return [serialize_event(e, time_now, as_client_event=True) for e in events] diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 685f15c061..b54bf5411f 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -112,15 +112,14 @@ class _ServiceQueuer(object): return run_as_background_process( - "as-sender-%s" % (service.id, ), - self._send_request, service, + "as-sender-%s" % (service.id,), self._send_request, service ) @defer.inlineCallbacks def _send_request(self, service): # sanity-check: we shouldn't get here if this service already has a sender # running. - assert(service.id not in self.requests_in_flight) + assert service.id not in self.requests_in_flight self.requests_in_flight.add(service.id) try: @@ -137,7 +136,6 @@ class _ServiceQueuer(object): class _TransactionController(object): - def __init__(self, clock, store, as_api, recoverer_fn): self.clock = clock self.store = store @@ -149,10 +147,7 @@ class _TransactionController(object): @defer.inlineCallbacks def send(self, service, events): try: - txn = yield self.store.create_appservice_txn( - service=service, - events=events - ) + txn = yield self.store.create_appservice_txn(service=service, events=events) service_is_up = yield self._is_service_up(service) if service_is_up: sent = yield txn.send(self.as_api) @@ -167,12 +162,12 @@ class _TransactionController(object): @defer.inlineCallbacks def on_recovered(self, recoverer): self.recoverers.remove(recoverer) - logger.info("Successfully recovered application service AS ID %s", - recoverer.service.id) + logger.info( + "Successfully recovered application service AS ID %s", recoverer.service.id + ) logger.info("Remaining active recoverers: %s", len(self.recoverers)) yield self.store.set_appservice_state( - recoverer.service, - ApplicationServiceState.UP + recoverer.service, ApplicationServiceState.UP ) def add_recoverers(self, recoverers): @@ -184,13 +179,10 @@ class _TransactionController(object): @defer.inlineCallbacks def _start_recoverer(self, service): try: - yield self.store.set_appservice_state( - service, - ApplicationServiceState.DOWN - ) + yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN) logger.info( "Application service falling behind. Starting recoverer. AS ID %s", - service.id + service.id, ) recoverer = self.recoverer_fn(service, self.on_recovered) self.add_recoverers([recoverer]) @@ -205,19 +197,16 @@ class _TransactionController(object): class _Recoverer(object): - @staticmethod @defer.inlineCallbacks def start(clock, store, as_api, callback): - services = yield store.get_appservices_by_state( - ApplicationServiceState.DOWN - ) - recoverers = [ - _Recoverer(clock, store, as_api, s, callback) for s in services - ] + services = yield store.get_appservices_by_state(ApplicationServiceState.DOWN) + recoverers = [_Recoverer(clock, store, as_api, s, callback) for s in services] for r in recoverers: - logger.info("Starting recoverer for AS ID %s which was marked as " - "DOWN", r.service.id) + logger.info( + "Starting recoverer for AS ID %s which was marked as " "DOWN", + r.service.id, + ) r.recover() defer.returnValue(recoverers) @@ -232,9 +221,9 @@ class _Recoverer(object): def recover(self): def _retry(): run_as_background_process( - "as-recoverer-%s" % (self.service.id,), - self.retry, + "as-recoverer-%s" % (self.service.id,), self.retry ) + self.clock.call_later((2 ** self.backoff_counter), _retry) def _backoff(self): @@ -248,8 +237,9 @@ class _Recoverer(object): try: txn = yield self.store.get_oldest_unsent_txn(self.service) if txn: - logger.info("Retrying transaction %s for AS ID %s", - txn.id, txn.service.id) + logger.info( + "Retrying transaction %s for AS ID %s", txn.id, txn.service.id + ) sent = yield txn.send(self.as_api) if sent: yield txn.complete(self.store) diff --git a/synapse/config/_base.py b/synapse/config/_base.py index bf039e5823..d3396922ae 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -285,8 +285,8 @@ class Config(object): if not config_files: config_parser.error( "Must supply a config file.\nA config file can be automatically" - " generated using \"--generate-config -H SERVER_NAME" - " -c CONFIG-FILE\"" + ' generated using "--generate-config -H SERVER_NAME' + ' -c CONFIG-FILE"' ) (config_path,) = config_files if not cls.path_exists(config_path): @@ -314,9 +314,7 @@ class Config(object): if not cls.path_exists(config_dir_path): os.makedirs(config_dir_path) with open(config_path, "w") as config_file: - config_file.write( - "# vim:ft=yaml\n\n" - ) + config_file.write("# vim:ft=yaml\n\n") config_file.write(config_str) config = yaml.safe_load(config_str) @@ -353,8 +351,8 @@ class Config(object): if not config_files: config_parser.error( "Must supply a config file.\nA config file can be automatically" - " generated using \"--generate-config -H SERVER_NAME" - " -c CONFIG-FILE\"" + ' generated using "--generate-config -H SERVER_NAME' + ' -c CONFIG-FILE"' ) obj.read_config_files( diff --git a/synapse/config/api.py b/synapse/config/api.py index 5eb4f86fa2..23b0ea6962 100644 --- a/synapse/config/api.py +++ b/synapse/config/api.py @@ -18,15 +18,17 @@ from ._base import Config class ApiConfig(Config): - def read_config(self, config): - self.room_invite_state_types = config.get("room_invite_state_types", [ - EventTypes.JoinRules, - EventTypes.CanonicalAlias, - EventTypes.RoomAvatar, - EventTypes.RoomEncryption, - EventTypes.Name, - ]) + self.room_invite_state_types = config.get( + "room_invite_state_types", + [ + EventTypes.JoinRules, + EventTypes.CanonicalAlias, + EventTypes.RoomAvatar, + EventTypes.RoomEncryption, + EventTypes.Name, + ], + ) def default_config(cls, **kwargs): return """\ @@ -40,4 +42,6 @@ class ApiConfig(Config): # - "{RoomAvatar}" # - "{RoomEncryption}" # - "{Name}" - """.format(**vars(EventTypes)) + """.format( + **vars(EventTypes) + ) diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 7e89d345d8..679ee62480 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -29,7 +29,6 @@ logger = logging.getLogger(__name__) class AppServiceConfig(Config): - def read_config(self, config): self.app_service_config_files = config.get("app_service_config_files", []) self.notify_appservices = config.get("notify_appservices", True) @@ -53,9 +52,7 @@ class AppServiceConfig(Config): def load_appservices(hostname, config_files): """Returns a list of Application Services from the config files.""" if not isinstance(config_files, list): - logger.warning( - "Expected %s to be a list of AS config files.", config_files - ) + logger.warning("Expected %s to be a list of AS config files.", config_files) return [] # Dicts of value -> filename @@ -66,22 +63,20 @@ def load_appservices(hostname, config_files): for config_file in config_files: try: - with open(config_file, 'r') as f: - appservice = _load_appservice( - hostname, yaml.safe_load(f), config_file - ) + with open(config_file, "r") as f: + appservice = _load_appservice(hostname, yaml.safe_load(f), config_file) if appservice.id in seen_ids: raise ConfigError( "Cannot reuse ID across application services: " - "%s (files: %s, %s)" % ( - appservice.id, config_file, seen_ids[appservice.id], - ) + "%s (files: %s, %s)" + % (appservice.id, config_file, seen_ids[appservice.id]) ) seen_ids[appservice.id] = config_file if appservice.token in seen_as_tokens: raise ConfigError( "Cannot reuse as_token across application services: " - "%s (files: %s, %s)" % ( + "%s (files: %s, %s)" + % ( appservice.token, config_file, seen_as_tokens[appservice.token], @@ -98,28 +93,26 @@ def load_appservices(hostname, config_files): def _load_appservice(hostname, as_info, config_filename): - required_string_fields = [ - "id", "as_token", "hs_token", "sender_localpart" - ] + required_string_fields = ["id", "as_token", "hs_token", "sender_localpart"] for field in required_string_fields: if not isinstance(as_info.get(field), string_types): - raise KeyError("Required string field: '%s' (%s)" % ( - field, config_filename, - )) + raise KeyError( + "Required string field: '%s' (%s)" % (field, config_filename) + ) # 'url' must either be a string or explicitly null, not missing # to avoid accidentally turning off push for ASes. - if (not isinstance(as_info.get("url"), string_types) and - as_info.get("url", "") is not None): + if ( + not isinstance(as_info.get("url"), string_types) + and as_info.get("url", "") is not None + ): raise KeyError( "Required string field or explicit null: 'url' (%s)" % (config_filename,) ) localpart = as_info["sender_localpart"] if urlparse.quote(localpart) != localpart: - raise ValueError( - "sender_localpart needs characters which are not URL encoded." - ) + raise ValueError("sender_localpart needs characters which are not URL encoded.") user = UserID(localpart, hostname) user_id = user.to_string() @@ -138,13 +131,12 @@ def _load_appservice(hostname, as_info, config_filename): for regex_obj in as_info["namespaces"][ns]: if not isinstance(regex_obj, dict): raise ValueError( - "Expected namespace entry in %s to be an object," - " but got %s", ns, regex_obj + "Expected namespace entry in %s to be an object," " but got %s", + ns, + regex_obj, ) if not isinstance(regex_obj.get("regex"), string_types): - raise ValueError( - "Missing/bad type 'regex' key in %s", regex_obj - ) + raise ValueError("Missing/bad type 'regex' key in %s", regex_obj) if not isinstance(regex_obj.get("exclusive"), bool): raise ValueError( "Missing/bad type 'exclusive' key in %s", regex_obj @@ -167,10 +159,8 @@ def _load_appservice(hostname, as_info, config_filename): ) ip_range_whitelist = None - if as_info.get('ip_range_whitelist'): - ip_range_whitelist = IPSet( - as_info.get('ip_range_whitelist') - ) + if as_info.get("ip_range_whitelist"): + ip_range_whitelist = IPSet(as_info.get("ip_range_whitelist")) return ApplicationService( token=as_info["as_token"], diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py index f7eebf26d2..e2eb473a92 100644 --- a/synapse/config/captcha.py +++ b/synapse/config/captcha.py @@ -16,7 +16,6 @@ from ._base import Config class CaptchaConfig(Config): - def read_config(self, config): self.recaptcha_private_key = config.get("recaptcha_private_key") self.recaptcha_public_key = config.get("recaptcha_public_key") diff --git a/synapse/config/consent_config.py b/synapse/config/consent_config.py index abeb0180d3..5b0bf919c7 100644 --- a/synapse/config/consent_config.py +++ b/synapse/config/consent_config.py @@ -89,29 +89,26 @@ class ConsentConfig(Config): if consent_config is None: return self.user_consent_version = str(consent_config["version"]) - self.user_consent_template_dir = self.abspath( - consent_config["template_dir"] - ) + self.user_consent_template_dir = self.abspath(consent_config["template_dir"]) if not path.isdir(self.user_consent_template_dir): raise ConfigError( - "Could not find template directory '%s'" % ( - self.user_consent_template_dir, - ), + "Could not find template directory '%s'" + % (self.user_consent_template_dir,) ) self.user_consent_server_notice_content = consent_config.get( - "server_notice_content", + "server_notice_content" ) self.block_events_without_consent_error = consent_config.get( - "block_events_error", + "block_events_error" + ) + self.user_consent_server_notice_to_guests = bool( + consent_config.get("send_server_notice_to_guests", False) + ) + self.user_consent_at_registration = bool( + consent_config.get("require_at_registration", False) ) - self.user_consent_server_notice_to_guests = bool(consent_config.get( - "send_server_notice_to_guests", False, - )) - self.user_consent_at_registration = bool(consent_config.get( - "require_at_registration", False, - )) self.user_consent_policy_name = consent_config.get( - "policy_name", "Privacy Policy", + "policy_name", "Privacy Policy" ) def default_config(self, **kwargs): diff --git a/synapse/config/database.py b/synapse/config/database.py index 3c27ed6b4a..adc0a47ddf 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -18,29 +18,21 @@ from ._base import Config class DatabaseConfig(Config): - def read_config(self, config): - self.event_cache_size = self.parse_size( - config.get("event_cache_size", "10K") - ) + self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K")) self.database_config = config.get("database") if self.database_config is None: - self.database_config = { - "name": "sqlite3", - "args": {}, - } + self.database_config = {"name": "sqlite3", "args": {}} name = self.database_config.get("name", None) if name == "psycopg2": pass elif name == "sqlite3": - self.database_config.setdefault("args", {}).update({ - "cp_min": 1, - "cp_max": 1, - "check_same_thread": False, - }) + self.database_config.setdefault("args", {}).update( + {"cp_min": 1, "cp_max": 1, "check_same_thread": False} + ) else: raise RuntimeError("Unsupported database type '%s'" % (name,)) @@ -48,7 +40,8 @@ class DatabaseConfig(Config): def default_config(self, data_dir_path, **kwargs): database_path = os.path.join(data_dir_path, "homeserver.db") - return """\ + return ( + """\ ## Database ## database: @@ -62,7 +55,9 @@ class DatabaseConfig(Config): # Number of events to cache in memory. # #event_cache_size: 10K - """ % locals() + """ + % locals() + ) def read_arguments(self, args): self.set_databasepath(args.database_path) @@ -77,6 +72,8 @@ class DatabaseConfig(Config): def add_arguments(self, parser): db_group = parser.add_argument_group("database") db_group.add_argument( - "-d", "--database-path", metavar="SQLITE_DATABASE_PATH", - help="The path to a sqlite database to use." + "-d", + "--database-path", + metavar="SQLITE_DATABASE_PATH", + help="The path to a sqlite database to use.", ) diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 86018dfcce..3a6cb07206 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -56,7 +56,7 @@ class EmailConfig(Config): if self.email_notif_from is not None: # make sure it's valid parsed = email.utils.parseaddr(self.email_notif_from) - if parsed[1] == '': + if parsed[1] == "": raise RuntimeError("Invalid notif_from address") template_dir = email_config.get("template_dir") @@ -65,19 +65,17 @@ class EmailConfig(Config): # (Note that loading as package_resources with jinja.PackageLoader doesn't # work for the same reason.) if not template_dir: - template_dir = pkg_resources.resource_filename( - 'synapse', 'res/templates' - ) + template_dir = pkg_resources.resource_filename("synapse", "res/templates") self.email_template_dir = os.path.abspath(template_dir) self.email_enable_notifs = email_config.get("enable_notifs", False) - account_validity_renewal_enabled = config.get( - "account_validity", {}, - ).get("renew_at") + account_validity_renewal_enabled = config.get("account_validity", {}).get( + "renew_at" + ) email_trust_identity_server_for_password_resets = email_config.get( - "trust_identity_server_for_password_resets", False, + "trust_identity_server_for_password_resets", False ) self.email_password_reset_behaviour = ( "remote" if email_trust_identity_server_for_password_resets else "local" @@ -103,62 +101,59 @@ class EmailConfig(Config): # make sure we can import the required deps import jinja2 import bleach + # prevent unused warnings jinja2 bleach if self.email_password_reset_behaviour == "local": - required = [ - "smtp_host", - "smtp_port", - "notif_from", - ] + required = ["smtp_host", "smtp_port", "notif_from"] missing = [] for k in required: if k not in email_config: missing.append(k) - if (len(missing) > 0): + if len(missing) > 0: raise RuntimeError( "email.password_reset_behaviour is set to 'local' " - "but required keys are missing: %s" % - (", ".join(["email." + k for k in missing]),) + "but required keys are missing: %s" + % (", ".join(["email." + k for k in missing]),) ) # Templates for password reset emails self.email_password_reset_template_html = email_config.get( - "password_reset_template_html", "password_reset.html", + "password_reset_template_html", "password_reset.html" ) self.email_password_reset_template_text = email_config.get( - "password_reset_template_text", "password_reset.txt", + "password_reset_template_text", "password_reset.txt" ) self.email_password_reset_failure_template = email_config.get( - "password_reset_failure_template", "password_reset_failure.html", + "password_reset_failure_template", "password_reset_failure.html" ) # This template does not support any replaceable variables, so we will # read it from the disk once during setup email_password_reset_success_template = email_config.get( - "password_reset_success_template", "password_reset_success.html", + "password_reset_success_template", "password_reset_success.html" ) # Check templates exist - for f in [self.email_password_reset_template_html, - self.email_password_reset_template_text, - self.email_password_reset_failure_template, - email_password_reset_success_template]: + for f in [ + self.email_password_reset_template_html, + self.email_password_reset_template_text, + self.email_password_reset_failure_template, + email_password_reset_success_template, + ]: p = os.path.join(self.email_template_dir, f) if not os.path.isfile(p): - raise ConfigError("Unable to find template file %s" % (p, )) + raise ConfigError("Unable to find template file %s" % (p,)) # Retrieve content of web templates filepath = os.path.join( - self.email_template_dir, - email_password_reset_success_template, + self.email_template_dir, email_password_reset_success_template ) self.email_password_reset_success_html_content = self.read_file( - filepath, - "email.password_reset_template_success_html", + filepath, "email.password_reset_template_success_html" ) if config.get("public_baseurl") is None: @@ -182,10 +177,10 @@ class EmailConfig(Config): if k not in email_config: missing.append(k) - if (len(missing) > 0): + if len(missing) > 0: raise RuntimeError( - "email.enable_notifs is True but required keys are missing: %s" % - (", ".join(["email." + k for k in missing]),) + "email.enable_notifs is True but required keys are missing: %s" + % (", ".join(["email." + k for k in missing]),) ) if config.get("public_baseurl") is None: @@ -199,27 +194,25 @@ class EmailConfig(Config): for f in self.email_notif_template_text, self.email_notif_template_html: p = os.path.join(self.email_template_dir, f) if not os.path.isfile(p): - raise ConfigError("Unable to find email template file %s" % (p, )) + raise ConfigError("Unable to find email template file %s" % (p,)) self.email_notif_for_new_users = email_config.get( "notif_for_new_users", True ) - self.email_riot_base_url = email_config.get( - "riot_base_url", None - ) + self.email_riot_base_url = email_config.get("riot_base_url", None) if account_validity_renewal_enabled: self.email_expiry_template_html = email_config.get( - "expiry_template_html", "notice_expiry.html", + "expiry_template_html", "notice_expiry.html" ) self.email_expiry_template_text = email_config.get( - "expiry_template_text", "notice_expiry.txt", + "expiry_template_text", "notice_expiry.txt" ) for f in self.email_expiry_template_text, self.email_expiry_template_html: p = os.path.join(self.email_template_dir, f) if not os.path.isfile(p): - raise ConfigError("Unable to find email template file %s" % (p, )) + raise ConfigError("Unable to find email template file %s" % (p,)) def default_config(self, config_dir_path, server_name, **kwargs): return """ diff --git a/synapse/config/jwt_config.py b/synapse/config/jwt_config.py index ecb4124096..b190dcbe38 100644 --- a/synapse/config/jwt_config.py +++ b/synapse/config/jwt_config.py @@ -15,13 +15,11 @@ from ._base import Config, ConfigError -MISSING_JWT = ( - """Missing jwt library. This is required for jwt login. +MISSING_JWT = """Missing jwt library. This is required for jwt login. Install by running: pip install pyjwt """ -) class JWTConfig(Config): @@ -34,6 +32,7 @@ class JWTConfig(Config): try: import jwt + jwt # To stop unused lint. except ImportError: raise ConfigError(MISSING_JWT) diff --git a/synapse/config/key.py b/synapse/config/key.py index 424875feae..94a0f47ea4 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -348,9 +348,8 @@ def _parse_key_servers(key_servers, federation_verify_certificates): result.verify_keys[key_id] = verify_key - if ( - not federation_verify_certificates and - not server.get("accept_keys_insecurely") + if not federation_verify_certificates and not server.get( + "accept_keys_insecurely" ): _assert_keyserver_has_verify_keys(result) diff --git a/synapse/config/logger.py b/synapse/config/logger.py index c1febbe9d3..a22655b125 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -29,7 +29,8 @@ from synapse.util.versionstring import get_version_string from ._base import Config -DEFAULT_LOG_CONFIG = Template(""" +DEFAULT_LOG_CONFIG = Template( + """ version: 1 formatters: @@ -68,11 +69,11 @@ loggers: root: level: INFO handlers: [file, console] -""") +""" +) class LoggingConfig(Config): - def read_config(self, config): self.verbosity = config.get("verbose", 0) self.no_redirect_stdio = config.get("no_redirect_stdio", False) @@ -81,13 +82,16 @@ class LoggingConfig(Config): def default_config(self, config_dir_path, server_name, **kwargs): log_config = os.path.join(config_dir_path, server_name + ".log.config") - return """\ + return ( + """\ ## Logging ## # A yaml python logging config file # log_config: "%(log_config)s" - """ % locals() + """ + % locals() + ) def read_arguments(self, args): if args.verbose is not None: @@ -102,22 +106,31 @@ class LoggingConfig(Config): def add_arguments(cls, parser): logging_group = parser.add_argument_group("logging") logging_group.add_argument( - '-v', '--verbose', dest="verbose", action='count', + "-v", + "--verbose", + dest="verbose", + action="count", help="The verbosity level. Specify multiple times to increase " - "verbosity. (Ignored if --log-config is specified.)" + "verbosity. (Ignored if --log-config is specified.)", ) logging_group.add_argument( - '-f', '--log-file', dest="log_file", - help="File to log to. (Ignored if --log-config is specified.)" + "-f", + "--log-file", + dest="log_file", + help="File to log to. (Ignored if --log-config is specified.)", ) logging_group.add_argument( - '--log-config', dest="log_config", default=None, - help="Python logging config file" + "--log-config", + dest="log_config", + default=None, + help="Python logging config file", ) logging_group.add_argument( - '-n', '--no-redirect-stdio', - action='store_true', default=None, - help="Do not redirect stdout/stderr to the log" + "-n", + "--no-redirect-stdio", + action="store_true", + default=None, + help="Do not redirect stdout/stderr to the log", ) def generate_files(self, config): @@ -125,9 +138,7 @@ class LoggingConfig(Config): if log_config and not os.path.exists(log_config): log_file = self.abspath("homeserver.log") with open(log_config, "w") as log_config_file: - log_config_file.write( - DEFAULT_LOG_CONFIG.substitute(log_file=log_file) - ) + log_config_file.write(DEFAULT_LOG_CONFIG.substitute(log_file=log_file)) def setup_logging(config, use_worker_options=False): @@ -143,10 +154,8 @@ def setup_logging(config, use_worker_options=False): register_sighup (func | None): Function to call to register a sighup handler. """ - log_config = (config.worker_log_config if use_worker_options - else config.log_config) - log_file = (config.worker_log_file if use_worker_options - else config.log_file) + log_config = config.worker_log_config if use_worker_options else config.log_config + log_file = config.worker_log_file if use_worker_options else config.log_file log_format = ( "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s" @@ -164,23 +173,23 @@ def setup_logging(config, use_worker_options=False): if config.verbosity > 1: level_for_storage = logging.DEBUG - logger = logging.getLogger('') + logger = logging.getLogger("") logger.setLevel(level) - logging.getLogger('synapse.storage.SQL').setLevel(level_for_storage) + logging.getLogger("synapse.storage.SQL").setLevel(level_for_storage) formatter = logging.Formatter(log_format) if log_file: # TODO: Customisable file size / backup count handler = logging.handlers.RotatingFileHandler( - log_file, maxBytes=(1000 * 1000 * 100), backupCount=3, - encoding='utf8' + log_file, maxBytes=(1000 * 1000 * 100), backupCount=3, encoding="utf8" ) def sighup(signum, stack): logger.info("Closing log file due to SIGHUP") handler.doRollover() logger.info("Opened new log file due to SIGHUP") + else: handler = logging.StreamHandler() @@ -193,8 +202,9 @@ def setup_logging(config, use_worker_options=False): logger.addHandler(handler) else: + def load_log_config(): - with open(log_config, 'r') as f: + with open(log_config, "r") as f: logging.config.dictConfig(yaml.safe_load(f)) def sighup(*args): @@ -209,10 +219,7 @@ def setup_logging(config, use_worker_options=False): # make sure that the first thing we log is a thing we can grep backwards # for logging.warn("***** STARTING SERVER *****") - logging.warn( - "Server %s version %s", - sys.argv[0], get_version_string(synapse), - ) + logging.warn("Server %s version %s", sys.argv[0], get_version_string(synapse)) logging.info("Server hostname: %s", config.server_name) # It's critical to point twisted's internal logging somewhere, otherwise it @@ -242,8 +249,7 @@ def setup_logging(config, use_worker_options=False): return observer(event) globalLogBeginner.beginLoggingTo( - [_log], - redirectStandardIO=not config.no_redirect_stdio, + [_log], redirectStandardIO=not config.no_redirect_stdio ) if not config.no_redirect_stdio: print("Redirected stdout/stderr to logs") diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index 2de51979d8..c85e234d22 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -15,11 +15,9 @@ from ._base import Config, ConfigError -MISSING_SENTRY = ( - """Missing sentry-sdk library. This is required to enable sentry +MISSING_SENTRY = """Missing sentry-sdk library. This is required to enable sentry integration. """ -) class MetricsConfig(Config): @@ -39,7 +37,7 @@ class MetricsConfig(Config): self.sentry_dsn = config["sentry"].get("dsn") if not self.sentry_dsn: raise ConfigError( - "sentry.dsn field is required when sentry integration is enabled", + "sentry.dsn field is required when sentry integration is enabled" ) def default_config(self, report_stats=None, **kwargs): @@ -66,6 +64,6 @@ class MetricsConfig(Config): if report_stats is None: res += "# report_stats: true|false\n" else: - res += "report_stats: %s\n" % ('true' if report_stats else 'false') + res += "report_stats: %s\n" % ("true" if report_stats else "false") return res diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py index f0a6be0679..fcf279e8e1 100644 --- a/synapse/config/password_auth_providers.py +++ b/synapse/config/password_auth_providers.py @@ -17,7 +17,7 @@ from synapse.util.module_loader import load_module from ._base import Config -LDAP_PROVIDER = 'ldap_auth_provider.LdapAuthProvider' +LDAP_PROVIDER = "ldap_auth_provider.LdapAuthProvider" class PasswordAuthProviderConfig(Config): @@ -29,24 +29,20 @@ class PasswordAuthProviderConfig(Config): # param. ldap_config = config.get("ldap_config", {}) if ldap_config.get("enabled", False): - providers.append({ - 'module': LDAP_PROVIDER, - 'config': ldap_config, - }) + providers.append({"module": LDAP_PROVIDER, "config": ldap_config}) providers.extend(config.get("password_providers", [])) for provider in providers: - mod_name = provider['module'] + mod_name = provider["module"] # This is for backwards compat when the ldap auth provider resided # in this package. if mod_name == "synapse.util.ldap_auth_provider.LdapAuthProvider": mod_name = LDAP_PROVIDER - (provider_class, provider_config) = load_module({ - "module": mod_name, - "config": provider['config'], - }) + (provider_class, provider_config) = load_module( + {"module": mod_name, "config": provider["config"]} + ) self.password_providers.append((provider_class, provider_config)) diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 14752298e9..c845862bd9 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -26,7 +26,7 @@ from synapse.util.stringutils import random_string_with_symbols class AccountValidityConfig(Config): def __init__(self, config, synapse_config): self.enabled = config.get("enabled", False) - self.renew_by_email_enabled = ("renew_at" in config) + self.renew_by_email_enabled = "renew_at" in config if self.enabled: if "period" in config: @@ -77,7 +77,6 @@ class AccountValidityConfig(Config): class RegistrationConfig(Config): - def read_config(self, config): self.enable_registration = bool( strtobool(str(config.get("enable_registration", False))) @@ -88,7 +87,7 @@ class RegistrationConfig(Config): ) self.account_validity = AccountValidityConfig( - config.get("account_validity", {}), config, + config.get("account_validity", {}), config ) self.registrations_require_3pid = config.get("registrations_require_3pid", []) @@ -104,25 +103,24 @@ class RegistrationConfig(Config): self.registration_shared_secret = config.get("registration_shared_secret") self.register_mxid_from_3pid = config.get("register_mxid_from_3pid") self.register_just_use_email_for_display_name = config.get( - "register_just_use_email_for_display_name", False, + "register_just_use_email_for_display_name", False ) self.bcrypt_rounds = config.get("bcrypt_rounds", 12) self.trusted_third_party_id_servers = config.get( - "trusted_third_party_id_servers", - ["matrix.org", "vector.im"], + "trusted_third_party_id_servers", ["matrix.org", "vector.im"] ) self.default_identity_server = config.get("default_identity_server") self.allow_guest_access = config.get("allow_guest_access", False) - self.invite_3pid_guest = ( - self.allow_guest_access and config.get("invite_3pid_guest", False) + self.invite_3pid_guest = self.allow_guest_access and config.get( + "invite_3pid_guest", False ) self.auto_join_rooms = config.get("auto_join_rooms", []) for room_alias in self.auto_join_rooms: if not RoomAlias.is_valid(room_alias): - raise ConfigError('Invalid auto_join_rooms entry %s' % (room_alias,)) + raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,)) self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True) self.disable_set_displayname = config.get("disable_set_displayname", False) @@ -130,13 +128,15 @@ class RegistrationConfig(Config): self.replicate_user_profiles_to = config.get("replicate_user_profiles_to", []) if not isinstance(self.replicate_user_profiles_to, list): - self.replicate_user_profiles_to = [self.replicate_user_profiles_to, ] + self.replicate_user_profiles_to = [self.replicate_user_profiles_to] self.shadow_server = config.get("shadow_server", None) - self.rewrite_identity_server_urls = config.get("rewrite_identity_server_urls", {}) + self.rewrite_identity_server_urls = config.get( + "rewrite_identity_server_urls", {} + ) - self.disable_msisdn_registration = ( - config.get("disable_msisdn_registration", False) + self.disable_msisdn_registration = config.get( + "disable_msisdn_registration", False ) def default_config(self, generate_secrets=False, **kwargs): @@ -145,9 +145,12 @@ class RegistrationConfig(Config): random_string_with_symbols(50), ) else: - registration_shared_secret = '# registration_shared_secret: ' + registration_shared_secret = ( + "# registration_shared_secret: " + ) - return """\ + return ( + """\ ## Registration ## # # Registration can be rate-limited using the parameters in the "Ratelimiting" @@ -331,17 +334,19 @@ class RegistrationConfig(Config): # users cannot be auto-joined since they do not exist. # #autocreate_auto_join_rooms: true - """ % locals() + """ + % locals() + ) def add_arguments(self, parser): reg_group = parser.add_argument_group("registration") reg_group.add_argument( - "--enable-registration", action="store_true", default=None, - help="Enable registration for new users." + "--enable-registration", + action="store_true", + default=None, + help="Enable registration for new users.", ) def read_arguments(self, args): if args.enable_registration is not None: - self.enable_registration = bool( - strtobool(str(args.enable_registration)) - ) + self.enable_registration = bool(strtobool(str(args.enable_registration))) diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 2abede409a..ed2b060b31 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -20,27 +20,11 @@ from synapse.util.module_loader import load_module from ._base import Config, ConfigError DEFAULT_THUMBNAIL_SIZES = [ - { - "width": 32, - "height": 32, - "method": "crop", - }, { - "width": 96, - "height": 96, - "method": "crop", - }, { - "width": 320, - "height": 240, - "method": "scale", - }, { - "width": 640, - "height": 480, - "method": "scale", - }, { - "width": 800, - "height": 600, - "method": "scale" - }, + {"width": 32, "height": 32, "method": "crop"}, + {"width": 96, "height": 96, "method": "crop"}, + {"width": 320, "height": 240, "method": "scale"}, + {"width": 640, "height": 480, "method": "scale"}, + {"width": 800, "height": 600, "method": "scale"}, ] THUMBNAIL_SIZE_YAML = """\ @@ -49,19 +33,15 @@ THUMBNAIL_SIZE_YAML = """\ # method: %(method)s """ -MISSING_NETADDR = ( - "Missing netaddr library. This is required for URL preview API." -) +MISSING_NETADDR = "Missing netaddr library. This is required for URL preview API." -MISSING_LXML = ( - """Missing lxml library. This is required for URL preview API. +MISSING_LXML = """Missing lxml library. This is required for URL preview API. Install by running: pip install lxml Requires libxslt1-dev system package. """ -) ThumbnailRequirement = namedtuple( @@ -69,7 +49,8 @@ ThumbnailRequirement = namedtuple( ) MediaStorageProviderConfig = namedtuple( - "MediaStorageProviderConfig", ( + "MediaStorageProviderConfig", + ( "store_local", # Whether to store newly uploaded local files "store_remote", # Whether to store newly downloaded remote files "store_synchronous", # Whether to wait for successful storage for local uploads @@ -100,8 +81,7 @@ def parse_thumbnail_requirements(thumbnail_sizes): requirements.setdefault("image/gif", []).append(png_thumbnail) requirements.setdefault("image/png", []).append(png_thumbnail) return { - media_type: tuple(thumbnails) - for media_type, thumbnails in requirements.items() + media_type: tuple(thumbnails) for media_type, thumbnails in requirements.items() } @@ -133,15 +113,15 @@ class ContentRepositoryConfig(Config): "Cannot use both 'backup_media_store_path' and 'storage_providers'" ) - storage_providers = [{ - "module": "file_system", - "store_local": True, - "store_synchronous": synchronous_backup_media_store, - "store_remote": True, - "config": { - "directory": backup_media_store_path, + storage_providers = [ + { + "module": "file_system", + "store_local": True, + "store_synchronous": synchronous_backup_media_store, + "store_remote": True, + "config": {"directory": backup_media_store_path}, } - }] + ] # This is a list of config that can be used to create the storage # providers. The entries are tuples of (Class, class_config, @@ -171,18 +151,19 @@ class ContentRepositoryConfig(Config): ) self.media_storage_providers.append( - (provider_class, parsed_config, wrapper_config,) + (provider_class, parsed_config, wrapper_config) ) self.uploads_path = self.ensure_directory(config["uploads_path"]) self.dynamic_thumbnails = config.get("dynamic_thumbnails", False) self.thumbnail_requirements = parse_thumbnail_requirements( - config.get("thumbnail_sizes", DEFAULT_THUMBNAIL_SIZES), + config.get("thumbnail_sizes", DEFAULT_THUMBNAIL_SIZES) ) self.url_preview_enabled = config.get("url_preview_enabled", False) if self.url_preview_enabled: try: import lxml + lxml # To stop unused lint. except ImportError: raise ConfigError(MISSING_LXML) @@ -205,15 +186,13 @@ class ContentRepositoryConfig(Config): # we always blacklist '0.0.0.0' and '::', which are supposed to be # unroutable addresses. - self.url_preview_ip_range_blacklist.update(['0.0.0.0', '::']) + self.url_preview_ip_range_blacklist.update(["0.0.0.0", "::"]) self.url_preview_ip_range_whitelist = IPSet( config.get("url_preview_ip_range_whitelist", ()) ) - self.url_preview_url_blacklist = config.get( - "url_preview_url_blacklist", () - ) + self.url_preview_url_blacklist = config.get("url_preview_url_blacklist", ()) def default_config(self, data_dir_path, **kwargs): media_store = os.path.join(data_dir_path, "media_store") @@ -225,7 +204,8 @@ class ContentRepositoryConfig(Config): # strip final NL formatted_thumbnail_sizes = formatted_thumbnail_sizes[:-1] - return r""" + return ( + r""" # Directory where uploaded images and attachments are stored. # media_store_path: "%(media_store)s" @@ -372,4 +352,6 @@ class ContentRepositoryConfig(Config): # The largest allowed URL preview spidering size in bytes # #max_spider_size: 10M - """ % locals() + """ + % locals() + ) diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py index 8a9fded4c5..c1da0e20e0 100644 --- a/synapse/config/room_directory.py +++ b/synapse/config/room_directory.py @@ -20,9 +20,7 @@ from ._base import Config, ConfigError class RoomDirectoryConfig(Config): def read_config(self, config): - self.enable_room_list_search = config.get( - "enable_room_list_search", True, - ) + self.enable_room_list_search = config.get("enable_room_list_search", True) alias_creation_rules = config.get("alias_creation_rules") @@ -33,11 +31,7 @@ class RoomDirectoryConfig(Config): ] else: self._alias_creation_rules = [ - _RoomDirectoryRule( - "alias_creation_rules", { - "action": "allow", - } - ) + _RoomDirectoryRule("alias_creation_rules", {"action": "allow"}) ] room_list_publication_rules = config.get("room_list_publication_rules") @@ -49,11 +43,7 @@ class RoomDirectoryConfig(Config): ] else: self._room_list_publication_rules = [ - _RoomDirectoryRule( - "room_list_publication_rules", { - "action": "allow", - } - ) + _RoomDirectoryRule("room_list_publication_rules", {"action": "allow"}) ] def default_config(self, config_dir_path, server_name, **kwargs): @@ -178,8 +168,7 @@ class _RoomDirectoryRule(object): self.action = action else: raise ConfigError( - "%s rules can only have action of 'allow'" - " or 'deny'" % (option_name,) + "%s rules can only have action of 'allow'" " or 'deny'" % (option_name,) ) self._alias_matches_all = alias == "*" diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index aa6eac271f..2ec38e48e9 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -28,6 +28,7 @@ class SAML2Config(Config): self.saml2_enabled = True import saml2.config + self.saml2_sp_config = saml2.config.SPConfig() self.saml2_sp_config.load(self._default_saml_config_dict()) self.saml2_sp_config.load(saml2_config.get("sp_config", {})) @@ -41,26 +42,23 @@ class SAML2Config(Config): public_baseurl = self.public_baseurl if public_baseurl is None: - raise ConfigError( - "saml2_config requires a public_baseurl to be set" - ) + raise ConfigError("saml2_config requires a public_baseurl to be set") metadata_url = public_baseurl + "_matrix/saml2/metadata.xml" response_url = public_baseurl + "_matrix/saml2/authn_response" return { "entityid": metadata_url, - "service": { "sp": { "endpoints": { "assertion_consumer_service": [ - (response_url, saml2.BINDING_HTTP_POST), - ], + (response_url, saml2.BINDING_HTTP_POST) + ] }, "required_attributes": ["uid"], "optional_attributes": ["mail", "surname", "givenname"], - }, - } + } + }, } def default_config(self, config_dir_path, server_name, **kwargs): @@ -106,4 +104,6 @@ class SAML2Config(Config): # # separate pysaml2 configuration file: # # # config_path: "%(config_dir_path)s/sp_conf.py" - """ % {"config_dir_path": config_dir_path} + """ % { + "config_dir_path": config_dir_path + } diff --git a/synapse/config/server.py b/synapse/config/server.py index 4f3b033121..be65268363 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -34,13 +34,12 @@ logger = logging.Logger(__name__) # # We later check for errors when binding to 0.0.0.0 and ignore them if :: is also in # in the list. -DEFAULT_BIND_ADDRESSES = ['::', '0.0.0.0'] +DEFAULT_BIND_ADDRESSES = ["::", "0.0.0.0"] DEFAULT_ROOM_VERSION = "4" class ServerConfig(Config): - def read_config(self, config): self.server_name = config["server_name"] self.server_context = config.get("server_context", None) @@ -81,13 +80,13 @@ class ServerConfig(Config): # Whether to require authentication to retrieve profile data (avatars, # display names) of other users through the client API. self.require_auth_for_profile_requests = config.get( - "require_auth_for_profile_requests", False, + "require_auth_for_profile_requests", False ) # Whether to require sharing a room with a user to retrieve their # profile data self.limit_profile_requests_to_known_users = config.get( - "limit_profile_requests_to_known_users", False, + "limit_profile_requests_to_known_users", False ) if "restrict_public_rooms_to_local_users" in config and ( @@ -117,17 +116,15 @@ class ServerConfig(Config): "allow_public_rooms_over_federation", True ) - default_room_version = config.get( - "default_room_version", DEFAULT_ROOM_VERSION, - ) + default_room_version = config.get("default_room_version", DEFAULT_ROOM_VERSION) # Ensure room version is a str default_room_version = str(default_room_version) if default_room_version not in KNOWN_ROOM_VERSIONS: raise ConfigError( - "Unknown default_room_version: %s, known room versions: %s" % - (default_room_version, list(KNOWN_ROOM_VERSIONS.keys())) + "Unknown default_room_version: %s, known room versions: %s" + % (default_room_version, list(KNOWN_ROOM_VERSIONS.keys())) ) # Get the actual room version object rather than just the identifier @@ -142,31 +139,25 @@ class ServerConfig(Config): # Whether we should block invites sent to users on this server # (other than those sent by local server admins) - self.block_non_admin_invites = config.get( - "block_non_admin_invites", False, - ) + self.block_non_admin_invites = config.get("block_non_admin_invites", False) # Whether to enable experimental MSC1849 (aka relations) support self.experimental_msc1849_support_enabled = config.get( - "experimental_msc1849_support_enabled", False, + "experimental_msc1849_support_enabled", False ) # Options to control access by tracking MAU self.limit_usage_by_mau = config.get("limit_usage_by_mau", False) self.max_mau_value = 0 if self.limit_usage_by_mau: - self.max_mau_value = config.get( - "max_mau_value", 0, - ) + self.max_mau_value = config.get("max_mau_value", 0) self.mau_stats_only = config.get("mau_stats_only", False) self.mau_limits_reserved_threepids = config.get( "mau_limit_reserved_threepids", [] ) - self.mau_trial_days = config.get( - "mau_trial_days", 0, - ) + self.mau_trial_days = config.get("mau_trial_days", 0) # Options to disable HS self.hs_disabled = config.get("hs_disabled", False) @@ -179,9 +170,7 @@ class ServerConfig(Config): # FIXME: federation_domain_whitelist needs sytests self.federation_domain_whitelist = None - federation_domain_whitelist = config.get( - "federation_domain_whitelist", None, - ) + federation_domain_whitelist = config.get("federation_domain_whitelist", None) if federation_domain_whitelist is not None: # turn the whitelist into a hash for speed of lookup @@ -191,7 +180,7 @@ class ServerConfig(Config): self.federation_domain_whitelist[domain] = True self.federation_ip_range_blacklist = config.get( - "federation_ip_range_blacklist", [], + "federation_ip_range_blacklist", [] ) # Attempt to create an IPSet from the given ranges @@ -204,13 +193,12 @@ class ServerConfig(Config): self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) except Exception as e: raise ConfigError( - "Invalid range(s) provided in " - "federation_ip_range_blacklist: %s" % e + "Invalid range(s) provided in " "federation_ip_range_blacklist: %s" % e ) if self.public_baseurl is not None: - if self.public_baseurl[-1] != '/': - self.public_baseurl += '/' + if self.public_baseurl[-1] != "/": + self.public_baseurl += "/" self.start_pushers = config.get("start_pushers", True) # (undocumented) option for torturing the worker-mode replication a bit, @@ -221,7 +209,7 @@ class ServerConfig(Config): # Whether to require a user to be in the room to add an alias to it. # Defaults to True. self.require_membership_for_aliases = config.get( - "require_membership_for_aliases", True, + "require_membership_for_aliases", True ) # Whether to allow per-room membership profiles through the send of membership @@ -231,7 +219,7 @@ class ServerConfig(Config): # Whether to show the users on this homeserver in the user directory. Defaults to # True. self.show_users_in_user_directory = config.get( - "show_users_in_user_directory", True, + "show_users_in_user_directory", True ) retention_config = config.get("retention") @@ -275,13 +263,17 @@ class ServerConfig(Config): self.retention_default_min_lifetime = None self.retention_default_max_lifetime = None - self.retention_allowed_lifetime_min = retention_config.get("allowed_lifetime_min") + self.retention_allowed_lifetime_min = retention_config.get( + "allowed_lifetime_min" + ) if self.retention_allowed_lifetime_min is not None: self.retention_allowed_lifetime_min = self.parse_duration( self.retention_allowed_lifetime_min ) - self.retention_allowed_lifetime_max = retention_config.get("allowed_lifetime_max") + self.retention_allowed_lifetime_max = retention_config.get( + "allowed_lifetime_max" + ) if self.retention_allowed_lifetime_max is not None: self.retention_allowed_lifetime_max = self.parse_duration( self.retention_allowed_lifetime_max @@ -290,7 +282,8 @@ class ServerConfig(Config): if ( self.retention_allowed_lifetime_min is not None and self.retention_allowed_lifetime_max is not None - and self.retention_allowed_lifetime_min > self.retention_allowed_lifetime_max + and self.retention_allowed_lifetime_min + > self.retention_allowed_lifetime_max ): raise ConfigError( "Invalid retention policy limits: 'allowed_lifetime_min' can not be" @@ -330,18 +323,22 @@ class ServerConfig(Config): " 'longest_max_lifetime' value." ) - self.retention_purge_jobs.append({ - "interval": interval, - "shortest_max_lifetime": shortest_max_lifetime, - "longest_max_lifetime": longest_max_lifetime, - }) + self.retention_purge_jobs.append( + { + "interval": interval, + "shortest_max_lifetime": shortest_max_lifetime, + "longest_max_lifetime": longest_max_lifetime, + } + ) if not self.retention_purge_jobs: - self.retention_purge_jobs = [{ - "interval": self.parse_duration("1d"), - "shortest_max_lifetime": None, - "longest_max_lifetime": None, - }] + self.retention_purge_jobs = [ + { + "interval": self.parse_duration("1d"), + "shortest_max_lifetime": None, + "longest_max_lifetime": None, + } + ] self.listeners = [] for listener in config.get("listeners", []): @@ -368,9 +365,9 @@ class ServerConfig(Config): # if we still have an empty list of addresses, use the default list if not bind_addresses: - if listener['type'] == 'metrics': + if listener["type"] == "metrics": # the metrics listener doesn't support IPv6 - bind_addresses.append('0.0.0.0') + bind_addresses.append("0.0.0.0") else: bind_addresses.extend(DEFAULT_BIND_ADDRESSES) @@ -390,78 +387,72 @@ class ServerConfig(Config): bind_host = config.get("bind_host", "") gzip_responses = config.get("gzip_responses", True) - self.listeners.append({ - "port": bind_port, - "bind_addresses": [bind_host], - "tls": True, - "type": "http", - "resources": [ - { - "names": ["client"], - "compress": gzip_responses, - }, - { - "names": ["federation"], - "compress": False, - } - ] - }) - - unsecure_port = config.get("unsecure_port", bind_port - 400) - if unsecure_port: - self.listeners.append({ - "port": unsecure_port, + self.listeners.append( + { + "port": bind_port, "bind_addresses": [bind_host], - "tls": False, + "tls": True, "type": "http", "resources": [ - { - "names": ["client"], - "compress": gzip_responses, - }, - { - "names": ["federation"], - "compress": False, - } - ] - }) + {"names": ["client"], "compress": gzip_responses}, + {"names": ["federation"], "compress": False}, + ], + } + ) + + unsecure_port = config.get("unsecure_port", bind_port - 400) + if unsecure_port: + self.listeners.append( + { + "port": unsecure_port, + "bind_addresses": [bind_host], + "tls": False, + "type": "http", + "resources": [ + {"names": ["client"], "compress": gzip_responses}, + {"names": ["federation"], "compress": False}, + ], + } + ) manhole = config.get("manhole") if manhole: - self.listeners.append({ - "port": manhole, - "bind_addresses": ["127.0.0.1"], - "type": "manhole", - "tls": False, - }) + self.listeners.append( + { + "port": manhole, + "bind_addresses": ["127.0.0.1"], + "type": "manhole", + "tls": False, + } + ) metrics_port = config.get("metrics_port") if metrics_port: logger.warn( - ("The metrics_port configuration option is deprecated in Synapse 0.31 " - "in favour of a listener. Please see " - "http://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.rst" - " on how to configure the new listener.")) - - self.listeners.append({ - "port": metrics_port, - "bind_addresses": [config.get("metrics_bind_host", "127.0.0.1")], - "tls": False, - "type": "http", - "resources": [ - { - "names": ["metrics"], - "compress": False, - }, - ] - }) + ( + "The metrics_port configuration option is deprecated in Synapse 0.31 " + "in favour of a listener. Please see " + "http://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.rst" + " on how to configure the new listener." + ) + ) + + self.listeners.append( + { + "port": metrics_port, + "bind_addresses": [config.get("metrics_bind_host", "127.0.0.1")], + "tls": False, + "type": "http", + "resources": [{"names": ["metrics"], "compress": False}], + } + ) _check_resource_config(self.listeners) # An experimental option to try and periodically clean up extremities # by sending dummy events. self.cleanup_extremities_with_dummy_events = config.get( - "cleanup_extremities_with_dummy_events", False, + "cleanup_extremities_with_dummy_events", False ) def has_tls_listener(self): @@ -480,7 +471,8 @@ class ServerConfig(Config): # Bring DEFAULT_ROOM_VERSION into the local-scope for use in the # default config string default_room_version = DEFAULT_ROOM_VERSION - return """\ + return ( + """\ ## Server ## # The domain name of the server, with optional explicit port. @@ -857,7 +849,9 @@ class ServerConfig(Config): # - shortest_max_lifetime: 3d # longest_max_lifetime: 1y # interval: 24h - """ % locals() + """ + % locals() + ) def read_arguments(self, args): if args.manhole is not None: @@ -869,17 +863,26 @@ class ServerConfig(Config): def add_arguments(self, parser): server_group = parser.add_argument_group("server") - server_group.add_argument("-D", "--daemonize", action='store_true', - default=None, - help="Daemonize the home server") - server_group.add_argument("--print-pidfile", action='store_true', - default=None, - help="Print the path to the pidfile just" - " before daemonizing") - server_group.add_argument("--manhole", metavar="PORT", dest="manhole", - type=int, - help="Turn on the twisted telnet manhole" - " service on the given port.") + server_group.add_argument( + "-D", + "--daemonize", + action="store_true", + default=None, + help="Daemonize the home server", + ) + server_group.add_argument( + "--print-pidfile", + action="store_true", + default=None, + help="Print the path to the pidfile just" " before daemonizing", + ) + server_group.add_argument( + "--manhole", + metavar="PORT", + dest="manhole", + type=int, + help="Turn on the twisted telnet manhole" " service on the given port.", + ) def is_threepid_reserved(reserved_threepids, threepid): @@ -893,7 +896,7 @@ def is_threepid_reserved(reserved_threepids, threepid): """ for tp in reserved_threepids: - if (threepid['medium'] == tp['medium'] and threepid['address'] == tp['address']): + if threepid["medium"] == tp["medium"] and threepid["address"] == tp["address"]: return True return False @@ -906,9 +909,7 @@ def read_gc_thresholds(thresholds): return None try: assert len(thresholds) == 3 - return ( - int(thresholds[0]), int(thresholds[1]), int(thresholds[2]), - ) + return (int(thresholds[0]), int(thresholds[1]), int(thresholds[2])) except Exception: raise ConfigError( "Value of `gc_threshold` must be a list of three integers if set" @@ -926,22 +927,22 @@ def _warn_if_webclient_configured(listeners): for listener in listeners: for res in listener.get("resources", []): for name in res.get("names", []): - if name == 'webclient': + if name == "webclient": logger.warning(NO_MORE_WEB_CLIENT_WARNING) return KNOWN_RESOURCES = ( - 'client', - 'consent', - 'federation', - 'keys', - 'media', - 'metrics', - 'openid', - 'replication', - 'static', - 'webclient', + "client", + "consent", + "federation", + "keys", + "media", + "metrics", + "openid", + "replication", + "static", + "webclient", ) @@ -955,11 +956,9 @@ def _check_resource_config(listeners): for resource in resource_names: if resource not in KNOWN_RESOURCES: - raise ConfigError( - "Unknown listener resource '%s'" % (resource, ) - ) + raise ConfigError("Unknown listener resource '%s'" % (resource,)) if resource == "consent": try: - check_requirements('resources.consent') + check_requirements("resources.consent") except DependencyException as e: raise ConfigError(e.message) diff --git a/synapse/config/server_notices_config.py b/synapse/config/server_notices_config.py index 529dc0a617..d930eb33b5 100644 --- a/synapse/config/server_notices_config.py +++ b/synapse/config/server_notices_config.py @@ -58,6 +58,7 @@ class ServerNoticesConfig(Config): The name to use for the server notices room. None if server notices are not enabled. """ + def __init__(self): super(ServerNoticesConfig, self).__init__() self.server_notices_mxid = None @@ -70,18 +71,12 @@ class ServerNoticesConfig(Config): if c is None: return - mxid_localpart = c['system_mxid_localpart'] - self.server_notices_mxid = UserID( - mxid_localpart, self.server_name, - ).to_string() - self.server_notices_mxid_display_name = c.get( - 'system_mxid_display_name', None, - ) - self.server_notices_mxid_avatar_url = c.get( - 'system_mxid_avatar_url', None, - ) + mxid_localpart = c["system_mxid_localpart"] + self.server_notices_mxid = UserID(mxid_localpart, self.server_name).to_string() + self.server_notices_mxid_display_name = c.get("system_mxid_display_name", None) + self.server_notices_mxid_avatar_url = c.get("system_mxid_avatar_url", None) # todo: i18n - self.server_notices_room_name = c.get('room_name', "Server Notices") + self.server_notices_room_name = c.get("room_name", "Server Notices") def default_config(self, **kwargs): return DEFAULT_CONFIG diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 658f9dd361..eb44df17b5 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -42,11 +42,11 @@ class TlsConfig(Config): self.acme_enabled = acme_config.get("enabled", False) # hyperlink complains on py2 if this is not a Unicode - self.acme_url = six.text_type(acme_config.get( - "url", u"https://acme-v01.api.letsencrypt.org/directory" - )) + self.acme_url = six.text_type( + acme_config.get("url", u"https://acme-v01.api.letsencrypt.org/directory") + ) self.acme_port = acme_config.get("port", 80) - self.acme_bind_addresses = acme_config.get("bind_addresses", ['::', '0.0.0.0']) + self.acme_bind_addresses = acme_config.get("bind_addresses", ["::", "0.0.0.0"]) self.acme_reprovision_threshold = acme_config.get("reprovision_threshold", 30) self.acme_domain = acme_config.get("domain", config.get("server_name")) @@ -74,12 +74,12 @@ class TlsConfig(Config): # Whether to verify certificates on outbound federation traffic self.federation_verify_certificates = config.get( - "federation_verify_certificates", True, + "federation_verify_certificates", True ) # Whitelist of domains to not verify certificates for fed_whitelist_entries = config.get( - "federation_certificate_verification_whitelist", [], + "federation_certificate_verification_whitelist", [] ) # Support globs (*) in whitelist values @@ -90,9 +90,7 @@ class TlsConfig(Config): self.federation_certificate_verification_whitelist.append(entry_regex) # List of custom certificate authorities for federation traffic validation - custom_ca_list = config.get( - "federation_custom_ca_list", None, - ) + custom_ca_list = config.get("federation_custom_ca_list", None) # Read in and parse custom CA certificates self.federation_ca_trust_root = None @@ -101,8 +99,10 @@ class TlsConfig(Config): # A trustroot cannot be generated without any CA certificates. # Raise an error if this option has been specified without any # corresponding certificates. - raise ConfigError("federation_custom_ca_list specified without " - "any certificate files") + raise ConfigError( + "federation_custom_ca_list specified without " + "any certificate files" + ) certs = [] for ca_file in custom_ca_list: @@ -114,8 +114,9 @@ class TlsConfig(Config): cert_base = Certificate.loadPEM(content) certs.append(cert_base) except Exception as e: - raise ConfigError("Error parsing custom CA certificate file %s: %s" - % (ca_file, e)) + raise ConfigError( + "Error parsing custom CA certificate file %s: %s" % (ca_file, e) + ) self.federation_ca_trust_root = trustRootFromCertificates(certs) @@ -146,17 +147,21 @@ class TlsConfig(Config): return None try: - with open(self.tls_certificate_file, 'rb') as f: + with open(self.tls_certificate_file, "rb") as f: cert_pem = f.read() except Exception as e: - raise ConfigError("Failed to read existing certificate file %s: %s" - % (self.tls_certificate_file, e)) + raise ConfigError( + "Failed to read existing certificate file %s: %s" + % (self.tls_certificate_file, e) + ) try: tls_certificate = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem) except Exception as e: - raise ConfigError("Failed to parse existing certificate file %s: %s" - % (self.tls_certificate_file, e)) + raise ConfigError( + "Failed to parse existing certificate file %s: %s" + % (self.tls_certificate_file, e) + ) if not allow_self_signed: if tls_certificate.get_subject() == tls_certificate.get_issuer(): @@ -166,7 +171,7 @@ class TlsConfig(Config): # YYYYMMDDhhmmssZ -- in UTC expires_on = datetime.strptime( - tls_certificate.get_notAfter().decode('ascii'), "%Y%m%d%H%M%SZ" + tls_certificate.get_notAfter().decode("ascii"), "%Y%m%d%H%M%SZ" ) now = datetime.utcnow() days_remaining = (expires_on - now).days @@ -191,7 +196,8 @@ class TlsConfig(Config): except Exception as e: logger.info( "Unable to read TLS certificate (%s). Ignoring as no " - "tls listeners enabled.", e, + "tls listeners enabled.", + e, ) self.tls_fingerprints = list(self._original_tls_fingerprints) @@ -215,8 +221,8 @@ class TlsConfig(Config): # this is to avoid the max line length. Sorrynotsorry proxypassline = ( - 'ProxyPass /.well-known/acme-challenge ' - 'http://localhost:8009/.well-known/acme-challenge' + "ProxyPass /.well-known/acme-challenge " + "http://localhost:8009/.well-known/acme-challenge" ) return ( diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py index 4376a23636..3484785746 100644 --- a/synapse/config/user_directory.py +++ b/synapse/config/user_directory.py @@ -27,14 +27,14 @@ class UserDirectoryConfig(Config): self.user_directory_defer_to_id_server = None user_directory_config = config.get("user_directory", None) if user_directory_config: - self.user_directory_search_enabled = ( - user_directory_config.get("enabled", True) + self.user_directory_search_enabled = user_directory_config.get( + "enabled", True ) - self.user_directory_search_all_users = ( - user_directory_config.get("search_all_users", False) + self.user_directory_search_all_users = user_directory_config.get( + "search_all_users", False ) - self.user_directory_defer_to_id_server = ( - user_directory_config.get("defer_to_id_server", None) + self.user_directory_defer_to_id_server = user_directory_config.get( + "defer_to_id_server", None ) def default_config(self, config_dir_path, server_name, **kwargs): diff --git a/synapse/config/voip.py b/synapse/config/voip.py index 2a1f005a37..82cf8c53a8 100644 --- a/synapse/config/voip.py +++ b/synapse/config/voip.py @@ -16,14 +16,13 @@ from ._base import Config class VoipConfig(Config): - def read_config(self, config): self.turn_uris = config.get("turn_uris", []) self.turn_shared_secret = config.get("turn_shared_secret") self.turn_username = config.get("turn_username") self.turn_password = config.get("turn_password") self.turn_user_lifetime = self.parse_duration( - config.get("turn_user_lifetime", "1h"), + config.get("turn_user_lifetime", "1h") ) self.turn_allow_guests = config.get("turn_allow_guests", True) diff --git a/synapse/config/workers.py b/synapse/config/workers.py index bfbd8b6c91..75993abf35 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -52,12 +52,14 @@ class WorkerConfig(Config): # argument. manhole = config.get("worker_manhole") if manhole: - self.worker_listeners.append({ - "port": manhole, - "bind_addresses": ["127.0.0.1"], - "type": "manhole", - "tls": False, - }) + self.worker_listeners.append( + { + "port": manhole, + "bind_addresses": ["127.0.0.1"], + "type": "manhole", + "tls": False, + } + ) if self.worker_listeners: for listener in self.worker_listeners: @@ -67,7 +69,7 @@ class WorkerConfig(Config): if bind_address: bind_addresses.append(bind_address) elif not bind_addresses: - bind_addresses.append('') + bind_addresses.append("") def read_arguments(self, args): # We support a bunch of command line arguments that override options in diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py index 99a586655b..41eabbe717 100644 --- a/synapse/crypto/event_signing.py +++ b/synapse/crypto/event_signing.py @@ -46,9 +46,7 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256): if name not in hashes: raise SynapseError( 400, - "Algorithm %s not in hashes %s" % ( - name, list(hashes), - ), + "Algorithm %s not in hashes %s" % (name, list(hashes)), Codes.UNAUTHORIZED, ) message_hash_base64 = hashes[name] @@ -56,9 +54,7 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256): message_hash_bytes = decode_base64(message_hash_base64) except Exception: raise SynapseError( - 400, - "Invalid base64: %s" % (message_hash_base64,), - Codes.UNAUTHORIZED, + 400, "Invalid base64: %s" % (message_hash_base64,), Codes.UNAUTHORIZED ) return message_hash_bytes == expected_hash @@ -135,8 +131,9 @@ def compute_event_signature(event_dict, signature_name, signing_key): return redact_json["signatures"] -def add_hashes_and_signatures(event_dict, signature_name, signing_key, - hash_algorithm=hashlib.sha256): +def add_hashes_and_signatures( + event_dict, signature_name, signing_key, hash_algorithm=hashlib.sha256 +): """Add content hash and sign the event Args: @@ -153,7 +150,5 @@ def add_hashes_and_signatures(event_dict, signature_name, signing_key, event_dict.setdefault("hashes", {})[name] = encode_base64(digest) event_dict["signatures"] = compute_event_signature( - event_dict, - signature_name=signature_name, - signing_key=signing_key, + event_dict, signature_name=signature_name, signing_key=signing_key ) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 6f603f1961..6f52440150 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -614,10 +614,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): results = yield logcontext.make_deferred_yieldable( defer.gatherResults( - [ - run_in_background(get_key, server) - for server in self.key_servers - ], + [run_in_background(get_key, server) for server in self.key_servers], consumeErrors=True, ).addErrback(unwrapFirstError) ) @@ -630,9 +627,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): defer.returnValue(union_of_keys) @defer.inlineCallbacks - def get_server_verify_key_v2_indirect( - self, keys_to_fetch, key_server - ): + def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server): """ Args: keys_to_fetch (dict[str, dict[str, int]]): @@ -690,10 +685,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): ) try: - self._validate_perspectives_response( - key_server, - response, - ) + self._validate_perspectives_response(key_server, response) processed_response = yield self.process_v2_response( perspective_name, response, time_added_ms=time_now_ms @@ -720,9 +712,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): defer.returnValue(keys) - def _validate_perspectives_response( - self, key_server, response, - ): + def _validate_perspectives_response(self, key_server, response): """Optionally check the signature on the result of a /key/query request Args: @@ -826,7 +816,6 @@ class ServerKeyFetcher(BaseV2KeyFetcher): path="/_matrix/key/v2/server/" + urllib.parse.quote(requested_key_id), ignore_backoff=True, - # we only give the remote server 10s to respond. It should be an # easy request to handle, so if it doesn't reply within 10s, it's # probably not going to. diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 203490fc36..cd52e3f867 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -85,17 +85,14 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru room_id_domain = get_domain_from_id(event.room_id) if room_id_domain != sender_domain: raise AuthError( - 403, - "Creation event's room_id domain does not match sender's" + 403, "Creation event's room_id domain does not match sender's" ) room_version = event.content.get("room_version", "1") if room_version not in KNOWN_ROOM_VERSIONS: raise AuthError( - 403, - "room appears to have unsupported version %s" % ( - room_version, - )) + 403, "room appears to have unsupported version %s" % (room_version,) + ) # FIXME logger.debug("Allowing! %s", event) return @@ -103,46 +100,30 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru creation_event = auth_events.get((EventTypes.Create, ""), None) if not creation_event: - raise AuthError( - 403, - "No create event in auth events", - ) + raise AuthError(403, "No create event in auth events") creating_domain = get_domain_from_id(event.room_id) originating_domain = get_domain_from_id(event.sender) if creating_domain != originating_domain: if not _can_federate(event, auth_events): - raise AuthError( - 403, - "This room has been marked as unfederatable." - ) + raise AuthError(403, "This room has been marked as unfederatable.") # FIXME: Temp hack if event.type == EventTypes.Aliases: if not event.is_state(): - raise AuthError( - 403, - "Alias event must be a state event", - ) + raise AuthError(403, "Alias event must be a state event") if not event.state_key: - raise AuthError( - 403, - "Alias event must have non-empty state_key" - ) + raise AuthError(403, "Alias event must have non-empty state_key") sender_domain = get_domain_from_id(event.sender) if event.state_key != sender_domain: raise AuthError( - 403, - "Alias event's state_key does not match sender's domain" + 403, "Alias event's state_key does not match sender's domain" ) logger.debug("Allowing! %s", event) return if logger.isEnabledFor(logging.DEBUG): - logger.debug( - "Auth events: %s", - [a.event_id for a in auth_events.values()] - ) + logger.debug("Auth events: %s", [a.event_id for a in auth_events.values()]) if event.type == EventTypes.Member: _is_membership_change_allowed(event, auth_events) @@ -159,9 +140,7 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru invite_level = _get_named_level(auth_events, "invite", 0) if user_level < invite_level: - raise AuthError( - 403, "You don't have permission to invite users", - ) + raise AuthError(403, "You don't have permission to invite users") else: logger.debug("Allowing! %s", event) return @@ -207,7 +186,7 @@ def _is_membership_change_allowed(event, auth_events): # Check if this is the room creator joining: if len(event.prev_event_ids()) == 1 and Membership.JOIN == membership: # Get room creation event: - key = (EventTypes.Create, "", ) + key = (EventTypes.Create, "") create = auth_events.get(key) if create and event.prev_event_ids()[0] == create.event_id: if create.content["creator"] == event.state_key: @@ -219,38 +198,31 @@ def _is_membership_change_allowed(event, auth_events): target_domain = get_domain_from_id(target_user_id) if creating_domain != target_domain: if not _can_federate(event, auth_events): - raise AuthError( - 403, - "This room has been marked as unfederatable." - ) + raise AuthError(403, "This room has been marked as unfederatable.") # get info about the caller - key = (EventTypes.Member, event.user_id, ) + key = (EventTypes.Member, event.user_id) caller = auth_events.get(key) caller_in_room = caller and caller.membership == Membership.JOIN caller_invited = caller and caller.membership == Membership.INVITE # get info about the target - key = (EventTypes.Member, target_user_id, ) + key = (EventTypes.Member, target_user_id) target = auth_events.get(key) target_in_room = target and target.membership == Membership.JOIN target_banned = target and target.membership == Membership.BAN - key = (EventTypes.JoinRules, "", ) + key = (EventTypes.JoinRules, "") join_rule_event = auth_events.get(key) if join_rule_event: - join_rule = join_rule_event.content.get( - "join_rule", JoinRules.INVITE - ) + join_rule = join_rule_event.content.get("join_rule", JoinRules.INVITE) else: join_rule = JoinRules.INVITE user_level = get_user_power_level(event.user_id, auth_events) - target_level = get_user_power_level( - target_user_id, auth_events - ) + target_level = get_user_power_level(target_user_id, auth_events) # FIXME (erikj): What should we do here as the default? ban_level = _get_named_level(auth_events, "ban", 50) @@ -266,29 +238,26 @@ def _is_membership_change_allowed(event, auth_events): "join_rule": join_rule, "target_user_id": target_user_id, "event.user_id": event.user_id, - } + }, ) if Membership.INVITE == membership and "third_party_invite" in event.content: if not _verify_third_party_invite(event, auth_events): raise AuthError(403, "You are not invited to this room.") if target_banned: - raise AuthError( - 403, "%s is banned from the room" % (target_user_id,) - ) + raise AuthError(403, "%s is banned from the room" % (target_user_id,)) return if Membership.JOIN != membership: - if (caller_invited - and Membership.LEAVE == membership - and target_user_id == event.user_id): + if ( + caller_invited + and Membership.LEAVE == membership + and target_user_id == event.user_id + ): return if not caller_in_room: # caller isn't joined - raise AuthError( - 403, - "%s not in room %s." % (event.user_id, event.room_id,) - ) + raise AuthError(403, "%s not in room %s." % (event.user_id, event.room_id)) if Membership.INVITE == membership: # TODO (erikj): We should probably handle this more intelligently @@ -296,19 +265,14 @@ def _is_membership_change_allowed(event, auth_events): # Invites are valid iff caller is in the room and target isn't. if target_banned: - raise AuthError( - 403, "%s is banned from the room" % (target_user_id,) - ) + raise AuthError(403, "%s is banned from the room" % (target_user_id,)) elif target_in_room: # the target is already in the room. - raise AuthError(403, "%s is already in the room." % - target_user_id) + raise AuthError(403, "%s is already in the room." % target_user_id) else: invite_level = _get_named_level(auth_events, "invite", 0) if user_level < invite_level: - raise AuthError( - 403, "You don't have permission to invite users", - ) + raise AuthError(403, "You don't have permission to invite users") elif Membership.JOIN == membership: # Joins are valid iff caller == target and they were: # invited: They are accepting the invitation @@ -329,16 +293,12 @@ def _is_membership_change_allowed(event, auth_events): elif Membership.LEAVE == membership: # TODO (erikj): Implement kicks. if target_banned and user_level < ban_level: - raise AuthError( - 403, "You cannot unban user %s." % (target_user_id,) - ) + raise AuthError(403, "You cannot unban user %s." % (target_user_id,)) elif target_user_id != event.user_id: kick_level = _get_named_level(auth_events, "kick", 50) if user_level < kick_level or user_level <= target_level: - raise AuthError( - 403, "You cannot kick user %s." % target_user_id - ) + raise AuthError(403, "You cannot kick user %s." % target_user_id) elif Membership.BAN == membership: if user_level < ban_level or user_level <= target_level: raise AuthError(403, "You don't have permission to ban") @@ -347,21 +307,17 @@ def _is_membership_change_allowed(event, auth_events): def _check_event_sender_in_room(event, auth_events): - key = (EventTypes.Member, event.user_id, ) + key = (EventTypes.Member, event.user_id) member_event = auth_events.get(key) - return _check_joined_room( - member_event, - event.user_id, - event.room_id - ) + return _check_joined_room(member_event, event.user_id, event.room_id) def _check_joined_room(member, user_id, room_id): if not member or member.membership != Membership.JOIN: - raise AuthError(403, "User %s not in room %s (%s)" % ( - user_id, room_id, repr(member) - )) + raise AuthError( + 403, "User %s not in room %s (%s)" % (user_id, room_id, repr(member)) + ) def get_send_level(etype, state_key, power_levels_event): @@ -402,26 +358,21 @@ def get_send_level(etype, state_key, power_levels_event): def _can_send_event(event, auth_events): power_levels_event = _get_power_level_event(auth_events) - send_level = get_send_level( - event.type, event.get("state_key"), power_levels_event, - ) + send_level = get_send_level(event.type, event.get("state_key"), power_levels_event) user_level = get_user_power_level(event.user_id, auth_events) if user_level < send_level: raise AuthError( 403, - "You don't have permission to post that to the room. " + - "user_level (%d) < send_level (%d)" % (user_level, send_level) + "You don't have permission to post that to the room. " + + "user_level (%d) < send_level (%d)" % (user_level, send_level), ) # Check state_key if hasattr(event, "state_key"): if event.state_key.startswith("@"): if event.state_key != event.user_id: - raise AuthError( - 403, - "You are not allowed to set others state" - ) + raise AuthError(403, "You are not allowed to set others state") return True @@ -459,10 +410,7 @@ def check_redaction(room_version, event, auth_events): event.internal_metadata.recheck_redaction = True return True - raise AuthError( - 403, - "You don't have permission to redact events" - ) + raise AuthError(403, "You don't have permission to redact events") def _check_power_levels(event, auth_events): @@ -479,7 +427,7 @@ def _check_power_levels(event, auth_events): except Exception: raise SynapseError(400, "Not a valid power level: %s" % (v,)) - key = (event.type, event.state_key, ) + key = (event.type, event.state_key) current_state = auth_events.get(key) if not current_state: @@ -500,16 +448,12 @@ def _check_power_levels(event, auth_events): old_list = current_state.content.get("users", {}) for user in set(list(old_list) + list(user_list)): - levels_to_check.append( - (user, "users") - ) + levels_to_check.append((user, "users")) old_list = current_state.content.get("events", {}) new_list = event.content.get("events", {}) for ev_id in set(list(old_list) + list(new_list)): - levels_to_check.append( - (ev_id, "events") - ) + levels_to_check.append((ev_id, "events")) old_state = current_state.content new_state = event.content @@ -540,7 +484,7 @@ def _check_power_levels(event, auth_events): raise AuthError( 403, "You don't have permission to remove ops level equal " - "to your own" + "to your own", ) # Check if the old and new levels are greater than the user level @@ -550,8 +494,7 @@ def _check_power_levels(event, auth_events): if old_level_too_big or new_level_too_big: raise AuthError( 403, - "You don't have permission to add ops level greater " - "than your own" + "You don't have permission to add ops level greater " "than your own", ) @@ -587,10 +530,9 @@ def get_user_power_level(user_id, auth_events): # some things which call this don't pass the create event: hack around # that. - key = (EventTypes.Create, "", ) + key = (EventTypes.Create, "") create_event = auth_events.get(key) - if (create_event is not None and - create_event.content["creator"] == user_id): + if create_event is not None and create_event.content["creator"] == user_id: return 100 else: return 0 @@ -636,9 +578,7 @@ def _verify_third_party_invite(event, auth_events): token = signed["token"] - invite_event = auth_events.get( - (EventTypes.ThirdPartyInvite, token,) - ) + invite_event = auth_events.get((EventTypes.ThirdPartyInvite, token)) if not invite_event: return False @@ -661,8 +601,7 @@ def _verify_third_party_invite(event, auth_events): if not key_name.startswith("ed25519:"): continue verify_key = decode_verify_key_bytes( - key_name, - decode_base64(public_key) + key_name, decode_base64(public_key) ) verify_signed_json(signed, server, verify_key) @@ -671,7 +610,7 @@ def _verify_third_party_invite(event, auth_events): # The caller is responsible for checking that the signing # server has not revoked that public key. return True - except (KeyError, SignatureVerifyException,): + except (KeyError, SignatureVerifyException): continue return False @@ -679,9 +618,7 @@ def _verify_third_party_invite(event, auth_events): def get_public_keys(invite_event): public_keys = [] if "public_key" in invite_event.content: - o = { - "public_key": invite_event.content["public_key"], - } + o = {"public_key": invite_event.content["public_key"]} if "key_validity_url" in invite_event.content: o["key_validity_url"] = invite_event.content["key_validity_url"] public_keys.append(o) @@ -702,22 +639,22 @@ def auth_types_for_event(event): auth_types = [] - auth_types.append((EventTypes.PowerLevels, "", )) - auth_types.append((EventTypes.Member, event.sender, )) - auth_types.append((EventTypes.Create, "", )) + auth_types.append((EventTypes.PowerLevels, "")) + auth_types.append((EventTypes.Member, event.sender)) + auth_types.append((EventTypes.Create, "")) if event.type == EventTypes.Member: membership = event.content["membership"] if membership in [Membership.JOIN, Membership.INVITE]: - auth_types.append((EventTypes.JoinRules, "", )) + auth_types.append((EventTypes.JoinRules, "")) - auth_types.append((EventTypes.Member, event.state_key, )) + auth_types.append((EventTypes.Member, event.state_key)) if membership == Membership.INVITE: if "third_party_invite" in event.content: key = ( EventTypes.ThirdPartyInvite, - event.content["third_party_invite"]["signed"]["token"] + event.content["third_party_invite"]["signed"]["token"], ) auth_types.append(key) diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 7154bcbea6..d3de70e671 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -127,25 +127,25 @@ def _event_dict_property(key): except KeyError: raise AttributeError(key) - return property( - getter, - setter, - delete, - ) + return property(getter, setter, delete) class EventBase(object): - def __init__(self, event_dict, signatures={}, unsigned={}, - internal_metadata_dict={}, rejected_reason=None): + def __init__( + self, + event_dict, + signatures={}, + unsigned={}, + internal_metadata_dict={}, + rejected_reason=None, + ): self.signatures = signatures self.unsigned = unsigned self.rejected_reason = rejected_reason self._event_dict = event_dict - self.internal_metadata = _EventInternalMetadata( - internal_metadata_dict - ) + self.internal_metadata = _EventInternalMetadata(internal_metadata_dict) auth_events = _event_dict_property("auth_events") depth = _event_dict_property("depth") @@ -168,10 +168,7 @@ class EventBase(object): def get_dict(self): d = dict(self._event_dict) - d.update({ - "signatures": self.signatures, - "unsigned": dict(self.unsigned), - }) + d.update({"signatures": self.signatures, "unsigned": dict(self.unsigned)}) return d @@ -358,6 +355,7 @@ class FrozenEventV2(EventBase): class FrozenEventV3(FrozenEventV2): """FrozenEventV3, which differs from FrozenEventV2 only in the event_id format""" + format_version = EventFormatVersions.V3 # All events of this type are V3 @property @@ -414,6 +412,4 @@ def event_type_from_format_version(format_version): elif format_version == EventFormatVersions.V3: return FrozenEventV3 else: - raise Exception( - "No event format %r" % (format_version,) - ) + raise Exception("No event format %r" % (format_version,)) diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 546b6f4982..db011e0407 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -78,7 +78,9 @@ class EventBuilder(object): _redacts = attr.ib(default=None) _origin_server_ts = attr.ib(default=None) - internal_metadata = attr.ib(default=attr.Factory(lambda: _EventInternalMetadata({}))) + internal_metadata = attr.ib( + default=attr.Factory(lambda: _EventInternalMetadata({})) + ) @property def state_key(self): @@ -102,11 +104,9 @@ class EventBuilder(object): """ state_ids = yield self._state.get_current_state_ids( - self.room_id, prev_event_ids, - ) - auth_ids = yield self._auth.compute_auth_events( - self, state_ids, + self.room_id, prev_event_ids ) + auth_ids = yield self._auth.compute_auth_events(self, state_ids) if self.format_version == EventFormatVersions.V1: auth_events = yield self._store.add_event_hashes(auth_ids) @@ -115,9 +115,7 @@ class EventBuilder(object): auth_events = auth_ids prev_events = prev_event_ids - old_depth = yield self._store.get_max_depth_of( - prev_event_ids, - ) + old_depth = yield self._store.get_max_depth_of(prev_event_ids) depth = old_depth + 1 # we cap depth of generated events, to ensure that they are not @@ -217,9 +215,14 @@ class EventBuilderFactory(object): ) -def create_local_event_from_event_dict(clock, hostname, signing_key, - format_version, event_dict, - internal_metadata_dict=None): +def create_local_event_from_event_dict( + clock, + hostname, + signing_key, + format_version, + event_dict, + internal_metadata_dict=None, +): """Takes a fully formed event dict, ensuring that fields like `origin` and `origin_server_ts` have correct values for a locally produced event, then signs and hashes it. @@ -237,9 +240,7 @@ def create_local_event_from_event_dict(clock, hostname, signing_key, """ if format_version not in KNOWN_EVENT_FORMAT_VERSIONS: - raise Exception( - "No event format defined for version %r" % (format_version,) - ) + raise Exception("No event format defined for version %r" % (format_version,)) if internal_metadata_dict is None: internal_metadata_dict = {} @@ -258,13 +259,9 @@ def create_local_event_from_event_dict(clock, hostname, signing_key, event_dict.setdefault("signatures", {}) - add_hashes_and_signatures( - event_dict, - hostname, - signing_key, - ) + add_hashes_and_signatures(event_dict, hostname, signing_key) return event_type_from_format_version(format_version)( - event_dict, internal_metadata_dict=internal_metadata_dict, + event_dict, internal_metadata_dict=internal_metadata_dict ) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index fa09c132a0..a96cdada3d 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -88,8 +88,9 @@ class EventContext(object): self.app_service = None @staticmethod - def with_state(state_group, current_state_ids, prev_state_ids, - prev_group=None, delta_ids=None): + def with_state( + state_group, current_state_ids, prev_state_ids, prev_group=None, delta_ids=None + ): context = EventContext() # The current state including the current event @@ -132,17 +133,19 @@ class EventContext(object): else: prev_state_id = None - defer.returnValue({ - "prev_state_id": prev_state_id, - "event_type": event.type, - "event_state_key": event.state_key if event.is_state() else None, - "state_group": self.state_group, - "rejected": self.rejected, - "prev_group": self.prev_group, - "delta_ids": _encode_state_dict(self.delta_ids), - "prev_state_events": self.prev_state_events, - "app_service_id": self.app_service.id if self.app_service else None - }) + defer.returnValue( + { + "prev_state_id": prev_state_id, + "event_type": event.type, + "event_state_key": event.state_key if event.is_state() else None, + "state_group": self.state_group, + "rejected": self.rejected, + "prev_group": self.prev_group, + "delta_ids": _encode_state_dict(self.delta_ids), + "prev_state_events": self.prev_state_events, + "app_service_id": self.app_service.id if self.app_service else None, + } + ) @staticmethod def deserialize(store, input): @@ -194,7 +197,7 @@ class EventContext(object): if not self._fetching_state_deferred: self._fetching_state_deferred = run_in_background( - self._fill_out_state, store, + self._fill_out_state, store ) yield make_deferred_yieldable(self._fetching_state_deferred) @@ -214,7 +217,7 @@ class EventContext(object): if not self._fetching_state_deferred: self._fetching_state_deferred = run_in_background( - self._fill_out_state, store, + self._fill_out_state, store ) yield make_deferred_yieldable(self._fetching_state_deferred) @@ -240,9 +243,7 @@ class EventContext(object): if self.state_group is None: return - self._current_state_ids = yield store.get_state_ids_for_group( - self.state_group, - ) + self._current_state_ids = yield store.get_state_ids_for_group(self.state_group) if self._prev_state_id and self._event_state_key is not None: self._prev_state_ids = dict(self._current_state_ids) @@ -252,8 +253,9 @@ class EventContext(object): self._prev_state_ids = self._current_state_ids @defer.inlineCallbacks - def update_state(self, state_group, prev_state_ids, current_state_ids, - prev_group, delta_ids): + def update_state( + self, state_group, prev_state_ids, current_state_ids, prev_group, delta_ids + ): """Replace the state in the context """ @@ -279,10 +281,7 @@ def _encode_state_dict(state_dict): if state_dict is None: return None - return [ - (etype, state_key, v) - for (etype, state_key), v in iteritems(state_dict) - ] + return [(etype, state_key, v) for (etype, state_key), v in iteritems(state_dict)] def _decode_state_dict(input): @@ -291,4 +290,4 @@ def _decode_state_dict(input): if input is None: return None - return frozendict({(etype, state_key,): v for etype, state_key, v in input}) + return frozendict({(etype, state_key): v for etype, state_key, v in input}) diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index b8ccced43b..f0de4d961f 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -46,8 +46,15 @@ class SpamChecker(object): return self.spam_checker.check_event_for_spam(event) - def user_may_invite(self, inviter_userid, invitee_userid, third_party_invite, - room_id, new_room, published_room): + def user_may_invite( + self, + inviter_userid, + invitee_userid, + third_party_invite, + room_id, + new_room, + published_room, + ): """Checks if a given user may send an invite If this method returns false, the invite will be rejected. @@ -74,12 +81,17 @@ class SpamChecker(object): return True return self.spam_checker.user_may_invite( - inviter_userid, invitee_userid, third_party_invite, room_id, new_room, + inviter_userid, + invitee_userid, + third_party_invite, + room_id, + new_room, published_room, ) - def user_may_create_room(self, userid, invite_list, third_party_invite_list, - cloning): + def user_may_create_room( + self, userid, invite_list, third_party_invite_list, cloning + ): """Checks if a given user may create a room If this method returns false, the creation request will be rejected. @@ -100,7 +112,7 @@ class SpamChecker(object): return True return self.spam_checker.user_may_create_room( - userid, invite_list, third_party_invite_list, cloning, + userid, invite_list, third_party_invite_list, cloning ) def user_may_create_room_alias(self, userid, room_alias): diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 50ceeb1e8e..8f5d95696b 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -36,8 +36,7 @@ class ThirdPartyEventRules(object): if module is not None: self.third_party_rules = module( - config=config, - http_client=hs.get_simple_http_client(), + config=config, http_client=hs.get_simple_http_client() ) @defer.inlineCallbacks @@ -109,6 +108,6 @@ class ThirdPartyEventRules(object): state_events[key] = room_state_events[event_id] ret = yield self.third_party_rules.check_threepid_can_be_invited( - medium, address, state_events, + medium, address, state_events ) defer.returnValue(ret) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index e2d4384de1..f24f0c16f0 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -31,7 +31,7 @@ from . import EventBase # by a match for 'stuff'. # TODO: This is fast, but fails to handle "foo\\.bar" which should be treated as # the literal fields "foo\" and "bar" but will instead be treated as "foo\\.bar" -SPLIT_FIELD_REGEX = re.compile(r'(? MAX_ALIAS_LENGTH: raise SynapseError( 400, - ("Can't create aliases longer than" - " %d characters" % (MAX_ALIAS_LENGTH,)), + ( + "Can't create aliases longer than" + " %d characters" % (MAX_ALIAS_LENGTH,) + ), Codes.INVALID_PARAM, ) @@ -170,11 +170,7 @@ class EventValidator(object): event (EventBuilder|FrozenEvent) """ - strings = [ - "room_id", - "sender", - "type", - ] + strings = ["room_id", "sender", "type"] if hasattr(event, "state_key"): strings.append("state_key") @@ -187,10 +183,7 @@ class EventValidator(object): UserID.from_string(event.sender) if event.type == EventTypes.Message: - strings = [ - "body", - "msgtype", - ] + strings = ["body", "msgtype"] self._ensure_strings(event.content, strings) diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index fc5cfb7d83..58b929363f 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -44,8 +44,9 @@ class FederationBase(object): self._clock = hs.get_clock() @defer.inlineCallbacks - def _check_sigs_and_hash_and_fetch(self, origin, pdus, room_version, - outlier=False, include_none=False): + def _check_sigs_and_hash_and_fetch( + self, origin, pdus, room_version, outlier=False, include_none=False + ): """Takes a list of PDUs and checks the signatures and hashs of each one. If a PDU fails its signature check then we check if we have it in the database and if not then request if from the originating server of @@ -79,9 +80,7 @@ class FederationBase(object): if not res: # Check local db. res = yield self.store.get_event( - pdu.event_id, - allow_rejected=True, - allow_none=True, + pdu.event_id, allow_rejected=True, allow_none=True ) if not res and pdu.origin != origin: @@ -98,23 +97,16 @@ class FederationBase(object): if not res: logger.warn( - "Failed to find copy of %s with valid signature", - pdu.event_id, + "Failed to find copy of %s with valid signature", pdu.event_id ) defer.returnValue(res) handle = logcontext.preserve_fn(handle_check_result) - deferreds2 = [ - handle(pdu, deferred) - for pdu, deferred in zip(pdus, deferreds) - ] + deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)] valid_pdus = yield logcontext.make_deferred_yieldable( - defer.gatherResults( - deferreds2, - consumeErrors=True, - ) + defer.gatherResults(deferreds2, consumeErrors=True) ).addErrback(unwrapFirstError) if include_none: @@ -124,7 +116,7 @@ class FederationBase(object): def _check_sigs_and_hash(self, room_version, pdu): return logcontext.make_deferred_yieldable( - self._check_sigs_and_hashes(room_version, [pdu])[0], + self._check_sigs_and_hashes(room_version, [pdu])[0] ) def _check_sigs_and_hashes(self, room_version, pdus): @@ -159,11 +151,9 @@ class FederationBase(object): # received event was probably a redacted copy (but we then use our # *actual* redacted copy to be on the safe side.) redacted_event = prune_event(pdu) - if ( - set(redacted_event.keys()) == set(pdu.keys()) and - set(six.iterkeys(redacted_event.content)) - == set(six.iterkeys(pdu.content)) - ): + if set(redacted_event.keys()) == set(pdu.keys()) and set( + six.iterkeys(redacted_event.content) + ) == set(six.iterkeys(pdu.content)): logger.info( "Event %s seems to have been redacted; using our redacted " "copy", @@ -172,14 +162,16 @@ class FederationBase(object): else: logger.warning( "Event %s content has been tampered, redacting", - pdu.event_id, pdu.get_pdu_json(), + pdu.event_id, + pdu.get_pdu_json(), ) return redacted_event if self.spam_checker.check_event_for_spam(pdu): logger.warn( "Event contains spam, redacting %s: %s", - pdu.event_id, pdu.get_pdu_json() + pdu.event_id, + pdu.get_pdu_json(), ) return prune_event(pdu) @@ -190,23 +182,24 @@ class FederationBase(object): with logcontext.PreserveLoggingContext(ctx): logger.warn( "Signature check failed for %s: %s", - pdu.event_id, failure.getErrorMessage(), + pdu.event_id, + failure.getErrorMessage(), ) return failure for deferred, pdu in zip(deferreds, pdus): deferred.addCallbacks( - callback, errback, - callbackArgs=[pdu], - errbackArgs=[pdu], + callback, errback, callbackArgs=[pdu], errbackArgs=[pdu] ) return deferreds -class PduToCheckSig(namedtuple("PduToCheckSig", [ - "pdu", "redacted_pdu_json", "sender_domain", "deferreds", -])): +class PduToCheckSig( + namedtuple( + "PduToCheckSig", ["pdu", "redacted_pdu_json", "sender_domain", "deferreds"] + ) +): pass @@ -260,10 +253,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus): # First we check that the sender event is signed by the sender's domain # (except if its a 3pid invite, in which case it may be sent by any server) - pdus_to_check_sender = [ - p for p in pdus_to_check - if not _is_invite_via_3pid(p.pdu) - ] + pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)] more_deferreds = keyring.verify_json_objects_for_server( [ @@ -297,7 +287,8 @@ def _check_sigs_on_pdus(keyring, room_version, pdus): # (ie, the room version uses old-style non-hash event IDs). if v.event_format == EventFormatVersions.V1: pdus_to_check_event_id = [ - p for p in pdus_to_check + p + for p in pdus_to_check if p.sender_domain != get_domain_from_id(p.pdu.event_id) ] @@ -315,10 +306,8 @@ def _check_sigs_on_pdus(keyring, room_version, pdus): def event_err(e, pdu_to_check): errmsg = ( - "event id %s: unable to verify signature for event id domain: %s" % ( - pdu_to_check.pdu.event_id, - e.getErrorMessage(), - ) + "event id %s: unable to verify signature for event id domain: %s" + % (pdu_to_check.pdu.event_id, e.getErrorMessage()) ) # XX as above: not really sure if these are the right codes raise SynapseError(400, errmsg, Codes.UNAUTHORIZED) @@ -368,21 +357,18 @@ def event_from_pdu_json(pdu_json, event_format_version, outlier=False): """ # we could probably enforce a bunch of other fields here (room_id, sender, # origin, etc etc) - assert_params_in_dict(pdu_json, ('type', 'depth')) + assert_params_in_dict(pdu_json, ("type", "depth")) - depth = pdu_json['depth'] + depth = pdu_json["depth"] if not isinstance(depth, six.integer_types): - raise SynapseError(400, "Depth %r not an intger" % (depth, ), - Codes.BAD_JSON) + raise SynapseError(400, "Depth %r not an intger" % (depth,), Codes.BAD_JSON) if depth < 0: raise SynapseError(400, "Depth too small", Codes.BAD_JSON) elif depth > MAX_DEPTH: raise SynapseError(400, "Depth too large", Codes.BAD_JSON) - event = event_type_from_format_version(event_format_version)( - pdu_json, - ) + event = event_type_from_format_version(event_format_version)(pdu_json) event.internal_metadata.outlier = outlier diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 70573746d6..3883eb525e 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -57,6 +57,7 @@ class InvalidResponseError(RuntimeError): """Helper for _try_destination_list: indicates that the server returned a response we couldn't parse """ + pass @@ -65,9 +66,7 @@ class FederationClient(FederationBase): super(FederationClient, self).__init__(hs) self.pdu_destination_tried = {} - self._clock.looping_call( - self._clear_tried_cache, 60 * 1000, - ) + self._clock.looping_call(self._clear_tried_cache, 60 * 1000) self.state = hs.get_state_handler() self.transport_layer = hs.get_federation_transport_client() @@ -99,8 +98,14 @@ class FederationClient(FederationBase): self.pdu_destination_tried[event_id] = destination_dict @log_function - def make_query(self, destination, query_type, args, - retry_on_dns_fail=False, ignore_backoff=False): + def make_query( + self, + destination, + query_type, + args, + retry_on_dns_fail=False, + ignore_backoff=False, + ): """Sends a federation Query to a remote homeserver of the given type and arguments. @@ -120,7 +125,10 @@ class FederationClient(FederationBase): sent_queries_counter.labels(query_type).inc() return self.transport_layer.make_query( - destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail, + destination, + query_type, + args, + retry_on_dns_fail=retry_on_dns_fail, ignore_backoff=ignore_backoff, ) @@ -137,9 +145,7 @@ class FederationClient(FederationBase): response """ sent_queries_counter.labels("client_device_keys").inc() - return self.transport_layer.query_client_keys( - destination, content, timeout - ) + return self.transport_layer.query_client_keys(destination, content, timeout) @log_function def query_user_devices(self, destination, user_id, timeout=30000): @@ -147,9 +153,7 @@ class FederationClient(FederationBase): server. """ sent_queries_counter.labels("user_devices").inc() - return self.transport_layer.query_user_devices( - destination, user_id, timeout - ) + return self.transport_layer.query_user_devices(destination, user_id, timeout) @log_function def claim_client_keys(self, destination, content, timeout): @@ -164,9 +168,7 @@ class FederationClient(FederationBase): response """ sent_queries_counter.labels("client_one_time_keys").inc() - return self.transport_layer.claim_client_keys( - destination, content, timeout - ) + return self.transport_layer.claim_client_keys(destination, content, timeout) @defer.inlineCallbacks @log_function @@ -191,7 +193,8 @@ class FederationClient(FederationBase): return transaction_data = yield self.transport_layer.backfill( - dest, room_id, extremities, limit) + dest, room_id, extremities, limit + ) logger.debug("backfill transaction_data=%s", repr(transaction_data)) @@ -204,17 +207,19 @@ class FederationClient(FederationBase): ] # FIXME: We should handle signature failures more gracefully. - pdus[:] = yield logcontext.make_deferred_yieldable(defer.gatherResults( - self._check_sigs_and_hashes(room_version, pdus), - consumeErrors=True, - ).addErrback(unwrapFirstError)) + pdus[:] = yield logcontext.make_deferred_yieldable( + defer.gatherResults( + self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True + ).addErrback(unwrapFirstError) + ) defer.returnValue(pdus) @defer.inlineCallbacks @log_function - def get_pdu(self, destinations, event_id, room_version, outlier=False, - timeout=None): + def get_pdu( + self, destinations, event_id, room_version, outlier=False, timeout=None + ): """Requests the PDU with given origin and ID from the remote home servers. @@ -255,7 +260,7 @@ class FederationClient(FederationBase): try: transaction_data = yield self.transport_layer.get_event( - destination, event_id, timeout=timeout, + destination, event_id, timeout=timeout ) logger.debug( @@ -282,8 +287,7 @@ class FederationClient(FederationBase): except SynapseError as e: logger.info( - "Failed to get PDU %s from %s because %s", - event_id, destination, e, + "Failed to get PDU %s from %s because %s", event_id, destination, e ) continue except NotRetryingDestination as e: @@ -296,8 +300,7 @@ class FederationClient(FederationBase): pdu_attempts[destination] = now logger.info( - "Failed to get PDU %s from %s because %s", - event_id, destination, e, + "Failed to get PDU %s from %s because %s", event_id, destination, e ) continue @@ -326,7 +329,7 @@ class FederationClient(FederationBase): # we have most of the state and auth_chain already. # However, this may 404 if the other side has an old synapse. result = yield self.transport_layer.get_room_state_ids( - destination, room_id, event_id=event_id, + destination, room_id, event_id=event_id ) state_event_ids = result["pdu_ids"] @@ -340,12 +343,10 @@ class FederationClient(FederationBase): logger.warning( "Failed to fetch missing state/auth events for %s: %s", room_id, - failed_to_fetch + failed_to_fetch, ) - event_map = { - ev.event_id: ev for ev in fetched_events - } + event_map = {ev.event_id: ev for ev in fetched_events} pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map] auth_chain = [ @@ -362,15 +363,14 @@ class FederationClient(FederationBase): raise e result = yield self.transport_layer.get_room_state( - destination, room_id, event_id=event_id, + destination, room_id, event_id=event_id ) room_version = yield self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) pdus = [ - event_from_pdu_json(p, format_ver, outlier=True) - for p in result["pdus"] + event_from_pdu_json(p, format_ver, outlier=True) for p in result["pdus"] ] auth_chain = [ @@ -378,9 +378,9 @@ class FederationClient(FederationBase): for p in result.get("auth_chain", []) ] - seen_events = yield self.store.get_events([ - ev.event_id for ev in itertools.chain(pdus, auth_chain) - ]) + seen_events = yield self.store.get_events( + [ev.event_id for ev in itertools.chain(pdus, auth_chain)] + ) signed_pdus = yield self._check_sigs_and_hash_and_fetch( destination, @@ -442,7 +442,7 @@ class FederationClient(FederationBase): batch_size = 20 missing_events = list(missing_events) for i in range(0, len(missing_events), batch_size): - batch = set(missing_events[i:i + batch_size]) + batch = set(missing_events[i : i + batch_size]) deferreds = [ run_in_background( @@ -470,21 +470,17 @@ class FederationClient(FederationBase): @defer.inlineCallbacks @log_function def get_event_auth(self, destination, room_id, event_id): - res = yield self.transport_layer.get_event_auth( - destination, room_id, event_id, - ) + res = yield self.transport_layer.get_event_auth(destination, room_id, event_id) room_version = yield self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) auth_chain = [ - event_from_pdu_json(p, format_ver, outlier=True) - for p in res["auth_chain"] + event_from_pdu_json(p, format_ver, outlier=True) for p in res["auth_chain"] ] signed_auth = yield self._check_sigs_and_hash_and_fetch( - destination, auth_chain, - outlier=True, room_version=room_version, + destination, auth_chain, outlier=True, room_version=room_version ) signed_auth.sort(key=lambda e: e.depth) @@ -527,28 +523,26 @@ class FederationClient(FederationBase): res = yield callback(destination) defer.returnValue(res) except InvalidResponseError as e: - logger.warn( - "Failed to %s via %s: %s", - description, destination, e, - ) + logger.warn("Failed to %s via %s: %s", description, destination, e) except HttpResponseException as e: if not 500 <= e.code < 600: raise e.to_synapse_error() else: logger.warn( "Failed to %s via %s: %i %s", - description, destination, e.code, e.args[0], + description, + destination, + e.code, + e.args[0], ) except Exception: - logger.warn( - "Failed to %s via %s", - description, destination, exc_info=1, - ) + logger.warn("Failed to %s via %s", description, destination, exc_info=1) - raise RuntimeError("Failed to %s via any server" % (description, )) + raise RuntimeError("Failed to %s via any server" % (description,)) - def make_membership_event(self, destinations, room_id, user_id, membership, - content, params): + def make_membership_event( + self, destinations, room_id, user_id, membership, content, params + ): """ Creates an m.room.member event, with context, without participating in the room. @@ -584,14 +578,14 @@ class FederationClient(FederationBase): valid_memberships = {Membership.JOIN, Membership.LEAVE} if membership not in valid_memberships: raise RuntimeError( - "make_membership_event called with membership='%s', must be one of %s" % - (membership, ",".join(valid_memberships)) + "make_membership_event called with membership='%s', must be one of %s" + % (membership, ",".join(valid_memberships)) ) @defer.inlineCallbacks def send_request(destination): ret = yield self.transport_layer.make_membership_event( - destination, room_id, user_id, membership, params, + destination, room_id, user_id, membership, params ) # Note: If not supplied, the room version may be either v1 or v2, @@ -614,16 +608,17 @@ class FederationClient(FederationBase): pdu_dict["prev_state"] = [] ev = builder.create_local_event_from_event_dict( - self._clock, self.hostname, self.signing_key, - format_version=event_format, event_dict=pdu_dict, + self._clock, + self.hostname, + self.signing_key, + format_version=event_format, + event_dict=pdu_dict, ) - defer.returnValue( - (destination, ev, event_format) - ) + defer.returnValue((destination, ev, event_format)) return self._try_destination_list( - "make_" + membership, destinations, send_request, + "make_" + membership, destinations, send_request ) def send_join(self, destinations, pdu, event_format_version): @@ -655,9 +650,7 @@ class FederationClient(FederationBase): create_event = e break else: - raise InvalidResponseError( - "no %s in auth chain" % (EventTypes.Create,), - ) + raise InvalidResponseError("no %s in auth chain" % (EventTypes.Create,)) # the room version should be sane. room_version = create_event.content.get("room_version", "1") @@ -665,9 +658,8 @@ class FederationClient(FederationBase): # This shouldn't be possible, because the remote server should have # rejected the join attempt during make_join. raise InvalidResponseError( - "room appears to have unsupported version %s" % ( - room_version, - )) + "room appears to have unsupported version %s" % (room_version,) + ) @defer.inlineCallbacks def send_request(destination): @@ -691,10 +683,7 @@ class FederationClient(FederationBase): for p in content.get("auth_chain", []) ] - pdus = { - p.event_id: p - for p in itertools.chain(state, auth_chain) - } + pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)} room_version = None for e in state: @@ -710,15 +699,13 @@ class FederationClient(FederationBase): raise SynapseError(400, "No create event in state") valid_pdus = yield self._check_sigs_and_hash_and_fetch( - destination, list(pdus.values()), + destination, + list(pdus.values()), outlier=True, room_version=room_version, ) - valid_pdus_map = { - p.event_id: p - for p in valid_pdus - } + valid_pdus_map = {p.event_id: p for p in valid_pdus} # NB: We *need* to copy to ensure that we don't have multiple # references being passed on, as that causes... issues. @@ -741,11 +728,14 @@ class FederationClient(FederationBase): check_authchain_validity(signed_auth) - defer.returnValue({ - "state": signed_state, - "auth_chain": signed_auth, - "origin": destination, - }) + defer.returnValue( + { + "state": signed_state, + "auth_chain": signed_auth, + "origin": destination, + } + ) + return self._try_destination_list("send_join", destinations, send_request) @defer.inlineCallbacks @@ -854,6 +844,7 @@ class FederationClient(FederationBase): Fails with a ``RuntimeError`` if no servers were reachable. """ + @defer.inlineCallbacks def send_request(destination): time_now = self._clock.time_msec() @@ -869,14 +860,23 @@ class FederationClient(FederationBase): return self._try_destination_list("send_leave", destinations, send_request) - def get_public_rooms(self, destination, limit=None, since_token=None, - search_filter=None, include_all_networks=False, - third_party_instance_id=None): + def get_public_rooms( + self, + destination, + limit=None, + since_token=None, + search_filter=None, + include_all_networks=False, + third_party_instance_id=None, + ): if destination == self.server_name: return return self.transport_layer.get_public_rooms( - destination, limit, since_token, search_filter, + destination, + limit, + since_token, + search_filter, include_all_networks=include_all_networks, third_party_instance_id=third_party_instance_id, ) @@ -891,9 +891,7 @@ class FederationClient(FederationBase): """ time_now = self._clock.time_msec() - send_content = { - "auth_chain": [e.get_pdu_json(time_now) for e in local_auth], - } + send_content = {"auth_chain": [e.get_pdu_json(time_now) for e in local_auth]} code, content = yield self.transport_layer.send_query_auth( destination=destination, @@ -905,13 +903,10 @@ class FederationClient(FederationBase): room_version = yield self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) - auth_chain = [ - event_from_pdu_json(e, format_ver) - for e in content["auth_chain"] - ] + auth_chain = [event_from_pdu_json(e, format_ver) for e in content["auth_chain"]] signed_auth = yield self._check_sigs_and_hash_and_fetch( - destination, auth_chain, outlier=True, room_version=room_version, + destination, auth_chain, outlier=True, room_version=room_version ) signed_auth.sort(key=lambda e: e.depth) @@ -925,8 +920,16 @@ class FederationClient(FederationBase): defer.returnValue(ret) @defer.inlineCallbacks - def get_missing_events(self, destination, room_id, earliest_events_ids, - latest_events, limit, min_depth, timeout): + def get_missing_events( + self, + destination, + room_id, + earliest_events_ids, + latest_events, + limit, + min_depth, + timeout, + ): """Tries to fetch events we are missing. This is called when we receive an event without having received all of its ancestors. @@ -957,12 +960,11 @@ class FederationClient(FederationBase): format_ver = room_version_to_event_format(room_version) events = [ - event_from_pdu_json(e, format_ver) - for e in content.get("events", []) + event_from_pdu_json(e, format_ver) for e in content.get("events", []) ] signed_events = yield self._check_sigs_and_hash_and_fetch( - destination, events, outlier=False, room_version=room_version, + destination, events, outlier=False, room_version=room_version ) except HttpResponseException as e: if not e.code == 400: @@ -982,17 +984,14 @@ class FederationClient(FederationBase): try: yield self.transport_layer.exchange_third_party_invite( - destination=destination, - room_id=room_id, - event_dict=event_dict, + destination=destination, room_id=room_id, event_dict=event_dict ) defer.returnValue(None) except CodeMessageException: raise except Exception as e: logger.exception( - "Failed to send_third_party_invite via %s: %s", - destination, str(e) + "Failed to send_third_party_invite via %s: %s", destination, str(e) ) raise RuntimeError("Failed to send to any server.") diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 4c28c1dc3c..2e0cebb638 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -69,7 +69,6 @@ received_queries_counter = Counter( class FederationServer(FederationBase): - def __init__(self, hs): super(FederationServer, self).__init__(hs) @@ -118,11 +117,13 @@ class FederationServer(FederationBase): # use a linearizer to ensure that we don't process the same transaction # multiple times in parallel. - with (yield self._transaction_linearizer.queue( - (origin, transaction.transaction_id), - )): + with ( + yield self._transaction_linearizer.queue( + (origin, transaction.transaction_id) + ) + ): result = yield self._handle_incoming_transaction( - origin, transaction, request_time, + origin, transaction, request_time ) defer.returnValue(result) @@ -144,7 +145,7 @@ class FederationServer(FederationBase): if response: logger.debug( "[%s] We've already responded to this request", - transaction.transaction_id + transaction.transaction_id, ) defer.returnValue(response) return @@ -152,18 +153,15 @@ class FederationServer(FederationBase): logger.debug("[%s] Transaction is new", transaction.transaction_id) # Reject if PDU count > 50 and EDU count > 100 - if (len(transaction.pdus) > 50 - or (hasattr(transaction, "edus") and len(transaction.edus) > 100)): + if len(transaction.pdus) > 50 or ( + hasattr(transaction, "edus") and len(transaction.edus) > 100 + ): - logger.info( - "Transaction PDU or EDU count too large. Returning 400", - ) + logger.info("Transaction PDU or EDU count too large. Returning 400") response = {} yield self.transaction_actions.set_response( - origin, - transaction, - 400, response + origin, transaction, 400, response ) defer.returnValue((400, response)) @@ -230,9 +228,7 @@ class FederationServer(FederationBase): try: yield self.check_server_matches_acl(origin_host, room_id) except AuthError as e: - logger.warn( - "Ignoring PDUs for room %s from banned server", room_id, - ) + logger.warn("Ignoring PDUs for room %s from banned server", room_id) for pdu in pdus_by_room[room_id]: event_id = pdu.event_id pdu_results[event_id] = e.error_dict() @@ -242,9 +238,7 @@ class FederationServer(FederationBase): event_id = pdu.event_id with nested_logging_context(event_id): try: - yield self._handle_received_pdu( - origin, pdu - ) + yield self._handle_received_pdu(origin, pdu) pdu_results[event_id] = {} except FederationError as e: logger.warn("Error handling PDU %s: %s", event_id, e) @@ -259,29 +253,18 @@ class FederationServer(FederationBase): ) yield concurrently_execute( - process_pdus_for_room, pdus_by_room.keys(), - TRANSACTION_CONCURRENCY_LIMIT, + process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT ) if hasattr(transaction, "edus"): for edu in (Edu(**x) for x in transaction.edus): - yield self.received_edu( - origin, - edu.edu_type, - edu.content - ) + yield self.received_edu(origin, edu.edu_type, edu.content) - response = { - "pdus": pdu_results, - } + response = {"pdus": pdu_results} logger.debug("Returning: %s", str(response)) - yield self.transaction_actions.set_response( - origin, - transaction, - 200, response - ) + yield self.transaction_actions.set_response(origin, transaction, 200, response) defer.returnValue((200, response)) @defer.inlineCallbacks @@ -311,7 +294,8 @@ class FederationServer(FederationBase): resp = yield self._state_resp_cache.wrap( (room_id, event_id), self._on_context_state_request_compute, - room_id, event_id, + room_id, + event_id, ) defer.returnValue((200, resp)) @@ -328,24 +312,17 @@ class FederationServer(FederationBase): if not in_room: raise AuthError(403, "Host not in room.") - state_ids = yield self.handler.get_state_ids_for_pdu( - room_id, event_id, - ) + state_ids = yield self.handler.get_state_ids_for_pdu(room_id, event_id) auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids) - defer.returnValue((200, { - "pdu_ids": state_ids, - "auth_chain_ids": auth_chain_ids, - })) + defer.returnValue( + (200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}) + ) @defer.inlineCallbacks def _on_context_state_request_compute(self, room_id, event_id): - pdus = yield self.handler.get_state_for_pdu( - room_id, event_id, - ) - auth_chain = yield self.store.get_auth_chain( - [pdu.event_id for pdu in pdus] - ) + pdus = yield self.handler.get_state_for_pdu(room_id, event_id) + auth_chain = yield self.store.get_auth_chain([pdu.event_id for pdu in pdus]) for event in auth_chain: # We sign these again because there was a bug where we @@ -355,14 +332,16 @@ class FederationServer(FederationBase): compute_event_signature( event.get_pdu_json(), self.hs.hostname, - self.hs.config.signing_key[0] + self.hs.config.signing_key[0], ) ) - defer.returnValue({ - "pdus": [pdu.get_pdu_json() for pdu in pdus], - "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], - }) + defer.returnValue( + { + "pdus": [pdu.get_pdu_json() for pdu in pdus], + "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], + } + ) @defer.inlineCallbacks @log_function @@ -370,9 +349,7 @@ class FederationServer(FederationBase): pdu = yield self.handler.get_persisted_pdu(origin, event_id) if pdu: - defer.returnValue( - (200, self._transaction_from_pdus([pdu]).get_dict()) - ) + defer.returnValue((200, self._transaction_from_pdus([pdu]).get_dict())) else: defer.returnValue((404, "")) @@ -394,10 +371,9 @@ class FederationServer(FederationBase): pdu = yield self.handler.on_make_join_request(room_id, user_id) time_now = self._clock.time_msec() - defer.returnValue({ - "event": pdu.get_pdu_json(time_now), - "room_version": room_version, - }) + defer.returnValue( + {"event": pdu.get_pdu_json(time_now), "room_version": room_version} + ) @defer.inlineCallbacks def on_invite_request(self, origin, content, room_version): @@ -431,12 +407,17 @@ class FederationServer(FederationBase): logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) res_pdus = yield self.handler.on_send_join_request(origin, pdu) time_now = self._clock.time_msec() - defer.returnValue((200, { - "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]], - "auth_chain": [ - p.get_pdu_json(time_now) for p in res_pdus["auth_chain"] - ], - })) + defer.returnValue( + ( + 200, + { + "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]], + "auth_chain": [ + p.get_pdu_json(time_now) for p in res_pdus["auth_chain"] + ], + }, + ) + ) @defer.inlineCallbacks def on_make_leave_request(self, origin, room_id, user_id): @@ -447,10 +428,9 @@ class FederationServer(FederationBase): room_version = yield self.store.get_room_version(room_id) time_now = self._clock.time_msec() - defer.returnValue({ - "event": pdu.get_pdu_json(time_now), - "room_version": room_version, - }) + defer.returnValue( + {"event": pdu.get_pdu_json(time_now), "room_version": room_version} + ) @defer.inlineCallbacks def on_send_leave_request(self, origin, content, room_id): @@ -475,9 +455,7 @@ class FederationServer(FederationBase): time_now = self._clock.time_msec() auth_pdus = yield self.handler.on_event_auth(event_id) - res = { - "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus], - } + res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]} defer.returnValue((200, res)) @defer.inlineCallbacks @@ -508,12 +486,11 @@ class FederationServer(FederationBase): format_ver = room_version_to_event_format(room_version) auth_chain = [ - event_from_pdu_json(e, format_ver) - for e in content["auth_chain"] + event_from_pdu_json(e, format_ver) for e in content["auth_chain"] ] signed_auth = yield self._check_sigs_and_hash_and_fetch( - origin, auth_chain, outlier=True, room_version=room_version, + origin, auth_chain, outlier=True, room_version=room_version ) ret = yield self.handler.on_query_auth( @@ -527,17 +504,12 @@ class FederationServer(FederationBase): time_now = self._clock.time_msec() send_content = { - "auth_chain": [ - e.get_pdu_json(time_now) - for e in ret["auth_chain"] - ], + "auth_chain": [e.get_pdu_json(time_now) for e in ret["auth_chain"]], "rejects": ret.get("rejects", []), "missing": ret.get("missing", []), } - defer.returnValue( - (200, send_content) - ) + defer.returnValue((200, send_content)) @log_function def on_query_client_keys(self, origin, content): @@ -566,20 +538,23 @@ class FederationServer(FederationBase): logger.info( "Claimed one-time-keys: %s", - ",".join(( - "%s for %s:%s" % (key_id, user_id, device_id) - for user_id, user_keys in iteritems(json_result) - for device_id, device_keys in iteritems(user_keys) - for key_id, _ in iteritems(device_keys) - )), + ",".join( + ( + "%s for %s:%s" % (key_id, user_id, device_id) + for user_id, user_keys in iteritems(json_result) + for device_id, device_keys in iteritems(user_keys) + for key_id, _ in iteritems(device_keys) + ) + ), ) defer.returnValue({"one_time_keys": json_result}) @defer.inlineCallbacks @log_function - def on_get_missing_events(self, origin, room_id, earliest_events, - latest_events, limit): + def on_get_missing_events( + self, origin, room_id, earliest_events, latest_events, limit + ): with (yield self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) yield self.check_server_matches_acl(origin_host, room_id) @@ -587,11 +562,13 @@ class FederationServer(FederationBase): logger.info( "on_get_missing_events: earliest_events: %r, latest_events: %r," " limit: %d", - earliest_events, latest_events, limit, + earliest_events, + latest_events, + limit, ) missing_events = yield self.handler.on_get_missing_events( - origin, room_id, earliest_events, latest_events, limit, + origin, room_id, earliest_events, latest_events, limit ) if len(missing_events) < 5: @@ -603,9 +580,9 @@ class FederationServer(FederationBase): time_now = self._clock.time_msec() - defer.returnValue({ - "events": [ev.get_pdu_json(time_now) for ev in missing_events], - }) + defer.returnValue( + {"events": [ev.get_pdu_json(time_now) for ev in missing_events]} + ) @log_function def on_openid_userinfo(self, token): @@ -666,22 +643,17 @@ class FederationServer(FederationBase): # origin. See bug #1893. This is also true for some third party # invites). if not ( - pdu.type == 'm.room.member' and - pdu.content and - pdu.content.get("membership", None) in ( - Membership.JOIN, Membership.INVITE, - ) + pdu.type == "m.room.member" + and pdu.content + and pdu.content.get("membership", None) + in (Membership.JOIN, Membership.INVITE) ): logger.info( - "Discarding PDU %s from invalid origin %s", - pdu.event_id, origin + "Discarding PDU %s from invalid origin %s", pdu.event_id, origin ) return else: - logger.info( - "Accepting join PDU %s from %s", - pdu.event_id, origin - ) + logger.info("Accepting join PDU %s from %s", pdu.event_id, origin) # We've already checked that we know the room version by this point room_version = yield self.store.get_room_version(pdu.room_id) @@ -690,33 +662,19 @@ class FederationServer(FederationBase): try: pdu = yield self._check_sigs_and_hash(room_version, pdu) except SynapseError as e: - raise FederationError( - "ERROR", - e.code, - e.msg, - affected=pdu.event_id, - ) + raise FederationError("ERROR", e.code, e.msg, affected=pdu.event_id) - yield self.handler.on_receive_pdu( - origin, pdu, sent_to_us_directly=True, - ) + yield self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True) def __str__(self): return "" % self.server_name @defer.inlineCallbacks def exchange_third_party_invite( - self, - sender_user_id, - target_user_id, - room_id, - signed, + self, sender_user_id, target_user_id, room_id, signed ): ret = yield self.handler.exchange_third_party_invite( - sender_user_id, - target_user_id, - room_id, - signed, + sender_user_id, target_user_id, room_id, signed ) defer.returnValue(ret) @@ -771,7 +729,7 @@ def server_matches_acl_event(server_name, acl_event): allow_ip_literals = True if not allow_ip_literals: # check for ipv6 literals. These start with '['. - if server_name[0] == '[': + if server_name[0] == "[": return False # check for ipv4 literals. We can just lift the routine from twisted. @@ -805,7 +763,9 @@ def server_matches_acl_event(server_name, acl_event): def _acl_entry_matches(server_name, acl_entry): if not isinstance(acl_entry, six.string_types): - logger.warn("Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)) + logger.warn( + "Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry) + ) return False regex = glob_to_regex(acl_entry) return regex.match(server_name) @@ -815,6 +775,7 @@ class FederationHandlerRegistry(object): """Allows classes to register themselves as handlers for a given EDU or query type for incoming federation traffic. """ + def __init__(self): self.edu_handlers = {} self.query_handlers = {} @@ -848,9 +809,7 @@ class FederationHandlerRegistry(object): on and the result used as the response to the query request. """ if query_type in self.query_handlers: - raise KeyError( - "Already have a Query handler for %s" % (query_type,) - ) + raise KeyError("Already have a Query handler for %s" % (query_type,)) logger.info("Registering federation query handler for %r", query_type) @@ -905,14 +864,10 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry): handler = self.edu_handlers.get(edu_type) if handler: return super(ReplicationFederationHandlerRegistry, self).on_edu( - edu_type, origin, content, + edu_type, origin, content ) - return self._send_edu( - edu_type=edu_type, - origin=origin, - content=content, - ) + return self._send_edu(edu_type=edu_type, origin=origin, content=content) def on_query(self, query_type, args): """Overrides FederationHandlerRegistry @@ -921,7 +876,4 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry): if handler: return handler(args) - return self._get_query_client( - query_type=query_type, - args=args, - ) + return self._get_query_client(query_type=query_type, args=args) diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index 74ffd13b4f..7535f79203 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -46,12 +46,9 @@ class TransactionActions(object): response code and response body. """ if not transaction.transaction_id: - raise RuntimeError("Cannot persist a transaction with no " - "transaction_id") + raise RuntimeError("Cannot persist a transaction with no " "transaction_id") - return self.store.get_received_txn_response( - transaction.transaction_id, origin - ) + return self.store.get_received_txn_response(transaction.transaction_id, origin) @log_function def set_response(self, origin, transaction, code, response): @@ -61,14 +58,10 @@ class TransactionActions(object): Deferred """ if not transaction.transaction_id: - raise RuntimeError("Cannot persist a transaction with no " - "transaction_id") + raise RuntimeError("Cannot persist a transaction with no " "transaction_id") return self.store.set_received_txn_response( - transaction.transaction_id, - origin, - code, - response, + transaction.transaction_id, origin, code, response ) @defer.inlineCallbacks diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 0240b339b0..454456a52d 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -77,12 +77,22 @@ class FederationRemoteSendQueue(object): # lambda binds to the queue rather than to the name of the queue which # changes. ARGH. def register(name, queue): - LaterGauge("synapse_federation_send_queue_%s_size" % (queue_name,), - "", [], lambda: len(queue)) + LaterGauge( + "synapse_federation_send_queue_%s_size" % (queue_name,), + "", + [], + lambda: len(queue), + ) for queue_name in [ - "presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed", - "edus", "device_messages", "pos_time", "presence_destinations", + "presence_map", + "presence_changed", + "keyed_edu", + "keyed_edu_changed", + "edus", + "device_messages", + "pos_time", + "presence_destinations", ]: register(queue_name, getattr(self, queue_name)) @@ -121,9 +131,7 @@ class FederationRemoteSendQueue(object): del self.presence_changed[key] user_ids = set( - user_id - for uids in self.presence_changed.values() - for user_id in uids + user_id for uids in self.presence_changed.values() for user_id in uids ) keys = self.presence_destinations.keys() @@ -285,19 +293,21 @@ class FederationRemoteSendQueue(object): ] for (key, user_id) in dest_user_ids: - rows.append((key, PresenceRow( - state=self.presence_map[user_id], - ))) + rows.append((key, PresenceRow(state=self.presence_map[user_id]))) # Fetch presence to send to destinations i = self.presence_destinations.bisect_right(from_token) j = self.presence_destinations.bisect_right(to_token) + 1 for pos, (user_id, dests) in self.presence_destinations.items()[i:j]: - rows.append((pos, PresenceDestinationsRow( - state=self.presence_map[user_id], - destinations=list(dests), - ))) + rows.append( + ( + pos, + PresenceDestinationsRow( + state=self.presence_map[user_id], destinations=list(dests) + ), + ) + ) # Fetch changes keyed edus i = self.keyed_edu_changed.bisect_right(from_token) @@ -308,10 +318,14 @@ class FederationRemoteSendQueue(object): keyed_edus = {v: k for k, v in self.keyed_edu_changed.items()[i:j]} for ((destination, edu_key), pos) in iteritems(keyed_edus): - rows.append((pos, KeyedEduRow( - key=edu_key, - edu=self.keyed_edu[(destination, edu_key)], - ))) + rows.append( + ( + pos, + KeyedEduRow( + key=edu_key, edu=self.keyed_edu[(destination, edu_key)] + ), + ) + ) # Fetch changed edus i = self.edus.bisect_right(from_token) @@ -327,9 +341,7 @@ class FederationRemoteSendQueue(object): device_messages = {v: k for k, v in self.device_messages.items()[i:j]} for (destination, pos) in iteritems(device_messages): - rows.append((pos, DeviceRow( - destination=destination, - ))) + rows.append((pos, DeviceRow(destination=destination))) # Sort rows based on pos rows.sort() @@ -377,16 +389,14 @@ class BaseFederationRow(object): raise NotImplementedError() -class PresenceRow(BaseFederationRow, namedtuple("PresenceRow", ( - "state", # UserPresenceState -))): +class PresenceRow( + BaseFederationRow, namedtuple("PresenceRow", ("state",)) # UserPresenceState +): TypeId = "p" @staticmethod def from_data(data): - return PresenceRow( - state=UserPresenceState.from_dict(data) - ) + return PresenceRow(state=UserPresenceState.from_dict(data)) def to_data(self): return self.state.as_dict() @@ -395,33 +405,35 @@ class PresenceRow(BaseFederationRow, namedtuple("PresenceRow", ( buff.presence.append(self.state) -class PresenceDestinationsRow(BaseFederationRow, namedtuple("PresenceDestinationsRow", ( - "state", # UserPresenceState - "destinations", # list[str] -))): +class PresenceDestinationsRow( + BaseFederationRow, + namedtuple( + "PresenceDestinationsRow", + ("state", "destinations"), # UserPresenceState # list[str] + ), +): TypeId = "pd" @staticmethod def from_data(data): return PresenceDestinationsRow( - state=UserPresenceState.from_dict(data["state"]), - destinations=data["dests"], + state=UserPresenceState.from_dict(data["state"]), destinations=data["dests"] ) def to_data(self): - return { - "state": self.state.as_dict(), - "dests": self.destinations, - } + return {"state": self.state.as_dict(), "dests": self.destinations} def add_to_buffer(self, buff): buff.presence_destinations.append((self.state, self.destinations)) -class KeyedEduRow(BaseFederationRow, namedtuple("KeyedEduRow", ( - "key", # tuple(str) - the edu key passed to send_edu - "edu", # Edu -))): +class KeyedEduRow( + BaseFederationRow, + namedtuple( + "KeyedEduRow", + ("key", "edu"), # tuple(str) - the edu key passed to send_edu # Edu + ), +): """Streams EDUs that have an associated key that is ued to clobber. For example, typing EDUs clobber based on room_id. """ @@ -430,28 +442,19 @@ class KeyedEduRow(BaseFederationRow, namedtuple("KeyedEduRow", ( @staticmethod def from_data(data): - return KeyedEduRow( - key=tuple(data["key"]), - edu=Edu(**data["edu"]), - ) + return KeyedEduRow(key=tuple(data["key"]), edu=Edu(**data["edu"])) def to_data(self): - return { - "key": self.key, - "edu": self.edu.get_internal_dict(), - } + return {"key": self.key, "edu": self.edu.get_internal_dict()} def add_to_buffer(self, buff): - buff.keyed_edus.setdefault( - self.edu.destination, {} - )[self.key] = self.edu + buff.keyed_edus.setdefault(self.edu.destination, {})[self.key] = self.edu -class EduRow(BaseFederationRow, namedtuple("EduRow", ( - "edu", # Edu -))): +class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu """Streams EDUs that don't have keys. See KeyedEduRow """ + TypeId = "e" @staticmethod @@ -465,13 +468,12 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ( buff.edus.setdefault(self.edu.destination, []).append(self.edu) -class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", ( - "destination", # str -))): +class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", ("destination",))): # str """Streams the fact that either a) there is pending to device messages for users on the remote, or b) a local users device has changed and needs to be sent to the remote. """ + TypeId = "d" @staticmethod @@ -487,23 +489,20 @@ class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", ( TypeToRow = { Row.TypeId: Row - for Row in ( - PresenceRow, - PresenceDestinationsRow, - KeyedEduRow, - EduRow, - DeviceRow, - ) + for Row in (PresenceRow, PresenceDestinationsRow, KeyedEduRow, EduRow, DeviceRow) } -ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", ( - "presence", # list(UserPresenceState) - "presence_destinations", # list of tuples of UserPresenceState and destinations - "keyed_edus", # dict of destination -> { key -> Edu } - "edus", # dict of destination -> [Edu] - "device_destinations", # set of destinations -)) +ParsedFederationStreamData = namedtuple( + "ParsedFederationStreamData", + ( + "presence", # list(UserPresenceState) + "presence_destinations", # list of tuples of UserPresenceState and destinations + "keyed_edus", # dict of destination -> { key -> Edu } + "edus", # dict of destination -> [Edu] + "device_destinations", # set of destinations + ), +) def process_rows_for_federation(transaction_queue, rows): @@ -542,7 +541,7 @@ def process_rows_for_federation(transaction_queue, rows): for state, destinations in buff.presence_destinations: transaction_queue.send_presence_to_destinations( - states=[state], destinations=destinations, + states=[state], destinations=destinations ) for destination, edu_map in iteritems(buff.keyed_edus): diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 4224b29ecf..45eea4870e 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -44,8 +44,8 @@ sent_pdus_destination_dist_count = Counter( ) sent_pdus_destination_dist_total = Counter( - "synapse_federation_client_sent_pdu_destinations:total", "" - "Total number of PDUs queued for sending across all destinations", + "synapse_federation_client_sent_pdu_destinations:total", + "" "Total number of PDUs queued for sending across all destinations", ) @@ -63,14 +63,15 @@ class FederationSender(object): self._transaction_manager = TransactionManager(hs) # map from destination to PerDestinationQueue - self._per_destination_queues = {} # type: dict[str, PerDestinationQueue] + self._per_destination_queues = {} # type: dict[str, PerDestinationQueue] LaterGauge( "synapse_federation_transaction_queue_pending_destinations", "", [], lambda: sum( - 1 for d in self._per_destination_queues.values() + 1 + for d in self._per_destination_queues.values() if d.transmission_loop_running ), ) @@ -108,8 +109,7 @@ class FederationSender(object): # awaiting a call to flush_read_receipts_for_room. The presence of an entry # here for a given room means that we are rate-limiting RR flushes to that room, # and that there is a pending call to _flush_rrs_for_room in the system. - self._queues_awaiting_rr_flush_by_room = { - } # type: dict[str, set[PerDestinationQueue]] + self._queues_awaiting_rr_flush_by_room = {} # type: dict[str, set[PerDestinationQueue]] self._rr_txn_interval_per_room_ms = ( 1000.0 / hs.get_config().federation_rr_transactions_per_room_per_second @@ -141,8 +141,7 @@ class FederationSender(object): # fire off a processing loop in the background run_as_background_process( - "process_event_queue_for_federation", - self._process_event_queue_loop, + "process_event_queue_for_federation", self._process_event_queue_loop ) @defer.inlineCallbacks @@ -152,7 +151,7 @@ class FederationSender(object): while True: last_token = yield self.store.get_federation_out_pos("events") next_token, events = yield self.store.get_all_new_events_stream( - last_token, self._last_poked_id, limit=100, + last_token, self._last_poked_id, limit=100 ) logger.debug("Handling %s -> %s", last_token, next_token) @@ -179,7 +178,7 @@ class FederationSender(object): # banned then it won't receive the event because it won't # be in the room after the ban. destinations = yield self.state.get_current_hosts_in_room( - event.room_id, latest_event_ids=event.prev_event_ids(), + event.room_id, latest_event_ids=event.prev_event_ids() ) except Exception: logger.exception( @@ -209,37 +208,40 @@ class FederationSender(object): for event in events: events_by_room.setdefault(event.room_id, []).append(event) - yield logcontext.make_deferred_yieldable(defer.gatherResults( - [ - logcontext.run_in_background(handle_room_events, evs) - for evs in itervalues(events_by_room) - ], - consumeErrors=True - )) - - yield self.store.update_federation_out_pos( - "events", next_token + yield logcontext.make_deferred_yieldable( + defer.gatherResults( + [ + logcontext.run_in_background(handle_room_events, evs) + for evs in itervalues(events_by_room) + ], + consumeErrors=True, + ) ) + yield self.store.update_federation_out_pos("events", next_token) + if events: now = self.clock.time_msec() ts = yield self.store.get_received_ts(events[-1].event_id) synapse.metrics.event_processing_lag.labels( - "federation_sender").set(now - ts) + "federation_sender" + ).set(now - ts) synapse.metrics.event_processing_last_ts.labels( - "federation_sender").set(ts) + "federation_sender" + ).set(ts) events_processed_counter.inc(len(events)) - event_processing_loop_room_count.labels( - "federation_sender" - ).inc(len(events_by_room)) + event_processing_loop_room_count.labels("federation_sender").inc( + len(events_by_room) + ) event_processing_loop_counter.labels("federation_sender").inc() synapse.metrics.event_processing_positions.labels( - "federation_sender").set(next_token) + "federation_sender" + ).set(next_token) finally: self._is_processing = False @@ -312,9 +314,7 @@ class FederationSender(object): if not domains: return - queues_pending_flush = self._queues_awaiting_rr_flush_by_room.get( - room_id - ) + queues_pending_flush = self._queues_awaiting_rr_flush_by_room.get(room_id) # if there is no flush yet scheduled, we will send out these receipts with # immediate flushes, and schedule the next flush for this room. @@ -377,10 +377,9 @@ class FederationSender(object): # updates in quick succession are correctly handled. # We only want to send presence for our own users, so lets always just # filter here just in case. - self.pending_presence.update({ - state.user_id: state for state in states - if self.is_mine_id(state.user_id) - }) + self.pending_presence.update( + {state.user_id: state for state in states if self.is_mine_id(state.user_id)} + ) # We then handle the new pending presence in batches, first figuring # out the destinations we need to send each state to and then poking it diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 22a2735405..9aab12c0d3 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -360,7 +360,7 @@ class PerDestinationQueue(object): # Retrieve list of new device updates to send to the destination now_stream_id, results = yield self._store.get_devices_by_remote( - self._destination, last_device_list, limit=limit, + self._destination, last_device_list, limit=limit ) edus = [ Edu( @@ -381,10 +381,7 @@ class PerDestinationQueue(object): last_device_stream_id = self._last_device_stream_id to_device_stream_id = self._store.get_to_device_stream_token() contents, stream_id = yield self._store.get_new_device_msgs_for_remote( - self._destination, - last_device_stream_id, - to_device_stream_id, - limit, + self._destination, last_device_stream_id, to_device_stream_id, limit ) edus = [ Edu( diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 35e6b8ff5b..c987bb9a0d 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -29,9 +29,10 @@ class TransactionManager(object): shared between PerDestinationQueue objects """ + def __init__(self, hs): self._server_name = hs.hostname - self.clock = hs.get_clock() # nb must be called this for @measure_func + self.clock = hs.get_clock() # nb must be called this for @measure_func self._store = hs.get_datastore() self._transaction_actions = TransactionActions(self._store) self._transport_layer = hs.get_federation_transport_client() @@ -55,9 +56,9 @@ class TransactionManager(object): txn_id = str(self._next_txn_id) logger.debug( - "TX [%s] {%s} Attempting new transaction" - " (pdus: %d, edus: %d)", - destination, txn_id, + "TX [%s] {%s} Attempting new transaction" " (pdus: %d, edus: %d)", + destination, + txn_id, len(pdus), len(edus), ) @@ -79,9 +80,9 @@ class TransactionManager(object): logger.debug("TX [%s] Persisted transaction", destination) logger.info( - "TX [%s] {%s} Sending transaction [%s]," - " (PDUs: %d, EDUs: %d)", - destination, txn_id, + "TX [%s] {%s} Sending transaction [%s]," " (PDUs: %d, EDUs: %d)", + destination, + txn_id, transaction.transaction_id, len(pdus), len(edus), @@ -112,20 +113,12 @@ class TransactionManager(object): response = e.response if e.code in (401, 404, 429) or 500 <= e.code: - logger.info( - "TX [%s] {%s} got %d response", - destination, txn_id, code - ) + logger.info("TX [%s] {%s} got %d response", destination, txn_id, code) raise e - logger.info( - "TX [%s] {%s} got %d response", - destination, txn_id, code - ) + logger.info("TX [%s] {%s} got %d response", destination, txn_id, code) - yield self._transaction_actions.delivered( - transaction, code, response - ) + yield self._transaction_actions.delivered(transaction, code, response) logger.debug("TX [%s] {%s} Marked as delivered", destination, txn_id) @@ -134,13 +127,18 @@ class TransactionManager(object): if "error" in r: logger.warn( "TX [%s] {%s} Remote returned error for %s: %s", - destination, txn_id, e_id, r, + destination, + txn_id, + e_id, + r, ) else: for p in pdus: logger.warn( "TX [%s] {%s} Failed to send event %s", - destination, txn_id, p.event_id, + destination, + txn_id, + p.event_id, ) success = False diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index e424c40fdf..aecd142309 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -48,12 +48,13 @@ class TransportLayerClient(object): Returns: Deferred: Results in a dict received from the remote homeserver. """ - logger.debug("get_room_state dest=%s, room=%s", - destination, room_id) + logger.debug("get_room_state dest=%s, room=%s", destination, room_id) path = _create_v1_path("/state/%s", room_id) return self.client.get_json( - destination, path=path, args={"event_id": event_id}, + destination, + path=path, + args={"event_id": event_id}, try_trailing_slash_on_400=True, ) @@ -71,12 +72,13 @@ class TransportLayerClient(object): Returns: Deferred: Results in a dict received from the remote homeserver. """ - logger.debug("get_room_state_ids dest=%s, room=%s", - destination, room_id) + logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id) path = _create_v1_path("/state_ids/%s", room_id) return self.client.get_json( - destination, path=path, args={"event_id": event_id}, + destination, + path=path, + args={"event_id": event_id}, try_trailing_slash_on_400=True, ) @@ -94,13 +96,11 @@ class TransportLayerClient(object): Returns: Deferred: Results in a dict received from the remote homeserver. """ - logger.debug("get_pdu dest=%s, event_id=%s", - destination, event_id) + logger.debug("get_pdu dest=%s, event_id=%s", destination, event_id) path = _create_v1_path("/event/%s", event_id) return self.client.get_json( - destination, path=path, timeout=timeout, - try_trailing_slash_on_400=True, + destination, path=path, timeout=timeout, try_trailing_slash_on_400=True ) @log_function @@ -119,7 +119,10 @@ class TransportLayerClient(object): """ logger.debug( "backfill dest=%s, room_id=%s, event_tuples=%s, limit=%s", - destination, room_id, repr(event_tuples), str(limit) + destination, + room_id, + repr(event_tuples), + str(limit), ) if not event_tuples: @@ -128,16 +131,10 @@ class TransportLayerClient(object): path = _create_v1_path("/backfill/%s", room_id) - args = { - "v": event_tuples, - "limit": [str(limit)], - } + args = {"v": event_tuples, "limit": [str(limit)]} return self.client.get_json( - destination, - path=path, - args=args, - try_trailing_slash_on_400=True, + destination, path=path, args=args, try_trailing_slash_on_400=True ) @defer.inlineCallbacks @@ -163,7 +160,8 @@ class TransportLayerClient(object): """ logger.debug( "send_data dest=%s, txid=%s", - transaction.destination, transaction.transaction_id + transaction.destination, + transaction.transaction_id, ) if transaction.destination == self.server_name: @@ -189,8 +187,9 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function - def make_query(self, destination, query_type, args, retry_on_dns_fail, - ignore_backoff=False): + def make_query( + self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False + ): path = _create_v1_path("/query/%s", query_type) content = yield self.client.get_json( @@ -235,8 +234,8 @@ class TransportLayerClient(object): valid_memberships = {Membership.JOIN, Membership.LEAVE} if membership not in valid_memberships: raise RuntimeError( - "make_membership_event called with membership='%s', must be one of %s" % - (membership, ",".join(valid_memberships)) + "make_membership_event called with membership='%s', must be one of %s" + % (membership, ",".join(valid_memberships)) ) path = _create_v1_path("/make_%s/%s/%s", membership, room_id, user_id) @@ -268,9 +267,7 @@ class TransportLayerClient(object): path = _create_v1_path("/send_join/%s/%s", room_id, event_id) response = yield self.client.put_json( - destination=destination, - path=path, - data=content, + destination=destination, path=path, data=content ) defer.returnValue(response) @@ -284,7 +281,6 @@ class TransportLayerClient(object): destination=destination, path=path, data=content, - # we want to do our best to send this through. The problem is # that if it fails, we won't retry it later, so if the remote # server was just having a momentary blip, the room will be out of @@ -300,10 +296,7 @@ class TransportLayerClient(object): path = _create_v1_path("/invite/%s/%s", room_id, event_id) response = yield self.client.put_json( - destination=destination, - path=path, - data=content, - ignore_backoff=True, + destination=destination, path=path, data=content, ignore_backoff=True ) defer.returnValue(response) @@ -314,26 +307,27 @@ class TransportLayerClient(object): path = _create_v2_path("/invite/%s/%s", room_id, event_id) response = yield self.client.put_json( - destination=destination, - path=path, - data=content, - ignore_backoff=True, + destination=destination, path=path, data=content, ignore_backoff=True ) defer.returnValue(response) @defer.inlineCallbacks @log_function - def get_public_rooms(self, remote_server, limit, since_token, - search_filter=None, include_all_networks=False, - third_party_instance_id=None): + def get_public_rooms( + self, + remote_server, + limit, + since_token, + search_filter=None, + include_all_networks=False, + third_party_instance_id=None, + ): path = _create_v1_path("/publicRooms") - args = { - "include_all_networks": "true" if include_all_networks else "false", - } + args = {"include_all_networks": "true" if include_all_networks else "false"} if third_party_instance_id: - args["third_party_instance_id"] = third_party_instance_id, + args["third_party_instance_id"] = (third_party_instance_id,) if limit: args["limit"] = [str(limit)] if since_token: @@ -342,10 +336,7 @@ class TransportLayerClient(object): # TODO(erikj): Actually send the search_filter across federation. response = yield self.client.get_json( - destination=remote_server, - path=path, - args=args, - ignore_backoff=True, + destination=remote_server, path=path, args=args, ignore_backoff=True ) defer.returnValue(response) @@ -353,12 +344,10 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def exchange_third_party_invite(self, destination, room_id, event_dict): - path = _create_v1_path("/exchange_third_party_invite/%s", room_id,) + path = _create_v1_path("/exchange_third_party_invite/%s", room_id) response = yield self.client.put_json( - destination=destination, - path=path, - data=event_dict, + destination=destination, path=path, data=event_dict ) defer.returnValue(response) @@ -368,10 +357,7 @@ class TransportLayerClient(object): def get_event_auth(self, destination, room_id, event_id): path = _create_v1_path("/event_auth/%s/%s", room_id, event_id) - content = yield self.client.get_json( - destination=destination, - path=path, - ) + content = yield self.client.get_json(destination=destination, path=path) defer.returnValue(content) @@ -381,9 +367,7 @@ class TransportLayerClient(object): path = _create_v1_path("/query_auth/%s/%s", room_id, event_id) content = yield self.client.post_json( - destination=destination, - path=path, - data=content, + destination=destination, path=path, data=content ) defer.returnValue(content) @@ -416,10 +400,7 @@ class TransportLayerClient(object): path = _create_v1_path("/user/keys/query") content = yield self.client.post_json( - destination=destination, - path=path, - data=query_content, - timeout=timeout, + destination=destination, path=path, data=query_content, timeout=timeout ) defer.returnValue(content) @@ -443,9 +424,7 @@ class TransportLayerClient(object): path = _create_v1_path("/user/devices/%s", user_id) content = yield self.client.get_json( - destination=destination, - path=path, - timeout=timeout, + destination=destination, path=path, timeout=timeout ) defer.returnValue(content) @@ -479,18 +458,23 @@ class TransportLayerClient(object): path = _create_v1_path("/user/keys/claim") content = yield self.client.post_json( - destination=destination, - path=path, - data=query_content, - timeout=timeout, + destination=destination, path=path, data=query_content, timeout=timeout ) defer.returnValue(content) @defer.inlineCallbacks @log_function - def get_missing_events(self, destination, room_id, earliest_events, - latest_events, limit, min_depth, timeout): - path = _create_v1_path("/get_missing_events/%s", room_id,) + def get_missing_events( + self, + destination, + room_id, + earliest_events, + latest_events, + limit, + min_depth, + timeout, + ): + path = _create_v1_path("/get_missing_events/%s", room_id) content = yield self.client.post_json( destination=destination, @@ -510,7 +494,7 @@ class TransportLayerClient(object): def get_group_profile(self, destination, group_id, requester_user_id): """Get a group profile """ - path = _create_v1_path("/groups/%s/profile", group_id,) + path = _create_v1_path("/groups/%s/profile", group_id) return self.client.get_json( destination=destination, @@ -529,7 +513,7 @@ class TransportLayerClient(object): requester_user_id (str) content (dict): The new profile of the group """ - path = _create_v1_path("/groups/%s/profile", group_id,) + path = _create_v1_path("/groups/%s/profile", group_id) return self.client.post_json( destination=destination, @@ -543,7 +527,7 @@ class TransportLayerClient(object): def get_group_summary(self, destination, group_id, requester_user_id): """Get a group summary """ - path = _create_v1_path("/groups/%s/summary", group_id,) + path = _create_v1_path("/groups/%s/summary", group_id) return self.client.get_json( destination=destination, @@ -556,7 +540,7 @@ class TransportLayerClient(object): def get_rooms_in_group(self, destination, group_id, requester_user_id): """Get all rooms in a group """ - path = _create_v1_path("/groups/%s/rooms", group_id,) + path = _create_v1_path("/groups/%s/rooms", group_id) return self.client.get_json( destination=destination, @@ -565,11 +549,12 @@ class TransportLayerClient(object): ignore_backoff=True, ) - def add_room_to_group(self, destination, group_id, requester_user_id, room_id, - content): + def add_room_to_group( + self, destination, group_id, requester_user_id, room_id, content + ): """Add a room to a group """ - path = _create_v1_path("/groups/%s/room/%s", group_id, room_id,) + path = _create_v1_path("/groups/%s/room/%s", group_id, room_id) return self.client.post_json( destination=destination, @@ -579,13 +564,13 @@ class TransportLayerClient(object): ignore_backoff=True, ) - def update_room_in_group(self, destination, group_id, requester_user_id, room_id, - config_key, content): + def update_room_in_group( + self, destination, group_id, requester_user_id, room_id, config_key, content + ): """Update room in group """ path = _create_v1_path( - "/groups/%s/room/%s/config/%s", - group_id, room_id, config_key, + "/groups/%s/room/%s/config/%s", group_id, room_id, config_key ) return self.client.post_json( @@ -599,7 +584,7 @@ class TransportLayerClient(object): def remove_room_from_group(self, destination, group_id, requester_user_id, room_id): """Remove a room from a group """ - path = _create_v1_path("/groups/%s/room/%s", group_id, room_id,) + path = _create_v1_path("/groups/%s/room/%s", group_id, room_id) return self.client.delete_json( destination=destination, @@ -612,7 +597,7 @@ class TransportLayerClient(object): def get_users_in_group(self, destination, group_id, requester_user_id): """Get users in a group """ - path = _create_v1_path("/groups/%s/users", group_id,) + path = _create_v1_path("/groups/%s/users", group_id) return self.client.get_json( destination=destination, @@ -625,7 +610,7 @@ class TransportLayerClient(object): def get_invited_users_in_group(self, destination, group_id, requester_user_id): """Get users that have been invited to a group """ - path = _create_v1_path("/groups/%s/invited_users", group_id,) + path = _create_v1_path("/groups/%s/invited_users", group_id) return self.client.get_json( destination=destination, @@ -638,16 +623,10 @@ class TransportLayerClient(object): def accept_group_invite(self, destination, group_id, user_id, content): """Accept a group invite """ - path = _create_v1_path( - "/groups/%s/users/%s/accept_invite", - group_id, user_id, - ) + path = _create_v1_path("/groups/%s/users/%s/accept_invite", group_id, user_id) return self.client.post_json( - destination=destination, - path=path, - data=content, - ignore_backoff=True, + destination=destination, path=path, data=content, ignore_backoff=True ) @log_function @@ -657,14 +636,13 @@ class TransportLayerClient(object): path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id) return self.client.post_json( - destination=destination, - path=path, - data=content, - ignore_backoff=True, + destination=destination, path=path, data=content, ignore_backoff=True ) @log_function - def invite_to_group(self, destination, group_id, user_id, requester_user_id, content): + def invite_to_group( + self, destination, group_id, user_id, requester_user_id, content + ): """Invite a user to a group """ path = _create_v1_path("/groups/%s/users/%s/invite", group_id, user_id) @@ -686,15 +664,13 @@ class TransportLayerClient(object): path = _create_v1_path("/groups/local/%s/users/%s/invite", group_id, user_id) return self.client.post_json( - destination=destination, - path=path, - data=content, - ignore_backoff=True, + destination=destination, path=path, data=content, ignore_backoff=True ) @log_function - def remove_user_from_group(self, destination, group_id, requester_user_id, - user_id, content): + def remove_user_from_group( + self, destination, group_id, requester_user_id, user_id, content + ): """Remove a user fron a group """ path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id) @@ -708,8 +684,9 @@ class TransportLayerClient(object): ) @log_function - def remove_user_from_group_notification(self, destination, group_id, user_id, - content): + def remove_user_from_group_notification( + self, destination, group_id, user_id, content + ): """Sent by group server to inform a user's server that they have been kicked from the group. """ @@ -717,10 +694,7 @@ class TransportLayerClient(object): path = _create_v1_path("/groups/local/%s/users/%s/remove", group_id, user_id) return self.client.post_json( - destination=destination, - path=path, - data=content, - ignore_backoff=True, + destination=destination, path=path, data=content, ignore_backoff=True ) @log_function @@ -732,24 +706,24 @@ class TransportLayerClient(object): path = _create_v1_path("/groups/%s/renew_attestation/%s", group_id, user_id) return self.client.post_json( - destination=destination, - path=path, - data=content, - ignore_backoff=True, + destination=destination, path=path, data=content, ignore_backoff=True ) @log_function - def update_group_summary_room(self, destination, group_id, user_id, room_id, - category_id, content): + def update_group_summary_room( + self, destination, group_id, user_id, room_id, category_id, content + ): """Update a room entry in a group summary """ if category_id: path = _create_v1_path( "/groups/%s/summary/categories/%s/rooms/%s", - group_id, category_id, room_id, + group_id, + category_id, + room_id, ) else: - path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id,) + path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id) return self.client.post_json( destination=destination, @@ -760,17 +734,20 @@ class TransportLayerClient(object): ) @log_function - def delete_group_summary_room(self, destination, group_id, user_id, room_id, - category_id): + def delete_group_summary_room( + self, destination, group_id, user_id, room_id, category_id + ): """Delete a room entry in a group summary """ if category_id: path = _create_v1_path( "/groups/%s/summary/categories/%s/rooms/%s", - group_id, category_id, room_id, + group_id, + category_id, + room_id, ) else: - path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id,) + path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id) return self.client.delete_json( destination=destination, @@ -783,7 +760,7 @@ class TransportLayerClient(object): def get_group_categories(self, destination, group_id, requester_user_id): """Get all categories in a group """ - path = _create_v1_path("/groups/%s/categories", group_id,) + path = _create_v1_path("/groups/%s/categories", group_id) return self.client.get_json( destination=destination, @@ -796,7 +773,7 @@ class TransportLayerClient(object): def get_group_category(self, destination, group_id, requester_user_id, category_id): """Get category info in a group """ - path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,) + path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) return self.client.get_json( destination=destination, @@ -806,11 +783,12 @@ class TransportLayerClient(object): ) @log_function - def update_group_category(self, destination, group_id, requester_user_id, category_id, - content): + def update_group_category( + self, destination, group_id, requester_user_id, category_id, content + ): """Update a category in a group """ - path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,) + path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) return self.client.post_json( destination=destination, @@ -821,11 +799,12 @@ class TransportLayerClient(object): ) @log_function - def delete_group_category(self, destination, group_id, requester_user_id, - category_id): + def delete_group_category( + self, destination, group_id, requester_user_id, category_id + ): """Delete a category in a group """ - path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,) + path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) return self.client.delete_json( destination=destination, @@ -838,7 +817,7 @@ class TransportLayerClient(object): def get_group_roles(self, destination, group_id, requester_user_id): """Get all roles in a group """ - path = _create_v1_path("/groups/%s/roles", group_id,) + path = _create_v1_path("/groups/%s/roles", group_id) return self.client.get_json( destination=destination, @@ -851,7 +830,7 @@ class TransportLayerClient(object): def get_group_role(self, destination, group_id, requester_user_id, role_id): """Get a roles info """ - path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,) + path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) return self.client.get_json( destination=destination, @@ -861,11 +840,12 @@ class TransportLayerClient(object): ) @log_function - def update_group_role(self, destination, group_id, requester_user_id, role_id, - content): + def update_group_role( + self, destination, group_id, requester_user_id, role_id, content + ): """Update a role in a group """ - path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,) + path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) return self.client.post_json( destination=destination, @@ -879,7 +859,7 @@ class TransportLayerClient(object): def delete_group_role(self, destination, group_id, requester_user_id, role_id): """Delete a role in a group """ - path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,) + path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) return self.client.delete_json( destination=destination, @@ -889,17 +869,17 @@ class TransportLayerClient(object): ) @log_function - def update_group_summary_user(self, destination, group_id, requester_user_id, - user_id, role_id, content): + def update_group_summary_user( + self, destination, group_id, requester_user_id, user_id, role_id, content + ): """Update a users entry in a group """ if role_id: path = _create_v1_path( - "/groups/%s/summary/roles/%s/users/%s", - group_id, role_id, user_id, + "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id ) else: - path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id,) + path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id) return self.client.post_json( destination=destination, @@ -910,11 +890,10 @@ class TransportLayerClient(object): ) @log_function - def set_group_join_policy(self, destination, group_id, requester_user_id, - content): + def set_group_join_policy(self, destination, group_id, requester_user_id, content): """Sets the join policy for a group """ - path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id,) + path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id) return self.client.put_json( destination=destination, @@ -925,17 +904,17 @@ class TransportLayerClient(object): ) @log_function - def delete_group_summary_user(self, destination, group_id, requester_user_id, - user_id, role_id): + def delete_group_summary_user( + self, destination, group_id, requester_user_id, user_id, role_id + ): """Delete a users entry in a group """ if role_id: path = _create_v1_path( - "/groups/%s/summary/roles/%s/users/%s", - group_id, role_id, user_id, + "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id ) else: - path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id,) + path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id) return self.client.delete_json( destination=destination, @@ -953,10 +932,7 @@ class TransportLayerClient(object): content = {"user_ids": user_ids} return self.client.post_json( - destination=destination, - path=path, - data=content, - ignore_backoff=True, + destination=destination, path=path, data=content, ignore_backoff=True ) @@ -975,9 +951,8 @@ def _create_v1_path(path, *args): Returns: str """ - return ( - FEDERATION_V1_PREFIX - + path % tuple(urllib.parse.quote(arg, "") for arg in args) + return FEDERATION_V1_PREFIX + path % tuple( + urllib.parse.quote(arg, "") for arg in args ) @@ -996,7 +971,6 @@ def _create_v2_path(path, *args): Returns: str """ - return ( - FEDERATION_V2_PREFIX - + path % tuple(urllib.parse.quote(arg, "") for arg in args) + return FEDERATION_V2_PREFIX + path % tuple( + urllib.parse.quote(arg, "") for arg in args ) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 6cf213b895..955f0f4308 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -66,8 +66,7 @@ class TransportLayerServer(JsonResource): self.authenticator = Authenticator(hs) self.ratelimiter = FederationRateLimiter( - self.clock, - config=hs.config.rc_federation, + self.clock, config=hs.config.rc_federation ) self.register_servlets() @@ -84,11 +83,13 @@ class TransportLayerServer(JsonResource): class AuthenticationError(SynapseError): """There was a problem authenticating the request""" + pass class NoAuthenticationError(AuthenticationError): """The request had no authentication information""" + pass @@ -105,8 +106,8 @@ class Authenticator(object): def authenticate_request(self, request, content): now = self._clock.time_msec() json_request = { - "method": request.method.decode('ascii'), - "uri": request.uri.decode('ascii'), + "method": request.method.decode("ascii"), + "uri": request.uri.decode("ascii"), "destination": self.server_name, "signatures": {}, } @@ -120,7 +121,7 @@ class Authenticator(object): if not auth_headers: raise NoAuthenticationError( - 401, "Missing Authorization headers", Codes.UNAUTHORIZED, + 401, "Missing Authorization headers", Codes.UNAUTHORIZED ) for auth in auth_headers: @@ -130,14 +131,14 @@ class Authenticator(object): json_request["signatures"].setdefault(origin, {})[key] = sig if ( - self.federation_domain_whitelist is not None and - origin not in self.federation_domain_whitelist + self.federation_domain_whitelist is not None + and origin not in self.federation_domain_whitelist ): raise FederationDeniedError(origin) if not json_request["signatures"]: raise NoAuthenticationError( - 401, "Missing Authorization headers", Codes.UNAUTHORIZED, + 401, "Missing Authorization headers", Codes.UNAUTHORIZED ) yield self.keyring.verify_json_for_server( @@ -177,12 +178,12 @@ def _parse_auth_header(header_bytes): AuthenticationError if the header could not be parsed """ try: - header_str = header_bytes.decode('utf-8') + header_str = header_bytes.decode("utf-8") params = header_str.split(" ")[1].split(",") param_dict = dict(kv.split("=") for kv in params) def strip_quotes(value): - if value.startswith("\""): + if value.startswith('"'): return value[1:-1] else: return value @@ -198,11 +199,11 @@ def _parse_auth_header(header_bytes): except Exception as e: logger.warn( "Error parsing auth header '%s': %s", - header_bytes.decode('ascii', 'replace'), + header_bytes.decode("ascii", "replace"), e, ) raise AuthenticationError( - 400, "Malformed Authorization header", Codes.UNAUTHORIZED, + 400, "Malformed Authorization header", Codes.UNAUTHORIZED ) @@ -242,6 +243,7 @@ class BaseFederationServlet(object): Exception: other exceptions will be caught, logged, and a 500 will be returned. """ + REQUIRE_AUTH = True PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version @@ -293,9 +295,7 @@ class BaseFederationServlet(object): origin, content, request.args, *args, **kwargs ) else: - response = yield func( - origin, content, request.args, *args, **kwargs - ) + response = yield func(origin, content, request.args, *args, **kwargs) defer.returnValue(response) @@ -343,14 +343,12 @@ class FederationSendServlet(BaseFederationServlet): try: transaction_data = content - logger.debug( - "Decoded %s: %s", - transaction_id, str(transaction_data) - ) + logger.debug("Decoded %s: %s", transaction_id, str(transaction_data)) logger.info( "Received txn %s from %s. (PDUs: %d, EDUs: %d)", - transaction_id, origin, + transaction_id, + origin, len(transaction_data.get("pdus", [])), len(transaction_data.get("edus", [])), ) @@ -361,8 +359,7 @@ class FederationSendServlet(BaseFederationServlet): # Add some extra data to the transaction dict that isn't included # in the request body. transaction_data.update( - transaction_id=transaction_id, - destination=self.server_name + transaction_id=transaction_id, destination=self.server_name ) except Exception as e: @@ -372,7 +369,7 @@ class FederationSendServlet(BaseFederationServlet): try: code, response = yield self.handler.on_incoming_transaction( - origin, transaction_data, + origin, transaction_data ) except Exception: logger.exception("on_incoming_transaction failed") @@ -416,7 +413,7 @@ class FederationBackfillServlet(BaseFederationServlet): PATH = "/backfill/(?P[^/]*)/?" def on_GET(self, origin, content, query, context): - versions = [x.decode('ascii') for x in query[b"v"]] + versions = [x.decode("ascii") for x in query[b"v"]] limit = parse_integer_from_args(query, "limit", None) if not limit: @@ -432,7 +429,7 @@ class FederationQueryServlet(BaseFederationServlet): def on_GET(self, origin, content, query, query_type): return self.handler.on_query_request( query_type, - {k.decode('utf8'): v[0].decode("utf-8") for k, v in query.items()} + {k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()}, ) @@ -456,15 +453,14 @@ class FederationMakeJoinServlet(BaseFederationServlet): Deferred[(int, object)|None]: either (response code, response object) to return a JSON response, or None if the request has already been handled. """ - versions = query.get(b'ver') + versions = query.get(b"ver") if versions is not None: supported_versions = [v.decode("utf-8") for v in versions] else: supported_versions = ["1"] content = yield self.handler.on_make_join_request( - origin, context, user_id, - supported_versions=supported_versions, + origin, context, user_id, supported_versions=supported_versions ) defer.returnValue((200, content)) @@ -474,9 +470,7 @@ class FederationMakeLeaveServlet(BaseFederationServlet): @defer.inlineCallbacks def on_GET(self, origin, content, query, context, user_id): - content = yield self.handler.on_make_leave_request( - origin, context, user_id, - ) + content = yield self.handler.on_make_leave_request(origin, context, user_id) defer.returnValue((200, content)) @@ -517,7 +511,7 @@ class FederationV1InviteServlet(BaseFederationServlet): # state resolution algorithm, and we don't use that for processing # invites content = yield self.handler.on_invite_request( - origin, content, room_version=RoomVersions.V1.identifier, + origin, content, room_version=RoomVersions.V1.identifier ) # V1 federation API is defined to return a content of `[200, {...}]` @@ -545,7 +539,7 @@ class FederationV2InviteServlet(BaseFederationServlet): event.setdefault("unsigned", {})["invite_room_state"] = invite_room_state content = yield self.handler.on_invite_request( - origin, event, room_version=room_version, + origin, event, room_version=room_version ) defer.returnValue((200, content)) @@ -629,8 +623,10 @@ class On3pidBindServlet(BaseFederationServlet): for invite in content["invites"]: try: if "signed" not in invite or "token" not in invite["signed"]: - message = ("Rejecting received notification of third-" - "party invite without signed: %s" % (invite,)) + message = ( + "Rejecting received notification of third-" + "party invite without signed: %s" % (invite,) + ) logger.info(message) raise SynapseError(400, message) yield self.handler.exchange_third_party_invite( @@ -671,18 +667,23 @@ class OpenIdUserInfo(BaseFederationServlet): def on_GET(self, origin, content, query): token = query.get(b"access_token", [None])[0] if token is None: - defer.returnValue((401, { - "errcode": "M_MISSING_TOKEN", "error": "Access Token required" - })) + defer.returnValue( + (401, {"errcode": "M_MISSING_TOKEN", "error": "Access Token required"}) + ) return - user_id = yield self.handler.on_openid_userinfo(token.decode('ascii')) + user_id = yield self.handler.on_openid_userinfo(token.decode("ascii")) if user_id is None: - defer.returnValue((401, { - "errcode": "M_UNKNOWN_TOKEN", - "error": "Access Token unknown or expired" - })) + defer.returnValue( + ( + 401, + { + "errcode": "M_UNKNOWN_TOKEN", + "error": "Access Token unknown or expired", + }, + ) + ) defer.returnValue((200, {"sub": user_id})) @@ -722,7 +723,7 @@ class PublicRoomList(BaseFederationServlet): def __init__(self, handler, authenticator, ratelimiter, server_name, allow_access): super(PublicRoomList, self).__init__( - handler, authenticator, ratelimiter, server_name, + handler, authenticator, ratelimiter, server_name ) self.allow_access = allow_access @@ -748,9 +749,7 @@ class PublicRoomList(BaseFederationServlet): network_tuple = ThirdPartyInstanceID(None, None) data = yield self.handler.get_local_public_room_list( - limit, since_token, - network_tuple=network_tuple, - from_federation=True, + limit, since_token, network_tuple=network_tuple, from_federation=True ) defer.returnValue((200, data)) @@ -761,17 +760,18 @@ class FederationVersionServlet(BaseFederationServlet): REQUIRE_AUTH = False def on_GET(self, origin, content, query): - return defer.succeed((200, { - "server": { - "name": "Synapse", - "version": get_version_string(synapse) - }, - })) + return defer.succeed( + ( + 200, + {"server": {"name": "Synapse", "version": get_version_string(synapse)}}, + ) + ) class FederationGroupsProfileServlet(BaseFederationServlet): """Get/set the basic profile of a group on behalf of a user """ + PATH = "/groups/(?P[^/]*)/profile" @defer.inlineCallbacks @@ -780,9 +780,7 @@ class FederationGroupsProfileServlet(BaseFederationServlet): if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - new_content = yield self.handler.get_group_profile( - group_id, requester_user_id - ) + new_content = yield self.handler.get_group_profile(group_id, requester_user_id) defer.returnValue((200, new_content)) @@ -808,9 +806,7 @@ class FederationGroupsSummaryServlet(BaseFederationServlet): if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - new_content = yield self.handler.get_group_summary( - group_id, requester_user_id - ) + new_content = yield self.handler.get_group_summary(group_id, requester_user_id) defer.returnValue((200, new_content)) @@ -818,6 +814,7 @@ class FederationGroupsSummaryServlet(BaseFederationServlet): class FederationGroupsRoomsServlet(BaseFederationServlet): """Get the rooms in a group on behalf of a user """ + PATH = "/groups/(?P[^/]*)/rooms" @defer.inlineCallbacks @@ -826,9 +823,7 @@ class FederationGroupsRoomsServlet(BaseFederationServlet): if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - new_content = yield self.handler.get_rooms_in_group( - group_id, requester_user_id - ) + new_content = yield self.handler.get_rooms_in_group(group_id, requester_user_id) defer.returnValue((200, new_content)) @@ -836,6 +831,7 @@ class FederationGroupsRoomsServlet(BaseFederationServlet): class FederationGroupsAddRoomsServlet(BaseFederationServlet): """Add/remove room from group """ + PATH = "/groups/(?P[^/]*)/room/(?P[^/]*)" @defer.inlineCallbacks @@ -857,7 +853,7 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet): raise SynapseError(403, "requester_user_id doesn't match origin") new_content = yield self.handler.remove_room_from_group( - group_id, requester_user_id, room_id, + group_id, requester_user_id, room_id ) defer.returnValue((200, new_content)) @@ -866,6 +862,7 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet): class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet): """Update room config in group """ + PATH = ( "/groups/(?P[^/]*)/room/(?P[^/]*)" "/config/(?P[^/]*)" @@ -878,7 +875,7 @@ class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet): raise SynapseError(403, "requester_user_id doesn't match origin") result = yield self.groups_handler.update_room_in_group( - group_id, requester_user_id, room_id, config_key, content, + group_id, requester_user_id, room_id, config_key, content ) defer.returnValue((200, result)) @@ -887,6 +884,7 @@ class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet): class FederationGroupsUsersServlet(BaseFederationServlet): """Get the users in a group on behalf of a user """ + PATH = "/groups/(?P[^/]*)/users" @defer.inlineCallbacks @@ -895,9 +893,7 @@ class FederationGroupsUsersServlet(BaseFederationServlet): if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - new_content = yield self.handler.get_users_in_group( - group_id, requester_user_id - ) + new_content = yield self.handler.get_users_in_group(group_id, requester_user_id) defer.returnValue((200, new_content)) @@ -905,6 +901,7 @@ class FederationGroupsUsersServlet(BaseFederationServlet): class FederationGroupsInvitedUsersServlet(BaseFederationServlet): """Get the users that have been invited to a group """ + PATH = "/groups/(?P[^/]*)/invited_users" @defer.inlineCallbacks @@ -923,6 +920,7 @@ class FederationGroupsInvitedUsersServlet(BaseFederationServlet): class FederationGroupsInviteServlet(BaseFederationServlet): """Ask a group server to invite someone to the group """ + PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/invite" @defer.inlineCallbacks @@ -932,7 +930,7 @@ class FederationGroupsInviteServlet(BaseFederationServlet): raise SynapseError(403, "requester_user_id doesn't match origin") new_content = yield self.handler.invite_to_group( - group_id, user_id, requester_user_id, content, + group_id, user_id, requester_user_id, content ) defer.returnValue((200, new_content)) @@ -941,6 +939,7 @@ class FederationGroupsInviteServlet(BaseFederationServlet): class FederationGroupsAcceptInviteServlet(BaseFederationServlet): """Accept an invitation from the group server """ + PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/accept_invite" @defer.inlineCallbacks @@ -948,9 +947,7 @@ class FederationGroupsAcceptInviteServlet(BaseFederationServlet): if get_domain_from_id(user_id) != origin: raise SynapseError(403, "user_id doesn't match origin") - new_content = yield self.handler.accept_invite( - group_id, user_id, content, - ) + new_content = yield self.handler.accept_invite(group_id, user_id, content) defer.returnValue((200, new_content)) @@ -958,6 +955,7 @@ class FederationGroupsAcceptInviteServlet(BaseFederationServlet): class FederationGroupsJoinServlet(BaseFederationServlet): """Attempt to join a group """ + PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/join" @defer.inlineCallbacks @@ -965,9 +963,7 @@ class FederationGroupsJoinServlet(BaseFederationServlet): if get_domain_from_id(user_id) != origin: raise SynapseError(403, "user_id doesn't match origin") - new_content = yield self.handler.join_group( - group_id, user_id, content, - ) + new_content = yield self.handler.join_group(group_id, user_id, content) defer.returnValue((200, new_content)) @@ -975,6 +971,7 @@ class FederationGroupsJoinServlet(BaseFederationServlet): class FederationGroupsRemoveUserServlet(BaseFederationServlet): """Leave or kick a user from the group """ + PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/remove" @defer.inlineCallbacks @@ -984,7 +981,7 @@ class FederationGroupsRemoveUserServlet(BaseFederationServlet): raise SynapseError(403, "requester_user_id doesn't match origin") new_content = yield self.handler.remove_user_from_group( - group_id, user_id, requester_user_id, content, + group_id, user_id, requester_user_id, content ) defer.returnValue((200, new_content)) @@ -993,6 +990,7 @@ class FederationGroupsRemoveUserServlet(BaseFederationServlet): class FederationGroupsLocalInviteServlet(BaseFederationServlet): """A group server has invited a local user """ + PATH = "/groups/local/(?P[^/]*)/users/(?P[^/]*)/invite" @defer.inlineCallbacks @@ -1000,9 +998,7 @@ class FederationGroupsLocalInviteServlet(BaseFederationServlet): if get_domain_from_id(group_id) != origin: raise SynapseError(403, "group_id doesn't match origin") - new_content = yield self.handler.on_invite( - group_id, user_id, content, - ) + new_content = yield self.handler.on_invite(group_id, user_id, content) defer.returnValue((200, new_content)) @@ -1010,6 +1006,7 @@ class FederationGroupsLocalInviteServlet(BaseFederationServlet): class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet): """A group server has removed a local user """ + PATH = "/groups/local/(?P[^/]*)/users/(?P[^/]*)/remove" @defer.inlineCallbacks @@ -1018,7 +1015,7 @@ class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet): raise SynapseError(403, "user_id doesn't match origin") new_content = yield self.handler.user_removed_from_group( - group_id, user_id, content, + group_id, user_id, content ) defer.returnValue((200, new_content)) @@ -1027,6 +1024,7 @@ class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet): class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): """A group or user's server renews their attestation """ + PATH = "/groups/(?P[^/]*)/renew_attestation/(?P[^/]*)" @defer.inlineCallbacks @@ -1047,6 +1045,7 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet): - /groups/:group/summary/rooms/:room_id - /groups/:group/summary/categories/:category/rooms/:room_id """ + PATH = ( "/groups/(?P[^/]*)/summary" "(/categories/(?P[^/]+))?" @@ -1063,7 +1062,8 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet): raise SynapseError(400, "category_id cannot be empty string") resp = yield self.handler.update_group_summary_room( - group_id, requester_user_id, + group_id, + requester_user_id, room_id=room_id, category_id=category_id, content=content, @@ -1081,9 +1081,7 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet): raise SynapseError(400, "category_id cannot be empty string") resp = yield self.handler.delete_group_summary_room( - group_id, requester_user_id, - room_id=room_id, - category_id=category_id, + group_id, requester_user_id, room_id=room_id, category_id=category_id ) defer.returnValue((200, resp)) @@ -1092,9 +1090,8 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet): class FederationGroupsCategoriesServlet(BaseFederationServlet): """Get all categories for a group """ - PATH = ( - "/groups/(?P[^/]*)/categories/?" - ) + + PATH = "/groups/(?P[^/]*)/categories/?" @defer.inlineCallbacks def on_GET(self, origin, content, query, group_id): @@ -1102,9 +1099,7 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet): if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - resp = yield self.handler.get_group_categories( - group_id, requester_user_id, - ) + resp = yield self.handler.get_group_categories(group_id, requester_user_id) defer.returnValue((200, resp)) @@ -1112,9 +1107,8 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet): class FederationGroupsCategoryServlet(BaseFederationServlet): """Add/remove/get a category in a group """ - PATH = ( - "/groups/(?P[^/]*)/categories/(?P[^/]+)" - ) + + PATH = "/groups/(?P[^/]*)/categories/(?P[^/]+)" @defer.inlineCallbacks def on_GET(self, origin, content, query, group_id, category_id): @@ -1138,7 +1132,7 @@ class FederationGroupsCategoryServlet(BaseFederationServlet): raise SynapseError(400, "category_id cannot be empty string") resp = yield self.handler.upsert_group_category( - group_id, requester_user_id, category_id, content, + group_id, requester_user_id, category_id, content ) defer.returnValue((200, resp)) @@ -1153,7 +1147,7 @@ class FederationGroupsCategoryServlet(BaseFederationServlet): raise SynapseError(400, "category_id cannot be empty string") resp = yield self.handler.delete_group_category( - group_id, requester_user_id, category_id, + group_id, requester_user_id, category_id ) defer.returnValue((200, resp)) @@ -1162,9 +1156,8 @@ class FederationGroupsCategoryServlet(BaseFederationServlet): class FederationGroupsRolesServlet(BaseFederationServlet): """Get roles in a group """ - PATH = ( - "/groups/(?P[^/]*)/roles/?" - ) + + PATH = "/groups/(?P[^/]*)/roles/?" @defer.inlineCallbacks def on_GET(self, origin, content, query, group_id): @@ -1172,9 +1165,7 @@ class FederationGroupsRolesServlet(BaseFederationServlet): if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - resp = yield self.handler.get_group_roles( - group_id, requester_user_id, - ) + resp = yield self.handler.get_group_roles(group_id, requester_user_id) defer.returnValue((200, resp)) @@ -1182,9 +1173,8 @@ class FederationGroupsRolesServlet(BaseFederationServlet): class FederationGroupsRoleServlet(BaseFederationServlet): """Add/remove/get a role in a group """ - PATH = ( - "/groups/(?P[^/]*)/roles/(?P[^/]+)" - ) + + PATH = "/groups/(?P[^/]*)/roles/(?P[^/]+)" @defer.inlineCallbacks def on_GET(self, origin, content, query, group_id, role_id): @@ -1192,9 +1182,7 @@ class FederationGroupsRoleServlet(BaseFederationServlet): if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - resp = yield self.handler.get_group_role( - group_id, requester_user_id, role_id - ) + resp = yield self.handler.get_group_role(group_id, requester_user_id, role_id) defer.returnValue((200, resp)) @@ -1208,7 +1196,7 @@ class FederationGroupsRoleServlet(BaseFederationServlet): raise SynapseError(400, "role_id cannot be empty string") resp = yield self.handler.update_group_role( - group_id, requester_user_id, role_id, content, + group_id, requester_user_id, role_id, content ) defer.returnValue((200, resp)) @@ -1223,7 +1211,7 @@ class FederationGroupsRoleServlet(BaseFederationServlet): raise SynapseError(400, "role_id cannot be empty string") resp = yield self.handler.delete_group_role( - group_id, requester_user_id, role_id, + group_id, requester_user_id, role_id ) defer.returnValue((200, resp)) @@ -1236,6 +1224,7 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet): - /groups/:group/summary/users/:user_id - /groups/:group/summary/roles/:role/users/:user_id """ + PATH = ( "/groups/(?P[^/]*)/summary" "(/roles/(?P[^/]+))?" @@ -1252,7 +1241,8 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet): raise SynapseError(400, "role_id cannot be empty string") resp = yield self.handler.update_group_summary_user( - group_id, requester_user_id, + group_id, + requester_user_id, user_id=user_id, role_id=role_id, content=content, @@ -1270,9 +1260,7 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet): raise SynapseError(400, "role_id cannot be empty string") resp = yield self.handler.delete_group_summary_user( - group_id, requester_user_id, - user_id=user_id, - role_id=role_id, + group_id, requester_user_id, user_id=user_id, role_id=role_id ) defer.returnValue((200, resp)) @@ -1281,14 +1269,13 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet): class FederationGroupsBulkPublicisedServlet(BaseFederationServlet): """Get roles in a group """ - PATH = ( - "/get_groups_publicised" - ) + + PATH = "/get_groups_publicised" @defer.inlineCallbacks def on_POST(self, origin, content, query): resp = yield self.handler.bulk_get_publicised_groups( - content["user_ids"], proxy=False, + content["user_ids"], proxy=False ) defer.returnValue((200, resp)) @@ -1297,6 +1284,7 @@ class FederationGroupsBulkPublicisedServlet(BaseFederationServlet): class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet): """Sets whether a group is joinable without an invite or knock """ + PATH = "/groups/(?P[^/]*)/settings/m.join_policy" @defer.inlineCallbacks @@ -1317,6 +1305,7 @@ class RoomComplexityServlet(BaseFederationServlet): Indicates to other servers how complex (and therefore likely resource-intensive) a public room this server knows about is. """ + PATH = "/rooms/(?P[^/]*)/complexity" PREFIX = FEDERATION_UNSTABLE_PREFIX @@ -1325,9 +1314,7 @@ class RoomComplexityServlet(BaseFederationServlet): store = self.handler.hs.get_datastore() - is_public = yield store.is_room_world_readable_or_publicly_joinable( - room_id - ) + is_public = yield store.is_room_world_readable_or_publicly_joinable(room_id) if not is_public: raise SynapseError(404, "Room not found", errcode=Codes.INVALID_PARAM) @@ -1362,13 +1349,9 @@ FEDERATION_SERVLET_CLASSES = ( RoomComplexityServlet, ) -OPENID_SERVLET_CLASSES = ( - OpenIdUserInfo, -) +OPENID_SERVLET_CLASSES = (OpenIdUserInfo,) -ROOM_LIST_CLASSES = ( - PublicRoomList, -) +ROOM_LIST_CLASSES = (PublicRoomList,) GROUP_SERVER_SERVLET_CLASSES = ( FederationGroupsProfileServlet, @@ -1399,9 +1382,7 @@ GROUP_LOCAL_SERVLET_CLASSES = ( ) -GROUP_ATTESTATION_SERVLET_CLASSES = ( - FederationGroupsRenewAttestaionServlet, -) +GROUP_ATTESTATION_SERVLET_CLASSES = (FederationGroupsRenewAttestaionServlet,) DEFAULT_SERVLET_GROUPS = ( "federation", diff --git a/synapse/federation/units.py b/synapse/federation/units.py index 025a79c022..14aad8f09d 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -32,21 +32,11 @@ class Edu(JsonEncodedObject): internal ID or previous references graph. """ - valid_keys = [ - "origin", - "destination", - "edu_type", - "content", - ] + valid_keys = ["origin", "destination", "edu_type", "content"] - required_keys = [ - "edu_type", - ] + required_keys = ["edu_type"] - internal_keys = [ - "origin", - "destination", - ] + internal_keys = ["origin", "destination"] class Transaction(JsonEncodedObject): @@ -75,10 +65,7 @@ class Transaction(JsonEncodedObject): "edus", ] - internal_keys = [ - "transaction_id", - "destination", - ] + internal_keys = ["transaction_id", "destination"] required_keys = [ "transaction_id", @@ -98,9 +85,7 @@ class Transaction(JsonEncodedObject): del kwargs["edus"] super(Transaction, self).__init__( - transaction_id=transaction_id, - pdus=pdus, - **kwargs + transaction_id=transaction_id, pdus=pdus, **kwargs ) @staticmethod @@ -109,13 +94,9 @@ class Transaction(JsonEncodedObject): transaction_id and origin_server_ts keys. """ if "origin_server_ts" not in kwargs: - raise KeyError( - "Require 'origin_server_ts' to construct a Transaction" - ) + raise KeyError("Require 'origin_server_ts' to construct a Transaction") if "transaction_id" not in kwargs: - raise KeyError( - "Require 'transaction_id' to construct a Transaction" - ) + raise KeyError("Require 'transaction_id' to construct a Transaction") kwargs["pdus"] = [p.get_pdu_json() for p in pdus] diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index 3ba74001d8..e73757570c 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -65,6 +65,7 @@ UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000 class GroupAttestationSigning(object): """Creates and verifies group attestations. """ + def __init__(self, hs): self.keyring = hs.get_keyring() self.clock = hs.get_clock() @@ -113,11 +114,15 @@ class GroupAttestationSigning(object): validity_period *= random.uniform(*DEFAULT_ATTESTATION_JITTER) valid_until_ms = int(self.clock.time_msec() + validity_period) - return sign_json({ - "group_id": group_id, - "user_id": user_id, - "valid_until_ms": valid_until_ms, - }, self.server_name, self.signing_key) + return sign_json( + { + "group_id": group_id, + "user_id": user_id, + "valid_until_ms": valid_until_ms, + }, + self.server_name, + self.signing_key, + ) class GroupAttestionRenewer(object): @@ -134,7 +139,7 @@ class GroupAttestionRenewer(object): if not hs.config.worker_app: self._renew_attestations_loop = self.clock.looping_call( - self._start_renew_attestations, 30 * 60 * 1000, + self._start_renew_attestations, 30 * 60 * 1000 ) @defer.inlineCallbacks @@ -147,9 +152,7 @@ class GroupAttestionRenewer(object): raise SynapseError(400, "Neither user not group are on this server") yield self.attestations.verify_attestation( - attestation, - user_id=user_id, - group_id=group_id, + attestation, user_id=user_id, group_id=group_id ) yield self.store.update_remote_attestion(group_id, user_id, attestation) @@ -180,7 +183,8 @@ class GroupAttestionRenewer(object): else: logger.warn( "Incorrectly trying to do attestations for user: %r in %r", - user_id, group_id, + user_id, + group_id, ) yield self.store.remove_attestation_renewal(group_id, user_id) return @@ -188,8 +192,7 @@ class GroupAttestionRenewer(object): attestation = self.attestations.create_attestation(group_id, user_id) yield self.transport_client.renew_group_attestation( - destination, group_id, user_id, - content={"attestation": attestation}, + destination, group_id, user_id, content={"attestation": attestation} ) yield self.store.update_attestation_renewal( @@ -197,12 +200,12 @@ class GroupAttestionRenewer(object): ) except (RequestSendFailed, HttpResponseException) as e: logger.warning( - "Failed to renew attestation of %r in %r: %s", - user_id, group_id, e, + "Failed to renew attestation of %r in %r: %s", user_id, group_id, e ) except Exception: - logger.exception("Error renewing attestation of %r in %r", - user_id, group_id) + logger.exception( + "Error renewing attestation of %r in %r", user_id, group_id + ) for row in rows: group_id = row["group_id"] diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 817be40360..168c9e3f84 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -54,8 +54,9 @@ class GroupsServerHandler(object): hs.get_groups_attestation_renewer() @defer.inlineCallbacks - def check_group_is_ours(self, group_id, requester_user_id, - and_exists=False, and_is_admin=None): + def check_group_is_ours( + self, group_id, requester_user_id, and_exists=False, and_is_admin=None + ): """Check that the group is ours, and optionally if it exists. If group does exist then return group. @@ -73,7 +74,9 @@ class GroupsServerHandler(object): if and_exists and not group: raise SynapseError(404, "Unknown group") - is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) + is_user_in_group = yield self.store.is_user_in_group( + requester_user_id, group_id + ) if group and not is_user_in_group and not group["is_public"]: raise SynapseError(404, "Unknown group") @@ -96,25 +99,27 @@ class GroupsServerHandler(object): """ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) + is_user_in_group = yield self.store.is_user_in_group( + requester_user_id, group_id + ) profile = yield self.get_group_profile(group_id, requester_user_id) users, roles = yield self.store.get_users_for_summary_by_role( - group_id, include_private=is_user_in_group, + group_id, include_private=is_user_in_group ) # TODO: Add profiles to users rooms, categories = yield self.store.get_rooms_for_summary_by_category( - group_id, include_private=is_user_in_group, + group_id, include_private=is_user_in_group ) for room_entry in rooms: room_id = room_entry["room_id"] joined_users = yield self.store.get_users_in_room(room_id) entry = yield self.room_list_handler.generate_room_entry( - room_id, len(joined_users), with_alias=False, allow_private=True, + room_id, len(joined_users), with_alias=False, allow_private=True ) entry = dict(entry) # so we don't change whats cached entry.pop("room_id", None) @@ -134,7 +139,7 @@ class GroupsServerHandler(object): entry["attestation"] = attestation else: entry["attestation"] = self.attestations.create_attestation( - group_id, user_id, + group_id, user_id ) user_profile = yield self.profile_handler.get_profile_from_cache(user_id) @@ -143,34 +148,34 @@ class GroupsServerHandler(object): users.sort(key=lambda e: e.get("order", 0)) membership_info = yield self.store.get_users_membership_info_in_group( - group_id, requester_user_id, + group_id, requester_user_id ) - defer.returnValue({ - "profile": profile, - "users_section": { - "users": users, - "roles": roles, - "total_user_count_estimate": 0, # TODO - }, - "rooms_section": { - "rooms": rooms, - "categories": categories, - "total_room_count_estimate": 0, # TODO - }, - "user": membership_info, - }) + defer.returnValue( + { + "profile": profile, + "users_section": { + "users": users, + "roles": roles, + "total_user_count_estimate": 0, # TODO + }, + "rooms_section": { + "rooms": rooms, + "categories": categories, + "total_room_count_estimate": 0, # TODO + }, + "user": membership_info, + } + ) @defer.inlineCallbacks - def update_group_summary_room(self, group_id, requester_user_id, - room_id, category_id, content): + def update_group_summary_room( + self, group_id, requester_user_id, room_id, category_id, content + ): """Add/update a room to the group summary """ yield self.check_group_is_ours( - group_id, - requester_user_id, - and_exists=True, - and_is_admin=requester_user_id, + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) RoomID.from_string(room_id) # Ensure valid room id @@ -190,21 +195,17 @@ class GroupsServerHandler(object): defer.returnValue({}) @defer.inlineCallbacks - def delete_group_summary_room(self, group_id, requester_user_id, - room_id, category_id): + def delete_group_summary_room( + self, group_id, requester_user_id, room_id, category_id + ): """Remove a room from the summary """ yield self.check_group_is_ours( - group_id, - requester_user_id, - and_exists=True, - and_is_admin=requester_user_id, + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) yield self.store.remove_room_from_summary( - group_id=group_id, - room_id=room_id, - category_id=category_id, + group_id=group_id, room_id=room_id, category_id=category_id ) defer.returnValue({}) @@ -223,9 +224,7 @@ class GroupsServerHandler(object): join_policy = _parse_join_policy_from_contents(content) if join_policy is None: - raise SynapseError( - 400, "No value specified for 'm.join_policy'" - ) + raise SynapseError(400, "No value specified for 'm.join_policy'") yield self.store.set_group_join_policy(group_id, join_policy=join_policy) @@ -237,9 +236,7 @@ class GroupsServerHandler(object): """ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - categories = yield self.store.get_group_categories( - group_id=group_id, - ) + categories = yield self.store.get_group_categories(group_id=group_id) defer.returnValue({"categories": categories}) @defer.inlineCallbacks @@ -249,8 +246,7 @@ class GroupsServerHandler(object): yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) res = yield self.store.get_group_category( - group_id=group_id, - category_id=category_id, + group_id=group_id, category_id=category_id ) defer.returnValue(res) @@ -260,10 +256,7 @@ class GroupsServerHandler(object): """Add/Update a group category """ yield self.check_group_is_ours( - group_id, - requester_user_id, - and_exists=True, - and_is_admin=requester_user_id, + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) is_public = _parse_visibility_from_contents(content) @@ -283,15 +276,11 @@ class GroupsServerHandler(object): """Delete a group category """ yield self.check_group_is_ours( - group_id, - requester_user_id, - and_exists=True, - and_is_admin=requester_user_id + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) yield self.store.remove_group_category( - group_id=group_id, - category_id=category_id, + group_id=group_id, category_id=category_id ) defer.returnValue({}) @@ -302,9 +291,7 @@ class GroupsServerHandler(object): """ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - roles = yield self.store.get_group_roles( - group_id=group_id, - ) + roles = yield self.store.get_group_roles(group_id=group_id) defer.returnValue({"roles": roles}) @defer.inlineCallbacks @@ -313,10 +300,7 @@ class GroupsServerHandler(object): """ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - res = yield self.store.get_group_role( - group_id=group_id, - role_id=role_id, - ) + res = yield self.store.get_group_role(group_id=group_id, role_id=role_id) defer.returnValue(res) @defer.inlineCallbacks @@ -324,10 +308,7 @@ class GroupsServerHandler(object): """Add/update a role in a group """ yield self.check_group_is_ours( - group_id, - requester_user_id, - and_exists=True, - and_is_admin=requester_user_id, + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) is_public = _parse_visibility_from_contents(content) @@ -335,10 +316,7 @@ class GroupsServerHandler(object): profile = content.get("profile") yield self.store.upsert_group_role( - group_id=group_id, - role_id=role_id, - is_public=is_public, - profile=profile, + group_id=group_id, role_id=role_id, is_public=is_public, profile=profile ) defer.returnValue({}) @@ -348,26 +326,21 @@ class GroupsServerHandler(object): """Remove role from group """ yield self.check_group_is_ours( - group_id, - requester_user_id, - and_exists=True, - and_is_admin=requester_user_id, + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) - yield self.store.remove_group_role( - group_id=group_id, - role_id=role_id, - ) + yield self.store.remove_group_role(group_id=group_id, role_id=role_id) defer.returnValue({}) @defer.inlineCallbacks - def update_group_summary_user(self, group_id, requester_user_id, user_id, role_id, - content): + def update_group_summary_user( + self, group_id, requester_user_id, user_id, role_id, content + ): """Add/update a users entry in the group summary """ yield self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id, + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) order = content.get("order", None) @@ -389,13 +362,11 @@ class GroupsServerHandler(object): """Remove a user from the group summary """ yield self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id, + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) yield self.store.remove_user_from_summary( - group_id=group_id, - user_id=user_id, - role_id=role_id, + group_id=group_id, user_id=user_id, role_id=role_id ) defer.returnValue({}) @@ -411,8 +382,11 @@ class GroupsServerHandler(object): if group: cols = [ - "name", "short_description", "long_description", - "avatar_url", "is_public", + "name", + "short_description", + "long_description", + "avatar_url", + "is_public", ] group_description = {key: group[key] for key in cols} group_description["is_openly_joinable"] = group["join_policy"] == "open" @@ -426,12 +400,11 @@ class GroupsServerHandler(object): """Update the group profile """ yield self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id, + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) profile = {} - for keyname in ("name", "avatar_url", "short_description", - "long_description"): + for keyname in ("name", "avatar_url", "short_description", "long_description"): if keyname in content: value = content[keyname] if not isinstance(value, string_types): @@ -449,10 +422,12 @@ class GroupsServerHandler(object): yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) + is_user_in_group = yield self.store.is_user_in_group( + requester_user_id, group_id + ) user_results = yield self.store.get_users_in_group( - group_id, include_private=is_user_in_group, + group_id, include_private=is_user_in_group ) chunk = [] @@ -470,24 +445,25 @@ class GroupsServerHandler(object): entry["is_privileged"] = bool(is_privileged) if not self.is_mine_id(g_user_id): - attestation = yield self.store.get_remote_attestation(group_id, g_user_id) + attestation = yield self.store.get_remote_attestation( + group_id, g_user_id + ) if not attestation: continue entry["attestation"] = attestation else: entry["attestation"] = self.attestations.create_attestation( - group_id, g_user_id, + group_id, g_user_id ) chunk.append(entry) # TODO: If admin add lists of users whose attestations have timed out - defer.returnValue({ - "chunk": chunk, - "total_user_count_estimate": len(user_results), - }) + defer.returnValue( + {"chunk": chunk, "total_user_count_estimate": len(user_results)} + ) @defer.inlineCallbacks def get_invited_users_in_group(self, group_id, requester_user_id): @@ -498,7 +474,9 @@ class GroupsServerHandler(object): yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) + is_user_in_group = yield self.store.is_user_in_group( + requester_user_id, group_id + ) if not is_user_in_group: raise SynapseError(403, "User not in group") @@ -508,9 +486,7 @@ class GroupsServerHandler(object): user_profiles = [] for user_id in invited_users: - user_profile = { - "user_id": user_id - } + user_profile = {"user_id": user_id} try: profile = yield self.profile_handler.get_profile_from_cache(user_id) user_profile.update(profile) @@ -518,10 +494,9 @@ class GroupsServerHandler(object): logger.warn("Error getting profile for %s: %s", user_id, e) user_profiles.append(user_profile) - defer.returnValue({ - "chunk": user_profiles, - "total_user_count_estimate": len(invited_users), - }) + defer.returnValue( + {"chunk": user_profiles, "total_user_count_estimate": len(invited_users)} + ) @defer.inlineCallbacks def get_rooms_in_group(self, group_id, requester_user_id): @@ -532,10 +507,12 @@ class GroupsServerHandler(object): yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) + is_user_in_group = yield self.store.is_user_in_group( + requester_user_id, group_id + ) room_results = yield self.store.get_rooms_in_group( - group_id, include_private=is_user_in_group, + group_id, include_private=is_user_in_group ) chunk = [] @@ -544,7 +521,7 @@ class GroupsServerHandler(object): joined_users = yield self.store.get_users_in_room(room_id) entry = yield self.room_list_handler.generate_room_entry( - room_id, len(joined_users), with_alias=False, allow_private=True, + room_id, len(joined_users), with_alias=False, allow_private=True ) if not entry: @@ -556,10 +533,9 @@ class GroupsServerHandler(object): chunk.sort(key=lambda e: -e["num_joined_members"]) - defer.returnValue({ - "chunk": chunk, - "total_room_count_estimate": len(room_results), - }) + defer.returnValue( + {"chunk": chunk, "total_room_count_estimate": len(room_results)} + ) @defer.inlineCallbacks def add_room_to_group(self, group_id, requester_user_id, room_id, content): @@ -578,8 +554,9 @@ class GroupsServerHandler(object): defer.returnValue({}) @defer.inlineCallbacks - def update_room_in_group(self, group_id, requester_user_id, room_id, config_key, - content): + def update_room_in_group( + self, group_id, requester_user_id, room_id, config_key, content + ): """Update room in group """ RoomID.from_string(room_id) # Ensure valid room id @@ -592,8 +569,7 @@ class GroupsServerHandler(object): is_public = _parse_visibility_dict(content) yield self.store.update_room_in_group_visibility( - group_id, room_id, - is_public=is_public, + group_id, room_id, is_public=is_public ) else: raise SynapseError(400, "Uknown config option") @@ -625,10 +601,7 @@ class GroupsServerHandler(object): # TODO: Check if user is already invited content = { - "profile": { - "name": group["name"], - "avatar_url": group["avatar_url"], - }, + "profile": {"name": group["name"], "avatar_url": group["avatar_url"]}, "inviter": requester_user_id, } @@ -638,9 +611,7 @@ class GroupsServerHandler(object): local_attestation = None else: local_attestation = self.attestations.create_attestation(group_id, user_id) - content.update({ - "attestation": local_attestation, - }) + content.update({"attestation": local_attestation}) res = yield self.transport_client.invite_to_group_notification( get_domain_from_id(user_id), group_id, user_id, content @@ -658,31 +629,24 @@ class GroupsServerHandler(object): remote_attestation = res["attestation"] yield self.attestations.verify_attestation( - remote_attestation, - user_id=user_id, - group_id=group_id, + remote_attestation, user_id=user_id, group_id=group_id ) else: remote_attestation = None yield self.store.add_user_to_group( - group_id, user_id, + group_id, + user_id, is_admin=False, is_public=False, # TODO local_attestation=local_attestation, remote_attestation=remote_attestation, ) elif res["state"] == "invite": - yield self.store.add_group_invite( - group_id, user_id, - ) - defer.returnValue({ - "state": "invite" - }) + yield self.store.add_group_invite(group_id, user_id) + defer.returnValue({"state": "invite"}) elif res["state"] == "reject": - defer.returnValue({ - "state": "reject" - }) + defer.returnValue({"state": "reject"}) else: raise SynapseError(502, "Unknown state returned by HS") @@ -693,16 +657,12 @@ class GroupsServerHandler(object): See accept_invite, join_group. """ if not self.hs.is_mine_id(user_id): - local_attestation = self.attestations.create_attestation( - group_id, user_id, - ) + local_attestation = self.attestations.create_attestation(group_id, user_id) remote_attestation = content["attestation"] yield self.attestations.verify_attestation( - remote_attestation, - user_id=user_id, - group_id=group_id, + remote_attestation, user_id=user_id, group_id=group_id ) else: local_attestation = None @@ -711,7 +671,8 @@ class GroupsServerHandler(object): is_public = _parse_visibility_from_contents(content) yield self.store.add_user_to_group( - group_id, user_id, + group_id, + user_id, is_admin=False, is_public=is_public, local_attestation=local_attestation, @@ -731,17 +692,14 @@ class GroupsServerHandler(object): yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) is_invited = yield self.store.is_user_invited_to_local_group( - group_id, requester_user_id, + group_id, requester_user_id ) if not is_invited: raise SynapseError(403, "User not invited to group") local_attestation = yield self._add_user(group_id, requester_user_id, content) - defer.returnValue({ - "state": "join", - "attestation": local_attestation, - }) + defer.returnValue({"state": "join", "attestation": local_attestation}) @defer.inlineCallbacks def join_group(self, group_id, requester_user_id, content): @@ -753,15 +711,12 @@ class GroupsServerHandler(object): group_info = yield self.check_group_is_ours( group_id, requester_user_id, and_exists=True ) - if group_info['join_policy'] != "open": + if group_info["join_policy"] != "open": raise SynapseError(403, "Group is not publicly joinable") local_attestation = yield self._add_user(group_id, requester_user_id, content) - defer.returnValue({ - "state": "join", - "attestation": local_attestation, - }) + defer.returnValue({"state": "join", "attestation": local_attestation}) @defer.inlineCallbacks def knock(self, group_id, requester_user_id, content): @@ -800,9 +755,7 @@ class GroupsServerHandler(object): is_kick = True - yield self.store.remove_user_from_group( - group_id, user_id, - ) + yield self.store.remove_user_from_group(group_id, user_id) if is_kick: if self.hs.is_mine_id(user_id): @@ -830,19 +783,20 @@ class GroupsServerHandler(object): if group: raise SynapseError(400, "Group already exists") - is_admin = yield self.auth.is_server_admin(UserID.from_string(requester_user_id)) + is_admin = yield self.auth.is_server_admin( + UserID.from_string(requester_user_id) + ) if not is_admin: if not self.hs.config.enable_group_creation: raise SynapseError( - 403, "Only a server admin can create groups on this server", + 403, "Only a server admin can create groups on this server" ) localpart = group_id_obj.localpart if not localpart.startswith(self.hs.config.group_creation_prefix): raise SynapseError( 400, - "Can only create groups with prefix %r on this server" % ( - self.hs.config.group_creation_prefix, - ), + "Can only create groups with prefix %r on this server" + % (self.hs.config.group_creation_prefix,), ) profile = content.get("profile", {}) @@ -865,21 +819,19 @@ class GroupsServerHandler(object): remote_attestation = content["attestation"] yield self.attestations.verify_attestation( - remote_attestation, - user_id=requester_user_id, - group_id=group_id, + remote_attestation, user_id=requester_user_id, group_id=group_id ) local_attestation = self.attestations.create_attestation( - group_id, - requester_user_id, + group_id, requester_user_id ) else: local_attestation = None remote_attestation = None yield self.store.add_user_to_group( - group_id, requester_user_id, + group_id, + requester_user_id, is_admin=True, is_public=True, # TODO local_attestation=local_attestation, @@ -893,9 +845,7 @@ class GroupsServerHandler(object): avatar_url=user_profile.get("avatar_url"), ) - defer.returnValue({ - "group_id": group_id, - }) + defer.returnValue({"group_id": group_id}) @defer.inlineCallbacks def delete_group(self, group_id, requester_user_id): @@ -911,29 +861,22 @@ class GroupsServerHandler(object): Deferred """ - yield self.check_group_is_ours( - group_id, requester_user_id, - and_exists=True, - ) + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) # Only server admins or group admins can delete groups. - is_admin = yield self.store.is_user_admin_in_group( - group_id, requester_user_id - ) + is_admin = yield self.store.is_user_admin_in_group(group_id, requester_user_id) if not is_admin: is_admin = yield self.auth.is_server_admin( - UserID.from_string(requester_user_id), + UserID.from_string(requester_user_id) ) if not is_admin: raise SynapseError(403, "User is not an admin") # Before deleting the group lets kick everyone out of it - users = yield self.store.get_users_in_group( - group_id, include_private=True, - ) + users = yield self.store.get_users_in_group(group_id, include_private=True) @defer.inlineCallbacks def _kick_user_from_group(user_id): @@ -989,9 +932,7 @@ def _parse_join_policy_dict(join_policy_dict): return "invite" if join_policy_type not in ("invite", "open"): - raise SynapseError( - 400, "Synapse only supports 'invite'/'open' join rule" - ) + raise SynapseError(400, "Synapse only supports 'invite'/'open' join rule") return join_policy_type @@ -1018,7 +959,5 @@ def _parse_visibility_dict(visibility): return True if vis_type not in ("public", "private"): - raise SynapseError( - 400, "Synapse only supports 'public'/'private' visibility" - ) + raise SynapseError(400, "Synapse only supports 'public'/'private' visibility") return vis_type == "public" diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index dca337ec61..c29c78bd65 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -94,14 +94,15 @@ class BaseHandler(object): burst_count = self.hs.config.rc_message.burst_count allowed, time_allowed = self.ratelimiter.can_do_action( - user_id, time_now, + user_id, + time_now, rate_hz=messages_per_second, burst_count=burst_count, update=update, ) if not allowed: raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now)), + retry_after_ms=int(1000 * (time_allowed - time_now)) ) @defer.inlineCallbacks @@ -139,7 +140,7 @@ class BaseHandler(object): if member_event.content["membership"] not in { Membership.JOIN, - Membership.INVITE + Membership.INVITE, }: continue @@ -156,8 +157,7 @@ class BaseHandler(object): # and having homeservers have their own users leave keeps more # of that decision-making and control local to the guest-having # homeserver. - requester = synapse.types.create_requester( - target_user, is_guest=True) + requester = synapse.types.create_requester(target_user, is_guest=True) handler = self.hs.get_room_member_handler() yield handler.update_membership( requester, diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 7fa5d44d29..e62e6cab77 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -20,7 +20,7 @@ class AccountDataEventSource(object): def __init__(self, hs): self.store = hs.get_datastore() - def get_current_key(self, direction='f'): + def get_current_key(self, direction="f"): return self.store.get_max_account_data_stream_id() @defer.inlineCallbacks @@ -34,29 +34,22 @@ class AccountDataEventSource(object): tags = yield self.store.get_updated_tags(user_id, last_stream_id) for room_id, room_tags in tags.items(): - results.append({ - "type": "m.tag", - "content": {"tags": room_tags}, - "room_id": room_id, - }) + results.append( + {"type": "m.tag", "content": {"tags": room_tags}, "room_id": room_id} + ) account_data, room_account_data = ( yield self.store.get_updated_account_data_for_user(user_id, last_stream_id) ) for account_data_type, content in account_data.items(): - results.append({ - "type": account_data_type, - "content": content, - }) + results.append({"type": account_data_type, "content": content}) for room_id, account_data in room_account_data.items(): for account_data_type, content in account_data.items(): - results.append({ - "type": account_data_type, - "content": content, - "room_id": room_id, - }) + results.append( + {"type": account_data_type, "content": content, "room_id": room_id} + ) defer.returnValue((results, current_stream_id)) diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 947237d7da..795e9afe44 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -51,12 +51,10 @@ class AccountValidityHandler(object): app_name = self.hs.config.email_app_name self._subject = self._account_validity.renew_email_subject % { - "app": app_name, + "app": app_name } - self._from_string = self.hs.config.email_notif_from % { - "app": app_name, - } + self._from_string = self.hs.config.email_notif_from % {"app": app_name} except Exception: # If substitution failed, fall back to the bare strings. self._subject = self._account_validity.renew_email_subject @@ -71,16 +69,10 @@ class AccountValidityHandler(object): ) # Check the renewal emails to send and send them every 30min. - self.clock.looping_call( - self.send_renewal_emails, - 30 * 60 * 1000, - ) + self.clock.looping_call(self.send_renewal_emails, 30 * 60 * 1000) # Check every hour to remove expired users from the user directory - self.clock.looping_call( - self._mark_expired_users_as_inactive, - 60 * 60 * 1000, - ) + self.clock.looping_call(self._mark_expired_users_as_inactive, 60 * 60 * 1000) @defer.inlineCallbacks def send_renewal_emails(self): @@ -94,8 +86,7 @@ class AccountValidityHandler(object): if expiring_users: for user in expiring_users: yield self._send_renewal_email( - user_id=user["user_id"], - expiration_ts=user["expiration_ts_ms"], + user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"] ) @defer.inlineCallbacks @@ -154,32 +145,33 @@ class AccountValidityHandler(object): for address in addresses: raw_to = email.utils.parseaddr(address)[1] - multipart_msg = MIMEMultipart('alternative') - multipart_msg['Subject'] = self._subject - multipart_msg['From'] = self._from_string - multipart_msg['To'] = address - multipart_msg['Date'] = email.utils.formatdate() - multipart_msg['Message-ID'] = email.utils.make_msgid() + multipart_msg = MIMEMultipart("alternative") + multipart_msg["Subject"] = self._subject + multipart_msg["From"] = self._from_string + multipart_msg["To"] = address + multipart_msg["Date"] = email.utils.formatdate() + multipart_msg["Message-ID"] = email.utils.make_msgid() multipart_msg.attach(text_part) multipart_msg.attach(html_part) logger.info("Sending renewal email to %s", address) - yield make_deferred_yieldable(self.sendmail( - self.hs.config.email_smtp_host, - self._raw_from, raw_to, multipart_msg.as_string().encode('utf8'), - reactor=self.hs.get_reactor(), - port=self.hs.config.email_smtp_port, - requireAuthentication=self.hs.config.email_smtp_user is not None, - username=self.hs.config.email_smtp_user, - password=self.hs.config.email_smtp_pass, - requireTransportSecurity=self.hs.config.require_transport_security - )) - - yield self.store.set_renewal_mail_status( - user_id=user_id, - email_sent=True, - ) + yield make_deferred_yieldable( + self.sendmail( + self.hs.config.email_smtp_host, + self._raw_from, + raw_to, + multipart_msg.as_string().encode("utf8"), + reactor=self.hs.get_reactor(), + port=self.hs.config.email_smtp_port, + requireAuthentication=self.hs.config.email_smtp_user is not None, + username=self.hs.config.email_smtp_user, + password=self.hs.config.email_smtp_pass, + requireTransportSecurity=self.hs.config.require_transport_security, + ) + ) + + yield self.store.set_renewal_mail_status(user_id=user_id, email_sent=True) @defer.inlineCallbacks def _get_email_addresses_for_user(self, user_id): @@ -264,15 +256,15 @@ class AccountValidityHandler(object): expiration_ts = self.clock.time_msec() + self._account_validity.period yield self.store.set_account_validity_for_user( - user_id=user_id, - expiration_ts=expiration_ts, - email_sent=email_sent, + user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent ) # Check if renewed users should be reintroduced to the user directory if self._show_users_in_user_directory: # Show the user in the directory again by setting them to active - yield self.profile_handler.set_active(UserID.from_string(user_id), True, True) + yield self.profile_handler.set_active( + UserID.from_string(user_id), True, True + ) defer.returnValue(expiration_ts) @@ -286,10 +278,7 @@ class AccountValidityHandler(object): """ # Get expired users expired_user_ids = yield self.store.get_expired_users() - expired_users = [ - UserID.from_string(user_id) - for user_id in expired_user_ids - ] + expired_users = [UserID.from_string(user_id) for user_id in expired_user_ids] # Mark each one as non-active for user in expired_users: diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py index 813777bf18..01e0ef408d 100644 --- a/synapse/handlers/acme.py +++ b/synapse/handlers/acme.py @@ -93,24 +93,20 @@ class AcmeHandler(object): ) well_known = Resource() - well_known.putChild(b'acme-challenge', responder.resource) + well_known.putChild(b"acme-challenge", responder.resource) responder_resource = Resource() - responder_resource.putChild(b'.well-known', well_known) - responder_resource.putChild(b'check', static.Data(b'OK', b'text/plain')) + responder_resource.putChild(b".well-known", well_known) + responder_resource.putChild(b"check", static.Data(b"OK", b"text/plain")) srv = server.Site(responder_resource) bind_addresses = self.hs.config.acme_bind_addresses for host in bind_addresses: logger.info( - "Listening for ACME requests on %s:%i", host, self.hs.config.acme_port, + "Listening for ACME requests on %s:%i", host, self.hs.config.acme_port ) try: - self.reactor.listenTCP( - self.hs.config.acme_port, - srv, - interface=host, - ) + self.reactor.listenTCP(self.hs.config.acme_port, srv, interface=host) except twisted.internet.error.CannotListenError as e: check_bind_error(e, host, bind_addresses) diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 5d629126fc..941ebfa107 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -23,7 +23,6 @@ logger = logging.getLogger(__name__) class AdminHandler(BaseHandler): - def __init__(self, hs): super(AdminHandler, self).__init__(hs) @@ -33,23 +32,17 @@ class AdminHandler(BaseHandler): sessions = yield self.store.get_user_ip_and_agents(user) for session in sessions: - connections.append({ - "ip": session["ip"], - "last_seen": session["last_seen"], - "user_agent": session["user_agent"], - }) + connections.append( + { + "ip": session["ip"], + "last_seen": session["last_seen"], + "user_agent": session["user_agent"], + } + ) ret = { "user_id": user.to_string(), - "devices": { - "": { - "sessions": [ - { - "connections": connections, - } - ] - }, - }, + "devices": {"": {"sessions": [{"connections": connections}]}}, } defer.returnValue(ret) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 17eedf4dbf..5cc89d43f6 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -38,7 +38,6 @@ events_processed_counter = Counter("synapse_handlers_appservice_events_processed class ApplicationServicesHandler(object): - def __init__(self, hs): self.store = hs.get_datastore() self.is_mine_id = hs.is_mine_id @@ -101,9 +100,10 @@ class ApplicationServicesHandler(object): yield self._check_user_exists(event.state_key) if not self.started_scheduler: + def start_scheduler(): return self.scheduler.start().addErrback( - log_failure, "Application Services Failure", + log_failure, "Application Services Failure" ) run_as_background_process("as_scheduler", start_scheduler) @@ -118,10 +118,15 @@ class ApplicationServicesHandler(object): for event in events: yield handle_event(event) - yield make_deferred_yieldable(defer.gatherResults([ - run_in_background(handle_room_events, evs) - for evs in itervalues(events_by_room) - ], consumeErrors=True)) + yield make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background(handle_room_events, evs) + for evs in itervalues(events_by_room) + ], + consumeErrors=True, + ) + ) yield self.store.set_appservice_last_pos(upper_bound) @@ -129,20 +134,23 @@ class ApplicationServicesHandler(object): ts = yield self.store.get_received_ts(events[-1].event_id) synapse.metrics.event_processing_positions.labels( - "appservice_sender").set(upper_bound) + "appservice_sender" + ).set(upper_bound) events_processed_counter.inc(len(events)) - event_processing_loop_room_count.labels( - "appservice_sender" - ).inc(len(events_by_room)) + event_processing_loop_room_count.labels("appservice_sender").inc( + len(events_by_room) + ) event_processing_loop_counter.labels("appservice_sender").inc() synapse.metrics.event_processing_lag.labels( - "appservice_sender").set(now - ts) + "appservice_sender" + ).set(now - ts) synapse.metrics.event_processing_last_ts.labels( - "appservice_sender").set(ts) + "appservice_sender" + ).set(ts) finally: self.is_processing = False @@ -155,13 +163,9 @@ class ApplicationServicesHandler(object): Returns: True if this user exists on at least one application service. """ - user_query_services = yield self._get_services_for_user( - user_id=user_id - ) + user_query_services = yield self._get_services_for_user(user_id=user_id) for user_service in user_query_services: - is_known_user = yield self.appservice_api.query_user( - user_service, user_id - ) + is_known_user = yield self.appservice_api.query_user(user_service, user_id) if is_known_user: defer.returnValue(True) defer.returnValue(False) @@ -179,9 +183,7 @@ class ApplicationServicesHandler(object): room_alias_str = room_alias.to_string() services = self.store.get_app_services() alias_query_services = [ - s for s in services if ( - s.is_interested_in_alias(room_alias_str) - ) + s for s in services if (s.is_interested_in_alias(room_alias_str)) ] for alias_service in alias_query_services: is_known_alias = yield self.appservice_api.query_alias( @@ -189,22 +191,24 @@ class ApplicationServicesHandler(object): ) if is_known_alias: # the alias exists now so don't query more ASes. - result = yield self.store.get_association_from_room_alias( - room_alias - ) + result = yield self.store.get_association_from_room_alias(room_alias) defer.returnValue(result) @defer.inlineCallbacks def query_3pe(self, kind, protocol, fields): services = yield self._get_services_for_3pn(protocol) - results = yield make_deferred_yieldable(defer.DeferredList([ - run_in_background( - self.appservice_api.query_3pe, - service, kind, protocol, fields, + results = yield make_deferred_yieldable( + defer.DeferredList( + [ + run_in_background( + self.appservice_api.query_3pe, service, kind, protocol, fields + ) + for service in services + ], + consumeErrors=True, ) - for service in services - ], consumeErrors=True)) + ) ret = [] for (success, result) in results: @@ -276,18 +280,12 @@ class ApplicationServicesHandler(object): def _get_services_for_user(self, user_id): services = self.store.get_app_services() - interested_list = [ - s for s in services if ( - s.is_interested_in_user(user_id) - ) - ] + interested_list = [s for s in services if (s.is_interested_in_user(user_id))] return defer.succeed(interested_list) def _get_services_for_3pn(self, protocol): services = self.store.get_app_services() - interested_list = [ - s for s in services if s.is_interested_in_protocol(protocol) - ] + interested_list = [s for s in services if s.is_interested_in_protocol(protocol)] return defer.succeed(interested_list) @defer.inlineCallbacks diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 9a2ff177a6..94b3879c61 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -134,13 +134,9 @@ class AuthHandler(BaseHandler): """ # build a list of supported flows - flows = [ - [login_type] for login_type in self._supported_login_types - ] + flows = [[login_type] for login_type in self._supported_login_types] - result, params, _ = yield self.check_auth( - flows, request_body, clientip, - ) + result, params, _ = yield self.check_auth(flows, request_body, clientip) # find the completed login type for login_type in self._supported_login_types: @@ -151,9 +147,7 @@ class AuthHandler(BaseHandler): break else: # this can't happen - raise Exception( - "check_auth returned True but no successful login type", - ) + raise Exception("check_auth returned True but no successful login type") # check that the UI auth matched the access token if user_id != requester.user.to_string(): @@ -215,11 +209,11 @@ class AuthHandler(BaseHandler): authdict = None sid = None - if clientdict and 'auth' in clientdict: - authdict = clientdict['auth'] - del clientdict['auth'] - if 'session' in authdict: - sid = authdict['session'] + if clientdict and "auth" in clientdict: + authdict = clientdict["auth"] + del clientdict["auth"] + if "session" in authdict: + sid = authdict["session"] session = self._get_session_info(sid) if len(clientdict) > 0: @@ -232,27 +226,27 @@ class AuthHandler(BaseHandler): # on a home server. # Revisit: Assumimg the REST APIs do sensible validation, the data # isn't arbintrary. - session['clientdict'] = clientdict + session["clientdict"] = clientdict self._save_session(session) - elif 'clientdict' in session: - clientdict = session['clientdict'] + elif "clientdict" in session: + clientdict = session["clientdict"] if not authdict: raise InteractiveAuthIncompleteError( - self._auth_dict_for_flows(flows, session), + self._auth_dict_for_flows(flows, session) ) - if 'creds' not in session: - session['creds'] = {} - creds = session['creds'] + if "creds" not in session: + session["creds"] = {} + creds = session["creds"] # check auth type currently being presented errordict = {} - if 'type' in authdict: - login_type = authdict['type'] + if "type" in authdict: + login_type = authdict["type"] try: result = yield self._check_auth_dict( - authdict, clientip, password_servlet=password_servlet, + authdict, clientip, password_servlet=password_servlet ) if result: creds[login_type] = result @@ -281,16 +275,15 @@ class AuthHandler(BaseHandler): # and is not sensitive). logger.info( "Auth completed with creds: %r. Client dict has keys: %r", - creds, list(clientdict) + creds, + list(clientdict), ) - defer.returnValue((creds, clientdict, session['id'])) + defer.returnValue((creds, clientdict, session["id"])) ret = self._auth_dict_for_flows(flows, session) - ret['completed'] = list(creds) + ret["completed"] = list(creds) ret.update(errordict) - raise InteractiveAuthIncompleteError( - ret, - ) + raise InteractiveAuthIncompleteError(ret) @defer.inlineCallbacks def add_oob_auth(self, stagetype, authdict, clientip): @@ -300,15 +293,13 @@ class AuthHandler(BaseHandler): """ if stagetype not in self.checkers: raise LoginError(400, "", Codes.MISSING_PARAM) - if 'session' not in authdict: + if "session" not in authdict: raise LoginError(400, "", Codes.MISSING_PARAM) - sess = self._get_session_info( - authdict['session'] - ) - if 'creds' not in sess: - sess['creds'] = {} - creds = sess['creds'] + sess = self._get_session_info(authdict["session"]) + if "creds" not in sess: + sess["creds"] = {} + creds = sess["creds"] result = yield self.checkers[stagetype](authdict, clientip) if result: @@ -329,10 +320,10 @@ class AuthHandler(BaseHandler): not send a session ID, returns None. """ sid = None - if clientdict and 'auth' in clientdict: - authdict = clientdict['auth'] - if 'session' in authdict: - sid = authdict['session'] + if clientdict and "auth" in clientdict: + authdict = clientdict["auth"] + if "session" in authdict: + sid = authdict["session"] return sid def set_session_data(self, session_id, key, value): @@ -347,7 +338,7 @@ class AuthHandler(BaseHandler): value (any): The data to store """ sess = self._get_session_info(session_id) - sess.setdefault('serverdict', {})[key] = value + sess.setdefault("serverdict", {})[key] = value self._save_session(sess) def get_session_data(self, session_id, key, default=None): @@ -360,7 +351,7 @@ class AuthHandler(BaseHandler): default (any): Value to return if the key has not been set """ sess = self._get_session_info(session_id) - return sess.setdefault('serverdict', {}).get(key, default) + return sess.setdefault("serverdict", {}).get(key, default) @defer.inlineCallbacks def _check_auth_dict(self, authdict, clientip, password_servlet=False): @@ -378,15 +369,13 @@ class AuthHandler(BaseHandler): SynapseError if there was a problem with the request LoginError if there was an authentication problem. """ - login_type = authdict['type'] + login_type = authdict["type"] checker = self.checkers.get(login_type) if checker is not None: # XXX: Temporary workaround for having Synapse handle password resets # See AuthHandler.check_auth for further details res = yield checker( - authdict, - clientip=clientip, - password_servlet=password_servlet, + authdict, clientip=clientip, password_servlet=password_servlet ) defer.returnValue(res) @@ -408,13 +397,11 @@ class AuthHandler(BaseHandler): # Client tried to provide captcha but didn't give the parameter: # bad request. raise LoginError( - 400, "Captcha response is required", - errcode=Codes.CAPTCHA_NEEDED + 400, "Captcha response is required", errcode=Codes.CAPTCHA_NEEDED ) logger.info( - "Submitting recaptcha response %s with remoteip %s", - user_response, clientip + "Submitting recaptcha response %s with remoteip %s", user_response, clientip ) # TODO: get this from the homeserver rather than creating a new one for @@ -424,34 +411,34 @@ class AuthHandler(BaseHandler): resp_body = yield client.post_urlencoded_get_json( self.hs.config.recaptcha_siteverify_api, args={ - 'secret': self.hs.config.recaptcha_private_key, - 'response': user_response, - 'remoteip': clientip, - } + "secret": self.hs.config.recaptcha_private_key, + "response": user_response, + "remoteip": clientip, + }, ) except PartialDownloadError as pde: # Twisted is silly data = pde.response resp_body = json.loads(data) - if 'success' in resp_body: + if "success" in resp_body: # Note that we do NOT check the hostname here: we explicitly # intend the CAPTCHA to be presented by whatever client the # user is using, we just care that they have completed a CAPTCHA. logger.info( "%s reCAPTCHA from hostname %s", - "Successful" if resp_body['success'] else "Failed", - resp_body.get('hostname') + "Successful" if resp_body["success"] else "Failed", + resp_body.get("hostname"), ) - if resp_body['success']: + if resp_body["success"]: defer.returnValue(True) raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) def _check_email_identity(self, authdict, **kwargs): - return self._check_threepid('email', authdict, **kwargs) + return self._check_threepid("email", authdict, **kwargs) def _check_msisdn(self, authdict, **kwargs): - return self._check_threepid('msisdn', authdict) + return self._check_threepid("msisdn", authdict) def _check_dummy_auth(self, authdict, **kwargs): return defer.succeed(True) @@ -461,10 +448,10 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def _check_threepid(self, medium, authdict, password_servlet=False, **kwargs): - if 'threepid_creds' not in authdict: + if "threepid_creds" not in authdict: raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM) - threepid_creds = authdict['threepid_creds'] + threepid_creds = authdict["threepid_creds"] identity_handler = self.hs.get_handlers().identity_handler @@ -482,31 +469,36 @@ class AuthHandler(BaseHandler): validated=True, ) - threepid = { - "medium": row["medium"], - "address": row["address"], - "validated_at": row["validated_at"], - } if row else None + threepid = ( + { + "medium": row["medium"], + "address": row["address"], + "validated_at": row["validated_at"], + } + if row + else None + ) if row: # Valid threepid returned, delete from the db yield self.store.delete_threepid_session(threepid_creds["sid"]) else: - raise SynapseError(400, "Password resets are not enabled on this homeserver") + raise SynapseError( + 400, "Password resets are not enabled on this homeserver" + ) if not threepid: raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) - if threepid['medium'] != medium: + if threepid["medium"] != medium: raise LoginError( 401, - "Expecting threepid of type '%s', got '%s'" % ( - medium, threepid['medium'], - ), - errcode=Codes.UNAUTHORIZED + "Expecting threepid of type '%s', got '%s'" + % (medium, threepid["medium"]), + errcode=Codes.UNAUTHORIZED, ) - threepid['threepid_creds'] = authdict['threepid_creds'] + threepid["threepid_creds"] = authdict["threepid_creds"] defer.returnValue(threepid) @@ -520,13 +512,14 @@ class AuthHandler(BaseHandler): "version": self.hs.config.user_consent_version, "en": { "name": self.hs.config.user_consent_policy_name, - "url": "%s_matrix/consent?v=%s" % ( + "url": "%s_matrix/consent?v=%s" + % ( self.hs.config.public_baseurl, self.hs.config.user_consent_version, ), }, - }, - }, + } + } } def _auth_dict_for_flows(self, flows, session): @@ -547,9 +540,9 @@ class AuthHandler(BaseHandler): params[stage] = get_params[stage]() return { - "session": session['id'], + "session": session["id"], "flows": [{"stages": f} for f in public_flows], - "params": params + "params": params, } def _get_session_info(self, session_id): @@ -560,9 +553,7 @@ class AuthHandler(BaseHandler): # create a new session while session_id is None or session_id in self.sessions: session_id = stringutils.random_string(24) - self.sessions[session_id] = { - "id": session_id, - } + self.sessions[session_id] = {"id": session_id} return self.sessions[session_id] @@ -652,7 +643,8 @@ class AuthHandler(BaseHandler): logger.warn( "Attempted to login as %s but it matches more than one user " "inexactly: %r", - user_id, user_infos.keys() + user_id, + user_infos.keys(), ) defer.returnValue(result) @@ -690,12 +682,10 @@ class AuthHandler(BaseHandler): user is too high too proceed. """ - if username.startswith('@'): + if username.startswith("@"): qualified_user_id = username else: - qualified_user_id = UserID( - username, self.hs.hostname - ).to_string() + qualified_user_id = UserID(username, self.hs.hostname).to_string() self.ratelimit_login_per_account(qualified_user_id) @@ -713,17 +703,15 @@ class AuthHandler(BaseHandler): raise SynapseError(400, "Missing parameter: password") for provider in self.password_providers: - if (hasattr(provider, "check_password") - and login_type == LoginType.PASSWORD): + if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD: known_login_type = True - is_valid = yield provider.check_password( - qualified_user_id, password, - ) + is_valid = yield provider.check_password(qualified_user_id, password) if is_valid: defer.returnValue((qualified_user_id, None)) - if (not hasattr(provider, "get_supported_login_types") - or not hasattr(provider, "check_auth")): + if not hasattr(provider, "get_supported_login_types") or not hasattr( + provider, "check_auth" + ): # this password provider doesn't understand custom login types continue @@ -744,15 +732,12 @@ class AuthHandler(BaseHandler): login_dict[f] = login_submission[f] if missing_fields: raise SynapseError( - 400, "Missing parameters for login type %s: %s" % ( - login_type, - missing_fields, - ), + 400, + "Missing parameters for login type %s: %s" + % (login_type, missing_fields), ) - result = yield provider.check_auth( - username, login_type, login_dict, - ) + result = yield provider.check_auth(username, login_type, login_dict) if result: if isinstance(result, str): result = (result, None) @@ -762,7 +747,7 @@ class AuthHandler(BaseHandler): known_login_type = True canonical_user_id = yield self._check_local_password( - qualified_user_id, password, + qualified_user_id, password ) if canonical_user_id: @@ -773,7 +758,8 @@ class AuthHandler(BaseHandler): # unknown username or invalid password. self._failed_attempts_ratelimiter.ratelimit( - qualified_user_id.lower(), time_now_s=self._clock.time(), + qualified_user_id.lower(), + time_now_s=self._clock.time(), rate_hz=self.hs.config.rc_login_failed_attempts.per_second, burst_count=self.hs.config.rc_login_failed_attempts.burst_count, update=True, @@ -781,10 +767,7 @@ class AuthHandler(BaseHandler): # We raise a 403 here, but note that if we're doing user-interactive # login, it turns all LoginErrors into a 401 anyway. - raise LoginError( - 403, "Invalid password", - errcode=Codes.FORBIDDEN - ) + raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN) @defer.inlineCallbacks def check_password_provider_3pid(self, medium, address, password): @@ -810,9 +793,7 @@ class AuthHandler(BaseHandler): # success, to a str (which is the user_id) or a tuple of # (user_id, callback_func), where callback_func should be run # after we've finished everything else - result = yield provider.check_3pid_auth( - medium, address, password, - ) + result = yield provider.check_3pid_auth(medium, address, password) if result: # Check if the return value is a str or a tuple if isinstance(result, str): @@ -853,8 +834,7 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def issue_access_token(self, user_id, device_id=None): access_token = self.macaroon_gen.generate_access_token(user_id) - yield self.store.add_access_token_to_user(user_id, access_token, - device_id) + yield self.store.add_access_token_to_user(user_id, access_token, device_id) defer.returnValue(access_token) @defer.inlineCallbacks @@ -896,12 +876,13 @@ class AuthHandler(BaseHandler): # delete pushers associated with this access token if user_info["token_id"] is not None: yield self.hs.get_pusherpool().remove_pushers_by_access_token( - str(user_info["user"]), (user_info["token_id"], ) + str(user_info["user"]), (user_info["token_id"],) ) @defer.inlineCallbacks - def delete_access_tokens_for_user(self, user_id, except_token_id=None, - device_id=None): + def delete_access_tokens_for_user( + self, user_id, except_token_id=None, device_id=None + ): """Invalidate access tokens belonging to a user Args: @@ -915,7 +896,7 @@ class AuthHandler(BaseHandler): Deferred """ tokens_and_devices = yield self.store.user_delete_access_tokens( - user_id, except_token_id=except_token_id, device_id=device_id, + user_id, except_token_id=except_token_id, device_id=device_id ) # see if any of our auth providers want to know about this @@ -923,14 +904,12 @@ class AuthHandler(BaseHandler): if hasattr(provider, "on_logged_out"): for token, token_id, device_id in tokens_and_devices: yield provider.on_logged_out( - user_id=user_id, - device_id=device_id, - access_token=token, + user_id=user_id, device_id=device_id, access_token=token ) # delete pushers associated with the access tokens yield self.hs.get_pusherpool().remove_pushers_by_access_token( - user_id, (token_id for _, token_id, _ in tokens_and_devices), + user_id, (token_id for _, token_id, _ in tokens_and_devices) ) @defer.inlineCallbacks @@ -944,12 +923,11 @@ class AuthHandler(BaseHandler): # of specific types of threepid (and fixes the fact that checking # for the presence of an email address during password reset was # case sensitive). - if medium == 'email': + if medium == "email": address = address.lower() yield self.store.user_add_threepid( - user_id, medium, address, validated_at, - self.hs.get_clock().time_msec() + user_id, medium, address, validated_at, self.hs.get_clock().time_msec() ) @defer.inlineCallbacks @@ -973,22 +951,15 @@ class AuthHandler(BaseHandler): """ # 'Canonicalise' email addresses as per above - if medium == 'email': + if medium == "email": address = address.lower() identity_handler = self.hs.get_handlers().identity_handler result = yield identity_handler.try_unbind_threepid( - user_id, - { - 'medium': medium, - 'address': address, - 'id_server': id_server, - }, + user_id, {"medium": medium, "address": address, "id_server": id_server} ) - yield self.store.user_delete_threepid( - user_id, medium, address, - ) + yield self.store.user_delete_threepid(user_id, medium, address) defer.returnValue(result) def _save_session(self, session): @@ -1006,14 +977,15 @@ class AuthHandler(BaseHandler): Returns: Deferred(unicode): Hashed password. """ + def _do_hash(): # Normalise the Unicode in the password pw = unicodedata.normalize("NFKC", password) return bcrypt.hashpw( - pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"), + pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"), bcrypt.gensalt(self.bcrypt_rounds), - ).decode('ascii') + ).decode("ascii") return logcontext.defer_to_thread(self.hs.get_reactor(), _do_hash) @@ -1027,18 +999,19 @@ class AuthHandler(BaseHandler): Returns: Deferred(bool): Whether self.hash(password) == stored_hash. """ + def _do_validate_hash(): # Normalise the Unicode in the password pw = unicodedata.normalize("NFKC", password) return bcrypt.checkpw( - pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"), - stored_hash + pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"), + stored_hash, ) if stored_hash: if not isinstance(stored_hash, bytes): - stored_hash = stored_hash.encode('ascii') + stored_hash = stored_hash.encode("ascii") return logcontext.defer_to_thread(self.hs.get_reactor(), _do_validate_hash) else: @@ -1058,14 +1031,16 @@ class AuthHandler(BaseHandler): for this user is too high too proceed. """ self._failed_attempts_ratelimiter.ratelimit( - user_id.lower(), time_now_s=self._clock.time(), + user_id.lower(), + time_now_s=self._clock.time(), rate_hz=self.hs.config.rc_login_failed_attempts.per_second, burst_count=self.hs.config.rc_login_failed_attempts.burst_count, update=False, ) self._account_ratelimiter.ratelimit( - user_id.lower(), time_now_s=self._clock.time(), + user_id.lower(), + time_now_s=self._clock.time(), rate_hz=self.hs.config.rc_login_account.per_second, burst_count=self.hs.config.rc_login_account.burst_count, update=True, @@ -1083,9 +1058,9 @@ class MacaroonGenerator(object): macaroon.add_first_party_caveat("type = access") # Include a nonce, to make sure that each login gets a different # access token. - macaroon.add_first_party_caveat("nonce = %s" % ( - stringutils.random_string_with_symbols(16), - )) + macaroon.add_first_party_caveat( + "nonce = %s" % (stringutils.random_string_with_symbols(16),) + ) for caveat in extra_caveats: macaroon.add_first_party_caveat(caveat) return macaroon.serialize() @@ -1116,7 +1091,8 @@ class MacaroonGenerator(object): macaroon = pymacaroons.Macaroon( location=self.hs.config.server_name, identifier="key", - key=self.hs.config.macaroon_secret_key) + key=self.hs.config.macaroon_secret_key, + ) macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) return macaroon diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index ac8f75d256..602c0e864d 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -28,6 +28,7 @@ logger = logging.getLogger(__name__) class DeactivateAccountHandler(BaseHandler): """Handler which deals with deactivating user accounts.""" + def __init__(self, hs): super(DeactivateAccountHandler, self).__init__(hs) self._auth_handler = hs.get_auth_handler() @@ -79,9 +80,9 @@ class DeactivateAccountHandler(BaseHandler): result = yield self._identity_handler.try_unbind_threepid( user_id, { - 'medium': threepid['medium'], - 'address': threepid['address'], - 'id_server': id_server, + "medium": threepid["medium"], + "address": threepid["address"], + "id_server": id_server, }, ) identity_server_supports_unbinding &= result @@ -90,7 +91,7 @@ class DeactivateAccountHandler(BaseHandler): logger.exception("Failed to remove threepid from ID server") raise SynapseError(400, "Failed to remove threepid from ID server") yield self.store.user_delete_threepid( - user_id, threepid['medium'], threepid['address'], + user_id, threepid["medium"], threepid["address"] ) # delete any devices belonging to the user, which will also @@ -231,5 +232,6 @@ class DeactivateAccountHandler(BaseHandler): except Exception: logger.exception( "Failed to part user %r from room %r: ignoring and continuing", - user_id, room_id, + user_id, + room_id, ) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index d69fc8b061..f99880355d 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -58,9 +58,7 @@ class DeviceWorkerHandler(BaseHandler): device_map = yield self.store.get_devices_by_user(user_id) - ips = yield self.store.get_last_client_ip_by_device( - user_id, device_id=None - ) + ips = yield self.store.get_last_client_ip_by_device(user_id, device_id=None) devices = list(device_map.values()) for device in devices: @@ -85,9 +83,7 @@ class DeviceWorkerHandler(BaseHandler): device = yield self.store.get_device(user_id, device_id) except errors.StoreError: raise errors.NotFoundError - ips = yield self.store.get_last_client_ip_by_device( - user_id, device_id, - ) + ips = yield self.store.get_last_client_ip_by_device(user_id, device_id) _update_device_from_client_ips(device, ips) defer.returnValue(device) @@ -114,13 +110,11 @@ class DeviceWorkerHandler(BaseHandler): rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key) member_events = yield self.store.get_membership_changes_for_user( - user_id, from_token.room_key, now_room_key, + user_id, from_token.room_key, now_room_key ) rooms_changed.update(event.room_id for event in member_events) - stream_ordering = RoomStreamToken.parse_stream_token( - from_token.room_key - ).stream + stream_ordering = RoomStreamToken.parse_stream_token(from_token.room_key).stream possibly_changed = set(changed) possibly_left = set() @@ -206,10 +200,9 @@ class DeviceWorkerHandler(BaseHandler): possibly_joined = [] possibly_left = [] - defer.returnValue({ - "changed": list(possibly_joined), - "left": list(possibly_left), - }) + defer.returnValue( + {"changed": list(possibly_joined), "left": list(possibly_left)} + ) class DeviceHandler(DeviceWorkerHandler): @@ -223,17 +216,18 @@ class DeviceHandler(DeviceWorkerHandler): federation_registry = hs.get_federation_registry() federation_registry.register_edu_handler( - "m.device_list_update", self._edu_updater.incoming_device_list_update, + "m.device_list_update", self._edu_updater.incoming_device_list_update ) federation_registry.register_query_handler( - "user_devices", self.on_federation_query_user_devices, + "user_devices", self.on_federation_query_user_devices ) hs.get_distributor().observe("user_left_room", self.user_left_room) @defer.inlineCallbacks - def check_device_registered(self, user_id, device_id, - initial_device_display_name=None): + def check_device_registered( + self, user_id, device_id, initial_device_display_name=None + ): """ If the given device has not been registered, register it with the supplied display name. @@ -297,12 +291,10 @@ class DeviceHandler(DeviceWorkerHandler): raise yield self._auth_handler.delete_access_tokens_for_user( - user_id, device_id=device_id, + user_id, device_id=device_id ) - yield self.store.delete_e2e_keys_by_device( - user_id=user_id, device_id=device_id - ) + yield self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id) yield self.notify_device_update(user_id, [device_id]) @@ -349,7 +341,7 @@ class DeviceHandler(DeviceWorkerHandler): # considered as part of a critical path. for device_id in device_ids: yield self._auth_handler.delete_access_tokens_for_user( - user_id, device_id=device_id, + user_id, device_id=device_id ) yield self.store.delete_e2e_keys_by_device( user_id=user_id, device_id=device_id @@ -372,9 +364,7 @@ class DeviceHandler(DeviceWorkerHandler): try: yield self.store.update_device( - user_id, - device_id, - new_display_name=content.get("display_name") + user_id, device_id, new_display_name=content.get("display_name") ) yield self.notify_device_update(user_id, [device_id]) except errors.StoreError as e: @@ -404,29 +394,26 @@ class DeviceHandler(DeviceWorkerHandler): for device_id in device_ids: logger.debug( - "Notifying about update %r/%r, ID: %r", user_id, device_id, - position, + "Notifying about update %r/%r, ID: %r", user_id, device_id, position ) room_ids = yield self.store.get_rooms_for_user(user_id) - yield self.notifier.on_new_event( - "device_list_key", position, rooms=room_ids, - ) + yield self.notifier.on_new_event("device_list_key", position, rooms=room_ids) if hosts: - logger.info("Sending device list update notif for %r to: %r", user_id, hosts) + logger.info( + "Sending device list update notif for %r to: %r", user_id, hosts + ) for host in hosts: self.federation_sender.send_device_messages(host) @defer.inlineCallbacks def on_federation_query_user_devices(self, user_id): stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id) - defer.returnValue({ - "user_id": user_id, - "stream_id": stream_id, - "devices": devices, - }) + defer.returnValue( + {"user_id": user_id, "stream_id": stream_id, "devices": devices} + ) @defer.inlineCallbacks def user_left_room(self, user, room_id): @@ -440,10 +427,7 @@ class DeviceHandler(DeviceWorkerHandler): def _update_device_from_client_ips(device, client_ips): ip = client_ips.get((device["user_id"], device["device_id"]), {}) - device.update({ - "last_seen_ts": ip.get("last_seen"), - "last_seen_ip": ip.get("ip"), - }) + device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")}) class DeviceListEduUpdater(object): @@ -481,13 +465,15 @@ class DeviceListEduUpdater(object): device_id = edu_content.pop("device_id") stream_id = str(edu_content.pop("stream_id")) # They may come as ints prev_ids = edu_content.pop("prev_id", []) - prev_ids = [str(p) for p in prev_ids] # They may come as ints + prev_ids = [str(p) for p in prev_ids] # They may come as ints if get_domain_from_id(user_id) != origin: # TODO: Raise? logger.warning( "Got device list update edu for %r/%r from %r", - user_id, device_id, origin, + user_id, + device_id, + origin, ) return @@ -497,13 +483,12 @@ class DeviceListEduUpdater(object): # probably won't get any further updates. logger.warning( "Got device list update edu for %r/%r, but don't share a room", - user_id, device_id, + user_id, + device_id, ) return - logger.debug( - "Received device list update for %r/%r", user_id, device_id, - ) + logger.debug("Received device list update for %r/%r", user_id, device_id) self._pending_updates.setdefault(user_id, []).append( (device_id, stream_id, prev_ids, edu_content) @@ -525,7 +510,10 @@ class DeviceListEduUpdater(object): for device_id, stream_id, prev_ids, content in pending_updates: logger.debug( "Handling update %r/%r, ID: %r, prev: %r ", - user_id, device_id, stream_id, prev_ids, + user_id, + device_id, + stream_id, + prev_ids, ) # Given a list of updates we check if we need to resync. This @@ -540,13 +528,13 @@ class DeviceListEduUpdater(object): try: result = yield self.federation.query_user_devices(origin, user_id) except ( - NotRetryingDestination, RequestSendFailed, HttpResponseException, + NotRetryingDestination, + RequestSendFailed, + HttpResponseException, ): # TODO: Remember that we are now out of sync and try again # later - logger.warn( - "Failed to handle device list update for %s", user_id, - ) + logger.warn("Failed to handle device list update for %s", user_id) # We abort on exceptions rather than accepting the update # as otherwise synapse will 'forget' that its device list # is out of date. If we bail then we will retry the resync @@ -571,7 +559,9 @@ class DeviceListEduUpdater(object): for device in devices: logger.debug( "Handling resync update %r/%r, ID: %r", - user_id, device["device_id"], stream_id, + user_id, + device["device_id"], + stream_id, ) # If the remote server has more than ~1000 devices for this user @@ -588,18 +578,21 @@ class DeviceListEduUpdater(object): if len(devices) > 1000: logger.warn( "Ignoring device list snapshot for %s as it has >1K devs (%d)", - user_id, len(devices) + user_id, + len(devices), ) devices = [] for device in devices: logger.debug( "Handling resync update %r/%r, ID: %r", - user_id, device["device_id"], stream_id, + user_id, + device["device_id"], + stream_id, ) yield self.store.update_remote_device_list_cache( - user_id, devices, stream_id, + user_id, devices, stream_id ) device_ids = [device["device_id"] for device in devices] yield self.device_handler.notify_device_update(user_id, device_ids) @@ -612,7 +605,7 @@ class DeviceListEduUpdater(object): # change (because of the single prev_id matching the current cache) for device_id, stream_id, prev_ids, content in pending_updates: yield self.store.update_remote_device_list_cache_entry( - user_id, device_id, content, stream_id, + user_id, device_id, content, stream_id ) yield self.device_handler.notify_device_update( @@ -630,14 +623,9 @@ class DeviceListEduUpdater(object): """ seen_updates = self._seen_updates.get(user_id, set()) - extremity = yield self.store.get_device_list_last_stream_id_for_remote( - user_id - ) + extremity = yield self.store.get_device_list_last_stream_id_for_remote(user_id) - logger.debug( - "Current extremity for %r: %r", - user_id, extremity, - ) + logger.debug("Current extremity for %r: %r", user_id, extremity) stream_id_in_updates = set() # stream_ids in updates list for _, stream_id, prev_ids, _ in updates: diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 2e2e5261de..e1ebb6346c 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -25,7 +25,6 @@ logger = logging.getLogger(__name__) class DeviceMessageHandler(object): - def __init__(self, hs): """ Args: @@ -47,15 +46,15 @@ class DeviceMessageHandler(object): if origin != get_domain_from_id(sender_user_id): logger.warn( "Dropping device message from %r with spoofed sender %r", - origin, sender_user_id + origin, + sender_user_id, ) message_type = content["type"] message_id = content["message_id"] for user_id, by_device in content["messages"].items(): # we use UserID.from_string to catch invalid user ids if not self.is_mine(UserID.from_string(user_id)): - logger.warning("Request for keys for non-local user %s", - user_id) + logger.warning("Request for keys for non-local user %s", user_id) raise SynapseError(400, "Not a user here") messages_by_device = { diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index a12f9508d8..42d5b3db30 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -36,7 +36,6 @@ logger = logging.getLogger(__name__) class DirectoryHandler(BaseHandler): - def __init__(self, hs): super(DirectoryHandler, self).__init__(hs) @@ -77,15 +76,19 @@ class DirectoryHandler(BaseHandler): raise SynapseError(400, "Failed to get server list") yield self.store.create_room_alias_association( - room_alias, - room_id, - servers, - creator=creator, + room_alias, room_id, servers, creator=creator ) @defer.inlineCallbacks - def create_association(self, requester, room_alias, room_id, servers=None, - send_event=True, check_membership=True): + def create_association( + self, + requester, + room_alias, + room_id, + servers=None, + send_event=True, + check_membership=True, + ): """Attempt to create a new alias Args: @@ -115,49 +118,40 @@ class DirectoryHandler(BaseHandler): if service: if not service.is_interested_in_alias(room_alias.to_string()): raise SynapseError( - 400, "This application service has not reserved" - " this kind of alias.", errcode=Codes.EXCLUSIVE + 400, + "This application service has not reserved" " this kind of alias.", + errcode=Codes.EXCLUSIVE, ) else: if self.require_membership and check_membership: rooms_for_user = yield self.store.get_rooms_for_user(user_id) if room_id not in rooms_for_user: raise AuthError( - 403, - "You must be in the room to create an alias for it", + 403, "You must be in the room to create an alias for it" ) if not self.spam_checker.user_may_create_room_alias(user_id, room_alias): - raise AuthError( - 403, "This user is not permitted to create this alias", - ) + raise AuthError(403, "This user is not permitted to create this alias") if not self.config.is_alias_creation_allowed( - user_id, room_id, room_alias.to_string(), + user_id, room_id, room_alias.to_string() ): # Lets just return a generic message, as there may be all sorts of # reasons why we said no. TODO: Allow configurable error messages # per alias creation rule? - raise SynapseError( - 403, "Not allowed to create alias", - ) + raise SynapseError(403, "Not allowed to create alias") - can_create = yield self.can_modify_alias( - room_alias, - user_id=user_id - ) + can_create = yield self.can_modify_alias(room_alias, user_id=user_id) if not can_create: raise AuthError( - 400, "This alias is reserved by an application service.", - errcode=Codes.EXCLUSIVE + 400, + "This alias is reserved by an application service.", + errcode=Codes.EXCLUSIVE, ) yield self._create_association(room_alias, room_id, servers, creator=user_id) if send_event: - yield self.send_room_alias_update_event( - requester, - room_id - ) + yield self.send_room_alias_update_event(requester, room_id) @defer.inlineCallbacks def delete_association(self, requester, room_alias, send_event=True): @@ -194,34 +188,24 @@ class DirectoryHandler(BaseHandler): raise if not can_delete: - raise AuthError( - 403, "You don't have permission to delete the alias.", - ) + raise AuthError(403, "You don't have permission to delete the alias.") - can_delete = yield self.can_modify_alias( - room_alias, - user_id=user_id - ) + can_delete = yield self.can_modify_alias(room_alias, user_id=user_id) if not can_delete: raise SynapseError( - 400, "This alias is reserved by an application service.", - errcode=Codes.EXCLUSIVE + 400, + "This alias is reserved by an application service.", + errcode=Codes.EXCLUSIVE, ) room_id = yield self._delete_association(room_alias) try: if send_event: - yield self.send_room_alias_update_event( - requester, - room_id - ) + yield self.send_room_alias_update_event(requester, room_id) yield self._update_canonical_alias( - requester, - requester.user.to_string(), - room_id, - room_alias, + requester, requester.user.to_string(), room_id, room_alias ) except AuthError as e: logger.info("Failed to update alias events: %s", e) @@ -234,7 +218,7 @@ class DirectoryHandler(BaseHandler): raise SynapseError( 400, "This application service has not reserved this kind of alias", - errcode=Codes.EXCLUSIVE + errcode=Codes.EXCLUSIVE, ) yield self._delete_association(room_alias) @@ -251,9 +235,7 @@ class DirectoryHandler(BaseHandler): def get_association(self, room_alias): room_id = None if self.hs.is_mine(room_alias): - result = yield self.get_association_from_room_alias( - room_alias - ) + result = yield self.get_association_from_room_alias(room_alias) if result: room_id = result.room_id @@ -263,9 +245,7 @@ class DirectoryHandler(BaseHandler): result = yield self.federation.make_query( destination=room_alias.domain, query_type="directory", - args={ - "room_alias": room_alias.to_string(), - }, + args={"room_alias": room_alias.to_string()}, retry_on_dns_fail=False, ignore_backoff=True, ) @@ -284,7 +264,7 @@ class DirectoryHandler(BaseHandler): raise SynapseError( 404, "Room alias %s not found" % (room_alias.to_string(),), - Codes.NOT_FOUND + Codes.NOT_FOUND, ) users = yield self.state.get_current_users_in_room(room_id) @@ -293,41 +273,28 @@ class DirectoryHandler(BaseHandler): # If this server is in the list of servers, return it first. if self.server_name in servers: - servers = ( - [self.server_name] + - [s for s in servers if s != self.server_name] - ) + servers = [self.server_name] + [s for s in servers if s != self.server_name] else: servers = list(servers) - defer.returnValue({ - "room_id": room_id, - "servers": servers, - }) + defer.returnValue({"room_id": room_id, "servers": servers}) return @defer.inlineCallbacks def on_directory_query(self, args): room_alias = RoomAlias.from_string(args["room_alias"]) if not self.hs.is_mine(room_alias): - raise SynapseError( - 400, "Room Alias is not hosted on this Home Server" - ) + raise SynapseError(400, "Room Alias is not hosted on this Home Server") - result = yield self.get_association_from_room_alias( - room_alias - ) + result = yield self.get_association_from_room_alias(room_alias) if result is not None: - defer.returnValue({ - "room_id": result.room_id, - "servers": result.servers, - }) + defer.returnValue({"room_id": result.room_id, "servers": result.servers}) else: raise SynapseError( 404, "Room alias %r not found" % (room_alias.to_string(),), - Codes.NOT_FOUND + Codes.NOT_FOUND, ) @defer.inlineCallbacks @@ -343,7 +310,7 @@ class DirectoryHandler(BaseHandler): "sender": requester.user.to_string(), "content": {"aliases": aliases}, }, - ratelimit=False + ratelimit=False, ) @defer.inlineCallbacks @@ -365,14 +332,12 @@ class DirectoryHandler(BaseHandler): "sender": user_id, "content": {}, }, - ratelimit=False + ratelimit=False, ) @defer.inlineCallbacks def get_association_from_room_alias(self, room_alias): - result = yield self.store.get_association_from_room_alias( - room_alias - ) + result = yield self.store.get_association_from_room_alias(room_alias) if not result: # Query AS to see if it exists as_handler = self.appservice_handler @@ -421,8 +386,7 @@ class DirectoryHandler(BaseHandler): if not self.spam_checker.user_may_publish_room(user_id, room_id): raise AuthError( - 403, - "This user is not permitted to publish rooms to the room list" + 403, "This user is not permitted to publish rooms to the room list" ) if requester.is_guest: @@ -434,8 +398,7 @@ class DirectoryHandler(BaseHandler): if visibility == "public" and not self.enable_room_list_search: # The room list has been disabled. raise AuthError( - 403, - "This user is not permitted to publish rooms to the room list" + 403, "This user is not permitted to publish rooms to the room list" ) room = yield self.store.get_room(room_id) @@ -452,20 +415,19 @@ class DirectoryHandler(BaseHandler): room_aliases.append(canonical_alias) if not self.config.is_publishing_room_allowed( - user_id, room_id, room_aliases, + user_id, room_id, room_aliases ): # Lets just return a generic message, as there may be all sorts of # reasons why we said no. TODO: Allow configurable error messages # per alias creation rule? - raise SynapseError( - 403, "Not allowed to publish room", - ) + raise SynapseError(403, "Not allowed to publish room") yield self.store.set_room_is_public(room_id, making_public) @defer.inlineCallbacks - def edit_published_appservice_room_list(self, appservice_id, network_id, - room_id, visibility): + def edit_published_appservice_room_list( + self, appservice_id, network_id, room_id, visibility + ): """Add or remove a room from the appservice/network specific public room list. diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 9dc46aa15f..807900fe52 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -99,9 +99,7 @@ class E2eKeysHandler(object): query_list.append((user_id, None)) user_ids_not_in_cache, remote_results = ( - yield self.store.get_user_devices_from_cache( - query_list - ) + yield self.store.get_user_devices_from_cache(query_list) ) for user_id, devices in iteritems(remote_results): user_devices = results.setdefault(user_id, {}) @@ -126,9 +124,7 @@ class E2eKeysHandler(object): destination_query = remote_queries_not_in_cache[destination] try: remote_result = yield self.federation.query_client_keys( - destination, - {"device_keys": destination_query}, - timeout=timeout + destination, {"device_keys": destination_query}, timeout=timeout ) for user_id, keys in remote_result["device_keys"].items(): @@ -138,14 +134,17 @@ class E2eKeysHandler(object): except Exception as e: failures[destination] = _exception_to_failure(e) - yield make_deferred_yieldable(defer.gatherResults([ - run_in_background(do_remote_query, destination) - for destination in remote_queries_not_in_cache - ], consumeErrors=True)) + yield make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background(do_remote_query, destination) + for destination in remote_queries_not_in_cache + ], + consumeErrors=True, + ) + ) - defer.returnValue({ - "device_keys": results, "failures": failures, - }) + defer.returnValue({"device_keys": results, "failures": failures}) @defer.inlineCallbacks def query_local_devices(self, query): @@ -165,8 +164,7 @@ class E2eKeysHandler(object): for user_id, device_ids in query.items(): # we use UserID.from_string to catch invalid user ids if not self.is_mine(UserID.from_string(user_id)): - logger.warning("Request for keys for non-local user %s", - user_id) + logger.warning("Request for keys for non-local user %s", user_id) raise SynapseError(400, "Not a user here") if not device_ids: @@ -231,9 +229,7 @@ class E2eKeysHandler(object): device_keys = remote_queries[destination] try: remote_result = yield self.federation.claim_client_keys( - destination, - {"one_time_keys": device_keys}, - timeout=timeout + destination, {"one_time_keys": device_keys}, timeout=timeout ) for user_id, keys in remote_result["one_time_keys"].items(): if user_id in device_keys: @@ -241,25 +237,29 @@ class E2eKeysHandler(object): except Exception as e: failures[destination] = _exception_to_failure(e) - yield make_deferred_yieldable(defer.gatherResults([ - run_in_background(claim_client_keys, destination) - for destination in remote_queries - ], consumeErrors=True)) + yield make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background(claim_client_keys, destination) + for destination in remote_queries + ], + consumeErrors=True, + ) + ) logger.info( "Claimed one-time-keys: %s", - ",".join(( - "%s for %s:%s" % (key_id, user_id, device_id) - for user_id, user_keys in iteritems(json_result) - for device_id, device_keys in iteritems(user_keys) - for key_id, _ in iteritems(device_keys) - )), + ",".join( + ( + "%s for %s:%s" % (key_id, user_id, device_id) + for user_id, user_keys in iteritems(json_result) + for device_id, device_keys in iteritems(user_keys) + for key_id, _ in iteritems(device_keys) + ) + ), ) - defer.returnValue({ - "one_time_keys": json_result, - "failures": failures - }) + defer.returnValue({"one_time_keys": json_result, "failures": failures}) @defer.inlineCallbacks def upload_keys_for_user(self, user_id, device_id, keys): @@ -270,11 +270,13 @@ class E2eKeysHandler(object): if device_keys: logger.info( "Updating device_keys for device %r for user %s at %d", - device_id, user_id, time_now + device_id, + user_id, + time_now, ) # TODO: Sign the JSON with the server key changed = yield self.store.set_e2e_device_keys( - user_id, device_id, time_now, device_keys, + user_id, device_id, time_now, device_keys ) if changed: # Only notify about device updates *if* the keys actually changed @@ -283,7 +285,7 @@ class E2eKeysHandler(object): one_time_keys = keys.get("one_time_keys", None) if one_time_keys: yield self._upload_one_time_keys_for_user( - user_id, device_id, time_now, one_time_keys, + user_id, device_id, time_now, one_time_keys ) # the device should have been registered already, but it may have been @@ -298,20 +300,22 @@ class E2eKeysHandler(object): defer.returnValue({"one_time_key_counts": result}) @defer.inlineCallbacks - def _upload_one_time_keys_for_user(self, user_id, device_id, time_now, - one_time_keys): + def _upload_one_time_keys_for_user( + self, user_id, device_id, time_now, one_time_keys + ): logger.info( "Adding one_time_keys %r for device %r for user %r at %d", - one_time_keys.keys(), device_id, user_id, time_now, + one_time_keys.keys(), + device_id, + user_id, + time_now, ) # make a list of (alg, id, key) tuples key_list = [] for key_id, key_obj in one_time_keys.items(): algorithm, key_id = key_id.split(":") - key_list.append(( - algorithm, key_id, key_obj - )) + key_list.append((algorithm, key_id, key_obj)) # First we check if we have already persisted any of the keys. existing_key_map = yield self.store.get_e2e_one_time_keys( @@ -325,42 +329,35 @@ class E2eKeysHandler(object): if not _one_time_keys_match(ex_json, key): raise SynapseError( 400, - ("One time key %s:%s already exists. " - "Old key: %s; new key: %r") % - (algorithm, key_id, ex_json, key) + ( + "One time key %s:%s already exists. " + "Old key: %s; new key: %r" + ) + % (algorithm, key_id, ex_json, key), ) else: - new_keys.append(( - algorithm, key_id, encode_canonical_json(key).decode('ascii'))) + new_keys.append( + (algorithm, key_id, encode_canonical_json(key).decode("ascii")) + ) - yield self.store.add_e2e_one_time_keys( - user_id, device_id, time_now, new_keys - ) + yield self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys) def _exception_to_failure(e): if isinstance(e, CodeMessageException): - return { - "status": e.code, "message": str(e), - } + return {"status": e.code, "message": str(e)} if isinstance(e, NotRetryingDestination): - return { - "status": 503, "message": "Not ready for retry", - } + return {"status": 503, "message": "Not ready for retry"} if isinstance(e, FederationDeniedError): - return { - "status": 403, "message": "Federation Denied", - } + return {"status": 403, "message": "Federation Denied"} # include ConnectionRefused and other errors # # Note that some Exceptions (notably twisted's ResponseFailed etc) don't # give a string for e.message, which json then fails to serialize. - return { - "status": 503, "message": str(e), - } + return {"status": 503, "message": str(e)} def _one_time_keys_match(old_key_json, new_key): diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 7bc174070e..ebd807bca6 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -152,14 +152,14 @@ class E2eRoomKeysHandler(object): else: raise - if version_info['version'] != version: + if version_info["version"] != version: # Check that the version we're trying to upload actually exists try: version_info = yield self.store.get_e2e_room_keys_version_info( - user_id, version, + user_id, version ) # if we get this far, the version must exist - raise RoomKeysVersionError(current_version=version_info['version']) + raise RoomKeysVersionError(current_version=version_info["version"]) except StoreError as e: if e.code == 404: raise NotFoundError("Version '%s' not found" % (version,)) @@ -168,8 +168,8 @@ class E2eRoomKeysHandler(object): # go through the room_keys. # XXX: this should/could be done concurrently, given we're in a lock. - for room_id, room in iteritems(room_keys['rooms']): - for session_id, session in iteritems(room['sessions']): + for room_id, room in iteritems(room_keys["rooms"]): + for session_id, session in iteritems(room["sessions"]): yield self._upload_room_key( user_id, version, room_id, session_id, session ) @@ -223,14 +223,14 @@ class E2eRoomKeysHandler(object): # spelt out with if/elifs rather than nested boolean expressions # purely for legibility. - if room_key['is_verified'] and not current_room_key['is_verified']: + if room_key["is_verified"] and not current_room_key["is_verified"]: return True elif ( - room_key['first_message_index'] < - current_room_key['first_message_index'] + room_key["first_message_index"] + < current_room_key["first_message_index"] ): return True - elif room_key['forwarded_count'] < current_room_key['forwarded_count']: + elif room_key["forwarded_count"] < current_room_key["forwarded_count"]: return True else: return False @@ -328,16 +328,10 @@ class E2eRoomKeysHandler(object): A deferred of an empty dict. """ if "version" not in version_info: - raise SynapseError( - 400, - "Missing version in body", - Codes.MISSING_PARAM - ) + raise SynapseError(400, "Missing version in body", Codes.MISSING_PARAM) if version_info["version"] != version: raise SynapseError( - 400, - "Version in body does not match", - Codes.INVALID_PARAM + 400, "Version in body does not match", Codes.INVALID_PARAM ) with (yield self._upload_linearizer.queue(user_id)): try: @@ -350,12 +344,10 @@ class E2eRoomKeysHandler(object): else: raise if old_info["algorithm"] != version_info["algorithm"]: - raise SynapseError( - 400, - "Algorithm does not match", - Codes.INVALID_PARAM - ) + raise SynapseError(400, "Algorithm does not match", Codes.INVALID_PARAM) - yield self.store.update_e2e_room_keys_version(user_id, version, version_info) + yield self.store.update_e2e_room_keys_version( + user_id, version, version_info + ) defer.returnValue({}) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index eb525070cf..5836d3c639 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -31,7 +31,6 @@ logger = logging.getLogger(__name__) class EventStreamHandler(BaseHandler): - def __init__(self, hs): super(EventStreamHandler, self).__init__(hs) @@ -53,9 +52,17 @@ class EventStreamHandler(BaseHandler): @defer.inlineCallbacks @log_function - def get_stream(self, auth_user_id, pagin_config, timeout=0, - as_client_event=True, affect_presence=True, - only_keys=None, room_id=None, is_guest=False): + def get_stream( + self, + auth_user_id, + pagin_config, + timeout=0, + as_client_event=True, + affect_presence=True, + only_keys=None, + room_id=None, + is_guest=False, + ): """Fetches the events stream for a given user. If `only_keys` is not None, events from keys will be sent down. @@ -73,7 +80,7 @@ class EventStreamHandler(BaseHandler): presence_handler = self.hs.get_presence_handler() context = yield presence_handler.user_syncing( - auth_user_id, affect_presence=affect_presence, + auth_user_id, affect_presence=affect_presence ) with context: if timeout: @@ -85,9 +92,12 @@ class EventStreamHandler(BaseHandler): timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1)) events, tokens = yield self.notifier.get_events_for( - auth_user, pagin_config, timeout, + auth_user, + pagin_config, + timeout, only_keys=only_keys, - is_guest=is_guest, explicit_room_id=room_id + is_guest=is_guest, + explicit_room_id=room_id, ) # When the user joins a new room, or another user joins a currently @@ -102,17 +112,15 @@ class EventStreamHandler(BaseHandler): # Send down presence. if event.state_key == auth_user_id: # Send down presence for everyone in the room. - users = yield self.state.get_current_users_in_room(event.room_id) - states = yield presence_handler.get_states( - users, - as_event=True, + users = yield self.state.get_current_users_in_room( + event.room_id ) + states = yield presence_handler.get_states(users, as_event=True) to_add.extend(states) else: ev = yield presence_handler.get_state( - UserID.from_string(event.state_key), - as_event=True, + UserID.from_string(event.state_key), as_event=True ) to_add.append(ev) @@ -121,7 +129,9 @@ class EventStreamHandler(BaseHandler): time_now = self.clock.time_msec() chunks = yield self._event_serializer.serialize_events( - events, time_now, as_client_event=as_client_event, + events, + time_now, + as_client_event=as_client_event, # We don't bundle "live" events, as otherwise clients # will end up double counting annotations. bundle_aggregations=False, @@ -137,7 +147,6 @@ class EventStreamHandler(BaseHandler): class EventHandler(BaseHandler): - @defer.inlineCallbacks def get_event(self, user, room_id, event_id): """Retrieve a single specified event. @@ -164,16 +173,10 @@ class EventHandler(BaseHandler): is_peeking = user.to_string() not in users filtered = yield filter_events_for_client( - self.store, - user.to_string(), - [event], - is_peeking=is_peeking + self.store, user.to_string(), [event], is_peeking=is_peeking ) if not filtered: - raise AuthError( - 403, - "You don't have permission to access that event." - ) + raise AuthError(403, "You don't have permission to access that event.") defer.returnValue(event) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 0a52f3d64d..189e6e3737 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -115,14 +115,14 @@ class FederationHandler(BaseHandler): self.config = hs.config self.http_client = hs.get_simple_http_client() - self._send_events_to_master = ( - ReplicationFederationSendEventsRestServlet.make_client(hs) + self._send_events_to_master = ReplicationFederationSendEventsRestServlet.make_client( + hs ) - self._notify_user_membership_change = ( - ReplicationUserJoinedLeftRoomRestServlet.make_client(hs) + self._notify_user_membership_change = ReplicationUserJoinedLeftRoomRestServlet.make_client( + hs ) - self._clean_room_for_join_client = ( - ReplicationCleanRoomRestServlet.make_client(hs) + self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client( + hs ) # When joining a room we need to queue any events for that room up @@ -132,9 +132,7 @@ class FederationHandler(BaseHandler): self.third_party_event_rules = hs.get_third_party_event_rules() @defer.inlineCallbacks - def on_receive_pdu( - self, origin, pdu, sent_to_us_directly=False, - ): + def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False): """ Process a PDU received via a federation /send/ transaction, or via backfill of missing prev_events @@ -151,26 +149,19 @@ class FederationHandler(BaseHandler): room_id = pdu.room_id event_id = pdu.event_id - logger.info( - "[%s %s] handling received PDU: %s", - room_id, event_id, pdu, - ) + logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu) # We reprocess pdus when we have seen them only as outliers existing = yield self.store.get_event( - event_id, - allow_none=True, - allow_rejected=True, + event_id, allow_none=True, allow_rejected=True ) # FIXME: Currently we fetch an event again when we already have it # if it has been marked as an outlier. - already_seen = ( - existing and ( - not existing.internal_metadata.is_outlier() - or pdu.internal_metadata.is_outlier() - ) + already_seen = existing and ( + not existing.internal_metadata.is_outlier() + or pdu.internal_metadata.is_outlier() ) if already_seen: logger.debug("[%s %s]: Already seen pdu", room_id, event_id) @@ -182,20 +173,19 @@ class FederationHandler(BaseHandler): try: self._sanity_check_event(pdu) except SynapseError as err: - logger.warn("[%s %s] Received event failed sanity checks", room_id, event_id) - raise FederationError( - "ERROR", - err.code, - err.msg, - affected=pdu.event_id, + logger.warn( + "[%s %s] Received event failed sanity checks", room_id, event_id ) + raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id) # If we are currently in the process of joining this room, then we # queue up events for later processing. if room_id in self.room_queues: logger.info( "[%s %s] Queuing PDU from %s for now: join in progress", - room_id, event_id, origin, + room_id, + event_id, + origin, ) self.room_queues[room_id].append((pdu, origin)) return @@ -206,14 +196,13 @@ class FederationHandler(BaseHandler): # # Note that if we were never in the room then we would have already # dropped the event, since we wouldn't know the room version. - is_in_room = yield self.auth.check_host_in_room( - room_id, - self.server_name - ) + is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name) if not is_in_room: logger.info( "[%s %s] Ignoring PDU from %s as we're not in the room", - room_id, event_id, origin, + room_id, + event_id, + origin, ) defer.returnValue(None) @@ -223,14 +212,9 @@ class FederationHandler(BaseHandler): # Get missing pdus if necessary. if not pdu.internal_metadata.is_outlier(): # We only backfill backwards to the min depth. - min_depth = yield self.get_min_depth_for_context( - pdu.room_id - ) + min_depth = yield self.get_min_depth_for_context(pdu.room_id) - logger.debug( - "[%s %s] min_depth: %d", - room_id, event_id, min_depth, - ) + logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth) prevs = set(pdu.prev_event_ids()) seen = yield self.store.have_seen_events(prevs) @@ -248,12 +232,17 @@ class FederationHandler(BaseHandler): # at a time. logger.info( "[%s %s] Acquiring room lock to fetch %d missing prev_events: %s", - room_id, event_id, len(missing_prevs), shortstr(missing_prevs), + room_id, + event_id, + len(missing_prevs), + shortstr(missing_prevs), ) with (yield self._room_pdu_linearizer.queue(pdu.room_id)): logger.info( "[%s %s] Acquired room lock to fetch %d missing prev_events", - room_id, event_id, len(missing_prevs), + room_id, + event_id, + len(missing_prevs), ) yield self._get_missing_events_for_pdu( @@ -267,12 +256,16 @@ class FederationHandler(BaseHandler): if not prevs - seen: logger.info( "[%s %s] Found all missing prev_events", - room_id, event_id, + room_id, + event_id, ) elif missing_prevs: logger.info( "[%s %s] Not recursively fetching %d missing prev_events: %s", - room_id, event_id, len(missing_prevs), shortstr(missing_prevs), + room_id, + event_id, + len(missing_prevs), + shortstr(missing_prevs), ) if prevs - seen: @@ -303,7 +296,10 @@ class FederationHandler(BaseHandler): if sent_to_us_directly: logger.warn( "[%s %s] Rejecting: failed to fetch %d prev events: %s", - room_id, event_id, len(prevs - seen), shortstr(prevs - seen) + room_id, + event_id, + len(prevs - seen), + shortstr(prevs - seen), ) raise FederationError( "ERROR", @@ -318,9 +314,7 @@ class FederationHandler(BaseHandler): # Calculate the state after each of the previous events, and # resolve them to find the correct state at the current event. auth_chains = set() - event_map = { - event_id: pdu, - } + event_map = {event_id: pdu} try: # Get the state of the events we know about ours = yield self.store.get_state_groups_ids(room_id, seen) @@ -337,7 +331,9 @@ class FederationHandler(BaseHandler): for p in prevs - seen: logger.info( "[%s %s] Requesting state at missing prev_event %s", - room_id, event_id, p, + room_id, + event_id, + p, ) room_version = yield self.store.get_room_version(room_id) @@ -348,19 +344,19 @@ class FederationHandler(BaseHandler): # by the get_pdu_cache in federation_client. remote_state, got_auth_chain = ( yield self.federation_client.get_state_for_room( - origin, room_id, p, + origin, room_id, p ) ) # we want the state *after* p; get_state_for_room returns the # state *before* p. remote_event = yield self.federation_client.get_pdu( - [origin], p, room_version, outlier=True, + [origin], p, room_version, outlier=True ) if remote_event is None: raise Exception( - "Unable to get missing prev_event %s" % (p, ) + "Unable to get missing prev_event %s" % (p,) ) if remote_event.is_state(): @@ -380,7 +376,9 @@ class FederationHandler(BaseHandler): event_map[x.event_id] = x state_map = yield resolve_events_with_store( - room_version, state_maps, event_map, + room_version, + state_maps, + event_map, state_res_store=StateResolutionStore(self.store), ) @@ -396,15 +394,15 @@ class FederationHandler(BaseHandler): ) event_map.update(evs) - state = [ - event_map[e] for e in six.itervalues(state_map) - ] + state = [event_map[e] for e in six.itervalues(state_map)] auth_chain = list(auth_chains) except Exception: logger.warn( "[%s %s] Error attempting to resolve state at missing " "prev_events", - room_id, event_id, exc_info=True, + room_id, + event_id, + exc_info=True, ) raise FederationError( "ERROR", @@ -414,10 +412,7 @@ class FederationHandler(BaseHandler): ) yield self._process_received_pdu( - origin, - pdu, - state=state, - auth_chain=auth_chain, + origin, pdu, state=state, auth_chain=auth_chain ) @defer.inlineCallbacks @@ -447,7 +442,10 @@ class FederationHandler(BaseHandler): logger.info( "[%s %s]: Requesting missing events between %s and %s", - room_id, event_id, shortstr(latest), event_id, + room_id, + event_id, + shortstr(latest), + event_id, ) # XXX: we set timeout to 10s to help workaround @@ -512,15 +510,15 @@ class FederationHandler(BaseHandler): # We failed to get the missing events, but since we need to handle # the case of `get_missing_events` not returning the necessary # events anyway, it is safe to simply log the error and continue. - logger.warn( - "[%s %s]: Failed to get prev_events: %s", - room_id, event_id, e, - ) + logger.warn("[%s %s]: Failed to get prev_events: %s", room_id, event_id, e) return logger.info( "[%s %s]: Got %d prev_events: %s", - room_id, event_id, len(missing_events), shortstr(missing_events), + room_id, + event_id, + len(missing_events), + shortstr(missing_events), ) # We want to sort these by depth so we process them and @@ -530,20 +528,20 @@ class FederationHandler(BaseHandler): for ev in missing_events: logger.info( "[%s %s] Handling received prev_event %s", - room_id, event_id, ev.event_id, + room_id, + event_id, + ev.event_id, ) with logcontext.nested_logging_context(ev.event_id): try: - yield self.on_receive_pdu( - origin, - ev, - sent_to_us_directly=False, - ) + yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False) except FederationError as e: if e.code == 403: logger.warn( "[%s %s] Received prev_event %s failed history check.", - room_id, event_id, ev.event_id, + room_id, + event_id, + ev.event_id, ) else: raise @@ -556,10 +554,7 @@ class FederationHandler(BaseHandler): room_id = event.room_id event_id = event.event_id - logger.debug( - "[%s %s] Processing event: %s", - room_id, event_id, event, - ) + logger.debug("[%s %s] Processing event: %s", room_id, event_id, event) event_ids = set() if state: @@ -581,43 +576,32 @@ class FederationHandler(BaseHandler): e.internal_metadata.outlier = True auth_ids = e.auth_event_ids() auth = { - (e.type, e.state_key): e for e in auth_chain + (e.type, e.state_key): e + for e in auth_chain if e.event_id in auth_ids or e.type == EventTypes.Create } - event_infos.append({ - "event": e, - "auth_events": auth, - }) + event_infos.append({"event": e, "auth_events": auth}) seen_ids.add(e.event_id) logger.info( "[%s %s] persisting newly-received auth/state events %s", - room_id, event_id, [e["event"].event_id for e in event_infos] + room_id, + event_id, + [e["event"].event_id for e in event_infos], ) yield self._handle_new_events(origin, event_infos) try: - context = yield self._handle_new_event( - origin, - event, - state=state, - ) + context = yield self._handle_new_event(origin, event, state=state) except AuthError as e: - raise FederationError( - "ERROR", - e.code, - e.msg, - affected=event.event_id, - ) + raise FederationError("ERROR", e.code, e.msg, affected=event.event_id) room = yield self.store.get_room(room_id) if not room: try: yield self.store.store_room( - room_id=room_id, - room_creator_user_id="", - is_public=False, + room_id=room_id, room_creator_user_id="", is_public=False ) except StoreError: logger.exception("Failed to store room.") @@ -631,12 +615,10 @@ class FederationHandler(BaseHandler): prev_state_ids = yield context.get_prev_state_ids(self.store) - prev_state_id = prev_state_ids.get( - (event.type, event.state_key) - ) + prev_state_id = prev_state_ids.get((event.type, event.state_key)) if prev_state_id: prev_state = yield self.store.get_event( - prev_state_id, allow_none=True, + prev_state_id, allow_none=True ) if prev_state and prev_state.membership == Membership.JOIN: newly_joined = False @@ -667,10 +649,7 @@ class FederationHandler(BaseHandler): room_version = yield self.store.get_room_version(room_id) events = yield self.federation_client.backfill( - dest, - room_id, - limit=limit, - extremities=extremities, + dest, room_id, limit=limit, extremities=extremities ) # ideally we'd sanity check the events here for excess prev_events etc, @@ -697,16 +676,9 @@ class FederationHandler(BaseHandler): event_ids = set(e.event_id for e in events) - edges = [ - ev.event_id - for ev in events - if set(ev.prev_event_ids()) - event_ids - ] + edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids] - logger.info( - "backfill: Got %d events with %d edges", - len(events), len(edges), - ) + logger.info("backfill: Got %d events with %d edges", len(events), len(edges)) # For each edge get the current state. @@ -715,9 +687,7 @@ class FederationHandler(BaseHandler): events_to_state = {} for e_id in edges: state, auth = yield self.federation_client.get_state_for_room( - destination=dest, - room_id=room_id, - event_id=e_id + destination=dest, room_id=room_id, event_id=e_id ) auth_events.update({a.event_id: a for a in auth}) auth_events.update({s.event_id: s for s in state}) @@ -726,12 +696,14 @@ class FederationHandler(BaseHandler): required_auth = set( a_id - for event in events + list(state_events.values()) + list(auth_events.values()) + for event in events + + list(state_events.values()) + + list(auth_events.values()) for a_id in event.auth_event_ids() ) - auth_events.update({ - e_id: event_map[e_id] for e_id in required_auth if e_id in event_map - }) + auth_events.update( + {e_id: event_map[e_id] for e_id in required_auth if e_id in event_map} + ) missing_auth = required_auth - set(auth_events) failed_to_fetch = set() @@ -750,27 +722,30 @@ class FederationHandler(BaseHandler): if missing_auth - failed_to_fetch: logger.info( "Fetching missing auth for backfill: %r", - missing_auth - failed_to_fetch + missing_auth - failed_to_fetch, ) - results = yield logcontext.make_deferred_yieldable(defer.gatherResults( - [ - logcontext.run_in_background( - self.federation_client.get_pdu, - [dest], - event_id, - room_version=room_version, - outlier=True, - timeout=10000, - ) - for event_id in missing_auth - failed_to_fetch - ], - consumeErrors=True - )).addErrback(unwrapFirstError) + results = yield logcontext.make_deferred_yieldable( + defer.gatherResults( + [ + logcontext.run_in_background( + self.federation_client.get_pdu, + [dest], + event_id, + room_version=room_version, + outlier=True, + timeout=10000, + ) + for event_id in missing_auth - failed_to_fetch + ], + consumeErrors=True, + ) + ).addErrback(unwrapFirstError) auth_events.update({a.event_id: a for a in results if a}) required_auth.update( a_id - for event in results if event + for event in results + if event for a_id in event.auth_event_ids() ) missing_auth = required_auth - set(auth_events) @@ -802,15 +777,19 @@ class FederationHandler(BaseHandler): continue a.internal_metadata.outlier = True - ev_infos.append({ - "event": a, - "auth_events": { - (auth_events[a_id].type, auth_events[a_id].state_key): - auth_events[a_id] - for a_id in a.auth_event_ids() - if a_id in auth_events + ev_infos.append( + { + "event": a, + "auth_events": { + ( + auth_events[a_id].type, + auth_events[a_id].state_key, + ): auth_events[a_id] + for a_id in a.auth_event_ids() + if a_id in auth_events + }, } - }) + ) # Step 1b: persist the events in the chunk we fetched state for (i.e. # the backwards extremities) as non-outliers. @@ -818,23 +797,24 @@ class FederationHandler(BaseHandler): # For paranoia we ensure that these events are marked as # non-outliers ev = event_map[e_id] - assert(not ev.internal_metadata.is_outlier()) - - ev_infos.append({ - "event": ev, - "state": events_to_state[e_id], - "auth_events": { - (auth_events[a_id].type, auth_events[a_id].state_key): - auth_events[a_id] - for a_id in ev.auth_event_ids() - if a_id in auth_events + assert not ev.internal_metadata.is_outlier() + + ev_infos.append( + { + "event": ev, + "state": events_to_state[e_id], + "auth_events": { + ( + auth_events[a_id].type, + auth_events[a_id].state_key, + ): auth_events[a_id] + for a_id in ev.auth_event_ids() + if a_id in auth_events + }, } - }) + ) - yield self._handle_new_events( - dest, ev_infos, - backfilled=True, - ) + yield self._handle_new_events(dest, ev_infos, backfilled=True) # Step 2: Persist the rest of the events in the chunk one by one events.sort(key=lambda e: e.depth) @@ -845,14 +825,12 @@ class FederationHandler(BaseHandler): # For paranoia we ensure that these events are marked as # non-outliers - assert(not event.internal_metadata.is_outlier()) + assert not event.internal_metadata.is_outlier() # We store these one at a time since each event depends on the # previous to work out the state. # TODO: We can probably do something more clever here. - yield self._handle_new_event( - dest, event, backfilled=True, - ) + yield self._handle_new_event(dest, event, backfilled=True) defer.returnValue(events) @@ -861,9 +839,7 @@ class FederationHandler(BaseHandler): """Checks the database to see if we should backfill before paginating, and if so do. """ - extremities = yield self.store.get_oldest_events_with_depth_in_room( - room_id - ) + extremities = yield self.store.get_oldest_events_with_depth_in_room(room_id) if not extremities: logger.debug("Not backfilling as no extremeties found.") @@ -895,31 +871,27 @@ class FederationHandler(BaseHandler): # state *before* the event, ignoring the special casing certain event # types have. - forward_events = yield self.store.get_successor_events( - list(extremities), - ) + forward_events = yield self.store.get_successor_events(list(extremities)) extremities_events = yield self.store.get_events( - forward_events, - check_redacted=False, - get_prev_content=False, + forward_events, check_redacted=False, get_prev_content=False ) # We set `check_history_visibility_only` as we might otherwise get false # positives from users having been erased. filtered_extremities = yield filter_events_for_server( - self.store, self.server_name, list(extremities_events.values()), - redact=False, check_history_visibility_only=True, + self.store, + self.server_name, + list(extremities_events.values()), + redact=False, + check_history_visibility_only=True, ) if not filtered_extremities: defer.returnValue(False) # Check if we reached a point where we should start backfilling. - sorted_extremeties_tuple = sorted( - extremities.items(), - key=lambda e: -int(e[1]) - ) + sorted_extremeties_tuple = sorted(extremities.items(), key=lambda e: -int(e[1])) max_depth = sorted_extremeties_tuple[0][1] # We don't want to specify too many extremities as it causes the backfill @@ -928,8 +900,7 @@ class FederationHandler(BaseHandler): if current_depth > max_depth: logger.debug( - "Not backfilling as we don't need to. %d < %d", - max_depth, current_depth, + "Not backfilling as we don't need to. %d < %d", max_depth, current_depth ) return @@ -954,8 +925,7 @@ class FederationHandler(BaseHandler): joined_users = [ (state_key, int(event.depth)) for (e_type, state_key), event in iteritems(state) - if e_type == EventTypes.Member - and event.membership == Membership.JOIN + if e_type == EventTypes.Member and event.membership == Membership.JOIN ] joined_domains = {} @@ -975,8 +945,7 @@ class FederationHandler(BaseHandler): curr_domains = get_domains_from_state(curr_state) likely_domains = [ - domain for domain, depth in curr_domains - if domain != self.server_name + domain for domain, depth in curr_domains if domain != self.server_name ] @defer.inlineCallbacks @@ -985,28 +954,20 @@ class FederationHandler(BaseHandler): for dom in domains: try: yield self.backfill( - dom, room_id, - limit=100, - extremities=extremities, + dom, room_id, limit=100, extremities=extremities ) # If this succeeded then we probably already have the # appropriate stuff. # TODO: We can probably do something more intelligent here. defer.returnValue(True) except SynapseError as e: - logger.info( - "Failed to backfill from %s because %s", - dom, e, - ) + logger.info("Failed to backfill from %s because %s", dom, e) continue except CodeMessageException as e: if 400 <= e.code < 500: raise - logger.info( - "Failed to backfill from %s because %s", - dom, e, - ) + logger.info("Failed to backfill from %s because %s", dom, e) continue except NotRetryingDestination as e: logger.info(str(e)) @@ -1015,10 +976,7 @@ class FederationHandler(BaseHandler): logger.info(e) continue except Exception as e: - logger.exception( - "Failed to backfill from %s because %s", - dom, e, - ) + logger.exception("Failed to backfill from %s because %s", dom, e) continue defer.returnValue(False) @@ -1039,10 +997,11 @@ class FederationHandler(BaseHandler): resolve = logcontext.preserve_fn( self.state_handler.resolve_state_groups_for_events ) - states = yield logcontext.make_deferred_yieldable(defer.gatherResults( - [resolve(room_id, [e]) for e in event_ids], - consumeErrors=True, - )) + states = yield logcontext.make_deferred_yieldable( + defer.gatherResults( + [resolve(room_id, [e]) for e in event_ids], consumeErrors=True + ) + ) # dict[str, dict[tuple, str]], a map from event_id to state map of # event_ids. @@ -1050,23 +1009,23 @@ class FederationHandler(BaseHandler): state_map = yield self.store.get_events( [e_id for ids in itervalues(states) for e_id in itervalues(ids)], - get_prev_content=False + get_prev_content=False, ) states = { key: { k: state_map[e_id] for k, e_id in iteritems(state_dict) if e_id in state_map - } for key, state_dict in iteritems(states) + } + for key, state_dict in iteritems(states) } for e_id, _ in sorted_extremeties_tuple: likely_domains = get_domains_from_state(states[e_id]) - success = yield try_backfill([ - dom for dom, _ in likely_domains - if dom not in tried_domains - ]) + success = yield try_backfill( + [dom for dom, _ in likely_domains if dom not in tried_domains] + ) if success: defer.returnValue(True) @@ -1091,20 +1050,20 @@ class FederationHandler(BaseHandler): SynapseError if the event does not pass muster """ if len(ev.prev_event_ids()) > 20: - logger.warn("Rejecting event %s which has %i prev_events", - ev.event_id, len(ev.prev_event_ids())) - raise SynapseError( - http_client.BAD_REQUEST, - "Too many prev_events", + logger.warn( + "Rejecting event %s which has %i prev_events", + ev.event_id, + len(ev.prev_event_ids()), ) + raise SynapseError(http_client.BAD_REQUEST, "Too many prev_events") if len(ev.auth_event_ids()) > 10: - logger.warn("Rejecting event %s which has %i auth_events", - ev.event_id, len(ev.auth_event_ids())) - raise SynapseError( - http_client.BAD_REQUEST, - "Too many auth_events", + logger.warn( + "Rejecting event %s which has %i auth_events", + ev.event_id, + len(ev.auth_event_ids()), ) + raise SynapseError(http_client.BAD_REQUEST, "Too many auth_events") @defer.inlineCallbacks def send_invite(self, target_host, event): @@ -1116,7 +1075,7 @@ class FederationHandler(BaseHandler): destination=target_host, room_id=event.room_id, event_id=event.event_id, - pdu=event + pdu=event, ) defer.returnValue(pdu) @@ -1125,8 +1084,7 @@ class FederationHandler(BaseHandler): def on_event_auth(self, event_id): event = yield self.store.get_event(event_id) auth = yield self.store.get_auth_chain( - [auth_id for auth_id in event.auth_event_ids()], - include_given=True + [auth_id for auth_id in event.auth_event_ids()], include_given=True ) defer.returnValue([e for e in auth]) @@ -1152,15 +1110,13 @@ class FederationHandler(BaseHandler): joinee, "join", content, - params={ - "ver": KNOWN_ROOM_VERSIONS, - }, + params={"ver": KNOWN_ROOM_VERSIONS}, ) # This shouldn't happen, because the RoomMemberHandler has a # linearizer lock which only allows one operation per user per room # at a time - so this is just paranoia. - assert (room_id not in self.room_queues) + assert room_id not in self.room_queues self.room_queues[room_id] = [] @@ -1177,7 +1133,7 @@ class FederationHandler(BaseHandler): except ValueError: pass ret = yield self.federation_client.send_join( - target_hosts, event, event_format_version, + target_hosts, event, event_format_version ) origin = ret["origin"] @@ -1196,17 +1152,13 @@ class FederationHandler(BaseHandler): try: yield self.store.store_room( - room_id=room_id, - room_creator_user_id="", - is_public=False + room_id=room_id, room_creator_user_id="", is_public=False ) except Exception: # FIXME pass - yield self._persist_auth_tree( - origin, auth_chain, state, event - ) + yield self._persist_auth_tree(origin, auth_chain, state, event) logger.debug("Finished joining %s to %s", joinee, room_id) finally: @@ -1233,14 +1185,18 @@ class FederationHandler(BaseHandler): """ for p, origin in room_queue: try: - logger.info("Processing queued PDU %s which was received " - "while we were joining %s", p.event_id, p.room_id) + logger.info( + "Processing queued PDU %s which was received " + "while we were joining %s", + p.event_id, + p.room_id, + ) with logcontext.nested_logging_context(p.event_id): yield self.on_receive_pdu(origin, p, sent_to_us_directly=True) except Exception as e: logger.warn( - "Error handling queued PDU %s from %s: %s", - p.event_id, origin, e) + "Error handling queued PDU %s from %s: %s", p.event_id, origin, e + ) @defer.inlineCallbacks @log_function @@ -1261,30 +1217,30 @@ class FederationHandler(BaseHandler): "room_id": room_id, "sender": user_id, "state_key": user_id, - } + }, ) try: event, context = yield self.event_creation_handler.create_new_client_event( - builder=builder, + builder=builder ) except AuthError as e: logger.warn("Failed to create join %r because %s", event, e) raise e event_allowed = yield self.third_party_event_rules.check_event_allowed( - event, context, + event, context ) if not event_allowed: logger.info("Creation of join %s forbidden by third-party rules", event) raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN, + 403, "This event is not allowed in this context", Codes.FORBIDDEN ) # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_join_request` yield self.auth.check_from_context( - room_version, event, context, do_sig_check=False, + room_version, event, context, do_sig_check=False ) defer.returnValue(event) @@ -1319,17 +1275,15 @@ class FederationHandler(BaseHandler): # would introduce the danger of backwards-compatibility problems. event.internal_metadata.send_on_behalf_of = origin - context = yield self._handle_new_event( - origin, event - ) + context = yield self._handle_new_event(origin, event) event_allowed = yield self.third_party_event_rules.check_event_allowed( - event, context, + event, context ) if not event_allowed: logger.info("Sending of join %s forbidden by third-party rules", event) raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN, + 403, "This event is not allowed in this context", Codes.FORBIDDEN ) logger.debug( @@ -1350,10 +1304,7 @@ class FederationHandler(BaseHandler): state = yield self.store.get_events(list(prev_state_ids.values())) - defer.returnValue({ - "state": list(state.values()), - "auth_chain": auth_chain, - }) + defer.returnValue({"state": list(state.values()), "auth_chain": auth_chain}) @defer.inlineCallbacks def on_invite_request(self, origin, pdu): @@ -1376,8 +1327,11 @@ class FederationHandler(BaseHandler): is_published = yield self.store.is_room_published(event.room_id) if not self.spam_checker.user_may_invite( - event.sender, event.state_key, None, - room_id=event.room_id, new_room=False, + event.sender, + event.state_key, + None, + room_id=event.room_id, + new_room=False, published_room=is_published, ): raise SynapseError( @@ -1390,26 +1344,23 @@ class FederationHandler(BaseHandler): sender_domain = get_domain_from_id(event.sender) if sender_domain != origin: - raise SynapseError(400, "The invite event was not from the server sending it") + raise SynapseError( + 400, "The invite event was not from the server sending it" + ) if not self.is_mine_id(event.state_key): raise SynapseError(400, "The invite event must be for this server") # block any attempts to invite the server notices mxid if event.state_key == self._server_notices_mxid: - raise SynapseError( - http_client.FORBIDDEN, - "Cannot invite this user", - ) + raise SynapseError(http_client.FORBIDDEN, "Cannot invite this user") event.internal_metadata.outlier = True event.internal_metadata.out_of_band_membership = True event.signatures.update( compute_event_signature( - event.get_pdu_json(), - self.hs.hostname, - self.hs.config.signing_key[0] + event.get_pdu_json(), self.hs.hostname, self.hs.config.signing_key[0] ) ) @@ -1421,10 +1372,7 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def do_remotely_reject_invite(self, target_hosts, room_id, user_id): origin, event, event_format_version = yield self._make_and_verify_event( - target_hosts, - room_id, - user_id, - "leave" + target_hosts, room_id, user_id, "leave" ) # Mark as outlier as we don't have any state for this event; we're not # even in the room. @@ -1439,10 +1387,7 @@ class FederationHandler(BaseHandler): except ValueError: pass - yield self.federation_client.send_leave( - target_hosts, - event - ) + yield self.federation_client.send_leave(target_hosts, event) context = yield self.state_handler.compute_event_context(event) yield self.persist_events_and_notify([(event, context)]) @@ -1450,25 +1395,21 @@ class FederationHandler(BaseHandler): defer.returnValue(event) @defer.inlineCallbacks - def _make_and_verify_event(self, target_hosts, room_id, user_id, membership, - content={}, params=None): + def _make_and_verify_event( + self, target_hosts, room_id, user_id, membership, content={}, params=None + ): origin, event, format_ver = yield self.federation_client.make_membership_event( - target_hosts, - room_id, - user_id, - membership, - content, - params=params, + target_hosts, room_id, user_id, membership, content, params=params ) logger.debug("Got response to make_%s: %s", membership, event) # We should assert some things. # FIXME: Do this in a nicer way - assert(event.type == EventTypes.Member) - assert(event.user_id == user_id) - assert(event.state_key == user_id) - assert(event.room_id == room_id) + assert event.type == EventTypes.Member + assert event.user_id == user_id + assert event.state_key == user_id + assert event.room_id == room_id defer.returnValue((origin, event, format_ver)) @defer.inlineCallbacks @@ -1487,27 +1428,27 @@ class FederationHandler(BaseHandler): "room_id": room_id, "sender": user_id, "state_key": user_id, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( - builder=builder, + builder=builder ) event_allowed = yield self.third_party_event_rules.check_event_allowed( - event, context, + event, context ) if not event_allowed: logger.warning("Creation of leave %s forbidden by third-party rules", event) raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN, + 403, "This event is not allowed in this context", Codes.FORBIDDEN ) try: # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_leave_request` yield self.auth.check_from_context( - room_version, event, context, do_sig_check=False, + room_version, event, context, do_sig_check=False ) except AuthError as e: logger.warn("Failed to create new leave %r because %s", event, e) @@ -1529,17 +1470,15 @@ class FederationHandler(BaseHandler): event.internal_metadata.outlier = False - context = yield self._handle_new_event( - origin, event - ) + context = yield self._handle_new_event(origin, event) event_allowed = yield self.third_party_event_rules.check_event_allowed( - event, context, + event, context ) if not event_allowed: logger.info("Sending of leave %s forbidden by third-party rules", event) raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN, + 403, "This event is not allowed in this context", Codes.FORBIDDEN ) logger.debug( @@ -1556,18 +1495,14 @@ class FederationHandler(BaseHandler): """ event = yield self.store.get_event( - event_id, allow_none=False, check_room_id=room_id, + event_id, allow_none=False, check_room_id=room_id ) - state_groups = yield self.store.get_state_groups( - room_id, [event_id] - ) + state_groups = yield self.store.get_state_groups(room_id, [event_id]) if state_groups: _, state = list(iteritems(state_groups)).pop() - results = { - (e.type, e.state_key): e for e in state - } + results = {(e.type, e.state_key): e for e in state} if event.is_state(): # Get previous state @@ -1589,12 +1524,10 @@ class FederationHandler(BaseHandler): """Returns the state at the event. i.e. not including said event. """ event = yield self.store.get_event( - event_id, allow_none=False, check_room_id=room_id, + event_id, allow_none=False, check_room_id=room_id ) - state_groups = yield self.store.get_state_groups_ids( - room_id, [event_id] - ) + state_groups = yield self.store.get_state_groups_ids(room_id, [event_id]) if state_groups: _, state = list(state_groups.items()).pop() @@ -1620,11 +1553,7 @@ class FederationHandler(BaseHandler): if not in_room: raise AuthError(403, "Host not in room.") - events = yield self.store.get_backfill_events( - room_id, - pdu_list, - limit - ) + events = yield self.store.get_backfill_events(room_id, pdu_list, limit) events = yield filter_events_for_server(self.store, origin, events) @@ -1648,22 +1577,15 @@ class FederationHandler(BaseHandler): AuthError if the server is not currently in the room """ event = yield self.store.get_event( - event_id, - allow_none=True, - allow_rejected=True, + event_id, allow_none=True, allow_rejected=True ) if event: - in_room = yield self.auth.check_host_in_room( - event.room_id, - origin - ) + in_room = yield self.auth.check_host_in_room(event.room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") - events = yield filter_events_for_server( - self.store, origin, [event], - ) + events = yield filter_events_for_server(self.store, origin, [event]) event = events[0] defer.returnValue(event) else: @@ -1673,13 +1595,11 @@ class FederationHandler(BaseHandler): return self.store.get_min_depth(context) @defer.inlineCallbacks - def _handle_new_event(self, origin, event, state=None, auth_events=None, - backfilled=False): + def _handle_new_event( + self, origin, event, state=None, auth_events=None, backfilled=False + ): context = yield self._prep_event( - origin, event, - state=state, - auth_events=auth_events, - backfilled=backfilled, + origin, event, state=state, auth_events=auth_events, backfilled=backfilled ) # reraise does not allow inlineCallbacks to preserve the stacktrace, so we @@ -1692,15 +1612,13 @@ class FederationHandler(BaseHandler): ) yield self.persist_events_and_notify( - [(event, context)], - backfilled=backfilled, + [(event, context)], backfilled=backfilled ) success = True finally: if not success: logcontext.run_in_background( - self.store.remove_push_actions_from_staging, - event.event_id, + self.store.remove_push_actions_from_staging, event.event_id ) defer.returnValue(context) @@ -1728,12 +1646,15 @@ class FederationHandler(BaseHandler): ) defer.returnValue(res) - contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults( - [ - logcontext.run_in_background(prep, ev_info) - for ev_info in event_infos - ], consumeErrors=True, - )) + contexts = yield logcontext.make_deferred_yieldable( + defer.gatherResults( + [ + logcontext.run_in_background(prep, ev_info) + for ev_info in event_infos + ], + consumeErrors=True, + ) + ) yield self.persist_events_and_notify( [ @@ -1768,8 +1689,7 @@ class FederationHandler(BaseHandler): events_to_context[e.event_id] = ctx event_map = { - e.event_id: e - for e in itertools.chain(auth_events, state, [event]) + e.event_id: e for e in itertools.chain(auth_events, state, [event]) } create_event = None @@ -1784,7 +1704,7 @@ class FederationHandler(BaseHandler): raise SynapseError(400, "No create event in state") room_version = create_event.content.get( - "room_version", RoomVersions.V1.identifier, + "room_version", RoomVersions.V1.identifier ) missing_auth_events = set() @@ -1795,11 +1715,7 @@ class FederationHandler(BaseHandler): for e_id in missing_auth_events: m_ev = yield self.federation_client.get_pdu( - [origin], - e_id, - room_version=room_version, - outlier=True, - timeout=10000, + [origin], e_id, room_version=room_version, outlier=True, timeout=10000 ) if m_ev and m_ev.event_id == e_id: event_map[e_id] = m_ev @@ -1824,10 +1740,7 @@ class FederationHandler(BaseHandler): # cause SynapseErrors in auth.check. We don't want to give up # the attempt to federate altogether in such cases. - logger.warn( - "Rejecting %s because %s", - e.event_id, err.msg - ) + logger.warn("Rejecting %s because %s", e.event_id, err.msg) if e == event: raise @@ -1837,16 +1750,14 @@ class FederationHandler(BaseHandler): [ (e, events_to_context[e.event_id]) for e in itertools.chain(auth_events, state) - ], + ] ) new_event_context = yield self.state_handler.compute_event_context( event, old_state=state ) - yield self.persist_events_and_notify( - [(event, new_event_context)], - ) + yield self.persist_events_and_notify([(event, new_event_context)]) @defer.inlineCallbacks def _prep_event(self, origin, event, state, auth_events, backfilled): @@ -1862,40 +1773,30 @@ class FederationHandler(BaseHandler): Returns: Deferred, which resolves to synapse.events.snapshot.EventContext """ - context = yield self.state_handler.compute_event_context( - event, old_state=state, - ) + context = yield self.state_handler.compute_event_context(event, old_state=state) if not auth_events: prev_state_ids = yield context.get_prev_state_ids(self.store) auth_events_ids = yield self.auth.compute_auth_events( - event, prev_state_ids, for_verification=True, + event, prev_state_ids, for_verification=True ) auth_events = yield self.store.get_events(auth_events_ids) - auth_events = { - (e.type, e.state_key): e for e in auth_events.values() - } + auth_events = {(e.type, e.state_key): e for e in auth_events.values()} # This is a hack to fix some old rooms where the initial join event # didn't reference the create event in its auth events. if event.type == EventTypes.Member and not event.auth_event_ids(): if len(event.prev_event_ids()) == 1 and event.depth < 5: c = yield self.store.get_event( - event.prev_event_ids()[0], - allow_none=True, + event.prev_event_ids()[0], allow_none=True ) if c and c.type == EventTypes.Create: auth_events[(c.type, c.state_key)] = c try: - yield self.do_auth( - origin, event, context, auth_events=auth_events - ) + yield self.do_auth(origin, event, context, auth_events=auth_events) except AuthError as e: - logger.warn( - "[%s %s] Rejecting: %s", - event.room_id, event.event_id, e.msg - ) + logger.warn("[%s %s] Rejecting: %s", event.room_id, event.event_id, e.msg) context.rejected = RejectedReason.AUTH_ERROR @@ -1926,9 +1827,7 @@ class FederationHandler(BaseHandler): # "soft-fail" the event. do_soft_fail_check = not backfilled and not event.internal_metadata.is_outlier() if do_soft_fail_check: - extrem_ids = yield self.store.get_latest_event_ids_in_room( - event.room_id, - ) + extrem_ids = yield self.store.get_latest_event_ids_in_room(event.room_id) extrem_ids = set(extrem_ids) prev_event_ids = set(event.prev_event_ids()) @@ -1956,31 +1855,31 @@ class FederationHandler(BaseHandler): # like bans, especially with state res v2. state_sets = yield self.store.get_state_groups( - event.room_id, extrem_ids, + event.room_id, extrem_ids ) state_sets = list(state_sets.values()) state_sets.append(state) current_state_ids = yield self.state_handler.resolve_events( - room_version, state_sets, event, + room_version, state_sets, event ) current_state_ids = { k: e.event_id for k, e in iteritems(current_state_ids) } else: current_state_ids = yield self.state_handler.get_current_state_ids( - event.room_id, latest_event_ids=extrem_ids, + event.room_id, latest_event_ids=extrem_ids ) logger.debug( "Doing soft-fail check for %s: state %s", - event.event_id, current_state_ids, + event.event_id, + current_state_ids, ) # Now check if event pass auth against said current state auth_types = auth_types_for_event(event) current_state_ids = [ - e for k, e in iteritems(current_state_ids) - if k in auth_types + e for k, e in iteritems(current_state_ids) if k in auth_types ] current_auth_events = yield self.store.get_events(current_state_ids) @@ -1991,19 +1890,14 @@ class FederationHandler(BaseHandler): try: self.auth.check(room_version, event, auth_events=current_auth_events) except AuthError as e: - logger.warn( - "Soft-failing %r because %s", - event, e, - ) + logger.warn("Soft-failing %r because %s", event, e) event.internal_metadata.soft_failed = True @defer.inlineCallbacks - def on_query_auth(self, origin, event_id, room_id, remote_auth_chain, rejects, - missing): - in_room = yield self.auth.check_host_in_room( - room_id, - origin - ) + def on_query_auth( + self, origin, event_id, room_id, remote_auth_chain, rejects, missing + ): + in_room = yield self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") @@ -2021,28 +1915,23 @@ class FederationHandler(BaseHandler): # Now get the current auth_chain for the event. local_auth_chain = yield self.store.get_auth_chain( - [auth_id for auth_id in event.auth_event_ids()], - include_given=True + [auth_id for auth_id in event.auth_event_ids()], include_given=True ) # TODO: Check if we would now reject event_id. If so we need to tell # everyone. - ret = yield self.construct_auth_difference( - local_auth_chain, remote_auth_chain - ) + ret = yield self.construct_auth_difference(local_auth_chain, remote_auth_chain) logger.debug("on_query_auth returning: %s", ret) defer.returnValue(ret) @defer.inlineCallbacks - def on_get_missing_events(self, origin, room_id, earliest_events, - latest_events, limit): - in_room = yield self.auth.check_host_in_room( - room_id, - origin - ) + def on_get_missing_events( + self, origin, room_id, earliest_events, latest_events, limit + ): + in_room = yield self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") @@ -2056,7 +1945,7 @@ class FederationHandler(BaseHandler): ) missing_events = yield filter_events_for_server( - self.store, origin, missing_events, + self.store, origin, missing_events ) defer.returnValue(missing_events) @@ -2144,25 +2033,17 @@ class FederationHandler(BaseHandler): if missing_auth: # TODO: can we use store.have_seen_events here instead? - have_events = yield self.store.get_seen_events_with_rejections( - missing_auth - ) + have_events = yield self.store.get_seen_events_with_rejections(missing_auth) logger.debug("Got events %s from store", have_events) missing_auth.difference_update(have_events.keys()) else: have_events = {} - have_events.update({ - e.event_id: "" - for e in auth_events.values() - }) + have_events.update({e.event_id: "" for e in auth_events.values()}) if missing_auth: # If we don't have all the auth events, we need to get them. - logger.info( - "auth_events contains unknown events: %s", - missing_auth, - ) + logger.info("auth_events contains unknown events: %s", missing_auth) try: try: remote_auth_chain = yield self.federation_client.get_event_auth( @@ -2188,18 +2069,16 @@ class FederationHandler(BaseHandler): try: auth_ids = e.auth_event_ids() auth = { - (e.type, e.state_key): e for e in remote_auth_chain + (e.type, e.state_key): e + for e in remote_auth_chain if e.event_id in auth_ids or e.type == EventTypes.Create } e.internal_metadata.outlier = True logger.debug( - "do_auth %s missing_auth: %s", - event.event_id, e.event_id - ) - yield self._handle_new_event( - origin, e, auth_events=auth + "do_auth %s missing_auth: %s", event.event_id, e.event_id ) + yield self._handle_new_event(origin, e, auth_events=auth) if e.event_id in event_auth_events: auth_events[(e.type, e.state_key)] = e @@ -2235,35 +2114,36 @@ class FederationHandler(BaseHandler): room_version = yield self.store.get_room_version(event.room_id) different_events = yield logcontext.make_deferred_yieldable( - defer.gatherResults([ - logcontext.run_in_background( - self.store.get_event, - d, - allow_none=True, - allow_rejected=False, - ) - for d in different_auth - if d in have_events and not have_events[d] - ], consumeErrors=True) + defer.gatherResults( + [ + logcontext.run_in_background( + self.store.get_event, d, allow_none=True, allow_rejected=False + ) + for d in different_auth + if d in have_events and not have_events[d] + ], + consumeErrors=True, + ) ).addErrback(unwrapFirstError) if different_events: local_view = dict(auth_events) remote_view = dict(auth_events) - remote_view.update({ - (d.type, d.state_key): d for d in different_events if d - }) + remote_view.update( + {(d.type, d.state_key): d for d in different_events if d} + ) new_state = yield self.state_handler.resolve_events( room_version, [list(local_view.values()), list(remote_view.values())], - event + event, ) logger.info( "After state res: updating auth_events with new state %s", { - (d.type, d.state_key): d.event_id for d in new_state.values() + (d.type, d.state_key): d.event_id + for d in new_state.values() if auth_events.get((d.type, d.state_key)) != d }, ) @@ -2275,7 +2155,7 @@ class FederationHandler(BaseHandler): ) yield self._update_context_for_auth_events( - event, context, auth_events, event_key, + event, context, auth_events, event_key ) if not different_auth: @@ -2309,21 +2189,14 @@ class FederationHandler(BaseHandler): prev_state_ids = yield context.get_prev_state_ids(self.store) # 1. Get what we think is the auth chain. - auth_ids = yield self.auth.compute_auth_events( - event, prev_state_ids - ) - local_auth_chain = yield self.store.get_auth_chain( - auth_ids, include_given=True - ) + auth_ids = yield self.auth.compute_auth_events(event, prev_state_ids) + local_auth_chain = yield self.store.get_auth_chain(auth_ids, include_given=True) try: # 2. Get remote difference. try: result = yield self.federation_client.query_auth( - origin, - event.room_id, - event.event_id, - local_auth_chain, + origin, event.room_id, event.event_id, local_auth_chain ) except RequestSendFailed as e: # The other side isn't around or doesn't implement the @@ -2348,19 +2221,15 @@ class FederationHandler(BaseHandler): auth = { (e.type, e.state_key): e for e in result["auth_chain"] - if e.event_id in auth_ids - or event.type == EventTypes.Create + if e.event_id in auth_ids or event.type == EventTypes.Create } ev.internal_metadata.outlier = True logger.debug( - "do_auth %s different_auth: %s", - event.event_id, e.event_id + "do_auth %s different_auth: %s", event.event_id, e.event_id ) - yield self._handle_new_event( - origin, ev, auth_events=auth - ) + yield self._handle_new_event(origin, ev, auth_events=auth) if ev.event_id in event_auth_events: auth_events[(ev.type, ev.state_key)] = ev @@ -2375,12 +2244,11 @@ class FederationHandler(BaseHandler): # TODO. yield self._update_context_for_auth_events( - event, context, auth_events, event_key, + event, context, auth_events, event_key ) @defer.inlineCallbacks - def _update_context_for_auth_events(self, event, context, auth_events, - event_key): + def _update_context_for_auth_events(self, event, context, auth_events, event_key): """Update the state_ids in an event context after auth event resolution, storing the changes as a new state group. @@ -2397,8 +2265,7 @@ class FederationHandler(BaseHandler): this will not be included in the current_state in the context. """ state_updates = { - k: a.event_id for k, a in iteritems(auth_events) - if k != event_key + k: a.event_id for k, a in iteritems(auth_events) if k != event_key } current_state_ids = yield context.get_current_state_ids(self.store) current_state_ids = dict(current_state_ids) @@ -2408,9 +2275,7 @@ class FederationHandler(BaseHandler): prev_state_ids = yield context.get_prev_state_ids(self.store) prev_state_ids = dict(prev_state_ids) - prev_state_ids.update({ - k: a.event_id for k, a in iteritems(auth_events) - }) + prev_state_ids.update({k: a.event_id for k, a in iteritems(auth_events)}) # create a new state group as a delta from the existing one. prev_group = context.state_group @@ -2559,30 +2424,23 @@ class FederationHandler(BaseHandler): logger.debug("construct_auth_difference returning") - defer.returnValue({ - "auth_chain": local_auth, - "rejects": { - e.event_id: { - "reason": reason_map[e.event_id], - "proof": None, - } - for e in base_remote_rejected - }, - "missing": [e.event_id for e in missing_locals], - }) + defer.returnValue( + { + "auth_chain": local_auth, + "rejects": { + e.event_id: {"reason": reason_map[e.event_id], "proof": None} + for e in base_remote_rejected + }, + "missing": [e.event_id for e in missing_locals], + } + ) @defer.inlineCallbacks @log_function def exchange_third_party_invite( - self, - sender_user_id, - target_user_id, - room_id, - signed, + self, sender_user_id, target_user_id, room_id, signed ): - third_party_invite = { - "signed": signed, - } + third_party_invite = {"signed": signed} event_dict = { "type": EventTypes.Member, @@ -2605,7 +2463,7 @@ class FederationHandler(BaseHandler): ) event_allowed = yield self.third_party_event_rules.check_event_allowed( - event, context, + event, context ) if not event_allowed: logger.info( @@ -2613,7 +2471,7 @@ class FederationHandler(BaseHandler): event, ) raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN, + 403, "This event is not allowed in this context", Codes.FORBIDDEN ) event, context = yield self.add_display_name_to_third_party_invite( @@ -2638,9 +2496,7 @@ class FederationHandler(BaseHandler): else: destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id)) yield self.federation_client.forward_third_party_invite( - destinations, - room_id, - event_dict, + destinations, room_id, event_dict ) @defer.inlineCallbacks @@ -2661,19 +2517,18 @@ class FederationHandler(BaseHandler): builder = self.event_builder_factory.new(room_version, event_dict) event, context = yield self.event_creation_handler.create_new_client_event( - builder=builder, + builder=builder ) event_allowed = yield self.third_party_event_rules.check_event_allowed( - event, context, + event, context ) if not event_allowed: logger.warning( - "Exchange of threepid invite %s forbidden by third-party rules", - event, + "Exchange of threepid invite %s forbidden by third-party rules", event ) raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN, + 403, "This event is not allowed in this context", Codes.FORBIDDEN ) event, context = yield self.add_display_name_to_third_party_invite( @@ -2695,11 +2550,12 @@ class FederationHandler(BaseHandler): yield member_handler.send_membership_event(None, event, context) @defer.inlineCallbacks - def add_display_name_to_third_party_invite(self, room_version, event_dict, - event, context): + def add_display_name_to_third_party_invite( + self, room_version, event_dict, event, context + ): key = ( EventTypes.ThirdPartyInvite, - event.content["third_party_invite"]["signed"]["token"] + event.content["third_party_invite"]["signed"]["token"], ) original_invite = None prev_state_ids = yield context.get_prev_state_ids(self.store) @@ -2718,8 +2574,7 @@ class FederationHandler(BaseHandler): event_dict["content"]["third_party_invite"]["display_name"] = display_name else: logger.info( - "Could not find invite event for third_party_invite: %r", - event_dict + "Could not find invite event for third_party_invite: %r", event_dict ) # We don't discard here as this is not the appropriate place to do # auth checks. If we need the invite and don't have it then the @@ -2728,7 +2583,7 @@ class FederationHandler(BaseHandler): builder = self.event_builder_factory.new(room_version, event_dict) EventValidator().validate_builder(builder) event, context = yield self.event_creation_handler.create_new_client_event( - builder=builder, + builder=builder ) EventValidator().validate_new(event, self.config) defer.returnValue((event, context)) @@ -2752,9 +2607,7 @@ class FederationHandler(BaseHandler): token = signed["token"] prev_state_ids = yield context.get_prev_state_ids(self.store) - invite_event_id = prev_state_ids.get( - (EventTypes.ThirdPartyInvite, token,) - ) + invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token)) invite_event = None if invite_event_id: @@ -2778,38 +2631,42 @@ class FederationHandler(BaseHandler): logger.debug( "Attempting to verify sig with key %s from %r " "against pubkey %r", - key_name, server, public_key_object, + key_name, + server, + public_key_object, ) try: public_key = public_key_object["public_key"] verify_key = decode_verify_key_bytes( - key_name, - decode_base64(public_key) + key_name, decode_base64(public_key) ) verify_signed_json(signed, server, verify_key) logger.debug( "Successfully verified sig with key %s from %r " "against pubkey %r", - key_name, server, public_key_object, + key_name, + server, + public_key_object, ) except Exception: logger.info( "Failed to verify sig with key %s from %r " "against pubkey %r", - key_name, server, public_key_object, + key_name, + server, + public_key_object, ) raise try: if "key_validity_url" in public_key_object: yield self._check_key_revocation( - public_key, - public_key_object["key_validity_url"] + public_key, public_key_object["key_validity_url"] ) except Exception: logger.info( "Failed to query key_validity_url %s", - public_key_object["key_validity_url"] + public_key_object["key_validity_url"], ) raise return @@ -2832,15 +2689,9 @@ class FederationHandler(BaseHandler): for revocation. """ try: - response = yield self.http_client.get_json( - url, - {"public_key": public_key} - ) + response = yield self.http_client.get_json(url, {"public_key": public_key}) except Exception: - raise SynapseError( - 502, - "Third party certificate could not be checked" - ) + raise SynapseError(502, "Third party certificate could not be checked") if "valid" not in response or not response["valid"]: raise AuthError(403, "Third party certificate was invalid") @@ -2861,12 +2712,11 @@ class FederationHandler(BaseHandler): yield self._send_events_to_master( store=self.store, event_and_contexts=event_and_contexts, - backfilled=backfilled + backfilled=backfilled, ) else: max_stream_id = yield self.store.persist_events( - event_and_contexts, - backfilled=backfilled, + event_and_contexts, backfilled=backfilled ) if not backfilled: # Never notify for backfilled events @@ -2900,13 +2750,10 @@ class FederationHandler(BaseHandler): event_stream_id = event.internal_metadata.stream_ordering self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, - extra_users=extra_users + event, event_stream_id, max_stream_id, extra_users=extra_users ) - return self.pusher_pool.on_new_notifications( - event_stream_id, max_stream_id, - ) + return self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) def _clean_room_for_join(self, room_id): """Called to clean up any data in DB for a given room, ready for the @@ -2925,9 +2772,7 @@ class FederationHandler(BaseHandler): """ if self.config.worker_app: return self._notify_user_membership_change( - room_id=room_id, - user_id=user.to_string(), - change="joined", + room_id=room_id, user_id=user.to_string(), change="joined" ) else: return user_joined_room(self.distributor, user, room_id) diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index f60ace02e8..7da63bb643 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -30,6 +30,7 @@ def _create_rerouter(func_name): """Returns a function that looks at the group id and calls the function on federation or the local group server if the group is local """ + def f(self, group_id, *args, **kwargs): if self.is_mine_id(group_id): return getattr(self.groups_server_handler, func_name)( @@ -58,6 +59,7 @@ def _create_rerouter(func_name): d.addErrback(http_response_errback) d.addErrback(request_failed_errback) return d + return f @@ -125,7 +127,7 @@ class GroupsLocalHandler(object): ) else: res = yield self.transport_client.get_group_summary( - get_domain_from_id(group_id), group_id, requester_user_id, + get_domain_from_id(group_id), group_id, requester_user_id ) group_server_name = get_domain_from_id(group_id) @@ -182,7 +184,7 @@ class GroupsLocalHandler(object): content["user_profile"] = yield self.profile_handler.get_profile(user_id) res = yield self.transport_client.create_group( - get_domain_from_id(group_id), group_id, user_id, content, + get_domain_from_id(group_id), group_id, user_id, content ) remote_attestation = res["attestation"] @@ -195,16 +197,15 @@ class GroupsLocalHandler(object): is_publicised = content.get("publicise", False) token = yield self.store.register_user_group_membership( - group_id, user_id, + group_id, + user_id, membership="join", is_admin=True, local_attestation=local_attestation, remote_attestation=remote_attestation, is_publicised=is_publicised, ) - self.notifier.on_new_event( - "groups_key", token, users=[user_id], - ) + self.notifier.on_new_event("groups_key", token, users=[user_id]) defer.returnValue(res) @@ -221,7 +222,7 @@ class GroupsLocalHandler(object): group_server_name = get_domain_from_id(group_id) res = yield self.transport_client.get_users_in_group( - get_domain_from_id(group_id), group_id, requester_user_id, + get_domain_from_id(group_id), group_id, requester_user_id ) chunk = res["chunk"] @@ -250,9 +251,7 @@ class GroupsLocalHandler(object): """Request to join a group """ if self.is_mine_id(group_id): - yield self.groups_server_handler.join_group( - group_id, user_id, content - ) + yield self.groups_server_handler.join_group(group_id, user_id, content) local_attestation = None remote_attestation = None else: @@ -260,7 +259,7 @@ class GroupsLocalHandler(object): content["attestation"] = local_attestation res = yield self.transport_client.join_group( - get_domain_from_id(group_id), group_id, user_id, content, + get_domain_from_id(group_id), group_id, user_id, content ) remote_attestation = res["attestation"] @@ -276,16 +275,15 @@ class GroupsLocalHandler(object): is_publicised = content.get("publicise", False) token = yield self.store.register_user_group_membership( - group_id, user_id, + group_id, + user_id, membership="join", is_admin=False, local_attestation=local_attestation, remote_attestation=remote_attestation, is_publicised=is_publicised, ) - self.notifier.on_new_event( - "groups_key", token, users=[user_id], - ) + self.notifier.on_new_event("groups_key", token, users=[user_id]) defer.returnValue({}) @@ -294,9 +292,7 @@ class GroupsLocalHandler(object): """Accept an invite to a group """ if self.is_mine_id(group_id): - yield self.groups_server_handler.accept_invite( - group_id, user_id, content - ) + yield self.groups_server_handler.accept_invite(group_id, user_id, content) local_attestation = None remote_attestation = None else: @@ -304,7 +300,7 @@ class GroupsLocalHandler(object): content["attestation"] = local_attestation res = yield self.transport_client.accept_group_invite( - get_domain_from_id(group_id), group_id, user_id, content, + get_domain_from_id(group_id), group_id, user_id, content ) remote_attestation = res["attestation"] @@ -320,16 +316,15 @@ class GroupsLocalHandler(object): is_publicised = content.get("publicise", False) token = yield self.store.register_user_group_membership( - group_id, user_id, + group_id, + user_id, membership="join", is_admin=False, local_attestation=local_attestation, remote_attestation=remote_attestation, is_publicised=is_publicised, ) - self.notifier.on_new_event( - "groups_key", token, users=[user_id], - ) + self.notifier.on_new_event("groups_key", token, users=[user_id]) defer.returnValue({}) @@ -337,17 +332,17 @@ class GroupsLocalHandler(object): def invite(self, group_id, user_id, requester_user_id, config): """Invite a user to a group """ - content = { - "requester_user_id": requester_user_id, - "config": config, - } + content = {"requester_user_id": requester_user_id, "config": config} if self.is_mine_id(group_id): res = yield self.groups_server_handler.invite_to_group( - group_id, user_id, requester_user_id, content, + group_id, user_id, requester_user_id, content ) else: res = yield self.transport_client.invite_to_group( - get_domain_from_id(group_id), group_id, user_id, requester_user_id, + get_domain_from_id(group_id), + group_id, + user_id, + requester_user_id, content, ) @@ -370,13 +365,12 @@ class GroupsLocalHandler(object): local_profile["avatar_url"] = content["profile"]["avatar_url"] token = yield self.store.register_user_group_membership( - group_id, user_id, + group_id, + user_id, membership="invite", content={"profile": local_profile, "inviter": content["inviter"]}, ) - self.notifier.on_new_event( - "groups_key", token, users=[user_id], - ) + self.notifier.on_new_event("groups_key", token, users=[user_id]) try: user_profile = yield self.profile_handler.get_profile(user_id) except Exception as e: @@ -391,25 +385,25 @@ class GroupsLocalHandler(object): """ if user_id == requester_user_id: token = yield self.store.register_user_group_membership( - group_id, user_id, - membership="leave", - ) - self.notifier.on_new_event( - "groups_key", token, users=[user_id], + group_id, user_id, membership="leave" ) + self.notifier.on_new_event("groups_key", token, users=[user_id]) # TODO: Should probably remember that we tried to leave so that we can # retry if the group server is currently down. if self.is_mine_id(group_id): res = yield self.groups_server_handler.remove_user_from_group( - group_id, user_id, requester_user_id, content, + group_id, user_id, requester_user_id, content ) else: content["requester_user_id"] = requester_user_id res = yield self.transport_client.remove_user_from_group( - get_domain_from_id(group_id), group_id, requester_user_id, - user_id, content, + get_domain_from_id(group_id), + group_id, + requester_user_id, + user_id, + content, ) defer.returnValue(res) @@ -420,12 +414,9 @@ class GroupsLocalHandler(object): """ # TODO: Check if user in group token = yield self.store.register_user_group_membership( - group_id, user_id, - membership="leave", - ) - self.notifier.on_new_event( - "groups_key", token, users=[user_id], + group_id, user_id, membership="leave" ) + self.notifier.on_new_event("groups_key", token, users=[user_id]) @defer.inlineCallbacks def get_joined_groups(self, user_id): @@ -445,7 +436,7 @@ class GroupsLocalHandler(object): defer.returnValue({"groups": result}) else: bulk_result = yield self.transport_client.bulk_get_publicised_groups( - get_domain_from_id(user_id), [user_id], + get_domain_from_id(user_id), [user_id] ) result = bulk_result.get("users", {}).get(user_id) # TODO: Verify attestations @@ -460,9 +451,7 @@ class GroupsLocalHandler(object): if self.hs.is_mine_id(user_id): local_users.add(user_id) else: - destinations.setdefault( - get_domain_from_id(user_id), set() - ).add(user_id) + destinations.setdefault(get_domain_from_id(user_id), set()).add(user_id) if not proxy and destinations: raise SynapseError(400, "Some user_ids are not local") @@ -472,16 +461,14 @@ class GroupsLocalHandler(object): for destination, dest_user_ids in iteritems(destinations): try: r = yield self.transport_client.bulk_get_publicised_groups( - destination, list(dest_user_ids), + destination, list(dest_user_ids) ) results.update(r["users"]) except Exception: failed_results.extend(dest_user_ids) for uid in local_users: - results[uid] = yield self.store.get_publicised_groups_for_user( - uid - ) + results[uid] = yield self.store.get_publicised_groups_for_user(uid) # Check AS associated groups for this user - this depends on the # RegExps in the AS registration file (under `users`) diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index dfc03f51e7..0b40541570 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -41,7 +41,6 @@ logger = logging.getLogger(__name__) class IdentityHandler(BaseHandler): - def __init__(self, hs): super(IdentityHandler, self).__init__(hs) @@ -71,24 +70,24 @@ class IdentityHandler(BaseHandler): @defer.inlineCallbacks def threepid_from_creds(self, creds): - if 'id_server' in creds: - id_server = creds['id_server'] - elif 'idServer' in creds: - id_server = creds['idServer'] + if "id_server" in creds: + id_server = creds["id_server"] + elif "idServer" in creds: + id_server = creds["idServer"] else: raise SynapseError(400, "No id_server in creds") - if 'client_secret' in creds: - client_secret = creds['client_secret'] - elif 'clientSecret' in creds: - client_secret = creds['clientSecret'] + if "client_secret" in creds: + client_secret = creds["client_secret"] + elif "clientSecret" in creds: + client_secret = creds["clientSecret"] else: raise SynapseError(400, "No client_secret in creds") if not self._should_trust_id_server(id_server): logger.warn( - '%s is not a trusted ID server: rejecting 3pid ' + - 'credentials', id_server + "%s is not a trusted ID server: rejecting 3pid " + "credentials", + id_server, ) defer.returnValue(None) # if we have a rewrite rule set for the identity server, @@ -97,17 +96,15 @@ class IdentityHandler(BaseHandler): id_server = self.rewrite_identity_server_urls[id_server] try: data = yield self.http_client.get_json( - "https://%s%s" % ( - id_server, - "/_matrix/identity/api/v1/3pid/getValidated3pid" - ), - {'sid': creds['sid'], 'client_secret': client_secret} + "https://%s%s" + % (id_server, "/_matrix/identity/api/v1/3pid/getValidated3pid"), + {"sid": creds["sid"], "client_secret": client_secret}, ) except HttpResponseException as e: logger.info("getValidated3pid failed with Matrix error: %r", e) raise e.to_synapse_error() - if 'medium' in data: + if "medium" in data: defer.returnValue(data) defer.returnValue(None) @@ -116,17 +113,17 @@ class IdentityHandler(BaseHandler): logger.debug("binding threepid %r to %s", creds, mxid) data = None - if 'id_server' in creds: - id_server = creds['id_server'] - elif 'idServer' in creds: - id_server = creds['idServer'] + if "id_server" in creds: + id_server = creds["id_server"] + elif "idServer" in creds: + id_server = creds["idServer"] else: raise SynapseError(400, "No id_server in creds") - if 'client_secret' in creds: - client_secret = creds['client_secret'] - elif 'clientSecret' in creds: - client_secret = creds['clientSecret'] + if "client_secret" in creds: + client_secret = creds["client_secret"] + elif "clientSecret" in creds: + client_secret = creds["clientSecret"] else: raise SynapseError(400, "No client_secret in creds") @@ -140,14 +137,8 @@ class IdentityHandler(BaseHandler): try: data = yield self.http_client.post_urlencoded_get_json( - "https://%s%s" % ( - id_server_host, "/_matrix/identity/api/v1/3pid/bind" - ), - { - 'sid': creds['sid'], - 'client_secret': client_secret, - 'mxid': mxid, - } + "https://%s%s" % (id_server_host, "/_matrix/identity/api/v1/3pid/bind"), + {"sid": creds["sid"], "client_secret": client_secret, "mxid": mxid}, ) logger.debug("bound threepid %r to %s", creds, mxid) @@ -183,9 +174,7 @@ class IdentityHandler(BaseHandler): id_servers = [threepid["id_server"]] else: id_servers = yield self.store.get_id_servers_user_bound( - user_id=mxid, - medium=threepid["medium"], - address=threepid["address"], + user_id=mxid, medium=threepid["medium"], address=threepid["address"] ) # We don't know where to unbind, so we don't have a choice but to return @@ -195,7 +184,7 @@ class IdentityHandler(BaseHandler): changed = True for id_server in id_servers: changed &= yield self.try_unbind_threepid_with_id_server( - mxid, threepid, id_server, + mxid, threepid, id_server ) defer.returnValue(changed) @@ -219,10 +208,7 @@ class IdentityHandler(BaseHandler): url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,) content = { "mxid": mxid, - "threepid": { - "medium": threepid["medium"], - "address": threepid["address"], - }, + "threepid": {"medium": threepid["medium"], "address": threepid["address"]}, } # we abuse the federation http client to sign the request, but we have to send it @@ -230,14 +216,12 @@ class IdentityHandler(BaseHandler): # 'browser-like' HTTPS. auth_headers = self.federation_http_client.build_auth_headers( destination=None, - method='POST', - url_bytes='/_matrix/identity/api/v1/3pid/unbind'.encode('ascii'), + method="POST", + url_bytes="/_matrix/identity/api/v1/3pid/unbind".encode("ascii"), content=content, destination_is=id_server, ) - headers = { - b"Authorization": auth_headers, - } + headers = {b"Authorization": auth_headers} # if we have a rewrite rule set for the identity server, # apply it now. @@ -250,15 +234,11 @@ class IdentityHandler(BaseHandler): url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,) try: - yield self.http_client.post_json_get_json( - url, - content, - headers, - ) + yield self.http_client.post_json_get_json(url, content, headers) changed = True except HttpResponseException as e: changed = False - if e.code in (400, 404, 501,): + if e.code in (400, 404, 501): # The remote server probably doesn't support unbinding (yet) logger.warn("Received %d response while unbinding threepid", e.code) else: @@ -276,23 +256,17 @@ class IdentityHandler(BaseHandler): @defer.inlineCallbacks def requestEmailToken( - self, - id_server, - email, - client_secret, - send_attempt, - next_link=None, + self, id_server, email, client_secret, send_attempt, next_link=None ): if not self._should_trust_id_server(id_server): raise SynapseError( - 400, "Untrusted ID server '%s'" % id_server, - Codes.SERVER_NOT_TRUSTED + 400, "Untrusted ID server '%s'" % id_server, Codes.SERVER_NOT_TRUSTED ) params = { - 'email': email, - 'client_secret': client_secret, - 'send_attempt': send_attempt, + "email": email, + "client_secret": client_secret, + "send_attempt": send_attempt, } # if we have a rewrite rule set for the identity server, @@ -301,15 +275,13 @@ class IdentityHandler(BaseHandler): id_server = self.rewrite_identity_server_urls[id_server] if next_link: - params.update({'next_link': next_link}) + params.update({"next_link": next_link}) try: data = yield self.http_client.post_json_get_json( - "https://%s%s" % ( - id_server, - "/_matrix/identity/api/v1/validate/email/requestToken" - ), - params + "https://%s%s" + % (id_server, "/_matrix/identity/api/v1/validate/email/requestToken"), + params, ) defer.returnValue(data) except HttpResponseException as e: @@ -318,20 +290,18 @@ class IdentityHandler(BaseHandler): @defer.inlineCallbacks def requestMsisdnToken( - self, id_server, country, phone_number, - client_secret, send_attempt, **kwargs + self, id_server, country, phone_number, client_secret, send_attempt, **kwargs ): if not self._should_trust_id_server(id_server): raise SynapseError( - 400, "Untrusted ID server '%s'" % id_server, - Codes.SERVER_NOT_TRUSTED + 400, "Untrusted ID server '%s'" % id_server, Codes.SERVER_NOT_TRUSTED ) params = { - 'country': country, - 'phone_number': phone_number, - 'client_secret': client_secret, - 'send_attempt': send_attempt, + "country": country, + "phone_number": phone_number, + "client_secret": client_secret, + "send_attempt": send_attempt, } params.update(kwargs) # if we have a rewrite rule set for the identity server, @@ -340,11 +310,9 @@ class IdentityHandler(BaseHandler): id_server = self.rewrite_identity_server_urls[id_server] try: data = yield self.http_client.post_json_get_json( - "https://%s%s" % ( - id_server, - "/_matrix/identity/api/v1/validate/msisdn/requestToken" - ), - params + "https://%s%s" + % (id_server, "/_matrix/identity/api/v1/validate/msisdn/requestToken"), + params, ) defer.returnValue(data) except HttpResponseException as e: @@ -368,13 +336,12 @@ class IdentityHandler(BaseHandler): """ if not self._should_trust_id_server(id_server): raise SynapseError( - 400, "Untrusted ID server '%s'" % id_server, - Codes.SERVER_NOT_TRUSTED + 400, "Untrusted ID server '%s'" % id_server, Codes.SERVER_NOT_TRUSTED ) if not self._enable_lookup: raise AuthError( - 403, "Looking up third-party identifiers is denied from this server", + 403, "Looking up third-party identifiers is denied from this server" ) target = self.rewrite_identity_server_urls.get(id_server, id_server) @@ -382,10 +349,7 @@ class IdentityHandler(BaseHandler): try: data = yield self.http_client.get_json( "https://%s/_matrix/identity/api/v1/lookup" % (target,), - { - "medium": medium, - "address": address, - } + {"medium": medium, "address": address}, ) if "mxid" in data: @@ -419,13 +383,12 @@ class IdentityHandler(BaseHandler): """ if not self._should_trust_id_server(id_server): raise SynapseError( - 400, "Untrusted ID server '%s'" % id_server, - Codes.SERVER_NOT_TRUSTED + 400, "Untrusted ID server '%s'" % id_server, Codes.SERVER_NOT_TRUSTED ) if not self._enable_lookup: raise AuthError( - 403, "Looking up third-party identifiers is denied from this server", + 403, "Looking up third-party identifiers is denied from this server" ) target = self.rewrite_identity_server_urls.get(id_server, id_server) @@ -433,9 +396,7 @@ class IdentityHandler(BaseHandler): try: data = yield self.http_client.post_json_get_json( "https://%s/_matrix/identity/api/v1/bulk_lookup" % (target,), - { - "threepids": threepids, - } + {"threepids": threepids}, ) except HttpResponseException as e: @@ -454,20 +415,22 @@ class IdentityHandler(BaseHandler): for key_name, signature in data["signatures"][server_hostname].items(): target = self.rewrite_identity_server_urls.get( - server_hostname, server_hostname, + server_hostname, server_hostname ) key_data = yield self.http_client.get_json( - "https://%s/_matrix/identity/api/v1/pubkey/%s" % - (target, key_name,), + "https://%s/_matrix/identity/api/v1/pubkey/%s" % (target, key_name) ) if "public_key" not in key_data: - raise AuthError(401, "No public key named %s from %s" % - (key_name, server_hostname,)) + raise AuthError( + 401, "No public key named %s from %s" % (key_name, server_hostname) + ) verify_signed_json( data, server_hostname, - decode_verify_key_bytes(key_name, decode_base64(key_data["public_key"])) + decode_verify_key_bytes( + key_name, decode_base64(key_data["public_key"]) + ), ) return diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index aaee5db0b7..a1fe9d116f 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -44,8 +44,13 @@ class InitialSyncHandler(BaseHandler): self.snapshot_cache = SnapshotCache() self._event_serializer = hs.get_event_client_serializer() - def snapshot_all_rooms(self, user_id=None, pagin_config=None, - as_client_event=True, include_archived=False): + def snapshot_all_rooms( + self, + user_id=None, + pagin_config=None, + as_client_event=True, + include_archived=False, + ): """Retrieve a snapshot of all rooms the user is invited or has joined. This snapshot may include messages for all rooms where the user is @@ -77,13 +82,22 @@ class InitialSyncHandler(BaseHandler): if result is not None: return result - return self.snapshot_cache.set(now_ms, key, self._snapshot_all_rooms( - user_id, pagin_config, as_client_event, include_archived - )) + return self.snapshot_cache.set( + now_ms, + key, + self._snapshot_all_rooms( + user_id, pagin_config, as_client_event, include_archived + ), + ) @defer.inlineCallbacks - def _snapshot_all_rooms(self, user_id=None, pagin_config=None, - as_client_event=True, include_archived=False): + def _snapshot_all_rooms( + self, + user_id=None, + pagin_config=None, + as_client_event=True, + include_archived=False, + ): memberships = [Membership.INVITE, Membership.JOIN] if include_archived: @@ -128,8 +142,7 @@ class InitialSyncHandler(BaseHandler): "room_id": event.room_id, "membership": event.membership, "visibility": ( - "public" if event.room_id in public_room_ids - else "private" + "public" if event.room_id in public_room_ids else "private" ), } @@ -139,7 +152,7 @@ class InitialSyncHandler(BaseHandler): invite_event = yield self.store.get_event(event.event_id) d["invite"] = yield self._event_serializer.serialize_event( - invite_event, time_now, as_client_event, + invite_event, time_now, as_client_event ) rooms_ret.append(d) @@ -151,14 +164,12 @@ class InitialSyncHandler(BaseHandler): if event.membership == Membership.JOIN: room_end_token = now_token.room_key deferred_room_state = run_in_background( - self.state_handler.get_current_state, - event.room_id, + self.state_handler.get_current_state, event.room_id ) elif event.membership == Membership.LEAVE: room_end_token = "s%d" % (event.stream_ordering,) deferred_room_state = run_in_background( - self.store.get_state_for_events, - [event.event_id], + self.store.get_state_for_events, [event.event_id] ) deferred_room_state.addCallback( lambda states: states[event.event_id] @@ -178,9 +189,7 @@ class InitialSyncHandler(BaseHandler): ) ).addErrback(unwrapFirstError) - messages = yield filter_events_for_client( - self.store, user_id, messages - ) + messages = yield filter_events_for_client(self.store, user_id, messages) start_token = now_token.copy_and_replace("room_key", token) end_token = now_token.copy_and_replace("room_key", room_end_token) @@ -189,8 +198,7 @@ class InitialSyncHandler(BaseHandler): d["messages"] = { "chunk": ( yield self._event_serializer.serialize_events( - messages, time_now=time_now, - as_client_event=as_client_event, + messages, time_now=time_now, as_client_event=as_client_event ) ), "start": start_token.to_string(), @@ -200,23 +208,21 @@ class InitialSyncHandler(BaseHandler): d["state"] = yield self._event_serializer.serialize_events( current_state.values(), time_now=time_now, - as_client_event=as_client_event + as_client_event=as_client_event, ) account_data_events = [] tags = tags_by_room.get(event.room_id) if tags: - account_data_events.append({ - "type": "m.tag", - "content": {"tags": tags}, - }) + account_data_events.append( + {"type": "m.tag", "content": {"tags": tags}} + ) account_data = account_data_by_room.get(event.room_id, {}) for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) + account_data_events.append( + {"type": account_data_type, "content": content} + ) d["account_data"] = account_data_events except Exception: @@ -226,10 +232,7 @@ class InitialSyncHandler(BaseHandler): account_data_events = [] for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) + account_data_events.append({"type": account_data_type, "content": content}) now = self.clock.time_msec() @@ -274,7 +277,7 @@ class InitialSyncHandler(BaseHandler): user_id = requester.user.to_string() membership, member_event_id = yield self._check_in_room_or_world_readable( - room_id, user_id, + room_id, user_id ) is_peeking = member_event_id is None @@ -290,28 +293,21 @@ class InitialSyncHandler(BaseHandler): account_data_events = [] tags = yield self.store.get_tags_for_room(user_id, room_id) if tags: - account_data_events.append({ - "type": "m.tag", - "content": {"tags": tags}, - }) + account_data_events.append({"type": "m.tag", "content": {"tags": tags}}) account_data = yield self.store.get_account_data_for_room(user_id, room_id) for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) + account_data_events.append({"type": account_data_type, "content": content}) result["account_data"] = account_data_events defer.returnValue(result) @defer.inlineCallbacks - def _room_initial_sync_parted(self, user_id, room_id, pagin_config, - membership, member_event_id, is_peeking): - room_state = yield self.store.get_state_for_events( - [member_event_id], - ) + def _room_initial_sync_parted( + self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking + ): + room_state = yield self.store.get_state_for_events([member_event_id]) room_state = room_state[member_event_id] @@ -319,14 +315,10 @@ class InitialSyncHandler(BaseHandler): if limit is None: limit = 10 - stream_token = yield self.store.get_stream_token_for_event( - member_event_id - ) + stream_token = yield self.store.get_stream_token_for_event(member_event_id) messages, token = yield self.store.get_recent_events_for_room( - room_id, - limit=limit, - end_token=stream_token + room_id, limit=limit, end_token=stream_token ) messages = yield filter_events_for_client( @@ -338,34 +330,39 @@ class InitialSyncHandler(BaseHandler): time_now = self.clock.time_msec() - defer.returnValue({ - "membership": membership, - "room_id": room_id, - "messages": { - "chunk": (yield self._event_serializer.serialize_events( - messages, time_now, - )), - "start": start_token.to_string(), - "end": end_token.to_string(), - }, - "state": (yield self._event_serializer.serialize_events( - room_state.values(), time_now, - )), - "presence": [], - "receipts": [], - }) + defer.returnValue( + { + "membership": membership, + "room_id": room_id, + "messages": { + "chunk": ( + yield self._event_serializer.serialize_events( + messages, time_now + ) + ), + "start": start_token.to_string(), + "end": end_token.to_string(), + }, + "state": ( + yield self._event_serializer.serialize_events( + room_state.values(), time_now + ) + ), + "presence": [], + "receipts": [], + } + ) @defer.inlineCallbacks - def _room_initial_sync_joined(self, user_id, room_id, pagin_config, - membership, is_peeking): - current_state = yield self.state.get_current_state( - room_id=room_id, - ) + def _room_initial_sync_joined( + self, user_id, room_id, pagin_config, membership, is_peeking + ): + current_state = yield self.state.get_current_state(room_id=room_id) # TODO: These concurrently time_now = self.clock.time_msec() state = yield self._event_serializer.serialize_events( - current_state.values(), time_now, + current_state.values(), time_now ) now_token = yield self.hs.get_event_sources().get_current_token() @@ -375,7 +372,8 @@ class InitialSyncHandler(BaseHandler): limit = 10 room_members = [ - m for m in current_state.values() + m + for m in current_state.values() if m.type == EventTypes.Member and m.content["membership"] == Membership.JOIN ] @@ -389,8 +387,7 @@ class InitialSyncHandler(BaseHandler): defer.returnValue([]) states = yield presence_handler.get_states( - [m.user_id for m in room_members], - as_event=True, + [m.user_id for m in room_members], as_event=True ) defer.returnValue(states) @@ -398,8 +395,7 @@ class InitialSyncHandler(BaseHandler): @defer.inlineCallbacks def get_receipts(): receipts = yield self.store.get_linearized_receipts_for_room( - room_id, - to_key=now_token.receipt_key, + room_id, to_key=now_token.receipt_key ) if not receipts: receipts = [] @@ -415,14 +411,14 @@ class InitialSyncHandler(BaseHandler): room_id, limit=limit, end_token=now_token.room_key, - ) + ), ], consumeErrors=True, - ).addErrback(unwrapFirstError), + ).addErrback(unwrapFirstError) ) messages = yield filter_events_for_client( - self.store, user_id, messages, is_peeking=is_peeking, + self.store, user_id, messages, is_peeking=is_peeking ) start_token = now_token.copy_and_replace("room_key", token) @@ -433,9 +429,9 @@ class InitialSyncHandler(BaseHandler): ret = { "room_id": room_id, "messages": { - "chunk": (yield self._event_serializer.serialize_events( - messages, time_now, - )), + "chunk": ( + yield self._event_serializer.serialize_events(messages, time_now) + ), "start": start_token.to_string(), "end": end_token.to_string(), }, @@ -464,8 +460,8 @@ class InitialSyncHandler(BaseHandler): room_id, EventTypes.RoomHistoryVisibility, "" ) if ( - visibility and - visibility.content["history_visibility"] == "world_readable" + visibility + and visibility.content["history_visibility"] == "world_readable" ): defer.returnValue((Membership.JOIN, None)) return diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index a212ca0b0e..e4531fc069 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -61,8 +61,9 @@ class MessageHandler(object): self._event_serializer = hs.get_event_client_serializer() @defer.inlineCallbacks - def get_room_data(self, user_id=None, room_id=None, - event_type=None, state_key="", is_guest=False): + def get_room_data( + self, user_id=None, room_id=None, event_type=None, state_key="", is_guest=False + ): """ Get data from a room. Args: @@ -77,9 +78,7 @@ class MessageHandler(object): ) if membership == Membership.JOIN: - data = yield self.state.get_current_state( - room_id, event_type, state_key - ) + data = yield self.state.get_current_state(room_id, event_type, state_key) elif membership == Membership.LEAVE: key = (event_type, state_key) room_state = yield self.store.get_state_for_events( @@ -91,8 +90,12 @@ class MessageHandler(object): @defer.inlineCallbacks def get_state_events( - self, user_id, room_id, state_filter=StateFilter.all(), - at_token=None, is_guest=False, + self, + user_id, + room_id, + state_filter=StateFilter.all(), + at_token=None, + is_guest=False, ): """Retrieve all state events for a given room. If the user is joined to the room then return the current state. If the user has @@ -124,11 +127,11 @@ class MessageHandler(object): # does not reliably give you the state at the given stream position. # (https://github.com/matrix-org/synapse/issues/3305) last_events, _ = yield self.store.get_recent_events_for_room( - room_id, end_token=at_token.room_key, limit=1, + room_id, end_token=at_token.room_key, limit=1 ) if not last_events: - raise NotFoundError("Can't find event for token %s" % (at_token, )) + raise NotFoundError("Can't find event for token %s" % (at_token,)) visible_events = yield filter_events_for_client( self.store, user_id, last_events, apply_retention_policies=False @@ -137,37 +140,35 @@ class MessageHandler(object): event = last_events[0] if visible_events: room_state = yield self.store.get_state_for_events( - [event.event_id], state_filter=state_filter, + [event.event_id], state_filter=state_filter ) room_state = room_state[event.event_id] else: raise AuthError( 403, - "User %s not allowed to view events in room %s at token %s" % ( - user_id, room_id, at_token, - ) + "User %s not allowed to view events in room %s at token %s" + % (user_id, room_id, at_token), ) else: membership, membership_event_id = ( - yield self.auth.check_in_room_or_world_readable( - room_id, user_id, - ) + yield self.auth.check_in_room_or_world_readable(room_id, user_id) ) if membership == Membership.JOIN: state_ids = yield self.store.get_filtered_current_state_ids( - room_id, state_filter=state_filter, + room_id, state_filter=state_filter ) room_state = yield self.store.get_events(state_ids.values()) elif membership == Membership.LEAVE: room_state = yield self.store.get_state_for_events( - [membership_event_id], state_filter=state_filter, + [membership_event_id], state_filter=state_filter ) room_state = room_state[membership_event_id] now = self.clock.time_msec() events = yield self._event_serializer.serialize_events( - room_state.values(), now, + room_state.values(), + now, # We don't bother bundling aggregations in when asked for state # events, as clients won't use them. bundle_aggregations=False, @@ -211,13 +212,15 @@ class MessageHandler(object): # Loop fell through, AS has no interested users in room raise AuthError(403, "Appservice not in room") - defer.returnValue({ - user_id: { - "avatar_url": profile.avatar_url, - "display_name": profile.display_name, + defer.returnValue( + { + user_id: { + "avatar_url": profile.avatar_url, + "display_name": profile.display_name, + } + for user_id, profile in iteritems(users_with_profile) } - for user_id, profile in iteritems(users_with_profile) - }) + ) class EventCreationHandler(object): @@ -269,14 +272,21 @@ class EventCreationHandler(object): self.clock.looping_call( lambda: run_as_background_process( "send_dummy_events_to_fill_extremities", - self._send_dummy_events_to_fill_extremities + self._send_dummy_events_to_fill_extremities, ), 5 * 60 * 1000, ) @defer.inlineCallbacks - def create_event(self, requester, event_dict, token_id=None, txn_id=None, - prev_events_and_hashes=None, require_consent=True): + def create_event( + self, + requester, + event_dict, + token_id=None, + txn_id=None, + prev_events_and_hashes=None, + require_consent=True, + ): """ Given a dict from a client, create a new event. @@ -336,8 +346,7 @@ class EventCreationHandler(object): content["avatar_url"] = yield profile.get_avatar_url(target) except Exception as e: logger.info( - "Failed to get profile information for %r: %s", - target, e + "Failed to get profile information for %r: %s", target, e ) is_exempt = yield self._is_exempt_from_privacy_policy(builder, requester) @@ -373,16 +382,17 @@ class EventCreationHandler(object): prev_event = yield self.store.get_event(prev_event_id, allow_none=True) if not prev_event or prev_event.membership != Membership.JOIN: logger.warning( - ("Attempt to send `m.room.aliases` in room %s by user %s but" - " membership is %s"), + ( + "Attempt to send `m.room.aliases` in room %s by user %s but" + " membership is %s" + ), event.room_id, event.sender, prev_event.membership if prev_event else None, ) raise AuthError( - 403, - "You must be in the room to create an alias for it", + 403, "You must be in the room to create an alias for it" ) self.validator.validate_new(event, self.config) @@ -449,8 +459,8 @@ class EventCreationHandler(object): # exempt the system notices user if ( - self.config.server_notices_mxid is not None and - user_id == self.config.server_notices_mxid + self.config.server_notices_mxid is not None + and user_id == self.config.server_notices_mxid ): return @@ -463,15 +473,10 @@ class EventCreationHandler(object): return consent_uri = self._consent_uri_builder.build_user_consent_uri( - requester.user.localpart, - ) - msg = self._block_events_without_consent_error % { - 'consent_uri': consent_uri, - } - raise ConsentNotGivenError( - msg=msg, - consent_uri=consent_uri, + requester.user.localpart ) + msg = self._block_events_without_consent_error % {"consent_uri": consent_uri} + raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri) @defer.inlineCallbacks def send_nonmember_event(self, requester, event, context, ratelimit=True): @@ -486,8 +491,7 @@ class EventCreationHandler(object): """ if event.type == EventTypes.Member: raise SynapseError( - 500, - "Tried to send member event through non-member codepath" + 500, "Tried to send member event through non-member codepath" ) user = UserID.from_string(event.sender) @@ -499,15 +503,13 @@ class EventCreationHandler(object): if prev_state is not None: logger.info( "Not bothering to persist state event %s duplicated by %s", - event.event_id, prev_state.event_id, + event.event_id, + prev_state.event_id, ) defer.returnValue(prev_state) yield self.handle_new_client_event( - requester=requester, - event=event, - context=context, - ratelimit=ratelimit, + requester=requester, event=event, context=context, ratelimit=ratelimit ) @defer.inlineCallbacks @@ -533,11 +535,7 @@ class EventCreationHandler(object): @defer.inlineCallbacks def create_and_send_nonmember_event( - self, - requester, - event_dict, - ratelimit=True, - txn_id=None + self, requester, event_dict, ratelimit=True, txn_id=None ): """ Creates an event, then sends it. @@ -552,32 +550,25 @@ class EventCreationHandler(object): # taking longer. with (yield self.limiter.queue(event_dict["room_id"])): event, context = yield self.create_event( - requester, - event_dict, - token_id=requester.access_token_id, - txn_id=txn_id + requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id ) spam_error = self.spam_checker.check_event_for_spam(event) if spam_error: if not isinstance(spam_error, string_types): spam_error = "Spam is not permitted here" - raise SynapseError( - 403, spam_error, Codes.FORBIDDEN - ) + raise SynapseError(403, spam_error, Codes.FORBIDDEN) yield self.send_nonmember_event( - requester, - event, - context, - ratelimit=ratelimit, + requester, event, context, ratelimit=ratelimit ) defer.returnValue(event) @measure_func("create_new_client_event") @defer.inlineCallbacks - def create_new_client_event(self, builder, requester=None, - prev_events_and_hashes=None): + def create_new_client_event( + self, builder, requester=None, prev_events_and_hashes=None + ): """Create a new event for a local client Args: @@ -597,22 +588,21 @@ class EventCreationHandler(object): """ if prev_events_and_hashes is not None: - assert len(prev_events_and_hashes) <= 10, \ - "Attempting to create an event with %i prev_events" % ( - len(prev_events_and_hashes), + assert len(prev_events_and_hashes) <= 10, ( + "Attempting to create an event with %i prev_events" + % (len(prev_events_and_hashes),) ) else: - prev_events_and_hashes = \ - yield self.store.get_prev_events_for_room(builder.room_id) + prev_events_and_hashes = yield self.store.get_prev_events_for_room( + builder.room_id + ) prev_events = [ (event_id, prev_hashes) for event_id, prev_hashes, _ in prev_events_and_hashes ] - event = yield builder.build( - prev_event_ids=[p for p, _ in prev_events], - ) + event = yield builder.build(prev_event_ids=[p for p, _ in prev_events]) context = yield self.state.compute_event_context(event) if requester: context.app_service = requester.app_service @@ -628,29 +618,19 @@ class EventCreationHandler(object): aggregation_key = relation["key"] already_exists = yield self.store.has_user_annotated_event( - relates_to, event.type, aggregation_key, event.sender, + relates_to, event.type, aggregation_key, event.sender ) if already_exists: raise SynapseError(400, "Can't send same reaction twice") - logger.debug( - "Created event %s", - event.event_id, - ) + logger.debug("Created event %s", event.event_id) - defer.returnValue( - (event, context,) - ) + defer.returnValue((event, context)) @measure_func("handle_new_client_event") @defer.inlineCallbacks def handle_new_client_event( - self, - requester, - event, - context, - ratelimit=True, - extra_users=[], + self, requester, event, context, ratelimit=True, extra_users=[] ): """Processes a new event. This includes checking auth, persisting it, notifying users, sending to remote servers, etc. @@ -666,19 +646,20 @@ class EventCreationHandler(object): extra_users (list(UserID)): Any extra users to notify about event """ - if event.is_state() and (event.type, event.state_key) == (EventTypes.Create, ""): - room_version = event.content.get( - "room_version", RoomVersions.V1.identifier - ) + if event.is_state() and (event.type, event.state_key) == ( + EventTypes.Create, + "", + ): + room_version = event.content.get("room_version", RoomVersions.V1.identifier) else: room_version = yield self.store.get_room_version(event.room_id) event_allowed = yield self.third_party_event_rules.check_event_allowed( - event, context, + event, context ) if not event_allowed: raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN, + 403, "This event is not allowed in this context", Codes.FORBIDDEN ) try: @@ -695,9 +676,7 @@ class EventCreationHandler(object): logger.exception("Failed to encode content: %r", event.content) raise - yield self.action_generator.handle_push_actions_for_event( - event, context - ) + yield self.action_generator.handle_push_actions_for_event(event, context) # reraise does not allow inlineCallbacks to preserve the stacktrace, so we # hack around with a try/finally instead. @@ -718,11 +697,7 @@ class EventCreationHandler(object): return yield self.persist_and_notify_client_event( - requester, - event, - context, - ratelimit=ratelimit, - extra_users=extra_users, + requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) success = True @@ -731,18 +706,12 @@ class EventCreationHandler(object): # Ensure that we actually remove the entries in the push actions # staging area, if we calculated them. run_in_background( - self.store.remove_push_actions_from_staging, - event.event_id, + self.store.remove_push_actions_from_staging, event.event_id ) @defer.inlineCallbacks def persist_and_notify_client_event( - self, - requester, - event, - context, - ratelimit=True, - extra_users=[], + self, requester, event, context, ratelimit=True, extra_users=[] ): """Called when we have fully built the event, have already calculated the push actions for the event, and checked auth. @@ -767,20 +736,16 @@ class EventCreationHandler(object): if mapping["room_id"] != event.room_id: raise SynapseError( 400, - "Room alias %s does not point to the room" % ( - room_alias_str, - ) + "Room alias %s does not point to the room" % (room_alias_str,), ) federation_handler = self.hs.get_handlers().federation_handler if event.type == EventTypes.Member: if event.content["membership"] == Membership.INVITE: + def is_inviter_member_event(e): - return ( - e.type == EventTypes.Member and - e.sender == event.sender - ) + return e.type == EventTypes.Member and e.sender == event.sender current_state_ids = yield context.get_current_state_ids(self.store) @@ -810,26 +775,21 @@ class EventCreationHandler(object): # to get them to sign the event. returned_invite = yield federation_handler.send_invite( - invitee.domain, - event, + invitee.domain, event ) event.unsigned.pop("room_state", None) # TODO: Make sure the signatures actually are correct. - event.signatures.update( - returned_invite.signatures - ) + event.signatures.update(returned_invite.signatures) if event.type == EventTypes.Redaction: prev_state_ids = yield context.get_prev_state_ids(self.store) auth_events_ids = yield self.auth.compute_auth_events( - event, prev_state_ids, for_verification=True, + event, prev_state_ids, for_verification=True ) auth_events = yield self.store.get_events(auth_events_ids) - auth_events = { - (e.type, e.state_key): e for e in auth_events.values() - } + auth_events = {(e.type, e.state_key): e for e in auth_events.values()} room_version = yield self.store.get_room_version(event.room_id) if self.auth.check_redaction(room_version, event, auth_events=auth_events): original_event = yield self.store.get_event( @@ -837,13 +797,10 @@ class EventCreationHandler(object): check_redacted=False, get_prev_content=False, allow_rejected=False, - allow_none=False + allow_none=False, ) if event.user_id != original_event.user_id: - raise AuthError( - 403, - "You don't have permission to redact events" - ) + raise AuthError(403, "You don't have permission to redact events") # We've already checked. event.internal_metadata.recheck_redaction = False @@ -851,24 +808,18 @@ class EventCreationHandler(object): if event.type == EventTypes.Create: prev_state_ids = yield context.get_prev_state_ids(self.store) if prev_state_ids: - raise AuthError( - 403, - "Changing the room create event is forbidden", - ) + raise AuthError(403, "Changing the room create event is forbidden") (event_stream_id, max_stream_id) = yield self.store.persist_event( event, context=context ) - yield self.pusher_pool.on_new_notifications( - event_stream_id, max_stream_id, - ) + yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) def _notify(): try: self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, - extra_users=extra_users + event, event_stream_id, max_stream_id, extra_users=extra_users ) except Exception: logger.exception("Error notifying about new room event") @@ -895,23 +846,19 @@ class EventCreationHandler(object): """ room_ids = yield self.store.get_rooms_with_many_extremities( - min_count=10, limit=5, + min_count=10, limit=5 ) for room_id in room_ids: # For each room we need to find a joined member we can use to send # the dummy event with. - prev_events_and_hashes = yield self.store.get_prev_events_for_room( - room_id, - ) + prev_events_and_hashes = yield self.store.get_prev_events_for_room(room_id) - latest_event_ids = ( - event_id for (event_id, _, _) in prev_events_and_hashes - ) + latest_event_ids = (event_id for (event_id, _, _) in prev_events_and_hashes) members = yield self.state.get_current_users_in_room( - room_id, latest_event_ids=latest_event_ids, + room_id, latest_event_ids=latest_event_ids ) user_id = None @@ -941,9 +888,4 @@ class EventCreationHandler(object): event.internal_metadata.proactively_send = False - yield self.send_nonmember_event( - requester, - event, - context, - ratelimit=False, - ) + yield self.send_nonmember_event(requester, event, context, ratelimit=False) diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 3cf783e3bd..be54084088 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -58,9 +58,7 @@ class PurgeStatus(object): self.status = PurgeStatus.STATUS_ACTIVE def asdict(self): - return { - "status": PurgeStatus.STATUS_TEXT[self.status] - } + return {"status": PurgeStatus.STATUS_TEXT[self.status]} class PaginationHandler(object): @@ -153,20 +151,19 @@ class PaginationHandler(object): # Figure out what token we should start purging at. ts = self.clock.time_msec() - max_lifetime - stream_ordering = ( - yield self.store.find_first_stream_ordering_after_ts(ts) - ) + stream_ordering = (yield self.store.find_first_stream_ordering_after_ts(ts)) r = ( yield self.store.get_room_event_after_stream_ordering( - room_id, stream_ordering, + room_id, stream_ordering ) ) if not r: logger.warning( "[purge] purging events not possible: No event found " "(ts %i => stream_ordering %i)", - ts, stream_ordering, + ts, + stream_ordering, ) continue @@ -185,13 +182,10 @@ class PaginationHandler(object): # the background so that it's not blocking any other operation apart from # other purges in the same room. run_as_background_process( - "_purge_history", - self._purge_history, - purge_id, room_id, token, True, + "_purge_history", self._purge_history, purge_id, room_id, token, True ) - def start_purge_history(self, room_id, token, - delete_local_events=False): + def start_purge_history(self, room_id, token, delete_local_events=False): """Start off a history purge on a room. Args: @@ -206,8 +200,7 @@ class PaginationHandler(object): """ if room_id in self._purges_in_progress_by_room: raise SynapseError( - 400, - "History purge already in progress for %s" % (room_id, ), + 400, "History purge already in progress for %s" % (room_id,) ) purge_id = random_string(16) @@ -218,14 +211,12 @@ class PaginationHandler(object): self._purges_by_id[purge_id] = PurgeStatus() run_in_background( - self._purge_history, - purge_id, room_id, token, delete_local_events, + self._purge_history, purge_id, room_id, token, delete_local_events ) return purge_id @defer.inlineCallbacks - def _purge_history(self, purge_id, room_id, token, - delete_local_events): + def _purge_history(self, purge_id, room_id, token, delete_local_events): """Carry out a history purge on a room. Args: @@ -241,16 +232,13 @@ class PaginationHandler(object): self._purges_in_progress_by_room.add(room_id) try: with (yield self.pagination_lock.write(room_id)): - yield self.store.purge_history( - room_id, token, delete_local_events, - ) + yield self.store.purge_history(room_id, token, delete_local_events) logger.info("[purge] complete") self._purges_by_id[purge_id].status = PurgeStatus.STATUS_COMPLETE except Exception: f = Failure() logger.error( - "[purge] failed", - exc_info=(f.type, f.value, f.getTracebackObject()), + "[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject()) ) self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED finally: @@ -259,6 +247,7 @@ class PaginationHandler(object): # remove the purge from the list 24 hours after it completes def clear_purge(): del self._purges_by_id[purge_id] + self.hs.get_reactor().callLater(24 * 3600, clear_purge) def get_purge_status(self, purge_id): @@ -273,8 +262,14 @@ class PaginationHandler(object): return self._purges_by_id.get(purge_id) @defer.inlineCallbacks - def get_messages(self, requester, room_id=None, pagin_config=None, - as_client_event=True, event_filter=None): + def get_messages( + self, + requester, + room_id=None, + pagin_config=None, + as_client_event=True, + event_filter=None, + ): """Get messages in a room. Args: @@ -312,7 +307,7 @@ class PaginationHandler(object): room_id, user_id ) - if source_config.direction == 'b': + if source_config.direction == "b": # if we're going backwards, we might need to backfill. This # requires that we have a topo token. if room_token.topological: @@ -346,27 +341,24 @@ class PaginationHandler(object): event_filter=event_filter, ) - next_token = pagin_config.from_token.copy_and_replace( - "room_key", next_key - ) + next_token = pagin_config.from_token.copy_and_replace("room_key", next_key) if events: if event_filter: events = event_filter.filter(events) events = yield filter_events_for_client( - self.store, - user_id, - events, - is_peeking=(member_event_id is None), + self.store, user_id, events, is_peeking=(member_event_id is None) ) if not events: - defer.returnValue({ - "chunk": [], - "start": pagin_config.from_token.to_string(), - "end": next_token.to_string(), - }) + defer.returnValue( + { + "chunk": [], + "start": pagin_config.from_token.to_string(), + "end": next_token.to_string(), + } + ) state = None if event_filter and event_filter.lazy_load_members() and len(events) > 0: @@ -374,12 +366,11 @@ class PaginationHandler(object): # FIXME: we also care about invite targets etc. state_filter = StateFilter.from_types( - (EventTypes.Member, event.sender) - for event in events + (EventTypes.Member, event.sender) for event in events ) state_ids = yield self.store.get_state_ids_for_event( - events[0].event_id, state_filter=state_filter, + events[0].event_id, state_filter=state_filter ) if state_ids: @@ -391,8 +382,7 @@ class PaginationHandler(object): chunk = { "chunk": ( yield self._event_serializer.serialize_events( - events, time_now, - as_client_event=as_client_event, + events, time_now, as_client_event=as_client_event ) ), "start": pagin_config.from_token.to_string(), @@ -402,8 +392,7 @@ class PaginationHandler(object): if state: chunk["state"] = ( yield self._event_serializer.serialize_events( - state, time_now, - as_client_event=as_client_event, + state, time_now, as_client_event=as_client_event ) ) diff --git a/synapse/handlers/password_policy.py b/synapse/handlers/password_policy.py index 9994b44455..d06b110269 100644 --- a/synapse/handlers/password_policy.py +++ b/synapse/handlers/password_policy.py @@ -57,8 +57,8 @@ class PasswordPolicyHandler(object): ) if ( - self.policy.get("require_digit", False) and - self.regexp_digit.search(password) is None + self.policy.get("require_digit", False) + and self.regexp_digit.search(password) is None ): raise PasswordRefusedError( msg="The password must include at least one digit", @@ -66,8 +66,8 @@ class PasswordPolicyHandler(object): ) if ( - self.policy.get("require_symbol", False) and - self.regexp_symbol.search(password) is None + self.policy.get("require_symbol", False) + and self.regexp_symbol.search(password) is None ): raise PasswordRefusedError( msg="The password must include at least one symbol", @@ -75,8 +75,8 @@ class PasswordPolicyHandler(object): ) if ( - self.policy.get("require_uppercase", False) and - self.regexp_uppercase.search(password) is None + self.policy.get("require_uppercase", False) + and self.regexp_uppercase.search(password) is None ): raise PasswordRefusedError( msg="The password must include at least one uppercase letter", @@ -84,8 +84,8 @@ class PasswordPolicyHandler(object): ) if ( - self.policy.get("require_lowercase", False) and - self.regexp_lowercase.search(password) is None + self.policy.get("require_lowercase", False) + and self.regexp_lowercase.search(password) is None ): raise PasswordRefusedError( msg="The password must include at least one lowercase letter", diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 557fb5f83d..5204073a38 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -50,16 +50,20 @@ logger = logging.getLogger(__name__) notified_presence_counter = Counter("synapse_handler_presence_notified_presence", "") federation_presence_out_counter = Counter( - "synapse_handler_presence_federation_presence_out", "") + "synapse_handler_presence_federation_presence_out", "" +) presence_updates_counter = Counter("synapse_handler_presence_presence_updates", "") timers_fired_counter = Counter("synapse_handler_presence_timers_fired", "") -federation_presence_counter = Counter("synapse_handler_presence_federation_presence", "") +federation_presence_counter = Counter( + "synapse_handler_presence_federation_presence", "" +) bump_active_time_counter = Counter("synapse_handler_presence_bump_active_time", "") get_updates_counter = Counter("synapse_handler_presence_get_updates", "", ["type"]) notify_reason_counter = Counter( - "synapse_handler_presence_notify_reason", "", ["reason"]) + "synapse_handler_presence_notify_reason", "", ["reason"] +) state_transition_counter = Counter( "synapse_handler_presence_state_transition", "", ["from", "to"] ) @@ -90,7 +94,6 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER class PresenceHandler(object): - def __init__(self, hs): """ @@ -110,31 +113,26 @@ class PresenceHandler(object): federation_registry = hs.get_federation_registry() - federation_registry.register_edu_handler( - "m.presence", self.incoming_presence - ) + federation_registry.register_edu_handler("m.presence", self.incoming_presence) active_presence = self.store.take_presence_startup_info() # A dictionary of the current state of users. This is prefilled with # non-offline presence from the DB. We should fetch from the DB if # we can't find a users presence in here. - self.user_to_current_state = { - state.user_id: state - for state in active_presence - } + self.user_to_current_state = {state.user_id: state for state in active_presence} LaterGauge( - "synapse_handlers_presence_user_to_current_state_size", "", [], - lambda: len(self.user_to_current_state) + "synapse_handlers_presence_user_to_current_state_size", + "", + [], + lambda: len(self.user_to_current_state), ) now = self.clock.time_msec() for state in active_presence: self.wheel_timer.insert( - now=now, - obj=state.user_id, - then=state.last_active_ts + IDLE_TIMER, + now=now, obj=state.user_id, then=state.last_active_ts + IDLE_TIMER ) self.wheel_timer.insert( now=now, @@ -193,27 +191,21 @@ class PresenceHandler(object): "handle_presence_timeouts", self._handle_timeouts ) - self.clock.call_later( - 30, - self.clock.looping_call, - run_timeout_handler, - 5000, - ) + self.clock.call_later(30, self.clock.looping_call, run_timeout_handler, 5000) def run_persister(): return run_as_background_process( "persist_presence_changes", self._persist_unpersisted_changes ) - self.clock.call_later( - 60, - self.clock.looping_call, - run_persister, - 60 * 1000, - ) + self.clock.call_later(60, self.clock.looping_call, run_persister, 60 * 1000) - LaterGauge("synapse_handlers_presence_wheel_timer_size", "", [], - lambda: len(self.wheel_timer)) + LaterGauge( + "synapse_handlers_presence_wheel_timer_size", + "", + [], + lambda: len(self.wheel_timer), + ) # Used to handle sending of presence to newly joined users/servers if hs.config.use_presence: @@ -241,15 +233,17 @@ class PresenceHandler(object): logger.info( "Performing _on_shutdown. Persisting %d unpersisted changes", - len(self.user_to_current_state) + len(self.user_to_current_state), ) if self.unpersisted_users_changes: - yield self.store.update_presence([ - self.user_to_current_state[user_id] - for user_id in self.unpersisted_users_changes - ]) + yield self.store.update_presence( + [ + self.user_to_current_state[user_id] + for user_id in self.unpersisted_users_changes + ] + ) logger.info("Finished _on_shutdown") @defer.inlineCallbacks @@ -261,13 +255,10 @@ class PresenceHandler(object): self.unpersisted_users_changes = set() if unpersisted: - logger.info( - "Persisting %d upersisted presence updates", len(unpersisted) + logger.info("Persisting %d upersisted presence updates", len(unpersisted)) + yield self.store.update_presence( + [self.user_to_current_state[user_id] for user_id in unpersisted] ) - yield self.store.update_presence([ - self.user_to_current_state[user_id] - for user_id in unpersisted - ]) @defer.inlineCallbacks def _update_states(self, new_states): @@ -303,10 +294,11 @@ class PresenceHandler(object): ) new_state, should_notify, should_ping = handle_update( - prev_state, new_state, + prev_state, + new_state, is_mine=self.is_mine_id(user_id), wheel_timer=self.wheel_timer, - now=now + now=now, ) self.user_to_current_state[user_id] = new_state @@ -328,7 +320,8 @@ class PresenceHandler(object): self.unpersisted_users_changes -= set(to_notify.keys()) to_federation_ping = { - user_id: state for user_id, state in to_federation_ping.items() + user_id: state + for user_id, state in to_federation_ping.items() if user_id not in to_notify } if to_federation_ping: @@ -351,8 +344,8 @@ class PresenceHandler(object): # Check whether the lists of syncing processes from an external # process have expired. expired_process_ids = [ - process_id for process_id, last_update - in self.external_process_last_updated_ms.items() + process_id + for process_id, last_update in self.external_process_last_updated_ms.items() if now - last_update > EXTERNAL_PROCESS_EXPIRY ] for process_id in expired_process_ids: @@ -362,9 +355,7 @@ class PresenceHandler(object): self.external_process_last_update.pop(process_id) states = [ - self.user_to_current_state.get( - user_id, UserPresenceState.default(user_id) - ) + self.user_to_current_state.get(user_id, UserPresenceState.default(user_id)) for user_id in users_to_check ] @@ -394,9 +385,7 @@ class PresenceHandler(object): prev_state = yield self.current_state_for_user(user_id) - new_fields = { - "last_active_ts": self.clock.time_msec(), - } + new_fields = {"last_active_ts": self.clock.time_msec()} if prev_state.state == PresenceState.UNAVAILABLE: new_fields["state"] = PresenceState.ONLINE @@ -430,15 +419,23 @@ class PresenceHandler(object): if prev_state.state == PresenceState.OFFLINE: # If they're currently offline then bring them online, otherwise # just update the last sync times. - yield self._update_states([prev_state.copy_and_replace( - state=PresenceState.ONLINE, - last_active_ts=self.clock.time_msec(), - last_user_sync_ts=self.clock.time_msec(), - )]) + yield self._update_states( + [ + prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=self.clock.time_msec(), + last_user_sync_ts=self.clock.time_msec(), + ) + ] + ) else: - yield self._update_states([prev_state.copy_and_replace( - last_user_sync_ts=self.clock.time_msec(), - )]) + yield self._update_states( + [ + prev_state.copy_and_replace( + last_user_sync_ts=self.clock.time_msec() + ) + ] + ) @defer.inlineCallbacks def _end(): @@ -446,9 +443,13 @@ class PresenceHandler(object): self.user_to_num_current_syncs[user_id] -= 1 prev_state = yield self.current_state_for_user(user_id) - yield self._update_states([prev_state.copy_and_replace( - last_user_sync_ts=self.clock.time_msec(), - )]) + yield self._update_states( + [ + prev_state.copy_and_replace( + last_user_sync_ts=self.clock.time_msec() + ) + ] + ) except Exception: logger.exception("Error updating presence after sync") @@ -469,7 +470,8 @@ class PresenceHandler(object): """ if self.hs.config.use_presence: syncing_user_ids = { - user_id for user_id, count in self.user_to_num_current_syncs.items() + user_id + for user_id, count in self.user_to_num_current_syncs.items() if count } for user_ids in self.external_process_to_current_syncs.values(): @@ -479,7 +481,9 @@ class PresenceHandler(object): return set() @defer.inlineCallbacks - def update_external_syncs_row(self, process_id, user_id, is_syncing, sync_time_msec): + def update_external_syncs_row( + self, process_id, user_id, is_syncing, sync_time_msec + ): """Update the syncing users for an external process as a delta. Args: @@ -500,20 +504,22 @@ class PresenceHandler(object): updates = [] if is_syncing and user_id not in process_presence: if prev_state.state == PresenceState.OFFLINE: - updates.append(prev_state.copy_and_replace( - state=PresenceState.ONLINE, - last_active_ts=sync_time_msec, - last_user_sync_ts=sync_time_msec, - )) + updates.append( + prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=sync_time_msec, + last_user_sync_ts=sync_time_msec, + ) + ) else: - updates.append(prev_state.copy_and_replace( - last_user_sync_ts=sync_time_msec, - )) + updates.append( + prev_state.copy_and_replace(last_user_sync_ts=sync_time_msec) + ) process_presence.add(user_id) elif user_id in process_presence: - updates.append(prev_state.copy_and_replace( - last_user_sync_ts=sync_time_msec, - )) + updates.append( + prev_state.copy_and_replace(last_user_sync_ts=sync_time_msec) + ) if not is_syncing: process_presence.discard(user_id) @@ -537,12 +543,12 @@ class PresenceHandler(object): prev_states = yield self.current_state_for_users(process_presence) time_now_ms = self.clock.time_msec() - yield self._update_states([ - prev_state.copy_and_replace( - last_user_sync_ts=time_now_ms, - ) - for prev_state in itervalues(prev_states) - ]) + yield self._update_states( + [ + prev_state.copy_and_replace(last_user_sync_ts=time_now_ms) + for prev_state in itervalues(prev_states) + ] + ) self.external_process_last_updated_ms.pop(process_id, None) @defer.inlineCallbacks @@ -574,8 +580,7 @@ class PresenceHandler(object): missing = [user_id for user_id, state in iteritems(states) if not state] if missing: new = { - user_id: UserPresenceState.default(user_id) - for user_id in missing + user_id: UserPresenceState.default(user_id) for user_id in missing } states.update(new) self.user_to_current_state.update(new) @@ -593,8 +598,10 @@ class PresenceHandler(object): room_ids_to_states, users_to_states = parties self.notifier.on_new_event( - "presence_key", stream_id, rooms=room_ids_to_states.keys(), - users=[UserID.from_string(u) for u in users_to_states] + "presence_key", + stream_id, + rooms=room_ids_to_states.keys(), + users=[UserID.from_string(u) for u in users_to_states], ) self._push_to_remotes(states) @@ -605,8 +612,10 @@ class PresenceHandler(object): room_ids_to_states, users_to_states = parties self.notifier.on_new_event( - "presence_key", stream_id, rooms=room_ids_to_states.keys(), - users=[UserID.from_string(u) for u in users_to_states] + "presence_key", + stream_id, + rooms=room_ids_to_states.keys(), + users=[UserID.from_string(u) for u in users_to_states], ) def _push_to_remotes(self, states): @@ -631,15 +640,15 @@ class PresenceHandler(object): user_id = push.get("user_id", None) if not user_id: logger.info( - "Got presence update from %r with no 'user_id': %r", - origin, push, + "Got presence update from %r with no 'user_id': %r", origin, push ) continue if get_domain_from_id(user_id) != origin: logger.info( "Got presence update from %r with bad 'user_id': %r", - origin, user_id, + origin, + user_id, ) continue @@ -647,14 +656,12 @@ class PresenceHandler(object): if not presence_state: logger.info( "Got presence update from %r with no 'presence_state': %r", - origin, push, + origin, + push, ) continue - new_fields = { - "state": presence_state, - "last_federation_update_ts": now, - } + new_fields = {"state": presence_state, "last_federation_update_ts": now} last_active_ago = push.get("last_active_ago", None) if last_active_ago is not None: @@ -672,10 +679,7 @@ class PresenceHandler(object): @defer.inlineCallbacks def get_state(self, target_user, as_event=False): - results = yield self.get_states( - [target_user.to_string()], - as_event=as_event, - ) + results = yield self.get_states([target_user.to_string()], as_event=as_event) defer.returnValue(results[0]) @@ -699,13 +703,15 @@ class PresenceHandler(object): now = self.clock.time_msec() if as_event: - defer.returnValue([ - { - "type": "m.presence", - "content": format_user_presence_state(state, now), - } - for state in updates - ]) + defer.returnValue( + [ + { + "type": "m.presence", + "content": format_user_presence_state(state, now), + } + for state in updates + ] + ) else: defer.returnValue(updates) @@ -717,7 +723,9 @@ class PresenceHandler(object): presence = state["presence"] valid_presence = ( - PresenceState.ONLINE, PresenceState.UNAVAILABLE, PresenceState.OFFLINE + PresenceState.ONLINE, + PresenceState.UNAVAILABLE, + PresenceState.OFFLINE, ) if presence not in valid_presence: raise SynapseError(400, "Invalid presence state") @@ -726,9 +734,7 @@ class PresenceHandler(object): prev_state = yield self.current_state_for_user(user_id) - new_fields = { - "state": presence - } + new_fields = {"state": presence} if not ignore_status_msg: msg = status_msg if presence != PresenceState.OFFLINE else None @@ -877,8 +883,7 @@ class PresenceHandler(object): hosts = set(host for host in hosts if host != self.server_name) self.federation.send_presence_to_destinations( - states=[state], - destinations=hosts, + states=[state], destinations=hosts ) else: # A remote user has joined the room, so we need to: @@ -904,7 +909,8 @@ class PresenceHandler(object): # default state. now = self.clock.time_msec() states = [ - state for state in states.values() + state + for state in states.values() if state.state != PresenceState.OFFLINE or now - state.last_active_ts < 7 * 24 * 60 * 60 * 1000 or state.status_msg is not None @@ -912,8 +918,7 @@ class PresenceHandler(object): if states: self.federation.send_presence_to_destinations( - states=states, - destinations=[get_domain_from_id(user_id)], + states=states, destinations=[get_domain_from_id(user_id)] ) @@ -937,7 +942,10 @@ def should_notify(old_state, new_state): notify_reason_counter.labels("current_active_change").inc() return True - if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: + if ( + new_state.last_active_ts - old_state.last_active_ts + > LAST_ACTIVE_GRANULARITY + ): # Only notify about last active bumps if we're not currently acive if not new_state.currently_active: notify_reason_counter.labels("last_active_change_online").inc() @@ -958,9 +966,7 @@ def format_user_presence_state(state, now, include_user_id=True): The "user_id" is optional so that this function can be used to format presence updates for client /sync responses and for federation /send requests. """ - content = { - "presence": state.state, - } + content = {"presence": state.state} if include_user_id: content["user_id"] = state.user_id if state.last_active_ts: @@ -986,8 +992,15 @@ class PresenceEventSource(object): @defer.inlineCallbacks @log_function - def get_new_events(self, user, from_key, room_ids=None, include_offline=True, - explicit_room_id=None, **kwargs): + def get_new_events( + self, + user, + from_key, + room_ids=None, + include_offline=True, + explicit_room_id=None, + **kwargs + ): # The process for getting presence events are: # 1. Get the rooms the user is in. # 2. Get the list of user in the rooms. @@ -1030,7 +1043,7 @@ class PresenceEventSource(object): if from_key: user_ids_changed = stream_change_cache.get_entities_changed( - users_interested_in, from_key, + users_interested_in, from_key ) else: user_ids_changed = users_interested_in @@ -1040,10 +1053,16 @@ class PresenceEventSource(object): if include_offline: defer.returnValue((list(updates.values()), max_token)) else: - defer.returnValue(([ - s for s in itervalues(updates) - if s.state != PresenceState.OFFLINE - ], max_token)) + defer.returnValue( + ( + [ + s + for s in itervalues(updates) + if s.state != PresenceState.OFFLINE + ], + max_token, + ) + ) def get_current_key(self): return self.store.get_current_presence_token() @@ -1061,13 +1080,13 @@ class PresenceEventSource(object): users_interested_in.add(user_id) # So that we receive our own presence users_who_share_room = yield self.store.get_users_who_share_room_with_user( - user_id, on_invalidate=cache_context.invalidate, + user_id, on_invalidate=cache_context.invalidate ) users_interested_in.update(users_who_share_room) if explicit_room_id: user_ids = yield self.store.get_users_in_room( - explicit_room_id, on_invalidate=cache_context.invalidate, + explicit_room_id, on_invalidate=cache_context.invalidate ) users_interested_in.update(user_ids) @@ -1123,9 +1142,7 @@ def handle_timeout(state, is_mine, syncing_user_ids, now): if now - state.last_active_ts > IDLE_TIMER: # Currently online, but last activity ages ago so auto # idle - state = state.copy_and_replace( - state=PresenceState.UNAVAILABLE, - ) + state = state.copy_and_replace(state=PresenceState.UNAVAILABLE) changed = True elif now - state.last_active_ts > LAST_ACTIVE_GRANULARITY: # So that we send down a notification that we've @@ -1145,8 +1162,7 @@ def handle_timeout(state, is_mine, syncing_user_ids, now): sync_or_active = max(state.last_user_sync_ts, state.last_active_ts) if now - sync_or_active > SYNC_ONLINE_TIMEOUT: state = state.copy_and_replace( - state=PresenceState.OFFLINE, - status_msg=None, + state=PresenceState.OFFLINE, status_msg=None ) changed = True else: @@ -1155,10 +1171,7 @@ def handle_timeout(state, is_mine, syncing_user_ids, now): # no one gets stuck online forever. if now - state.last_federation_update_ts > FEDERATION_TIMEOUT: # The other side seems to have disappeared. - state = state.copy_and_replace( - state=PresenceState.OFFLINE, - status_msg=None, - ) + state = state.copy_and_replace(state=PresenceState.OFFLINE, status_msg=None) changed = True return state if changed else None @@ -1193,21 +1206,17 @@ def handle_update(prev_state, new_state, is_mine, wheel_timer, now): if new_state.state == PresenceState.ONLINE: # Idle timer wheel_timer.insert( - now=now, - obj=user_id, - then=new_state.last_active_ts + IDLE_TIMER + now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER ) active = now - new_state.last_active_ts < LAST_ACTIVE_GRANULARITY - new_state = new_state.copy_and_replace( - currently_active=active, - ) + new_state = new_state.copy_and_replace(currently_active=active) if active: wheel_timer.insert( now=now, obj=user_id, - then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY + then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY, ) if new_state.state != PresenceState.OFFLINE: @@ -1215,29 +1224,25 @@ def handle_update(prev_state, new_state, is_mine, wheel_timer, now): wheel_timer.insert( now=now, obj=user_id, - then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT + then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT, ) last_federate = new_state.last_federation_update_ts if now - last_federate > FEDERATION_PING_INTERVAL: # Been a while since we've poked remote servers - new_state = new_state.copy_and_replace( - last_federation_update_ts=now, - ) + new_state = new_state.copy_and_replace(last_federation_update_ts=now) federation_ping = True else: wheel_timer.insert( now=now, obj=user_id, - then=new_state.last_federation_update_ts + FEDERATION_TIMEOUT + then=new_state.last_federation_update_ts + FEDERATION_TIMEOUT, ) # Check whether the change was something worth notifying about if should_notify(prev_state, new_state): - new_state = new_state.copy_and_replace( - last_federation_update_ts=now, - ) + new_state = new_state.copy_and_replace(last_federation_update_ts=now) persist_and_notify = True return new_state, persist_and_notify, federation_ping diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index f40929e877..36d7ac2741 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -71,7 +71,7 @@ class BaseProfileHandler(BaseHandler): if hs.config.worker_app is None: self.clock.looping_call( - self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS, + self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS ) if len(self.hs.config.replicate_user_profiles_to) > 0: @@ -115,7 +115,7 @@ class BaseProfileHandler(BaseHandler): yield self._replicate_host_profile_batch(repl_host, i) except Exception: logger.exception( - "Exception while replicating to %s: aborting for now", repl_host, + "Exception while replicating to %s: aborting for now", repl_host ) @defer.inlineCallbacks @@ -123,18 +123,16 @@ class BaseProfileHandler(BaseHandler): logger.info("Replicating profile batch %d to %s", batchnum, host) batch_rows = yield self.store.get_profile_batch(batchnum) batch = { - UserID(r["user_id"], self.hs.hostname).to_string(): ({ - "display_name": r["displayname"], - "avatar_url": r["avatar_url"], - } if r["active"] else None) for r in batch_rows + UserID(r["user_id"], self.hs.hostname).to_string(): ( + {"display_name": r["displayname"], "avatar_url": r["avatar_url"]} + if r["active"] + else None + ) + for r in batch_rows } url = "https://%s/_matrix/identity/api/v1/replicate_profiles" % (host,) - body = { - "batchnum": batchnum, - "batch": batch, - "origin_server": self.hs.hostname, - } + body = {"batchnum": batchnum, "batch": batch, "origin_server": self.hs.hostname} signed_body = sign_json(body, self.hs.hostname, self.hs.config.signing_key[0]) try: yield self.http_client.post_json_get_json(url, signed_body) @@ -142,7 +140,9 @@ class BaseProfileHandler(BaseHandler): logger.info("Sucessfully replicated profile batch %d to %s", batchnum, host) except Exception: # This will get retried when the looping call next comes around - logger.exception("Failed to replicate profile batch %d to %s", batchnum, host) + logger.exception( + "Failed to replicate profile batch %d to %s", batchnum, host + ) raise @defer.inlineCallbacks @@ -162,18 +162,13 @@ class BaseProfileHandler(BaseHandler): raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) raise - defer.returnValue({ - "displayname": displayname, - "avatar_url": avatar_url, - }) + defer.returnValue({"displayname": displayname, "avatar_url": avatar_url}) else: try: result = yield self.federation.make_query( destination=target_user.domain, query_type="profile", - args={ - "user_id": user_id, - }, + args={"user_id": user_id}, ignore_backoff=True, ) defer.returnValue(result) @@ -202,10 +197,7 @@ class BaseProfileHandler(BaseHandler): raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) raise - defer.returnValue({ - "displayname": displayname, - "avatar_url": avatar_url, - }) + defer.returnValue({"displayname": displayname, "avatar_url": avatar_url}) else: profile = yield self.store.get_from_remote_profile_cache(user_id) defer.returnValue(profile or {}) @@ -228,10 +220,7 @@ class BaseProfileHandler(BaseHandler): result = yield self.federation.make_query( destination=target_user.domain, query_type="profile", - args={ - "user_id": target_user.to_string(), - "field": "displayname", - }, + args={"user_id": target_user.to_string(), "field": "displayname"}, ignore_backoff=True, ) except RequestSendFailed as e: @@ -260,18 +249,22 @@ class BaseProfileHandler(BaseHandler): if not by_admin and self.hs.config.disable_set_displayname: profile = yield self.store.get_profileinfo(target_user.localpart) if profile.display_name: - raise SynapseError(400, "Changing displayname is disabled on this server") + raise SynapseError( + 400, "Changing displayname is disabled on this server" + ) if len(new_displayname) > MAX_DISPLAYNAME_LEN: raise SynapseError( - 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN, ), + 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,) ) - if new_displayname == '': + if new_displayname == "": new_displayname = None if len(self.hs.config.replicate_user_profiles_to) > 0: - cur_batchnum = yield self.store.get_latest_profile_replication_batch_number() + cur_batchnum = ( + yield self.store.get_latest_profile_replication_batch_number() + ) new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1 else: new_batchnum = None @@ -307,7 +300,9 @@ class BaseProfileHandler(BaseHandler): where we've already done these checks anyway. """ if len(self.hs.config.replicate_user_profiles_to) > 0: - cur_batchnum = yield self.store.get_latest_profile_replication_batch_number() + cur_batchnum = ( + yield self.store.get_latest_profile_replication_batch_number() + ) new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1 else: new_batchnum = None @@ -335,10 +330,7 @@ class BaseProfileHandler(BaseHandler): result = yield self.federation.make_query( destination=target_user.domain, query_type="profile", - args={ - "user_id": target_user.to_string(), - "field": "avatar_url", - }, + args={"user_id": target_user.to_string(), "field": "avatar_url"}, ignore_backoff=True, ) except RequestSendFailed as e: @@ -361,17 +353,21 @@ class BaseProfileHandler(BaseHandler): if not by_admin and self.hs.config.disable_set_avatar_url: profile = yield self.store.get_profileinfo(target_user.localpart) if profile.avatar_url: - raise SynapseError(400, "Changing avatar url is disabled on this server") + raise SynapseError( + 400, "Changing avatar url is disabled on this server" + ) if len(self.hs.config.replicate_user_profiles_to) > 0: - cur_batchnum = yield self.store.get_latest_profile_replication_batch_number() + cur_batchnum = ( + yield self.store.get_latest_profile_replication_batch_number() + ) new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1 else: new_batchnum = None if len(new_avatar_url) > MAX_AVATAR_URL_LEN: raise SynapseError( - 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN, ), + 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,) ) # Enforce a max avatar size if one is defined @@ -389,8 +385,10 @@ class BaseProfileHandler(BaseHandler): media_size = media_info["media_length"] if self.max_avatar_size and media_size > self.max_avatar_size: raise SynapseError( - 400, "Avatars must be less than %s bytes in size" % - (self.max_avatar_size,), errcode=Codes.TOO_LARGE, + 400, + "Avatars must be less than %s bytes in size" + % (self.max_avatar_size,), + errcode=Codes.TOO_LARGE, ) # Ensure the avatar's file type is allowed @@ -399,12 +397,11 @@ class BaseProfileHandler(BaseHandler): and media_info["media_type"] not in self.allowed_avatar_mimetypes ): raise SynapseError( - 400, "Avatar file type '%s' not allowed" % - media_info["media_type"], + 400, "Avatar file type '%s' not allowed" % media_info["media_type"] ) yield self.store.set_profile_avatar_url( - target_user.localpart, new_avatar_url, new_batchnum, + target_user.localpart, new_avatar_url, new_batchnum ) if self.hs.config.user_directory_search_all_users: @@ -465,9 +462,7 @@ class BaseProfileHandler(BaseHandler): yield self.ratelimit(requester) - room_ids = yield self.store.get_rooms_for_user( - target_user.to_string(), - ) + room_ids = yield self.store.get_rooms_for_user(target_user.to_string()) for room_id in room_ids: handler = self.hs.get_room_member_handler() @@ -483,8 +478,7 @@ class BaseProfileHandler(BaseHandler): ) except Exception as e: logger.warn( - "Failed to update join event for room %s - %s", - room_id, str(e) + "Failed to update join event for room %s - %s", room_id, str(e) ) @defer.inlineCallbacks @@ -516,11 +510,9 @@ class BaseProfileHandler(BaseHandler): return try: - requester_rooms = yield self.store.get_rooms_for_user( - requester.to_string() - ) + requester_rooms = yield self.store.get_rooms_for_user(requester.to_string()) target_user_rooms = yield self.store.get_rooms_for_user( - target_user.to_string(), + target_user.to_string() ) # Check if the room lists have no elements in common. @@ -544,12 +536,12 @@ class MasterProfileHandler(BaseProfileHandler): assert hs.config.worker_app is None self.clock.looping_call( - self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS, + self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS ) def _start_update_remote_profile_cache(self): return run_as_background_process( - "Update remote profile", self._update_remote_profile_cache, + "Update remote profile", self._update_remote_profile_cache ) @defer.inlineCallbacks @@ -563,7 +555,7 @@ class MasterProfileHandler(BaseProfileHandler): for user_id, displayname, avatar_url in entries: is_subscribed = yield self.store.is_subscribed_remote_profile_for_user( - user_id, + user_id ) if not is_subscribed: yield self.store.maybe_delete_remote_profile_cache(user_id) @@ -573,9 +565,7 @@ class MasterProfileHandler(BaseProfileHandler): profile = yield self.federation.make_query( destination=get_domain_from_id(user_id), query_type="profile", - args={ - "user_id": user_id, - }, + args={"user_id": user_id}, ignore_backoff=True, ) except Exception: @@ -590,6 +580,4 @@ class MasterProfileHandler(BaseProfileHandler): new_avatar = profile.get("avatar_url") # We always hit update to update the last_check timestamp - yield self.store.update_remote_profile_cache( - user_id, new_name, new_avatar - ) + yield self.store.update_remote_profile_cache(user_id, new_name, new_avatar) diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py index 32108568c6..3e4d8c93a4 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py @@ -43,7 +43,7 @@ class ReadMarkerHandler(BaseHandler): with (yield self.read_marker_linearizer.queue((room_id, user_id))): existing_read_marker = yield self.store.get_account_data_for_room_and_type( - user_id, room_id, "m.fully_read", + user_id, room_id, "m.fully_read" ) should_update = True @@ -51,14 +51,11 @@ class ReadMarkerHandler(BaseHandler): if existing_read_marker: # Only update if the new marker is ahead in the stream should_update = yield self.store.is_event_after( - event_id, - existing_read_marker['event_id'] + event_id, existing_read_marker["event_id"] ) if should_update: - content = { - "event_id": event_id - } + content = {"event_id": event_id} max_id = yield self.store.add_account_data_to_room( user_id, room_id, "m.fully_read", content ) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 274d2946ad..a85dd8cdee 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -88,19 +88,16 @@ class ReceiptsHandler(BaseHandler): affected_room_ids = list(set([r.room_id for r in receipts])) - self.notifier.on_new_event( - "receipt_key", max_batch_id, rooms=affected_room_ids - ) + self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids) # Note that the min here shouldn't be relied upon to be accurate. yield self.hs.get_pusherpool().on_new_receipts( - min_batch_id, max_batch_id, affected_room_ids, + min_batch_id, max_batch_id, affected_room_ids ) defer.returnValue(True) @defer.inlineCallbacks - def received_client_receipt(self, room_id, receipt_type, user_id, - event_id): + def received_client_receipt(self, room_id, receipt_type, user_id, event_id): """Called when a client tells us a local user has read up to the given event_id in the room. """ @@ -109,9 +106,7 @@ class ReceiptsHandler(BaseHandler): receipt_type=receipt_type, user_id=user_id, event_ids=[event_id], - data={ - "ts": int(self.clock.time_msec()), - }, + data={"ts": int(self.clock.time_msec())}, ) is_new = yield self._handle_new_receipts([receipt]) @@ -125,8 +120,7 @@ class ReceiptsHandler(BaseHandler): """Gets all receipts for a room, upto the given key. """ result = yield self.store.get_linearized_receipts_for_room( - room_id, - to_key=to_key, + room_id, to_key=to_key ) if not result: @@ -148,14 +142,12 @@ class ReceiptEventSource(object): defer.returnValue(([], to_key)) events = yield self.store.get_linearized_receipts_for_rooms( - room_ids, - from_key=from_key, - to_key=to_key, + room_ids, from_key=from_key, to_key=to_key ) defer.returnValue((events, to_key)) - def get_current_key(self, direction='f'): + def get_current_key(self, direction="f"): return self.store.get_max_receipt_stream_id() @defer.inlineCallbacks @@ -169,9 +161,7 @@ class ReceiptEventSource(object): room_ids = yield self.store.get_rooms_for_user(user.to_string()) events = yield self.store.get_linearized_receipts_for_rooms( - room_ids, - from_key=from_key, - to_key=to_key, + room_ids, from_key=from_key, to_key=to_key ) defer.returnValue((events, to_key)) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 7747964352..16427277d9 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -47,7 +47,6 @@ logger = logging.getLogger(__name__) class RegistrationHandler(BaseHandler): - def __init__(self, hs): """ @@ -70,7 +69,7 @@ class RegistrationHandler(BaseHandler): self.macaroon_gen = hs.get_macaroon_generator() self._generate_user_id_linearizer = Linearizer( - name="_generate_user_id_linearizer", + name="_generate_user_id_linearizer" ) self._server_notices_mxid = hs.config.server_notices_mxid @@ -78,38 +77,31 @@ class RegistrationHandler(BaseHandler): if hs.config.worker_app: self._register_client = ReplicationRegisterServlet.make_client(hs) - self._register_device_client = ( - RegisterDeviceReplicationServlet.make_client(hs) + self._register_device_client = RegisterDeviceReplicationServlet.make_client( + hs ) - self._post_registration_client = ( - ReplicationPostRegisterActionsServlet.make_client(hs) + self._post_registration_client = ReplicationPostRegisterActionsServlet.make_client( + hs ) else: self.device_handler = hs.get_device_handler() self.pusher_pool = hs.get_pusherpool() @defer.inlineCallbacks - def check_username(self, localpart, guest_access_token=None, - assigned_user_id=None): + def check_username(self, localpart, guest_access_token=None, assigned_user_id=None): if types.contains_invalid_mxid_characters(localpart): raise SynapseError( 400, "User ID can only contain characters a-z, 0-9, or '=_-./'", - Codes.INVALID_USERNAME + Codes.INVALID_USERNAME, ) if not localpart: - raise SynapseError( - 400, - "User ID cannot be empty", - Codes.INVALID_USERNAME - ) + raise SynapseError(400, "User ID cannot be empty", Codes.INVALID_USERNAME) - if localpart[0] == '_': + if localpart[0] == "_": raise SynapseError( - 400, - "User ID may not begin with _", - Codes.INVALID_USERNAME + 400, "User ID may not begin with _", Codes.INVALID_USERNAME ) user = UserID(localpart, self.hs.hostname) @@ -129,19 +121,15 @@ class RegistrationHandler(BaseHandler): if len(user_id) > MAX_USERID_LENGTH: raise SynapseError( 400, - "User ID may not be longer than %s characters" % ( - MAX_USERID_LENGTH, - ), - Codes.INVALID_USERNAME + "User ID may not be longer than %s characters" % (MAX_USERID_LENGTH,), + Codes.INVALID_USERNAME, ) users = yield self.store.get_users_by_id_case_insensitive(user_id) if users: if not guest_access_token: raise SynapseError( - 400, - "User ID already taken.", - errcode=Codes.USER_IN_USE, + 400, "User ID already taken.", errcode=Codes.USER_IN_USE ) user_data = yield self.auth.get_user_by_access_token(guest_access_token) if not user_data["is_guest"] or user_data["user"].localpart != localpart: @@ -206,8 +194,7 @@ class RegistrationHandler(BaseHandler): try: int(localpart) raise RegistrationError( - 400, - "Numeric user IDs are reserved for guest users." + 400, "Numeric user IDs are reserved for guest users." ) except ValueError: pass @@ -239,7 +226,7 @@ class RegistrationHandler(BaseHandler): if default_display_name: yield self.profile_handler.set_displayname( - user, None, default_display_name, by_admin=True, + user, None, default_display_name, by_admin=True ) if self.hs.config.user_directory_search_all_users: @@ -273,7 +260,7 @@ class RegistrationHandler(BaseHandler): ) yield self.profile_handler.set_displayname( - user, None, default_display_name, by_admin=True, + user, None, default_display_name, by_admin=True ) except SynapseError: @@ -296,15 +283,13 @@ class RegistrationHandler(BaseHandler): } # Bind email to new account - yield self._register_email_threepid( - user_id, threepid_dict, None, False, - ) + yield self._register_email_threepid(user_id, threepid_dict, None, False) # Prevent the new user from showing up in the user directory if the server # mandates it. if not self._show_in_user_directory: yield self.store.add_account_data_for_user( - user_id, "im.vector.hide_profile", {'hide_profile': True}, + user_id, "im.vector.hide_profile", {"hide_profile": True} ) yield self.profile_handler.set_active(user, False, True) @@ -339,8 +324,8 @@ class RegistrationHandler(BaseHandler): room_alias = RoomAlias.from_string(r) if self.hs.hostname != room_alias.domain: logger.warning( - 'Cannot create room alias %s, ' - 'it does not match server domain', + "Cannot create room alias %s, " + "it does not match server domain", r, ) else: @@ -353,7 +338,7 @@ class RegistrationHandler(BaseHandler): fake_requester, config={ "preset": "public_chat", - "room_alias_name": room_alias_localpart + "room_alias_name": room_alias_localpart, }, ratelimit=False, ) @@ -387,8 +372,9 @@ class RegistrationHandler(BaseHandler): raise AuthError(403, "Invalid application service token.") if not service.is_interested_in_user(user_id): raise SynapseError( - 400, "Invalid user localpart for this application service.", - errcode=Codes.EXCLUSIVE + 400, + "Invalid user localpart for this application service.", + errcode=Codes.EXCLUSIVE, ) service_id = service.id if service.is_exclusive_user(user_id) else None @@ -411,7 +397,7 @@ class RegistrationHandler(BaseHandler): ) yield self.profile_handler.set_displayname( - user, None, display_name, by_admin=True, + user, None, display_name, by_admin=True ) if self.hs.config.user_directory_search_all_users: @@ -431,17 +417,15 @@ class RegistrationHandler(BaseHandler): """ captcha_response = yield self._validate_captcha( - ip, - private_key, - challenge, - response + ip, private_key, challenge, response ) if not captcha_response["valid"]: - logger.info("Invalid captcha entered from %s. Error: %s", - ip, captcha_response["error_url"]) - raise InvalidCaptchaError( - error_url=captcha_response["error_url"] + logger.info( + "Invalid captcha entered from %s. Error: %s", + ip, + captcha_response["error_url"], ) + raise InvalidCaptchaError(error_url=captcha_response["error_url"]) else: logger.info("Valid captcha entered from %s", ip) @@ -452,8 +436,7 @@ class RegistrationHandler(BaseHandler): """ if types.contains_invalid_mxid_characters(localpart): raise SynapseError( - 400, - "User ID can only contain characters a-z, 0-9, or '=_-./'", + 400, "User ID can only contain characters a-z, 0-9, or '=_-./'" ) yield self.auth.check_auth_blocking() user = UserID(localpart, self.hs.hostname) @@ -470,7 +453,7 @@ class RegistrationHandler(BaseHandler): ) yield self.profile_handler.set_displayname( - user, None, user.localpart, by_admin=True, + user, None, user.localpart, by_admin=True ) except Exception as e: yield self.store.add_access_token_to_user(user_id, token) @@ -487,8 +470,11 @@ class RegistrationHandler(BaseHandler): """ for c in threepidCreds: - logger.info("validating threepidcred sid %s on id server %s", - c['sid'], c['idServer']) + logger.info( + "validating threepidcred sid %s on id server %s", + c["sid"], + c["idServer"], + ) try: threepid = yield self.identity_handler.threepid_from_creds(c) except Exception: @@ -497,15 +483,18 @@ class RegistrationHandler(BaseHandler): if not threepid: raise RegistrationError(400, "Couldn't validate 3pid") - logger.info("got threepid with medium '%s' and address '%s'", - threepid['medium'], threepid['address']) + logger.info( + "got threepid with medium '%s' and address '%s'", + threepid["medium"], + threepid["address"], + ) if not ( - yield check_3pid_allowed(self.hs, threepid['medium'], threepid['address']) - ): - raise RegistrationError( - 403, "Third party identifier is not allowed" + yield check_3pid_allowed( + self.hs, threepid["medium"], threepid["address"] ) + ): + raise RegistrationError(403, "Third party identifier is not allowed") @defer.inlineCallbacks def bind_emails(self, user_id, threepidCreds): @@ -524,23 +513,23 @@ class RegistrationHandler(BaseHandler): if self._server_notices_mxid is not None: if user_id == self._server_notices_mxid: raise SynapseError( - 400, "This user ID is reserved.", - errcode=Codes.EXCLUSIVE + 400, "This user ID is reserved.", errcode=Codes.EXCLUSIVE ) # valid user IDs must not clash with any user ID namespaces claimed by # application services. services = self.store.get_app_services() interested_services = [ - s for s in services - if s.is_interested_in_user(user_id) - and s != allowed_appservice + s + for s in services + if s.is_interested_in_user(user_id) and s != allowed_appservice ] for service in interested_services: if service.is_exclusive_user(user_id): raise SynapseError( - 400, "This user ID is reserved by an application service.", - errcode=Codes.EXCLUSIVE + 400, + "This user ID is reserved by an application service.", + errcode=Codes.EXCLUSIVE, ) @defer.inlineCallbacks @@ -556,24 +545,24 @@ class RegistrationHandler(BaseHandler): as_token = self.hs.config.shadow_server.get("as_token") yield self.http_client.post_json_get_json( - "%s/_matrix/client/r0/register?access_token=%s" % ( - shadow_hs_url, as_token, - ), + "%s/_matrix/client/r0/register?access_token=%s" % (shadow_hs_url, as_token), { # XXX: auth_result is an unspecified extension for shadow registration - 'auth_result': auth_result, + "auth_result": auth_result, # XXX: another unspecified extension for shadow registration to ensure # that the displayname is correctly set by the masters erver - 'display_name': display_name, - 'username': localpart, - 'password': params.get("password"), - 'bind_email': params.get("bind_email"), - 'bind_msisdn': params.get("bind_msisdn"), - 'device_id': params.get("device_id"), - 'initial_device_display_name': params.get("initial_device_display_name"), - 'inhibit_login': False, - 'access_token': as_token, - } + "display_name": display_name, + "username": localpart, + "password": params.get("password"), + "bind_email": params.get("bind_email"), + "bind_msisdn": params.get("bind_msisdn"), + "device_id": params.get("device_id"), + "initial_device_display_name": params.get( + "initial_device_display_name" + ), + "inhibit_login": False, + "access_token": as_token, + }, ) @defer.inlineCallbacks @@ -599,14 +588,13 @@ class RegistrationHandler(BaseHandler): dict: Containing 'valid'(bool) and 'error_url'(str) if invalid. """ - response = yield self._submit_captcha(ip_addr, private_key, challenge, - response) + response = yield self._submit_captcha(ip_addr, private_key, challenge, response) # parse Google's response. Lovely format.. - lines = response.split('\n') + lines = response.split("\n") json = { - "valid": lines[0] == 'true', - "error_url": "http://www.recaptcha.net/recaptcha/api/challenge?" + - "error=%s" % lines[1] + "valid": lines[0] == "true", + "error_url": "http://www.recaptcha.net/recaptcha/api/challenge?" + + "error=%s" % lines[1], } defer.returnValue(json) @@ -618,17 +606,16 @@ class RegistrationHandler(BaseHandler): data = yield self.captcha_client.post_urlencoded_get_raw( "http://www.recaptcha.net:80/recaptcha/api/verify", args={ - 'privatekey': private_key, - 'remoteip': ip_addr, - 'challenge': challenge, - 'response': response - } + "privatekey": private_key, + "remoteip": ip_addr, + "challenge": challenge, + "response": response, + }, ) defer.returnValue(data) @defer.inlineCallbacks - def get_or_create_user(self, requester, localpart, displayname, - password_hash=None): + def get_or_create_user(self, requester, localpart, displayname, password_hash=None): """Creates a new user if the user does not exist, else revokes all previous access tokens and generates a new one. @@ -668,7 +655,7 @@ class RegistrationHandler(BaseHandler): ) if displayname is not None: yield self.profile_handler.set_displayname( - user, None, displayname or user.localpart, by_admin=True, + user, None, displayname or user.localpart, by_admin=True ) else: yield self._auth_handler.delete_access_tokens_for_user(user_id) @@ -693,15 +680,12 @@ class RegistrationHandler(BaseHandler): """ access_token = yield self.store.get_3pid_guest_access_token(medium, address) if access_token: - user_info = yield self.auth.get_user_by_access_token( - access_token - ) + user_info = yield self.auth.get_user_by_access_token(access_token) defer.returnValue((user_info["user"].to_string(), access_token)) user_id, access_token = yield self.register( - generate_token=True, - make_guest=True + generate_token=True, make_guest=True ) access_token = yield self.store.save_or_get_3pid_guest_access_token( medium, address, access_token, inviter_user_id @@ -722,9 +706,9 @@ class RegistrationHandler(BaseHandler): ) room_id = room_id.to_string() else: - raise SynapseError(400, "%s was not legal room ID or room alias" % ( - room_identifier, - )) + raise SynapseError( + 400, "%s was not legal room ID or room alias" % (room_identifier,) + ) yield room_member_handler.update_membership( requester=requester, @@ -735,10 +719,19 @@ class RegistrationHandler(BaseHandler): ratelimit=False, ) - def register_with_store(self, user_id, token=None, password_hash=None, - was_guest=False, make_guest=False, appservice_id=None, - create_profile_with_displayname=None, admin=False, - user_type=None, address=None): + def register_with_store( + self, + user_id, + token=None, + password_hash=None, + was_guest=False, + make_guest=False, + appservice_id=None, + create_profile_with_displayname=None, + admin=False, + user_type=None, + address=None, + ): """Register user in the datastore. Args: @@ -767,14 +760,15 @@ class RegistrationHandler(BaseHandler): time_now = self.clock.time() allowed, time_allowed = self.ratelimiter.can_do_action( - address, time_now_s=time_now, + address, + time_now_s=time_now, rate_hz=self.hs.config.rc_registration.per_second, burst_count=self.hs.config.rc_registration.burst_count, ) if not allowed: raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now)), + retry_after_ms=int(1000 * (time_allowed - time_now)) ) if self.hs.config.worker_app: @@ -804,8 +798,7 @@ class RegistrationHandler(BaseHandler): ) @defer.inlineCallbacks - def register_device(self, user_id, device_id, initial_display_name, - is_guest=False): + def register_device(self, user_id, device_id, initial_display_name, is_guest=False): """Register a device for a user and generate an access token. Args: @@ -838,14 +831,15 @@ class RegistrationHandler(BaseHandler): ) else: access_token = yield self._auth_handler.get_access_token_for_user_id( - user_id, device_id=device_id, + user_id, device_id=device_id ) defer.returnValue((device_id, access_token)) @defer.inlineCallbacks - def post_registration_actions(self, user_id, auth_result, access_token, - bind_email, bind_msisdn): + def post_registration_actions( + self, user_id, auth_result, access_token, bind_email, bind_msisdn + ): """A user has completed registration Args: @@ -879,20 +873,15 @@ class RegistrationHandler(BaseHandler): yield self.store.upsert_monthly_active_user(user_id) yield self._register_email_threepid( - user_id, threepid, access_token, - bind_email, + user_id, threepid, access_token, bind_email ) if auth_result and LoginType.MSISDN in auth_result: threepid = auth_result[LoginType.MSISDN] - yield self._register_msisdn_threepid( - user_id, threepid, bind_msisdn, - ) + yield self._register_msisdn_threepid(user_id, threepid, bind_msisdn) if auth_result and LoginType.TERMS in auth_result: - yield self._on_user_consented( - user_id, self.hs.config.user_consent_version, - ) + yield self._on_user_consented(user_id, self.hs.config.user_consent_version) @defer.inlineCallbacks def _on_user_consented(self, user_id, consent_version): @@ -904,9 +893,7 @@ class RegistrationHandler(BaseHandler): consented to. """ logger.info("%s has consented to the privacy policy", user_id) - yield self.store.user_set_consent_version( - user_id, consent_version, - ) + yield self.store.user_set_consent_version(user_id, consent_version) yield self.post_consent_actions(user_id) @defer.inlineCallbacks @@ -930,33 +917,30 @@ class RegistrationHandler(BaseHandler): Returns: defer.Deferred: """ - reqd = ('medium', 'address', 'validated_at') + reqd = ("medium", "address", "validated_at") if any(x not in threepid for x in reqd): # This will only happen if the ID server returns a malformed response logger.info("Can't add incomplete 3pid") return yield self._auth_handler.add_threepid( - user_id, - threepid['medium'], - threepid['address'], - threepid['validated_at'], + user_id, threepid["medium"], threepid["address"], threepid["validated_at"] ) # And we add an email pusher for them by default, but only # if email notifications are enabled (so people don't start # getting mail spam where they weren't before if email # notifs are set up on a home server) - if (self.hs.config.email_enable_notifs and - self.hs.config.email_notif_for_new_users - and token): + if ( + self.hs.config.email_enable_notifs + and self.hs.config.email_notif_for_new_users + and token + ): # Pull the ID of the access token back out of the db # It would really make more sense for this to be passed # up when the access token is saved, but that's quite an # invasive change I'd rather do separately. - user_tuple = yield self.store.get_user_by_access_token( - token - ) + user_tuple = yield self.store.get_user_by_access_token(token) token_id = user_tuple["token_id"] yield self.pusher_pool.add_pusher( @@ -973,11 +957,9 @@ class RegistrationHandler(BaseHandler): if bind_email: logger.info("bind_email specified: binding") - logger.debug("Binding emails %s to %s" % ( - threepid, user_id - )) + logger.debug("Binding emails %s to %s" % (threepid, user_id)) yield self.identity_handler.bind_threepid( - threepid['threepid_creds'], user_id + threepid["threepid_creds"], user_id ) else: logger.info("bind_email not specified: not binding email") @@ -1000,7 +982,7 @@ class RegistrationHandler(BaseHandler): defer.Deferred: """ try: - assert_params_in_dict(threepid, ['medium', 'address', 'validated_at']) + assert_params_in_dict(threepid, ["medium", "address", "validated_at"]) except SynapseError as ex: if ex.errcode == Codes.MISSING_PARAM: # This will only happen if the ID server returns a malformed response @@ -1009,17 +991,14 @@ class RegistrationHandler(BaseHandler): raise yield self._auth_handler.add_threepid( - user_id, - threepid['medium'], - threepid['address'], - threepid['validated_at'], + user_id, threepid["medium"], threepid["address"], threepid["validated_at"] ) if bind_msisdn: logger.info("bind_msisdn specified: binding") logger.debug("Binding msisdn %s to %s", threepid, user_id) yield self.identity_handler.bind_threepid( - threepid['threepid_creds'], user_id + threepid["threepid_creds"], user_id ) else: logger.info("bind_msisdn not specified: not binding msisdn") diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 7c24f9aac3..2ae9f93e97 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -103,7 +103,7 @@ class RoomCreationHandler(BaseHandler): if r is None: raise NotFoundError("Unknown room id %s" % (old_room_id,)) new_room_id = yield self._generate_room_id( - creator_id=user_id, is_public=r["is_public"], + creator_id=user_id, is_public=r["is_public"] ) logger.info("Creating new room %s to replace %s", new_room_id, old_room_id) @@ -112,7 +112,8 @@ class RoomCreationHandler(BaseHandler): # room, to check our user has perms in the old room. tombstone_event, tombstone_context = ( yield self.event_creation_handler.create_event( - requester, { + requester, + { "type": EventTypes.Tombstone, "state_key": "", "room_id": old_room_id, @@ -120,14 +121,14 @@ class RoomCreationHandler(BaseHandler): "content": { "body": "This room has been replaced", "replacement_room": new_room_id, - } + }, }, token_id=requester.access_token_id, ) ) old_room_version = yield self.store.get_room_version(old_room_id) yield self.auth.check_from_context( - old_room_version, tombstone_event, tombstone_context, + old_room_version, tombstone_event, tombstone_context ) yield self.clone_existing_room( @@ -140,27 +141,27 @@ class RoomCreationHandler(BaseHandler): # now send the tombstone yield self.event_creation_handler.send_nonmember_event( - requester, tombstone_event, tombstone_context, + requester, tombstone_event, tombstone_context ) old_room_state = yield tombstone_context.get_current_state_ids(self.store) # update any aliases yield self._move_aliases_to_new_room( - requester, old_room_id, new_room_id, old_room_state, + requester, old_room_id, new_room_id, old_room_state ) # and finally, shut down the PLs in the old room, and update them in the new # room. yield self._update_upgraded_room_pls( - requester, old_room_id, new_room_id, old_room_state, + requester, old_room_id, new_room_id, old_room_state ) defer.returnValue(new_room_id) @defer.inlineCallbacks def _update_upgraded_room_pls( - self, requester, old_room_id, new_room_id, old_room_state, + self, requester, old_room_id, new_room_id, old_room_state ): """Send updated power levels in both rooms after an upgrade @@ -178,7 +179,7 @@ class RoomCreationHandler(BaseHandler): if old_room_pl_event_id is None: logger.warning( "Not supported: upgrading a room with no PL event. Not setting PLs " - "in old room.", + "in old room." ) return @@ -199,45 +200,48 @@ class RoomCreationHandler(BaseHandler): if current < restricted_level: logger.info( "Setting level for %s in %s to %i (was %i)", - v, old_room_id, restricted_level, current, + v, + old_room_id, + restricted_level, + current, ) pl_content[v] = restricted_level updated = True else: - logger.info( - "Not setting level for %s (already %i)", - v, current, - ) + logger.info("Not setting level for %s (already %i)", v, current) if updated: try: yield self.event_creation_handler.create_and_send_nonmember_event( - requester, { + requester, + { "type": EventTypes.PowerLevels, - "state_key": '', + "state_key": "", "room_id": old_room_id, "sender": requester.user.to_string(), "content": pl_content, - }, ratelimit=False, + }, + ratelimit=False, ) except AuthError as e: logger.warning("Unable to update PLs in old room: %s", e) logger.info("Setting correct PLs in new room") yield self.event_creation_handler.create_and_send_nonmember_event( - requester, { + requester, + { "type": EventTypes.PowerLevels, - "state_key": '', + "state_key": "", "room_id": new_room_id, "sender": requester.user.to_string(), "content": old_room_pl_state.content, - }, ratelimit=False, + }, + ratelimit=False, ) @defer.inlineCallbacks def clone_existing_room( - self, requester, old_room_id, new_room_id, new_room_version, - tombstone_event_id, + self, requester, old_room_id, new_room_id, new_room_version, tombstone_event_id ): """Populate a new room based on an old room @@ -254,30 +258,24 @@ class RoomCreationHandler(BaseHandler): """ user_id = requester.user.to_string() - if (self._server_notices_mxid is not None and - requester.user.to_string() == self._server_notices_mxid): + if ( + self._server_notices_mxid is not None + and requester.user.to_string() == self._server_notices_mxid + ): # allow the server notices mxid to create rooms is_requester_admin = True else: - is_requester_admin = yield self.auth.is_server_admin( - requester.user, - ) + is_requester_admin = yield self.auth.is_server_admin(requester.user) if not is_requester_admin and not self.spam_checker.user_may_create_room( - user_id, - invite_list=[], - third_party_invite_list=[], - cloning=True, + user_id, invite_list=[], third_party_invite_list=[], cloning=True ): raise SynapseError(403, "You are not permitted to create rooms") creation_content = { "room_version": new_room_version, - "predecessor": { - "room_id": old_room_id, - "event_id": tombstone_event_id, - } + "predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id}, } # Check if old room was non-federatable @@ -306,7 +304,7 @@ class RoomCreationHandler(BaseHandler): ) old_room_state_ids = yield self.store.get_filtered_current_state_ids( - old_room_id, StateFilter.from_types(types_to_copy), + old_room_id, StateFilter.from_types(types_to_copy) ) # map from event_id to BaseEvent old_room_state_events = yield self.store.get_events(old_room_state_ids.values()) @@ -319,11 +317,9 @@ class RoomCreationHandler(BaseHandler): yield self._send_events_for_new_room( requester, new_room_id, - # we expect to override all the presets with initial_state, so this is # somewhat arbitrary. preset_config=RoomCreationPreset.PRIVATE_CHAT, - invite_list=[], initial_state=initial_state, creation_content=creation_content, @@ -331,20 +327,22 @@ class RoomCreationHandler(BaseHandler): # Transfer membership events old_room_member_state_ids = yield self.store.get_filtered_current_state_ids( - old_room_id, StateFilter.from_types([(EventTypes.Member, None)]), + old_room_id, StateFilter.from_types([(EventTypes.Member, None)]) ) # map from event_id to BaseEvent old_room_member_state_events = yield self.store.get_events( - old_room_member_state_ids.values(), + old_room_member_state_ids.values() ) for k, old_event in iteritems(old_room_member_state_events): # Only transfer ban events - if ("membership" in old_event.content and - old_event.content["membership"] == "ban"): + if ( + "membership" in old_event.content + and old_event.content["membership"] == "ban" + ): yield self.room_member_handler.update_membership( requester, - UserID.from_string(old_event['state_key']), + UserID.from_string(old_event["state_key"]), new_room_id, "ban", ratelimit=False, @@ -356,7 +354,7 @@ class RoomCreationHandler(BaseHandler): @defer.inlineCallbacks def _move_aliases_to_new_room( - self, requester, old_room_id, new_room_id, old_room_state, + self, requester, old_room_id, new_room_id, old_room_state ): directory_handler = self.hs.get_handlers().directory_handler @@ -387,14 +385,11 @@ class RoomCreationHandler(BaseHandler): alias = RoomAlias.from_string(alias_str) try: yield directory_handler.delete_association( - requester, alias, send_event=False, + requester, alias, send_event=False ) removed_aliases.append(alias_str) except SynapseError as e: - logger.warning( - "Unable to remove alias %s from old room: %s", - alias, e, - ) + logger.warning("Unable to remove alias %s from old room: %s", alias, e) # if we didn't find any aliases, or couldn't remove anyway, we can skip the rest # of this. @@ -410,30 +405,26 @@ class RoomCreationHandler(BaseHandler): # as when you remove an alias from the directory normally - it just means that # the aliases event gets out of sync with the directory # (cf https://github.com/vector-im/riot-web/issues/2369) - yield directory_handler.send_room_alias_update_event( - requester, old_room_id, - ) + yield directory_handler.send_room_alias_update_event(requester, old_room_id) except AuthError as e: - logger.warning( - "Failed to send updated alias event on old room: %s", e, - ) + logger.warning("Failed to send updated alias event on old room: %s", e) # we can now add any aliases we successfully removed to the new room. for alias in removed_aliases: try: yield directory_handler.create_association( - requester, RoomAlias.from_string(alias), - new_room_id, servers=(self.hs.hostname, ), - send_event=False, check_membership=False, + requester, + RoomAlias.from_string(alias), + new_room_id, + servers=(self.hs.hostname,), + send_event=False, + check_membership=False, ) logger.info("Moved alias %s to new room", alias) except SynapseError as e: # I'm not really expecting this to happen, but it could if the spam # checking module decides it shouldn't, or similar. - logger.error( - "Error adding alias %s to new room: %s", - alias, e, - ) + logger.error("Error adding alias %s to new room: %s", alias, e) try: if canonical_alias and (canonical_alias in removed_aliases): @@ -444,24 +435,19 @@ class RoomCreationHandler(BaseHandler): "state_key": "", "room_id": new_room_id, "sender": requester.user.to_string(), - "content": {"alias": canonical_alias, }, + "content": {"alias": canonical_alias}, }, - ratelimit=False + ratelimit=False, ) - yield directory_handler.send_room_alias_update_event( - requester, new_room_id, - ) + yield directory_handler.send_room_alias_update_event(requester, new_room_id) except SynapseError as e: # again I'm not really expecting this to fail, but if it does, I'd rather # we returned the new room to the client at this point. - logger.error( - "Unable to send updated alias events in new room: %s", e, - ) + logger.error("Unable to send updated alias events in new room: %s", e) @defer.inlineCallbacks - def create_room(self, requester, config, ratelimit=True, - creator_join_profile=None): + def create_room(self, requester, config, ratelimit=True, creator_join_profile=None): """ Creates a new room. Args: @@ -491,21 +477,19 @@ class RoomCreationHandler(BaseHandler): yield self.auth.check_auth_blocking(user_id) - if (self._server_notices_mxid is not None and - requester.user.to_string() == self._server_notices_mxid): + if ( + self._server_notices_mxid is not None + and requester.user.to_string() == self._server_notices_mxid + ): # allow the server notices mxid to create rooms is_requester_admin = True else: - is_requester_admin = yield self.auth.is_server_admin( - requester.user, - ) + is_requester_admin = yield self.auth.is_server_admin(requester.user) # Check whether the third party rules allows/changes the room create # request. yield self.third_party_event_rules.on_create_room( - requester, - config, - is_requester_admin=is_requester_admin, + requester, config, is_requester_admin=is_requester_admin ) invite_list = config.get("invite", []) @@ -523,16 +507,11 @@ class RoomCreationHandler(BaseHandler): yield self.ratelimit(requester) room_version = config.get( - "room_version", - self.config.default_room_version.identifier, + "room_version", self.config.default_room_version.identifier ) if not isinstance(room_version, string_types): - raise SynapseError( - 400, - "room_version must be a string", - Codes.BAD_JSON, - ) + raise SynapseError(400, "room_version must be a string", Codes.BAD_JSON) if room_version not in KNOWN_ROOM_VERSIONS: raise SynapseError( @@ -546,20 +525,11 @@ class RoomCreationHandler(BaseHandler): if wchar in config["room_alias_name"]: raise SynapseError(400, "Invalid characters in room alias") - room_alias = RoomAlias( - config["room_alias_name"], - self.hs.hostname, - ) - mapping = yield self.store.get_association_from_room_alias( - room_alias - ) + room_alias = RoomAlias(config["room_alias_name"], self.hs.hostname) + mapping = yield self.store.get_association_from_room_alias(room_alias) if mapping: - raise SynapseError( - 400, - "Room alias already taken", - Codes.ROOM_IN_USE - ) + raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE) else: room_alias = None @@ -569,9 +539,7 @@ class RoomCreationHandler(BaseHandler): except Exception: raise SynapseError(400, "Invalid user_id: %s" % (i,)) - yield self.event_creation_handler.assert_accepted_privacy_policy( - requester, - ) + yield self.event_creation_handler.assert_accepted_privacy_policy(requester) visibility = config.get("visibility", None) is_public = visibility == "public" @@ -593,7 +561,7 @@ class RoomCreationHandler(BaseHandler): "preset", RoomCreationPreset.PRIVATE_CHAT if visibility == "private" - else RoomCreationPreset.PUBLIC_CHAT + else RoomCreationPreset.PUBLIC_CHAT, ) raw_initial_state = config.get("initial_state", []) @@ -630,7 +598,8 @@ class RoomCreationHandler(BaseHandler): "state_key": "", "content": {"name": name}, }, - ratelimit=False) + ratelimit=False, + ) if "topic" in config: topic = config["topic"] @@ -643,7 +612,8 @@ class RoomCreationHandler(BaseHandler): "state_key": "", "content": {"topic": topic}, }, - ratelimit=False) + ratelimit=False, + ) for invitee in invite_list: content = {} @@ -680,30 +650,25 @@ class RoomCreationHandler(BaseHandler): if room_alias: result["room_alias"] = room_alias.to_string() - yield directory_handler.send_room_alias_update_event( - requester, room_id - ) + yield directory_handler.send_room_alias_update_event(requester, room_id) defer.returnValue(result) @defer.inlineCallbacks def _send_events_for_new_room( - self, - creator, # A Requester object. - room_id, - preset_config, - invite_list, - initial_state, - creation_content, - room_alias=None, - power_level_content_override=None, - creator_join_profile=None, + self, + creator, # A Requester object. + room_id, + preset_config, + invite_list, + initial_state, + creation_content, + room_alias=None, + power_level_content_override=None, + creator_join_profile=None, ): def create(etype, content, **kwargs): - e = { - "type": etype, - "content": content, - } + e = {"type": etype, "content": content} e.update(event_keys) e.update(kwargs) @@ -715,26 +680,17 @@ class RoomCreationHandler(BaseHandler): event = create(etype, content, **kwargs) logger.info("Sending %s in new room", etype) yield self.event_creation_handler.create_and_send_nonmember_event( - creator, - event, - ratelimit=False + creator, event, ratelimit=False ) config = RoomCreationHandler.PRESETS_DICT[preset_config] creator_id = creator.user.to_string() - event_keys = { - "room_id": room_id, - "sender": creator_id, - "state_key": "", - } + event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""} creation_content.update({"creator": creator_id}) - yield send( - etype=EventTypes.Create, - content=creation_content, - ) + yield send(etype=EventTypes.Create, content=creation_content) logger.info("Sending %s in new room", EventTypes.Member) yield self.room_member_handler.update_membership( @@ -749,17 +705,12 @@ class RoomCreationHandler(BaseHandler): # We treat the power levels override specially as this needs to be one # of the first events that get sent into a room. - pl_content = initial_state.pop((EventTypes.PowerLevels, ''), None) + pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None) if pl_content is not None: - yield send( - etype=EventTypes.PowerLevels, - content=pl_content, - ) + yield send(etype=EventTypes.PowerLevels, content=pl_content) else: power_level_content = { - "users": { - creator_id: 100, - }, + "users": {creator_id: 100}, "users_default": 0, "events": { EventTypes.Name: 50, @@ -783,50 +734,39 @@ class RoomCreationHandler(BaseHandler): if power_level_content_override: power_level_content.update(power_level_content_override) - yield send( - etype=EventTypes.PowerLevels, - content=power_level_content, - ) + yield send(etype=EventTypes.PowerLevels, content=power_level_content) - if room_alias and (EventTypes.CanonicalAlias, '') not in initial_state: + if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: yield send( etype=EventTypes.CanonicalAlias, content={"alias": room_alias.to_string()}, ) - if (EventTypes.JoinRules, '') not in initial_state: + if (EventTypes.JoinRules, "") not in initial_state: yield send( - etype=EventTypes.JoinRules, - content={"join_rule": config["join_rules"]}, + etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]} ) - if (EventTypes.RoomHistoryVisibility, '') not in initial_state: + if (EventTypes.RoomHistoryVisibility, "") not in initial_state: yield send( etype=EventTypes.RoomHistoryVisibility, - content={"history_visibility": config["history_visibility"]} + content={"history_visibility": config["history_visibility"]}, ) if config["guest_can_join"]: - if (EventTypes.GuestAccess, '') not in initial_state: + if (EventTypes.GuestAccess, "") not in initial_state: yield send( - etype=EventTypes.GuestAccess, - content={"guest_access": "can_join"} + etype=EventTypes.GuestAccess, content={"guest_access": "can_join"} ) for (etype, state_key), content in initial_state.items(): - yield send( - etype=etype, - state_key=state_key, - content=content, - ) + yield send(etype=etype, state_key=state_key, content=content) if "encryption_alg" in config: yield send( etype=EventTypes.Encryption, state_key="", - content={ - 'algorithm': config["encryption_alg"], - } + content={"algorithm": config["encryption_alg"]}, ) @defer.inlineCallbacks @@ -837,12 +777,9 @@ class RoomCreationHandler(BaseHandler): while attempts < 5: try: random_string = stringutils.random_string(18) - gen_room_id = RoomID( - random_string, - self.hs.hostname, - ).to_string() + gen_room_id = RoomID(random_string, self.hs.hostname).to_string() if isinstance(gen_room_id, bytes): - gen_room_id = gen_room_id.decode('utf-8') + gen_room_id = gen_room_id.decode("utf-8") yield self.store.store_room( room_id=gen_room_id, room_creator_user_id=creator_id, @@ -884,24 +821,19 @@ class RoomContextHandler(object): def filter_evts(events): return filter_events_for_client( - self.store, - user.to_string(), - events, - is_peeking=is_peeking + self.store, user.to_string(), events, is_peeking=is_peeking ) - event = yield self.store.get_event(event_id, get_prev_content=True, - allow_none=True) + event = yield self.store.get_event( + event_id, get_prev_content=True, allow_none=True + ) if not event: defer.returnValue(None) return - filtered = yield(filter_evts([event])) + filtered = yield (filter_evts([event])) if not filtered: - raise AuthError( - 403, - "You don't have permission to access that event." - ) + raise AuthError(403, "You don't have permission to access that event.") results = yield self.store.get_events_around( room_id, event_id, before_limit, after_limit, event_filter @@ -933,7 +865,7 @@ class RoomContextHandler(object): # https://github.com/matrix-org/matrix-doc/issues/687 state = yield self.store.get_state_for_events( - [last_event_id], state_filter=state_filter, + [last_event_id], state_filter=state_filter ) results["state"] = list(state[last_event_id].values()) @@ -945,9 +877,7 @@ class RoomContextHandler(object): "room_key", results["start"] ).to_string() - results["end"] = token.copy_and_replace( - "room_key", results["end"] - ).to_string() + results["end"] = token.copy_and_replace("room_key", results["end"]).to_string() defer.returnValue(results) @@ -958,13 +888,7 @@ class RoomEventSource(object): @defer.inlineCallbacks def get_new_events( - self, - user, - from_key, - limit, - room_ids, - is_guest, - explicit_room_id=None, + self, user, from_key, limit, room_ids, is_guest, explicit_room_id=None ): # We just ignore the key for now. @@ -975,9 +899,7 @@ class RoomEventSource(object): logger.warn("Stream has topological part!!!! %r", from_key) from_key = "s%s" % (from_token.stream,) - app_service = self.store.get_app_service_by_user_id( - user.to_string() - ) + app_service = self.store.get_app_service_by_user_id(user.to_string()) if app_service: # We no longer support AS users using /sync directly. # See https://github.com/matrix-org/matrix-doc/issues/1144 @@ -992,7 +914,7 @@ class RoomEventSource(object): from_key=from_key, to_key=to_key, limit=limit or 10, - order='ASC', + order="ASC", ) events = list(room_events) diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 617d1c9ef8..aae696a7e8 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -46,13 +46,18 @@ class RoomListHandler(BaseHandler): super(RoomListHandler, self).__init__(hs) self.enable_room_list_search = hs.config.enable_room_list_search self.response_cache = ResponseCache(hs, "room_list") - self.remote_response_cache = ResponseCache(hs, "remote_room_list", - timeout_ms=30 * 1000) + self.remote_response_cache = ResponseCache( + hs, "remote_room_list", timeout_ms=30 * 1000 + ) - def get_local_public_room_list(self, limit=None, since_token=None, - search_filter=None, - network_tuple=EMPTY_THIRD_PARTY_ID, - from_federation=False): + def get_local_public_room_list( + self, + limit=None, + since_token=None, + search_filter=None, + network_tuple=EMPTY_THIRD_PARTY_ID, + from_federation=False, + ): """Generate a local public room list. There are multiple different lists: the main one plus one per third @@ -68,14 +73,14 @@ class RoomListHandler(BaseHandler): Setting to None returns all public rooms across all lists. """ if not self.enable_room_list_search: - return defer.succeed({ - "chunk": [], - "total_room_count_estimate": 0, - }) + return defer.succeed({"chunk": [], "total_room_count_estimate": 0}) logger.info( "Getting public room list: limit=%r, since=%r, search=%r, network=%r", - limit, since_token, bool(search_filter), network_tuple, + limit, + since_token, + bool(search_filter), + network_tuple, ) if search_filter: @@ -88,24 +93,33 @@ class RoomListHandler(BaseHandler): # solution at some point timeout = self.clock.time() + 60 return self._get_public_room_list( - limit, since_token, search_filter, - network_tuple=network_tuple, timeout=timeout, + limit, + since_token, + search_filter, + network_tuple=network_tuple, + timeout=timeout, ) key = (limit, since_token, network_tuple) return self.response_cache.wrap( key, self._get_public_room_list, - limit, since_token, - network_tuple=network_tuple, from_federation=from_federation, + limit, + since_token, + network_tuple=network_tuple, + from_federation=from_federation, ) @defer.inlineCallbacks - def _get_public_room_list(self, limit=None, since_token=None, - search_filter=None, - network_tuple=EMPTY_THIRD_PARTY_ID, - from_federation=False, - timeout=None,): + def _get_public_room_list( + self, + limit=None, + since_token=None, + search_filter=None, + network_tuple=EMPTY_THIRD_PARTY_ID, + from_federation=False, + timeout=None, + ): """Generate a public room list. Args: limit (int|None): Maximum amount of rooms to return. @@ -135,15 +149,14 @@ class RoomListHandler(BaseHandler): current_public_id = yield self.store.get_current_public_room_stream_id() public_room_stream_id = since_token.public_room_stream_id newly_visible, newly_unpublished = yield self.store.get_public_room_changes( - public_room_stream_id, current_public_id, - network_tuple=network_tuple, + public_room_stream_id, current_public_id, network_tuple=network_tuple ) else: stream_token = yield self.store.get_room_max_stream_ordering() public_room_stream_id = yield self.store.get_current_public_room_stream_id() room_ids = yield self.store.get_public_room_ids_at_stream_id( - public_room_stream_id, network_tuple=network_tuple, + public_room_stream_id, network_tuple=network_tuple ) # We want to return rooms in a particular order: the number of joined @@ -168,7 +181,7 @@ class RoomListHandler(BaseHandler): return joined_users = yield self.state_handler.get_current_users_in_room( - room_id, latest_event_ids, + room_id, latest_event_ids ) num_joined_users = len(joined_users) @@ -180,8 +193,9 @@ class RoomListHandler(BaseHandler): # We want larger rooms to be first, hence negating num_joined_users rooms_to_order_value[room_id] = (-num_joined_users, room_id) - logger.info("Getting ordering for %i rooms since %s", - len(room_ids), stream_token) + logger.info( + "Getting ordering for %i rooms since %s", len(room_ids), stream_token + ) yield concurrently_execute(get_order_for_room, room_ids, 10) sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1]) @@ -193,7 +207,8 @@ class RoomListHandler(BaseHandler): # Filter out rooms that we don't want to return rooms_to_scan = [ - r for r in sorted_rooms + r + for r in sorted_rooms if r not in newly_unpublished and rooms_to_num_joined[r] > 0 ] @@ -204,13 +219,12 @@ class RoomListHandler(BaseHandler): # `since_token.current_limit` is the index of the last room we # sent down, so we exclude it and everything before/after it. if since_token.direction_is_forward: - rooms_to_scan = rooms_to_scan[since_token.current_limit + 1:] + rooms_to_scan = rooms_to_scan[since_token.current_limit + 1 :] else: - rooms_to_scan = rooms_to_scan[:since_token.current_limit] + rooms_to_scan = rooms_to_scan[: since_token.current_limit] rooms_to_scan.reverse() - logger.info("After sorting and filtering, %i rooms remain", - len(rooms_to_scan)) + logger.info("After sorting and filtering, %i rooms remain", len(rooms_to_scan)) # _append_room_entry_to_chunk will append to chunk but will stop if # len(chunk) > limit @@ -237,15 +251,19 @@ class RoomListHandler(BaseHandler): if timeout and self.clock.time() > timeout: raise Exception("Timed out searching room directory") - batch = rooms_to_scan[i:i + step] + batch = rooms_to_scan[i : i + step] logger.info("Processing %i rooms for result", len(batch)) yield concurrently_execute( lambda r: self._append_room_entry_to_chunk( - r, rooms_to_num_joined[r], - chunk, limit, search_filter, + r, + rooms_to_num_joined[r], + chunk, + limit, + search_filter, from_federation=from_federation, ), - batch, 5, + batch, + 5, ) logger.info("Now %i rooms in result", len(chunk)) if len(chunk) >= limit + 1: @@ -273,10 +291,7 @@ class RoomListHandler(BaseHandler): new_limit = sorted_rooms.index(last_room_id) - results = { - "chunk": chunk, - "total_room_count_estimate": total_room_count, - } + results = {"chunk": chunk, "total_room_count_estimate": total_room_count} if since_token: results["new_rooms"] = bool(newly_visible) @@ -313,8 +328,15 @@ class RoomListHandler(BaseHandler): defer.returnValue(results) @defer.inlineCallbacks - def _append_room_entry_to_chunk(self, room_id, num_joined_users, chunk, limit, - search_filter, from_federation=False): + def _append_room_entry_to_chunk( + self, + room_id, + num_joined_users, + chunk, + limit, + search_filter, + from_federation=False, + ): """Generate the entry for a room in the public room list and append it to the `chunk` if it matches the search filter @@ -345,8 +367,14 @@ class RoomListHandler(BaseHandler): chunk.append(result) @cachedInlineCallbacks(num_args=1, cache_context=True) - def generate_room_entry(self, room_id, num_joined_users, cache_context, - with_alias=True, allow_private=False): + def generate_room_entry( + self, + room_id, + num_joined_users, + cache_context, + with_alias=True, + allow_private=False, + ): """Returns the entry for a room Args: @@ -360,33 +388,31 @@ class RoomListHandler(BaseHandler): Deferred[dict|None]: Returns a room entry as a dictionary, or None if this room was determined not to be shown publicly. """ - result = { - "room_id": room_id, - "num_joined_members": num_joined_users, - } + result = {"room_id": room_id, "num_joined_members": num_joined_users} current_state_ids = yield self.store.get_current_state_ids( - room_id, on_invalidate=cache_context.invalidate, + room_id, on_invalidate=cache_context.invalidate ) - event_map = yield self.store.get_events([ - event_id for key, event_id in iteritems(current_state_ids) - if key[0] in ( - EventTypes.Create, - EventTypes.JoinRules, - EventTypes.Name, - EventTypes.Topic, - EventTypes.CanonicalAlias, - EventTypes.RoomHistoryVisibility, - EventTypes.GuestAccess, - "m.room.avatar", - ) - ]) + event_map = yield self.store.get_events( + [ + event_id + for key, event_id in iteritems(current_state_ids) + if key[0] + in ( + EventTypes.Create, + EventTypes.JoinRules, + EventTypes.Name, + EventTypes.Topic, + EventTypes.CanonicalAlias, + EventTypes.RoomHistoryVisibility, + EventTypes.GuestAccess, + "m.room.avatar", + ) + ] + ) - current_state = { - (ev.type, ev.state_key): ev - for ev in event_map.values() - } + current_state = {(ev.type, ev.state_key): ev for ev in event_map.values()} # Double check that this is actually a public room. @@ -446,14 +472,17 @@ class RoomListHandler(BaseHandler): defer.returnValue(result) @defer.inlineCallbacks - def get_remote_public_room_list(self, server_name, limit=None, since_token=None, - search_filter=None, include_all_networks=False, - third_party_instance_id=None,): + def get_remote_public_room_list( + self, + server_name, + limit=None, + since_token=None, + search_filter=None, + include_all_networks=False, + third_party_instance_id=None, + ): if not self.enable_room_list_search: - defer.returnValue({ - "chunk": [], - "total_room_count_estimate": 0, - }) + defer.returnValue({"chunk": [], "total_room_count_estimate": 0}) if search_filter: # We currently don't support searching across federation, so we have @@ -462,52 +491,75 @@ class RoomListHandler(BaseHandler): since_token = None res = yield self._get_remote_list_cached( - server_name, limit=limit, since_token=since_token, + server_name, + limit=limit, + since_token=since_token, include_all_networks=include_all_networks, third_party_instance_id=third_party_instance_id, ) if search_filter: - res = {"chunk": [ - entry - for entry in list(res.get("chunk", [])) - if _matches_room_entry(entry, search_filter) - ]} + res = { + "chunk": [ + entry + for entry in list(res.get("chunk", [])) + if _matches_room_entry(entry, search_filter) + ] + } defer.returnValue(res) - def _get_remote_list_cached(self, server_name, limit=None, since_token=None, - search_filter=None, include_all_networks=False, - third_party_instance_id=None,): + def _get_remote_list_cached( + self, + server_name, + limit=None, + since_token=None, + search_filter=None, + include_all_networks=False, + third_party_instance_id=None, + ): repl_layer = self.hs.get_federation_client() if search_filter: # We can't cache when asking for search return repl_layer.get_public_rooms( - server_name, limit=limit, since_token=since_token, - search_filter=search_filter, include_all_networks=include_all_networks, + server_name, + limit=limit, + since_token=since_token, + search_filter=search_filter, + include_all_networks=include_all_networks, third_party_instance_id=third_party_instance_id, ) key = ( - server_name, limit, since_token, include_all_networks, + server_name, + limit, + since_token, + include_all_networks, third_party_instance_id, ) return self.remote_response_cache.wrap( key, repl_layer.get_public_rooms, - server_name, limit=limit, since_token=since_token, + server_name, + limit=limit, + since_token=since_token, search_filter=search_filter, include_all_networks=include_all_networks, third_party_instance_id=third_party_instance_id, ) -class RoomListNextBatch(namedtuple("RoomListNextBatch", ( - "stream_ordering", # stream_ordering of the first public room list - "public_room_stream_id", # public room stream id for first public room list - "current_limit", # The number of previous rooms returned - "direction_is_forward", # Bool if this is a next_batch, false if prev_batch -))): +class RoomListNextBatch( + namedtuple( + "RoomListNextBatch", + ( + "stream_ordering", # stream_ordering of the first public room list + "public_room_stream_id", # public room stream id for first public room list + "current_limit", # The number of previous rooms returned + "direction_is_forward", # Bool if this is a next_batch, false if prev_batch + ), + ) +): KEY_DICT = { "stream_ordering": "s", @@ -527,21 +579,19 @@ class RoomListNextBatch(namedtuple("RoomListNextBatch", ( decoded = msgpack.loads(decode_base64(token), raw=False) else: decoded = msgpack.loads(decode_base64(token)) - return RoomListNextBatch(**{ - cls.REVERSE_KEY_DICT[key]: val - for key, val in decoded.items() - }) + return RoomListNextBatch( + **{cls.REVERSE_KEY_DICT[key]: val for key, val in decoded.items()} + ) def to_token(self): - return encode_base64(msgpack.dumps({ - self.KEY_DICT[key]: val - for key, val in self._asdict().items() - })) + return encode_base64( + msgpack.dumps( + {self.KEY_DICT[key]: val for key, val in self._asdict().items()} + ) + ) def copy_and_replace(self, **kwds): - return self._replace( - **kwds - ) + return self._replace(**kwds) def _matches_room_entry(room_entry, search_filter): diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 7f1ee95a7a..0a2ff75629 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -159,7 +159,11 @@ class RoomMemberHandler(object): @defer.inlineCallbacks def _local_membership_update( - self, requester, target, room_id, membership, + self, + requester, + target, + room_id, + membership, prev_events_and_hashes, txn_id=None, ratelimit=True, @@ -183,7 +187,6 @@ class RoomMemberHandler(object): "room_id": room_id, "sender": requester.user.to_string(), "state_key": user_id, - # For backwards compatibility: "membership": membership, }, @@ -195,26 +198,19 @@ class RoomMemberHandler(object): # Check if this event matches the previous membership event for the user. duplicate = yield self.event_creation_handler.deduplicate_state_event( - event, context, + event, context ) if duplicate is not None: # Discard the new event since this membership change is a no-op. defer.returnValue(duplicate) yield self.event_creation_handler.handle_new_client_event( - requester, - event, - context, - extra_users=[target], - ratelimit=ratelimit, + requester, event, context, extra_users=[target], ratelimit=ratelimit ) prev_state_ids = yield context.get_prev_state_ids(self.store) - prev_member_event_id = prev_state_ids.get( - (EventTypes.Member, user_id), - None - ) + prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) if event.membership == Membership.JOIN: # Only fire user_joined_room if the user has actually joined the @@ -236,11 +232,11 @@ class RoomMemberHandler(object): if predecessor: # It is an upgraded room. Copy over old tags self.copy_room_tags_and_direct_to_room( - predecessor["room_id"], room_id, user_id, + predecessor["room_id"], room_id, user_id ) # Move over old push rules self.store.move_push_rules_from_room_to_room_for_user( - predecessor["room_id"], room_id, user_id, + predecessor["room_id"], room_id, user_id ) elif event.membership == Membership.LEAVE: if prev_member_event_id: @@ -251,12 +247,7 @@ class RoomMemberHandler(object): defer.returnValue(event) @defer.inlineCallbacks - def copy_room_tags_and_direct_to_room( - self, - old_room_id, - new_room_id, - user_id, - ): + def copy_room_tags_and_direct_to_room(self, old_room_id, new_room_id, user_id): """Copies the tags and direct room state from one room to another. Args: @@ -268,9 +259,7 @@ class RoomMemberHandler(object): Deferred[None] """ # Retrieve user account data for predecessor room - user_account_data, _ = yield self.store.get_account_data_for_user( - user_id, - ) + user_account_data, _ = yield self.store.get_account_data_for_user(user_id) # Copy direct message state if applicable direct_rooms = user_account_data.get("m.direct", {}) @@ -284,35 +273,31 @@ class RoomMemberHandler(object): # Save back to user's m.direct account data yield self.store.add_account_data_for_user( - user_id, "m.direct", direct_rooms, + user_id, "m.direct", direct_rooms ) break # Copy room tags if applicable - room_tags = yield self.store.get_tags_for_room( - user_id, old_room_id, - ) + room_tags = yield self.store.get_tags_for_room(user_id, old_room_id) # Copy each room tag to the new room for tag, tag_content in room_tags.items(): - yield self.store.add_tag_to_room( - user_id, new_room_id, tag, tag_content - ) + yield self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content) @defer.inlineCallbacks def update_membership( - self, - requester, - target, - room_id, - action, - txn_id=None, - remote_room_hosts=None, - third_party_signed=None, - ratelimit=True, - content=None, - new_room=False, - require_consent=True, + self, + requester, + target, + room_id, + action, + txn_id=None, + remote_room_hosts=None, + third_party_signed=None, + ratelimit=True, + content=None, + new_room=False, + require_consent=True, ): """Update a users membership in a room @@ -357,18 +342,18 @@ class RoomMemberHandler(object): @defer.inlineCallbacks def _update_membership( - self, - requester, - target, - room_id, - action, - txn_id=None, - remote_room_hosts=None, - third_party_signed=None, - ratelimit=True, - content=None, - new_room=False, - require_consent=True, + self, + requester, + target, + room_id, + action, + txn_id=None, + remote_room_hosts=None, + third_party_signed=None, + ratelimit=True, + content=None, + new_room=False, + require_consent=True, ): content_specified = bool(content) if content is None: @@ -402,7 +387,7 @@ class RoomMemberHandler(object): if not remote_room_hosts: remote_room_hosts = [] - if effective_membership_state not in ("leave", "ban",): + if effective_membership_state not in ("leave", "ban"): is_blocked = yield self.store.is_room_blocked(room_id) if is_blocked: raise SynapseError(403, "This room has been blocked on this server") @@ -410,22 +395,19 @@ class RoomMemberHandler(object): if effective_membership_state == Membership.INVITE: # block any attempts to invite the server notices mxid if target.to_string() == self._server_notices_mxid: - raise SynapseError( - http_client.FORBIDDEN, - "Cannot invite this user", - ) + raise SynapseError(http_client.FORBIDDEN, "Cannot invite this user") block_invite = False - if (self._server_notices_mxid is not None and - requester.user.to_string() == self._server_notices_mxid): + if ( + self._server_notices_mxid is not None + and requester.user.to_string() == self._server_notices_mxid + ): # allow the server notices mxid to send invites is_requester_admin = True else: - is_requester_admin = yield self.auth.is_server_admin( - requester.user, - ) + is_requester_admin = yield self.auth.is_server_admin(requester.user) if not is_requester_admin: if self.config.block_non_admin_invites: @@ -438,7 +420,8 @@ class RoomMemberHandler(object): is_published = yield self.store.is_room_published(room_id) if not self.spam_checker.user_may_invite( - requester.user.to_string(), target.to_string(), + requester.user.to_string(), + target.to_string(), third_party_invite=None, room_id=room_id, new_room=new_room, @@ -448,19 +431,13 @@ class RoomMemberHandler(object): block_invite = True if block_invite: - raise SynapseError( - 403, "Invites have been disabled on this server", - ) + raise SynapseError(403, "Invites have been disabled on this server") - prev_events_and_hashes = yield self.store.get_prev_events_for_room( - room_id, - ) - latest_event_ids = ( - event_id for (event_id, _, _) in prev_events_and_hashes - ) + prev_events_and_hashes = yield self.store.get_prev_events_for_room(room_id) + latest_event_ids = (event_id for (event_id, _, _) in prev_events_and_hashes) current_state_ids = yield self.state_handler.get_current_state_ids( - room_id, latest_event_ids=latest_event_ids, + room_id, latest_event_ids=latest_event_ids ) # TODO: Refactor into dictionary of explicitly allowed transitions @@ -475,13 +452,13 @@ class RoomMemberHandler(object): 403, "Cannot unban user who was not banned" " (membership=%s)" % old_membership, - errcode=Codes.BAD_STATE + errcode=Codes.BAD_STATE, ) if old_membership == "ban" and action != "unban": raise SynapseError( 403, "Cannot %s user who was banned" % (action,), - errcode=Codes.BAD_STATE + errcode=Codes.BAD_STATE, ) if old_state: @@ -497,8 +474,8 @@ class RoomMemberHandler(object): # we don't allow people to reject invites to the server notice # room, but they can leave it once they are joined. if ( - old_membership == Membership.INVITE and - effective_membership_state == Membership.LEAVE + old_membership == Membership.INVITE + and effective_membership_state == Membership.LEAVE ): is_blocked = yield self._is_server_notice_room(room_id) if is_blocked: @@ -521,27 +498,24 @@ class RoomMemberHandler(object): # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") - if (self._server_notices_mxid is not None and - requester.user.to_string() == self._server_notices_mxid): + if ( + self._server_notices_mxid is not None + and requester.user.to_string() == self._server_notices_mxid + ): # allow the server notices mxid to join rooms is_requester_admin = True else: - is_requester_admin = yield self.auth.is_server_admin( - requester.user, - ) + is_requester_admin = yield self.auth.is_server_admin(requester.user) inviter = yield self._get_inviter(target.to_string(), room_id) if not is_requester_admin: # We assume that if the spam checker allowed the user to create # a room then they're allowed to join it. if not new_room and not self.spam_checker.user_may_join_room( - target.to_string(), room_id, - is_invited=inviter is not None, + target.to_string(), room_id, is_invited=inviter is not None ): - raise SynapseError( - 403, "Not allowed to join this room", - ) + raise SynapseError(403, "Not allowed to join this room") if not is_host_in_room: if inviter and not self.hs.is_mine(inviter): @@ -580,7 +554,7 @@ class RoomMemberHandler(object): # send the rejection to the inviter's HS. remote_room_hosts = remote_room_hosts + [inviter.domain] res = yield self._remote_reject_invite( - requester, remote_room_hosts, room_id, target, + requester, remote_room_hosts, room_id, target ) defer.returnValue(res) @@ -599,12 +573,7 @@ class RoomMemberHandler(object): @defer.inlineCallbacks def send_membership_event( - self, - requester, - event, - context, - remote_room_hosts=None, - ratelimit=True, + self, requester, event, context, remote_room_hosts=None, ratelimit=True ): """ Change the membership status of a user in a room. @@ -631,15 +600,14 @@ class RoomMemberHandler(object): if requester is not None: sender = UserID.from_string(event.sender) assert sender == requester.user, ( - "Sender (%s) must be same as requester (%s)" % - (sender, requester.user) + "Sender (%s) must be same as requester (%s)" % (sender, requester.user) ) assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) else: requester = synapse.types.create_requester(target_user) prev_event = yield self.event_creation_handler.deduplicate_state_event( - event, context, + event, context ) if prev_event is not None: return @@ -659,16 +627,11 @@ class RoomMemberHandler(object): raise SynapseError(403, "This room has been blocked on this server") yield self.event_creation_handler.handle_new_client_event( - requester, - event, - context, - extra_users=[target_user], - ratelimit=ratelimit, + requester, event, context, extra_users=[target_user], ratelimit=ratelimit ) prev_member_event_id = prev_state_ids.get( - (EventTypes.Member, event.state_key), - None + (EventTypes.Member, event.state_key), None ) if event.membership == Membership.JOIN: @@ -738,99 +701,82 @@ class RoomMemberHandler(object): @defer.inlineCallbacks def _get_inviter(self, user_id, room_id): invite = yield self.store.get_invite_for_user_in_room( - user_id=user_id, - room_id=room_id, + user_id=user_id, room_id=room_id ) if invite: defer.returnValue(UserID.from_string(invite.sender)) @defer.inlineCallbacks def do_3pid_invite( - self, - room_id, - inviter, - medium, - address, - id_server, - requester, - txn_id, - new_room=False, + self, + room_id, + inviter, + medium, + address, + id_server, + requester, + txn_id, + new_room=False, ): if self.config.block_non_admin_invites: - is_requester_admin = yield self.auth.is_server_admin( - requester.user, - ) + is_requester_admin = yield self.auth.is_server_admin(requester.user) if not is_requester_admin: raise SynapseError( - 403, "Invites have been disabled on this server", - Codes.FORBIDDEN, + 403, "Invites have been disabled on this server", Codes.FORBIDDEN ) # We need to rate limit *before* we send out any 3PID invites, so we # can't just rely on the standard ratelimiting of events. self.ratelimiter.ratelimit( - requester.user.to_string(), time_now_s=self.hs.clock.time(), + requester.user.to_string(), + time_now_s=self.hs.clock.time(), rate_hz=self.hs.config.rc_third_party_invite.per_second, burst_count=self.hs.config.rc_third_party_invite.burst_count, update=True, ) can_invite = yield self.third_party_event_rules.check_threepid_can_be_invited( - medium, address, room_id, + medium, address, room_id ) if not can_invite: raise SynapseError( - 403, "This third-party identifier can not be invited in this room", + 403, + "This third-party identifier can not be invited in this room", Codes.FORBIDDEN, ) can_invite = yield self.third_party_event_rules.check_threepid_can_be_invited( - medium, address, room_id, + medium, address, room_id ) if not can_invite: raise SynapseError( - 403, "This third-party identifier can not be invited in this room", + 403, + "This third-party identifier can not be invited in this room", Codes.FORBIDDEN, ) - invitee = yield self._lookup_3pid( - id_server, medium, address - ) + invitee = yield self._lookup_3pid(id_server, medium, address) is_published = yield self.store.is_room_published(room_id) if not self.spam_checker.user_may_invite( - requester.user.to_string(), invitee, - third_party_invite={ - "medium": medium, - "address": address, - }, + requester.user.to_string(), + invitee, + third_party_invite={"medium": medium, "address": address}, room_id=room_id, new_room=new_room, published_room=is_published, ): logger.info("Blocking invite due to spam checker") - raise SynapseError( - 403, "Invites have been disabled on this server", - ) + raise SynapseError(403, "Invites have been disabled on this server") if invitee: yield self.update_membership( - requester, - UserID.from_string(invitee), - room_id, - "invite", - txn_id=txn_id, + requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id ) else: yield self._make_and_store_3pid_invite( - requester, - id_server, - medium, - address, - room_id, - inviter, - txn_id=txn_id + requester, id_server, medium, address, room_id, inviter, txn_id=txn_id ) def _get_id_server_target(self, id_server): @@ -869,14 +815,7 @@ class RoomMemberHandler(object): @defer.inlineCallbacks def _make_and_store_3pid_invite( - self, - requester, - id_server, - medium, - address, - room_id, - user, - txn_id + self, requester, id_server, medium, address, room_id, user, txn_id ): room_state = yield self.state_handler.get_current_state(room_id) @@ -924,7 +863,7 @@ class RoomMemberHandler(object): room_join_rules=room_join_rules, room_name=room_name, inviter_display_name=inviter_display_name, - inviter_avatar_url=inviter_avatar_url + inviter_avatar_url=inviter_avatar_url, ) ) @@ -935,7 +874,6 @@ class RoomMemberHandler(object): "content": { "display_name": display_name, "public_keys": public_keys, - # For backwards compatibility: "key_validity_url": fallback_public_key["key_validity_url"], "public_key": fallback_public_key["public_key"], @@ -950,19 +888,19 @@ class RoomMemberHandler(object): @defer.inlineCallbacks def _ask_id_server_for_third_party_invite( - self, - requester, - id_server, - medium, - address, - room_id, - inviter_user_id, - room_alias, - room_avatar_url, - room_join_rules, - room_name, - inviter_display_name, - inviter_avatar_url + self, + requester, + id_server, + medium, + address, + room_id, + inviter_user_id, + room_alias, + room_avatar_url, + room_join_rules, + room_name, + inviter_display_name, + inviter_avatar_url, ): """ Asks an identity server for a third party invite. @@ -995,7 +933,8 @@ class RoomMemberHandler(object): target = self._get_id_server_target(id_server) is_url = "%s%s/_matrix/identity/api/v1/store-invite" % ( - id_server_scheme, target, + id_server_scheme, + target, ) invite_config = { @@ -1019,14 +958,15 @@ class RoomMemberHandler(object): inviter_user_id=inviter_user_id, ) - invite_config.update({ - "guest_access_token": guest_access_token, - "guest_user_id": guest_user_id, - }) + invite_config.update( + { + "guest_access_token": guest_access_token, + "guest_user_id": guest_user_id, + } + ) data = yield self.simple_http_client.post_urlencoded_get_json( - is_url, - invite_config + is_url, invite_config ) # TODO: Check for success token = data["token"] @@ -1034,9 +974,8 @@ class RoomMemberHandler(object): if "public_key" in data: fallback_public_key = { "public_key": data["public_key"], - "key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % ( - id_server_scheme, target, - ), + "key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" + % (id_server_scheme, target), } else: fallback_public_key = public_keys[0] @@ -1105,10 +1044,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): # that we are allowed to join when we decide whether or not we # need to do the invite/join dance. yield self.federation_handler.do_invite_join( - remote_room_hosts, - room_id, - user.to_string(), - content, + remote_room_hosts, room_id, user.to_string(), content ) yield self._user_joined_room(user, room_id) @@ -1119,9 +1055,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): fed_handler = self.federation_handler try: ret = yield fed_handler.do_remotely_reject_invite( - remote_room_hosts, - room_id, - target.to_string(), + remote_room_hosts, room_id, target.to_string() ) defer.returnValue(ret) except Exception as e: @@ -1133,9 +1067,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): # logger.warn("Failed to reject invite: %s", e) - yield self.store.locally_reject_invite( - target.to_string(), room_id - ) + yield self.store.locally_reject_invite(target.to_string(), room_id) defer.returnValue({}) def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id): @@ -1159,18 +1091,15 @@ class RoomMemberMasterHandler(RoomMemberHandler): user_id = user.to_string() member = yield self.state_handler.get_current_state( - room_id=room_id, - event_type=EventTypes.Member, - state_key=user_id + room_id=room_id, event_type=EventTypes.Member, state_key=user_id ) membership = member.membership if member else None if membership is not None and membership not in [ - Membership.LEAVE, Membership.BAN + Membership.LEAVE, + Membership.BAN, ]: - raise SynapseError(400, "User %s in room %s" % ( - user_id, room_id - )) + raise SynapseError(400, "User %s in room %s" % (user_id, room_id)) if membership: yield self.store.forget(user_id, room_id) diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index acc6eb8099..da501f38c0 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -71,18 +71,14 @@ class RoomMemberWorkerHandler(RoomMemberHandler): """Implements RoomMemberHandler._user_joined_room """ return self._notify_change_client( - user_id=target.to_string(), - room_id=room_id, - change="joined", + user_id=target.to_string(), room_id=room_id, change="joined" ) def _user_left_room(self, target, room_id): """Implements RoomMemberHandler._user_left_room """ return self._notify_change_client( - user_id=target.to_string(), - room_id=room_id, - change="left", + user_id=target.to_string(), room_id=room_id, change="left" ) def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id): diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 9bba74d6c9..ddc4430d03 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -32,7 +32,6 @@ logger = logging.getLogger(__name__) class SearchHandler(BaseHandler): - def __init__(self, hs): super(SearchHandler, self).__init__(hs) self._event_serializer = hs.get_event_client_serializer() @@ -93,7 +92,7 @@ class SearchHandler(BaseHandler): batch_token = None if batch: try: - b = decode_base64(batch).decode('ascii') + b = decode_base64(batch).decode("ascii") batch_group, batch_group_key, batch_token = b.split("\n") assert batch_group is not None @@ -104,7 +103,9 @@ class SearchHandler(BaseHandler): logger.info( "Search batch properties: %r, %r, %r", - batch_group, batch_group_key, batch_token, + batch_group, + batch_group_key, + batch_token, ) logger.info("Search content: %s", content) @@ -116,9 +117,9 @@ class SearchHandler(BaseHandler): search_term = room_cat["search_term"] # Which "keys" to search over in FTS query - keys = room_cat.get("keys", [ - "content.body", "content.name", "content.topic", - ]) + keys = room_cat.get( + "keys", ["content.body", "content.name", "content.topic"] + ) # Filter to apply to results filter_dict = room_cat.get("filter", {}) @@ -130,9 +131,7 @@ class SearchHandler(BaseHandler): include_state = room_cat.get("include_state", False) # Include context around each event? - event_context = room_cat.get( - "event_context", None - ) + event_context = room_cat.get("event_context", None) # Group results together? May allow clients to paginate within a # group @@ -140,12 +139,8 @@ class SearchHandler(BaseHandler): group_keys = [g["key"] for g in group_by] if event_context is not None: - before_limit = int(event_context.get( - "before_limit", 5 - )) - after_limit = int(event_context.get( - "after_limit", 5 - )) + before_limit = int(event_context.get("before_limit", 5)) + after_limit = int(event_context.get("after_limit", 5)) # Return the historic display name and avatar for the senders # of the events? @@ -159,7 +154,8 @@ class SearchHandler(BaseHandler): if set(group_keys) - {"room_id", "sender"}: raise SynapseError( 400, - "Invalid group by keys: %r" % (set(group_keys) - {"room_id", "sender"},) + "Invalid group by keys: %r" + % (set(group_keys) - {"room_id", "sender"},), ) search_filter = Filter(filter_dict) @@ -190,15 +186,13 @@ class SearchHandler(BaseHandler): room_ids.intersection_update({batch_group_key}) if not room_ids: - defer.returnValue({ - "search_categories": { - "room_events": { - "results": [], - "count": 0, - "highlights": [], + defer.returnValue( + { + "search_categories": { + "room_events": {"results": [], "count": 0, "highlights": []} } } - }) + ) rank_map = {} # event_id -> rank of event allowed_events = [] @@ -213,9 +207,7 @@ class SearchHandler(BaseHandler): count = None if order_by == "rank": - search_result = yield self.store.search_msgs( - room_ids, search_term, keys - ) + search_result = yield self.store.search_msgs(room_ids, search_term, keys) count = search_result["count"] @@ -235,19 +227,17 @@ class SearchHandler(BaseHandler): ) events.sort(key=lambda e: -rank_map[e.event_id]) - allowed_events = events[:search_filter.limit()] + allowed_events = events[: search_filter.limit()] for e in allowed_events: - rm = room_groups.setdefault(e.room_id, { - "results": [], - "order": rank_map[e.event_id], - }) + rm = room_groups.setdefault( + e.room_id, {"results": [], "order": rank_map[e.event_id]} + ) rm["results"].append(e.event_id) - s = sender_group.setdefault(e.sender, { - "results": [], - "order": rank_map[e.event_id], - }) + s = sender_group.setdefault( + e.sender, {"results": [], "order": rank_map[e.event_id]} + ) s["results"].append(e.event_id) elif order_by == "recent": @@ -262,7 +252,10 @@ class SearchHandler(BaseHandler): while len(room_events) < search_filter.limit() and i < 5: i += 1 search_result = yield self.store.search_rooms( - room_ids, search_term, keys, search_filter.limit() * 2, + room_ids, + search_term, + keys, + search_filter.limit() * 2, pagination_token=pagination_token, ) @@ -277,16 +270,14 @@ class SearchHandler(BaseHandler): rank_map.update({r["event"].event_id: r["rank"] for r in results}) - filtered_events = search_filter.filter([ - r["event"] for r in results - ]) + filtered_events = search_filter.filter([r["event"] for r in results]) events = yield filter_events_for_client( self.store, user.to_string(), filtered_events ) room_events.extend(events) - room_events = room_events[:search_filter.limit()] + room_events = room_events[: search_filter.limit()] if len(results) < search_filter.limit() * 2: pagination_token = None @@ -295,9 +286,7 @@ class SearchHandler(BaseHandler): pagination_token = results[-1]["pagination_token"] for event in room_events: - group = room_groups.setdefault(event.room_id, { - "results": [], - }) + group = room_groups.setdefault(event.room_id, {"results": []}) group["results"].append(event.event_id) if room_events and len(room_events) >= search_filter.limit(): @@ -309,18 +298,23 @@ class SearchHandler(BaseHandler): # it returns more from the same group (if applicable) rather # than reverting to searching all results again. if batch_group and batch_group_key: - global_next_batch = encode_base64(("%s\n%s\n%s" % ( - batch_group, batch_group_key, pagination_token - )).encode('ascii')) + global_next_batch = encode_base64( + ( + "%s\n%s\n%s" + % (batch_group, batch_group_key, pagination_token) + ).encode("ascii") + ) else: - global_next_batch = encode_base64(("%s\n%s\n%s" % ( - "all", "", pagination_token - )).encode('ascii')) + global_next_batch = encode_base64( + ("%s\n%s\n%s" % ("all", "", pagination_token)).encode("ascii") + ) for room_id, group in room_groups.items(): - group["next_batch"] = encode_base64(("%s\n%s\n%s" % ( - "room_id", room_id, pagination_token - )).encode('ascii')) + group["next_batch"] = encode_base64( + ("%s\n%s\n%s" % ("room_id", room_id, pagination_token)).encode( + "ascii" + ) + ) allowed_events.extend(room_events) @@ -338,12 +332,13 @@ class SearchHandler(BaseHandler): contexts = {} for event in allowed_events: res = yield self.store.get_events_around( - event.room_id, event.event_id, before_limit, after_limit, + event.room_id, event.event_id, before_limit, after_limit ) logger.info( "Context for search returned %d and %d events", - len(res["events_before"]), len(res["events_after"]), + len(res["events_before"]), + len(res["events_after"]), ) res["events_before"] = yield filter_events_for_client( @@ -403,12 +398,12 @@ class SearchHandler(BaseHandler): for context in contexts.values(): context["events_before"] = ( yield self._event_serializer.serialize_events( - context["events_before"], time_now, + context["events_before"], time_now ) ) context["events_after"] = ( yield self._event_serializer.serialize_events( - context["events_after"], time_now, + context["events_after"], time_now ) ) @@ -426,11 +421,15 @@ class SearchHandler(BaseHandler): results = [] for e in allowed_events: - results.append({ - "rank": rank_map[e.event_id], - "result": (yield self._event_serializer.serialize_event(e, time_now)), - "context": contexts.get(e.event_id, {}), - }) + results.append( + { + "rank": rank_map[e.event_id], + "result": ( + yield self._event_serializer.serialize_event(e, time_now) + ), + "context": contexts.get(e.event_id, {}), + } + ) rooms_cat_res = { "results": results, @@ -442,7 +441,7 @@ class SearchHandler(BaseHandler): s = {} for room_id, state in state_results.items(): s[room_id] = yield self._event_serializer.serialize_events( - state, time_now, + state, time_now ) rooms_cat_res["state"] = s @@ -456,8 +455,4 @@ class SearchHandler(BaseHandler): if global_next_batch: rooms_cat_res["next_batch"] = global_next_batch - defer.returnValue({ - "search_categories": { - "room_events": rooms_cat_res - } - }) + defer.returnValue({"search_categories": {"room_events": rooms_cat_res}}) diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index b556d23173..dea9a3bc77 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -26,6 +26,7 @@ logger = logging.getLogger(__name__) class SetPasswordHandler(BaseHandler): """Handler which deals with changing user account passwords""" + def __init__(self, hs): super(SetPasswordHandler, self).__init__(hs) self._auth_handler = hs.get_auth_handler() @@ -51,11 +52,11 @@ class SetPasswordHandler(BaseHandler): # we want to log out all of the user's other sessions. First delete # all his other devices. yield self._device_handler.delete_all_devices_for_user( - user_id, except_device_id=except_device_id, + user_id, except_device_id=except_device_id ) # and now delete any access tokens which weren't associated with # devices (or were associated with this device). yield self._auth_handler.delete_access_tokens_for_user( - user_id, except_token_id=except_access_token_id, + user_id, except_token_id=except_access_token_id ) diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py index b268bbcb2c..6b364befd5 100644 --- a/synapse/handlers/state_deltas.py +++ b/synapse/handlers/state_deltas.py @@ -21,7 +21,6 @@ logger = logging.getLogger(__name__) class StateDeltasHandler(object): - def __init__(self, hs): self.store = hs.get_datastore() diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 7ad16c8566..a0ee8db988 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -156,7 +156,7 @@ class StatsHandler(StateDeltasHandler): prev_event_content = {} if prev_event_id is not None: prev_event = yield self.store.get_event( - prev_event_id, allow_none=True, + prev_event_id, allow_none=True ) if prev_event: prev_event_content = prev_event.content diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 62fda0c664..c5188a1f8e 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -64,20 +64,14 @@ LAZY_LOADED_MEMBERS_CACHE_MAX_AGE = 30 * 60 * 1000 LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE = 100 -SyncConfig = collections.namedtuple("SyncConfig", [ - "user", - "filter_collection", - "is_guest", - "request_key", - "device_id", -]) - - -class TimelineBatch(collections.namedtuple("TimelineBatch", [ - "prev_batch", - "events", - "limited", -])): +SyncConfig = collections.namedtuple( + "SyncConfig", ["user", "filter_collection", "is_guest", "request_key", "device_id"] +) + + +class TimelineBatch( + collections.namedtuple("TimelineBatch", ["prev_batch", "events", "limited"]) +): __slots__ = [] def __nonzero__(self): @@ -85,18 +79,24 @@ class TimelineBatch(collections.namedtuple("TimelineBatch", [ to tell if room needs to be part of the sync result. """ return bool(self.events) + __bool__ = __nonzero__ # python3 -class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [ - "room_id", # str - "timeline", # TimelineBatch - "state", # dict[(str, str), FrozenEvent] - "ephemeral", - "account_data", - "unread_notifications", - "summary", -])): +class JoinedSyncResult( + collections.namedtuple( + "JoinedSyncResult", + [ + "room_id", # str + "timeline", # TimelineBatch + "state", # dict[(str, str), FrozenEvent] + "ephemeral", + "account_data", + "unread_notifications", + "summary", + ], + ) +): __slots__ = [] def __nonzero__(self): @@ -111,77 +111,93 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [ # nb the notification count does not, er, count: if there's nothing # else in the result, we don't need to send it. ) + __bool__ = __nonzero__ # python3 -class ArchivedSyncResult(collections.namedtuple("ArchivedSyncResult", [ - "room_id", # str - "timeline", # TimelineBatch - "state", # dict[(str, str), FrozenEvent] - "account_data", -])): +class ArchivedSyncResult( + collections.namedtuple( + "ArchivedSyncResult", + [ + "room_id", # str + "timeline", # TimelineBatch + "state", # dict[(str, str), FrozenEvent] + "account_data", + ], + ) +): __slots__ = [] def __nonzero__(self): """Make the result appear empty if there are no updates. This is used to tell if room needs to be part of the sync result. """ - return bool( - self.timeline - or self.state - or self.account_data - ) + return bool(self.timeline or self.state or self.account_data) + __bool__ = __nonzero__ # python3 -class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [ - "room_id", # str - "invite", # FrozenEvent: the invite event -])): +class InvitedSyncResult( + collections.namedtuple( + "InvitedSyncResult", + ["room_id", "invite"], # str # FrozenEvent: the invite event + ) +): __slots__ = [] def __nonzero__(self): """Invited rooms should always be reported to the client""" return True + __bool__ = __nonzero__ # python3 -class GroupsSyncResult(collections.namedtuple("GroupsSyncResult", [ - "join", - "invite", - "leave", -])): +class GroupsSyncResult( + collections.namedtuple("GroupsSyncResult", ["join", "invite", "leave"]) +): __slots__ = [] def __nonzero__(self): return bool(self.join or self.invite or self.leave) + __bool__ = __nonzero__ # python3 -class DeviceLists(collections.namedtuple("DeviceLists", [ - "changed", # list of user_ids whose devices may have changed - "left", # list of user_ids whose devices we no longer track -])): +class DeviceLists( + collections.namedtuple( + "DeviceLists", + [ + "changed", # list of user_ids whose devices may have changed + "left", # list of user_ids whose devices we no longer track + ], + ) +): __slots__ = [] def __nonzero__(self): return bool(self.changed or self.left) + __bool__ = __nonzero__ # python3 -class SyncResult(collections.namedtuple("SyncResult", [ - "next_batch", # Token for the next sync - "presence", # List of presence events for the user. - "account_data", # List of account_data events for the user. - "joined", # JoinedSyncResult for each joined room. - "invited", # InvitedSyncResult for each invited room. - "archived", # ArchivedSyncResult for each archived room. - "to_device", # List of direct messages for the device. - "device_lists", # List of user_ids whose devices have changed - "device_one_time_keys_count", # Dict of algorithm to count for one time keys - # for this device - "groups", -])): +class SyncResult( + collections.namedtuple( + "SyncResult", + [ + "next_batch", # Token for the next sync + "presence", # List of presence events for the user. + "account_data", # List of account_data events for the user. + "joined", # JoinedSyncResult for each joined room. + "invited", # InvitedSyncResult for each invited room. + "archived", # ArchivedSyncResult for each archived room. + "to_device", # List of direct messages for the device. + "device_lists", # List of user_ids whose devices have changed + "device_one_time_keys_count", # Dict of algorithm to count for one time keys + # for this device + "groups", + ], + ) +): __slots__ = [] def __nonzero__(self): @@ -190,20 +206,20 @@ class SyncResult(collections.namedtuple("SyncResult", [ events. """ return bool( - self.presence or - self.joined or - self.invited or - self.archived or - self.account_data or - self.to_device or - self.device_lists or - self.groups + self.presence + or self.joined + or self.invited + or self.archived + or self.account_data + or self.to_device + or self.device_lists + or self.groups ) + __bool__ = __nonzero__ # python3 class SyncHandler(object): - def __init__(self, hs): self.hs_config = hs.config self.store = hs.get_datastore() @@ -217,13 +233,16 @@ class SyncHandler(object): # ExpiringCache((User, Device)) -> LruCache(state_key => event_id) self.lazy_loaded_members_cache = ExpiringCache( - "lazy_loaded_members_cache", self.clock, - max_len=0, expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, + "lazy_loaded_members_cache", + self.clock, + max_len=0, + expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, ) @defer.inlineCallbacks - def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, - full_state=False): + def wait_for_sync_for_user( + self, sync_config, since_token=None, timeout=0, full_state=False + ): """Get the sync for a client if we have new data for it now. Otherwise wait for new data to arrive on the server. If the timeout expires, then return an empty sync result. @@ -239,13 +258,15 @@ class SyncHandler(object): res = yield self.response_cache.wrap( sync_config.request_key, self._wait_for_sync_for_user, - sync_config, since_token, timeout, full_state, + sync_config, + since_token, + timeout, + full_state, ) defer.returnValue(res) @defer.inlineCallbacks - def _wait_for_sync_for_user(self, sync_config, since_token, timeout, - full_state): + def _wait_for_sync_for_user(self, sync_config, since_token, timeout, full_state): if since_token is None: sync_type = "initial_sync" elif full_state: @@ -261,14 +282,17 @@ class SyncHandler(object): # we are going to return immediately, so don't bother calling # notifier.wait_for_events. result = yield self.current_sync_for_user( - sync_config, since_token, full_state=full_state, + sync_config, since_token, full_state=full_state ) else: + def current_sync_callback(before_token, after_token): return self.current_sync_for_user(sync_config, since_token) result = yield self.notifier.wait_for_events( - sync_config.user.to_string(), timeout, current_sync_callback, + sync_config.user.to_string(), + timeout, + current_sync_callback, from_token=since_token, ) @@ -281,8 +305,7 @@ class SyncHandler(object): defer.returnValue(result) - def current_sync_for_user(self, sync_config, since_token=None, - full_state=False): + def current_sync_for_user(self, sync_config, since_token=None, full_state=False): """Get the sync for client needed to match what the server has now. Returns: A Deferred SyncResult. @@ -334,8 +357,7 @@ class SyncHandler(object): # result returned by the event source is poor form (it might cache # the object) room_id = event["room_id"] - event_copy = {k: v for (k, v) in iteritems(event) - if k != "room_id"} + event_copy = {k: v for (k, v) in iteritems(event) if k != "room_id"} ephemeral_by_room.setdefault(room_id, []).append(event_copy) receipt_key = since_token.receipt_key if since_token else "0" @@ -353,22 +375,30 @@ class SyncHandler(object): for event in receipts: room_id = event["room_id"] # exclude room id, as above - event_copy = {k: v for (k, v) in iteritems(event) - if k != "room_id"} + event_copy = {k: v for (k, v) in iteritems(event) if k != "room_id"} ephemeral_by_room.setdefault(room_id, []).append(event_copy) defer.returnValue((now_token, ephemeral_by_room)) @defer.inlineCallbacks - def _load_filtered_recents(self, room_id, sync_config, now_token, - since_token=None, recents=None, newly_joined_room=False): + def _load_filtered_recents( + self, + room_id, + sync_config, + now_token, + since_token=None, + recents=None, + newly_joined_room=False, + ): """ Returns: a Deferred TimelineBatch """ with Measure(self.clock, "load_filtered_recents"): timeline_limit = sync_config.filter_collection.timeline_limit() - block_all_timeline = sync_config.filter_collection.blocks_all_room_timeline() + block_all_timeline = ( + sync_config.filter_collection.blocks_all_room_timeline() + ) if recents is None or newly_joined_room or timeline_limit < len(recents): limited = True @@ -396,11 +426,9 @@ class SyncHandler(object): recents = [] if not limited or block_all_timeline: - defer.returnValue(TimelineBatch( - events=recents, - prev_batch=now_token, - limited=False - )) + defer.returnValue( + TimelineBatch(events=recents, prev_batch=now_token, limited=False) + ) filtering_factor = 2 load_limit = max(timeline_limit * filtering_factor, 10) @@ -427,9 +455,7 @@ class SyncHandler(object): ) else: events, end_key = yield self.store.get_recent_events_for_room( - room_id, - limit=load_limit + 1, - end_token=end_key, + room_id, limit=load_limit + 1, end_token=end_key ) loaded_recents = sync_config.filter_collection.filter_room_timeline( events @@ -462,15 +488,15 @@ class SyncHandler(object): recents = recents[-timeline_limit:] room_key = recents[0].internal_metadata.before - prev_batch_token = now_token.copy_and_replace( - "room_key", room_key - ) + prev_batch_token = now_token.copy_and_replace("room_key", room_key) - defer.returnValue(TimelineBatch( - events=recents, - prev_batch=prev_batch_token, - limited=limited or newly_joined_room - )) + defer.returnValue( + TimelineBatch( + events=recents, + prev_batch=prev_batch_token, + limited=limited or newly_joined_room, + ) + ) @defer.inlineCallbacks def get_state_after_event(self, event, state_filter=StateFilter.all()): @@ -486,7 +512,7 @@ class SyncHandler(object): A Deferred map from ((type, state_key)->Event) """ state_ids = yield self.store.get_state_ids_for_event( - event.event_id, state_filter=state_filter, + event.event_id, state_filter=state_filter ) if event.is_state(): state_ids = state_ids.copy() @@ -511,13 +537,13 @@ class SyncHandler(object): # does not reliably give you the state at the given stream position. # (https://github.com/matrix-org/synapse/issues/3305) last_events, _ = yield self.store.get_recent_events_for_room( - room_id, end_token=stream_position.room_key, limit=1, + room_id, end_token=stream_position.room_key, limit=1 ) if last_events: last_event = last_events[-1] state = yield self.get_state_after_event( - last_event, state_filter=state_filter, + last_event, state_filter=state_filter ) else: @@ -549,7 +575,7 @@ class SyncHandler(object): # FIXME: this promulgates https://github.com/matrix-org/synapse/issues/3305 last_events, _ = yield self.store.get_recent_event_ids_for_room( - room_id, end_token=now_token.room_key, limit=1, + room_id, end_token=now_token.room_key, limit=1 ) if not last_events: @@ -559,28 +585,25 @@ class SyncHandler(object): last_event = last_events[-1] state_ids = yield self.store.get_state_ids_for_event( last_event.event_id, - state_filter=StateFilter.from_types([ - (EventTypes.Name, ''), - (EventTypes.CanonicalAlias, ''), - ]), + state_filter=StateFilter.from_types( + [(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")] + ), ) # this is heavily cached, thus: fast. details = yield self.store.get_room_summary(room_id) - name_id = state_ids.get((EventTypes.Name, '')) - canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, '')) + name_id = state_ids.get((EventTypes.Name, "")) + canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, "")) summary = {} empty_ms = MemberSummary([], 0) # TODO: only send these when they change. - summary["m.joined_member_count"] = ( - details.get(Membership.JOIN, empty_ms).count - ) - summary["m.invited_member_count"] = ( - details.get(Membership.INVITE, empty_ms).count - ) + summary["m.joined_member_count"] = details.get(Membership.JOIN, empty_ms).count + summary["m.invited_member_count"] = details.get( + Membership.INVITE, empty_ms + ).count # if the room has a name or canonical_alias set, we can skip # calculating heroes. Empty strings are falsey, so we check @@ -592,7 +615,7 @@ class SyncHandler(object): if canonical_alias_id: canonical_alias = yield self.store.get_event( - canonical_alias_id, allow_none=True, + canonical_alias_id, allow_none=True ) if canonical_alias and canonical_alias.content.get("alias"): defer.returnValue(summary) @@ -600,26 +623,14 @@ class SyncHandler(object): me = sync_config.user.to_string() joined_user_ids = [ - r[0] - for r in details.get(Membership.JOIN, empty_ms).members - if r[0] != me + r[0] for r in details.get(Membership.JOIN, empty_ms).members if r[0] != me ] invited_user_ids = [ - r[0] - for r in details.get(Membership.INVITE, empty_ms).members - if r[0] != me + r[0] for r in details.get(Membership.INVITE, empty_ms).members if r[0] != me ] - gone_user_ids = ( - [ - r[0] - for r in details.get(Membership.LEAVE, empty_ms).members - if r[0] != me - ] + [ - r[0] - for r in details.get(Membership.BAN, empty_ms).members - if r[0] != me - ] - ) + gone_user_ids = [ + r[0] for r in details.get(Membership.LEAVE, empty_ms).members if r[0] != me + ] + [r[0] for r in details.get(Membership.BAN, empty_ms).members if r[0] != me] # FIXME: only build up a member_ids list for our heroes member_ids = {} @@ -627,20 +638,18 @@ class SyncHandler(object): Membership.JOIN, Membership.INVITE, Membership.LEAVE, - Membership.BAN + Membership.BAN, ): for user_id, event_id in details.get(membership, empty_ms).members: member_ids[user_id] = event_id # FIXME: order by stream ordering rather than as returned by SQL - if (joined_user_ids or invited_user_ids): - summary['m.heroes'] = sorted( + if joined_user_ids or invited_user_ids: + summary["m.heroes"] = sorted( [user_id for user_id in (joined_user_ids + invited_user_ids)] )[0:5] else: - summary['m.heroes'] = sorted( - [user_id for user_id in gone_user_ids] - )[0:5] + summary["m.heroes"] = sorted([user_id for user_id in gone_user_ids])[0:5] if not sync_config.filter_collection.lazy_load_members(): defer.returnValue(summary) @@ -652,8 +661,7 @@ class SyncHandler(object): # track which members the client should already know about via LL: # Ones which are already in state... existing_members = set( - user_id for (typ, user_id) in state.keys() - if typ == EventTypes.Member + user_id for (typ, user_id) in state.keys() if typ == EventTypes.Member ) # ...or ones which are in the timeline... @@ -664,10 +672,10 @@ class SyncHandler(object): # ...and then ensure any missing ones get included in state. missing_hero_event_ids = [ member_ids[hero_id] - for hero_id in summary['m.heroes'] + for hero_id in summary["m.heroes"] if ( - cache.get(hero_id) != member_ids[hero_id] and - hero_id not in existing_members + cache.get(hero_id) != member_ids[hero_id] + and hero_id not in existing_members ) ] @@ -691,8 +699,9 @@ class SyncHandler(object): return cache @defer.inlineCallbacks - def compute_state_delta(self, room_id, batch, sync_config, since_token, now_token, - full_state): + def compute_state_delta( + self, room_id, batch, sync_config, since_token, now_token, full_state + ): """ Works out the difference in state between the start of the timeline and the previous sync. @@ -745,23 +754,23 @@ class SyncHandler(object): timeline_state = { (event.type, event.state_key): event.event_id - for event in batch.events if event.is_state() + for event in batch.events + if event.is_state() } if full_state: if batch: current_state_ids = yield self.store.get_state_ids_for_event( - batch.events[-1].event_id, state_filter=state_filter, + batch.events[-1].event_id, state_filter=state_filter ) state_ids = yield self.store.get_state_ids_for_event( - batch.events[0].event_id, state_filter=state_filter, + batch.events[0].event_id, state_filter=state_filter ) else: current_state_ids = yield self.get_state_at( - room_id, stream_position=now_token, - state_filter=state_filter, + room_id, stream_position=now_token, state_filter=state_filter ) state_ids = current_state_ids @@ -775,7 +784,7 @@ class SyncHandler(object): ) elif batch.limited: state_at_timeline_start = yield self.store.get_state_ids_for_event( - batch.events[0].event_id, state_filter=state_filter, + batch.events[0].event_id, state_filter=state_filter ) # for now, we disable LL for gappy syncs - see @@ -793,12 +802,11 @@ class SyncHandler(object): state_filter = StateFilter.all() state_at_previous_sync = yield self.get_state_at( - room_id, stream_position=since_token, - state_filter=state_filter, + room_id, stream_position=since_token, state_filter=state_filter ) current_state_ids = yield self.store.get_state_ids_for_event( - batch.events[-1].event_id, state_filter=state_filter, + batch.events[-1].event_id, state_filter=state_filter ) state_ids = _calculate_state( @@ -854,8 +862,7 @@ class SyncHandler(object): # add any member IDs we are about to send into our LruCache for t, event_id in itertools.chain( - state_ids.items(), - timeline_state.items(), + state_ids.items(), timeline_state.items() ): if t[0] == EventTypes.Member: cache.set(t[1], event_id) @@ -864,10 +871,14 @@ class SyncHandler(object): if state_ids: state = yield self.store.get_events(list(state_ids.values())) - defer.returnValue({ - (e.type, e.state_key): e - for e in sync_config.filter_collection.filter_room_state(list(state.values())) - }) + defer.returnValue( + { + (e.type, e.state_key): e + for e in sync_config.filter_collection.filter_room_state( + list(state.values()) + ) + } + ) @defer.inlineCallbacks def unread_notifs_for_room_id(self, room_id, sync_config): @@ -875,7 +886,7 @@ class SyncHandler(object): last_unread_event_id = yield self.store.get_last_receipt_event_id_for_user( user_id=sync_config.user.to_string(), room_id=room_id, - receipt_type="m.read" + receipt_type="m.read", ) notifs = [] @@ -909,7 +920,9 @@ class SyncHandler(object): logger.info( "Calculating sync response for %r between %s and %s", - sync_config.user, since_token, now_token, + sync_config.user, + since_token, + now_token, ) user_id = sync_config.user.to_string() @@ -920,11 +933,12 @@ class SyncHandler(object): raise NotImplementedError() else: joined_room_ids = yield self.get_rooms_for_user_at( - user_id, now_token.room_stream_id, + user_id, now_token.room_stream_id ) sync_result_builder = SyncResultBuilder( - sync_config, full_state, + sync_config, + full_state, since_token=since_token, now_token=now_token, joined_room_ids=joined_room_ids, @@ -941,8 +955,7 @@ class SyncHandler(object): _, _, newly_left_rooms, newly_left_users = res block_all_presence_data = ( - since_token is None and - sync_config.filter_collection.blocks_all_presence() + since_token is None and sync_config.filter_collection.blocks_all_presence() ) if self.hs_config.use_presence and not block_all_presence_data: yield self._generate_sync_entry_for_presence( @@ -973,22 +986,23 @@ class SyncHandler(object): room_id = joined_room.room_id if room_id in newly_joined_rooms: issue4422_logger.debug( - "Sync result for newly joined room %s: %r", - room_id, joined_room, + "Sync result for newly joined room %s: %r", room_id, joined_room ) - defer.returnValue(SyncResult( - presence=sync_result_builder.presence, - account_data=sync_result_builder.account_data, - joined=sync_result_builder.joined, - invited=sync_result_builder.invited, - archived=sync_result_builder.archived, - to_device=sync_result_builder.to_device, - device_lists=device_lists, - groups=sync_result_builder.groups, - device_one_time_keys_count=one_time_key_counts, - next_batch=sync_result_builder.now_token, - )) + defer.returnValue( + SyncResult( + presence=sync_result_builder.presence, + account_data=sync_result_builder.account_data, + joined=sync_result_builder.joined, + invited=sync_result_builder.invited, + archived=sync_result_builder.archived, + to_device=sync_result_builder.to_device, + device_lists=device_lists, + groups=sync_result_builder.groups, + device_one_time_keys_count=one_time_key_counts, + next_batch=sync_result_builder.now_token, + ) + ) @measure_func("_generate_sync_entry_for_groups") @defer.inlineCallbacks @@ -999,11 +1013,11 @@ class SyncHandler(object): if since_token and since_token.groups_key: results = yield self.store.get_groups_changes_for_user( - user_id, since_token.groups_key, now_token.groups_key, + user_id, since_token.groups_key, now_token.groups_key ) else: results = yield self.store.get_all_groups_for_user( - user_id, now_token.groups_key, + user_id, now_token.groups_key ) invited = {} @@ -1031,17 +1045,19 @@ class SyncHandler(object): left[group_id] = content["content"] sync_result_builder.groups = GroupsSyncResult( - join=joined, - invite=invited, - leave=left, + join=joined, invite=invited, leave=left ) @measure_func("_generate_sync_entry_for_device_list") @defer.inlineCallbacks - def _generate_sync_entry_for_device_list(self, sync_result_builder, - newly_joined_rooms, - newly_joined_or_invited_users, - newly_left_rooms, newly_left_users): + def _generate_sync_entry_for_device_list( + self, + sync_result_builder, + newly_joined_rooms, + newly_joined_or_invited_users, + newly_left_rooms, + newly_left_users, + ): user_id = sync_result_builder.sync_config.user.to_string() since_token = sync_result_builder.since_token @@ -1065,24 +1081,20 @@ class SyncHandler(object): changed.update(newly_joined_or_invited_users) if not changed and not newly_left_users: - defer.returnValue(DeviceLists( - changed=[], - left=newly_left_users, - )) + defer.returnValue(DeviceLists(changed=[], left=newly_left_users)) users_who_share_room = yield self.store.get_users_who_share_room_with_user( user_id ) - defer.returnValue(DeviceLists( - changed=users_who_share_room & changed, - left=set(newly_left_users) - users_who_share_room, - )) + defer.returnValue( + DeviceLists( + changed=users_who_share_room & changed, + left=set(newly_left_users) - users_who_share_room, + ) + ) else: - defer.returnValue(DeviceLists( - changed=[], - left=[], - )) + defer.returnValue(DeviceLists(changed=[], left=[])) @defer.inlineCallbacks def _generate_sync_entry_for_to_device(self, sync_result_builder): @@ -1109,8 +1121,9 @@ class SyncHandler(object): deleted = yield self.store.delete_messages_for_device( user_id, device_id, since_stream_id ) - logger.debug("Deleted %d to-device messages up to %d", - deleted, since_stream_id) + logger.debug( + "Deleted %d to-device messages up to %d", deleted, since_stream_id + ) messages, stream_id = yield self.store.get_new_messages_for_device( user_id, device_id, since_stream_id, now_token.to_device_key @@ -1118,7 +1131,10 @@ class SyncHandler(object): logger.debug( "Returning %d to-device messages between %d and %d (current token: %d)", - len(messages), since_stream_id, stream_id, now_token.to_device_key + len(messages), + since_stream_id, + stream_id, + now_token.to_device_key, ) sync_result_builder.now_token = now_token.copy_and_replace( "to_device_key", stream_id @@ -1145,8 +1161,7 @@ class SyncHandler(object): if since_token and not sync_result_builder.full_state: account_data, account_data_by_room = ( yield self.store.get_updated_account_data_for_user( - user_id, - since_token.account_data_key, + user_id, since_token.account_data_key ) ) @@ -1160,27 +1175,28 @@ class SyncHandler(object): ) else: account_data, account_data_by_room = ( - yield self.store.get_account_data_for_user( - sync_config.user.to_string() - ) + yield self.store.get_account_data_for_user(sync_config.user.to_string()) ) - account_data['m.push_rules'] = yield self.push_rules_for_user( + account_data["m.push_rules"] = yield self.push_rules_for_user( sync_config.user ) - account_data_for_user = sync_config.filter_collection.filter_account_data([ - {"type": account_data_type, "content": content} - for account_data_type, content in account_data.items() - ]) + account_data_for_user = sync_config.filter_collection.filter_account_data( + [ + {"type": account_data_type, "content": content} + for account_data_type, content in account_data.items() + ] + ) sync_result_builder.account_data = account_data_for_user defer.returnValue(account_data_by_room) @defer.inlineCallbacks - def _generate_sync_entry_for_presence(self, sync_result_builder, newly_joined_rooms, - newly_joined_or_invited_users): + def _generate_sync_entry_for_presence( + self, sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users + ): """Generates the presence portion of the sync response. Populates the `sync_result_builder` with the result. @@ -1223,17 +1239,13 @@ class SyncHandler(object): extra_users_ids.discard(user.to_string()) if extra_users_ids: - states = yield self.presence_handler.get_states( - extra_users_ids, - ) + states = yield self.presence_handler.get_states(extra_users_ids) presence.extend(states) # Deduplicate the presence entries so that there's at most one per user presence = list({p.user_id: p for p in presence}.values()) - presence = sync_config.filter_collection.filter_presence( - presence - ) + presence = sync_config.filter_collection.filter_presence(presence) sync_result_builder.presence = presence @@ -1253,8 +1265,8 @@ class SyncHandler(object): """ user_id = sync_result_builder.sync_config.user.to_string() block_all_room_ephemeral = ( - sync_result_builder.since_token is None and - sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral() + sync_result_builder.since_token is None + and sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral() ) if block_all_room_ephemeral: @@ -1275,15 +1287,14 @@ class SyncHandler(object): have_changed = yield self._have_rooms_changed(sync_result_builder) if not have_changed: tags_by_room = yield self.store.get_updated_tags( - user_id, - since_token.account_data_key, + user_id, since_token.account_data_key ) if not tags_by_room: logger.debug("no-oping sync") defer.returnValue(([], [], [], [])) ignored_account_data = yield self.store.get_global_account_data_by_type_for_user( - "m.ignored_user_list", user_id=user_id, + "m.ignored_user_list", user_id=user_id ) if ignored_account_data: @@ -1296,7 +1307,7 @@ class SyncHandler(object): room_entries, invited, newly_joined_rooms, newly_left_rooms = res tags_by_room = yield self.store.get_updated_tags( - user_id, since_token.account_data_key, + user_id, since_token.account_data_key ) else: res = yield self._get_all_rooms(sync_result_builder, ignored_users) @@ -1331,8 +1342,8 @@ class SyncHandler(object): for event in it: if event.type == EventTypes.Member: if ( - event.membership == Membership.JOIN or - event.membership == Membership.INVITE + event.membership == Membership.JOIN + or event.membership == Membership.INVITE ): newly_joined_or_invited_users.add(event.state_key) else: @@ -1343,12 +1354,14 @@ class SyncHandler(object): newly_left_users -= newly_joined_or_invited_users - defer.returnValue(( - newly_joined_rooms, - newly_joined_or_invited_users, - newly_left_rooms, - newly_left_users, - )) + defer.returnValue( + ( + newly_joined_rooms, + newly_joined_or_invited_users, + newly_left_rooms, + newly_left_users, + ) + ) @defer.inlineCallbacks def _have_rooms_changed(self, sync_result_builder): @@ -1454,7 +1467,9 @@ class SyncHandler(object): prev_membership = old_mem_ev.membership issue4422_logger.debug( "Previous membership for room %s with join: %s (event %s)", - room_id, prev_membership, old_mem_ev_id, + room_id, + prev_membership, + old_mem_ev_id, ) if not old_mem_ev or old_mem_ev.membership != Membership.JOIN: @@ -1476,8 +1491,7 @@ class SyncHandler(object): if not old_state_ids: old_state_ids = yield self.get_state_at(room_id, since_token) old_mem_ev_id = old_state_ids.get( - (EventTypes.Member, user_id), - None, + (EventTypes.Member, user_id), None ) old_mem_ev = None if old_mem_ev_id: @@ -1498,7 +1512,8 @@ class SyncHandler(object): # Always include leave/ban events. Just take the last one. # TODO: How do we handle ban -> leave in same batch? leave_events = [ - e for e in non_joins + e + for e in non_joins if e.membership in (Membership.LEAVE, Membership.BAN) ] @@ -1526,15 +1541,17 @@ class SyncHandler(object): else: batch_events = None - room_entries.append(RoomSyncResultBuilder( - room_id=room_id, - rtype="archived", - events=batch_events, - newly_joined=room_id in newly_joined_rooms, - full_state=False, - since_token=since_token, - upto_token=leave_token, - )) + room_entries.append( + RoomSyncResultBuilder( + room_id=room_id, + rtype="archived", + events=batch_events, + newly_joined=room_id in newly_joined_rooms, + full_state=False, + since_token=since_token, + upto_token=leave_token, + ) + ) timeline_limit = sync_config.filter_collection.timeline_limit() @@ -1581,7 +1598,8 @@ class SyncHandler(object): # debugging for https://github.com/matrix-org/synapse/issues/4422 issue4422_logger.debug( "RoomSyncResultBuilder events for newly joined room %s: %r", - room_id, entry.events, + room_id, + entry.events, ) room_entries.append(entry) @@ -1606,12 +1624,14 @@ class SyncHandler(object): sync_config = sync_result_builder.sync_config membership_list = ( - Membership.INVITE, Membership.JOIN, Membership.LEAVE, Membership.BAN + Membership.INVITE, + Membership.JOIN, + Membership.LEAVE, + Membership.BAN, ) room_list = yield self.store.get_rooms_for_user_where_membership_is( - user_id=user_id, - membership_list=membership_list + user_id=user_id, membership_list=membership_list ) room_entries = [] @@ -1619,23 +1639,22 @@ class SyncHandler(object): for event in room_list: if event.membership == Membership.JOIN: - room_entries.append(RoomSyncResultBuilder( - room_id=event.room_id, - rtype="joined", - events=None, - newly_joined=False, - full_state=True, - since_token=since_token, - upto_token=now_token, - )) + room_entries.append( + RoomSyncResultBuilder( + room_id=event.room_id, + rtype="joined", + events=None, + newly_joined=False, + full_state=True, + since_token=since_token, + upto_token=now_token, + ) + ) elif event.membership == Membership.INVITE: if event.sender in ignored_users: continue invite = yield self.store.get_event(event.event_id) - invited.append(InvitedSyncResult( - room_id=event.room_id, - invite=invite, - )) + invited.append(InvitedSyncResult(room_id=event.room_id, invite=invite)) elif event.membership in (Membership.LEAVE, Membership.BAN): # Always send down rooms we were banned or kicked from. if not sync_config.filter_collection.include_leave: @@ -1646,22 +1665,31 @@ class SyncHandler(object): leave_token = now_token.copy_and_replace( "room_key", "s%d" % (event.stream_ordering,) ) - room_entries.append(RoomSyncResultBuilder( - room_id=event.room_id, - rtype="archived", - events=None, - newly_joined=False, - full_state=True, - since_token=since_token, - upto_token=leave_token, - )) + room_entries.append( + RoomSyncResultBuilder( + room_id=event.room_id, + rtype="archived", + events=None, + newly_joined=False, + full_state=True, + since_token=since_token, + upto_token=leave_token, + ) + ) defer.returnValue((room_entries, invited, [])) @defer.inlineCallbacks - def _generate_room_entry(self, sync_result_builder, ignored_users, - room_builder, ephemeral, tags, account_data, - always_include=False): + def _generate_room_entry( + self, + sync_result_builder, + ignored_users, + room_builder, + ephemeral, + tags, + account_data, + always_include=False, + ): """Populates the `joined` and `archived` section of `sync_result_builder` based on the `room_builder`. @@ -1678,9 +1706,7 @@ class SyncHandler(object): """ newly_joined = room_builder.newly_joined full_state = ( - room_builder.full_state - or newly_joined - or sync_result_builder.full_state + room_builder.full_state or newly_joined or sync_result_builder.full_state ) events = room_builder.events @@ -1697,7 +1723,8 @@ class SyncHandler(object): upto_token = room_builder.upto_token batch = yield self._load_filtered_recents( - room_id, sync_config, + room_id, + sync_config, now_token=upto_token, since_token=since_token, recents=events, @@ -1708,7 +1735,8 @@ class SyncHandler(object): # debug for https://github.com/matrix-org/synapse/issues/4422 issue4422_logger.debug( "Timeline events after filtering in newly-joined room %s: %r", - room_id, batch, + room_id, + batch, ) # When we join the room (or the client requests full_state), we should @@ -1726,16 +1754,10 @@ class SyncHandler(object): account_data_events = [] if tags is not None: - account_data_events.append({ - "type": "m.tag", - "content": {"tags": tags}, - }) + account_data_events.append({"type": "m.tag", "content": {"tags": tags}}) for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) + account_data_events.append({"type": account_data_type, "content": content}) account_data_events = sync_config.filter_collection.filter_room_account_data( account_data_events @@ -1743,16 +1765,13 @@ class SyncHandler(object): ephemeral = sync_config.filter_collection.filter_room_ephemeral(ephemeral) - if not (always_include - or batch - or account_data_events - or ephemeral - or full_state): + if not ( + always_include or batch or account_data_events or ephemeral or full_state + ): return state = yield self.compute_state_delta( - room_id, batch, sync_config, since_token, now_token, - full_state=full_state + room_id, batch, sync_config, since_token, now_token, full_state=full_state ) summary = {} @@ -1760,22 +1779,19 @@ class SyncHandler(object): # we include a summary in room responses when we're lazy loading # members (as the client otherwise doesn't have enough info to form # the name itself). - if ( - sync_config.filter_collection.lazy_load_members() and - ( - # we recalulate the summary: - # if there are membership changes in the timeline, or - # if membership has changed during a gappy sync, or - # if this is an initial sync. - any(ev.type == EventTypes.Member for ev in batch.events) or - ( - # XXX: this may include false positives in the form of LL - # members which have snuck into state - batch.limited and - any(t == EventTypes.Member for (t, k) in state) - ) or - since_token is None + if sync_config.filter_collection.lazy_load_members() and ( + # we recalulate the summary: + # if there are membership changes in the timeline, or + # if membership has changed during a gappy sync, or + # if this is an initial sync. + any(ev.type == EventTypes.Member for ev in batch.events) + or ( + # XXX: this may include false positives in the form of LL + # members which have snuck into state + batch.limited + and any(t == EventTypes.Member for (t, k) in state) ) + or since_token is None ): summary = yield self.compute_summary( room_id, sync_config, batch, state, now_token @@ -1794,9 +1810,7 @@ class SyncHandler(object): ) if room_sync or always_include: - notifs = yield self.unread_notifs_for_room_id( - room_id, sync_config - ) + notifs = yield self.unread_notifs_for_room_id(room_id, sync_config) if notifs is not None: unread_notifications["notification_count"] = notifs["notify_count"] @@ -1807,11 +1821,8 @@ class SyncHandler(object): if batch.limited and since_token: user_id = sync_result_builder.sync_config.user.to_string() logger.info( - "Incremental gappy sync of %s for user %s with %d state events" % ( - room_id, - user_id, - len(state), - ) + "Incremental gappy sync of %s for user %s with %d state events" + % (room_id, user_id, len(state)) ) elif room_builder.rtype == "archived": room_sync = ArchivedSyncResult( @@ -1841,9 +1852,7 @@ class SyncHandler(object): Deferred[frozenset[str]]: Set of room_ids the user is in at given stream_ordering. """ - joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering( - user_id, - ) + joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering(user_id) joined_room_ids = set() @@ -1862,11 +1871,9 @@ class SyncHandler(object): logger.info("User joined room after current token: %s", room_id) extrems = yield self.store.get_forward_extremeties_for_room( - room_id, stream_ordering, - ) - users_in_room = yield self.state.get_current_users_in_room( - room_id, extrems, + room_id, stream_ordering ) + users_in_room = yield self.state.get_current_users_in_room(room_id, extrems) if user_id in users_in_room: joined_room_ids.add(room_id) @@ -1886,7 +1893,7 @@ def _action_has_highlight(actions): def _calculate_state( - timeline_contains, timeline_start, previous, current, lazy_load_members, + timeline_contains, timeline_start, previous, current, lazy_load_members ): """Works out what state to include in a sync response. @@ -1930,15 +1937,12 @@ def _calculate_state( if lazy_load_members: p_ids.difference_update( - e for t, e in iteritems(timeline_start) - if t[0] == EventTypes.Member + e for t, e in iteritems(timeline_start) if t[0] == EventTypes.Member ) state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids - return { - event_id_to_key[e]: e for e in state_ids - } + return {event_id_to_key[e]: e for e in state_ids} class SyncResultBuilder(object): @@ -1961,8 +1965,10 @@ class SyncResultBuilder(object): groups (GroupsSyncResult|None) to_device (list) """ - def __init__(self, sync_config, full_state, since_token, now_token, - joined_room_ids): + + def __init__( + self, sync_config, full_state, since_token, now_token, joined_room_ids + ): """ Args: sync_config (SyncConfig) @@ -1991,8 +1997,10 @@ class RoomSyncResultBuilder(object): """Stores information needed to create either a `JoinedSyncResult` or `ArchivedSyncResult`. """ - def __init__(self, room_id, rtype, events, newly_joined, full_state, - since_token, upto_token): + + def __init__( + self, room_id, rtype, events, newly_joined, full_state, since_token, upto_token + ): """ Args: room_id(str) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 972662eb48..f8062c8671 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -68,13 +68,10 @@ class TypingHandler(object): # caches which room_ids changed at which serials self._typing_stream_change_cache = StreamChangeCache( - "TypingStreamChangeCache", self._latest_room_serial, + "TypingStreamChangeCache", self._latest_room_serial ) - self.clock.looping_call( - self._handle_timeouts, - 5000, - ) + self.clock.looping_call(self._handle_timeouts, 5000) def _reset(self): """ @@ -108,19 +105,11 @@ class TypingHandler(object): if self.hs.is_mine_id(member.user_id): last_fed_poke = self._member_last_federation_poke.get(member, None) if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now: - run_in_background( - self._push_remote, - member=member, - typing=True - ) + run_in_background(self._push_remote, member=member, typing=True) # Add a paranoia timer to ensure that we always have a timer for # each person typing. - self.wheel_timer.insert( - now=now, - obj=member, - then=now + 60 * 1000, - ) + self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000) def is_typing(self, member): return member.user_id in self._room_typing.get(member.room_id, []) @@ -138,9 +127,7 @@ class TypingHandler(object): yield self.auth.check_joined_room(room_id, target_user_id) - logger.debug( - "%s has started typing in %s", target_user_id, room_id - ) + logger.debug("%s has started typing in %s", target_user_id, room_id) member = RoomMember(room_id=room_id, user_id=target_user_id) @@ -149,20 +136,13 @@ class TypingHandler(object): now = self.clock.time_msec() self._member_typing_until[member] = now + timeout - self.wheel_timer.insert( - now=now, - obj=member, - then=now + timeout, - ) + self.wheel_timer.insert(now=now, obj=member, then=now + timeout) if was_present: # No point sending another notification defer.returnValue(None) - self._push_update( - member=member, - typing=True, - ) + self._push_update(member=member, typing=True) @defer.inlineCallbacks def stopped_typing(self, target_user, auth_user, room_id): @@ -177,9 +157,7 @@ class TypingHandler(object): yield self.auth.check_joined_room(room_id, target_user_id) - logger.debug( - "%s has stopped typing in %s", target_user_id, room_id - ) + logger.debug("%s has stopped typing in %s", target_user_id, room_id) member = RoomMember(room_id=room_id, user_id=target_user_id) @@ -200,20 +178,14 @@ class TypingHandler(object): self._member_typing_until.pop(member, None) self._member_last_federation_poke.pop(member, None) - self._push_update( - member=member, - typing=False, - ) + self._push_update(member=member, typing=False) def _push_update(self, member, typing): if self.hs.is_mine_id(member.user_id): # Only send updates for changes to our own users. run_in_background(self._push_remote, member, typing) - self._push_update_local( - member=member, - typing=typing - ) + self._push_update_local(member=member, typing=typing) @defer.inlineCallbacks def _push_remote(self, member, typing): @@ -223,9 +195,7 @@ class TypingHandler(object): now = self.clock.time_msec() self.wheel_timer.insert( - now=now, - obj=member, - then=now + FEDERATION_PING_INTERVAL, + now=now, obj=member, then=now + FEDERATION_PING_INTERVAL ) for domain in set(get_domain_from_id(u) for u in users): @@ -256,8 +226,7 @@ class TypingHandler(object): if user.domain != origin: logger.info( - "Got typing update from %r with bad 'user_id': %r", - origin, user_id, + "Got typing update from %r with bad 'user_id': %r", origin, user_id ) return @@ -268,15 +237,8 @@ class TypingHandler(object): logger.info("Got typing update from %s: %r", user_id, content) now = self.clock.time_msec() self._member_typing_until[member] = now + FEDERATION_TIMEOUT - self.wheel_timer.insert( - now=now, - obj=member, - then=now + FEDERATION_TIMEOUT, - ) - self._push_update_local( - member=member, - typing=content["typing"] - ) + self.wheel_timer.insert(now=now, obj=member, then=now + FEDERATION_TIMEOUT) + self._push_update_local(member=member, typing=content["typing"]) def _push_update_local(self, member, typing): room_set = self._room_typing.setdefault(member.room_id, set()) @@ -288,7 +250,7 @@ class TypingHandler(object): self._latest_room_serial += 1 self._room_serials[member.room_id] = self._latest_room_serial self._typing_stream_change_cache.entity_has_changed( - member.room_id, self._latest_room_serial, + member.room_id, self._latest_room_serial ) self.notifier.on_new_event( @@ -300,7 +262,7 @@ class TypingHandler(object): return [] changed_rooms = self._typing_stream_change_cache.get_all_entities_changed( - last_id, + last_id ) if changed_rooms is None: @@ -334,9 +296,7 @@ class TypingNotificationEventSource(object): return { "type": "m.typing", "room_id": room_id, - "content": { - "user_ids": list(typing), - }, + "content": {"user_ids": list(typing)}, } def get_new_events(self, from_key, room_ids, **kwargs): diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py index d36bcd6336..3acf772cd1 100644 --- a/synapse/http/__init__.py +++ b/synapse/http/__init__.py @@ -25,6 +25,7 @@ from synapse.api.errors import SynapseError class RequestTimedOutError(SynapseError): """Exception representing timeout of an outbound request""" + def __init__(self): super(RequestTimedOutError, self).__init__(504, "Timed out") @@ -40,15 +41,12 @@ def cancelled_to_request_timed_out_error(value, timeout): return value -ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$') +ACCESS_TOKEN_RE = re.compile(r"(\?.*access(_|%5[Ff])token=)[^&]*(.*)$") def redact_uri(uri): """Strips access tokens from the uri replaces with """ - return ACCESS_TOKEN_RE.sub( - r'\1\3', - uri - ) + return ACCESS_TOKEN_RE.sub(r"\1\3", uri) class QuieterFileBodyProducer(FileBodyProducer): @@ -57,6 +55,7 @@ class QuieterFileBodyProducer(FileBodyProducer): Workaround for https://github.com/matrix-org/synapse/issues/4003 / https://twistedmatrix.com/trac/ticket/6528 """ + def stopProducing(self): try: FileBodyProducer.stopProducing(self) diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py index 0e10e3f8f7..096619a8c2 100644 --- a/synapse/http/additional_resource.py +++ b/synapse/http/additional_resource.py @@ -28,6 +28,7 @@ class AdditionalResource(Resource): This class is also where we wrap the request handler with logging, metrics, and exception handling. """ + def __init__(self, hs, handler): """Initialise AdditionalResource diff --git a/synapse/http/client.py b/synapse/http/client.py index d1869f308b..0e5c30db0a 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -104,8 +104,8 @@ class IPBlacklistingResolver(object): ip_address, self._ip_whitelist, self._ip_blacklist ): logger.info( - "Dropped %s from DNS resolution to %s due to blacklist" % - (ip_address, hostname) + "Dropped %s from DNS resolution to %s due to blacklist" + % (ip_address, hostname) ) has_bad_ip = True @@ -157,7 +157,7 @@ class BlacklistingAgentWrapper(Agent): self._ip_blacklist = ip_blacklist def request(self, method, uri, headers=None, bodyProducer=None): - h = urllib.parse.urlparse(uri.decode('ascii')) + h = urllib.parse.urlparse(uri.decode("ascii")) try: ip_address = IPAddress(h.hostname) @@ -165,10 +165,7 @@ class BlacklistingAgentWrapper(Agent): if check_against_blacklist( ip_address, self._ip_whitelist, self._ip_blacklist ): - logger.info( - "Blocking access to %s due to blacklist" % - (ip_address,) - ) + logger.info("Blocking access to %s due to blacklist" % (ip_address,)) e = SynapseError(403, "IP address blocked by IP blacklist entry") return defer.fail(Failure(e)) except Exception: @@ -217,7 +214,7 @@ class SimpleHttpClient(object): if hs.config.user_agent_suffix: self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix) - self.user_agent = self.user_agent.encode('ascii') + self.user_agent = self.user_agent.encode("ascii") if self._ip_blacklist: real_reactor = hs.get_reactor() @@ -533,8 +530,8 @@ class SimpleHttpClient(object): resp_headers = dict(response.headers.getAllRawHeaders()) if ( - b'Content-Length' in resp_headers - and int(resp_headers[b'Content-Length'][0]) > max_size + b"Content-Length" in resp_headers + and int(resp_headers[b"Content-Length"][0]) > max_size ): logger.warn("Requested URL is too large > %r bytes" % (self.max_size,)) raise SynapseError( @@ -559,18 +556,13 @@ class SimpleHttpClient(object): # This can happen e.g. because the body is too large. raise except Exception as e: - raise_from( - SynapseError( - 502, ("Failed to download remote body: %s" % e), - ), - e - ) + raise_from(SynapseError(502, ("Failed to download remote body: %s" % e)), e) defer.returnValue( ( length, resp_headers, - response.request.absoluteURI.decode('ascii'), + response.request.absoluteURI.decode("ascii"), response.code, ) ) @@ -660,7 +652,7 @@ def encode_urlencode_args(args): def encode_urlencode_arg(arg): if isinstance(arg, text_type): - return arg.encode('utf-8') + return arg.encode("utf-8") elif isinstance(arg, list): return [encode_urlencode_arg(i) for i in arg] else: diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index cd79ebab62..92a5b606c8 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -31,7 +31,7 @@ def parse_server_name(server_name): ValueError if the server name could not be parsed. """ try: - if server_name[-1] == ']': + if server_name[-1] == "]": # ipv6 literal, hopefully return server_name, None @@ -43,9 +43,7 @@ def parse_server_name(server_name): raise ValueError("Invalid server name '%s'" % server_name) -VALID_HOST_REGEX = re.compile( - "\\A[0-9a-zA-Z.-]+\\Z", -) +VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z") def parse_and_validate_server_name(server_name): @@ -67,17 +65,15 @@ def parse_and_validate_server_name(server_name): # that nobody is sneaking IP literals in that look like hostnames, etc. # look for ipv6 literals - if host[0] == '[': - if host[-1] != ']': - raise ValueError("Mismatched [...] in server name '%s'" % ( - server_name, - )) + if host[0] == "[": + if host[-1] != "]": + raise ValueError("Mismatched [...] in server name '%s'" % (server_name,)) return host, port # otherwise it should only be alphanumerics. if not VALID_HOST_REGEX.match(host): - raise ValueError("Server name '%s' contains invalid characters" % ( - server_name, - )) + raise ValueError( + "Server name '%s' contains invalid characters" % (server_name,) + ) return host, port diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index b4cbe97b41..414cde0777 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -48,7 +48,7 @@ WELL_KNOWN_INVALID_CACHE_PERIOD = 1 * 3600 WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600 logger = logging.getLogger(__name__) -well_known_cache = TTLCache('well-known') +well_known_cache = TTLCache("well-known") @implementer(IAgent) @@ -78,7 +78,9 @@ class MatrixFederationAgent(object): """ def __init__( - self, reactor, tls_client_options_factory, + self, + reactor, + tls_client_options_factory, _well_known_tls_policy=None, _srv_resolver=None, _well_known_cache=well_known_cache, @@ -100,9 +102,9 @@ class MatrixFederationAgent(object): if _well_known_tls_policy is not None: # the param is called 'contextFactory', but actually passing a # contextfactory is deprecated, and it expects an IPolicyForHTTPS. - agent_args['contextFactory'] = _well_known_tls_policy + agent_args["contextFactory"] = _well_known_tls_policy _well_known_agent = RedirectAgent( - Agent(self._reactor, pool=self._pool, **agent_args), + Agent(self._reactor, pool=self._pool, **agent_args) ) self._well_known_agent = _well_known_agent @@ -149,7 +151,7 @@ class MatrixFederationAgent(object): tls_options = None else: tls_options = self._tls_client_options_factory.get_options( - res.tls_server_name.decode("ascii"), + res.tls_server_name.decode("ascii") ) # make sure that the Host header is set correctly @@ -158,14 +160,14 @@ class MatrixFederationAgent(object): else: headers = headers.copy() - if not headers.hasHeader(b'host'): - headers.addRawHeader(b'host', res.host_header) + if not headers.hasHeader(b"host"): + headers.addRawHeader(b"host", res.host_header) class EndpointFactory(object): @staticmethod def endpointForURI(_uri): ep = LoggingHostnameEndpoint( - self._reactor, res.target_host, res.target_port, + self._reactor, res.target_host, res.target_port ) if tls_options is not None: ep = wrapClientTLS(tls_options, ep) @@ -203,21 +205,25 @@ class MatrixFederationAgent(object): port = parsed_uri.port if port == -1: port = 8448 - defer.returnValue(_RoutingResult( - host_header=parsed_uri.netloc, - tls_server_name=parsed_uri.host, - target_host=parsed_uri.host, - target_port=port, - )) + defer.returnValue( + _RoutingResult( + host_header=parsed_uri.netloc, + tls_server_name=parsed_uri.host, + target_host=parsed_uri.host, + target_port=port, + ) + ) if parsed_uri.port != -1: # there is an explicit port - defer.returnValue(_RoutingResult( - host_header=parsed_uri.netloc, - tls_server_name=parsed_uri.host, - target_host=parsed_uri.host, - target_port=parsed_uri.port, - )) + defer.returnValue( + _RoutingResult( + host_header=parsed_uri.netloc, + tls_server_name=parsed_uri.host, + target_host=parsed_uri.host, + target_port=parsed_uri.port, + ) + ) if lookup_well_known: # try a .well-known lookup @@ -229,8 +235,8 @@ class MatrixFederationAgent(object): # parse the server name in the .well-known response into host/port. # (This code is lifted from twisted.web.client.URI.fromBytes). - if b':' in well_known_server: - well_known_host, well_known_port = well_known_server.rsplit(b':', 1) + if b":" in well_known_server: + well_known_host, well_known_port = well_known_server.rsplit(b":", 1) try: well_known_port = int(well_known_port) except ValueError: @@ -264,21 +270,27 @@ class MatrixFederationAgent(object): port = 8448 logger.debug( "No SRV record for %s, using %s:%i", - parsed_uri.host.decode("ascii"), target_host.decode("ascii"), port, + parsed_uri.host.decode("ascii"), + target_host.decode("ascii"), + port, ) else: target_host, port = pick_server_from_list(server_list) logger.debug( "Picked %s:%i from SRV records for %s", - target_host.decode("ascii"), port, parsed_uri.host.decode("ascii"), + target_host.decode("ascii"), + port, + parsed_uri.host.decode("ascii"), ) - defer.returnValue(_RoutingResult( - host_header=parsed_uri.netloc, - tls_server_name=parsed_uri.host, - target_host=target_host, - target_port=port, - )) + defer.returnValue( + _RoutingResult( + host_header=parsed_uri.netloc, + tls_server_name=parsed_uri.host, + target_host=target_host, + target_port=port, + ) + ) @defer.inlineCallbacks def _get_well_known(self, server_name): @@ -318,18 +330,18 @@ class MatrixFederationAgent(object): - None if there was no .well-known file. - INVALID_WELL_KNOWN if the .well-known was invalid """ - uri = b"https://%s/.well-known/matrix/server" % (server_name, ) + uri = b"https://%s/.well-known/matrix/server" % (server_name,) uri_str = uri.decode("ascii") logger.info("Fetching %s", uri_str) try: response = yield make_deferred_yieldable( - self._well_known_agent.request(b"GET", uri), + self._well_known_agent.request(b"GET", uri) ) body = yield make_deferred_yieldable(readBody(response)) if response.code != 200: - raise Exception("Non-200 response %s" % (response.code, )) + raise Exception("Non-200 response %s" % (response.code,)) - parsed_body = json.loads(body.decode('utf-8')) + parsed_body = json.loads(body.decode("utf-8")) logger.info("Response from .well-known: %s", parsed_body) if not isinstance(parsed_body, dict): raise Exception("not a dict") @@ -347,8 +359,7 @@ class MatrixFederationAgent(object): result = parsed_body["m.server"].encode("ascii") cache_period = _cache_period_from_headers( - response.headers, - time_now=self._reactor.seconds, + response.headers, time_now=self._reactor.seconds ) if cache_period is None: cache_period = WELL_KNOWN_DEFAULT_CACHE_PERIOD @@ -364,6 +375,7 @@ class MatrixFederationAgent(object): @implementer(IStreamClientEndpoint) class LoggingHostnameEndpoint(object): """A wrapper for HostnameEndpint which logs when it connects""" + def __init__(self, reactor, host, port, *args, **kwargs): self.host = host self.port = port @@ -377,17 +389,17 @@ class LoggingHostnameEndpoint(object): def _cache_period_from_headers(headers, time_now=time.time): cache_controls = _parse_cache_control(headers) - if b'no-store' in cache_controls: + if b"no-store" in cache_controls: return 0 - if b'max-age' in cache_controls: + if b"max-age" in cache_controls: try: - max_age = int(cache_controls[b'max-age']) + max_age = int(cache_controls[b"max-age"]) return max_age except ValueError: pass - expires = headers.getRawHeaders(b'expires') + expires = headers.getRawHeaders(b"expires") if expires is not None: try: expires_date = stringToDatetime(expires[-1]) @@ -403,9 +415,9 @@ def _cache_period_from_headers(headers, time_now=time.time): def _parse_cache_control(headers): cache_controls = {} - for hdr in headers.getRawHeaders(b'cache-control', []): - for directive in hdr.split(b','): - splits = [x.strip() for x in directive.split(b'=', 1)] + for hdr in headers.getRawHeaders(b"cache-control", []): + for directive in hdr.split(b","): + splits = [x.strip() for x in directive.split(b"=", 1)] k = splits[0].lower() v = splits[1] if len(splits) > 1 else None cache_controls[k] = v diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py index 71830c549d..1f22f78a75 100644 --- a/synapse/http/federation/srv_resolver.py +++ b/synapse/http/federation/srv_resolver.py @@ -45,6 +45,7 @@ class Server(object): expires (int): when the cache should expire this record - in *seconds* since the epoch """ + host = attr.ib() port = attr.ib() priority = attr.ib(default=0) @@ -79,9 +80,7 @@ def pick_server_from_list(server_list): return s.host, s.port # this should be impossible. - raise RuntimeError( - "pick_server_from_list got to end of eligible server list.", - ) + raise RuntimeError("pick_server_from_list got to end of eligible server list.") class SrvResolver(object): @@ -95,6 +94,7 @@ class SrvResolver(object): cache (dict): cache object get_time (callable): clock implementation. Should return seconds since the epoch """ + def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time): self._dns_client = dns_client self._cache = cache @@ -124,7 +124,7 @@ class SrvResolver(object): try: answers, _, _ = yield make_deferred_yieldable( - self._dns_client.lookupService(service_name), + self._dns_client.lookupService(service_name) ) except DNSNameError: # TODO: cache this. We can get the SOA out of the exception, and use @@ -136,17 +136,18 @@ class SrvResolver(object): cache_entry = self._cache.get(service_name, None) if cache_entry: logger.warn( - "Failed to resolve %r, falling back to cache. %r", - service_name, e + "Failed to resolve %r, falling back to cache. %r", service_name, e ) defer.returnValue(list(cache_entry)) else: raise e - if (len(answers) == 1 - and answers[0].type == dns.SRV - and answers[0].payload - and answers[0].payload.target == dns.Name(b'.')): + if ( + len(answers) == 1 + and answers[0].type == dns.SRV + and answers[0].payload + and answers[0].payload.target == dns.Name(b".") + ): raise ConnectError("Service %s unavailable" % service_name) servers = [] @@ -157,13 +158,15 @@ class SrvResolver(object): payload = answer.payload - servers.append(Server( - host=payload.target.name, - port=payload.port, - priority=payload.priority, - weight=payload.weight, - expires=now + answer.ttl, - )) + servers.append( + Server( + host=payload.target.name, + port=payload.port, + priority=payload.priority, + weight=payload.weight, + expires=now + answer.ttl, + ) + ) self._cache[service_name] = list(servers) defer.returnValue(servers) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 663ea72a7a..5ef8bb60a3 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -54,10 +54,12 @@ from synapse.util.metrics import Measure logger = logging.getLogger(__name__) -outgoing_requests_counter = Counter("synapse_http_matrixfederationclient_requests", - "", ["method"]) -incoming_responses_counter = Counter("synapse_http_matrixfederationclient_responses", - "", ["method", "code"]) +outgoing_requests_counter = Counter( + "synapse_http_matrixfederationclient_requests", "", ["method"] +) +incoming_responses_counter = Counter( + "synapse_http_matrixfederationclient_responses", "", ["method", "code"] +) MAX_LONG_RETRIES = 10 @@ -137,11 +139,7 @@ def _handle_json_response(reactor, timeout_sec, request, response): check_content_type_is_json(response.headers) d = treq.json_content(response) - d = timeout_deferred( - d, - timeout=timeout_sec, - reactor=reactor, - ) + d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor) body = yield make_deferred_yieldable(d) except Exception as e: @@ -157,7 +155,7 @@ def _handle_json_response(reactor, timeout_sec, request, response): request.txn_id, request.destination, response.code, - response.phrase.decode('ascii', errors='replace'), + response.phrase.decode("ascii", errors="replace"), ) defer.returnValue(body) @@ -181,7 +179,7 @@ class MatrixFederationHttpClient(object): # We need to use a DNS resolver which filters out blacklisted IP # addresses, to prevent DNS rebinding. nameResolver = IPBlacklistingResolver( - real_reactor, None, hs.config.federation_ip_range_blacklist, + real_reactor, None, hs.config.federation_ip_range_blacklist ) @implementer(IReactorPluggableNameResolver) @@ -194,21 +192,19 @@ class MatrixFederationHttpClient(object): self.reactor = Reactor() - self.agent = MatrixFederationAgent( - self.reactor, - tls_client_options_factory, - ) + self.agent = MatrixFederationAgent(self.reactor, tls_client_options_factory) # Use a BlacklistingAgentWrapper to prevent circumventing the IP # blacklist via IP literals in server names self.agent = BlacklistingAgentWrapper( - self.agent, self.reactor, + self.agent, + self.reactor, ip_blacklist=hs.config.federation_ip_range_blacklist, ) self.clock = hs.get_clock() self._store = hs.get_datastore() - self.version_string_bytes = hs.version_string.encode('ascii') + self.version_string_bytes = hs.version_string.encode("ascii") self.default_timeout = 60 def schedule(x): @@ -218,10 +214,7 @@ class MatrixFederationHttpClient(object): @defer.inlineCallbacks def _send_request_with_optional_trailing_slash( - self, - request, - try_trailing_slash_on_400=False, - **send_request_args + self, request, try_trailing_slash_on_400=False, **send_request_args ): """Wrapper for _send_request which can optionally retry the request upon receiving a combination of a 400 HTTP response code and a @@ -244,9 +237,7 @@ class MatrixFederationHttpClient(object): Deferred[Dict]: Parsed JSON response body. """ try: - response = yield self._send_request( - request, **send_request_args - ) + response = yield self._send_request(request, **send_request_args) except HttpResponseException as e: # Received an HTTP error > 300. Check if it meets the requirements # to retry with a trailing slash @@ -262,9 +253,7 @@ class MatrixFederationHttpClient(object): logger.info("Retrying request with trailing slash") request.path += "/" - response = yield self._send_request( - request, **send_request_args - ) + response = yield self._send_request(request, **send_request_args) defer.returnValue(response) @@ -329,8 +318,8 @@ class MatrixFederationHttpClient(object): _sec_timeout = self.default_timeout if ( - self.hs.config.federation_domain_whitelist is not None and - request.destination not in self.hs.config.federation_domain_whitelist + self.hs.config.federation_domain_whitelist is not None + and request.destination not in self.hs.config.federation_domain_whitelist ): raise FederationDeniedError(request.destination) @@ -350,9 +339,7 @@ class MatrixFederationHttpClient(object): else: query_bytes = b"" - headers_dict = { - b"User-Agent": [self.version_string_bytes], - } + headers_dict = {b"User-Agent": [self.version_string_bytes]} with limiter: # XXX: Would be much nicer to retry only at the transaction-layer @@ -362,16 +349,14 @@ class MatrixFederationHttpClient(object): else: retries_left = MAX_SHORT_RETRIES - url_bytes = urllib.parse.urlunparse(( - b"matrix", destination_bytes, - path_bytes, None, query_bytes, b"", - )) - url_str = url_bytes.decode('ascii') + url_bytes = urllib.parse.urlunparse( + (b"matrix", destination_bytes, path_bytes, None, query_bytes, b"") + ) + url_str = url_bytes.decode("ascii") - url_to_sign_bytes = urllib.parse.urlunparse(( - b"", b"", - path_bytes, None, query_bytes, b"", - )) + url_to_sign_bytes = urllib.parse.urlunparse( + (b"", b"", path_bytes, None, query_bytes, b"") + ) while True: try: @@ -379,26 +364,27 @@ class MatrixFederationHttpClient(object): if json: headers_dict[b"Content-Type"] = [b"application/json"] auth_headers = self.build_auth_headers( - destination_bytes, method_bytes, url_to_sign_bytes, - json, + destination_bytes, method_bytes, url_to_sign_bytes, json ) data = encode_canonical_json(json) producer = QuieterFileBodyProducer( - BytesIO(data), - cooperator=self._cooperator, + BytesIO(data), cooperator=self._cooperator ) else: producer = None auth_headers = self.build_auth_headers( - destination_bytes, method_bytes, url_to_sign_bytes, + destination_bytes, method_bytes, url_to_sign_bytes ) headers_dict[b"Authorization"] = auth_headers logger.info( "{%s} [%s] Sending request: %s %s; timeout %fs", - request.txn_id, request.destination, request.method, - url_str, _sec_timeout, + request.txn_id, + request.destination, + request.method, + url_str, + _sec_timeout, ) try: @@ -430,7 +416,7 @@ class MatrixFederationHttpClient(object): request.txn_id, request.destination, response.code, - response.phrase.decode('ascii', errors='replace'), + response.phrase.decode("ascii", errors="replace"), ) if 200 <= response.code < 300: @@ -440,9 +426,7 @@ class MatrixFederationHttpClient(object): # Update transactions table? d = treq.content(response) d = timeout_deferred( - d, - timeout=_sec_timeout, - reactor=self.reactor, + d, timeout=_sec_timeout, reactor=self.reactor ) try: @@ -460,9 +444,7 @@ class MatrixFederationHttpClient(object): ) body = None - e = HttpResponseException( - response.code, response.phrase, body - ) + e = HttpResponseException(response.code, response.phrase, body) # Retry if the error is a 429 (Too Many Requests), # otherwise just raise a standard HttpResponseException @@ -521,7 +503,7 @@ class MatrixFederationHttpClient(object): defer.returnValue(response) def build_auth_headers( - self, destination, method, url_bytes, content=None, destination_is=None, + self, destination, method, url_bytes, content=None, destination_is=None ): """ Builds the Authorization headers for a federation request @@ -538,11 +520,7 @@ class MatrixFederationHttpClient(object): Returns: list[bytes]: a list of headers to be added as "Authorization:" headers """ - request = { - "method": method, - "uri": url_bytes, - "origin": self.server_name, - } + request = {"method": method, "uri": url_bytes, "origin": self.server_name} if destination is not None: request["destination"] = destination @@ -558,20 +536,28 @@ class MatrixFederationHttpClient(object): auth_headers = [] for key, sig in request["signatures"][self.server_name].items(): - auth_headers.append(( - "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % ( - self.server_name, key, sig, - )).encode('ascii') + auth_headers.append( + ( + 'X-Matrix origin=%s,key="%s",sig="%s"' + % (self.server_name, key, sig) + ).encode("ascii") ) return auth_headers @defer.inlineCallbacks - def put_json(self, destination, path, args={}, data={}, - json_data_callback=None, - long_retries=False, timeout=None, - ignore_backoff=False, - backoff_on_404=False, - try_trailing_slash_on_400=False): + def put_json( + self, + destination, + path, + args={}, + data={}, + json_data_callback=None, + long_retries=False, + timeout=None, + ignore_backoff=False, + backoff_on_404=False, + try_trailing_slash_on_400=False, + ): """ Sends the specifed json data using PUT Args: @@ -635,14 +621,22 @@ class MatrixFederationHttpClient(object): ) body = yield _handle_json_response( - self.reactor, self.default_timeout, request, response, + self.reactor, self.default_timeout, request, response ) defer.returnValue(body) @defer.inlineCallbacks - def post_json(self, destination, path, data={}, long_retries=False, - timeout=None, ignore_backoff=False, args={}): + def post_json( + self, + destination, + path, + data={}, + long_retries=False, + timeout=None, + ignore_backoff=False, + args={}, + ): """ Sends the specifed json data using POST Args: @@ -681,11 +675,7 @@ class MatrixFederationHttpClient(object): """ request = MatrixFederationRequest( - method="POST", - destination=destination, - path=path, - query=args, - json=data, + method="POST", destination=destination, path=path, query=args, json=data ) response = yield self._send_request( @@ -701,14 +691,21 @@ class MatrixFederationHttpClient(object): _sec_timeout = self.default_timeout body = yield _handle_json_response( - self.reactor, _sec_timeout, request, response, + self.reactor, _sec_timeout, request, response ) defer.returnValue(body) @defer.inlineCallbacks - def get_json(self, destination, path, args=None, retry_on_dns_fail=True, - timeout=None, ignore_backoff=False, - try_trailing_slash_on_400=False): + def get_json( + self, + destination, + path, + args=None, + retry_on_dns_fail=True, + timeout=None, + ignore_backoff=False, + try_trailing_slash_on_400=False, + ): """ GETs some json from the given host homeserver and path Args: @@ -745,10 +742,7 @@ class MatrixFederationHttpClient(object): remote, due to e.g. DNS failures, connection timeouts etc. """ request = MatrixFederationRequest( - method="GET", - destination=destination, - path=path, - query=args, + method="GET", destination=destination, path=path, query=args ) response = yield self._send_request_with_optional_trailing_slash( @@ -761,14 +755,21 @@ class MatrixFederationHttpClient(object): ) body = yield _handle_json_response( - self.reactor, self.default_timeout, request, response, + self.reactor, self.default_timeout, request, response ) defer.returnValue(body) @defer.inlineCallbacks - def delete_json(self, destination, path, long_retries=False, - timeout=None, ignore_backoff=False, args={}): + def delete_json( + self, + destination, + path, + long_retries=False, + timeout=None, + ignore_backoff=False, + args={}, + ): """Send a DELETE request to the remote expecting some json response Args: @@ -802,10 +803,7 @@ class MatrixFederationHttpClient(object): remote, due to e.g. DNS failures, connection timeouts etc. """ request = MatrixFederationRequest( - method="DELETE", - destination=destination, - path=path, - query=args, + method="DELETE", destination=destination, path=path, query=args ) response = yield self._send_request( @@ -816,14 +814,21 @@ class MatrixFederationHttpClient(object): ) body = yield _handle_json_response( - self.reactor, self.default_timeout, request, response, + self.reactor, self.default_timeout, request, response ) defer.returnValue(body) @defer.inlineCallbacks - def get_file(self, destination, path, output_stream, args={}, - retry_on_dns_fail=True, max_size=None, - ignore_backoff=False): + def get_file( + self, + destination, + path, + output_stream, + args={}, + retry_on_dns_fail=True, + max_size=None, + ignore_backoff=False, + ): """GETs a file from a given homeserver Args: destination (str): The remote server to send the HTTP request to. @@ -848,16 +853,11 @@ class MatrixFederationHttpClient(object): remote, due to e.g. DNS failures, connection timeouts etc. """ request = MatrixFederationRequest( - method="GET", - destination=destination, - path=path, - query=args, + method="GET", destination=destination, path=path, query=args ) response = yield self._send_request( - request, - retry_on_dns_fail=retry_on_dns_fail, - ignore_backoff=ignore_backoff, + request, retry_on_dns_fail=retry_on_dns_fail, ignore_backoff=ignore_backoff ) headers = dict(response.headers.getAllRawHeaders()) @@ -879,7 +879,7 @@ class MatrixFederationHttpClient(object): request.txn_id, request.destination, response.code, - response.phrase.decode('ascii', errors='replace'), + response.phrase.decode("ascii", errors="replace"), length, ) defer.returnValue((length, headers)) @@ -896,11 +896,13 @@ class _ReadBodyToFileProtocol(protocol.Protocol): self.stream.write(data) self.length += len(data) if self.max_size is not None and self.length >= self.max_size: - self.deferred.errback(SynapseError( - 502, - "Requested file is too large > %r bytes" % (self.max_size,), - Codes.TOO_LARGE, - )) + self.deferred.errback( + SynapseError( + 502, + "Requested file is too large > %r bytes" % (self.max_size,), + Codes.TOO_LARGE, + ) + ) self.deferred = defer.Deferred() self.transport.loseConnection() @@ -920,8 +922,7 @@ def _readBodyToFile(response, stream, max_size): def _flatten_response_never_received(e): if hasattr(e, "reasons"): reasons = ", ".join( - _flatten_response_never_received(f.value) - for f in e.reasons + _flatten_response_never_received(f.value) for f in e.reasons ) return "%s:[%s]" % (type(e).__name__, reasons) @@ -943,16 +944,15 @@ def check_content_type_is_json(headers): """ c_type = headers.getRawHeaders(b"Content-Type") if c_type is None: - raise RequestSendFailed(RuntimeError( - "No Content-Type header" - ), can_retry=False) + raise RequestSendFailed(RuntimeError("No Content-Type header"), can_retry=False) - c_type = c_type[0].decode('ascii') # only the first header + c_type = c_type[0].decode("ascii") # only the first header val, options = cgi.parse_header(c_type) if val != "application/json": - raise RequestSendFailed(RuntimeError( - "Content-Type not application/json: was '%s'" % c_type - ), can_retry=False) + raise RequestSendFailed( + RuntimeError("Content-Type not application/json: was '%s'" % c_type), + can_retry=False, + ) def encode_query_args(args): @@ -967,4 +967,4 @@ def encode_query_args(args): query_bytes = urllib.parse.urlencode(encoded_args, True) - return query_bytes.encode('utf8') + return query_bytes.encode("utf8") diff --git a/synapse/http/server.py b/synapse/http/server.py index 16fb7935da..6fd13e87d1 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -81,9 +81,7 @@ def wrap_json_request_handler(h): yield h(self, request) except SynapseError as e: code = e.code - logger.info( - "%s SynapseError: %s - %s", request, code, e.msg - ) + logger.info("%s SynapseError: %s - %s", request, code, e.msg) # Only respond with an error response if we haven't already started # writing, otherwise lets just kill the connection @@ -96,7 +94,10 @@ def wrap_json_request_handler(h): pass else: respond_with_json( - request, code, e.error_dict(), send_cors=True, + request, + code, + e.error_dict(), + send_cors=True, pretty_print=_request_user_agent_is_curl(request), ) @@ -124,10 +125,7 @@ def wrap_json_request_handler(h): respond_with_json( request, 500, - { - "error": "Internal server error", - "errcode": Codes.UNKNOWN, - }, + {"error": "Internal server error", "errcode": Codes.UNKNOWN}, send_cors=True, pretty_print=_request_user_agent_is_curl(request), ) @@ -143,6 +141,7 @@ def wrap_html_request_handler(h): The handler method must have a signature of "handle_foo(self, request)", where "request" must be a SynapseRequest. """ + def wrapped_request_handler(self, request): d = defer.maybeDeferred(h, self, request) d.addErrback(_return_html_error, request) @@ -164,9 +163,7 @@ def _return_html_error(f, request): msg = cme.msg if isinstance(cme, SynapseError): - logger.info( - "%s SynapseError: %s - %s", request, code, msg - ) + logger.info("%s SynapseError: %s - %s", request, code, msg) else: logger.error( "Failed handle request %r", @@ -183,9 +180,7 @@ def _return_html_error(f, request): exc_info=(f.type, f.value, f.getTracebackObject()), ) - body = HTML_ERROR_TEMPLATE.format( - code=code, msg=cgi.escape(msg), - ).encode("utf-8") + body = HTML_ERROR_TEMPLATE.format(code=code, msg=cgi.escape(msg)).encode("utf-8") request.setResponseCode(code) request.setHeader(b"Content-Type", b"text/html; charset=utf-8") request.setHeader(b"Content-Length", b"%i" % (len(body),)) @@ -205,6 +200,7 @@ def wrap_async_request_handler(h): The handler may return a deferred, in which case the completion of the request isn't logged until the deferred completes. """ + @defer.inlineCallbacks def wrapped_async_request_handler(self, request): with request.processing(): @@ -306,12 +302,14 @@ class JsonResource(HttpServer, resource.Resource): # URL again (as it was decoded by _get_handler_for request), as # ASCII because it's a URL, and then decode it to get the UTF-8 # characters that were quoted. - return urllib.parse.unquote(s.encode('ascii')).decode('utf8') + return urllib.parse.unquote(s.encode("ascii")).decode("utf8") - kwargs = intern_dict({ - name: _unquote(value) if value else value - for name, value in group_dict.items() - }) + kwargs = intern_dict( + { + name: _unquote(value) if value else value + for name, value in group_dict.items() + } + ) callback_return = yield callback(request, **kwargs) if callback_return is not None: @@ -339,7 +337,7 @@ class JsonResource(HttpServer, resource.Resource): # Loop through all the registered callbacks to check if the method # and path regex match for path_entry in self.path_regexs.get(request.method, []): - m = path_entry.pattern.match(request.path.decode('ascii')) + m = path_entry.pattern.match(request.path.decode("ascii")) if m: # We found a match! return path_entry.callback, m.groupdict() @@ -347,11 +345,14 @@ class JsonResource(HttpServer, resource.Resource): # Huh. No one wanted to handle that? Fiiiiiine. Send 400. return _unrecognised_request_handler, {} - def _send_response(self, request, code, response_json_object, - response_code_message=None): + def _send_response( + self, request, code, response_json_object, response_code_message=None + ): # TODO: Only enable CORS for the requests that need it. respond_with_json( - request, code, response_json_object, + request, + code, + response_json_object, send_cors=True, response_code_message=response_code_message, pretty_print=_request_user_agent_is_curl(request), @@ -395,7 +396,7 @@ class RootRedirect(resource.Resource): self.url = path def render_GET(self, request): - return redirectTo(self.url.encode('ascii'), request) + return redirectTo(self.url.encode("ascii"), request) def getChild(self, name, request): if len(name) == 0: @@ -403,16 +404,22 @@ class RootRedirect(resource.Resource): return resource.Resource.getChild(self, name, request) -def respond_with_json(request, code, json_object, send_cors=False, - response_code_message=None, pretty_print=False, - canonical_json=True): +def respond_with_json( + request, + code, + json_object, + send_cors=False, + response_code_message=None, + pretty_print=False, + canonical_json=True, +): # could alternatively use request.notifyFinish() and flip a flag when # the Deferred fires, but since the flag is RIGHT THERE it seems like # a waste. if request._disconnected: logger.warn( - "Not sending response to request %s, already disconnected.", - request) + "Not sending response to request %s, already disconnected.", request + ) return if pretty_print: @@ -425,14 +432,17 @@ def respond_with_json(request, code, json_object, send_cors=False, json_bytes = json.dumps(json_object).encode("utf-8") return respond_with_json_bytes( - request, code, json_bytes, + request, + code, + json_bytes, send_cors=send_cors, response_code_message=response_code_message, ) -def respond_with_json_bytes(request, code, json_bytes, send_cors=False, - response_code_message=None): +def respond_with_json_bytes( + request, code, json_bytes, send_cors=False, response_code_message=None +): """Sends encoded JSON in response to the given request. Args: @@ -474,7 +484,7 @@ def set_cors_headers(request): ) request.setHeader( b"Access-Control-Allow-Headers", - b"Origin, X-Requested-With, Content-Type, Accept, Authorization" + b"Origin, X-Requested-With, Content-Type, Accept, Authorization", ) @@ -498,9 +508,7 @@ def finish_request(request): def _request_user_agent_is_curl(request): - user_agents = request.requestHeaders.getRawHeaders( - b"User-Agent", default=[] - ) + user_agents = request.requestHeaders.getRawHeaders(b"User-Agent", default=[]) for user_agent in user_agents: if b"curl" in user_agent: return True diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 197c652850..cd8415acd5 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -48,7 +48,7 @@ def parse_integer(request, name, default=None, required=False): def parse_integer_from_args(args, name, default=None, required=False): if not isinstance(name, bytes): - name = name.encode('ascii') + name = name.encode("ascii") if name in args: try: @@ -89,18 +89,14 @@ def parse_boolean(request, name, default=None, required=False): def parse_boolean_from_args(args, name, default=None, required=False): if not isinstance(name, bytes): - name = name.encode('ascii') + name = name.encode("ascii") if name in args: try: - return { - b"true": True, - b"false": False, - }[args[name][0]] + return {b"true": True, b"false": False}[args[name][0]] except Exception: message = ( - "Boolean query parameter %r must be one of" - " ['true', 'false']" + "Boolean query parameter %r must be one of" " ['true', 'false']" ) % (name,) raise SynapseError(400, message) else: @@ -111,8 +107,15 @@ def parse_boolean_from_args(args, name, default=None, required=False): return default -def parse_string(request, name, default=None, required=False, - allowed_values=None, param_type="string", encoding='ascii'): +def parse_string( + request, + name, + default=None, + required=False, + allowed_values=None, + param_type="string", + encoding="ascii", +): """ Parse a string parameter from the request query string. @@ -145,11 +148,18 @@ def parse_string(request, name, default=None, required=False, ) -def parse_string_from_args(args, name, default=None, required=False, - allowed_values=None, param_type="string", encoding='ascii'): +def parse_string_from_args( + args, + name, + default=None, + required=False, + allowed_values=None, + param_type="string", + encoding="ascii", +): if not isinstance(name, bytes): - name = name.encode('ascii') + name = name.encode("ascii") if name in args: value = args[name][0] @@ -159,7 +169,8 @@ def parse_string_from_args(args, name, default=None, required=False, if allowed_values is not None and value not in allowed_values: message = "Query parameter %r must be one of [%s]" % ( - name, ", ".join(repr(v) for v in allowed_values) + name, + ", ".join(repr(v) for v in allowed_values), ) raise SynapseError(400, message) else: @@ -201,7 +212,7 @@ def parse_json_value_from_request(request, allow_empty_body=False): # Decode to Unicode so that simplejson will return Unicode strings on # Python 2 try: - content_unicode = content_bytes.decode('utf8') + content_unicode = content_bytes.decode("utf8") except UnicodeDecodeError: logger.warn("Unable to decode UTF-8") raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) @@ -227,9 +238,7 @@ def parse_json_object_from_request(request, allow_empty_body=False): SynapseError if the request body couldn't be decoded as JSON or if it wasn't a JSON object. """ - content = parse_json_value_from_request( - request, allow_empty_body=allow_empty_body, - ) + content = parse_json_value_from_request(request, allow_empty_body=allow_empty_body) if allow_empty_body and content is None: return {} diff --git a/synapse/http/site.py b/synapse/http/site.py index e508c0bd4f..93f679ea48 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -46,10 +46,11 @@ class SynapseRequest(Request): Attributes: logcontext(LoggingContext) : the log context for this request """ + def __init__(self, site, channel, *args, **kw): Request.__init__(self, channel, *args, **kw) self.site = site - self._channel = channel # this is used by the tests + self._channel = channel # this is used by the tests self.authenticated_entity = None self.start_time = 0 @@ -72,12 +73,12 @@ class SynapseRequest(Request): def __repr__(self): # We overwrite this so that we don't log ``access_token`` - return '<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>' % ( + return "<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>" % ( self.__class__.__name__, id(self), self.get_method(), self.get_redacted_uri(), - self.clientproto.decode('ascii', errors='replace'), + self.clientproto.decode("ascii", errors="replace"), self.site.site_tag, ) @@ -87,7 +88,7 @@ class SynapseRequest(Request): def get_redacted_uri(self): uri = self.uri if isinstance(uri, bytes): - uri = self.uri.decode('ascii') + uri = self.uri.decode("ascii") return redact_uri(uri) def get_method(self): @@ -102,7 +103,7 @@ class SynapseRequest(Request): """ method = self.method if isinstance(method, bytes): - method = self.method.decode('ascii') + method = self.method.decode("ascii") return method def get_user_agent(self): @@ -134,8 +135,7 @@ class SynapseRequest(Request): # dispatching to the handler, so that the handler # can update the servlet name in the request # metrics - requests_counter.labels(self.get_method(), - self.request_metrics.name).inc() + requests_counter.labels(self.get_method(), self.request_metrics.name).inc() @contextlib.contextmanager def processing(self): @@ -200,7 +200,7 @@ class SynapseRequest(Request): # the client disconnects. with PreserveLoggingContext(self.logcontext): logger.warn( - "Error processing request %r: %s %s", self, reason.type, reason.value, + "Error processing request %r: %s %s", self, reason.type, reason.value ) if not self._is_processing: @@ -222,7 +222,7 @@ class SynapseRequest(Request): self.start_time = time.time() self.request_metrics = RequestMetrics() self.request_metrics.start( - self.start_time, name=servlet_name, method=self.get_method(), + self.start_time, name=servlet_name, method=self.get_method() ) self.site.access_logger.info( @@ -230,7 +230,7 @@ class SynapseRequest(Request): self.getClientIP(), self.site.site_tag, self.get_method(), - self.get_redacted_uri() + self.get_redacted_uri(), ) def _finished_processing(self): @@ -282,7 +282,7 @@ class SynapseRequest(Request): self.site.access_logger.info( "%s - %s - {%s}" " Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)" - " %sB %s \"%s %s %s\" \"%s\" [%d dbevts]", + ' %sB %s "%s %s %s" "%s" [%d dbevts]', self.getClientIP(), self.site.site_tag, authenticated_entity, @@ -297,7 +297,7 @@ class SynapseRequest(Request): code, self.get_method(), self.get_redacted_uri(), - self.clientproto.decode('ascii', errors='replace'), + self.clientproto.decode("ascii", errors="replace"), user_agent, usage.evt_db_fetch_count, ) @@ -316,14 +316,19 @@ class XForwardedForRequest(SynapseRequest): Add a layer on top of another request that only uses the value of an X-Forwarded-For header as the result of C{getClientIP}. """ + def getClientIP(self): """ @return: The client address (the first address) in the value of the I{X-Forwarded-For header}. If the header is not present, return C{b"-"}. """ - return self.requestHeaders.getRawHeaders( - b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip().decode('ascii') + return ( + self.requestHeaders.getRawHeaders(b"x-forwarded-for", [b"-"])[0] + .split(b",")[0] + .strip() + .decode("ascii") + ) class SynapseRequestFactory(object): @@ -343,8 +348,17 @@ class SynapseSite(Site): Subclass of a twisted http Site that does access logging with python's standard logging """ - def __init__(self, logger_name, site_tag, config, resource, - server_version_string, *args, **kwargs): + + def __init__( + self, + logger_name, + site_tag, + config, + resource, + server_version_string, + *args, + **kwargs + ): Site.__init__(self, resource, *args, **kwargs) self.site_tag = site_tag @@ -352,7 +366,7 @@ class SynapseSite(Site): proxied = config.get("x_forwarded", False) self.requestFactory = SynapseRequestFactory(self, proxied) self.access_logger = logging.getLogger(logger_name) - self.server_version_string = server_version_string.encode('ascii') + self.server_version_string = server_version_string.encode("ascii") def log(self, request): pass diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index 8aee14a8a8..1f30179b51 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -231,10 +231,7 @@ class BucketCollector(object): res.append(["+Inf", sum(data.values())]) metric = HistogramMetricFamily( - self.name, - "", - buckets=res, - sum_value=sum([x * y for x, y in data.items()]), + self.name, "", buckets=res, sum_value=sum([x * y for x, y in data.items()]) ) yield metric @@ -263,7 +260,7 @@ class CPUMetrics(object): ticks_per_sec = 100 try: # Try and get the system config - ticks_per_sec = os.sysconf('SC_CLK_TCK') + ticks_per_sec = os.sysconf("SC_CLK_TCK") except (ValueError, TypeError, AttributeError): pass diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 037f1c490e..167e2c068a 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -60,8 +60,10 @@ _background_process_db_txn_count = Counter( _background_process_db_txn_duration = Counter( "synapse_background_process_db_txn_duration_seconds", - ("Seconds spent by background processes waiting for database " - "transactions, excluding scheduling time"), + ( + "Seconds spent by background processes waiting for database " + "transactions, excluding scheduling time" + ), ["name"], registry=None, ) @@ -94,6 +96,7 @@ class _Collector(object): Ensures that all of the metrics are up-to-date with any in-flight processes before they are returned. """ + def collect(self): background_process_in_flight_count = GaugeMetricFamily( "synapse_background_process_in_flight_count", @@ -105,14 +108,11 @@ class _Collector(object): # We also copy the process lists as that can also change with _bg_metrics_lock: _background_processes_copy = { - k: list(v) - for k, v in six.iteritems(_background_processes) + k: list(v) for k, v in six.iteritems(_background_processes) } for desc, processes in six.iteritems(_background_processes_copy): - background_process_in_flight_count.add_metric( - (desc,), len(processes), - ) + background_process_in_flight_count.add_metric((desc,), len(processes)) for process in processes: process.update_metrics() @@ -121,11 +121,11 @@ class _Collector(object): # now we need to run collect() over each of the static Counters, and # yield each metric they return. for m in ( - _background_process_ru_utime, - _background_process_ru_stime, - _background_process_db_txn_count, - _background_process_db_txn_duration, - _background_process_db_sched_duration, + _background_process_ru_utime, + _background_process_ru_stime, + _background_process_db_txn_count, + _background_process_db_txn_duration, + _background_process_db_sched_duration, ): for r in m.collect(): yield r @@ -151,14 +151,12 @@ class _BackgroundProcess(object): _background_process_ru_utime.labels(self.desc).inc(diff.ru_utime) _background_process_ru_stime.labels(self.desc).inc(diff.ru_stime) - _background_process_db_txn_count.labels(self.desc).inc( - diff.db_txn_count, - ) + _background_process_db_txn_count.labels(self.desc).inc(diff.db_txn_count) _background_process_db_txn_duration.labels(self.desc).inc( - diff.db_txn_duration_sec, + diff.db_txn_duration_sec ) _background_process_db_sched_duration.labels(self.desc).inc( - diff.db_sched_duration_sec, + diff.db_sched_duration_sec ) @@ -182,6 +180,7 @@ def run_as_background_process(desc, func, *args, **kwargs): Returns: Deferred which returns the result of func, but note that it does not follow the synapse logcontext rules. """ + @defer.inlineCallbacks def run(): with _bg_metrics_lock: diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index b3abd1b3c6..bf43ca09be 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -21,6 +21,7 @@ class ModuleApi(object): """A proxy object that gets passed to password auth providers so they can register new users etc if necessary. """ + def __init__(self, hs, auth_handler): self.hs = hs @@ -57,7 +58,7 @@ class ModuleApi(object): Returns: str: qualified @user:id """ - if username.startswith('@'): + if username.startswith("@"): return username return UserID(username, self.hs.hostname).to_string() @@ -89,8 +90,7 @@ class ModuleApi(object): # Register the user reg = self.hs.get_registration_handler() user_id, access_token = yield reg.register( - localpart=localpart, default_display_name=displayname, - bind_emails=emails, + localpart=localpart, default_display_name=displayname, bind_emails=emails ) defer.returnValue((user_id, access_token)) diff --git a/synapse/notifier.py b/synapse/notifier.py index ff589660da..1ac9c5bb84 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -37,7 +37,8 @@ logger = logging.getLogger(__name__) notified_events_counter = Counter("synapse_notifier_notified_events", "") users_woken_by_stream_counter = Counter( - "synapse_notifier_users_woken_by_stream", "", ["stream"]) + "synapse_notifier_users_woken_by_stream", "", ["stream"] +) # TODO(paul): Should be shared somewhere @@ -55,6 +56,7 @@ class _NotificationListener(object): The events stream handler will have yielded to the deferred, so to notify the handler it is sufficient to resolve the deferred. """ + __slots__ = ["deferred"] def __init__(self, deferred): @@ -95,9 +97,7 @@ class _NotifierUserStream(object): stream_id(str): The new id for the stream the event came from. time_now_ms(int): The current time in milliseconds. """ - self.current_token = self.current_token.copy_and_advance( - stream_key, stream_id - ) + self.current_token = self.current_token.copy_and_advance(stream_key, stream_id) self.last_notified_token = self.current_token self.last_notified_ms = time_now_ms noify_deferred = self.notify_deferred @@ -141,6 +141,7 @@ class _NotifierUserStream(object): class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))): def __nonzero__(self): return bool(self.events) + __bool__ = __nonzero__ # python3 @@ -190,15 +191,17 @@ class Notifier(object): all_user_streams.add(x) return sum(stream.count_listeners() for stream in all_user_streams) + LaterGauge("synapse_notifier_listeners", "", [], count_listeners) LaterGauge( - "synapse_notifier_rooms", "", [], + "synapse_notifier_rooms", + "", + [], lambda: count(bool, list(self.room_to_user_streams.values())), ) LaterGauge( - "synapse_notifier_users", "", [], - lambda: len(self.user_to_user_stream), + "synapse_notifier_users", "", [], lambda: len(self.user_to_user_stream) ) def add_replication_callback(self, cb): @@ -209,8 +212,9 @@ class Notifier(object): """ self.replication_callbacks.append(cb) - def on_new_room_event(self, event, room_stream_id, max_room_stream_id, - extra_users=[]): + def on_new_room_event( + self, event, room_stream_id, max_room_stream_id, extra_users=[] + ): """ Used by handlers to inform the notifier something has happened in the room, room event wise. @@ -222,9 +226,7 @@ class Notifier(object): until all previous events have been persisted before notifying the client streams. """ - self.pending_new_room_events.append(( - room_stream_id, event, extra_users - )) + self.pending_new_room_events.append((room_stream_id, event, extra_users)) self._notify_pending_new_room_events(max_room_stream_id) self.notify_replication() @@ -240,9 +242,9 @@ class Notifier(object): self.pending_new_room_events = [] for room_stream_id, event, extra_users in pending: if room_stream_id > max_room_stream_id: - self.pending_new_room_events.append(( - room_stream_id, event, extra_users - )) + self.pending_new_room_events.append( + (room_stream_id, event, extra_users) + ) else: self._on_new_room_event(event, room_stream_id, extra_users) @@ -250,8 +252,7 @@ class Notifier(object): """Notify any user streams that are interested in this room event""" # poke any interested application service. run_as_background_process( - "notify_app_services", - self._notify_app_services, room_stream_id, + "notify_app_services", self._notify_app_services, room_stream_id ) if self.federation_sender: @@ -261,9 +262,7 @@ class Notifier(object): self._user_joined_room(event.state_key, event.room_id) self.on_new_event( - "room_key", room_stream_id, - users=extra_users, - rooms=[event.room_id], + "room_key", room_stream_id, users=extra_users, rooms=[event.room_id] ) @defer.inlineCallbacks @@ -305,8 +304,9 @@ class Notifier(object): self.notify_replication() @defer.inlineCallbacks - def wait_for_events(self, user_id, timeout, callback, room_ids=None, - from_token=StreamToken.START): + def wait_for_events( + self, user_id, timeout, callback, room_ids=None, from_token=StreamToken.START + ): """Wait until the callback returns a non empty response or the timeout fires. """ @@ -368,9 +368,15 @@ class Notifier(object): defer.returnValue(result) @defer.inlineCallbacks - def get_events_for(self, user, pagination_config, timeout, - only_keys=None, - is_guest=False, explicit_room_id=None): + def get_events_for( + self, + user, + pagination_config, + timeout, + only_keys=None, + is_guest=False, + explicit_room_id=None, + ): """ For the given user and rooms, return any new events for them. If there are no new events wait for up to `timeout` milliseconds for any new events to happen before returning. @@ -419,10 +425,7 @@ class Notifier(object): if name == "room": new_events = yield filter_events_for_client( - self.store, - user.to_string(), - new_events, - is_peeking=is_peeking, + self.store, user.to_string(), new_events, is_peeking=is_peeking ) elif name == "presence": now = self.clock.time_msec() @@ -450,7 +453,8 @@ class Notifier(object): # # I am sorry for what I have done. user_id_for_stream = "_PEEKING_%s_%s" % ( - explicit_room_id, user_id_for_stream + explicit_room_id, + user_id_for_stream, ) result = yield self.wait_for_events( @@ -477,9 +481,7 @@ class Notifier(object): @defer.inlineCallbacks def _is_world_readable(self, room_id): state = yield self.state_handler.get_current_state( - room_id, - EventTypes.RoomHistoryVisibility, - "", + room_id, EventTypes.RoomHistoryVisibility, "" ) if state and "history_visibility" in state.content: defer.returnValue(state.content["history_visibility"] == "world_readable") diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py index a5de75c48a..1ffd5e2df3 100644 --- a/synapse/push/action_generator.py +++ b/synapse/push/action_generator.py @@ -40,6 +40,4 @@ class ActionGenerator(object): @defer.inlineCallbacks def handle_push_actions_for_event(self, event, context): with Measure(self.clock, "action_for_event_by_user"): - yield self.bulk_evaluator.action_for_event_by_user( - event, context - ) + yield self.bulk_evaluator.action_for_event_by_user(event, context) diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 3523a40108..96d087de22 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -31,48 +31,54 @@ def list_with_base_rules(rawrules): # Grab the base rules that the user has modified. # The modified base rules have a priority_class of -1. - modified_base_rules = { - r['rule_id']: r for r in rawrules if r['priority_class'] < 0 - } + modified_base_rules = {r["rule_id"]: r for r in rawrules if r["priority_class"] < 0} # Remove the modified base rules from the list, They'll be added back # in the default postions in the list. - rawrules = [r for r in rawrules if r['priority_class'] >= 0] + rawrules = [r for r in rawrules if r["priority_class"] >= 0] # shove the server default rules for each kind onto the end of each current_prio_class = list(PRIORITY_CLASS_INVERSE_MAP)[-1] - ruleslist.extend(make_base_prepend_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules - )) + ruleslist.extend( + make_base_prepend_rules( + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules + ) + ) for r in rawrules: - if r['priority_class'] < current_prio_class: - while r['priority_class'] < current_prio_class: - ruleslist.extend(make_base_append_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class], - modified_base_rules, - )) - current_prio_class -= 1 - if current_prio_class > 0: - ruleslist.extend(make_base_prepend_rules( + if r["priority_class"] < current_prio_class: + while r["priority_class"] < current_prio_class: + ruleslist.extend( + make_base_append_rules( PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules, - )) + ) + ) + current_prio_class -= 1 + if current_prio_class > 0: + ruleslist.extend( + make_base_prepend_rules( + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], + modified_base_rules, + ) + ) ruleslist.append(r) while current_prio_class > 0: - ruleslist.extend(make_base_append_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class], - modified_base_rules, - )) + ruleslist.extend( + make_base_append_rules( + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules + ) + ) current_prio_class -= 1 if current_prio_class > 0: - ruleslist.extend(make_base_prepend_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class], - modified_base_rules, - )) + ruleslist.extend( + make_base_prepend_rules( + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules + ) + ) return ruleslist @@ -80,20 +86,20 @@ def list_with_base_rules(rawrules): def make_base_append_rules(kind, modified_base_rules): rules = [] - if kind == 'override': + if kind == "override": rules = BASE_APPEND_OVERRIDE_RULES - elif kind == 'underride': + elif kind == "underride": rules = BASE_APPEND_UNDERRIDE_RULES - elif kind == 'content': + elif kind == "content": rules = BASE_APPEND_CONTENT_RULES # Copy the rules before modifying them rules = copy.deepcopy(rules) for r in rules: # Only modify the actions, keep the conditions the same. - modified = modified_base_rules.get(r['rule_id']) + modified = modified_base_rules.get(r["rule_id"]) if modified: - r['actions'] = modified['actions'] + r["actions"] = modified["actions"] return rules @@ -101,103 +107,86 @@ def make_base_append_rules(kind, modified_base_rules): def make_base_prepend_rules(kind, modified_base_rules): rules = [] - if kind == 'override': + if kind == "override": rules = BASE_PREPEND_OVERRIDE_RULES # Copy the rules before modifying them rules = copy.deepcopy(rules) for r in rules: # Only modify the actions, keep the conditions the same. - modified = modified_base_rules.get(r['rule_id']) + modified = modified_base_rules.get(r["rule_id"]) if modified: - r['actions'] = modified['actions'] + r["actions"] = modified["actions"] return rules BASE_APPEND_CONTENT_RULES = [ { - 'rule_id': 'global/content/.m.rule.contains_user_name', - 'conditions': [ + "rule_id": "global/content/.m.rule.contains_user_name", + "conditions": [ { - 'kind': 'event_match', - 'key': 'content.body', - 'pattern_type': 'user_localpart' + "kind": "event_match", + "key": "content.body", + "pattern_type": "user_localpart", } ], - 'actions': [ - 'notify', - { - 'set_tweak': 'sound', - 'value': 'default', - }, { - 'set_tweak': 'highlight' - } - ] - }, + "actions": [ + "notify", + {"set_tweak": "sound", "value": "default"}, + {"set_tweak": "highlight"}, + ], + } ] BASE_PREPEND_OVERRIDE_RULES = [ { - 'rule_id': 'global/override/.m.rule.master', - 'enabled': False, - 'conditions': [], - 'actions': [ - "dont_notify" - ] + "rule_id": "global/override/.m.rule.master", + "enabled": False, + "conditions": [], + "actions": ["dont_notify"], } ] BASE_APPEND_OVERRIDE_RULES = [ { - 'rule_id': 'global/override/.m.rule.suppress_notices', - 'conditions': [ + "rule_id": "global/override/.m.rule.suppress_notices", + "conditions": [ { - 'kind': 'event_match', - 'key': 'content.msgtype', - 'pattern': 'm.notice', - '_id': '_suppress_notices', + "kind": "event_match", + "key": "content.msgtype", + "pattern": "m.notice", + "_id": "_suppress_notices", } ], - 'actions': [ - 'dont_notify', - ] + "actions": ["dont_notify"], }, # NB. .m.rule.invite_for_me must be higher prio than .m.rule.member_event # otherwise invites will be matched by .m.rule.member_event { - 'rule_id': 'global/override/.m.rule.invite_for_me', - 'conditions': [ + "rule_id": "global/override/.m.rule.invite_for_me", + "conditions": [ { - 'kind': 'event_match', - 'key': 'type', - 'pattern': 'm.room.member', - '_id': '_member', + "kind": "event_match", + "key": "type", + "pattern": "m.room.member", + "_id": "_member", }, { - 'kind': 'event_match', - 'key': 'content.membership', - 'pattern': 'invite', - '_id': '_invite_member', - }, - { - 'kind': 'event_match', - 'key': 'state_key', - 'pattern_type': 'user_id' + "kind": "event_match", + "key": "content.membership", + "pattern": "invite", + "_id": "_invite_member", }, + {"kind": "event_match", "key": "state_key", "pattern_type": "user_id"}, + ], + "actions": [ + "notify", + {"set_tweak": "sound", "value": "default"}, + {"set_tweak": "highlight", "value": False}, ], - 'actions': [ - 'notify', - { - 'set_tweak': 'sound', - 'value': 'default' - }, { - 'set_tweak': 'highlight', - 'value': False - } - ] }, # Will we sometimes want to know about people joining and leaving? # Perhaps: if so, this could be expanded upon. Seems the most usual case @@ -206,217 +195,164 @@ BASE_APPEND_OVERRIDE_RULES = [ # join/leave/avatar/displayname events. # See also: https://matrix.org/jira/browse/SYN-607 { - 'rule_id': 'global/override/.m.rule.member_event', - 'conditions': [ + "rule_id": "global/override/.m.rule.member_event", + "conditions": [ { - 'kind': 'event_match', - 'key': 'type', - 'pattern': 'm.room.member', - '_id': '_member', + "kind": "event_match", + "key": "type", + "pattern": "m.room.member", + "_id": "_member", } ], - 'actions': [ - 'dont_notify' - ] + "actions": ["dont_notify"], }, # This was changed from underride to override so it's closer in priority # to the content rules where the user name highlight rule lives. This # way a room rule is lower priority than both but a custom override rule # is higher priority than both. { - 'rule_id': 'global/override/.m.rule.contains_display_name', - 'conditions': [ - { - 'kind': 'contains_display_name' - } + "rule_id": "global/override/.m.rule.contains_display_name", + "conditions": [{"kind": "contains_display_name"}], + "actions": [ + "notify", + {"set_tweak": "sound", "value": "default"}, + {"set_tweak": "highlight"}, ], - 'actions': [ - 'notify', - { - 'set_tweak': 'sound', - 'value': 'default' - }, { - 'set_tweak': 'highlight' - } - ] }, { - 'rule_id': 'global/override/.m.rule.roomnotif', - 'conditions': [ + "rule_id": "global/override/.m.rule.roomnotif", + "conditions": [ { - 'kind': 'event_match', - 'key': 'content.body', - 'pattern': '@room', - '_id': '_roomnotif_content', + "kind": "event_match", + "key": "content.body", + "pattern": "@room", + "_id": "_roomnotif_content", }, { - 'kind': 'sender_notification_permission', - 'key': 'room', - '_id': '_roomnotif_pl', + "kind": "sender_notification_permission", + "key": "room", + "_id": "_roomnotif_pl", }, ], - 'actions': [ - 'notify', { - 'set_tweak': 'highlight', - 'value': True, - } - ] + "actions": ["notify", {"set_tweak": "highlight", "value": True}], }, { - 'rule_id': 'global/override/.m.rule.tombstone', - 'conditions': [ + "rule_id": "global/override/.m.rule.tombstone", + "conditions": [ { - 'kind': 'event_match', - 'key': 'type', - 'pattern': 'm.room.tombstone', - '_id': '_tombstone', + "kind": "event_match", + "key": "type", + "pattern": "m.room.tombstone", + "_id": "_tombstone", } ], - 'actions': [ - 'notify', { - 'set_tweak': 'highlight', - 'value': True, - } - ] - } + "actions": ["notify", {"set_tweak": "highlight", "value": True}], + }, ] BASE_APPEND_UNDERRIDE_RULES = [ { - 'rule_id': 'global/underride/.m.rule.call', - 'conditions': [ + "rule_id": "global/underride/.m.rule.call", + "conditions": [ { - 'kind': 'event_match', - 'key': 'type', - 'pattern': 'm.call.invite', - '_id': '_call', + "kind": "event_match", + "key": "type", + "pattern": "m.call.invite", + "_id": "_call", } ], - 'actions': [ - 'notify', - { - 'set_tweak': 'sound', - 'value': 'ring' - }, { - 'set_tweak': 'highlight', - 'value': False - } - ] + "actions": [ + "notify", + {"set_tweak": "sound", "value": "ring"}, + {"set_tweak": "highlight", "value": False}, + ], }, # XXX: once m.direct is standardised everywhere, we should use it to detect # a DM from the user's perspective rather than this heuristic. { - 'rule_id': 'global/underride/.m.rule.room_one_to_one', - 'conditions': [ + "rule_id": "global/underride/.m.rule.room_one_to_one", + "conditions": [ + {"kind": "room_member_count", "is": "2", "_id": "member_count"}, { - 'kind': 'room_member_count', - 'is': '2', - '_id': 'member_count', + "kind": "event_match", + "key": "type", + "pattern": "m.room.message", + "_id": "_message", }, - { - 'kind': 'event_match', - 'key': 'type', - 'pattern': 'm.room.message', - '_id': '_message', - } ], - 'actions': [ - 'notify', - { - 'set_tweak': 'sound', - 'value': 'default' - }, { - 'set_tweak': 'highlight', - 'value': False - } - ] + "actions": [ + "notify", + {"set_tweak": "sound", "value": "default"}, + {"set_tweak": "highlight", "value": False}, + ], }, # XXX: this is going to fire for events which aren't m.room.messages # but are encrypted (e.g. m.call.*)... { - 'rule_id': 'global/underride/.m.rule.encrypted_room_one_to_one', - 'conditions': [ + "rule_id": "global/underride/.m.rule.encrypted_room_one_to_one", + "conditions": [ + {"kind": "room_member_count", "is": "2", "_id": "member_count"}, { - 'kind': 'room_member_count', - 'is': '2', - '_id': 'member_count', + "kind": "event_match", + "key": "type", + "pattern": "m.room.encrypted", + "_id": "_encrypted", }, - { - 'kind': 'event_match', - 'key': 'type', - 'pattern': 'm.room.encrypted', - '_id': '_encrypted', - } ], - 'actions': [ - 'notify', - { - 'set_tweak': 'sound', - 'value': 'default' - }, { - 'set_tweak': 'highlight', - 'value': False - } - ] + "actions": [ + "notify", + {"set_tweak": "sound", "value": "default"}, + {"set_tweak": "highlight", "value": False}, + ], }, { - 'rule_id': 'global/underride/.m.rule.message', - 'conditions': [ + "rule_id": "global/underride/.m.rule.message", + "conditions": [ { - 'kind': 'event_match', - 'key': 'type', - 'pattern': 'm.room.message', - '_id': '_message', + "kind": "event_match", + "key": "type", + "pattern": "m.room.message", + "_id": "_message", } ], - 'actions': [ - 'notify', { - 'set_tweak': 'highlight', - 'value': False - } - ] + "actions": ["notify", {"set_tweak": "highlight", "value": False}], }, # XXX: this is going to fire for events which aren't m.room.messages # but are encrypted (e.g. m.call.*)... { - 'rule_id': 'global/underride/.m.rule.encrypted', - 'conditions': [ + "rule_id": "global/underride/.m.rule.encrypted", + "conditions": [ { - 'kind': 'event_match', - 'key': 'type', - 'pattern': 'm.room.encrypted', - '_id': '_encrypted', + "kind": "event_match", + "key": "type", + "pattern": "m.room.encrypted", + "_id": "_encrypted", } ], - 'actions': [ - 'notify', { - 'set_tweak': 'highlight', - 'value': False - } - ] - } + "actions": ["notify", {"set_tweak": "highlight", "value": False}], + }, ] BASE_RULE_IDS = set() for r in BASE_APPEND_CONTENT_RULES: - r['priority_class'] = PRIORITY_CLASS_MAP['content'] - r['default'] = True - BASE_RULE_IDS.add(r['rule_id']) + r["priority_class"] = PRIORITY_CLASS_MAP["content"] + r["default"] = True + BASE_RULE_IDS.add(r["rule_id"]) for r in BASE_PREPEND_OVERRIDE_RULES: - r['priority_class'] = PRIORITY_CLASS_MAP['override'] - r['default'] = True - BASE_RULE_IDS.add(r['rule_id']) + r["priority_class"] = PRIORITY_CLASS_MAP["override"] + r["default"] = True + BASE_RULE_IDS.add(r["rule_id"]) for r in BASE_APPEND_OVERRIDE_RULES: - r['priority_class'] = PRIORITY_CLASS_MAP['override'] - r['default'] = True - BASE_RULE_IDS.add(r['rule_id']) + r["priority_class"] = PRIORITY_CLASS_MAP["override"] + r["default"] = True + BASE_RULE_IDS.add(r["rule_id"]) for r in BASE_APPEND_UNDERRIDE_RULES: - r['priority_class'] = PRIORITY_CLASS_MAP['underride'] - r['default'] = True - BASE_RULE_IDS.add(r['rule_id']) + r["priority_class"] = PRIORITY_CLASS_MAP["underride"] + r["default"] = True + BASE_RULE_IDS.add(r["rule_id"]) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 8f9a76147f..c8a5b381da 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -39,9 +39,11 @@ rules_by_room = {} push_rules_invalidation_counter = Counter( - "synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", "") + "synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", "" +) push_rules_state_size_counter = Counter( - "synapse_push_bulk_push_rule_evaluator_push_rules_state_size_counter", "") + "synapse_push_bulk_push_rule_evaluator_push_rules_state_size_counter", "" +) # Measures whether we use the fast path of using state deltas, or if we have to # recalculate from scratch @@ -83,7 +85,7 @@ class BulkPushRuleEvaluator(object): # if this event is an invite event, we may need to run rules for the user # who's been invited, otherwise they won't get told they've been invited - if event.type == 'm.room.member' and event.content['membership'] == 'invite': + if event.type == "m.room.member" and event.content["membership"] == "invite": invited = event.state_key if invited and self.hs.is_mine_id(invited): has_pusher = yield self.store.user_has_pusher(invited) @@ -106,7 +108,9 @@ class BulkPushRuleEvaluator(object): # before any lookup methods get called on it as otherwise there may be # a race if invalidate_all gets called (which assumes its in the cache) return RulesForRoom( - self.hs, room_id, self._get_rules_for_room.cache, + self.hs, + room_id, + self._get_rules_for_room.cache, self.room_push_rule_cache_metrics, ) @@ -121,12 +125,10 @@ class BulkPushRuleEvaluator(object): auth_events = {POWER_KEY: pl_event} else: auth_events_ids = yield self.auth.compute_auth_events( - event, prev_state_ids, for_verification=False, + event, prev_state_ids, for_verification=False ) auth_events = yield self.store.get_events(auth_events_ids) - auth_events = { - (e.type, e.state_key): e for e in itervalues(auth_events) - } + auth_events = {(e.type, e.state_key): e for e in itervalues(auth_events)} sender_level = get_user_power_level(event.sender, auth_events) @@ -145,16 +147,14 @@ class BulkPushRuleEvaluator(object): rules_by_user = yield self._get_rules_for_event(event, context) actions_by_user = {} - room_members = yield self.store.get_joined_users_from_context( - event, context - ) + room_members = yield self.store.get_joined_users_from_context(event, context) (power_levels, sender_power_level) = ( yield self._get_power_levels_and_sender_level(event, context) ) evaluator = PushRuleEvaluatorForEvent( - event, len(room_members), sender_power_level, power_levels, + event, len(room_members), sender_power_level, power_levels ) condition_cache = {} @@ -180,15 +180,15 @@ class BulkPushRuleEvaluator(object): display_name = event.content.get("displayname", None) for rule in rules: - if 'enabled' in rule and not rule['enabled']: + if "enabled" in rule and not rule["enabled"]: continue matches = _condition_checker( - evaluator, rule['conditions'], uid, display_name, condition_cache + evaluator, rule["conditions"], uid, display_name, condition_cache ) if matches: - actions = [x for x in rule['actions'] if x != 'dont_notify'] - if actions and 'notify' in actions: + actions = [x for x in rule["actions"] if x != "dont_notify"] + if actions and "notify" in actions: # Push rules say we should notify the user of this event actions_by_user[uid] = actions break @@ -196,9 +196,7 @@ class BulkPushRuleEvaluator(object): # Mark in the DB staging area the push actions for users who should be # notified for this event. (This will then get handled when we persist # the event) - yield self.store.add_push_actions_to_staging( - event.event_id, actions_by_user, - ) + yield self.store.add_push_actions_to_staging(event.event_id, actions_by_user) def _condition_checker(evaluator, conditions, uid, display_name, cache): @@ -361,19 +359,19 @@ class RulesForRoom(object): self.sequence, members={}, # There were no membership changes rules_by_user=ret_rules_by_user, - state_group=state_group + state_group=state_group, ) if logger.isEnabledFor(logging.DEBUG): logger.debug( - "Returning push rules for %r %r", - self.room_id, ret_rules_by_user.keys(), + "Returning push rules for %r %r", self.room_id, ret_rules_by_user.keys() ) defer.returnValue(ret_rules_by_user) @defer.inlineCallbacks - def _update_rules_with_member_event_ids(self, ret_rules_by_user, member_event_ids, - state_group, event): + def _update_rules_with_member_event_ids( + self, ret_rules_by_user, member_event_ids, state_group, event + ): """Update the partially filled rules_by_user dict by fetching rules for any newly joined users in the `member_event_ids` list. @@ -391,16 +389,13 @@ class RulesForRoom(object): table="room_memberships", column="event_id", iterable=member_event_ids.values(), - retcols=('user_id', 'membership', 'event_id'), + retcols=("user_id", "membership", "event_id"), keyvalues={}, batch_size=500, desc="_get_rules_for_member_event_ids", ) - members = { - row["event_id"]: (row["user_id"], row["membership"]) - for row in rows - } + members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows} # If the event is a join event then it will be in current state evnts # map but not in the DB, so we have to explicitly insert it. @@ -413,15 +408,15 @@ class RulesForRoom(object): logger.debug("Found members %r: %r", self.room_id, members.values()) interested_in_user_ids = set( - user_id for user_id, membership in itervalues(members) + user_id + for user_id, membership in itervalues(members) if membership == Membership.JOIN ) logger.debug("Joined: %r", interested_in_user_ids) if_users_with_pushers = yield self.store.get_if_users_have_pushers( - interested_in_user_ids, - on_invalidate=self.invalidate_all_cb, + interested_in_user_ids, on_invalidate=self.invalidate_all_cb ) user_ids = set( @@ -431,7 +426,7 @@ class RulesForRoom(object): logger.debug("With pushers: %r", user_ids) users_with_receipts = yield self.store.get_users_with_read_receipts_in_room( - self.room_id, on_invalidate=self.invalidate_all_cb, + self.room_id, on_invalidate=self.invalidate_all_cb ) logger.debug("With receipts: %r", users_with_receipts) @@ -442,7 +437,7 @@ class RulesForRoom(object): user_ids.add(uid) rules_by_user = yield self.store.bulk_get_push_rules( - user_ids, on_invalidate=self.invalidate_all_cb, + user_ids, on_invalidate=self.invalidate_all_cb ) ret_rules_by_user.update( diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py index 8bd96b1178..a59b639f15 100644 --- a/synapse/push/clientformat.py +++ b/synapse/push/clientformat.py @@ -25,14 +25,14 @@ def format_push_rules_for_user(user, ruleslist): # We're going to be mutating this a lot, so do a deep copy ruleslist = copy.deepcopy(ruleslist) - rules = {'global': {}, 'device': {}} + rules = {"global": {}, "device": {}} - rules['global'] = _add_empty_priority_class_arrays(rules['global']) + rules["global"] = _add_empty_priority_class_arrays(rules["global"]) for r in ruleslist: rulearray = None - template_name = _priority_class_to_template_name(r['priority_class']) + template_name = _priority_class_to_template_name(r["priority_class"]) # Remove internal stuff. for c in r["conditions"]: @@ -44,14 +44,14 @@ def format_push_rules_for_user(user, ruleslist): elif pattern_type == "user_localpart": c["pattern"] = user.localpart - rulearray = rules['global'][template_name] + rulearray = rules["global"][template_name] template_rule = _rule_to_template(r) if template_rule: - if 'enabled' in r: - template_rule['enabled'] = r['enabled'] + if "enabled" in r: + template_rule["enabled"] = r["enabled"] else: - template_rule['enabled'] = True + template_rule["enabled"] = True rulearray.append(template_rule) return rules @@ -65,33 +65,33 @@ def _add_empty_priority_class_arrays(d): def _rule_to_template(rule): unscoped_rule_id = None - if 'rule_id' in rule: - unscoped_rule_id = _rule_id_from_namespaced(rule['rule_id']) + if "rule_id" in rule: + unscoped_rule_id = _rule_id_from_namespaced(rule["rule_id"]) - template_name = _priority_class_to_template_name(rule['priority_class']) - if template_name in ['override', 'underride']: + template_name = _priority_class_to_template_name(rule["priority_class"]) + if template_name in ["override", "underride"]: templaterule = {k: rule[k] for k in ["conditions", "actions"]} elif template_name in ["sender", "room"]: - templaterule = {'actions': rule['actions']} - unscoped_rule_id = rule['conditions'][0]['pattern'] - elif template_name == 'content': + templaterule = {"actions": rule["actions"]} + unscoped_rule_id = rule["conditions"][0]["pattern"] + elif template_name == "content": if len(rule["conditions"]) != 1: return None thecond = rule["conditions"][0] if "pattern" not in thecond: return None - templaterule = {'actions': rule['actions']} + templaterule = {"actions": rule["actions"]} templaterule["pattern"] = thecond["pattern"] if unscoped_rule_id: - templaterule['rule_id'] = unscoped_rule_id - if 'default' in rule: - templaterule['default'] = rule['default'] + templaterule["rule_id"] = unscoped_rule_id + if "default" in rule: + templaterule["default"] = rule["default"] return templaterule def _rule_id_from_namespaced(in_rule_id): - return in_rule_id.split('/')[-1] + return in_rule_id.split("/")[-1] def _priority_class_to_template_name(pc): diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index c89a8438a9..424ffa8b68 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -32,13 +32,13 @@ DELAY_BEFORE_MAIL_MS = 10 * 60 * 1000 THROTTLE_START_MS = 10 * 60 * 1000 THROTTLE_MAX_MS = 24 * 60 * 60 * 1000 # 24h # THROTTLE_MULTIPLIER = 6 # 10 mins, 1 hour, 6 hours, 24 hours -THROTTLE_MULTIPLIER = 144 # 10 mins, 24 hours - i.e. jump straight to 1 day +THROTTLE_MULTIPLIER = 144 # 10 mins, 24 hours - i.e. jump straight to 1 day # If no event triggers a notification for this long after the previous, # the throttle is released. # 12 hours - a gap of 12 hours in conversation is surely enough to merit a new # notification when things get going again... -THROTTLE_RESET_AFTER_MS = (12 * 60 * 60 * 1000) +THROTTLE_RESET_AFTER_MS = 12 * 60 * 60 * 1000 # does each email include all unread notifs, or just the ones which have happened # since the last mail? @@ -53,17 +53,18 @@ class EmailPusher(object): This shares quite a bit of code with httpusher: it would be good to factor out the common parts """ + def __init__(self, hs, pusherdict, mailer): self.hs = hs self.mailer = mailer self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() - self.pusher_id = pusherdict['id'] - self.user_id = pusherdict['user_name'] - self.app_id = pusherdict['app_id'] - self.email = pusherdict['pushkey'] - self.last_stream_ordering = pusherdict['last_stream_ordering'] + self.pusher_id = pusherdict["id"] + self.user_id = pusherdict["user_name"] + self.app_id = pusherdict["app_id"] + self.email = pusherdict["pushkey"] + self.last_stream_ordering = pusherdict["last_stream_ordering"] self.timed_call = None self.throttle_params = None @@ -93,7 +94,9 @@ class EmailPusher(object): def on_new_notifications(self, min_stream_ordering, max_stream_ordering): if self.max_stream_ordering: - self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering) + self.max_stream_ordering = max( + max_stream_ordering, self.max_stream_ordering + ) else: self.max_stream_ordering = max_stream_ordering self._start_processing() @@ -174,14 +177,12 @@ class EmailPusher(object): return for push_action in unprocessed: - received_at = push_action['received_ts'] + received_at = push_action["received_ts"] if received_at is None: received_at = 0 notif_ready_at = received_at + DELAY_BEFORE_MAIL_MS - room_ready_at = self.room_ready_to_notify_at( - push_action['room_id'] - ) + room_ready_at = self.room_ready_to_notify_at(push_action["room_id"]) should_notify_at = max(notif_ready_at, room_ready_at) @@ -192,25 +193,23 @@ class EmailPusher(object): # to be delivered. reason = { - 'room_id': push_action['room_id'], - 'now': self.clock.time_msec(), - 'received_at': received_at, - 'delay_before_mail_ms': DELAY_BEFORE_MAIL_MS, - 'last_sent_ts': self.get_room_last_sent_ts(push_action['room_id']), - 'throttle_ms': self.get_room_throttle_ms(push_action['room_id']), + "room_id": push_action["room_id"], + "now": self.clock.time_msec(), + "received_at": received_at, + "delay_before_mail_ms": DELAY_BEFORE_MAIL_MS, + "last_sent_ts": self.get_room_last_sent_ts(push_action["room_id"]), + "throttle_ms": self.get_room_throttle_ms(push_action["room_id"]), } yield self.send_notification(unprocessed, reason) - yield self.save_last_stream_ordering_and_success(max([ - ea['stream_ordering'] for ea in unprocessed - ])) + yield self.save_last_stream_ordering_and_success( + max([ea["stream_ordering"] for ea in unprocessed]) + ) # we update the throttle on all the possible unprocessed push actions for ea in unprocessed: - yield self.sent_notif_update_throttle( - ea['room_id'], ea - ) + yield self.sent_notif_update_throttle(ea["room_id"], ea) break else: if soonest_due_at is None or should_notify_at < soonest_due_at: @@ -236,8 +235,11 @@ class EmailPusher(object): self.last_stream_ordering = last_stream_ordering yield self.store.update_pusher_last_stream_ordering_and_success( - self.app_id, self.email, self.user_id, - last_stream_ordering, self.clock.time_msec() + self.app_id, + self.email, + self.user_id, + last_stream_ordering, + self.clock.time_msec(), ) def seconds_until(self, ts_msec): @@ -276,10 +278,10 @@ class EmailPusher(object): # THROTTLE_RESET_AFTER_MS after the previous one that triggered a # notif, we release the throttle. Otherwise, the throttle is increased. time_of_previous_notifs = yield self.store.get_time_of_last_push_action_before( - notified_push_action['stream_ordering'] + notified_push_action["stream_ordering"] ) - time_of_this_notifs = notified_push_action['received_ts'] + time_of_this_notifs = notified_push_action["received_ts"] if time_of_previous_notifs is not None and time_of_this_notifs is not None: gap = time_of_this_notifs - time_of_previous_notifs @@ -298,12 +300,11 @@ class EmailPusher(object): new_throttle_ms = THROTTLE_START_MS else: new_throttle_ms = min( - current_throttle_ms * THROTTLE_MULTIPLIER, - THROTTLE_MAX_MS + current_throttle_ms * THROTTLE_MULTIPLIER, THROTTLE_MAX_MS ) self.throttle_params[room_id] = { "last_sent_ts": self.clock.time_msec(), - "throttle_ms": new_throttle_ms + "throttle_ms": new_throttle_ms, } yield self.store.set_throttle_params( self.pusher_id, room_id, self.throttle_params[room_id] diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index a21b164266..6ae4415860 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -65,16 +65,16 @@ class HttpPusher(object): self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() self.state_handler = self.hs.get_state_handler() - self.user_id = pusherdict['user_name'] - self.app_id = pusherdict['app_id'] - self.app_display_name = pusherdict['app_display_name'] - self.device_display_name = pusherdict['device_display_name'] - self.pushkey = pusherdict['pushkey'] - self.pushkey_ts = pusherdict['ts'] - self.data = pusherdict['data'] - self.last_stream_ordering = pusherdict['last_stream_ordering'] + self.user_id = pusherdict["user_name"] + self.app_id = pusherdict["app_id"] + self.app_display_name = pusherdict["app_display_name"] + self.device_display_name = pusherdict["device_display_name"] + self.pushkey = pusherdict["pushkey"] + self.pushkey_ts = pusherdict["ts"] + self.data = pusherdict["data"] + self.last_stream_ordering = pusherdict["last_stream_ordering"] self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC - self.failing_since = pusherdict['failing_since'] + self.failing_since = pusherdict["failing_since"] self.timed_call = None self._is_processing = False @@ -85,32 +85,26 @@ class HttpPusher(object): # off as None though as we don't know any better. self.max_stream_ordering = None - if 'data' not in pusherdict: - raise PusherConfigException( - "No 'data' key for HTTP pusher" - ) - self.data = pusherdict['data'] + if "data" not in pusherdict: + raise PusherConfigException("No 'data' key for HTTP pusher") + self.data = pusherdict["data"] self.name = "%s/%s/%s" % ( - pusherdict['user_name'], - pusherdict['app_id'], - pusherdict['pushkey'], + pusherdict["user_name"], + pusherdict["app_id"], + pusherdict["pushkey"], ) if self.data is None: - raise PusherConfigException( - "data can not be null for HTTP pusher" - ) + raise PusherConfigException("data can not be null for HTTP pusher") - if 'url' not in self.data: - raise PusherConfigException( - "'url' required in data for HTTP pusher" - ) - self.url = self.data['url'] + if "url" not in self.data: + raise PusherConfigException("'url' required in data for HTTP pusher") + self.url = self.data["url"] self.http_client = hs.get_proxied_http_client() self.data_minus_url = {} self.data_minus_url.update(self.data) - del self.data_minus_url['url'] + del self.data_minus_url["url"] def on_started(self, should_check_for_notifs): """Called when this pusher has been started. @@ -124,7 +118,9 @@ class HttpPusher(object): self._start_processing() def on_new_notifications(self, min_stream_ordering, max_stream_ordering): - self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering or 0) + self.max_stream_ordering = max( + max_stream_ordering, self.max_stream_ordering or 0 + ) self._start_processing() def on_new_receipts(self, min_stream_id, max_stream_id): @@ -192,7 +188,9 @@ class HttpPusher(object): logger.info( "Processing %i unprocessed push actions for %s starting at " "stream_ordering %s", - len(unprocessed), self.name, self.last_stream_ordering, + len(unprocessed), + self.name, + self.last_stream_ordering, ) for push_action in unprocessed: @@ -200,71 +198,72 @@ class HttpPusher(object): if processed: http_push_processed_counter.inc() self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC - self.last_stream_ordering = push_action['stream_ordering'] + self.last_stream_ordering = push_action["stream_ordering"] yield self.store.update_pusher_last_stream_ordering_and_success( - self.app_id, self.pushkey, self.user_id, + self.app_id, + self.pushkey, + self.user_id, self.last_stream_ordering, - self.clock.time_msec() + self.clock.time_msec(), ) if self.failing_since: self.failing_since = None yield self.store.update_pusher_failing_since( - self.app_id, self.pushkey, self.user_id, - self.failing_since + self.app_id, self.pushkey, self.user_id, self.failing_since ) else: http_push_failed_counter.inc() if not self.failing_since: self.failing_since = self.clock.time_msec() yield self.store.update_pusher_failing_since( - self.app_id, self.pushkey, self.user_id, - self.failing_since + self.app_id, self.pushkey, self.user_id, self.failing_since ) if ( - self.failing_since and - self.failing_since < - self.clock.time_msec() - HttpPusher.GIVE_UP_AFTER_MS + self.failing_since + and self.failing_since + < self.clock.time_msec() - HttpPusher.GIVE_UP_AFTER_MS ): # we really only give up so that if the URL gets # fixed, we don't suddenly deliver a load # of old notifications. - logger.warn("Giving up on a notification to user %s, " - "pushkey %s", - self.user_id, self.pushkey) + logger.warn( + "Giving up on a notification to user %s, " "pushkey %s", + self.user_id, + self.pushkey, + ) self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC - self.last_stream_ordering = push_action['stream_ordering'] + self.last_stream_ordering = push_action["stream_ordering"] yield self.store.update_pusher_last_stream_ordering( self.app_id, self.pushkey, self.user_id, - self.last_stream_ordering + self.last_stream_ordering, ) self.failing_since = None yield self.store.update_pusher_failing_since( - self.app_id, - self.pushkey, - self.user_id, - self.failing_since + self.app_id, self.pushkey, self.user_id, self.failing_since ) else: logger.info("Push failed: delaying for %ds", self.backoff_delay) self.timed_call = self.hs.get_reactor().callLater( self.backoff_delay, self.on_timer ) - self.backoff_delay = min(self.backoff_delay * 2, self.MAX_BACKOFF_SEC) + self.backoff_delay = min( + self.backoff_delay * 2, self.MAX_BACKOFF_SEC + ) break @defer.inlineCallbacks def _process_one(self, push_action): - if 'notify' not in push_action['actions']: + if "notify" not in push_action["actions"]: defer.returnValue(True) - tweaks = push_rule_evaluator.tweaks_for_actions(push_action['actions']) + tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"]) badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) - event = yield self.store.get_event(push_action['event_id'], allow_none=True) + event = yield self.store.get_event(push_action["event_id"], allow_none=True) if event is None: defer.returnValue(True) # It's been redacted rejected = yield self.dispatch_push(event, tweaks, badge) @@ -277,37 +276,30 @@ class HttpPusher(object): # for sanity, we only remove the pushkey if it # was the one we actually sent... logger.warn( - ("Ignoring rejected pushkey %s because we" - " didn't send it"), pk + ("Ignoring rejected pushkey %s because we" " didn't send it"), + pk, ) else: - logger.info( - "Pushkey %s was rejected: removing", - pk - ) - yield self.hs.remove_pusher( - self.app_id, pk, self.user_id - ) + logger.info("Pushkey %s was rejected: removing", pk) + yield self.hs.remove_pusher(self.app_id, pk, self.user_id) defer.returnValue(True) @defer.inlineCallbacks def _build_notification_dict(self, event, tweaks, badge): - if self.data.get('format') == 'event_id_only': + if self.data.get("format") == "event_id_only": d = { - 'notification': { - 'event_id': event.event_id, - 'room_id': event.room_id, - 'counts': { - 'unread': badge, - }, - 'devices': [ + "notification": { + "event_id": event.event_id, + "room_id": event.room_id, + "counts": {"unread": badge}, + "devices": [ { - 'app_id': self.app_id, - 'pushkey': self.pushkey, - 'pushkey_ts': long(self.pushkey_ts / 1000), - 'data': self.data_minus_url, + "app_id": self.app_id, + "pushkey": self.pushkey, + "pushkey_ts": long(self.pushkey_ts / 1000), + "data": self.data_minus_url, } - ] + ], } } defer.returnValue(d) @@ -317,41 +309,41 @@ class HttpPusher(object): ) d = { - 'notification': { - 'id': event.event_id, # deprecated: remove soon - 'event_id': event.event_id, - 'room_id': event.room_id, - 'type': event.type, - 'sender': event.user_id, - 'counts': { # -- we don't mark messages as read yet so - # we have no way of knowing + "notification": { + "id": event.event_id, # deprecated: remove soon + "event_id": event.event_id, + "room_id": event.room_id, + "type": event.type, + "sender": event.user_id, + "counts": { # -- we don't mark messages as read yet so + # we have no way of knowing # Just set the badge to 1 until we have read receipts - 'unread': badge, + "unread": badge, # 'missed_calls': 2 }, - 'devices': [ + "devices": [ { - 'app_id': self.app_id, - 'pushkey': self.pushkey, - 'pushkey_ts': long(self.pushkey_ts / 1000), - 'data': self.data_minus_url, - 'tweaks': tweaks + "app_id": self.app_id, + "pushkey": self.pushkey, + "pushkey_ts": long(self.pushkey_ts / 1000), + "data": self.data_minus_url, + "tweaks": tweaks, } - ] + ], } } - if event.type == 'm.room.member' and event.is_state(): - d['notification']['membership'] = event.content['membership'] - d['notification']['user_is_target'] = event.state_key == self.user_id + if event.type == "m.room.member" and event.is_state(): + d["notification"]["membership"] = event.content["membership"] + d["notification"]["user_is_target"] = event.state_key == self.user_id if self.hs.config.push_include_content and event.content: - d['notification']['content'] = event.content + d["notification"]["content"] = event.content # We no longer send aliases separately, instead, we send the human # readable name of the room, which may be an alias. - if 'sender_display_name' in ctx and len(ctx['sender_display_name']) > 0: - d['notification']['sender_display_name'] = ctx['sender_display_name'] - if 'name' in ctx and len(ctx['name']) > 0: - d['notification']['room_name'] = ctx['name'] + if "sender_display_name" in ctx and len(ctx["sender_display_name"]) > 0: + d["notification"]["sender_display_name"] = ctx["sender_display_name"] + if "name" in ctx and len(ctx["name"]) > 0: + d["notification"]["room_name"] = ctx["name"] defer.returnValue(d) @@ -361,16 +353,21 @@ class HttpPusher(object): if not notification_dict: defer.returnValue([]) try: - resp = yield self.http_client.post_json_get_json(self.url, notification_dict) + resp = yield self.http_client.post_json_get_json( + self.url, notification_dict + ) except Exception as e: logger.warning( "Failed to push event %s to %s: %s %s", - event.event_id, self.name, type(e), e, + event.event_id, + self.name, + type(e), + e, ) defer.returnValue(False) rejected = [] - if 'rejected' in resp: - rejected = resp['rejected'] + if "rejected" in resp: + rejected = resp["rejected"] defer.returnValue(rejected) @defer.inlineCallbacks @@ -381,21 +378,19 @@ class HttpPusher(object): """ logger.info("Sending updated badge count %d to %s", badge, self.name) d = { - 'notification': { - 'id': '', - 'type': None, - 'sender': '', - 'counts': { - 'unread': badge - }, - 'devices': [ + "notification": { + "id": "", + "type": None, + "sender": "", + "counts": {"unread": badge}, + "devices": [ { - 'app_id': self.app_id, - 'pushkey': self.pushkey, - 'pushkey_ts': long(self.pushkey_ts / 1000), - 'data': self.data_minus_url, + "app_id": self.app_id, + "pushkey": self.pushkey, + "pushkey_ts": long(self.pushkey_ts / 1000), + "data": self.data_minus_url, } - ] + ], } } try: @@ -403,7 +398,6 @@ class HttpPusher(object): http_badges_processed_counter.inc() except Exception as e: logger.warning( - "Failed to send badge count to %s: %s %s", - self.name, type(e), e, + "Failed to send badge count to %s: %s %s", self.name, type(e), e ) http_badges_failed_counter.inc() diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 099f9545ab..a5c9959a46 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -42,17 +42,17 @@ from synapse.visibility import filter_events_for_client logger = logging.getLogger(__name__) -MESSAGE_FROM_PERSON_IN_ROOM = "You have a message on %(app)s from %(person)s " \ - "in the %(room)s room..." +MESSAGE_FROM_PERSON_IN_ROOM = "You have a message on %(app)s from %(person)s " "in the %(room)s room..." MESSAGE_FROM_PERSON = "You have a message on %(app)s from %(person)s..." MESSAGES_FROM_PERSON = "You have messages on %(app)s from %(person)s..." MESSAGES_IN_ROOM = "You have messages on %(app)s in the %(room)s room..." -MESSAGES_IN_ROOM_AND_OTHERS = \ +MESSAGES_IN_ROOM_AND_OTHERS = ( "You have messages on %(app)s in the %(room)s room and others..." -MESSAGES_FROM_PERSON_AND_OTHERS = \ +) +MESSAGES_FROM_PERSON_AND_OTHERS = ( "You have messages on %(app)s from %(person)s and others..." -INVITE_FROM_PERSON_TO_ROOM = "%(person)s has invited you to join the " \ - "%(room)s room on %(app)s..." +) +INVITE_FROM_PERSON_TO_ROOM = "%(person)s has invited you to join the " "%(room)s room on %(app)s..." INVITE_FROM_PERSON = "%(person)s has invited you to chat on %(app)s..." CONTEXT_BEFORE = 1 @@ -60,12 +60,38 @@ CONTEXT_AFTER = 1 # From https://github.com/matrix-org/matrix-react-sdk/blob/master/src/HtmlUtils.js ALLOWED_TAGS = [ - 'font', # custom to matrix for IRC-style font coloring - 'del', # for markdown + "font", # custom to matrix for IRC-style font coloring + "del", # for markdown # deliberately no h1/h2 to stop people shouting. - 'h3', 'h4', 'h5', 'h6', 'blockquote', 'p', 'a', 'ul', 'ol', - 'nl', 'li', 'b', 'i', 'u', 'strong', 'em', 'strike', 'code', 'hr', 'br', 'div', - 'table', 'thead', 'caption', 'tbody', 'tr', 'th', 'td', 'pre' + "h3", + "h4", + "h5", + "h6", + "blockquote", + "p", + "a", + "ul", + "ol", + "nl", + "li", + "b", + "i", + "u", + "strong", + "em", + "strike", + "code", + "hr", + "br", + "div", + "table", + "thead", + "caption", + "tbody", + "tr", + "th", + "td", + "pre", ] ALLOWED_ATTRS = { # custom ones first: @@ -94,13 +120,7 @@ class Mailer(object): logger.info("Created Mailer for app_name %s" % app_name) @defer.inlineCallbacks - def send_password_reset_mail( - self, - email_address, - token, - client_secret, - sid, - ): + def send_password_reset_mail(self, email_address, token, client_secret, sid): """Send an email with a password reset link to a user Args: @@ -112,19 +132,16 @@ class Mailer(object): group together multiple email sending attempts sid (str): The generated session ID """ - if email.utils.parseaddr(email_address)[1] == '': + if email.utils.parseaddr(email_address)[1] == "": raise RuntimeError("Invalid 'to' email address") link = ( - self.hs.config.public_baseurl + - "_matrix/client/unstable/password_reset/email/submit_token" - "?token=%s&client_secret=%s&sid=%s" % - (token, client_secret, sid) + self.hs.config.public_baseurl + + "_matrix/client/unstable/password_reset/email/submit_token" + "?token=%s&client_secret=%s&sid=%s" % (token, client_secret, sid) ) - template_vars = { - "link": link, - } + template_vars = {"link": link} yield self.send_email( email_address, @@ -133,15 +150,14 @@ class Mailer(object): ) @defer.inlineCallbacks - def send_notification_mail(self, app_id, user_id, email_address, - push_actions, reason): + def send_notification_mail( + self, app_id, user_id, email_address, push_actions, reason + ): """Send email regarding a user's room notifications""" - rooms_in_order = deduped_ordered_list( - [pa['room_id'] for pa in push_actions] - ) + rooms_in_order = deduped_ordered_list([pa["room_id"] for pa in push_actions]) notif_events = yield self.store.get_events( - [pa['event_id'] for pa in push_actions] + [pa["event_id"] for pa in push_actions] ) notifs_by_room = {} @@ -171,9 +187,7 @@ class Mailer(object): yield concurrently_execute(_fetch_room_state, rooms_in_order, 3) # actually sort our so-called rooms_in_order list, most recent room first - rooms_in_order.sort( - key=lambda r: -(notifs_by_room[r][-1]['received_ts'] or 0) - ) + rooms_in_order.sort(key=lambda r: -(notifs_by_room[r][-1]["received_ts"] or 0)) rooms = [] @@ -183,9 +197,11 @@ class Mailer(object): ) rooms.append(roomvars) - reason['room_name'] = yield calculate_room_name( - self.store, state_by_room[reason['room_id']], user_id, - fallback_to_members=True + reason["room_name"] = yield calculate_room_name( + self.store, + state_by_room[reason["room_id"]], + user_id, + fallback_to_members=True, ) summary_text = yield self.make_summary_text( @@ -204,25 +220,21 @@ class Mailer(object): } yield self.send_email( - email_address, - "[%s] %s" % (self.app_name, summary_text), - template_vars, + email_address, "[%s] %s" % (self.app_name, summary_text), template_vars ) @defer.inlineCallbacks def send_email(self, email_address, subject, template_vars): """Send an email with the given information and template text""" try: - from_string = self.hs.config.email_notif_from % { - "app": self.app_name - } + from_string = self.hs.config.email_notif_from % {"app": self.app_name} except TypeError: from_string = self.hs.config.email_notif_from raw_from = email.utils.parseaddr(from_string)[1] raw_to = email.utils.parseaddr(email_address)[1] - if raw_to == '': + if raw_to == "": raise RuntimeError("Invalid 'to' address") html_text = self.template_html.render(**template_vars) @@ -231,27 +243,31 @@ class Mailer(object): plain_text = self.template_text.render(**template_vars) text_part = MIMEText(plain_text, "plain", "utf8") - multipart_msg = MIMEMultipart('alternative') - multipart_msg['Subject'] = subject - multipart_msg['From'] = from_string - multipart_msg['To'] = email_address - multipart_msg['Date'] = email.utils.formatdate() - multipart_msg['Message-ID'] = email.utils.make_msgid() + multipart_msg = MIMEMultipart("alternative") + multipart_msg["Subject"] = subject + multipart_msg["From"] = from_string + multipart_msg["To"] = email_address + multipart_msg["Date"] = email.utils.formatdate() + multipart_msg["Message-ID"] = email.utils.make_msgid() multipart_msg.attach(text_part) multipart_msg.attach(html_part) logger.info("Sending email push notification to %s" % email_address) - yield make_deferred_yieldable(self.sendmail( - self.hs.config.email_smtp_host, - raw_from, raw_to, multipart_msg.as_string().encode('utf8'), - reactor=self.hs.get_reactor(), - port=self.hs.config.email_smtp_port, - requireAuthentication=self.hs.config.email_smtp_user is not None, - username=self.hs.config.email_smtp_user, - password=self.hs.config.email_smtp_pass, - requireTransportSecurity=self.hs.config.require_transport_security - )) + yield make_deferred_yieldable( + self.sendmail( + self.hs.config.email_smtp_host, + raw_from, + raw_to, + multipart_msg.as_string().encode("utf8"), + reactor=self.hs.get_reactor(), + port=self.hs.config.email_smtp_port, + requireAuthentication=self.hs.config.email_smtp_user is not None, + username=self.hs.config.email_smtp_user, + password=self.hs.config.email_smtp_pass, + requireTransportSecurity=self.hs.config.require_transport_security, + ) + ) @defer.inlineCallbacks def get_room_vars(self, room_id, user_id, notifs, notif_events, room_state_ids): @@ -272,17 +288,18 @@ class Mailer(object): if not is_invite: for n in notifs: notifvars = yield self.get_notif_vars( - n, user_id, notif_events[n['event_id']], room_state_ids + n, user_id, notif_events[n["event_id"]], room_state_ids ) # merge overlapping notifs together. # relies on the notifs being in chronological order. merge = False - if room_vars['notifs'] and 'messages' in room_vars['notifs'][-1]: - prev_messages = room_vars['notifs'][-1]['messages'] - for message in notifvars['messages']: - pm = list(filter(lambda pm: pm['id'] == message['id'], - prev_messages)) + if room_vars["notifs"] and "messages" in room_vars["notifs"][-1]: + prev_messages = room_vars["notifs"][-1]["messages"] + for message in notifvars["messages"]: + pm = list( + filter(lambda pm: pm["id"] == message["id"], prev_messages) + ) if pm: if not message["is_historical"]: pm[0]["is_historical"] = False @@ -293,20 +310,22 @@ class Mailer(object): prev_messages.append(message) if not merge: - room_vars['notifs'].append(notifvars) + room_vars["notifs"].append(notifvars) defer.returnValue(room_vars) @defer.inlineCallbacks def get_notif_vars(self, notif, user_id, notif_event, room_state_ids): results = yield self.store.get_events_around( - notif['room_id'], notif['event_id'], - before_limit=CONTEXT_BEFORE, after_limit=CONTEXT_AFTER + notif["room_id"], + notif["event_id"], + before_limit=CONTEXT_BEFORE, + after_limit=CONTEXT_AFTER, ) ret = { "link": self.make_notif_link(notif), - "ts": notif['received_ts'], + "ts": notif["received_ts"], "messages": [], } @@ -318,7 +337,7 @@ class Mailer(object): for event in the_events: messagevars = yield self.get_message_vars(notif, event, room_state_ids) if messagevars is not None: - ret['messages'].append(messagevars) + ret["messages"].append(messagevars) defer.returnValue(ret) @@ -340,7 +359,7 @@ class Mailer(object): ret = { "msgtype": msgtype, - "is_historical": event.event_id != notif['event_id'], + "is_historical": event.event_id != notif["event_id"], "id": event.event_id, "ts": event.origin_server_ts, "sender_name": sender_name, @@ -379,8 +398,9 @@ class Mailer(object): return messagevars @defer.inlineCallbacks - def make_summary_text(self, notifs_by_room, room_state_ids, - notif_events, user_id, reason): + def make_summary_text( + self, notifs_by_room, room_state_ids, notif_events, user_id, reason + ): if len(notifs_by_room) == 1: # Only one room has new stuff room_id = list(notifs_by_room.keys())[0] @@ -404,16 +424,19 @@ class Mailer(object): inviter_name = name_from_member_event(inviter_member_event) if room_name is None: - defer.returnValue(INVITE_FROM_PERSON % { - "person": inviter_name, - "app": self.app_name - }) + defer.returnValue( + INVITE_FROM_PERSON + % {"person": inviter_name, "app": self.app_name} + ) else: - defer.returnValue(INVITE_FROM_PERSON_TO_ROOM % { - "person": inviter_name, - "room": room_name, - "app": self.app_name, - }) + defer.returnValue( + INVITE_FROM_PERSON_TO_ROOM + % { + "person": inviter_name, + "room": room_name, + "app": self.app_name, + } + ) sender_name = None if len(notifs_by_room[room_id]) == 1: @@ -427,67 +450,86 @@ class Mailer(object): sender_name = name_from_member_event(state_event) if sender_name is not None and room_name is not None: - defer.returnValue(MESSAGE_FROM_PERSON_IN_ROOM % { - "person": sender_name, - "room": room_name, - "app": self.app_name, - }) + defer.returnValue( + MESSAGE_FROM_PERSON_IN_ROOM + % { + "person": sender_name, + "room": room_name, + "app": self.app_name, + } + ) elif sender_name is not None: - defer.returnValue(MESSAGE_FROM_PERSON % { - "person": sender_name, - "app": self.app_name, - }) + defer.returnValue( + MESSAGE_FROM_PERSON + % {"person": sender_name, "app": self.app_name} + ) else: # There's more than one notification for this room, so just # say there are several if room_name is not None: - defer.returnValue(MESSAGES_IN_ROOM % { - "room": room_name, - "app": self.app_name, - }) + defer.returnValue( + MESSAGES_IN_ROOM % {"room": room_name, "app": self.app_name} + ) else: # If the room doesn't have a name, say who the messages # are from explicitly to avoid, "messages in the Bob room" - sender_ids = list(set([ - notif_events[n['event_id']].sender - for n in notifs_by_room[room_id] - ])) - - member_events = yield self.store.get_events([ - room_state_ids[room_id][("m.room.member", s)] - for s in sender_ids - ]) - - defer.returnValue(MESSAGES_FROM_PERSON % { - "person": descriptor_from_member_events(member_events.values()), - "app": self.app_name, - }) + sender_ids = list( + set( + [ + notif_events[n["event_id"]].sender + for n in notifs_by_room[room_id] + ] + ) + ) + + member_events = yield self.store.get_events( + [ + room_state_ids[room_id][("m.room.member", s)] + for s in sender_ids + ] + ) + + defer.returnValue( + MESSAGES_FROM_PERSON + % { + "person": descriptor_from_member_events( + member_events.values() + ), + "app": self.app_name, + } + ) else: # Stuff's happened in multiple different rooms # ...but we still refer to the 'reason' room which triggered the mail - if reason['room_name'] is not None: - defer.returnValue(MESSAGES_IN_ROOM_AND_OTHERS % { - "room": reason['room_name'], - "app": self.app_name, - }) + if reason["room_name"] is not None: + defer.returnValue( + MESSAGES_IN_ROOM_AND_OTHERS + % {"room": reason["room_name"], "app": self.app_name} + ) else: # If the reason room doesn't have a name, say who the messages # are from explicitly to avoid, "messages in the Bob room" - sender_ids = list(set([ - notif_events[n['event_id']].sender - for n in notifs_by_room[reason['room_id']] - ])) + sender_ids = list( + set( + [ + notif_events[n["event_id"]].sender + for n in notifs_by_room[reason["room_id"]] + ] + ) + ) - member_events = yield self.store.get_events([ - room_state_ids[room_id][("m.room.member", s)] - for s in sender_ids - ]) + member_events = yield self.store.get_events( + [room_state_ids[room_id][("m.room.member", s)] for s in sender_ids] + ) - defer.returnValue(MESSAGES_FROM_PERSON_AND_OTHERS % { - "person": descriptor_from_member_events(member_events.values()), - "app": self.app_name, - }) + defer.returnValue( + MESSAGES_FROM_PERSON_AND_OTHERS + % { + "person": descriptor_from_member_events(member_events.values()), + "app": self.app_name, + } + ) def make_room_link(self, room_id): if self.hs.config.email_riot_base_url: @@ -503,17 +545,17 @@ class Mailer(object): if self.hs.config.email_riot_base_url: return "%s/#/room/%s/%s" % ( self.hs.config.email_riot_base_url, - notif['room_id'], notif['event_id'] + notif["room_id"], + notif["event_id"], ) elif self.app_name == "Vector": # need /beta for Universal Links to work on iOS return "https://vector.im/beta/#/room/%s/%s" % ( - notif['room_id'], notif['event_id'] + notif["room_id"], + notif["event_id"], ) else: - return "https://matrix.to/#/%s/%s" % ( - notif['room_id'], notif['event_id'] - ) + return "https://matrix.to/#/%s/%s" % (notif["room_id"], notif["event_id"]) def make_unsubscribe_link(self, user_id, app_id, email_address): params = { @@ -530,12 +572,18 @@ class Mailer(object): def safe_markup(raw_html): - return jinja2.Markup(bleach.linkify(bleach.clean( - raw_html, tags=ALLOWED_TAGS, attributes=ALLOWED_ATTRS, - # bleach master has this, but it isn't released yet - # protocols=ALLOWED_SCHEMES, - strip=True - ))) + return jinja2.Markup( + bleach.linkify( + bleach.clean( + raw_html, + tags=ALLOWED_TAGS, + attributes=ALLOWED_ATTRS, + # bleach master has this, but it isn't released yet + # protocols=ALLOWED_SCHEMES, + strip=True, + ) + ) + ) def safe_text(raw_text): @@ -543,10 +591,9 @@ def safe_text(raw_text): Process text: treat it as HTML but escape any tags (ie. just escape the HTML) then linkify it. """ - return jinja2.Markup(bleach.linkify(bleach.clean( - raw_text, tags=[], attributes={}, - strip=False - ))) + return jinja2.Markup( + bleach.linkify(bleach.clean(raw_text, tags=[], attributes={}, strip=False)) + ) def deduped_ordered_list(l): @@ -595,15 +642,11 @@ def _create_mxc_to_http_filter(config): serverAndMediaId = value[6:] fragment = None - if '#' in serverAndMediaId: - (serverAndMediaId, fragment) = serverAndMediaId.split('#', 1) + if "#" in serverAndMediaId: + (serverAndMediaId, fragment) = serverAndMediaId.split("#", 1) fragment = "#" + fragment - params = { - "width": width, - "height": height, - "method": resize_method, - } + params = {"width": width, "height": height, "method": resize_method} return "%s_matrix/media/v1/thumbnail/%s?%s%s" % ( config.public_baseurl, serverAndMediaId, diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py index 0c66702325..06056fbf4f 100644 --- a/synapse/push/presentable_names.py +++ b/synapse/push/presentable_names.py @@ -28,8 +28,13 @@ ALL_ALONE = "Empty Room" @defer.inlineCallbacks -def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True, - fallback_to_single_member=True): +def calculate_room_name( + store, + room_state_ids, + user_id, + fallback_to_members=True, + fallback_to_single_member=True, +): """ Works out a user-facing name for the given room as per Matrix spec recommendations. @@ -58,8 +63,10 @@ def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True room_state_ids[("m.room.canonical_alias", "")], allow_none=True ) if ( - canon_alias and canon_alias.content and canon_alias.content["alias"] and - _looks_like_an_alias(canon_alias.content["alias"]) + canon_alias + and canon_alias.content + and canon_alias.content["alias"] + and _looks_like_an_alias(canon_alias.content["alias"]) ): defer.returnValue(canon_alias.content["alias"]) @@ -71,9 +78,7 @@ def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True if "m.room.aliases" in room_state_bytype_ids: m_room_aliases = room_state_bytype_ids["m.room.aliases"] for alias_id in m_room_aliases.values(): - alias_event = yield store.get_event( - alias_id, allow_none=True - ) + alias_event = yield store.get_event(alias_id, allow_none=True) if alias_event and alias_event.content.get("aliases"): the_aliases = alias_event.content["aliases"] if len(the_aliases) > 0 and _looks_like_an_alias(the_aliases[0]): @@ -89,8 +94,8 @@ def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True ) if ( - my_member_event is not None and - my_member_event.content['membership'] == "invite" + my_member_event is not None + and my_member_event.content["membership"] == "invite" ): if ("m.room.member", my_member_event.sender) in room_state_ids: inviter_member_event = yield store.get_event( @@ -100,9 +105,8 @@ def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True if inviter_member_event: if fallback_to_single_member: defer.returnValue( - "Invite from %s" % ( - name_from_member_event(inviter_member_event), - ) + "Invite from %s" + % (name_from_member_event(inviter_member_event),) ) else: return @@ -116,8 +120,10 @@ def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True list(room_state_bytype_ids["m.room.member"].values()) ) all_members = [ - ev for ev in member_events.values() - if ev.content['membership'] == "join" or ev.content['membership'] == "invite" + ev + for ev in member_events.values() + if ev.content["membership"] == "join" + or ev.content["membership"] == "invite" ] # Sort the member events oldest-first so the we name people in the # order the joined (it should at least be deterministic rather than @@ -134,9 +140,9 @@ def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True # or inbound invite, or outbound 3PID invite. if all_members[0].sender == user_id: if "m.room.third_party_invite" in room_state_bytype_ids: - third_party_invites = ( - room_state_bytype_ids["m.room.third_party_invite"].values() - ) + third_party_invites = room_state_bytype_ids[ + "m.room.third_party_invite" + ].values() if len(third_party_invites) > 0: # technically third party invite events are not member @@ -191,8 +197,9 @@ def descriptor_from_member_events(member_events): def name_from_member_event(member_event): if ( - member_event.content and "displayname" in member_event.content and - member_event.content["displayname"] + member_event.content + and "displayname" in member_event.content + and member_event.content["displayname"] ): return member_event.content["displayname"] return member_event.state_key diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index cf6c8b875e..5ed9147de4 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -26,8 +26,8 @@ from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) -GLOB_REGEX = re.compile(r'\\\[(\\\!|)(.*)\\\]') -IS_GLOB = re.compile(r'[\?\*\[\]]') +GLOB_REGEX = re.compile(r"\\\[(\\\!|)(.*)\\\]") +IS_GLOB = re.compile(r"[\?\*\[\]]") INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$") @@ -36,20 +36,20 @@ def _room_member_count(ev, condition, room_member_count): def _sender_notification_permission(ev, condition, sender_power_level, power_levels): - notif_level_key = condition.get('key') + notif_level_key = condition.get("key") if notif_level_key is None: return False - notif_levels = power_levels.get('notifications', {}) + notif_levels = power_levels.get("notifications", {}) room_notif_level = notif_levels.get(notif_level_key, 50) return sender_power_level >= room_notif_level def _test_ineq_condition(condition, number): - if 'is' not in condition: + if "is" not in condition: return False - m = INEQUALITY_EXPR.match(condition['is']) + m = INEQUALITY_EXPR.match(condition["is"]) if not m: return False ineq = m.group(1) @@ -58,15 +58,15 @@ def _test_ineq_condition(condition, number): return False rhs = int(rhs) - if ineq == '' or ineq == '==': + if ineq == "" or ineq == "==": return number == rhs - elif ineq == '<': + elif ineq == "<": return number < rhs - elif ineq == '>': + elif ineq == ">": return number > rhs - elif ineq == '>=': + elif ineq == ">=": return number >= rhs - elif ineq == '<=': + elif ineq == "<=": return number <= rhs else: return False @@ -77,8 +77,8 @@ def tweaks_for_actions(actions): for a in actions: if not isinstance(a, dict): continue - if 'set_tweak' in a and 'value' in a: - tweaks[a['set_tweak']] = a['value'] + if "set_tweak" in a and "value" in a: + tweaks[a["set_tweak"]] = a["value"] return tweaks @@ -93,26 +93,24 @@ class PushRuleEvaluatorForEvent(object): self._value_cache = _flatten_dict(event) def matches(self, condition, user_id, display_name): - if condition['kind'] == 'event_match': + if condition["kind"] == "event_match": return self._event_match(condition, user_id) - elif condition['kind'] == 'contains_display_name': + elif condition["kind"] == "contains_display_name": return self._contains_display_name(display_name) - elif condition['kind'] == 'room_member_count': - return _room_member_count( - self._event, condition, self._room_member_count - ) - elif condition['kind'] == 'sender_notification_permission': + elif condition["kind"] == "room_member_count": + return _room_member_count(self._event, condition, self._room_member_count) + elif condition["kind"] == "sender_notification_permission": return _sender_notification_permission( - self._event, condition, self._sender_power_level, self._power_levels, + self._event, condition, self._sender_power_level, self._power_levels ) else: return True def _event_match(self, condition, user_id): - pattern = condition.get('pattern', None) + pattern = condition.get("pattern", None) if not pattern: - pattern_type = condition.get('pattern_type', None) + pattern_type = condition.get("pattern_type", None) if pattern_type == "user_id": pattern = user_id elif pattern_type == "user_localpart": @@ -123,14 +121,14 @@ class PushRuleEvaluatorForEvent(object): return False # XXX: optimisation: cache our pattern regexps - if condition['key'] == 'content.body': + if condition["key"] == "content.body": body = self._event.content.get("body", None) if not body: return False return _glob_matches(pattern, body, word_boundary=True) else: - haystack = self._get_value(condition['key']) + haystack = self._get_value(condition["key"]) if haystack is None: return False @@ -193,16 +191,13 @@ def _glob_to_re(glob, word_boundary): if IS_GLOB.search(glob): r = re.escape(glob) - r = r.replace(r'\*', '.*?') - r = r.replace(r'\?', '.') + r = r.replace(r"\*", ".*?") + r = r.replace(r"\?", ".") # handle [abc], [a-z] and [!a-z] style ranges. r = GLOB_REGEX.sub( lambda x: ( - '[%s%s]' % ( - x.group(1) and '^' or '', - x.group(2).replace(r'\\\-', '-') - ) + "[%s%s]" % (x.group(1) and "^" or "", x.group(2).replace(r"\\\-", "-")) ), r, ) diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index 8049c298c2..e37269cdb9 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -23,9 +23,7 @@ def get_badge_count(store, user_id): invites = yield store.get_invited_rooms_for_user(user_id) joins = yield store.get_rooms_for_user(user_id) - my_receipts_by_room = yield store.get_receipts_for_user( - user_id, "m.read", - ) + my_receipts_by_room = yield store.get_receipts_for_user(user_id, "m.read") badge = len(invites) @@ -57,10 +55,10 @@ def get_context_for_event(store, state_handler, ev, user_id): store, room_state_ids, user_id, fallback_to_single_member=False ) if name: - ctx['name'] = name + ctx["name"] = name sender_state_event_id = room_state_ids[("m.room.member", ev.sender)] sender_state_event = yield store.get_event(sender_state_event_id) - ctx['sender_display_name'] = name_from_member_event(sender_state_event) + ctx["sender_display_name"] = name_from_member_event(sender_state_event) defer.returnValue(ctx) diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py index aff85daeb5..a9c64a9c54 100644 --- a/synapse/push/pusher.py +++ b/synapse/push/pusher.py @@ -36,9 +36,7 @@ class PusherFactory(object): def __init__(self, hs): self.hs = hs - self.pusher_types = { - "http": HttpPusher, - } + self.pusher_types = {"http": HttpPusher} logger.info("email enable notifs: %r", hs.config.email_enable_notifs) if hs.config.email_enable_notifs: @@ -56,7 +54,7 @@ class PusherFactory(object): logger.info("defined email pusher type") def create_pusher(self, pusherdict): - kind = pusherdict['kind'] + kind = pusherdict["kind"] f = self.pusher_types.get(kind, None) if not f: return None @@ -77,8 +75,8 @@ class PusherFactory(object): return EmailPusher(self.hs, pusherdict, mailer) def _app_name_from_pusherdict(self, pusherdict): - if 'data' in pusherdict and 'brand' in pusherdict['data']: - app_name = pusherdict['data']['brand'] + if "data" in pusherdict and "brand" in pusherdict["data"]: + app_name = pusherdict["data"]["brand"] else: app_name = self.hs.config.email_app_name diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 63c583565f..df6f670740 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -40,6 +40,7 @@ class PusherPool: notifications are sent; accordingly Pusher.on_started, Pusher.on_new_notifications and Pusher.on_new_receipts are not expected to return deferreds. """ + def __init__(self, _hs): self.hs = _hs self.pusher_factory = PusherFactory(_hs) @@ -57,9 +58,19 @@ class PusherPool: run_as_background_process("start_pushers", self._start_pushers) @defer.inlineCallbacks - def add_pusher(self, user_id, access_token, kind, app_id, - app_display_name, device_display_name, pushkey, lang, data, - profile_tag=""): + def add_pusher( + self, + user_id, + access_token, + kind, + app_id, + app_display_name, + device_display_name, + pushkey, + lang, + data, + profile_tag="", + ): """Creates a new pusher and adds it to the pool Returns: @@ -71,21 +82,23 @@ class PusherPool: # will then get pulled out of the database, # recreated, added and started: this means we have only one # code path adding pushers. - self.pusher_factory.create_pusher({ - "id": None, - "user_name": user_id, - "kind": kind, - "app_id": app_id, - "app_display_name": app_display_name, - "device_display_name": device_display_name, - "pushkey": pushkey, - "ts": time_now_msec, - "lang": lang, - "data": data, - "last_stream_ordering": None, - "last_success": None, - "failing_since": None - }) + self.pusher_factory.create_pusher( + { + "id": None, + "user_name": user_id, + "kind": kind, + "app_id": app_id, + "app_display_name": app_display_name, + "device_display_name": device_display_name, + "pushkey": pushkey, + "ts": time_now_msec, + "lang": lang, + "data": data, + "last_stream_ordering": None, + "last_success": None, + "failing_since": None, + } + ) # create the pusher setting last_stream_ordering to the current maximum # stream ordering in event_push_actions, so it will process @@ -113,18 +126,19 @@ class PusherPool: defer.returnValue(pusher) @defer.inlineCallbacks - def remove_pushers_by_app_id_and_pushkey_not_user(self, app_id, pushkey, - not_user_id): - to_remove = yield self.store.get_pushers_by_app_id_and_pushkey( - app_id, pushkey - ) + def remove_pushers_by_app_id_and_pushkey_not_user( + self, app_id, pushkey, not_user_id + ): + to_remove = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) for p in to_remove: - if p['user_name'] != not_user_id: + if p["user_name"] != not_user_id: logger.info( "Removing pusher for app id %s, pushkey %s, user %s", - app_id, pushkey, p['user_name'] + app_id, + pushkey, + p["user_name"], ) - yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) + yield self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) @defer.inlineCallbacks def remove_pushers_by_access_token(self, user_id, access_tokens): @@ -138,14 +152,14 @@ class PusherPool: """ tokens = set(access_tokens) for p in (yield self.store.get_pushers_by_user_id(user_id)): - if p['access_token'] in tokens: + if p["access_token"] in tokens: logger.info( "Removing pusher for app id %s, pushkey %s, user %s", - p['app_id'], p['pushkey'], p['user_name'] - ) - yield self.remove_pusher( - p['app_id'], p['pushkey'], p['user_name'], + p["app_id"], + p["pushkey"], + p["user_name"], ) + yield self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) @defer.inlineCallbacks def on_new_notifications(self, min_stream_id, max_stream_id): @@ -199,13 +213,11 @@ class PusherPool: if not self._should_start_pushers: return - resultlist = yield self.store.get_pushers_by_app_id_and_pushkey( - app_id, pushkey - ) + resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) pusher_dict = None for r in resultlist: - if r['user_name'] == user_id: + if r["user_name"] == user_id: pusher_dict = r pusher = None @@ -245,9 +257,9 @@ class PusherPool: except PusherConfigException as e: logger.warning( "Pusher incorrectly configured user=%s, appid=%s, pushkey=%s: %s", - pusherdict.get('user_name'), - pusherdict.get('app_id'), - pusherdict.get('pushkey'), + pusherdict.get("user_name"), + pusherdict.get("app_id"), + pusherdict.get("pushkey"), e, ) return @@ -258,11 +270,8 @@ class PusherPool: if not p: return - appid_pushkey = "%s:%s" % ( - pusherdict['app_id'], - pusherdict['pushkey'], - ) - byuser = self.pushers.setdefault(pusherdict['user_name'], {}) + appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"]) + byuser = self.pushers.setdefault(pusherdict["user_name"], {}) if appid_pushkey in byuser: byuser[appid_pushkey].on_stop() @@ -275,7 +284,7 @@ class PusherPool: last_stream_ordering = pusherdict["last_stream_ordering"] if last_stream_ordering: have_notifs = yield self.store.get_if_maybe_push_in_range_for_user( - user_id, last_stream_ordering, + user_id, last_stream_ordering ) else: # We always want to default to starting up the pusher rather than diff --git a/synapse/push/rulekinds.py b/synapse/push/rulekinds.py index 4cae48ac07..ce7cc1b4ee 100644 --- a/synapse/push/rulekinds.py +++ b/synapse/push/rulekinds.py @@ -13,10 +13,10 @@ # limitations under the License. PRIORITY_CLASS_MAP = { - 'underride': 1, - 'sender': 2, - 'room': 3, - 'content': 4, - 'override': 5, + "underride": 1, + "sender": 2, + "room": 3, + "content": 4, + "override": 5, } PRIORITY_CLASS_INVERSE_MAP = {v: k for k, v in PRIORITY_CLASS_MAP.items()} diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 11ace2bfb1..77e2d1a0e1 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -45,14 +45,11 @@ REQUIREMENTS = [ "signedjson>=1.0.0", "pynacl>=1.2.1", "idna>=2.5", - # validating SSL certs for IP addresses requires service_identity 18.1. "service_identity>=18.1.0", - # our logcontext handling relies on the ability to cancel inlineCallbacks # (https://twistedmatrix.com/trac/ticket/4632) which landed in Twisted 18.7. "Twisted>=18.7.0", - "treq>=15.1", # Twisted has required pyopenssl 16.0 since about Twisted 16.6. "pyopenssl>=16.0.0", @@ -71,34 +68,27 @@ REQUIREMENTS = [ # prometheus_client 0.4.0 changed the format of counter metrics # (cf https://github.com/matrix-org/synapse/issues/4001) "prometheus_client>=0.0.18,<0.4.0", - # we use attr.s(slots), which arrived in 16.0.0 # Twisted 18.7.0 requires attrs>=17.4.0 "attrs>=17.4.0", - "netaddr>=0.7.18", ] CONDITIONAL_REQUIREMENTS = { "email": ["Jinja2>=2.9", "bleach>=1.4.3"], "matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"], - # we use execute_batch, which arrived in psycopg 2.7. "postgres": ["psycopg2>=2.7"], - # ConsentResource uses select_autoescape, which arrived in jinja 2.9 "resources.consent": ["Jinja2>=2.9"], - # ACME support is required to provision TLS certificates from authorities # that use the protocol, such as Let's Encrypt. "acme": [ "txacme>=0.9.2", - # txacme depends on eliot. Eliot 1.8.0 is incompatible with # python 3.5.2, as per https://github.com/itamarst/eliot/issues/418 'eliot<1.8.0;python_version<"3.5.3"', ], - "saml2": ["pysaml2>=4.5.0"], "systemd": ["systemd-python>=231"], "url_preview": ["lxml>=3.5.0"], @@ -121,12 +111,14 @@ def list_requirements(): class DependencyException(Exception): @property def message(self): - return "\n".join([ - "Missing Requirements: %s" % (", ".join(self.dependencies),), - "To install run:", - " pip install --upgrade --force %s" % (" ".join(self.dependencies),), - "", - ]) + return "\n".join( + [ + "Missing Requirements: %s" % (", ".join(self.dependencies),), + "To install run:", + " pip install --upgrade --force %s" % (" ".join(self.dependencies),), + "", + ] + ) @property def dependencies(self): diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 0a432a16fa..fe482e279f 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -83,8 +83,7 @@ class ReplicationEndpoint(object): def __init__(self, hs): if self.CACHE: self.response_cache = ResponseCache( - hs, "repl." + self.NAME, - timeout_ms=30 * 60 * 1000, + hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000 ) assert self.METHOD in ("PUT", "POST", "GET") @@ -134,8 +133,7 @@ class ReplicationEndpoint(object): data = yield cls._serialize_payload(**kwargs) url_args = [ - urllib.parse.quote(kwargs[name], safe='') - for name in cls.PATH_ARGS + urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS ] if cls.CACHE: @@ -156,7 +154,10 @@ class ReplicationEndpoint(object): ) uri = "http://%s:%s/_synapse/replication/%s/%s" % ( - host, port, cls.NAME, "/".join(url_args) + host, + port, + cls.NAME, + "/".join(url_args), ) try: @@ -202,10 +203,7 @@ class ReplicationEndpoint(object): url_args.append("txn_id") args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args) - pattern = re.compile("^/_synapse/replication/%s/%s$" % ( - self.NAME, - args - )) + pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args)) http_server.register_paths(method, [pattern], handler) @@ -219,8 +217,4 @@ class ReplicationEndpoint(object): assert self.CACHE - return self.response_cache.wrap( - txn_id, - self._handle_request, - request, **kwargs - ) + return self.response_cache.wrap(txn_id, self._handle_request, request, **kwargs) diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 0f0a07c422..61eafbe708 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -68,18 +68,17 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): for event, context in event_and_contexts: serialized_context = yield context.serialize(event, store) - event_payloads.append({ - "event": event.get_pdu_json(), - "event_format_version": event.format_version, - "internal_metadata": event.internal_metadata.get_dict(), - "rejected_reason": event.rejected_reason, - "context": serialized_context, - }) - - payload = { - "events": event_payloads, - "backfilled": backfilled, - } + event_payloads.append( + { + "event": event.get_pdu_json(), + "event_format_version": event.format_version, + "internal_metadata": event.internal_metadata.get_dict(), + "rejected_reason": event.rejected_reason, + "context": serialized_context, + } + ) + + payload = {"events": event_payloads, "backfilled": backfilled} defer.returnValue(payload) @@ -103,18 +102,15 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): event = EventType(event_dict, internal_metadata, rejected_reason) context = yield EventContext.deserialize( - self.store, event_payload["context"], + self.store, event_payload["context"] ) event_and_contexts.append((event, context)) - logger.info( - "Got %d events from federation", - len(event_and_contexts), - ) + logger.info("Got %d events from federation", len(event_and_contexts)) yield self.federation_handler.persist_events_and_notify( - event_and_contexts, backfilled, + event_and_contexts, backfilled ) defer.returnValue((200, {})) @@ -146,10 +142,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): @staticmethod def _serialize_payload(edu_type, origin, content): - return { - "origin": origin, - "content": content, - } + return {"origin": origin, "content": content} @defer.inlineCallbacks def _handle_request(self, request, edu_type): @@ -159,10 +152,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): origin = content["origin"] edu_content = content["content"] - logger.info( - "Got %r edu from %s", - edu_type, origin, - ) + logger.info("Got %r edu from %s", edu_type, origin) result = yield self.registry.on_edu(edu_type, origin, edu_content) @@ -201,9 +191,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): query_type (str) args (dict): The arguments received for the given query type """ - return { - "args": args, - } + return {"args": args} @defer.inlineCallbacks def _handle_request(self, request, query_type): @@ -212,10 +200,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): args = content["args"] - logger.info( - "Got %r query", - query_type, - ) + logger.info("Got %r query", query_type) result = yield self.registry.on_query(query_type, args) diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index 63bc0405ea..7c1197e5dd 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -61,13 +61,10 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): is_guest = content["is_guest"] device_id, access_token = yield self.registration_handler.register_device( - user_id, device_id, initial_display_name, is_guest, + user_id, device_id, initial_display_name, is_guest ) - defer.returnValue((200, { - "device_id": device_id, - "access_token": access_token, - })) + defer.returnValue((200, {"device_id": device_id, "access_token": access_token})) def register_servlets(hs, http_server): diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 81a2b204c7..0a76a3762f 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -40,7 +40,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): """ NAME = "remote_join" - PATH_ARGS = ("room_id", "user_id",) + PATH_ARGS = ("room_id", "user_id") def __init__(self, hs): super(ReplicationRemoteJoinRestServlet, self).__init__(hs) @@ -50,8 +50,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): self.clock = hs.get_clock() @staticmethod - def _serialize_payload(requester, room_id, user_id, remote_room_hosts, - content): + def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content): """ Args: requester(Requester) @@ -78,16 +77,10 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): if requester.user: request.authenticated_entity = requester.user.to_string() - logger.info( - "remote_join: %s into room: %s", - user_id, room_id, - ) + logger.info("remote_join: %s into room: %s", user_id, room_id) yield self.federation_handler.do_invite_join( - remote_room_hosts, - room_id, - user_id, - event_content, + remote_room_hosts, room_id, user_id, event_content ) defer.returnValue((200, {})) @@ -107,7 +100,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): """ NAME = "remote_reject_invite" - PATH_ARGS = ("room_id", "user_id",) + PATH_ARGS = ("room_id", "user_id") def __init__(self, hs): super(ReplicationRemoteRejectInviteRestServlet, self).__init__(hs) @@ -141,16 +134,11 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): if requester.user: request.authenticated_entity = requester.user.to_string() - logger.info( - "remote_reject_invite: %s out of room: %s", - user_id, room_id, - ) + logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id) try: event = yield self.federation_handler.do_remotely_reject_invite( - remote_room_hosts, - room_id, - user_id, + remote_room_hosts, room_id, user_id ) ret = event.get_pdu_json() except Exception as e: @@ -162,9 +150,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): # logger.warn("Failed to reject invite: %s", e) - yield self.store.locally_reject_invite( - user_id, room_id - ) + yield self.store.locally_reject_invite(user_id, room_id) ret = {} defer.returnValue((200, ret)) @@ -228,7 +214,7 @@ class ReplicationRegister3PIDGuestRestServlet(ReplicationEndpoint): logger.info("get_or_register_3pid_guest: %r", content) ret = yield self.registeration_handler.get_or_register_3pid_guest( - medium, address, inviter_user_id, + medium, address, inviter_user_id ) defer.returnValue((200, ret)) @@ -264,7 +250,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): user_id (str) change (str): Either "joined" or "left" """ - assert change in ("joined", "left",) + assert change in ("joined", "left") return {} diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index 912a5ac341..f81a0f1b8f 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -37,8 +37,16 @@ class ReplicationRegisterServlet(ReplicationEndpoint): @staticmethod def _serialize_payload( - user_id, token, password_hash, was_guest, make_guest, appservice_id, - create_profile_with_displayname, admin, user_type, address, + user_id, + token, + password_hash, + was_guest, + make_guest, + appservice_id, + create_profile_with_displayname, + admin, + user_type, + address, ): """ Args: @@ -85,7 +93,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): create_profile_with_displayname=content["create_profile_with_displayname"], admin=content["admin"], user_type=content["user_type"], - address=content["address"] + address=content["address"], ) defer.returnValue((200, {})) @@ -104,8 +112,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): self.registration_handler = hs.get_registration_handler() @staticmethod - def _serialize_payload(user_id, auth_result, access_token, bind_email, - bind_msisdn): + def _serialize_payload(user_id, auth_result, access_token, bind_email, bind_msisdn): """ Args: user_id (str): The user ID that consented diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index 3635015eda..034763fe99 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -45,6 +45,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): "extra_users": [], } """ + NAME = "send_event" PATH_ARGS = ("event_id",) @@ -57,8 +58,9 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): @staticmethod @defer.inlineCallbacks - def _serialize_payload(event_id, store, event, context, requester, - ratelimit, extra_users): + def _serialize_payload( + event_id, store, event, context, requester, ratelimit, extra_users + ): """ Args: event_id (str) @@ -108,14 +110,11 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): request.authenticated_entity = requester.user.to_string() logger.info( - "Got event to send with ID: %s into room: %s", - event.event_id, event.room_id, + "Got event to send with ID: %s into room: %s", event.event_id, event.room_id ) yield self.event_creation_handler.persist_and_notify_client_event( - requester, event, context, - ratelimit=ratelimit, - extra_users=extra_users, + requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) defer.returnValue((200, {})) diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 817d1f67f9..182cb2a1d8 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -37,7 +37,7 @@ class BaseSlavedStore(SQLBaseStore): super(BaseSlavedStore, self).__init__(db_conn, hs) if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen = SlavedIdTracker( - db_conn, "cache_invalidation_stream", "stream_id", + db_conn, "cache_invalidation_stream", "stream_id" ) else: self._cache_id_gen = None diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index d9ba6d69b1..3c44d1d48d 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -21,10 +21,9 @@ from synapse.storage.tags import TagsWorkerStore class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): self._account_data_id_gen = SlavedIdTracker( - db_conn, "account_data_max_stream_id", "stream_id", + db_conn, "account_data_max_stream_id", "stream_id" ) super(SlavedAccountDataStore, self).__init__(db_conn, hs) @@ -45,24 +44,20 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved self._account_data_id_gen.advance(token) for row in rows: self.get_tags_for_user.invalidate((row.user_id,)) - self._account_data_stream_cache.entity_has_changed( - row.user_id, token - ) + self._account_data_stream_cache.entity_has_changed(row.user_id, token) elif stream_name == "account_data": self._account_data_id_gen.advance(token) for row in rows: if not row.room_id: self.get_global_account_data_by_type_for_user.invalidate( - (row.data_type, row.user_id,) + (row.data_type, row.user_id) ) self.get_account_data_for_user.invalidate((row.user_id,)) - self.get_account_data_for_room.invalidate((row.user_id, row.room_id,)) + self.get_account_data_for_room.invalidate((row.user_id, row.room_id)) self.get_account_data_for_room_and_type.invalidate( - (row.user_id, row.room_id, row.data_type,), - ) - self._account_data_stream_cache.entity_has_changed( - row.user_id, token + (row.user_id, row.room_id, row.data_type) ) + self._account_data_stream_cache.entity_has_changed(row.user_id, token) return super(SlavedAccountDataStore, self).process_replication_rows( stream_name, token, rows ) diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py index b53a4c6bd1..cda12ea70d 100644 --- a/synapse/replication/slave/storage/appservice.py +++ b/synapse/replication/slave/storage/appservice.py @@ -20,6 +20,7 @@ from synapse.storage.appservice import ( ) -class SlavedApplicationServiceStore(ApplicationServiceTransactionWorkerStore, - ApplicationServiceWorkerStore): +class SlavedApplicationServiceStore( + ApplicationServiceTransactionWorkerStore, ApplicationServiceWorkerStore +): pass diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 5b8521c770..14ced32333 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -25,9 +25,7 @@ class SlavedClientIpStore(BaseSlavedStore): super(SlavedClientIpStore, self).__init__(db_conn, hs) self.client_ip_last_seen = Cache( - name="client_ip_last_seen", - keylen=4, - max_entries=50000 * CACHE_SIZE_FACTOR, + name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR ) def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index 4d59778863..284fd30d89 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -24,15 +24,15 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): super(SlavedDeviceInboxStore, self).__init__(db_conn, hs) self._device_inbox_id_gen = SlavedIdTracker( - db_conn, "device_max_stream_id", "stream_id", + db_conn, "device_max_stream_id", "stream_id" ) self._device_inbox_stream_cache = StreamChangeCache( "DeviceInboxStreamChangeCache", - self._device_inbox_id_gen.get_current_token() + self._device_inbox_id_gen.get_current_token(), ) self._device_federation_outbox_stream_cache = StreamChangeCache( "DeviceFederationOutboxStreamChangeCache", - self._device_inbox_id_gen.get_current_token() + self._device_inbox_id_gen.get_current_token(), ) self._last_device_delete_cache = ExpiringCache( diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 16c9a162c5..d9300fce33 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -27,14 +27,14 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto self.hs = hs self._device_list_id_gen = SlavedIdTracker( - db_conn, "device_lists_stream", "stream_id", + db_conn, "device_lists_stream", "stream_id" ) device_list_max = self._device_list_id_gen.get_current_token() self._device_list_stream_cache = StreamChangeCache( - "DeviceListStreamChangeCache", device_list_max, + "DeviceListStreamChangeCache", device_list_max ) self._device_list_federation_stream_cache = StreamChangeCache( - "DeviceListFederationStreamChangeCache", device_list_max, + "DeviceListFederationStreamChangeCache", device_list_max ) def stream_positions(self): @@ -46,17 +46,13 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto if stream_name == "device_lists": self._device_list_id_gen.advance(token) for row in rows: - self._invalidate_caches_for_devices( - token, row.user_id, row.destination, - ) + self._invalidate_caches_for_devices(token, row.user_id, row.destination) return super(SlavedDeviceStore, self).process_replication_rows( stream_name, token, rows ) def _invalidate_caches_for_devices(self, token, user_id, destination): - self._device_list_stream_cache.entity_has_changed( - user_id, token - ) + self._device_list_stream_cache.entity_has_changed(user_id, token) if destination: self._device_list_federation_stream_cache.entity_has_changed( diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index a3952506c1..ab5937e638 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -45,21 +45,20 @@ logger = logging.getLogger(__name__) # the method descriptor on the DataStore and chuck them into our class. -class SlavedEventStore(EventFederationWorkerStore, - RoomMemberWorkerStore, - EventPushActionsWorkerStore, - StreamWorkerStore, - StateGroupWorkerStore, - EventsWorkerStore, - SignatureWorkerStore, - UserErasureWorkerStore, - RelationsWorkerStore, - BaseSlavedStore): - +class SlavedEventStore( + EventFederationWorkerStore, + RoomMemberWorkerStore, + EventPushActionsWorkerStore, + StreamWorkerStore, + StateGroupWorkerStore, + EventsWorkerStore, + SignatureWorkerStore, + UserErasureWorkerStore, + RelationsWorkerStore, + BaseSlavedStore, +): def __init__(self, db_conn, hs): - self._stream_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering", - ) + self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering") self._backfill_id_gen = SlavedIdTracker( db_conn, "events", "stream_ordering", step=-1 ) @@ -90,8 +89,13 @@ class SlavedEventStore(EventFederationWorkerStore, self._backfill_id_gen.advance(-token) for row in rows: self.invalidate_caches_for_event( - -token, row.event_id, row.room_id, row.type, row.state_key, - row.redacts, row.relates_to, + -token, + row.event_id, + row.room_id, + row.type, + row.state_key, + row.redacts, + row.relates_to, backfilled=True, ) return super(SlavedEventStore, self).process_replication_rows( @@ -103,41 +107,48 @@ class SlavedEventStore(EventFederationWorkerStore, if row.type == EventsStreamEventRow.TypeId: self.invalidate_caches_for_event( - token, data.event_id, data.room_id, data.type, data.state_key, - data.redacts, data.relates_to, + token, + data.event_id, + data.room_id, + data.type, + data.state_key, + data.redacts, + data.relates_to, backfilled=False, ) elif row.type == EventsStreamCurrentStateRow.TypeId: if data.type == EventTypes.Member: self.get_rooms_for_user_with_stream_ordering.invalidate( - (data.state_key, ), + (data.state_key,) ) else: - raise Exception("Unknown events stream row type %s" % (row.type, )) - - def invalidate_caches_for_event(self, stream_ordering, event_id, room_id, - etype, state_key, redacts, relates_to, - backfilled): + raise Exception("Unknown events stream row type %s" % (row.type,)) + + def invalidate_caches_for_event( + self, + stream_ordering, + event_id, + room_id, + etype, + state_key, + redacts, + relates_to, + backfilled, + ): self._invalidate_get_event_cache(event_id) self.get_latest_event_ids_in_room.invalidate((room_id,)) - self.get_unread_event_push_actions_by_room_for_user.invalidate_many( - (room_id,) - ) + self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,)) if not backfilled: - self._events_stream_cache.entity_has_changed( - room_id, stream_ordering - ) + self._events_stream_cache.entity_has_changed(room_id, stream_ordering) if redacts: self._invalidate_get_event_cache(redacts) if etype == EventTypes.Member: - self._membership_stream_cache.entity_has_changed( - state_key, stream_ordering - ) + self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) self.get_invited_rooms_for_user.invalidate((state_key,)) if relates_to: diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index e933b170bb..28a46edd28 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -27,10 +27,11 @@ class SlavedGroupServerStore(BaseSlavedStore): self.hs = hs self._group_updates_id_gen = SlavedIdTracker( - db_conn, "local_group_updates", "stream_id", + db_conn, "local_group_updates", "stream_id" ) self._group_updates_stream_cache = StreamChangeCache( - "_group_updates_stream_cache", self._group_updates_id_gen.get_current_token(), + "_group_updates_stream_cache", + self._group_updates_id_gen.get_current_token(), ) get_groups_changes_for_user = __func__(DataStore.get_groups_changes_for_user) @@ -46,9 +47,7 @@ class SlavedGroupServerStore(BaseSlavedStore): if stream_name == "groups": self._group_updates_id_gen.advance(token) for row in rows: - self._group_updates_stream_cache.entity_has_changed( - row.user_id, token - ) + self._group_updates_stream_cache.entity_has_changed(row.user_id, token) return super(SlavedGroupServerStore, self).process_replication_rows( stream_name, token, rows diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index 0ec1db25ce..82d808af4c 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -24,9 +24,7 @@ from ._slaved_id_tracker import SlavedIdTracker class SlavedPresenceStore(BaseSlavedStore): def __init__(self, db_conn, hs): super(SlavedPresenceStore, self).__init__(db_conn, hs) - self._presence_id_gen = SlavedIdTracker( - db_conn, "presence_stream", "stream_id", - ) + self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id") self._presence_on_startup = self._get_active_presence(db_conn) @@ -55,9 +53,7 @@ class SlavedPresenceStore(BaseSlavedStore): if stream_name == "presence": self._presence_id_gen.advance(token) for row in rows: - self.presence_stream_cache.entity_has_changed( - row.user_id, token - ) + self.presence_stream_cache.entity_has_changed(row.user_id, token) self._get_presence_for_user.invalidate((row.user_id,)) return super(SlavedPresenceStore, self).process_replication_rows( stream_name, token, rows diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 45fc913c52..af7012702e 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -23,7 +23,7 @@ from .events import SlavedEventStore class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): def __init__(self, db_conn, hs): self._push_rules_stream_id_gen = SlavedIdTracker( - db_conn, "push_rules_stream", "stream_id", + db_conn, "push_rules_stream", "stream_id" ) super(SlavedPushRuleStore, self).__init__(db_conn, hs) @@ -47,9 +47,7 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): for row in rows: self.get_push_rules_for_user.invalidate((row.user_id,)) self.get_push_rules_enabled_for_user.invalidate((row.user_id,)) - self.push_rules_stream_cache.entity_has_changed( - row.user_id, token - ) + self.push_rules_stream_cache.entity_has_changed(row.user_id, token) return super(SlavedPushRuleStore, self).process_replication_rows( stream_name, token, rows ) diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index 3b2213c0d4..8eeb267d61 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -21,12 +21,10 @@ from ._slaved_id_tracker import SlavedIdTracker class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): super(SlavedPusherStore, self).__init__(db_conn, hs) self._pushers_id_gen = SlavedIdTracker( - db_conn, "pushers", "id", - extra_tables=[("deleted_pushers", "stream_id")], + db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] ) def stream_positions(self): diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index ed12342f40..91afa5a72b 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -29,7 +29,6 @@ from ._slaved_id_tracker import SlavedIdTracker class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): # We instantiate this first as the ReceiptsWorkerStore constructor # needs to be able to call get_max_receipt_stream_id diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index 0cb474928c..f68b3378e3 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -38,6 +38,4 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore): if stream_name == "public_rooms": self._public_room_id_gen.advance(token) - return super(RoomStore, self).process_replication_rows( - stream_name, token, rows - ) + return super(RoomStore, self).process_replication_rows(stream_name, token, rows) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 206dc3b397..a44ceb00e7 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -39,6 +39,7 @@ class ReplicationClientFactory(ReconnectingClientFactory): Accepts a handler that will be called when new data is available or data is required. """ + maxDelay = 30 # Try at least once every N seconds def __init__(self, hs, client_name, handler): @@ -64,9 +65,7 @@ class ReplicationClientFactory(ReconnectingClientFactory): def clientConnectionFailed(self, connector, reason): logger.error("Failed to connect to replication: %r", reason) - ReconnectingClientFactory.clientConnectionFailed( - self, connector, reason - ) + ReconnectingClientFactory.clientConnectionFailed(self, connector, reason) class ReplicationClientHandler(object): @@ -74,6 +73,7 @@ class ReplicationClientHandler(object): By default proxies incoming replication data to the SlaveStore. """ + def __init__(self, store): self.store = store diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 2098c32a77..0ff2a7199f 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -23,9 +23,11 @@ import platform if platform.python_implementation() == "PyPy": import json + _json_encoder = json.JSONEncoder() else: import simplejson as json + _json_encoder = json.JSONEncoder(namedtuple_as_object=False) logger = logging.getLogger(__name__) @@ -41,6 +43,7 @@ class Command(object): The default implementation creates a command of form ` ` """ + NAME = None def __init__(self, data): @@ -73,6 +76,7 @@ class ServerCommand(Command): SERVER """ + NAME = "SERVER" @@ -99,6 +103,7 @@ class RdataCommand(Command): RDATA presence batch ["@bar:example.com", "online", ...] RDATA presence 59 ["@baz:example.com", "online", ...] """ + NAME = "RDATA" def __init__(self, stream_name, token, row): @@ -110,17 +115,17 @@ class RdataCommand(Command): def from_line(cls, line): stream_name, token, row_json = line.split(" ", 2) return cls( - stream_name, - None if token == "batch" else int(token), - json.loads(row_json) + stream_name, None if token == "batch" else int(token), json.loads(row_json) ) def to_line(self): - return " ".join(( - self.stream_name, - str(self.token) if self.token is not None else "batch", - _json_encoder.encode(self.row), - )) + return " ".join( + ( + self.stream_name, + str(self.token) if self.token is not None else "batch", + _json_encoder.encode(self.row), + ) + ) def get_logcontext_id(self): return "RDATA-" + self.stream_name @@ -133,6 +138,7 @@ class PositionCommand(Command): Sent to the client after all missing updates for a stream have been sent to the client and they're now up to date. """ + NAME = "POSITION" def __init__(self, stream_name, token): @@ -145,19 +151,21 @@ class PositionCommand(Command): return cls(stream_name, int(token)) def to_line(self): - return " ".join((self.stream_name, str(self.token),)) + return " ".join((self.stream_name, str(self.token))) class ErrorCommand(Command): """Sent by either side if there was an ERROR. The data is a string describing the error. """ + NAME = "ERROR" class PingCommand(Command): """Sent by either side as a keep alive. The data is arbitary (often timestamp) """ + NAME = "PING" @@ -165,6 +173,7 @@ class NameCommand(Command): """Sent by client to inform the server of the client's identity. The data is the name """ + NAME = "NAME" @@ -184,6 +193,7 @@ class ReplicateCommand(Command): REPLICATE ALL NOW """ + NAME = "REPLICATE" def __init__(self, stream_name, token): @@ -200,7 +210,7 @@ class ReplicateCommand(Command): return cls(stream_name, token) def to_line(self): - return " ".join((self.stream_name, str(self.token),)) + return " ".join((self.stream_name, str(self.token))) def get_logcontext_id(self): return "REPLICATE-" + self.stream_name @@ -218,6 +228,7 @@ class UserSyncCommand(Command): Where is either "start" or "stop" """ + NAME = "USER_SYNC" def __init__(self, user_id, is_syncing, last_sync_ms): @@ -235,9 +246,13 @@ class UserSyncCommand(Command): return cls(user_id, state == "start", int(last_sync_ms)) def to_line(self): - return " ".join(( - self.user_id, "start" if self.is_syncing else "end", str(self.last_sync_ms), - )) + return " ".join( + ( + self.user_id, + "start" if self.is_syncing else "end", + str(self.last_sync_ms), + ) + ) class FederationAckCommand(Command): @@ -251,6 +266,7 @@ class FederationAckCommand(Command): FEDERATION_ACK """ + NAME = "FEDERATION_ACK" def __init__(self, token): @@ -268,6 +284,7 @@ class SyncCommand(Command): """Used for testing. The client protocol implementation allows waiting on a SYNC command with a specified data. """ + NAME = "SYNC" @@ -278,6 +295,7 @@ class RemovePusherCommand(Command): REMOVE_PUSHER """ + NAME = "REMOVE_PUSHER" def __init__(self, app_id, push_key, user_id): @@ -309,6 +327,7 @@ class InvalidateCacheCommand(Command): Where is a json list. """ + NAME = "INVALIDATE_CACHE" def __init__(self, cache_func, keys): @@ -322,9 +341,7 @@ class InvalidateCacheCommand(Command): return cls(cache_func, json.loads(keys_json)) def to_line(self): - return " ".join(( - self.cache_func, _json_encoder.encode(self.keys), - )) + return " ".join((self.cache_func, _json_encoder.encode(self.keys))) class UserIpCommand(Command): @@ -334,6 +351,7 @@ class UserIpCommand(Command): USER_IP , , , , , """ + NAME = "USER_IP" def __init__(self, user_id, access_token, ip, user_agent, device_id, last_seen): @@ -350,15 +368,22 @@ class UserIpCommand(Command): access_token, ip, user_agent, device_id, last_seen = json.loads(jsn) - return cls( - user_id, access_token, ip, user_agent, device_id, last_seen - ) + return cls(user_id, access_token, ip, user_agent, device_id, last_seen) def to_line(self): - return self.user_id + " " + _json_encoder.encode(( - self.access_token, self.ip, self.user_agent, self.device_id, - self.last_seen, - )) + return ( + self.user_id + + " " + + _json_encoder.encode( + ( + self.access_token, + self.ip, + self.user_agent, + self.device_id, + self.last_seen, + ) + ) + ) # Map of command name to command type. diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index b51590cf8f..97efb835ad 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -84,7 +84,8 @@ from .commands import ( from .streams import STREAMS_MAP connection_close_counter = Counter( - "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]) + "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"] +) # A list of all connected protocols. This allows us to send metrics about the # connections. @@ -119,7 +120,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): It also sends `PING` periodically, and correctly times out remote connections (if they send a `PING` command) """ - delimiter = b'\n' + + delimiter = b"\n" VALID_INBOUND_COMMANDS = [] # Valid commands we expect to receive VALID_OUTBOUND_COMMANDS = [] # Valid commans we can send @@ -183,10 +185,14 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): if now - self.last_sent_command >= PING_TIME: self.send_command(PingCommand(now)) - if self.received_ping and now - self.last_received_command > PING_TIMEOUT_MS: + if ( + self.received_ping + and now - self.last_received_command > PING_TIMEOUT_MS + ): logger.info( "[%s] Connection hasn't received command in %r ms. Closing.", - self.id(), now - self.last_received_command + self.id(), + now - self.last_received_command, ) self.send_error("ping timeout") @@ -208,7 +214,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.last_received_command = self.clock.time_msec() self.inbound_commands_counter[cmd_name] = ( - self.inbound_commands_counter[cmd_name] + 1) + self.inbound_commands_counter[cmd_name] + 1 + ) cmd_cls = COMMAND_MAP[cmd_name] try: @@ -224,9 +231,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # Now lets try and call on_ function run_as_background_process( - "replication-" + cmd.get_logcontext_id(), - self.handle_command, - cmd, + "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd ) def handle_command(self, cmd): @@ -274,8 +279,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): return self.outbound_commands_counter[cmd.NAME] = ( - self.outbound_commands_counter[cmd.NAME] + 1) - string = "%s %s" % (cmd.NAME, cmd.to_line(),) + self.outbound_commands_counter[cmd.NAME] + 1 + ) + string = "%s %s" % (cmd.NAME, cmd.to_line()) if "\n" in string: raise Exception("Unexpected newline in command: %r", string) @@ -283,10 +289,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): if len(encoded_string) > self.MAX_LENGTH: raise Exception( - "Failed to send command %s as too long (%d > %d)" % ( - cmd.NAME, - len(encoded_string), self.MAX_LENGTH, - ) + "Failed to send command %s as too long (%d > %d)" + % (cmd.NAME, len(encoded_string), self.MAX_LENGTH) ) self.sendLine(encoded_string) @@ -379,7 +383,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): if self.transport: addr = str(self.transport.getPeer()) return "ReplicationConnection" % ( - self.name, self.conn_id, addr, + self.name, + self.conn_id, + addr, ) def id(self): @@ -422,7 +428,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): def on_USER_SYNC(self, cmd): return self.streamer.on_user_sync( - self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms, + self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms ) def on_REPLICATE(self, cmd): @@ -432,10 +438,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): if stream_name == "ALL": # Subscribe to all streams we're publishing to. deferreds = [ - run_in_background( - self.subscribe_to_stream, - stream, token, - ) + run_in_background(self.subscribe_to_stream, stream, token) for stream in iterkeys(self.streamer.streams_by_name) ] @@ -449,16 +452,18 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): return self.streamer.federation_ack(cmd.token) def on_REMOVE_PUSHER(self, cmd): - return self.streamer.on_remove_pusher( - cmd.app_id, cmd.push_key, cmd.user_id, - ) + return self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id) def on_INVALIDATE_CACHE(self, cmd): return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) def on_USER_IP(self, cmd): return self.streamer.on_user_ip( - cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id, + cmd.user_id, + cmd.access_token, + cmd.ip, + cmd.user_agent, + cmd.device_id, cmd.last_seen, ) @@ -476,7 +481,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): try: # Get missing updates updates, current_token = yield self.streamer.get_stream_updates( - stream_name, token, + stream_name, token ) # Send all the missing updates @@ -608,8 +613,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): row = STREAMS_MAP[stream_name].parse_row(cmd.row) except Exception: logger.exception( - "[%s] Failed to parse RDATA: %r %r", - self.id(), stream_name, cmd.row + "[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row ) raise @@ -643,7 +647,9 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): logger.info( "[%s] Subscribing to replication stream: %r from %r", - self.id(), stream_name, token + self.id(), + stream_name, + token, ) self.streams_connecting.add(stream_name) @@ -661,9 +667,7 @@ pending_commands = LaterGauge( "synapse_replication_tcp_protocol_pending_commands", "", ["name"], - lambda: { - (p.name,): len(p.pending_commands) for p in connected_connections - }, + lambda: {(p.name,): len(p.pending_commands) for p in connected_connections}, ) @@ -678,9 +682,7 @@ transport_send_buffer = LaterGauge( "synapse_replication_tcp_protocol_transport_send_buffer", "", ["name"], - lambda: { - (p.name,): transport_buffer_size(p) for p in connected_connections - }, + lambda: {(p.name,): transport_buffer_size(p) for p in connected_connections}, ) @@ -694,7 +696,7 @@ def transport_kernel_read_buffer_size(protocol, read=True): op = SIOCINQ else: op = SIOCOUTQ - size = struct.unpack("I", fcntl.ioctl(fileno, op, '\0\0\0\0'))[0] + size = struct.unpack("I", fcntl.ioctl(fileno, op, "\0\0\0\0"))[0] return size return 0 @@ -726,7 +728,7 @@ tcp_inbound_commands = LaterGauge( "", ["command", "name"], lambda: { - (k, p.name,): count + (k, p.name): count for p in connected_connections for k, count in iteritems(p.inbound_commands_counter) }, @@ -737,7 +739,7 @@ tcp_outbound_commands = LaterGauge( "", ["command", "name"], lambda: { - (k, p.name,): count + (k, p.name): count for p in connected_connections for k, count in iteritems(p.outbound_commands_counter) }, diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index f6a38f5140..d1e98428bc 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -33,13 +33,15 @@ from .protocol import ServerReplicationStreamProtocol from .streams import STREAMS_MAP from .streams.federation import FederationStream -stream_updates_counter = Counter("synapse_replication_tcp_resource_stream_updates", - "", ["stream_name"]) +stream_updates_counter = Counter( + "synapse_replication_tcp_resource_stream_updates", "", ["stream_name"] +) user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "") federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "") remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "") -invalidate_cache_counter = Counter("synapse_replication_tcp_resource_invalidate_cache", - "") +invalidate_cache_counter = Counter( + "synapse_replication_tcp_resource_invalidate_cache", "" +) user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "") logger = logging.getLogger(__name__) @@ -48,6 +50,7 @@ logger = logging.getLogger(__name__) class ReplicationStreamProtocolFactory(Factory): """Factory for new replication connections. """ + def __init__(self, hs): self.streamer = ReplicationStreamer(hs) self.clock = hs.get_clock() @@ -55,9 +58,7 @@ class ReplicationStreamProtocolFactory(Factory): def buildProtocol(self, addr): return ServerReplicationStreamProtocol( - self.server_name, - self.clock, - self.streamer, + self.server_name, self.clock, self.streamer ) @@ -80,29 +81,39 @@ class ReplicationStreamer(object): # Current connections. self.connections = [] - LaterGauge("synapse_replication_tcp_resource_total_connections", "", [], - lambda: len(self.connections)) + LaterGauge( + "synapse_replication_tcp_resource_total_connections", + "", + [], + lambda: len(self.connections), + ) # List of streams that clients can subscribe to. # We only support federation stream if federation sending hase been # disabled on the master. self.streams = [ - stream(hs) for stream in itervalues(STREAMS_MAP) + stream(hs) + for stream in itervalues(STREAMS_MAP) if stream != FederationStream or not hs.config.send_federation ] self.streams_by_name = {stream.NAME: stream for stream in self.streams} LaterGauge( - "synapse_replication_tcp_resource_connections_per_stream", "", + "synapse_replication_tcp_resource_connections_per_stream", + "", ["stream_name"], lambda: { - (stream_name,): len([ - conn for conn in self.connections - if stream_name in conn.replication_streams - ]) + (stream_name,): len( + [ + conn + for conn in self.connections + if stream_name in conn.replication_streams + ] + ) for stream_name in self.streams_by_name - }) + }, + ) self.federation_sender = None if not hs.config.send_federation: @@ -179,7 +190,9 @@ class ReplicationStreamer(object): logger.debug( "Getting stream: %s: %s -> %s", - stream.NAME, stream.last_token, stream.upto_token + stream.NAME, + stream.last_token, + stream.upto_token, ) try: updates, current_token = yield stream.get_updates() @@ -189,7 +202,8 @@ class ReplicationStreamer(object): logger.debug( "Sending %d updates to %d connections", - len(updates), len(self.connections), + len(updates), + len(self.connections), ) if updates: @@ -243,7 +257,7 @@ class ReplicationStreamer(object): """ user_sync_counter.inc() yield self.presence_handler.update_external_syncs_row( - conn_id, user_id, is_syncing, last_sync_ms, + conn_id, user_id, is_syncing, last_sync_ms ) @measure_func("repl.on_remove_pusher") @@ -272,7 +286,7 @@ class ReplicationStreamer(object): """ user_ip_cache_counter.inc() yield self.store.insert_client_ip( - user_id, access_token, ip, user_agent, device_id, last_seen, + user_id, access_token, ip, user_agent, device_id, last_seen ) yield self._server_notices_sender.on_user_ip(user_id) diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index b6ce7a7bee..7ef67a5a73 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -26,78 +26,75 @@ logger = logging.getLogger(__name__) MAX_EVENTS_BEHIND = 10000 -BackfillStreamRow = namedtuple("BackfillStreamRow", ( - "event_id", # str - "room_id", # str - "type", # str - "state_key", # str, optional - "redacts", # str, optional - "relates_to", # str, optional -)) -PresenceStreamRow = namedtuple("PresenceStreamRow", ( - "user_id", # str - "state", # str - "last_active_ts", # int - "last_federation_update_ts", # int - "last_user_sync_ts", # int - "status_msg", # str - "currently_active", # bool -)) -TypingStreamRow = namedtuple("TypingStreamRow", ( - "room_id", # str - "user_ids", # list(str) -)) -ReceiptsStreamRow = namedtuple("ReceiptsStreamRow", ( - "room_id", # str - "receipt_type", # str - "user_id", # str - "event_id", # str - "data", # dict -)) -PushRulesStreamRow = namedtuple("PushRulesStreamRow", ( - "user_id", # str -)) -PushersStreamRow = namedtuple("PushersStreamRow", ( - "user_id", # str - "app_id", # str - "pushkey", # str - "deleted", # bool -)) -CachesStreamRow = namedtuple("CachesStreamRow", ( - "cache_func", # str - "keys", # list(str) - "invalidation_ts", # int -)) -PublicRoomsStreamRow = namedtuple("PublicRoomsStreamRow", ( - "room_id", # str - "visibility", # str - "appservice_id", # str, optional - "network_id", # str, optional -)) -DeviceListsStreamRow = namedtuple("DeviceListsStreamRow", ( - "user_id", # str - "destination", # str -)) -ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ( - "entity", # str -)) -TagAccountDataStreamRow = namedtuple("TagAccountDataStreamRow", ( - "user_id", # str - "room_id", # str - "data", # dict -)) -AccountDataStreamRow = namedtuple("AccountDataStream", ( - "user_id", # str - "room_id", # str - "data_type", # str - "data", # dict -)) -GroupsStreamRow = namedtuple("GroupsStreamRow", ( - "group_id", # str - "user_id", # str - "type", # str - "content", # dict -)) +BackfillStreamRow = namedtuple( + "BackfillStreamRow", + ( + "event_id", # str + "room_id", # str + "type", # str + "state_key", # str, optional + "redacts", # str, optional + "relates_to", # str, optional + ), +) +PresenceStreamRow = namedtuple( + "PresenceStreamRow", + ( + "user_id", # str + "state", # str + "last_active_ts", # int + "last_federation_update_ts", # int + "last_user_sync_ts", # int + "status_msg", # str + "currently_active", # bool + ), +) +TypingStreamRow = namedtuple( + "TypingStreamRow", ("room_id", "user_ids") # str # list(str) +) +ReceiptsStreamRow = namedtuple( + "ReceiptsStreamRow", + ( + "room_id", # str + "receipt_type", # str + "user_id", # str + "event_id", # str + "data", # dict + ), +) +PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str +PushersStreamRow = namedtuple( + "PushersStreamRow", + ("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool +) +CachesStreamRow = namedtuple( + "CachesStreamRow", + ("cache_func", "keys", "invalidation_ts"), # str # list(str) # int +) +PublicRoomsStreamRow = namedtuple( + "PublicRoomsStreamRow", + ( + "room_id", # str + "visibility", # str + "appservice_id", # str, optional + "network_id", # str, optional + ), +) +DeviceListsStreamRow = namedtuple( + "DeviceListsStreamRow", ("user_id", "destination") # str # str +) +ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str +TagAccountDataStreamRow = namedtuple( + "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict +) +AccountDataStreamRow = namedtuple( + "AccountDataStream", + ("user_id", "room_id", "data_type", "data"), # str # str # str # dict +) +GroupsStreamRow = namedtuple( + "GroupsStreamRow", + ("group_id", "user_id", "type", "content"), # str # str # str # dict +) class Stream(object): @@ -106,6 +103,7 @@ class Stream(object): Provides a `get_updates()` function that returns new updates since the last time it was called up until the point `advance_current_token` was called. """ + NAME = None # The name of the stream ROW_TYPE = None # The type of the row. Used by the default impl of parse_row. _LIMITED = True # Whether the update function takes a limit @@ -185,16 +183,13 @@ class Stream(object): if self._LIMITED: rows = yield self.update_function( - from_token, current_token, - limit=MAX_EVENTS_BEHIND + 1, + from_token, current_token, limit=MAX_EVENTS_BEHIND + 1 ) # never turn more than MAX_EVENTS_BEHIND + 1 into updates. rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1) else: - rows = yield self.update_function( - from_token, current_token, - ) + rows = yield self.update_function(from_token, current_token) updates = [(row[0], row[1:]) for row in rows] @@ -230,6 +225,7 @@ class BackfillStream(Stream): """We fetched some old events and either we had never seen that event before or it went from being an outlier to not. """ + NAME = "backfill" ROW_TYPE = BackfillStreamRow @@ -286,6 +282,7 @@ class ReceiptsStream(Stream): class PushRulesStream(Stream): """A user has changed their push rules """ + NAME = "push_rules" ROW_TYPE = PushRulesStreamRow @@ -306,6 +303,7 @@ class PushRulesStream(Stream): class PushersStream(Stream): """A user has added/changed/removed a pusher """ + NAME = "pushers" ROW_TYPE = PushersStreamRow @@ -322,6 +320,7 @@ class CachesStream(Stream): """A cache was invalidated on the master and no other stream would invalidate the cache on the workers """ + NAME = "caches" ROW_TYPE = CachesStreamRow @@ -337,6 +336,7 @@ class CachesStream(Stream): class PublicRoomsStream(Stream): """The public rooms list changed """ + NAME = "public_rooms" ROW_TYPE = PublicRoomsStreamRow @@ -352,6 +352,7 @@ class PublicRoomsStream(Stream): class DeviceListsStream(Stream): """Someone added/changed/removed a device """ + NAME = "device_lists" _LIMITED = False ROW_TYPE = DeviceListsStreamRow @@ -368,6 +369,7 @@ class DeviceListsStream(Stream): class ToDeviceStream(Stream): """New to_device messages for a client """ + NAME = "to_device" ROW_TYPE = ToDeviceStreamRow @@ -383,6 +385,7 @@ class ToDeviceStream(Stream): class TagAccountDataStream(Stream): """Someone added/removed a tag for a room """ + NAME = "tag_account_data" ROW_TYPE = TagAccountDataStreamRow @@ -398,6 +401,7 @@ class TagAccountDataStream(Stream): class AccountDataStream(Stream): """Global or per room account data was changed """ + NAME = "account_data" ROW_TYPE = AccountDataStreamRow @@ -416,7 +420,7 @@ class AccountDataStream(Stream): results = list(room_results) results.extend( - (stream_id, user_id, None, account_data_type, content,) + (stream_id, user_id, None, account_data_type, content) for stream_id, user_id, account_data_type, content in global_results ) diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index f1290d022a..3d0694bb11 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -52,6 +52,7 @@ data part are: @attr.s(slots=True, frozen=True) class EventsStreamRow(object): """A parsed row from the events replication stream""" + type = attr.ib() # str: the TypeId of one of the *EventsStreamRows data = attr.ib() # BaseEventsStreamRow @@ -80,11 +81,11 @@ class BaseEventsStreamRow(object): class EventsStreamEventRow(BaseEventsStreamRow): TypeId = "ev" - event_id = attr.ib() # str - room_id = attr.ib() # str - type = attr.ib() # str - state_key = attr.ib() # str, optional - redacts = attr.ib() # str, optional + event_id = attr.ib() # str + room_id = attr.ib() # str + type = attr.ib() # str + state_key = attr.ib() # str, optional + redacts = attr.ib() # str, optional relates_to = attr.ib() # str, optional @@ -92,24 +93,21 @@ class EventsStreamEventRow(BaseEventsStreamRow): class EventsStreamCurrentStateRow(BaseEventsStreamRow): TypeId = "state" - room_id = attr.ib() # str - type = attr.ib() # str + room_id = attr.ib() # str + type = attr.ib() # str state_key = attr.ib() # str - event_id = attr.ib() # str, optional + event_id = attr.ib() # str, optional TypeToRow = { - Row.TypeId: Row - for Row in ( - EventsStreamEventRow, - EventsStreamCurrentStateRow, - ) + Row.TypeId: Row for Row in (EventsStreamEventRow, EventsStreamCurrentStateRow) } class EventsStream(Stream): """We received a new event, or an event went from being an outlier to not """ + NAME = "events" def __init__(self, hs): @@ -121,19 +119,17 @@ class EventsStream(Stream): @defer.inlineCallbacks def update_function(self, from_token, current_token, limit=None): event_rows = yield self._store.get_all_new_forward_event_rows( - from_token, current_token, limit, + from_token, current_token, limit ) event_updates = ( - (row[0], EventsStreamEventRow.TypeId, row[1:]) - for row in event_rows + (row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows ) state_rows = yield self._store.get_all_updated_current_state_deltas( from_token, current_token, limit ) state_updates = ( - (row[0], EventsStreamCurrentStateRow.TypeId, row[1:]) - for row in state_rows + (row[0], EventsStreamCurrentStateRow.TypeId, row[1:]) for row in state_rows ) all_updates = heapq.merge(event_updates, state_updates) diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index 9aa43aa8d2..dc2484109d 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -17,16 +17,20 @@ from collections import namedtuple from ._base import Stream -FederationStreamRow = namedtuple("FederationStreamRow", ( - "type", # str, the type of data as defined in the BaseFederationRows - "data", # dict, serialization of a federation.send_queue.BaseFederationRow -)) +FederationStreamRow = namedtuple( + "FederationStreamRow", + ( + "type", # str, the type of data as defined in the BaseFederationRows + "data", # dict, serialization of a federation.send_queue.BaseFederationRow + ), +) class FederationStream(Stream): """Data to be sent over federation. Only available when master has federation sending disabled. """ + NAME = "federation" ROW_TYPE = FederationStreamRow diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 195f103cdd..f161bc51a5 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -67,6 +67,7 @@ class ClientRestResource(JsonResource): * /_matrix/client/unstable * etc """ + def __init__(self, hs): JsonResource.__init__(self, hs, canonical_json=False) self.register_servlets(self, hs) diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index d6c4dcdb18..9843a902c6 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -61,7 +61,7 @@ def historical_admin_path_patterns(path_regex): "^/_synapse/admin/v1", "^/_matrix/client/api/v1/admin", "^/_matrix/client/unstable/admin", - "^/_matrix/client/r0/admin" + "^/_matrix/client/r0/admin", ) ) @@ -88,12 +88,12 @@ class UsersRestServlet(RestServlet): class VersionServlet(RestServlet): - PATTERNS = (re.compile("^/_synapse/admin/v1/server_version$"), ) + PATTERNS = (re.compile("^/_synapse/admin/v1/server_version$"),) def __init__(self, hs): self.res = { - 'server_version': get_version_string(synapse), - 'python_version': platform.python_version(), + "server_version": get_version_string(synapse), + "python_version": platform.python_version(), } def on_GET(self, request): @@ -107,6 +107,7 @@ class UserRegisterServlet(RestServlet): nonces (dict[str, int]): The nonces that we will accept. A dict of nonce to the time it was generated, in int seconds. """ + PATTERNS = historical_admin_path_patterns("/register") NONCE_TIMEOUT = 60 @@ -146,28 +147,24 @@ class UserRegisterServlet(RestServlet): body = parse_json_object_from_request(request) if "nonce" not in body: - raise SynapseError( - 400, "nonce must be specified", errcode=Codes.BAD_JSON, - ) + raise SynapseError(400, "nonce must be specified", errcode=Codes.BAD_JSON) nonce = body["nonce"] if nonce not in self.nonces: - raise SynapseError( - 400, "unrecognised nonce", - ) + raise SynapseError(400, "unrecognised nonce") # Delete the nonce, so it can't be reused, even if it's invalid del self.nonces[nonce] if "username" not in body: raise SynapseError( - 400, "username must be specified", errcode=Codes.BAD_JSON, + 400, "username must be specified", errcode=Codes.BAD_JSON ) else: if ( - not isinstance(body['username'], text_type) - or len(body['username']) > 512 + not isinstance(body["username"], text_type) + or len(body["username"]) > 512 ): raise SynapseError(400, "Invalid username") @@ -177,12 +174,12 @@ class UserRegisterServlet(RestServlet): if "password" not in body: raise SynapseError( - 400, "password must be specified", errcode=Codes.BAD_JSON, + 400, "password must be specified", errcode=Codes.BAD_JSON ) else: if ( - not isinstance(body['password'], text_type) - or len(body['password']) > 512 + not isinstance(body["password"], text_type) + or len(body["password"]) > 512 ): raise SynapseError(400, "Invalid password") @@ -202,7 +199,7 @@ class UserRegisterServlet(RestServlet): key=self.hs.config.registration_shared_secret.encode(), digestmod=hashlib.sha1, ) - want_mac.update(nonce.encode('utf8')) + want_mac.update(nonce.encode("utf8")) want_mac.update(b"\x00") want_mac.update(username) want_mac.update(b"\x00") @@ -211,13 +208,10 @@ class UserRegisterServlet(RestServlet): want_mac.update(b"admin" if admin else b"notadmin") if user_type: want_mac.update(b"\x00") - want_mac.update(user_type.encode('utf8')) + want_mac.update(user_type.encode("utf8")) want_mac = want_mac.hexdigest() - if not hmac.compare_digest( - want_mac.encode('ascii'), - got_mac.encode('ascii') - ): + if not hmac.compare_digest(want_mac.encode("ascii"), got_mac.encode("ascii")): raise SynapseError(403, "HMAC incorrect") # Reuse the parts of RegisterRestServlet to reduce code duplication @@ -226,7 +220,7 @@ class UserRegisterServlet(RestServlet): register = RegisterRestServlet(self.hs) (user_id, _) = yield register.registration_handler.register( - localpart=body['username'].lower(), + localpart=body["username"].lower(), password=body["password"], admin=bool(admin), generate_token=False, @@ -308,7 +302,7 @@ class PurgeHistoryRestServlet(RestServlet): # user can provide an event_id in the URL or the request body, or can # provide a timestamp in the request body. if event_id is None: - event_id = body.get('purge_up_to_event_id') + event_id = body.get("purge_up_to_event_id") if event_id is not None: event = yield self.store.get_event(event_id) @@ -318,44 +312,39 @@ class PurgeHistoryRestServlet(RestServlet): token = yield self.store.get_topological_token_for_event(event_id) - logger.info( - "[purge] purging up to token %s (event_id %s)", - token, event_id, - ) - elif 'purge_up_to_ts' in body: - ts = body['purge_up_to_ts'] + logger.info("[purge] purging up to token %s (event_id %s)", token, event_id) + elif "purge_up_to_ts" in body: + ts = body["purge_up_to_ts"] if not isinstance(ts, int): raise SynapseError( - 400, "purge_up_to_ts must be an int", - errcode=Codes.BAD_JSON, + 400, "purge_up_to_ts must be an int", errcode=Codes.BAD_JSON ) - stream_ordering = ( - yield self.store.find_first_stream_ordering_after_ts(ts) - ) + stream_ordering = (yield self.store.find_first_stream_ordering_after_ts(ts)) r = ( yield self.store.get_room_event_after_stream_ordering( - room_id, stream_ordering, + room_id, stream_ordering ) ) if not r: logger.warn( "[purge] purging events not possible: No event found " "(received_ts %i => stream_ordering %i)", - ts, stream_ordering, + ts, + stream_ordering, ) raise SynapseError( - 404, - "there is no event to be purged", - errcode=Codes.NOT_FOUND, + 404, "there is no event to be purged", errcode=Codes.NOT_FOUND ) (stream, topo, _event_id) = r token = "t%d-%d" % (topo, stream) logger.info( "[purge] purging up to token %s (received_ts %i => " "stream_ordering %i)", - token, ts, stream_ordering, + token, + ts, + stream_ordering, ) else: raise SynapseError( @@ -365,13 +354,10 @@ class PurgeHistoryRestServlet(RestServlet): ) purge_id = yield self.pagination_handler.start_purge_history( - room_id, token, - delete_local_events=delete_local_events, + room_id, token, delete_local_events=delete_local_events ) - defer.returnValue((200, { - "purge_id": purge_id, - })) + defer.returnValue((200, {"purge_id": purge_id})) class PurgeHistoryStatusRestServlet(RestServlet): @@ -421,16 +407,14 @@ class DeactivateAccountRestServlet(RestServlet): UserID.from_string(target_user_id) result = yield self._deactivate_account_handler.deactivate_account( - target_user_id, erase, + target_user_id, erase ) if result: id_server_unbind_result = "success" else: id_server_unbind_result = "no-support" - defer.returnValue((200, { - "id_server_unbind_result": id_server_unbind_result, - })) + defer.returnValue((200, {"id_server_unbind_result": id_server_unbind_result})) class ShutdownRoomRestServlet(RestServlet): @@ -439,6 +423,7 @@ class ShutdownRoomRestServlet(RestServlet): to a new room created by `new_room_user_id` and kicked users will be auto joined to the new room. """ + PATTERNS = historical_admin_path_patterns("/shutdown_room/(?P[^/]+)") DEFAULT_MESSAGE = ( @@ -474,9 +459,7 @@ class ShutdownRoomRestServlet(RestServlet): config={ "preset": "public_chat", "name": room_name, - "power_level_content_override": { - "users_default": -10, - }, + "power_level_content_override": {"users_default": -10}, }, ratelimit=False, ) @@ -485,8 +468,7 @@ class ShutdownRoomRestServlet(RestServlet): requester_user_id = requester.user.to_string() logger.info( - "Shutting down room %r, joining to new room: %r", - room_id, new_room_id, + "Shutting down room %r, joining to new room: %r", room_id, new_room_id ) # This will work even if the room is already blocked, but that is @@ -529,7 +511,7 @@ class ShutdownRoomRestServlet(RestServlet): kicked_users.append(user_id) except Exception: logger.exception( - "Failed to leave old room and join new room for %r", user_id, + "Failed to leave old room and join new room for %r", user_id ) failed_to_kick_users.append(user_id) @@ -550,18 +532,24 @@ class ShutdownRoomRestServlet(RestServlet): room_id, new_room_id, requester_user_id ) - defer.returnValue((200, { - "kicked_users": kicked_users, - "failed_to_kick_users": failed_to_kick_users, - "local_aliases": aliases_for_room, - "new_room_id": new_room_id, - })) + defer.returnValue( + ( + 200, + { + "kicked_users": kicked_users, + "failed_to_kick_users": failed_to_kick_users, + "local_aliases": aliases_for_room, + "new_room_id": new_room_id, + }, + ) + ) class QuarantineMediaInRoom(RestServlet): """Quarantines all media in a room so that no one can download it via this server. """ + PATTERNS = historical_admin_path_patterns("/quarantine_media/(?P[^/]+)") def __init__(self, hs): @@ -574,7 +562,7 @@ class QuarantineMediaInRoom(RestServlet): yield assert_user_is_admin(self.auth, requester.user) num_quarantined = yield self.store.quarantine_media_ids_in_room( - room_id, requester.user.to_string(), + room_id, requester.user.to_string() ) defer.returnValue((200, {"num_quarantined": num_quarantined})) @@ -583,6 +571,7 @@ class QuarantineMediaInRoom(RestServlet): class ListMediaInRoom(RestServlet): """Lists all of the media in a given room. """ + PATTERNS = historical_admin_path_patterns("/room/(?P[^/]+)/media") def __init__(self, hs): @@ -613,7 +602,10 @@ class ResetPasswordRestServlet(RestServlet): Returns: 200 OK with empty object if success otherwise an error. """ - PATTERNS = historical_admin_path_patterns("/reset_password/(?P[^/]*)") + + PATTERNS = historical_admin_path_patterns( + "/reset_password/(?P[^/]*)" + ) def __init__(self, hs): self.store = hs.get_datastore() @@ -633,7 +625,7 @@ class ResetPasswordRestServlet(RestServlet): params = parse_json_object_from_request(request) assert_params_in_dict(params, ["new_password"]) - new_password = params['new_password'] + new_password = params["new_password"] yield self._set_password_handler.set_password( target_user_id, new_password, requester @@ -650,7 +642,10 @@ class GetUsersPaginatedRestServlet(RestServlet): Returns: 200 OK with json object {list[dict[str, Any]], count} or empty object. """ - PATTERNS = historical_admin_path_patterns("/users_paginate/(?P[^/]*)") + + PATTERNS = historical_admin_path_patterns( + "/users_paginate/(?P[^/]*)" + ) def __init__(self, hs): self.store = hs.get_datastore() @@ -676,9 +671,7 @@ class GetUsersPaginatedRestServlet(RestServlet): logger.info("limit: %s, start: %s", limit, start) - ret = yield self.handlers.admin_handler.get_users_paginate( - order, start, limit - ) + ret = yield self.handlers.admin_handler.get_users_paginate(order, start, limit) defer.returnValue((200, ret)) @defer.inlineCallbacks @@ -702,13 +695,11 @@ class GetUsersPaginatedRestServlet(RestServlet): order = "name" # order by name in user table params = parse_json_object_from_request(request) assert_params_in_dict(params, ["limit", "start"]) - limit = params['limit'] - start = params['start'] + limit = params["limit"] + start = params["start"] logger.info("limit: %s, start: %s", limit, start) - ret = yield self.handlers.admin_handler.get_users_paginate( - order, start, limit - ) + ret = yield self.handlers.admin_handler.get_users_paginate(order, start, limit) defer.returnValue((200, ret)) @@ -722,6 +713,7 @@ class SearchUsersRestServlet(RestServlet): Returns: 200 OK with json object {list[dict[str, Any]], count} or empty object. """ + PATTERNS = historical_admin_path_patterns("/search_users/(?P[^/]*)") def __init__(self, hs): @@ -750,15 +742,14 @@ class SearchUsersRestServlet(RestServlet): term = parse_string(request, "term", required=True) logger.info("term: %s ", term) - ret = yield self.handlers.admin_handler.search_users( - term - ) + ret = yield self.handlers.admin_handler.search_users(term) defer.returnValue((200, ret)) class DeleteGroupAdminRestServlet(RestServlet): """Allows deleting of local groups """ + PATTERNS = historical_admin_path_patterns("/delete_group/(?P[^/]*)") def __init__(self, hs): @@ -800,15 +791,15 @@ class AccountValidityRenewServlet(RestServlet): raise SynapseError(400, "Missing property 'user_id' in the request body") expiration_ts = yield self.account_activity_handler.renew_account_for_user( - body["user_id"], body.get("expiration_ts"), + body["user_id"], + body.get("expiration_ts"), not body.get("enable_renewal_emails", True), ) - res = { - "expiration_ts": expiration_ts, - } + res = {"expiration_ts": expiration_ts} defer.returnValue((200, res)) + ######################################################################################## # # please don't add more servlets here: this file is already long and unwieldy. Put diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py index ae5aca9dac..ee66838a0d 100644 --- a/synapse/rest/admin/server_notice_servlet.py +++ b/synapse/rest/admin/server_notice_servlet.py @@ -46,6 +46,7 @@ class SendServerNoticeServlet(RestServlet): "event_id": "$1895723857jgskldgujpious" } """ + def __init__(self, hs): """ Args: @@ -58,15 +59,9 @@ class SendServerNoticeServlet(RestServlet): def register(self, json_resource): PATTERN = "^/_synapse/admin/v1/send_server_notice" + json_resource.register_paths("POST", (re.compile(PATTERN + "$"),), self.on_POST) json_resource.register_paths( - "POST", - (re.compile(PATTERN + "$"), ), - self.on_POST, - ) - json_resource.register_paths( - "PUT", - (re.compile(PATTERN + "/(?P[^/]*)$",), ), - self.on_PUT, + "PUT", (re.compile(PATTERN + "/(?P[^/]*)$"),), self.on_PUT ) @defer.inlineCallbacks @@ -96,5 +91,5 @@ class SendServerNoticeServlet(RestServlet): def on_PUT(self, request, txn_id): return self.txns.fetch_or_execute_request( - request, self.on_POST, request, txn_id, + request, self.on_POST, request, txn_id ) diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py index 48c17f1b6d..36404b797d 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py @@ -26,7 +26,6 @@ CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins class HttpTransactionCache(object): - def __init__(self, hs): self.hs = hs self.auth = self.hs.get_auth() @@ -53,7 +52,7 @@ class HttpTransactionCache(object): str: A transaction key """ token = self.auth.get_access_token_from_request(request) - return request.path.decode('utf8') + "/" + token + return request.path.decode("utf8") + "/" + token def fetch_or_execute_request(self, request, fn, *args, **kwargs): """A helper function for fetch_or_execute which extracts diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 0035182bb9..dd0d38ea5c 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -56,8 +56,9 @@ class ClientDirectoryServer(RestServlet): content = parse_json_object_from_request(request) if "room_id" not in content: - raise SynapseError(400, 'Missing params: ["room_id"]', - errcode=Codes.BAD_JSON) + raise SynapseError( + 400, 'Missing params: ["room_id"]', errcode=Codes.BAD_JSON + ) logger.debug("Got content: %s", content) logger.debug("Got room name: %s", room_alias.to_string()) @@ -89,13 +90,11 @@ class ClientDirectoryServer(RestServlet): try: service = yield self.auth.get_appservice_by_req(request) room_alias = RoomAlias.from_string(room_alias) - yield dir_handler.delete_appservice_association( - service, room_alias - ) + yield dir_handler.delete_appservice_association(service, room_alias) logger.info( "Application service at %s deleted alias %s", service.url, - room_alias.to_string() + room_alias.to_string(), ) defer.returnValue((200, {})) except AuthError: @@ -107,14 +106,10 @@ class ClientDirectoryServer(RestServlet): room_alias = RoomAlias.from_string(room_alias) - yield dir_handler.delete_association( - requester, room_alias - ) + yield dir_handler.delete_association(requester, room_alias) logger.info( - "User %s deleted alias %s", - user.to_string(), - room_alias.to_string() + "User %s deleted alias %s", user.to_string(), room_alias.to_string() ) defer.returnValue((200, {})) @@ -135,9 +130,9 @@ class ClientDirectoryListServer(RestServlet): if room is None: raise NotFoundError("Unknown room") - defer.returnValue((200, { - "visibility": "public" if room["is_public"] else "private" - })) + defer.returnValue( + (200, {"visibility": "public" if room["is_public"] else "private"}) + ) @defer.inlineCallbacks def on_PUT(self, request, room_id): @@ -147,7 +142,7 @@ class ClientDirectoryListServer(RestServlet): visibility = content.get("visibility", "public") yield self.handlers.directory_handler.edit_published_room_list( - requester, room_id, visibility, + requester, room_id, visibility ) defer.returnValue((200, {})) @@ -157,7 +152,7 @@ class ClientDirectoryListServer(RestServlet): requester = yield self.auth.get_user_by_req(request) yield self.handlers.directory_handler.edit_published_room_list( - requester, room_id, "private", + requester, room_id, "private" ) defer.returnValue((200, {})) @@ -191,7 +186,7 @@ class ClientAppserviceDirectoryListServer(RestServlet): ) yield self.handlers.directory_handler.edit_published_appservice_room_list( - requester.app_service.id, network_id, room_id, visibility, + requester.app_service.id, network_id, room_id, visibility ) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 84ca36270b..d6de2b7360 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -38,17 +38,14 @@ class EventStreamRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request): - requester = yield self.auth.get_user_by_req( - request, - allow_guest=True, - ) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) is_guest = requester.is_guest room_id = None if is_guest: if b"room_id" not in request.args: raise SynapseError(400, "Guest users must specify room_id param") if b"room_id" in request.args: - room_id = request.args[b"room_id"][0].decode('ascii') + room_id = request.args[b"room_id"][0].decode("ascii") pagin_config = PaginationConfig.from_request(request) timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 7c86b88f30..6526e1403a 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -44,10 +44,7 @@ def login_submission_legacy_convert(submission): to a typed object. """ if "user" in submission: - submission["identifier"] = { - "type": "m.id.user", - "user": submission["user"], - } + submission["identifier"] = {"type": "m.id.user", "user": submission["user"]} del submission["user"] if "medium" in submission and "address" in submission: @@ -73,11 +70,7 @@ def login_id_thirdparty_from_phone(identifier): msisdn = phone_number_to_msisdn(identifier["country"], identifier["number"]) - return { - "type": "m.id.thirdparty", - "medium": "msisdn", - "address": msisdn, - } + return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn} class LoginRestServlet(RestServlet): @@ -120,9 +113,9 @@ class LoginRestServlet(RestServlet): # login flow types returned. flows.append({"type": LoginRestServlet.TOKEN_TYPE}) - flows.extend(( - {"type": t} for t in self.auth_handler.get_supported_login_types() - )) + flows.extend( + ({"type": t} for t in self.auth_handler.get_supported_login_types()) + ) return (200, {"flows": flows}) @@ -132,7 +125,8 @@ class LoginRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request): self._address_ratelimiter.ratelimit( - request.getClientIP(), time_now_s=self.hs.clock.time(), + request.getClientIP(), + time_now_s=self.hs.clock.time(), rate_hz=self.hs.config.rc_login_address.per_second, burst_count=self.hs.config.rc_login_address.burst_count, update=True, @@ -140,8 +134,9 @@ class LoginRestServlet(RestServlet): login_submission = parse_json_object_from_request(request) try: - if self.jwt_enabled and (login_submission["type"] == - LoginRestServlet.JWT_TYPE): + if self.jwt_enabled and ( + login_submission["type"] == LoginRestServlet.JWT_TYPE + ): result = yield self.do_jwt_login(login_submission) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: result = yield self.do_token_login(login_submission) @@ -170,10 +165,10 @@ class LoginRestServlet(RestServlet): # field) logger.info( "Got login request with identifier: %r, medium: %r, address: %r, user: %r", - login_submission.get('identifier'), - login_submission.get('medium'), - login_submission.get('address'), - login_submission.get('user'), + login_submission.get("identifier"), + login_submission.get("medium"), + login_submission.get("address"), + login_submission.get("user"), ) login_submission_legacy_convert(login_submission) @@ -190,13 +185,13 @@ class LoginRestServlet(RestServlet): # convert threepid identifiers to user IDs if identifier["type"] == "m.id.thirdparty": - address = identifier.get('address') - medium = identifier.get('medium') + address = identifier.get("address") + medium = identifier.get("medium") if medium is None or address is None: raise SynapseError(400, "Invalid thirdparty identifier") - if medium == 'email': + if medium == "email": # For emails, transform the address to lowercase. # We store all email addreses as lowercase in the DB. # (See add_threepid in synapse/handlers/auth.py) @@ -205,34 +200,28 @@ class LoginRestServlet(RestServlet): # Check for login providers that support 3pid login types canonical_user_id, callback_3pid = ( yield self.auth_handler.check_password_provider_3pid( - medium, - address, - login_submission["password"], + medium, address, login_submission["password"] ) ) if canonical_user_id: # Authentication through password provider and 3pid succeeded result = yield self._register_device_with_callback( - canonical_user_id, login_submission, callback_3pid, + canonical_user_id, login_submission, callback_3pid ) defer.returnValue(result) # No password providers were able to handle this 3pid # Check local store user_id = yield self.hs.get_datastore().get_user_id_by_threepid( - medium, address, + medium, address ) if not user_id: logger.warn( - "unknown 3pid identifier medium %s, address %r", - medium, address, + "unknown 3pid identifier medium %s, address %r", medium, address ) raise LoginError(403, "", errcode=Codes.FORBIDDEN) - identifier = { - "type": "m.id.user", - "user": user_id, - } + identifier = {"type": "m.id.user", "user": user_id} # by this point, the identifier should be an m.id.user: if it's anything # else, we haven't understood it. @@ -242,22 +231,16 @@ class LoginRestServlet(RestServlet): raise SynapseError(400, "User identifier is missing 'user' key") canonical_user_id, callback = yield self.auth_handler.validate_login( - identifier["user"], - login_submission, + identifier["user"], login_submission ) result = yield self._register_device_with_callback( - canonical_user_id, login_submission, callback, + canonical_user_id, login_submission, callback ) defer.returnValue(result) @defer.inlineCallbacks - def _register_device_with_callback( - self, - user_id, - login_submission, - callback=None, - ): + def _register_device_with_callback(self, user_id, login_submission, callback=None): """ Registers a device with a given user_id. Optionally run a callback function after registration has completed. @@ -273,7 +256,7 @@ class LoginRestServlet(RestServlet): device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( - user_id, device_id, initial_display_name, + user_id, device_id, initial_display_name ) result = { @@ -290,7 +273,7 @@ class LoginRestServlet(RestServlet): @defer.inlineCallbacks def do_token_login(self, login_submission): - token = login_submission['token'] + token = login_submission["token"] auth_handler = self.auth_handler user_id = ( yield auth_handler.validate_short_term_login_token_and_get_user_id(token) @@ -299,7 +282,7 @@ class LoginRestServlet(RestServlet): device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( - user_id, device_id, initial_display_name, + user_id, device_id, initial_display_name ) result = { @@ -316,15 +299,16 @@ class LoginRestServlet(RestServlet): token = login_submission.get("token", None) if token is None: raise LoginError( - 401, "Token field for JWT is missing", - errcode=Codes.UNAUTHORIZED + 401, "Token field for JWT is missing", errcode=Codes.UNAUTHORIZED ) import jwt from jwt.exceptions import InvalidTokenError try: - payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm]) + payload = jwt.decode( + token, self.jwt_secret, algorithms=[self.jwt_algorithm] + ) except jwt.ExpiredSignatureError: raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED) except InvalidTokenError: @@ -342,7 +326,7 @@ class LoginRestServlet(RestServlet): device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( - registered_user_id, device_id, initial_display_name, + registered_user_id, device_id, initial_display_name ) result = { @@ -358,7 +342,7 @@ class LoginRestServlet(RestServlet): device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( - registered_user_id, device_id, initial_display_name, + registered_user_id, device_id, initial_display_name ) result = { @@ -375,21 +359,20 @@ class CasRedirectServlet(RestServlet): def __init__(self, hs): super(CasRedirectServlet, self).__init__() - self.cas_server_url = hs.config.cas_server_url.encode('ascii') - self.cas_service_url = hs.config.cas_service_url.encode('ascii') + self.cas_server_url = hs.config.cas_server_url.encode("ascii") + self.cas_service_url = hs.config.cas_service_url.encode("ascii") def on_GET(self, request): args = request.args if b"redirectUrl" not in args: return (400, "Redirect URL not specified for CAS auth") - client_redirect_url_param = urllib.parse.urlencode({ - b"redirectUrl": args[b"redirectUrl"][0] - }).encode('ascii') - hs_redirect_url = (self.cas_service_url + - b"/_matrix/client/r0/login/cas/ticket") - service_param = urllib.parse.urlencode({ - b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param) - }).encode('ascii') + client_redirect_url_param = urllib.parse.urlencode( + {b"redirectUrl": args[b"redirectUrl"][0]} + ).encode("ascii") + hs_redirect_url = self.cas_service_url + b"/_matrix/client/r0/login/cas/ticket" + service_param = urllib.parse.urlencode( + {b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)} + ).encode("ascii") request.redirect(b"%s/login?%s" % (self.cas_server_url, service_param)) finish_request(request) @@ -411,7 +394,7 @@ class CasTicketServlet(RestServlet): uri = self.cas_server_url + "/proxyValidate" args = { "ticket": parse_string(request, "ticket", required=True), - "service": self.cas_service_url + "service": self.cas_service_url, } try: body = yield self._http_client.get_raw(uri, args) @@ -438,7 +421,7 @@ class CasTicketServlet(RestServlet): raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) return self._sso_auth_handler.on_successful_auth( - user, request, client_redirect_url, + user, request, client_redirect_url ) def parse_cas_response(self, cas_response_body): @@ -448,7 +431,7 @@ class CasTicketServlet(RestServlet): root = ET.fromstring(cas_response_body) if not root.tag.endswith("serviceResponse"): raise Exception("root of CAS response is not serviceResponse") - success = (root[0].tag.endswith("authenticationSuccess")) + success = root[0].tag.endswith("authenticationSuccess") for child in root[0]: if child.tag.endswith("user"): user = child.text @@ -466,11 +449,11 @@ class CasTicketServlet(RestServlet): raise Exception("CAS response does not contain user") except Exception: logger.error("Error parsing CAS response", exc_info=1) - raise LoginError(401, "Invalid CAS response", - errcode=Codes.UNAUTHORIZED) + raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) if not success: - raise LoginError(401, "Unsuccessful CAS response", - errcode=Codes.UNAUTHORIZED) + raise LoginError( + 401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED + ) return user, attributes @@ -482,6 +465,7 @@ class SSOAuthHandler(object): Args: hs (synapse.server.HomeServer) """ + def __init__(self, hs): self._hostname = hs.hostname self._auth_handler = hs.get_auth_handler() @@ -490,8 +474,7 @@ class SSOAuthHandler(object): @defer.inlineCallbacks def on_successful_auth( - self, username, request, client_redirect_url, - user_display_name=None, + self, username, request, client_redirect_url, user_display_name=None ): """Called once the user has successfully authenticated with the SSO. diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py index b8064f261e..cd711be519 100644 --- a/synapse/rest/client/v1/logout.py +++ b/synapse/rest/client/v1/logout.py @@ -46,7 +46,8 @@ class LogoutRestServlet(RestServlet): yield self._auth_handler.delete_access_token(access_token) else: yield self._device_handler.delete_device( - requester.user.to_string(), requester.device_id) + requester.user.to_string(), requester.device_id + ) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index e263da3cb7..3e87f0fdb3 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -47,7 +47,7 @@ class PresenceStatusRestServlet(RestServlet): if requester.user != user: allowed = yield self.presence_handler.is_visible( - observed_user=user, observer_user=requester.user, + observed_user=user, observer_user=requester.user ) if not allowed: diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index 34361697df..dd3962a0cb 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -68,13 +68,10 @@ class ProfileDisplaynameRestServlet(RestServlet): except Exception: defer.returnValue((400, "Unable to parse name")) - yield self.profile_handler.set_displayname( - user, requester, new_name, is_admin) + yield self.profile_handler.set_displayname(user, requester, new_name, is_admin) if self.hs.config.shadow_server: - shadow_user = UserID( - user.localpart, self.hs.config.shadow_server.get("hs") - ) + shadow_user = UserID(user.localpart, self.hs.config.shadow_server.get("hs")) self.shadow_displayname(shadow_user.to_string(), content) defer.returnValue((200, {})) @@ -89,10 +86,9 @@ class ProfileDisplaynameRestServlet(RestServlet): as_token = self.hs.config.shadow_server.get("as_token") yield self.http_client.put_json( - "%s/_matrix/client/r0/profile/%s/displayname?access_token=%s&user_id=%s" % ( - shadow_hs_url, user_id, as_token, user_id - ), - body + "%s/_matrix/client/r0/profile/%s/displayname?access_token=%s&user_id=%s" + % (shadow_hs_url, user_id, as_token, user_id), + body, ) @@ -143,9 +139,7 @@ class ProfileAvatarURLRestServlet(RestServlet): ) if self.hs.config.shadow_server: - shadow_user = UserID( - user.localpart, self.hs.config.shadow_server.get("hs") - ) + shadow_user = UserID(user.localpart, self.hs.config.shadow_server.get("hs")) self.shadow_avatar_url(shadow_user.to_string(), content) defer.returnValue((200, {})) @@ -160,10 +154,9 @@ class ProfileAvatarURLRestServlet(RestServlet): as_token = self.hs.config.shadow_server.get("as_token") yield self.http_client.put_json( - "%s/_matrix/client/r0/profile/%s/avatar_url?access_token=%s&user_id=%s" % ( - shadow_hs_url, user_id, as_token, user_id - ), - body + "%s/_matrix/client/r0/profile/%s/avatar_url?access_token=%s&user_id=%s" + % (shadow_hs_url, user_id, as_token, user_id), + body, ) diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 3d6326fe2f..e635efb420 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -21,7 +21,11 @@ from synapse.api.errors import ( SynapseError, UnrecognizedRequestError, ) -from synapse.http.servlet import RestServlet, parse_json_value_from_request, parse_string +from synapse.http.servlet import ( + RestServlet, + parse_json_value_from_request, + parse_string, +) from synapse.push.baserules import BASE_RULE_IDS from synapse.push.clientformat import format_push_rules_for_user from synapse.push.rulekinds import PRIORITY_CLASS_MAP @@ -32,7 +36,8 @@ from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundExc class PushRuleRestServlet(RestServlet): PATTERNS = client_patterns("/(?Ppushrules/.*)$", v1=True) SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( - "Unrecognised request: You probably wanted a trailing slash") + "Unrecognised request: You probably wanted a trailing slash" + ) def __init__(self, hs): super(PushRuleRestServlet, self).__init__() @@ -54,27 +59,25 @@ class PushRuleRestServlet(RestServlet): requester = yield self.auth.get_user_by_req(request) - if '/' in spec['rule_id'] or '\\' in spec['rule_id']: + if "/" in spec["rule_id"] or "\\" in spec["rule_id"]: raise SynapseError(400, "rule_id may not contain slashes") content = parse_json_value_from_request(request) user_id = requester.user.to_string() - if 'attr' in spec: + if "attr" in spec: yield self.set_rule_attr(user_id, spec, content) self.notify_user(user_id) defer.returnValue((200, {})) - if spec['rule_id'].startswith('.'): + if spec["rule_id"].startswith("."): # Rule ids starting with '.' are reserved for server default rules. raise SynapseError(400, "cannot add new rule_ids that start with '.'") try: (conditions, actions) = _rule_tuple_from_request_object( - spec['template'], - spec['rule_id'], - content, + spec["template"], spec["rule_id"], content ) except InvalidRuleException as e: raise SynapseError(400, str(e)) @@ -95,7 +98,7 @@ class PushRuleRestServlet(RestServlet): conditions=conditions, actions=actions, before=before, - after=after + after=after, ) self.notify_user(user_id) except InconsistentRuleException as e: @@ -118,9 +121,7 @@ class PushRuleRestServlet(RestServlet): namespaced_rule_id = _namespaced_rule_id_from_spec(spec) try: - yield self.store.delete_push_rule( - user_id, namespaced_rule_id - ) + yield self.store.delete_push_rule(user_id, namespaced_rule_id) self.notify_user(user_id) defer.returnValue((200, {})) except StoreError as e: @@ -149,10 +150,10 @@ class PushRuleRestServlet(RestServlet): PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR ) - if path[0] == '': + if path[0] == "": defer.returnValue((200, rules)) - elif path[0] == 'global': - result = _filter_ruleset_with_path(rules['global'], path[1:]) + elif path[0] == "global": + result = _filter_ruleset_with_path(rules["global"], path[1:]) defer.returnValue((200, result)) else: raise UnrecognizedRequestError() @@ -162,12 +163,10 @@ class PushRuleRestServlet(RestServlet): def notify_user(self, user_id): stream_id, _ = self.store.get_push_rules_stream_token() - self.notifier.on_new_event( - "push_rules_key", stream_id, users=[user_id] - ) + self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id]) def set_rule_attr(self, user_id, spec, val): - if spec['attr'] == 'enabled': + if spec["attr"] == "enabled": if isinstance(val, dict) and "enabled" in val: val = val["enabled"] if not isinstance(val, bool): @@ -176,14 +175,12 @@ class PushRuleRestServlet(RestServlet): # bools directly, so let's not break them. raise SynapseError(400, "Value for 'enabled' must be boolean") namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - return self.store.set_push_rule_enabled( - user_id, namespaced_rule_id, val - ) - elif spec['attr'] == 'actions': - actions = val.get('actions') + return self.store.set_push_rule_enabled(user_id, namespaced_rule_id, val) + elif spec["attr"] == "actions": + actions = val.get("actions") _check_actions(actions) namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - rule_id = spec['rule_id'] + rule_id = spec["rule_id"] is_default_rule = rule_id.startswith(".") if is_default_rule: if namespaced_rule_id not in BASE_RULE_IDS: @@ -210,12 +207,12 @@ def _rule_spec_from_path(path): """ if len(path) < 2: raise UnrecognizedRequestError() - if path[0] != 'pushrules': + if path[0] != "pushrules": raise UnrecognizedRequestError() scope = path[1] path = path[2:] - if scope != 'global': + if scope != "global": raise UnrecognizedRequestError() if len(path) == 0: @@ -229,56 +226,40 @@ def _rule_spec_from_path(path): rule_id = path[0] - spec = { - 'scope': scope, - 'template': template, - 'rule_id': rule_id - } + spec = {"scope": scope, "template": template, "rule_id": rule_id} path = path[1:] if len(path) > 0 and len(path[0]) > 0: - spec['attr'] = path[0] + spec["attr"] = path[0] return spec def _rule_tuple_from_request_object(rule_template, rule_id, req_obj): - if rule_template in ['override', 'underride']: - if 'conditions' not in req_obj: + if rule_template in ["override", "underride"]: + if "conditions" not in req_obj: raise InvalidRuleException("Missing 'conditions'") - conditions = req_obj['conditions'] + conditions = req_obj["conditions"] for c in conditions: - if 'kind' not in c: + if "kind" not in c: raise InvalidRuleException("Condition without 'kind'") - elif rule_template == 'room': - conditions = [{ - 'kind': 'event_match', - 'key': 'room_id', - 'pattern': rule_id - }] - elif rule_template == 'sender': - conditions = [{ - 'kind': 'event_match', - 'key': 'user_id', - 'pattern': rule_id - }] - elif rule_template == 'content': - if 'pattern' not in req_obj: + elif rule_template == "room": + conditions = [{"kind": "event_match", "key": "room_id", "pattern": rule_id}] + elif rule_template == "sender": + conditions = [{"kind": "event_match", "key": "user_id", "pattern": rule_id}] + elif rule_template == "content": + if "pattern" not in req_obj: raise InvalidRuleException("Content rule missing 'pattern'") - pat = req_obj['pattern'] + pat = req_obj["pattern"] - conditions = [{ - 'kind': 'event_match', - 'key': 'content.body', - 'pattern': pat - }] + conditions = [{"kind": "event_match", "key": "content.body", "pattern": pat}] else: raise InvalidRuleException("Unknown rule template: %s" % (rule_template,)) - if 'actions' not in req_obj: + if "actions" not in req_obj: raise InvalidRuleException("No actions found") - actions = req_obj['actions'] + actions = req_obj["actions"] _check_actions(actions) @@ -290,9 +271,9 @@ def _check_actions(actions): raise InvalidRuleException("No actions found") for a in actions: - if a in ['notify', 'dont_notify', 'coalesce']: + if a in ["notify", "dont_notify", "coalesce"]: pass - elif isinstance(a, dict) and 'set_tweak' in a: + elif isinstance(a, dict) and "set_tweak" in a: pass else: raise InvalidRuleException("Unrecognised action") @@ -304,7 +285,7 @@ def _filter_ruleset_with_path(ruleset, path): PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR ) - if path[0] == '': + if path[0] == "": return ruleset template_kind = path[0] if template_kind not in ruleset: @@ -314,13 +295,13 @@ def _filter_ruleset_with_path(ruleset, path): raise UnrecognizedRequestError( PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR ) - if path[0] == '': + if path[0] == "": return ruleset[template_kind] rule_id = path[0] the_rule = None for r in ruleset[template_kind]: - if r['rule_id'] == rule_id: + if r["rule_id"] == rule_id: the_rule = r if the_rule is None: raise NotFoundError @@ -339,19 +320,19 @@ def _filter_ruleset_with_path(ruleset, path): def _priority_class_from_spec(spec): - if spec['template'] not in PRIORITY_CLASS_MAP.keys(): - raise InvalidRuleException("Unknown template: %s" % (spec['template'])) - pc = PRIORITY_CLASS_MAP[spec['template']] + if spec["template"] not in PRIORITY_CLASS_MAP.keys(): + raise InvalidRuleException("Unknown template: %s" % (spec["template"])) + pc = PRIORITY_CLASS_MAP[spec["template"]] return pc def _namespaced_rule_id_from_spec(spec): - return _namespaced_rule_id(spec, spec['rule_id']) + return _namespaced_rule_id(spec, spec["rule_id"]) def _namespaced_rule_id(spec, rule_id): - return "global/%s/%s" % (spec['template'], rule_id) + return "global/%s/%s" % (spec["template"], rule_id) class InvalidRuleException(Exception): diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 15d860db37..e9246018df 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -44,9 +44,7 @@ class PushersRestServlet(RestServlet): requester = yield self.auth.get_user_by_req(request) user = requester.user - pushers = yield self.hs.get_datastore().get_pushers_by_user_id( - user.to_string() - ) + pushers = yield self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) allowed_keys = [ "app_display_name", @@ -87,50 +85,61 @@ class PushersSetRestServlet(RestServlet): content = parse_json_object_from_request(request) - if ('pushkey' in content and 'app_id' in content - and 'kind' in content and - content['kind'] is None): + if ( + "pushkey" in content + and "app_id" in content + and "kind" in content + and content["kind"] is None + ): yield self.pusher_pool.remove_pusher( - content['app_id'], content['pushkey'], user_id=user.to_string() + content["app_id"], content["pushkey"], user_id=user.to_string() ) defer.returnValue((200, {})) assert_params_in_dict( content, - ['kind', 'app_id', 'app_display_name', - 'device_display_name', 'pushkey', 'lang', 'data'] + [ + "kind", + "app_id", + "app_display_name", + "device_display_name", + "pushkey", + "lang", + "data", + ], ) - logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind']) + logger.debug("set pushkey %s to kind %s", content["pushkey"], content["kind"]) logger.debug("Got pushers request with body: %r", content) append = False - if 'append' in content: - append = content['append'] + if "append" in content: + append = content["append"] if not append: yield self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user( - app_id=content['app_id'], - pushkey=content['pushkey'], - not_user_id=user.to_string() + app_id=content["app_id"], + pushkey=content["pushkey"], + not_user_id=user.to_string(), ) try: yield self.pusher_pool.add_pusher( user_id=user.to_string(), access_token=requester.access_token_id, - kind=content['kind'], - app_id=content['app_id'], - app_display_name=content['app_display_name'], - device_display_name=content['device_display_name'], - pushkey=content['pushkey'], - lang=content['lang'], - data=content['data'], - profile_tag=content.get('profile_tag', ""), + kind=content["kind"], + app_id=content["app_id"], + app_display_name=content["app_display_name"], + device_display_name=content["device_display_name"], + pushkey=content["pushkey"], + lang=content["lang"], + data=content["data"], + profile_tag=content.get("profile_tag", ""), ) except PusherConfigException as pce: - raise SynapseError(400, "Config Error: " + str(pce), - errcode=Codes.MISSING_PARAM) + raise SynapseError( + 400, "Config Error: " + str(pce), errcode=Codes.MISSING_PARAM + ) self.notifier.on_new_replication_data() @@ -144,6 +153,7 @@ class PushersRemoveRestServlet(RestServlet): """ To allow pusher to be delete by clicking a link (ie. GET request) """ + PATTERNS = client_patterns("/pushers/remove$", v1=True) SUCCESS_HTML = b"You have been unsubscribed" @@ -164,9 +174,7 @@ class PushersRemoveRestServlet(RestServlet): try: yield self.pusher_pool.remove_pusher( - app_id=app_id, - pushkey=pushkey, - user_id=user.to_string(), + app_id=app_id, pushkey=pushkey, user_id=user.to_string() ) except StoreError as se: if se.code != 404: @@ -177,9 +185,9 @@ class PushersRemoveRestServlet(RestServlet): request.setResponseCode(200) request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % ( - len(PushersRemoveRestServlet.SUCCESS_HTML), - )) + request.setHeader( + b"Content-Length", b"%d" % (len(PushersRemoveRestServlet.SUCCESS_HTML),) + ) request.write(PushersRemoveRestServlet.SUCCESS_HTML) finish_request(request) defer.returnValue(None) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 151b553730..95dc4baec0 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -61,18 +61,16 @@ class RoomCreateRestServlet(TransactionRestServlet): PATTERNS = "/createRoom" register_txn_path(self, PATTERNS, http_server) # define CORS for all of /rooms in RoomCreateRestServlet for simplicity - http_server.register_paths("OPTIONS", - client_patterns("/rooms(?:/.*)?$", v1=True), - self.on_OPTIONS) + http_server.register_paths( + "OPTIONS", client_patterns("/rooms(?:/.*)?$", v1=True), self.on_OPTIONS + ) # define CORS for /createRoom[/txnid] - http_server.register_paths("OPTIONS", - client_patterns("/createRoom(?:/.*)?$", v1=True), - self.on_OPTIONS) + http_server.register_paths( + "OPTIONS", client_patterns("/createRoom(?:/.*)?$", v1=True), self.on_OPTIONS + ) def on_PUT(self, request, txn_id): - return self.txns.fetch_or_execute_request( - request, self.on_POST, request - ) + return self.txns.fetch_or_execute_request(request, self.on_POST, request) @defer.inlineCallbacks def on_POST(self, request): @@ -107,21 +105,23 @@ class RoomStateEventRestServlet(TransactionRestServlet): no_state_key = "/rooms/(?P[^/]*)/state/(?P[^/]*)$" # /room/$roomid/state/$eventtype/$statekey - state_key = ("/rooms/(?P[^/]*)/state/" - "(?P[^/]*)/(?P[^/]*)$") - - http_server.register_paths("GET", - client_patterns(state_key, v1=True), - self.on_GET) - http_server.register_paths("PUT", - client_patterns(state_key, v1=True), - self.on_PUT) - http_server.register_paths("GET", - client_patterns(no_state_key, v1=True), - self.on_GET_no_state_key) - http_server.register_paths("PUT", - client_patterns(no_state_key, v1=True), - self.on_PUT_no_state_key) + state_key = ( + "/rooms/(?P[^/]*)/state/" + "(?P[^/]*)/(?P[^/]*)$" + ) + + http_server.register_paths( + "GET", client_patterns(state_key, v1=True), self.on_GET + ) + http_server.register_paths( + "PUT", client_patterns(state_key, v1=True), self.on_PUT + ) + http_server.register_paths( + "GET", client_patterns(no_state_key, v1=True), self.on_GET_no_state_key + ) + http_server.register_paths( + "PUT", client_patterns(no_state_key, v1=True), self.on_PUT_no_state_key + ) def on_GET_no_state_key(self, request, room_id, event_type): return self.on_GET(request, room_id, event_type, "") @@ -132,8 +132,9 @@ class RoomStateEventRestServlet(TransactionRestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id, event_type, state_key): requester = yield self.auth.get_user_by_req(request, allow_guest=True) - format = parse_string(request, "format", default="content", - allowed_values=["content", "event"]) + format = parse_string( + request, "format", default="content", allowed_values=["content", "event"] + ) msg_handler = self.message_handler data = yield msg_handler.get_room_data( @@ -145,9 +146,7 @@ class RoomStateEventRestServlet(TransactionRestServlet): ) if not data: - raise SynapseError( - 404, "Event not found.", errcode=Codes.NOT_FOUND - ) + raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) if format == "event": event = format_event_for_client_v2(data.get_dict()) @@ -182,9 +181,7 @@ class RoomStateEventRestServlet(TransactionRestServlet): ) else: event = yield self.event_creation_handler.create_and_send_nonmember_event( - requester, - event_dict, - txn_id=txn_id, + requester, event_dict, txn_id=txn_id ) ret = {} @@ -195,7 +192,6 @@ class RoomStateEventRestServlet(TransactionRestServlet): # TODO: Needs unit testing for generic events + feedback class RoomSendEventRestServlet(TransactionRestServlet): - def __init__(self, hs): super(RoomSendEventRestServlet, self).__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() @@ -203,7 +199,7 @@ class RoomSendEventRestServlet(TransactionRestServlet): def register(self, http_server): # /rooms/$roomid/send/$event_type[/$txn_id] - PATTERNS = ("/rooms/(?P[^/]*)/send/(?P[^/]*)") + PATTERNS = "/rooms/(?P[^/]*)/send/(?P[^/]*)" register_txn_path(self, PATTERNS, http_server, with_get=True) @defer.inlineCallbacks @@ -218,13 +214,11 @@ class RoomSendEventRestServlet(TransactionRestServlet): "sender": requester.user.to_string(), } - if b'ts' in request.args and requester.app_service: - event_dict['origin_server_ts'] = parse_integer(request, "ts", 0) + if b"ts" in request.args and requester.app_service: + event_dict["origin_server_ts"] = parse_integer(request, "ts", 0) event = yield self.event_creation_handler.create_and_send_nonmember_event( - requester, - event_dict, - txn_id=txn_id, + requester, event_dict, txn_id=txn_id ) defer.returnValue((200, {"event_id": event.event_id})) @@ -247,15 +241,12 @@ class JoinRoomAliasServlet(TransactionRestServlet): def register(self, http_server): # /join/$room_identifier[/$txn_id] - PATTERNS = ("/join/(?P[^/]*)") + PATTERNS = "/join/(?P[^/]*)" register_txn_path(self, PATTERNS, http_server) @defer.inlineCallbacks def on_POST(self, request, room_identifier, txn_id=None): - requester = yield self.auth.get_user_by_req( - request, - allow_guest=True, - ) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) try: content = parse_json_object_from_request(request) @@ -268,7 +259,7 @@ class JoinRoomAliasServlet(TransactionRestServlet): room_id = room_identifier try: remote_room_hosts = [ - x.decode('ascii') for x in request.args[b"server_name"] + x.decode("ascii") for x in request.args[b"server_name"] ] except Exception: remote_room_hosts = None @@ -278,9 +269,9 @@ class JoinRoomAliasServlet(TransactionRestServlet): room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias) room_id = room_id.to_string() else: - raise SynapseError(400, "%s was not legal room ID or room alias" % ( - room_identifier, - )) + raise SynapseError( + 400, "%s was not legal room ID or room alias" % (room_identifier,) + ) yield self.room_member_handler.update_membership( requester=requester, @@ -339,14 +330,11 @@ class PublicRoomListRestServlet(TransactionRestServlet): handler = self.hs.get_room_list_handler() if server: data = yield handler.get_remote_public_room_list( - server, - limit=limit, - since_token=since_token, + server, limit=limit, since_token=since_token ) else: data = yield handler.get_local_public_room_list( - limit=limit, - since_token=since_token, + limit=limit, since_token=since_token ) defer.returnValue((200, data)) @@ -439,16 +427,13 @@ class RoomMemberListRestServlet(RestServlet): chunk = [] for event in events: - if ( - (membership and event['content'].get("membership") != membership) or - (not_membership and event['content'].get("membership") == not_membership) + if (membership and event["content"].get("membership") != membership) or ( + not_membership and event["content"].get("membership") == not_membership ): continue chunk.append(event) - defer.returnValue((200, { - "chunk": chunk - })) + defer.returnValue((200, {"chunk": chunk})) # deprecated in favour of /members?membership=join? @@ -466,12 +451,10 @@ class JoinedRoomMemberListRestServlet(RestServlet): requester = yield self.auth.get_user_by_req(request) users_with_profile = yield self.message_handler.get_joined_members( - requester, room_id, + requester, room_id ) - defer.returnValue((200, { - "joined": users_with_profile, - })) + defer.returnValue((200, {"joined": users_with_profile})) # TODO: Needs better unit testing @@ -486,9 +469,7 @@ class RoomMessageListRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) - pagination_config = PaginationConfig.from_request( - request, default_limit=10, - ) + pagination_config = PaginationConfig.from_request(request, default_limit=10) as_client_event = b"raw" not in request.args filter_bytes = parse_string(request, b"filter", encoding=None) if filter_bytes: @@ -544,9 +525,7 @@ class RoomInitialSyncRestServlet(RestServlet): requester = yield self.auth.get_user_by_req(request, allow_guest=True) pagination_config = PaginationConfig.from_request(request) content = yield self.initial_sync_handler.room_initial_sync( - room_id=room_id, - requester=requester, - pagin_config=pagination_config, + room_id=room_id, requester=requester, pagin_config=pagination_config ) defer.returnValue((200, content)) @@ -603,30 +582,24 @@ class RoomEventContextServlet(RestServlet): event_filter = None results = yield self.room_context_handler.get_event_context( - requester.user, - room_id, - event_id, - limit, - event_filter, + requester.user, room_id, event_id, limit, event_filter ) if not results: - raise SynapseError( - 404, "Event not found.", errcode=Codes.NOT_FOUND - ) + raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) time_now = self.clock.time_msec() results["events_before"] = yield self._event_serializer.serialize_events( - results["events_before"], time_now, + results["events_before"], time_now ) results["event"] = yield self._event_serializer.serialize_event( - results["event"], time_now, + results["event"], time_now ) results["events_after"] = yield self._event_serializer.serialize_events( - results["events_after"], time_now, + results["events_after"], time_now ) results["state"] = yield self._event_serializer.serialize_events( - results["state"], time_now, + results["state"], time_now ) defer.returnValue((200, results)) @@ -639,20 +612,14 @@ class RoomForgetRestServlet(TransactionRestServlet): self.auth = hs.get_auth() def register(self, http_server): - PATTERNS = ("/rooms/(?P[^/]*)/forget") + PATTERNS = "/rooms/(?P[^/]*)/forget" register_txn_path(self, PATTERNS, http_server) @defer.inlineCallbacks def on_POST(self, request, room_id, txn_id=None): - requester = yield self.auth.get_user_by_req( - request, - allow_guest=False, - ) + requester = yield self.auth.get_user_by_req(request, allow_guest=False) - yield self.room_member_handler.forget( - user=requester.user, - room_id=room_id, - ) + yield self.room_member_handler.forget(user=requester.user, room_id=room_id) defer.returnValue((200, {})) @@ -664,7 +631,6 @@ class RoomForgetRestServlet(TransactionRestServlet): # TODO: Needs unit testing class RoomMembershipRestServlet(TransactionRestServlet): - def __init__(self, hs): super(RoomMembershipRestServlet, self).__init__(hs) self.room_member_handler = hs.get_room_member_handler() @@ -672,20 +638,19 @@ class RoomMembershipRestServlet(TransactionRestServlet): def register(self, http_server): # /rooms/$roomid/[invite|join|leave] - PATTERNS = ("/rooms/(?P[^/]*)/" - "(?Pjoin|invite|leave|ban|unban|kick)") + PATTERNS = ( + "/rooms/(?P[^/]*)/" + "(?Pjoin|invite|leave|ban|unban|kick)" + ) register_txn_path(self, PATTERNS, http_server) @defer.inlineCallbacks def on_POST(self, request, room_id, membership_action, txn_id=None): - requester = yield self.auth.get_user_by_req( - request, - allow_guest=True, - ) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) if requester.is_guest and membership_action not in { Membership.JOIN, - Membership.LEAVE + Membership.LEAVE, }: raise AuthError(403, "Guest access not allowed") @@ -716,8 +681,8 @@ class RoomMembershipRestServlet(TransactionRestServlet): target = UserID.from_string(content["user_id"]) event_content = None - if 'reason' in content and membership_action in ['kick', 'ban']: - event_content = {'reason': content['reason']} + if "reason" in content and membership_action in ["kick", "ban"]: + event_content = {"reason": content["reason"]} yield self.room_member_handler.update_membership( requester=requester, @@ -756,7 +721,7 @@ class RoomRedactEventRestServlet(TransactionRestServlet): self.auth = hs.get_auth() def register(self, http_server): - PATTERNS = ("/rooms/(?P[^/]*)/redact/(?P[^/]*)") + PATTERNS = "/rooms/(?P[^/]*)/redact/(?P[^/]*)" register_txn_path(self, PATTERNS, http_server) @defer.inlineCallbacks @@ -818,9 +783,7 @@ class RoomTypingRestServlet(RestServlet): ) else: yield self.typing_handler.stopped_typing( - target_user=target_user, - auth_user=requester.user, - room_id=room_id, + target_user=target_user, auth_user=requester.user, room_id=room_id ) defer.returnValue((200, {})) @@ -842,9 +805,7 @@ class SearchRestServlet(RestServlet): batch = parse_string(request, "next_batch") results = yield self.handlers.search_handler.search( - requester.user, - content, - batch, + requester.user, content, batch ) defer.returnValue((200, results)) @@ -880,20 +841,18 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False): with_get: True to also register respective GET paths for the PUTs. """ http_server.register_paths( - "POST", - client_patterns(regex_string + "$", v1=True), - servlet.on_POST + "POST", client_patterns(regex_string + "$", v1=True), servlet.on_POST ) http_server.register_paths( "PUT", client_patterns(regex_string + "/(?P[^/]*)$", v1=True), - servlet.on_PUT + servlet.on_PUT, ) if with_get: http_server.register_paths( "GET", client_patterns(regex_string + "/(?P[^/]*)$", v1=True), - servlet.on_GET + servlet.on_GET, ) diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index 6381049210..41b3171ac8 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -34,8 +34,7 @@ class VoipRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request): requester = yield self.auth.get_user_by_req( - request, - self.hs.config.turn_allow_guests + request, self.hs.config.turn_allow_guests ) turnUris = self.hs.config.turn_uris @@ -49,9 +48,7 @@ class VoipRestServlet(RestServlet): username = "%d:%s" % (expiry, requester.user.to_string()) mac = hmac.new( - turnSecret.encode(), - msg=username.encode(), - digestmod=hashlib.sha1 + turnSecret.encode(), msg=username.encode(), digestmod=hashlib.sha1 ) # We need to use standard padded base64 encoding here # encode_base64 because we need to add the standard padding to get the @@ -65,12 +62,17 @@ class VoipRestServlet(RestServlet): else: defer.returnValue((200, {})) - defer.returnValue((200, { - 'username': username, - 'password': password, - 'ttl': userLifetime / 1000, - 'uris': turnUris, - })) + defer.returnValue( + ( + 200, + { + "username": username, + "password": password, + "ttl": userLifetime / 1000, + "uris": turnUris, + }, + ) + ) def on_OPTIONS(self, request): return (200, {}) diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index 5236d5d566..e3d59ac3ac 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -52,11 +52,11 @@ def client_patterns(path_regex, releases=(0,), unstable=True, v1=False): def set_timeline_upper_limit(filter_json, filter_timeline_limit): if filter_timeline_limit < 0: return # no upper limits - timeline = filter_json.get('room', {}).get('timeline', {}) - if 'limit' in timeline: - filter_json['room']['timeline']["limit"] = min( - filter_json['room']['timeline']['limit'], - filter_timeline_limit) + timeline = filter_json.get("room", {}).get("timeline", {}) + if "limit" in timeline: + filter_json["room"]["timeline"]["limit"] = min( + filter_json["room"]["timeline"]["limit"], filter_timeline_limit + ) def interactive_auth_handler(orig): @@ -74,10 +74,12 @@ def interactive_auth_handler(orig): # ... yield self.auth_handler.check_auth """ + def wrapped(*args, **kwargs): res = defer.maybeDeferred(orig, *args, **kwargs) res.addErrback(_catch_incomplete_interactive_auth) return res + return wrapped diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index b88e58611c..8ba052fd14 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -54,6 +54,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): if self.config.email_password_reset_behaviour == "local": from synapse.push.mailer import Mailer, load_jinja2_templates + templates = load_jinja2_templates( config=hs.config, template_html_name=hs.config.email_password_reset_template_html, @@ -74,14 +75,12 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): "User password resets have been disabled due to lack of email config" ) raise SynapseError( - 400, "Email-based password resets have been disabled on this server", + 400, "Email-based password resets have been disabled on this server" ) body = parse_json_object_from_request(request) - assert_params_in_dict(body, [ - 'client_secret', 'email', 'send_attempt' - ]) + assert_params_in_dict(body, ["client_secret", "email", "send_attempt"]) # Extract params from body client_secret = body["client_secret"] @@ -99,24 +98,24 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): ) existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( - 'email', email, + "email", email ) if existingUid is None: raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) if self.config.email_password_reset_behaviour == "remote": - if 'id_server' not in body: + if "id_server" not in body: raise SynapseError(400, "Missing 'id_server' param in body") # Have the identity server handle the password reset flow ret = yield self.identity_handler.requestEmailToken( - body["id_server"], email, client_secret, send_attempt, next_link, + body["id_server"], email, client_secret, send_attempt, next_link ) else: # Send password reset emails from Synapse sid = yield self.send_password_reset( - email, client_secret, send_attempt, next_link, + email, client_secret, send_attempt, next_link ) # Wrap the session id in a JSON object @@ -125,13 +124,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): defer.returnValue((200, ret)) @defer.inlineCallbacks - def send_password_reset( - self, - email, - client_secret, - send_attempt, - next_link=None, - ): + def send_password_reset(self, email, client_secret, send_attempt, next_link=None): """Send a password reset email Args: @@ -148,14 +141,14 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # Check that this email/client_secret/send_attempt combo is new or # greater than what we've seen previously session = yield self.datastore.get_threepid_validation_session( - "email", client_secret, address=email, validated=False, + "email", client_secret, address=email, validated=False ) # Check to see if a session already exists and that it is not yet # marked as validated if session and session.get("validated_at") is None: - session_id = session['session_id'] - last_send_attempt = session['last_send_attempt'] + session_id = session["session_id"] + last_send_attempt = session["last_send_attempt"] # Check that the send_attempt is higher than previous attempts if send_attempt <= last_send_attempt: @@ -173,22 +166,27 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # and session_id try: yield self.mailer.send_password_reset_mail( - email, token, client_secret, session_id, + email, token, client_secret, session_id ) except Exception: - logger.exception( - "Error sending a password reset email to %s", email, - ) + logger.exception("Error sending a password reset email to %s", email) raise SynapseError( 500, "An error was encountered when sending the password reset email" ) - token_expires = (self.hs.clock.time_msec() + - self.config.email_validation_token_lifetime) + token_expires = ( + self.hs.clock.time_msec() + self.config.email_validation_token_lifetime + ) yield self.datastore.start_or_continue_validation_session( - "email", email, session_id, client_secret, send_attempt, - next_link, token, token_expires, + "email", + email, + session_id, + client_secret, + send_attempt, + next_link, + token, + token_expires, ) defer.returnValue(session_id) @@ -207,12 +205,12 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet): def on_POST(self, request): body = parse_json_object_from_request(request) - assert_params_in_dict(body, [ - 'id_server', 'client_secret', - 'country', 'phone_number', 'send_attempt', - ]) + assert_params_in_dict( + body, + ["id_server", "client_secret", "country", "phone_number", "send_attempt"], + ) - msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) + msisdn = phone_number_to_msisdn(body["country"], body["phone_number"]) if not (yield check_3pid_allowed(self.hs, "msisdn", msisdn)): raise SynapseError( @@ -223,9 +221,7 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet): assert_valid_client_secret(body["client_secret"]) - existingUid = yield self.datastore.get_user_id_by_threepid( - 'msisdn', msisdn - ) + existingUid = yield self.datastore.get_user_id_by_threepid("msisdn", msisdn) if existingUid is None: raise SynapseError(400, "MSISDN not found", Codes.THREEPID_NOT_FOUND) @@ -236,10 +232,9 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet): class PasswordResetSubmitTokenServlet(RestServlet): """Handles 3PID validation token submission""" + PATTERNS = client_patterns( - "/password_reset/(?P[^/]*)/submit_token/*$", - releases=(), - unstable=True, + "/password_reset/(?P[^/]*)/submit_token/*$", releases=(), unstable=True ) def __init__(self, hs): @@ -258,8 +253,7 @@ class PasswordResetSubmitTokenServlet(RestServlet): def on_GET(self, request, medium): if medium != "email": raise SynapseError( - 400, - "This medium is currently not supported for password resets", + 400, "This medium is currently not supported for password resets" ) if self.config.email_password_reset_behaviour == "off": if self.config.password_resets_were_disabled_due_to_email_config: @@ -267,7 +261,7 @@ class PasswordResetSubmitTokenServlet(RestServlet): "User password resets have been disabled due to lack of email config" ) raise SynapseError( - 400, "Email-based password resets have been disabled on this server", + 400, "Email-based password resets have been disabled on this server" ) sid = parse_string(request, "sid") @@ -281,10 +275,7 @@ class PasswordResetSubmitTokenServlet(RestServlet): try: # Mark the session as valid next_link = yield self.datastore.validate_threepid_session( - sid, - client_secret, - token, - self.clock.time_msec(), + sid, client_secret, token, self.clock.time_msec() ) # Perform a 302 redirect if next_link is set @@ -307,13 +298,11 @@ class PasswordResetSubmitTokenServlet(RestServlet): html = self.load_jinja2_template( self.config.email_template_dir, self.config.email_password_reset_failure_template, - template_vars={ - "failure_reason": e.msg, - } + template_vars={"failure_reason": e.msg}, ) request.setResponseCode(e.code) - request.write(html.encode('utf-8')) + request.write(html.encode("utf-8")) finish_request(request) defer.returnValue(None) @@ -339,22 +328,16 @@ class PasswordResetSubmitTokenServlet(RestServlet): def on_POST(self, request, medium): if medium != "email": raise SynapseError( - 400, - "This medium is currently not supported for password resets", + 400, "This medium is currently not supported for password resets" ) body = parse_json_object_from_request(request) - assert_params_in_dict(body, [ - 'sid', 'client_secret', 'token', - ]) + assert_params_in_dict(body, ["sid", "client_secret", "token"]) assert_valid_client_secret(body["client_secret"]) valid, _ = yield self.datastore.validate_threepid_session( - body['sid'], - body['client_secret'], - body['token'], - self.clock.time_msec(), + body["sid"], body["client_secret"], body["token"], self.clock.time_msec() ) response_code = 200 if valid else 400 @@ -395,29 +378,30 @@ class PasswordRestServlet(RestServlet): params = body else: params = yield self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request), + requester, body, self.hs.get_ip_from_request(request) ) user_id = requester.user.to_string() else: requester = None result, params, _ = yield self.auth_handler.check_auth( [[LoginType.EMAIL_IDENTITY], [LoginType.MSISDN]], - body, self.hs.get_ip_from_request(request), + body, + self.hs.get_ip_from_request(request), password_servlet=True, ) if LoginType.EMAIL_IDENTITY in result: threepid = result[LoginType.EMAIL_IDENTITY] - if 'medium' not in threepid or 'address' not in threepid: + if "medium" not in threepid or "address" not in threepid: raise SynapseError(500, "Malformed threepid") - if threepid['medium'] == 'email': + if threepid["medium"] == "email": # For emails, transform the address to lowercase. # We store all email addreses as lowercase in the DB. # (See add_threepid in synapse/handlers/auth.py) - threepid['address'] = threepid['address'].lower() + threepid["address"] = threepid["address"].lower() # if using email, we must know about the email they're authing with! threepid_user_id = yield self.datastore.get_user_id_by_threepid( - threepid['medium'], threepid['address'] + threepid["medium"], threepid["address"] ) if not threepid_user_id: raise SynapseError(404, "Email address not found", Codes.NOT_FOUND) @@ -427,11 +411,9 @@ class PasswordRestServlet(RestServlet): raise SynapseError(500, "", Codes.UNKNOWN) assert_params_in_dict(params, ["new_password"]) - new_password = params['new_password'] + new_password = params["new_password"] - yield self._set_password_handler.set_password( - user_id, new_password, requester - ) + yield self._set_password_handler.set_password(user_id, new_password, requester) if self.hs.config.shadow_server: shadow_user = UserID( @@ -451,10 +433,9 @@ class PasswordRestServlet(RestServlet): as_token = self.hs.config.shadow_server.get("as_token") yield self.http_client.post_json_get_json( - "%s/_matrix/client/r0/account/password?access_token=%s&user_id=%s" % ( - shadow_hs_url, as_token, user_id, - ), - body + "%s/_matrix/client/r0/account/password?access_token=%s&user_id=%s" + % (shadow_hs_url, as_token, user_id), + body, ) @@ -485,25 +466,22 @@ class DeactivateAccountRestServlet(RestServlet): # allow ASes to dectivate their own users if requester.app_service: yield self._deactivate_account_handler.deactivate_account( - requester.user.to_string(), erase, + requester.user.to_string(), erase ) defer.returnValue((200, {})) yield self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request), + requester, body, self.hs.get_ip_from_request(request) ) result = yield self._deactivate_account_handler.deactivate_account( - requester.user.to_string(), erase, - id_server=body.get("id_server"), + requester.user.to_string(), erase, id_server=body.get("id_server") ) if result: id_server_unbind_result = "success" else: id_server_unbind_result = "no-support" - defer.returnValue((200, { - "id_server_unbind_result": id_server_unbind_result, - })) + defer.returnValue((200, {"id_server_unbind_result": id_server_unbind_result})) class EmailThreepidRequestTokenRestServlet(RestServlet): @@ -519,11 +497,10 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): def on_POST(self, request): body = parse_json_object_from_request(request) assert_params_in_dict( - body, - ['id_server', 'client_secret', 'email', 'send_attempt'], + body, ["id_server", "client_secret", "email", "send_attempt"] ) - if not (yield check_3pid_allowed(self.hs, "email", body['email'])): + if not (yield check_3pid_allowed(self.hs, "email", body["email"])): raise SynapseError( 403, "Your email is not authorized on this server", @@ -533,7 +510,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): assert_valid_client_secret(body["client_secret"]) existingUid = yield self.datastore.get_user_id_by_threepid( - 'email', body['email'] + "email", body["email"] ) if existingUid is not None: @@ -555,12 +532,12 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request): body = parse_json_object_from_request(request) - assert_params_in_dict(body, [ - 'id_server', 'client_secret', - 'country', 'phone_number', 'send_attempt', - ]) + assert_params_in_dict( + body, + ["id_server", "client_secret", "country", "phone_number", "send_attempt"], + ) - msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) + msisdn = phone_number_to_msisdn(body["country"], body["phone_number"]) if not (yield check_3pid_allowed(self.hs, "msisdn", msisdn)): raise SynapseError( @@ -571,9 +548,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): assert_valid_client_secret(body["client_secret"]) - existingUid = yield self.datastore.get_user_id_by_threepid( - 'msisdn', msisdn - ) + existingUid = yield self.datastore.get_user_id_by_threepid("msisdn", msisdn) if existingUid is not None: raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) @@ -598,11 +573,9 @@ class ThreepidRestServlet(RestServlet): def on_GET(self, request): requester = yield self.auth.get_user_by_req(request) - threepids = yield self.datastore.user_get_threepids( - requester.user.to_string() - ) + threepids = yield self.datastore.user_get_threepids(requester.user.to_string()) - defer.returnValue((200, {'threepids': threepids})) + defer.returnValue((200, {"threepids": threepids})) @defer.inlineCallbacks def on_POST(self, request): @@ -616,8 +589,8 @@ class ThreepidRestServlet(RestServlet): # skip validation if this is a shadow 3PID from an AS if not requester.app_service: - threePidCreds = body.get('threePidCreds') - threePidCreds = body.get('three_pid_creds', threePidCreds) + threePidCreds = body.get("threePidCreds") + threePidCreds = body.get("three_pid_creds", threePidCreds) if threePidCreds is None: raise SynapseError(400, "Missing param", Codes.MISSING_PARAM) @@ -628,7 +601,7 @@ class ThreepidRestServlet(RestServlet): 400, "Failed to auth 3pid", Codes.THREEPID_AUTH_FAILED ) - for reqd in ['medium', 'address', 'validated_at']: + for reqd in ["medium", "address", "validated_at"]: if reqd not in threepid: logger.warn("Couldn't add 3pid: invalid response from ID server") raise SynapseError(500, "Invalid response from ID Server") @@ -637,29 +610,21 @@ class ThreepidRestServlet(RestServlet): # This makes the API entirely change shape when we have an AS token; # it really should be an entirely separate API - perhaps # /account/3pid/replicate or something. - threepid = body.get('threepid') + threepid = body.get("threepid") yield self.auth_handler.add_threepid( - user_id, - threepid['medium'], - threepid['address'], - threepid['validated_at'], + user_id, threepid["medium"], threepid["address"], threepid["validated_at"] ) - if not requester.app_service and ('bind' in body and body['bind']): - logger.debug( - "Binding threepid %s to %s", - threepid, user_id - ) - yield self.identity_handler.bind_threepid( - threePidCreds, user_id - ) + if not requester.app_service and ("bind" in body and body["bind"]): + logger.debug("Binding threepid %s to %s", threepid, user_id) + yield self.identity_handler.bind_threepid(threePidCreds, user_id) if self.hs.config.shadow_server: shadow_user = UserID( requester.user.localpart, self.hs.config.shadow_server.get("hs") ) - self.shadow_3pid({'threepid': threepid}, shadow_user.to_string()) + self.shadow_3pid({"threepid": threepid}, shadow_user.to_string()) defer.returnValue((200, {})) @@ -670,10 +635,9 @@ class ThreepidRestServlet(RestServlet): as_token = self.hs.config.shadow_server.get("as_token") yield self.http_client.post_json_get_json( - "%s/_matrix/client/r0/account/3pid?access_token=%s&user_id=%s" % ( - shadow_hs_url, as_token, user_id, - ), - body + "%s/_matrix/client/r0/account/3pid?access_token=%s&user_id=%s" + % (shadow_hs_url, as_token, user_id), + body, ) @@ -693,14 +657,14 @@ class ThreepidDeleteRestServlet(RestServlet): raise SynapseError(400, "3PID changes disabled on this server") body = parse_json_object_from_request(request) - assert_params_in_dict(body, ['medium', 'address']) + assert_params_in_dict(body, ["medium", "address"]) requester = yield self.auth.get_user_by_req(request) user_id = requester.user.to_string() try: ret = yield self.auth_handler.delete_threepid( - user_id, body['medium'], body['address'], body.get("id_server"), + user_id, body["medium"], body["address"], body.get("id_server") ) except Exception: # NB. This endpoint should succeed if there is nothing to @@ -720,9 +684,7 @@ class ThreepidDeleteRestServlet(RestServlet): else: id_server_unbind_result = "no-support" - defer.returnValue((200, { - "id_server_unbind_result": id_server_unbind_result, - })) + defer.returnValue((200, {"id_server_unbind_result": id_server_unbind_result})) @defer.inlineCallbacks def shadow_3pid_delete(self, body, user_id): @@ -731,10 +693,9 @@ class ThreepidDeleteRestServlet(RestServlet): as_token = self.hs.config.shadow_server.get("as_token") yield self.http_client.post_json_get_json( - "%s/_matrix/client/r0/account/3pid/delete?access_token=%s&user_id=%s" % ( - shadow_hs_url, as_token, user_id - ), - body + "%s/_matrix/client/r0/account/3pid/delete?access_token=%s&user_id=%s" + % (shadow_hs_url, as_token, user_id), + body, ) @@ -791,7 +752,7 @@ class ThreepidBulkLookupRestServlet(RestServlet): # Proxy the request to the identity server. lookup_3pid handles checking # if the lookup is allowed so we don't need to do it here. ret = yield self.identity_handler.bulk_lookup_3pid( - body["id_server"], body["threepids"], + body["id_server"], body["threepids"] ) defer.returnValue((200, ret)) @@ -808,7 +769,7 @@ class WhoamiRestServlet(RestServlet): def on_GET(self, request): requester = yield self.auth.get_user_by_req(request) - defer.returnValue((200, {'user_id': requester.user.to_string()})) + defer.returnValue((200, {"user_id": requester.user.to_string()})) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py index 17b967d363..0bfb84bdcd 100644 --- a/synapse/rest/client/v2_alpha/account_data.py +++ b/synapse/rest/client/v2_alpha/account_data.py @@ -31,6 +31,7 @@ class AccountDataServlet(RestServlet): PUT /user/{user_id}/account_data/{account_dataType} HTTP/1.1 GET /user/{user_id}/account_data/{account_dataType} HTTP/1.1 """ + PATTERNS = client_patterns( "/user/(?P[^/]*)/account_data/(?P[^/]*)" ) @@ -52,16 +53,14 @@ class AccountDataServlet(RestServlet): if account_data_type == "im.vector.hide_profile": user = UserID.from_string(user_id) - hide_profile = body.get('hide_profile') + hide_profile = body.get("hide_profile") yield self._profile_handler.set_active(user, not hide_profile, True) max_id = yield self.store.add_account_data_for_user( user_id, account_data_type, body ) - self.notifier.on_new_event( - "account_data_key", max_id, users=[user_id] - ) + self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) defer.returnValue((200, {})) @@ -72,7 +71,7 @@ class AccountDataServlet(RestServlet): raise AuthError(403, "Cannot get account data for other users.") event = yield self.store.get_global_account_data_by_type_for_user( - account_data_type, user_id, + account_data_type, user_id ) if event is None: @@ -86,6 +85,7 @@ class RoomAccountDataServlet(RestServlet): PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 GET /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 """ + PATTERNS = client_patterns( "/user/(?P[^/]*)" "/rooms/(?P[^/]*)" @@ -110,16 +110,14 @@ class RoomAccountDataServlet(RestServlet): raise SynapseError( 405, "Cannot set m.fully_read through this API." - " Use /rooms/!roomId:server.name/read_markers" + " Use /rooms/!roomId:server.name/read_markers", ) max_id = yield self.store.add_account_data_to_room( user_id, room_id, account_data_type, body ) - self.notifier.on_new_event( - "account_data_key", max_id, users=[user_id] - ) + self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) defer.returnValue((200, {})) @@ -130,7 +128,7 @@ class RoomAccountDataServlet(RestServlet): raise AuthError(403, "Cannot get account data for other users.") event = yield self.store.get_account_data_for_room_and_type( - user_id, room_id, account_data_type, + user_id, room_id, account_data_type ) if event is None: diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py index 8091b78285..33f6a23028 100644 --- a/synapse/rest/client/v2_alpha/account_validity.py +++ b/synapse/rest/client/v2_alpha/account_validity.py @@ -28,7 +28,9 @@ logger = logging.getLogger(__name__) class AccountValidityRenewServlet(RestServlet): PATTERNS = client_patterns("/account_validity/renew$") - SUCCESS_HTML = b"Your account has been successfully renewed." + SUCCESS_HTML = ( + b"Your account has been successfully renewed." + ) def __init__(self, hs): """ @@ -86,7 +88,9 @@ class AccountValiditySendMailServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request): if not self.account_validity.renew_by_email_enabled: - raise AuthError(403, "Account renewal via email is disabled on this server.") + raise AuthError( + 403, "Account renewal via email is disabled on this server." + ) requester = yield self.auth.get_user_by_req(request, allow_expired=True) user_id = requester.user.to_string() diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 8dfe5cba02..bebc2951e7 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -122,6 +122,7 @@ class AuthRestServlet(RestServlet): cannot be handled in the normal flow (with requests to the same endpoint). Current use is for web fallback auth. """ + PATTERNS = client_patterns(r"/auth/(?P[\w\.]*)/fallback/web") def __init__(self, hs): @@ -138,11 +139,10 @@ class AuthRestServlet(RestServlet): if stagetype == LoginType.RECAPTCHA: html = RECAPTCHA_TEMPLATE % { - 'session': session, - 'myurl': "%s/r0/auth/%s/fallback/web" % ( - CLIENT_API_PREFIX, LoginType.RECAPTCHA - ), - 'sitekey': self.hs.config.recaptcha_public_key, + "session": session, + "myurl": "%s/r0/auth/%s/fallback/web" + % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), + "sitekey": self.hs.config.recaptcha_public_key, } html_bytes = html.encode("utf8") request.setResponseCode(200) @@ -154,14 +154,11 @@ class AuthRestServlet(RestServlet): return None elif stagetype == LoginType.TERMS: html = TERMS_TEMPLATE % { - 'session': session, - 'terms_url': "%s_matrix/consent?v=%s" % ( - self.hs.config.public_baseurl, - self.hs.config.user_consent_version, - ), - 'myurl': "%s/r0/auth/%s/fallback/web" % ( - CLIENT_API_PREFIX, LoginType.TERMS - ), + "session": session, + "terms_url": "%s_matrix/consent?v=%s" + % (self.hs.config.public_baseurl, self.hs.config.user_consent_version), + "myurl": "%s/r0/auth/%s/fallback/web" + % (CLIENT_API_PREFIX, LoginType.TERMS), } html_bytes = html.encode("utf8") request.setResponseCode(200) @@ -187,26 +184,20 @@ class AuthRestServlet(RestServlet): if not response: raise SynapseError(400, "No captcha response supplied") - authdict = { - 'response': response, - 'session': session, - } + authdict = {"response": response, "session": session} success = yield self.auth_handler.add_oob_auth( - LoginType.RECAPTCHA, - authdict, - self.hs.get_ip_from_request(request) + LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request) ) if success: html = SUCCESS_TEMPLATE else: html = RECAPTCHA_TEMPLATE % { - 'session': session, - 'myurl': "%s/r0/auth/%s/fallback/web" % ( - CLIENT_API_PREFIX, LoginType.RECAPTCHA - ), - 'sitekey': self.hs.config.recaptcha_public_key, + "session": session, + "myurl": "%s/r0/auth/%s/fallback/web" + % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), + "sitekey": self.hs.config.recaptcha_public_key, } html_bytes = html.encode("utf8") request.setResponseCode(200) @@ -218,31 +209,28 @@ class AuthRestServlet(RestServlet): defer.returnValue(None) elif stagetype == LoginType.TERMS: - if ('session' not in request.args or - len(request.args['session'])) == 0: + if ("session" not in request.args or len(request.args["session"])) == 0: raise SynapseError(400, "No session supplied") - session = request.args['session'][0] - authdict = {'session': session} + session = request.args["session"][0] + authdict = {"session": session} success = yield self.auth_handler.add_oob_auth( - LoginType.TERMS, - authdict, - self.hs.get_ip_from_request(request) + LoginType.TERMS, authdict, self.hs.get_ip_from_request(request) ) if success: html = SUCCESS_TEMPLATE else: html = TERMS_TEMPLATE % { - 'session': session, - 'terms_url': "%s_matrix/consent?v=%s" % ( + "session": session, + "terms_url": "%s_matrix/consent?v=%s" + % ( self.hs.config.public_baseurl, self.hs.config.user_consent_version, ), - 'myurl': "%s/r0/auth/%s/fallback/web" % ( - CLIENT_API_PREFIX, LoginType.TERMS - ), + "myurl": "%s/r0/auth/%s/fallback/web" + % (CLIENT_API_PREFIX, LoginType.TERMS), } html_bytes = html.encode("utf8") request.setResponseCode(200) diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 78665304a5..d279229d74 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -56,6 +56,7 @@ class DeleteDevicesRestServlet(RestServlet): API for bulk deletion of devices. Accepts a JSON object with a devices key which lists the device_ids to delete. Requires user interactive auth. """ + PATTERNS = client_patterns("/delete_devices") def __init__(self, hs): @@ -84,12 +85,11 @@ class DeleteDevicesRestServlet(RestServlet): assert_params_in_dict(body, ["devices"]) yield self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request), + requester, body, self.hs.get_ip_from_request(request) ) yield self.device_handler.delete_devices( - requester.user.to_string(), - body['devices'], + requester.user.to_string(), body["devices"] ) defer.returnValue((200, {})) @@ -112,8 +112,7 @@ class DeviceRestServlet(RestServlet): def on_GET(self, request, device_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) device = yield self.device_handler.get_device( - requester.user.to_string(), - device_id, + requester.user.to_string(), device_id ) defer.returnValue((200, device)) @@ -134,12 +133,10 @@ class DeviceRestServlet(RestServlet): raise yield self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request), + requester, body, self.hs.get_ip_from_request(request) ) - yield self.device_handler.delete_device( - requester.user.to_string(), device_id, - ) + yield self.device_handler.delete_device(requester.user.to_string(), device_id) defer.returnValue((200, {})) @defer.inlineCallbacks @@ -148,9 +145,7 @@ class DeviceRestServlet(RestServlet): body = parse_json_object_from_request(request) yield self.device_handler.update_device( - requester.user.to_string(), - device_id, - body + requester.user.to_string(), device_id, body ) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index 65db48c3cc..3f0adf4a21 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -53,8 +53,7 @@ class GetFilterRestServlet(RestServlet): try: filter = yield self.filtering.get_user_filter( - user_localpart=target_user.localpart, - filter_id=filter_id, + user_localpart=target_user.localpart, filter_id=filter_id ) defer.returnValue((200, filter.get_filter_json())) @@ -84,14 +83,10 @@ class CreateFilterRestServlet(RestServlet): raise AuthError(403, "Can only create filters for local users") content = parse_json_object_from_request(request) - set_timeline_upper_limit( - content, - self.hs.config.filter_timeline_limit - ) + set_timeline_upper_limit(content, self.hs.config.filter_timeline_limit) filter_id = yield self.filtering.add_user_filter( - user_localpart=target_user.localpart, - user_filter=content, + user_localpart=target_user.localpart, user_filter=content ) defer.returnValue((200, {"filter_id": str(filter_id)})) diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py index d082385ec7..a312dd2593 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py @@ -29,6 +29,7 @@ logger = logging.getLogger(__name__) class GroupServlet(RestServlet): """Get the group profile """ + PATTERNS = client_patterns("/groups/(?P[^/]*)/profile$") def __init__(self, hs): @@ -43,8 +44,7 @@ class GroupServlet(RestServlet): requester_user_id = requester.user.to_string() group_description = yield self.groups_handler.get_group_profile( - group_id, - requester_user_id, + group_id, requester_user_id ) defer.returnValue((200, group_description)) @@ -56,7 +56,7 @@ class GroupServlet(RestServlet): content = parse_json_object_from_request(request) yield self.groups_handler.update_group_profile( - group_id, requester_user_id, content, + group_id, requester_user_id, content ) defer.returnValue((200, {})) @@ -65,6 +65,7 @@ class GroupServlet(RestServlet): class GroupSummaryServlet(RestServlet): """Get the full group summary """ + PATTERNS = client_patterns("/groups/(?P[^/]*)/summary$") def __init__(self, hs): @@ -79,8 +80,7 @@ class GroupSummaryServlet(RestServlet): requester_user_id = requester.user.to_string() get_group_summary = yield self.groups_handler.get_group_summary( - group_id, - requester_user_id, + group_id, requester_user_id ) defer.returnValue((200, get_group_summary)) @@ -93,6 +93,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): - /groups/:group/summary/rooms/:room_id - /groups/:group/summary/categories/:category/rooms/:room_id """ + PATTERNS = client_patterns( "/groups/(?P[^/]*)/summary" "(/categories/(?P[^/]+))?" @@ -112,7 +113,8 @@ class GroupSummaryRoomsCatServlet(RestServlet): content = parse_json_object_from_request(request) resp = yield self.groups_handler.update_group_summary_room( - group_id, requester_user_id, + group_id, + requester_user_id, room_id=room_id, category_id=category_id, content=content, @@ -126,9 +128,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): requester_user_id = requester.user.to_string() resp = yield self.groups_handler.delete_group_summary_room( - group_id, requester_user_id, - room_id=room_id, - category_id=category_id, + group_id, requester_user_id, room_id=room_id, category_id=category_id ) defer.returnValue((200, resp)) @@ -137,6 +137,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): class GroupCategoryServlet(RestServlet): """Get/add/update/delete a group category """ + PATTERNS = client_patterns( "/groups/(?P[^/]*)/categories/(?P[^/]+)$" ) @@ -153,8 +154,7 @@ class GroupCategoryServlet(RestServlet): requester_user_id = requester.user.to_string() category = yield self.groups_handler.get_group_category( - group_id, requester_user_id, - category_id=category_id, + group_id, requester_user_id, category_id=category_id ) defer.returnValue((200, category)) @@ -166,9 +166,7 @@ class GroupCategoryServlet(RestServlet): content = parse_json_object_from_request(request) resp = yield self.groups_handler.update_group_category( - group_id, requester_user_id, - category_id=category_id, - content=content, + group_id, requester_user_id, category_id=category_id, content=content ) defer.returnValue((200, resp)) @@ -179,8 +177,7 @@ class GroupCategoryServlet(RestServlet): requester_user_id = requester.user.to_string() resp = yield self.groups_handler.delete_group_category( - group_id, requester_user_id, - category_id=category_id, + group_id, requester_user_id, category_id=category_id ) defer.returnValue((200, resp)) @@ -189,9 +186,8 @@ class GroupCategoryServlet(RestServlet): class GroupCategoriesServlet(RestServlet): """Get all group categories """ - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/categories/$" - ) + + PATTERNS = client_patterns("/groups/(?P[^/]*)/categories/$") def __init__(self, hs): super(GroupCategoriesServlet, self).__init__() @@ -205,7 +201,7 @@ class GroupCategoriesServlet(RestServlet): requester_user_id = requester.user.to_string() category = yield self.groups_handler.get_group_categories( - group_id, requester_user_id, + group_id, requester_user_id ) defer.returnValue((200, category)) @@ -214,9 +210,8 @@ class GroupCategoriesServlet(RestServlet): class GroupRoleServlet(RestServlet): """Get/add/update/delete a group role """ - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/roles/(?P[^/]+)$" - ) + + PATTERNS = client_patterns("/groups/(?P[^/]*)/roles/(?P[^/]+)$") def __init__(self, hs): super(GroupRoleServlet, self).__init__() @@ -230,8 +225,7 @@ class GroupRoleServlet(RestServlet): requester_user_id = requester.user.to_string() category = yield self.groups_handler.get_group_role( - group_id, requester_user_id, - role_id=role_id, + group_id, requester_user_id, role_id=role_id ) defer.returnValue((200, category)) @@ -243,9 +237,7 @@ class GroupRoleServlet(RestServlet): content = parse_json_object_from_request(request) resp = yield self.groups_handler.update_group_role( - group_id, requester_user_id, - role_id=role_id, - content=content, + group_id, requester_user_id, role_id=role_id, content=content ) defer.returnValue((200, resp)) @@ -256,8 +248,7 @@ class GroupRoleServlet(RestServlet): requester_user_id = requester.user.to_string() resp = yield self.groups_handler.delete_group_role( - group_id, requester_user_id, - role_id=role_id, + group_id, requester_user_id, role_id=role_id ) defer.returnValue((200, resp)) @@ -266,9 +257,8 @@ class GroupRoleServlet(RestServlet): class GroupRolesServlet(RestServlet): """Get all group roles """ - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/roles/$" - ) + + PATTERNS = client_patterns("/groups/(?P[^/]*)/roles/$") def __init__(self, hs): super(GroupRolesServlet, self).__init__() @@ -282,7 +272,7 @@ class GroupRolesServlet(RestServlet): requester_user_id = requester.user.to_string() category = yield self.groups_handler.get_group_roles( - group_id, requester_user_id, + group_id, requester_user_id ) defer.returnValue((200, category)) @@ -295,6 +285,7 @@ class GroupSummaryUsersRoleServlet(RestServlet): - /groups/:group/summary/users/:room_id - /groups/:group/summary/roles/:role/users/:user_id """ + PATTERNS = client_patterns( "/groups/(?P[^/]*)/summary" "(/roles/(?P[^/]+))?" @@ -314,7 +305,8 @@ class GroupSummaryUsersRoleServlet(RestServlet): content = parse_json_object_from_request(request) resp = yield self.groups_handler.update_group_summary_user( - group_id, requester_user_id, + group_id, + requester_user_id, user_id=user_id, role_id=role_id, content=content, @@ -328,9 +320,7 @@ class GroupSummaryUsersRoleServlet(RestServlet): requester_user_id = requester.user.to_string() resp = yield self.groups_handler.delete_group_summary_user( - group_id, requester_user_id, - user_id=user_id, - role_id=role_id, + group_id, requester_user_id, user_id=user_id, role_id=role_id ) defer.returnValue((200, resp)) @@ -339,6 +329,7 @@ class GroupSummaryUsersRoleServlet(RestServlet): class GroupRoomServlet(RestServlet): """Get all rooms in a group """ + PATTERNS = client_patterns("/groups/(?P[^/]*)/rooms$") def __init__(self, hs): @@ -352,7 +343,9 @@ class GroupRoomServlet(RestServlet): requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_rooms_in_group(group_id, requester_user_id) + result = yield self.groups_handler.get_rooms_in_group( + group_id, requester_user_id + ) defer.returnValue((200, result)) @@ -360,6 +353,7 @@ class GroupRoomServlet(RestServlet): class GroupUsersServlet(RestServlet): """Get all users in a group """ + PATTERNS = client_patterns("/groups/(?P[^/]*)/users$") def __init__(self, hs): @@ -373,7 +367,9 @@ class GroupUsersServlet(RestServlet): requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_users_in_group(group_id, requester_user_id) + result = yield self.groups_handler.get_users_in_group( + group_id, requester_user_id + ) defer.returnValue((200, result)) @@ -381,6 +377,7 @@ class GroupUsersServlet(RestServlet): class GroupInvitedUsersServlet(RestServlet): """Get users invited to a group """ + PATTERNS = client_patterns("/groups/(?P[^/]*)/invited_users$") def __init__(self, hs): @@ -395,8 +392,7 @@ class GroupInvitedUsersServlet(RestServlet): requester_user_id = requester.user.to_string() result = yield self.groups_handler.get_invited_users_in_group( - group_id, - requester_user_id, + group_id, requester_user_id ) defer.returnValue((200, result)) @@ -405,6 +401,7 @@ class GroupInvitedUsersServlet(RestServlet): class GroupSettingJoinPolicyServlet(RestServlet): """Set group join policy """ + PATTERNS = client_patterns("/groups/(?P[^/]*)/settings/m.join_policy$") def __init__(self, hs): @@ -420,9 +417,7 @@ class GroupSettingJoinPolicyServlet(RestServlet): content = parse_json_object_from_request(request) result = yield self.groups_handler.set_group_join_policy( - group_id, - requester_user_id, - content, + group_id, requester_user_id, content ) defer.returnValue((200, result)) @@ -431,6 +426,7 @@ class GroupSettingJoinPolicyServlet(RestServlet): class GroupCreateServlet(RestServlet): """Create a group """ + PATTERNS = client_patterns("/create_group$") def __init__(self, hs): @@ -451,9 +447,7 @@ class GroupCreateServlet(RestServlet): group_id = GroupID(localpart, self.server_name).to_string() result = yield self.groups_handler.create_group( - group_id, - requester_user_id, - content, + group_id, requester_user_id, content ) defer.returnValue((200, result)) @@ -462,6 +456,7 @@ class GroupCreateServlet(RestServlet): class GroupAdminRoomsServlet(RestServlet): """Add a room to the group """ + PATTERNS = client_patterns( "/groups/(?P[^/]*)/admin/rooms/(?P[^/]*)$" ) @@ -479,7 +474,7 @@ class GroupAdminRoomsServlet(RestServlet): content = parse_json_object_from_request(request) result = yield self.groups_handler.add_room_to_group( - group_id, requester_user_id, room_id, content, + group_id, requester_user_id, room_id, content ) defer.returnValue((200, result)) @@ -490,7 +485,7 @@ class GroupAdminRoomsServlet(RestServlet): requester_user_id = requester.user.to_string() result = yield self.groups_handler.remove_room_from_group( - group_id, requester_user_id, room_id, + group_id, requester_user_id, room_id ) defer.returnValue((200, result)) @@ -499,6 +494,7 @@ class GroupAdminRoomsServlet(RestServlet): class GroupAdminRoomsConfigServlet(RestServlet): """Update the config of a room in a group """ + PATTERNS = client_patterns( "/groups/(?P[^/]*)/admin/rooms/(?P[^/]*)" "/config/(?P[^/]*)$" @@ -517,7 +513,7 @@ class GroupAdminRoomsConfigServlet(RestServlet): content = parse_json_object_from_request(request) result = yield self.groups_handler.update_room_in_group( - group_id, requester_user_id, room_id, config_key, content, + group_id, requester_user_id, room_id, config_key, content ) defer.returnValue((200, result)) @@ -526,6 +522,7 @@ class GroupAdminRoomsConfigServlet(RestServlet): class GroupAdminUsersInviteServlet(RestServlet): """Invite a user to the group """ + PATTERNS = client_patterns( "/groups/(?P[^/]*)/admin/users/invite/(?P[^/]*)$" ) @@ -546,7 +543,7 @@ class GroupAdminUsersInviteServlet(RestServlet): content = parse_json_object_from_request(request) config = content.get("config", {}) result = yield self.groups_handler.invite( - group_id, user_id, requester_user_id, config, + group_id, user_id, requester_user_id, config ) defer.returnValue((200, result)) @@ -555,6 +552,7 @@ class GroupAdminUsersInviteServlet(RestServlet): class GroupAdminUsersKickServlet(RestServlet): """Kick a user from the group """ + PATTERNS = client_patterns( "/groups/(?P[^/]*)/admin/users/remove/(?P[^/]*)$" ) @@ -572,7 +570,7 @@ class GroupAdminUsersKickServlet(RestServlet): content = parse_json_object_from_request(request) result = yield self.groups_handler.remove_user_from_group( - group_id, user_id, requester_user_id, content, + group_id, user_id, requester_user_id, content ) defer.returnValue((200, result)) @@ -581,9 +579,8 @@ class GroupAdminUsersKickServlet(RestServlet): class GroupSelfLeaveServlet(RestServlet): """Leave a joined group """ - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/self/leave$" - ) + + PATTERNS = client_patterns("/groups/(?P[^/]*)/self/leave$") def __init__(self, hs): super(GroupSelfLeaveServlet, self).__init__() @@ -598,7 +595,7 @@ class GroupSelfLeaveServlet(RestServlet): content = parse_json_object_from_request(request) result = yield self.groups_handler.remove_user_from_group( - group_id, requester_user_id, requester_user_id, content, + group_id, requester_user_id, requester_user_id, content ) defer.returnValue((200, result)) @@ -607,9 +604,8 @@ class GroupSelfLeaveServlet(RestServlet): class GroupSelfJoinServlet(RestServlet): """Attempt to join a group, or knock """ - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/self/join$" - ) + + PATTERNS = client_patterns("/groups/(?P[^/]*)/self/join$") def __init__(self, hs): super(GroupSelfJoinServlet, self).__init__() @@ -624,7 +620,7 @@ class GroupSelfJoinServlet(RestServlet): content = parse_json_object_from_request(request) result = yield self.groups_handler.join_group( - group_id, requester_user_id, content, + group_id, requester_user_id, content ) defer.returnValue((200, result)) @@ -633,9 +629,8 @@ class GroupSelfJoinServlet(RestServlet): class GroupSelfAcceptInviteServlet(RestServlet): """Accept a group invite """ - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/self/accept_invite$" - ) + + PATTERNS = client_patterns("/groups/(?P[^/]*)/self/accept_invite$") def __init__(self, hs): super(GroupSelfAcceptInviteServlet, self).__init__() @@ -650,7 +645,7 @@ class GroupSelfAcceptInviteServlet(RestServlet): content = parse_json_object_from_request(request) result = yield self.groups_handler.accept_invite( - group_id, requester_user_id, content, + group_id, requester_user_id, content ) defer.returnValue((200, result)) @@ -659,9 +654,8 @@ class GroupSelfAcceptInviteServlet(RestServlet): class GroupSelfUpdatePublicityServlet(RestServlet): """Update whether we publicise a users membership of a group """ - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/self/update_publicity$" - ) + + PATTERNS = client_patterns("/groups/(?P[^/]*)/self/update_publicity$") def __init__(self, hs): super(GroupSelfUpdatePublicityServlet, self).__init__() @@ -676,9 +670,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet): content = parse_json_object_from_request(request) publicise = content["publicise"] - yield self.store.update_group_publicity( - group_id, requester_user_id, publicise, - ) + yield self.store.update_group_publicity(group_id, requester_user_id, publicise) defer.returnValue((200, {})) @@ -686,9 +678,8 @@ class GroupSelfUpdatePublicityServlet(RestServlet): class PublicisedGroupsForUserServlet(RestServlet): """Get the list of groups a user is advertising """ - PATTERNS = client_patterns( - "/publicised_groups/(?P[^/]*)$" - ) + + PATTERNS = client_patterns("/publicised_groups/(?P[^/]*)$") def __init__(self, hs): super(PublicisedGroupsForUserServlet, self).__init__() @@ -701,9 +692,7 @@ class PublicisedGroupsForUserServlet(RestServlet): def on_GET(self, request, user_id): yield self.auth.get_user_by_req(request, allow_guest=True) - result = yield self.groups_handler.get_publicised_groups_for_user( - user_id - ) + result = yield self.groups_handler.get_publicised_groups_for_user(user_id) defer.returnValue((200, result)) @@ -711,9 +700,8 @@ class PublicisedGroupsForUserServlet(RestServlet): class PublicisedGroupsForUsersServlet(RestServlet): """Get the list of groups a user is advertising """ - PATTERNS = client_patterns( - "/publicised_groups$" - ) + + PATTERNS = client_patterns("/publicised_groups$") def __init__(self, hs): super(PublicisedGroupsForUsersServlet, self).__init__() @@ -729,9 +717,7 @@ class PublicisedGroupsForUsersServlet(RestServlet): content = parse_json_object_from_request(request) user_ids = content["user_ids"] - result = yield self.groups_handler.bulk_get_publicised_groups( - user_ids - ) + result = yield self.groups_handler.bulk_get_publicised_groups(user_ids) defer.returnValue((200, result)) @@ -739,9 +725,8 @@ class PublicisedGroupsForUsersServlet(RestServlet): class GroupsForUserServlet(RestServlet): """Get all groups the logged in user is joined to """ - PATTERNS = client_patterns( - "/joined_groups$" - ) + + PATTERNS = client_patterns("/joined_groups$") def __init__(self, hs): super(GroupsForUserServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 4cbfbf5631..45c9928b65 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -56,6 +56,7 @@ class KeyUploadServlet(RestServlet): }, } """ + PATTERNS = client_patterns("/keys/upload(/(?P[^/]+))?$") def __init__(self, hs): @@ -76,18 +77,19 @@ class KeyUploadServlet(RestServlet): if device_id is not None: # passing the device_id here is deprecated; however, we allow it # for now for compatibility with older clients. - if (requester.device_id is not None and - device_id != requester.device_id): - logger.warning("Client uploading keys for a different device " - "(logged in as %s, uploading for %s)", - requester.device_id, device_id) + if requester.device_id is not None and device_id != requester.device_id: + logger.warning( + "Client uploading keys for a different device " + "(logged in as %s, uploading for %s)", + requester.device_id, + device_id, + ) else: device_id = requester.device_id if device_id is None: raise SynapseError( - 400, - "To upload keys, you must pass device_id when authenticating" + 400, "To upload keys, you must pass device_id when authenticating" ) result = yield self.e2e_keys_handler.upload_keys_for_user( @@ -159,6 +161,7 @@ class KeyChangesServlet(RestServlet): 200 OK { "changed": ["@foo:example.com"] } """ + PATTERNS = client_patterns("/keys/changes$") def __init__(self, hs): @@ -184,9 +187,7 @@ class KeyChangesServlet(RestServlet): user_id = requester.user.to_string() - results = yield self.device_handler.get_user_ids_changed( - user_id, from_token, - ) + results = yield self.device_handler.get_user_ids_changed(user_id, from_token) defer.returnValue((200, results)) @@ -209,6 +210,7 @@ class OneTimeKeyServlet(RestServlet): } } } } """ + PATTERNS = client_patterns("/keys/claim$") def __init__(self, hs): @@ -221,10 +223,7 @@ class OneTimeKeyServlet(RestServlet): yield self.auth.get_user_by_req(request, allow_guest=True) timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) - result = yield self.e2e_keys_handler.claim_one_time_keys( - body, - timeout, - ) + result = yield self.e2e_keys_handler.claim_one_time_keys(body, timeout) defer.returnValue((200, result)) diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py index 53e666989b..728a52328f 100644 --- a/synapse/rest/client/v2_alpha/notifications.py +++ b/synapse/rest/client/v2_alpha/notifications.py @@ -51,7 +51,7 @@ class NotificationsServlet(RestServlet): ) receipts_by_room = yield self.store.get_receipts_for_user_with_orderings( - user_id, 'm.read' + user_id, "m.read" ) notif_event_ids = [pa["event_id"] for pa in push_actions] @@ -67,11 +67,13 @@ class NotificationsServlet(RestServlet): "profile_tag": pa["profile_tag"], "actions": pa["actions"], "ts": pa["received_ts"], - "event": (yield self._event_serializer.serialize_event( - notif_events[pa["event_id"]], - self.clock.time_msec(), - event_format=format_event_for_client_v2_without_room_id, - )), + "event": ( + yield self._event_serializer.serialize_event( + notif_events[pa["event_id"]], + self.clock.time_msec(), + event_format=format_event_for_client_v2_without_room_id, + ) + ), } if pa["room_id"] not in receipts_by_room: @@ -80,17 +82,15 @@ class NotificationsServlet(RestServlet): receipt = receipts_by_room[pa["room_id"]] returned_pa["read"] = ( - receipt["topological_ordering"], receipt["stream_ordering"] - ) >= ( - pa["topological_ordering"], pa["stream_ordering"] - ) + receipt["topological_ordering"], + receipt["stream_ordering"], + ) >= (pa["topological_ordering"], pa["stream_ordering"]) returned_push_actions.append(returned_pa) next_token = str(pa["stream_ordering"]) - defer.returnValue((200, { - "notifications": returned_push_actions, - "next_token": next_token, - })) + defer.returnValue( + (200, {"notifications": returned_push_actions, "next_token": next_token}) + ) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py index bb927d9f9d..b1b5385b09 100644 --- a/synapse/rest/client/v2_alpha/openid.py +++ b/synapse/rest/client/v2_alpha/openid.py @@ -56,9 +56,8 @@ class IdTokenServlet(RestServlet): "expires_in": 3600, } """ - PATTERNS = client_patterns( - "/user/(?P[^/]*)/openid/request_token" - ) + + PATTERNS = client_patterns("/user/(?P[^/]*)/openid/request_token") EXPIRES_MS = 3600 * 1000 @@ -84,12 +83,17 @@ class IdTokenServlet(RestServlet): yield self.store.insert_open_id_token(token, ts_valid_until_ms, user_id) - defer.returnValue((200, { - "access_token": token, - "token_type": "Bearer", - "matrix_server_name": self.server_name, - "expires_in": self.EXPIRES_MS / 1000, - })) + defer.returnValue( + ( + 200, + { + "access_token": token, + "token_type": "Bearer", + "matrix_server_name": self.server_name, + "expires_in": self.EXPIRES_MS / 1000, + }, + ) + ) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py index f4bd0d077f..e75664279b 100644 --- a/synapse/rest/client/v2_alpha/read_marker.py +++ b/synapse/rest/client/v2_alpha/read_marker.py @@ -48,7 +48,7 @@ class ReadMarkerRestServlet(RestServlet): room_id, "m.read", user_id=requester.user.to_string(), - event_id=read_event_id + event_id=read_event_id, ) read_marker_event_id = body.get("m.fully_read", None) @@ -56,7 +56,7 @@ class ReadMarkerRestServlet(RestServlet): yield self.read_marker_handler.received_client_read_marker( room_id, user_id=requester.user.to_string(), - event_id=read_marker_event_id + event_id=read_marker_event_id, ) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index fa12ac3e4d..488905626a 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -49,10 +49,7 @@ class ReceiptRestServlet(RestServlet): yield self.presence_handler.bump_presence_active_time(requester.user) yield self.receipts_handler.received_client_receipt( - room_id, - receipt_type, - user_id=requester.user.to_string(), - event_id=event_id + room_id, receipt_type, user_id=requester.user.to_string(), event_id=event_id ) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 3d5a198278..bcacd2a2ca 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -55,6 +55,7 @@ from ._base import client_patterns, interactive_auth_handler if hasattr(hmac, "compare_digest"): compare_digest = hmac.compare_digest else: + def compare_digest(a, b): return a == b @@ -78,11 +79,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): def on_POST(self, request): body = parse_json_object_from_request(request) - assert_params_in_dict(body, [ - 'id_server', 'client_secret', 'email', 'send_attempt' - ]) + assert_params_in_dict( + body, ["id_server", "client_secret", "email", "send_attempt"] + ) - if not (yield check_3pid_allowed(self.hs, "email", body['email'])): + if not (yield check_3pid_allowed(self.hs, "email", body["email"])): raise SynapseError( 403, "Your email is not authorized to register on this server", @@ -92,7 +93,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): assert_params_in_dict(body["client_secret"]) existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( - 'email', body['email'] + "email", body["email"] ) if existingUid is not None: @@ -118,13 +119,12 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): def on_POST(self, request): body = parse_json_object_from_request(request) - assert_params_in_dict(body, [ - 'id_server', 'client_secret', - 'country', 'phone_number', - 'send_attempt', - ]) + assert_params_in_dict( + body, + ["id_server", "client_secret", "country", "phone_number", "send_attempt"], + ) - msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) + msisdn = phone_number_to_msisdn(body["country"], body["phone_number"]) assert_valid_client_secret(body["client_secret"]) @@ -136,7 +136,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): ) existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( - 'msisdn', msisdn + "msisdn", msisdn ) if existingUid is not None: @@ -172,7 +172,7 @@ class UsernameAvailabilityRestServlet(RestServlet): reject_limit=1, # Allow 1 request at a time concurrent_requests=1, - ) + ), ) @defer.inlineCallbacks @@ -220,7 +220,8 @@ class RegisterRestServlet(RestServlet): time_now = self.clock.time() allowed, time_allowed = self.ratelimiter.can_do_action( - client_addr, time_now_s=time_now, + client_addr, + time_now_s=time_now, rate_hz=self.hs.config.rc_registration.per_second, burst_count=self.hs.config.rc_registration.burst_count, update=False, @@ -228,7 +229,7 @@ class RegisterRestServlet(RestServlet): if not allowed: raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now)), + retry_after_ms=int(1000 * (time_allowed - time_now)) ) kind = b"user" @@ -247,21 +248,25 @@ class RegisterRestServlet(RestServlet): # we do basic sanity checks here because the auth layer will store these # in sessions. Pull out the username/password provided to us. desired_password = None - if 'password' in body: - if (not isinstance(body['password'], string_types) or - len(body['password']) > 512): + if "password" in body: + if ( + not isinstance(body["password"], string_types) + or len(body["password"]) > 512 + ): raise SynapseError(400, "Invalid password") - self.password_policy_handler.validate_password(body['password']) + self.password_policy_handler.validate_password(body["password"]) desired_password = body["password"] desired_username = None - if 'username' in body: - if (not isinstance(body['username'], string_types) or - len(body['username']) > 512): + if "username" in body: + if ( + not isinstance(body["username"], string_types) + or len(body["username"]) > 512 + ): raise SynapseError(400, "Invalid username") - desired_username = body['username'] + desired_username = body["username"] - desired_display_name = body.get('display_name') + desired_display_name = body.get("display_name") appservice = None if self.auth.has_access_token(request): @@ -286,8 +291,11 @@ class RegisterRestServlet(RestServlet): if isinstance(desired_username, string_types): result = yield self._do_appservice_registration( - desired_username, desired_password, desired_display_name, - access_token, body + desired_username, + desired_password, + desired_display_name, + access_token, + body, ) defer.returnValue((200, result)) # we throw for non 200 responses return @@ -302,7 +310,7 @@ class RegisterRestServlet(RestServlet): desired_username = desired_username.lower() # == Shared Secret Registration == (e.g. create new user scripts) - if 'mac' in body: + if "mac" in body: # FIXME: Should we really be determining if this is shared secret # auth based purely on the 'mac' key? result = yield self._do_shared_secret_registration( @@ -317,16 +325,13 @@ class RegisterRestServlet(RestServlet): guest_access_token = body.get("guest_access_token", None) - if ( - 'initial_device_display_name' in body and - 'password' not in body - ): + if "initial_device_display_name" in body and "password" not in body: # ignore 'initial_device_display_name' if sent without # a password to work around a client bug where it sent # the 'initial_device_display_name' param alone, wiping out # the original registration params logger.warn("Ignoring initial_device_display_name without password") - del body['initial_device_display_name'] + del body["initial_device_display_name"] session_id = self.auth_handler.get_session_id(body) registered_user_id = None @@ -348,8 +353,8 @@ class RegisterRestServlet(RestServlet): # FIXME: need a better error than "no auth flow found" for scenarios # where we required 3PID for registration but the user didn't give one - require_email = 'email' in self.hs.config.registrations_require_3pid - require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid + require_email = "email" in self.hs.config.registrations_require_3pid + require_msisdn = "msisdn" in self.hs.config.registrations_require_3pid show_msisdn = True if self.hs.config.disable_msisdn_registration: @@ -374,9 +379,9 @@ class RegisterRestServlet(RestServlet): if not require_email: flows.extend([[LoginType.RECAPTCHA, LoginType.MSISDN]]) # always let users provide both MSISDN & email - flows.extend([ - [LoginType.RECAPTCHA, LoginType.MSISDN, LoginType.EMAIL_IDENTITY], - ]) + flows.extend( + [[LoginType.RECAPTCHA, LoginType.MSISDN, LoginType.EMAIL_IDENTITY]] + ) else: # only support 3PIDless registration if no 3PIDs are required if not require_email and not require_msisdn: @@ -390,9 +395,7 @@ class RegisterRestServlet(RestServlet): if not require_email or require_msisdn: flows.extend([[LoginType.MSISDN]]) # always let users provide both MSISDN & email - flows.extend([ - [LoginType.MSISDN, LoginType.EMAIL_IDENTITY] - ]) + flows.extend([[LoginType.MSISDN, LoginType.EMAIL_IDENTITY]]) # Append m.login.terms to all flows if we're requiring consent if self.hs.config.user_consent_at_registration: @@ -422,26 +425,24 @@ class RegisterRestServlet(RestServlet): if auth_result: for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]: if login_type in auth_result: - medium = auth_result[login_type]['medium'] - address = auth_result[login_type]['address'] + medium = auth_result[login_type]["medium"] + address = auth_result[login_type]["address"] if not (yield check_3pid_allowed(self.hs, medium, address)): raise SynapseError( 403, - "Third party identifiers (email/phone numbers)" + - " are not authorized on this server", + "Third party identifiers (email/phone numbers)" + + " are not authorized on this server", Codes.THREEPID_DENIED, ) existingUid = yield self.store.get_user_id_by_threepid( - medium, address, + medium, address ) if existingUid is not None: raise SynapseError( - 400, - "%s is already in use" % medium, - Codes.THREEPID_IN_USE, + 400, "%s is already in use" % medium, Codes.THREEPID_IN_USE ) if self.hs.config.register_mxid_from_3pid: @@ -455,12 +456,12 @@ class RegisterRestServlet(RestServlet): # desired_username if auth_result: if ( - self.hs.config.register_mxid_from_3pid == 'email' and - LoginType.EMAIL_IDENTITY in auth_result + self.hs.config.register_mxid_from_3pid == "email" + and LoginType.EMAIL_IDENTITY in auth_result ): - address = auth_result[LoginType.EMAIL_IDENTITY]['address'] + address = auth_result[LoginType.EMAIL_IDENTITY]["address"] desired_username = synapse.types.strip_invalid_mxid_characters( - address.replace('@', '-').lower() + address.replace("@", "-").lower() ) # find a unique mxid for the account, suffixing numbers @@ -476,7 +477,7 @@ class RegisterRestServlet(RestServlet): break except SynapseError as e: if e.errcode == Codes.USER_IN_USE: - m = re.match(r'^(.*?)(\d+)$', desired_username) + m = re.match(r"^(.*?)(\d+)$", desired_username) if m: desired_username = m.group(1) + str( int(m.group(2)) + 1 @@ -493,10 +494,10 @@ class RegisterRestServlet(RestServlet): # Custom mapping between email address and display name desired_display_name = self._map_email_to_displayname(address) elif ( - self.hs.config.register_mxid_from_3pid == 'msisdn' and - LoginType.MSISDN in auth_result + self.hs.config.register_mxid_from_3pid == "msisdn" + and LoginType.MSISDN in auth_result ): - desired_username = auth_result[LoginType.MSISDN]['address'] + desired_username = auth_result[LoginType.MSISDN]["address"] else: raise SynapseError( 400, "Cannot derive mxid from 3pid; no recognised 3pid" @@ -511,8 +512,7 @@ class RegisterRestServlet(RestServlet): if registered_user_id is not None: logger.info( - "Already registered user ID %r for this session", - registered_user_id + "Already registered user ID %r for this session", registered_user_id ) # don't re-register the threepids registered = False @@ -546,11 +546,11 @@ class RegisterRestServlet(RestServlet): # the two activation emails, they would register the same 3pid twice. for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]: if login_type in auth_result: - medium = auth_result[login_type]['medium'] - address = auth_result[login_type]['address'] + medium = auth_result[login_type]["medium"] + address = auth_result[login_type]["address"] existingUid = yield self.store.get_user_id_by_threepid( - medium, address, + medium, address ) if existingUid is not None: @@ -623,19 +623,17 @@ class RegisterRestServlet(RestServlet): ) result = yield self._create_registration_details(user_id, body) - auth_result = body.get('auth_result') + auth_result = body.get("auth_result") if auth_result and LoginType.EMAIL_IDENTITY in auth_result: threepid = auth_result[LoginType.EMAIL_IDENTITY] yield self._register_email_threepid( - user_id, threepid, result["access_token"], - body.get("bind_email") + user_id, threepid, result["access_token"], body.get("bind_email") ) if auth_result and LoginType.MSISDN in auth_result: threepid = auth_result[LoginType.MSISDN] yield self._register_msisdn_threepid( - user_id, threepid, result["access_token"], - body.get("bind_msisdn") + user_id, threepid, result["access_token"], body.get("bind_msisdn") ) defer.returnValue(result) @@ -646,7 +644,7 @@ class RegisterRestServlet(RestServlet): raise SynapseError(400, "Shared secret registration is not enabled") if not username: raise SynapseError( - 400, "username must be specified", errcode=Codes.BAD_JSON, + 400, "username must be specified", errcode=Codes.BAD_JSON ) # use the username from the original request rather than the @@ -667,12 +665,10 @@ class RegisterRestServlet(RestServlet): ).hexdigest() if not compare_digest(want_mac, got_mac): - raise SynapseError( - 403, "HMAC incorrect", - ) + raise SynapseError(403, "HMAC incorrect") (user_id, _) = yield self.registration_handler.register( - localpart=username, password=password, generate_token=False, + localpart=username, password=password, generate_token=False ) result = yield self._create_registration_details(user_id, body) @@ -691,21 +687,15 @@ class RegisterRestServlet(RestServlet): Returns: defer.Deferred: (object) dictionary for response from /register """ - result = { - "user_id": user_id, - "home_server": self.hs.hostname, - } + result = {"user_id": user_id, "home_server": self.hs.hostname} if not params.get("inhibit_login", False): device_id = params.get("device_id") initial_display_name = params.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( - user_id, device_id, initial_display_name, is_guest=False, + user_id, device_id, initial_display_name, is_guest=False ) - result.update({ - "access_token": access_token, - "device_id": device_id, - }) + result.update({"access_token": access_token, "device_id": device_id}) defer.returnValue(result) @defer.inlineCallbacks @@ -713,9 +703,7 @@ class RegisterRestServlet(RestServlet): if not self.hs.config.allow_guest_access: raise SynapseError(403, "Guest access is disabled") user_id, _ = yield self.registration_handler.register( - generate_token=False, - make_guest=True, - address=address, + generate_token=False, make_guest=True, address=address ) # we don't allow guests to specify their own device_id, because @@ -723,15 +711,20 @@ class RegisterRestServlet(RestServlet): device_id = synapse.api.auth.GUEST_DEVICE_ID initial_display_name = params.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( - user_id, device_id, initial_display_name, is_guest=True, + user_id, device_id, initial_display_name, is_guest=True ) - defer.returnValue((200, { - "user_id": user_id, - "device_id": device_id, - "access_token": access_token, - "home_server": self.hs.hostname, - })) + defer.returnValue( + ( + 200, + { + "user_id": user_id, + "device_id": device_id, + "access_token": access_token, + "home_server": self.hs.hostname, + }, + ) + ) def cap(name): @@ -764,10 +757,10 @@ def _map_email_to_displayname(address): """ # Split the part before and after the @ in the email. # Replace all . with spaces in the first part - parts = address.replace('.', ' ').split('@') + parts = address.replace(".", " ").split("@") # Figure out which org this email address belongs to - org_parts = parts[1].split(' ') + org_parts = parts[1].split(" ") # If this is a ...matrix.org email, mark them as an Admin if org_parts[-2] == "matrix" and org_parts[-1] == "org": @@ -783,9 +776,7 @@ def _map_email_to_displayname(address): else: org = org_parts[-2] - desired_display_name = ( - cap(parts[0]) + " [" + cap(org) + "]" - ) + desired_display_name = cap(parts[0]) + " [" + cap(org) + "]" return desired_display_name diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index f8f8742bdc..8e362782cc 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -32,7 +32,10 @@ from synapse.http.servlet import ( parse_string, ) from synapse.rest.client.transactions import HttpTransactionCache -from synapse.storage.relations import AggregationPaginationToken, RelationPaginationToken +from synapse.storage.relations import ( + AggregationPaginationToken, + RelationPaginationToken, +) from ._base import client_patterns diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py index 10198662a9..e7578af804 100644 --- a/synapse/rest/client/v2_alpha/report_event.py +++ b/synapse/rest/client/v2_alpha/report_event.py @@ -33,9 +33,7 @@ logger = logging.getLogger(__name__) class ReportEventRestServlet(RestServlet): - PATTERNS = client_patterns( - "/rooms/(?P[^/]*)/report/(?P[^/]*)$" - ) + PATTERNS = client_patterns("/rooms/(?P[^/]*)/report/(?P[^/]*)$") def __init__(self, hs): super(ReportEventRestServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py index 87779645f9..8d1b810565 100644 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -129,22 +129,12 @@ class RoomKeysServlet(RestServlet): version = parse_string(request, "version") if session_id: - body = { - "sessions": { - session_id: body - } - } + body = {"sessions": {session_id: body}} if room_id: - body = { - "rooms": { - room_id: body - } - } + body = {"rooms": {room_id: body}} - yield self.e2e_room_keys_handler.upload_room_keys( - user_id, version, body - ) + yield self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) defer.returnValue((200, {})) @defer.inlineCallbacks @@ -212,10 +202,10 @@ class RoomKeysServlet(RestServlet): if session_id: # If the client requests a specific session, but that session was # not backed up, then return an M_NOT_FOUND. - if room_keys['rooms'] == {}: + if room_keys["rooms"] == {}: raise NotFoundError("No room_keys found") else: - room_keys = room_keys['rooms'][room_id]['sessions'][session_id] + room_keys = room_keys["rooms"][room_id]["sessions"][session_id] elif room_id: # If the client requests all sessions from a room, but no sessions # are found, then return an empty result rather than an error, so @@ -223,10 +213,10 @@ class RoomKeysServlet(RestServlet): # empty result is valid. (Similarly if the client requests all # sessions from the backup, but in that case, room_keys is already # in the right format, so we don't need to do anything about it.) - if room_keys['rooms'] == {}: - room_keys = {'sessions': {}} + if room_keys["rooms"] == {}: + room_keys = {"sessions": {}} else: - room_keys = room_keys['rooms'][room_id] + room_keys = room_keys["rooms"][room_id] defer.returnValue((200, room_keys)) @@ -256,9 +246,7 @@ class RoomKeysServlet(RestServlet): class RoomKeysNewVersionServlet(RestServlet): - PATTERNS = client_patterns( - "/room_keys/version$" - ) + PATTERNS = client_patterns("/room_keys/version$") def __init__(self, hs): """ @@ -304,9 +292,7 @@ class RoomKeysNewVersionServlet(RestServlet): user_id = requester.user.to_string() info = parse_json_object_from_request(request) - new_version = yield self.e2e_room_keys_handler.create_version( - user_id, info - ) + new_version = yield self.e2e_room_keys_handler.create_version(user_id, info) defer.returnValue((200, {"version": new_version})) # we deliberately don't have a PUT /version, as these things really should @@ -314,9 +300,7 @@ class RoomKeysNewVersionServlet(RestServlet): class RoomKeysVersionServlet(RestServlet): - PATTERNS = client_patterns( - "/room_keys/version(/(?P[^/]+))?$" - ) + PATTERNS = client_patterns("/room_keys/version(/(?P[^/]+))?$") def __init__(self, hs): """ @@ -350,9 +334,7 @@ class RoomKeysVersionServlet(RestServlet): user_id = requester.user.to_string() try: - info = yield self.e2e_room_keys_handler.get_version_info( - user_id, version - ) + info = yield self.e2e_room_keys_handler.get_version_info(user_id, version) except SynapseError as e: if e.code == 404: raise SynapseError(404, "No backup found", Codes.NOT_FOUND) @@ -375,9 +357,7 @@ class RoomKeysVersionServlet(RestServlet): requester = yield self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() - yield self.e2e_room_keys_handler.delete_version( - user_id, version - ) + yield self.e2e_room_keys_handler.delete_version(user_id, version) defer.returnValue((200, {})) @defer.inlineCallbacks @@ -407,11 +387,11 @@ class RoomKeysVersionServlet(RestServlet): info = parse_json_object_from_request(request) if version is None: - raise SynapseError(400, "No version specified to update", Codes.MISSING_PARAM) + raise SynapseError( + 400, "No version specified to update", Codes.MISSING_PARAM + ) - yield self.e2e_room_keys_handler.update_version( - user_id, version, info - ) + yield self.e2e_room_keys_handler.update_version(user_id, version, info) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py index c621a90fba..d7f7faa029 100644 --- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py +++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py @@ -47,9 +47,10 @@ class RoomUpgradeRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): """ + PATTERNS = client_patterns( # /rooms/$roomid/upgrade - "/rooms/(?P[^/]*)/upgrade$", + "/rooms/(?P[^/]*)/upgrade$" ) def __init__(self, hs): @@ -63,7 +64,7 @@ class RoomUpgradeRestServlet(RestServlet): requester = yield self._auth.get_user_by_req(request) content = parse_json_object_from_request(request) - assert_params_in_dict(content, ("new_version", )) + assert_params_in_dict(content, ("new_version",)) new_version = content["new_version"] if new_version not in KNOWN_ROOM_VERSIONS: @@ -77,9 +78,7 @@ class RoomUpgradeRestServlet(RestServlet): requester, room_id, new_version ) - ret = { - "replacement_room": new_room_id, - } + ret = {"replacement_room": new_room_id} defer.returnValue((200, ret)) diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py index 120a713361..78075b8fc0 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) class SendToDeviceRestServlet(servlet.RestServlet): PATTERNS = client_patterns( - "/sendToDevice/(?P[^/]*)/(?P[^/]*)$", + "/sendToDevice/(?P[^/]*)/(?P[^/]*)$" ) def __init__(self, hs): diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 148fc6c985..02d56dee6c 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -96,44 +96,42 @@ class SyncRestServlet(RestServlet): 400, "'from' is not a valid query parameter. Did you mean 'since'?" ) - requester = yield self.auth.get_user_by_req( - request, allow_guest=True - ) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) user = requester.user device_id = requester.device_id timeout = parse_integer(request, "timeout", default=0) since = parse_string(request, "since") set_presence = parse_string( - request, "set_presence", default="online", - allowed_values=self.ALLOWED_PRESENCE + request, + "set_presence", + default="online", + allowed_values=self.ALLOWED_PRESENCE, ) filter_id = parse_string(request, "filter", default=None) full_state = parse_boolean(request, "full_state", default=False) logger.debug( "/sync: user=%r, timeout=%r, since=%r," - " set_presence=%r, filter_id=%r, device_id=%r" % ( - user, timeout, since, set_presence, filter_id, device_id - ) + " set_presence=%r, filter_id=%r, device_id=%r" + % (user, timeout, since, set_presence, filter_id, device_id) ) request_key = (user, timeout, since, filter_id, full_state, device_id) if filter_id: - if filter_id.startswith('{'): + if filter_id.startswith("{"): try: filter_object = json.loads(filter_id) - set_timeline_upper_limit(filter_object, - self.hs.config.filter_timeline_limit) + set_timeline_upper_limit( + filter_object, self.hs.config.filter_timeline_limit + ) except Exception: raise SynapseError(400, "Invalid filter JSON") self.filtering.check_valid_filter(filter_object) filter = FilterCollection(filter_object) else: - filter = yield self.filtering.get_user_filter( - user.localpart, filter_id - ) + filter = yield self.filtering.get_user_filter(user.localpart, filter_id) else: filter = DEFAULT_FILTER_COLLECTION @@ -156,15 +154,19 @@ class SyncRestServlet(RestServlet): affect_presence = set_presence != PresenceState.OFFLINE if affect_presence: - yield self.presence_handler.set_state(user, {"presence": set_presence}, True) + yield self.presence_handler.set_state( + user, {"presence": set_presence}, True + ) context = yield self.presence_handler.user_syncing( - user.to_string(), affect_presence=affect_presence, + user.to_string(), affect_presence=affect_presence ) with context: sync_result = yield self.sync_handler.wait_for_sync_for_user( - sync_config, since_token=since_token, timeout=timeout, - full_state=full_state + sync_config, + since_token=since_token, + timeout=timeout, + full_state=full_state, ) time_now = self.clock.time_msec() @@ -176,53 +178,54 @@ class SyncRestServlet(RestServlet): @defer.inlineCallbacks def encode_response(self, time_now, sync_result, access_token_id, filter): - if filter.event_format == 'client': + if filter.event_format == "client": event_formatter = format_event_for_client_v2_without_room_id - elif filter.event_format == 'federation': + elif filter.event_format == "federation": event_formatter = format_event_raw else: - raise Exception("Unknown event format %s" % (filter.event_format, )) + raise Exception("Unknown event format %s" % (filter.event_format,)) joined = yield self.encode_joined( - sync_result.joined, time_now, access_token_id, + sync_result.joined, + time_now, + access_token_id, filter.event_fields, event_formatter, ) invited = yield self.encode_invited( - sync_result.invited, time_now, access_token_id, - event_formatter, + sync_result.invited, time_now, access_token_id, event_formatter ) archived = yield self.encode_archived( - sync_result.archived, time_now, access_token_id, + sync_result.archived, + time_now, + access_token_id, filter.event_fields, event_formatter, ) - defer.returnValue({ - "account_data": {"events": sync_result.account_data}, - "to_device": {"events": sync_result.to_device}, - "device_lists": { - "changed": list(sync_result.device_lists.changed), - "left": list(sync_result.device_lists.left), - }, - "presence": SyncRestServlet.encode_presence( - sync_result.presence, time_now - ), - "rooms": { - "join": joined, - "invite": invited, - "leave": archived, - }, - "groups": { - "join": sync_result.groups.join, - "invite": sync_result.groups.invite, - "leave": sync_result.groups.leave, - }, - "device_one_time_keys_count": sync_result.device_one_time_keys_count, - "next_batch": sync_result.next_batch.to_string(), - }) + defer.returnValue( + { + "account_data": {"events": sync_result.account_data}, + "to_device": {"events": sync_result.to_device}, + "device_lists": { + "changed": list(sync_result.device_lists.changed), + "left": list(sync_result.device_lists.left), + }, + "presence": SyncRestServlet.encode_presence( + sync_result.presence, time_now + ), + "rooms": {"join": joined, "invite": invited, "leave": archived}, + "groups": { + "join": sync_result.groups.join, + "invite": sync_result.groups.invite, + "leave": sync_result.groups.leave, + }, + "device_one_time_keys_count": sync_result.device_one_time_keys_count, + "next_batch": sync_result.next_batch.to_string(), + } + ) @staticmethod def encode_presence(events, time_now): @@ -262,7 +265,11 @@ class SyncRestServlet(RestServlet): joined = {} for room in rooms: joined[room.room_id] = yield self.encode_room( - room, time_now, token_id, joined=True, only_fields=event_fields, + room, + time_now, + token_id, + joined=True, + only_fields=event_fields, event_formatter=event_formatter, ) @@ -290,7 +297,9 @@ class SyncRestServlet(RestServlet): invited = {} for room in rooms: invite = yield self._event_serializer.serialize_event( - room.invite, time_now, token_id=token_id, + room.invite, + time_now, + token_id=token_id, event_format=event_formatter, is_invite=True, ) @@ -298,9 +307,7 @@ class SyncRestServlet(RestServlet): invite["unsigned"] = unsigned invited_state = list(unsigned.pop("invite_room_state", [])) invited_state.append(invite) - invited[room.room_id] = { - "invite_state": {"events": invited_state} - } + invited[room.room_id] = {"invite_state": {"events": invited_state}} defer.returnValue(invited) @@ -327,7 +334,10 @@ class SyncRestServlet(RestServlet): joined = {} for room in rooms: joined[room.room_id] = yield self.encode_room( - room, time_now, token_id, joined=False, + room, + time_now, + token_id, + joined=False, only_fields=event_fields, event_formatter=event_formatter, ) @@ -336,8 +346,7 @@ class SyncRestServlet(RestServlet): @defer.inlineCallbacks def encode_room( - self, room, time_now, token_id, joined, - only_fields, event_formatter, + self, room, time_now, token_id, joined, only_fields, event_formatter ): """ Args: @@ -355,9 +364,11 @@ class SyncRestServlet(RestServlet): Returns: dict[str, object]: the room, encoded in our response format """ + def serialize(events): return self._event_serializer.serialize_events( - events, time_now=time_now, + events, + time_now=time_now, # We don't bundle "live" events, as otherwise clients # will end up double counting annotations. bundle_aggregations=False, @@ -377,7 +388,9 @@ class SyncRestServlet(RestServlet): if event.room_id != room.room_id: logger.warn( "Event %r is under room %r instead of %r", - event.event_id, room.room_id, event.room_id, + event.event_id, + room.room_id, + event.room_id, ) serialized_state = yield serialize(state_events) diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index ebff7cff45..07b6ede603 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -29,9 +29,8 @@ class TagListServlet(RestServlet): """ GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1 """ - PATTERNS = client_patterns( - "/user/(?P[^/]*)/rooms/(?P[^/]*)/tags" - ) + + PATTERNS = client_patterns("/user/(?P[^/]*)/rooms/(?P[^/]*)/tags") def __init__(self, hs): super(TagListServlet, self).__init__() @@ -54,6 +53,7 @@ class TagServlet(RestServlet): PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 """ + PATTERNS = client_patterns( "/user/(?P[^/]*)/rooms/(?P[^/]*)/tags/(?P[^/]*)" ) @@ -74,9 +74,7 @@ class TagServlet(RestServlet): max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body) - self.notifier.on_new_event( - "account_data_key", max_id, users=[user_id] - ) + self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) defer.returnValue((200, {})) @@ -88,9 +86,7 @@ class TagServlet(RestServlet): max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag) - self.notifier.on_new_event( - "account_data_key", max_id, users=[user_id] - ) + self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index e7a987466a..1e66662a05 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -57,7 +57,7 @@ class ThirdPartyProtocolServlet(RestServlet): yield self.auth.get_user_by_req(request, allow_guest=True) protocols = yield self.appservice_handler.get_3pe_protocols( - only_protocol=protocol, + only_protocol=protocol ) if protocol in protocols: defer.returnValue((200, protocols[protocol])) diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py index 6c366142e1..2da0f55811 100644 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -26,6 +26,7 @@ class TokenRefreshRestServlet(RestServlet): Exchanges refresh tokens for a pair of an access token and a new refresh token. """ + PATTERNS = client_patterns("/tokenrefresh") def __init__(self, hs): diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py index b6f4d8b3f4..7edc801896 100644 --- a/synapse/rest/client/v2_alpha/user_directory.py +++ b/synapse/rest/client/v2_alpha/user_directory.py @@ -64,15 +64,14 @@ class UserDirectorySearchRestServlet(RestServlet): user_id = requester.user.to_string() if not self.hs.config.user_directory_search_enabled: - defer.returnValue((200, { - "limited": False, - "results": [], - })) + defer.returnValue((200, {"limited": False, "results": []})) body = parse_json_object_from_request(request) if self.hs.config.user_directory_defer_to_id_server: - signed_body = sign_json(body, self.hs.hostname, self.hs.config.signing_key[0]) + signed_body = sign_json( + body, self.hs.hostname, self.hs.config.signing_key[0] + ) url = "%s/_matrix/identity/api/v1/user_directory/search" % ( self.hs.config.user_directory_defer_to_id_server, ) @@ -88,7 +87,7 @@ class UserDirectorySearchRestServlet(RestServlet): raise SynapseError(400, "`search_term` is required field") results = yield self.user_directory_handler.search_users( - user_id, search_term, limit, + user_id, search_term, limit ) defer.returnValue((200, results)) @@ -98,9 +97,8 @@ class UserInfoServlet(RestServlet): """ GET /user/{user_id}/info HTTP/1.1 """ - PATTERNS = client_patterns( - "/user/(?P[^/]*)/info$" - ) + + PATTERNS = client_patterns("/user/(?P[^/]*)/info$") def __init__(self, hs): super(UserInfoServlet, self).__init__() @@ -113,9 +111,7 @@ class UserInfoServlet(RestServlet): registry = hs.get_federation_registry() if not registry.query_handlers.get("user_info"): - registry.register_query_handler( - "user_info", self._on_federation_query - ) + registry.register_query_handler("user_info", self._on_federation_query) @defer.inlineCallbacks def on_GET(self, request, user_id): @@ -127,7 +123,7 @@ class UserInfoServlet(RestServlet): # Attempt to make a federation request to the server that owns this user args = {"user_id": user_id} res = yield self.transport_layer.make_query( - user.domain, "user_info", args, retry_on_dns_fail=True, + user.domain, "user_info", args, retry_on_dns_fail=True ) defer.returnValue((200, res)) @@ -174,10 +170,7 @@ class UserInfoServlet(RestServlet): expiration_ts is not None and self.clock.time_msec() >= expiration_ts ) - res = { - "expired": is_expired, - "deactivated": is_deactivated, - } + res = {"expired": is_expired, "deactivated": is_deactivated} defer.returnValue(res) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index babbf6a23c..0e09191632 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -25,27 +25,28 @@ class VersionsRestServlet(RestServlet): PATTERNS = [re.compile("^/_matrix/client/versions$")] def on_GET(self, request): - return (200, { - "versions": [ - # XXX: at some point we need to decide whether we need to include - # the previous version numbers, given we've defined r0.3.0 to be - # backwards compatible with r0.2.0. But need to check how - # conscientious we've been in compatibility, and decide whether the - # middle number is the major revision when at 0.X.Y (as opposed to - # X.Y.Z). And we need to decide whether it's fair to make clients - # parse the version string to figure out what's going on. - "r0.0.1", - "r0.1.0", - "r0.2.0", - "r0.3.0", - "r0.4.0", - "r0.5.0", - ], - # as per MSC1497: - "unstable_features": { - "m.lazy_load_members": True, - } - }) + return ( + 200, + { + "versions": [ + # XXX: at some point we need to decide whether we need to include + # the previous version numbers, given we've defined r0.3.0 to be + # backwards compatible with r0.2.0. But need to check how + # conscientious we've been in compatibility, and decide whether the + # middle number is the major revision when at 0.X.Y (as opposed to + # X.Y.Z). And we need to decide whether it's fair to make clients + # parse the version string to figure out what's going on. + "r0.0.1", + "r0.1.0", + "r0.2.0", + "r0.3.0", + "r0.4.0", + "r0.5.0", + ], + # as per MSC1497: + "unstable_features": {"m.lazy_load_members": True}, + }, + ) def register_servlets(http_server): diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index 6b371bfa2f..9a32892d8b 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -42,6 +42,7 @@ logger = logging.getLogger(__name__) if hasattr(hmac, "compare_digest"): compare_digest = hmac.compare_digest else: + def compare_digest(a, b): return a == b @@ -80,6 +81,7 @@ class ConsentResource(Resource): For POST: required; gives the value to be recorded in the database against the user. """ + def __init__(self, hs): """ Args: @@ -98,21 +100,20 @@ class ConsentResource(Resource): if self._default_consent_version is None: raise ConfigError( "Consent resource is enabled but user_consent section is " - "missing in config file.", + "missing in config file." ) consent_template_directory = hs.config.user_consent_template_dir loader = jinja2.FileSystemLoader(consent_template_directory) self._jinja_env = jinja2.Environment( - loader=loader, - autoescape=jinja2.select_autoescape(['html', 'htm', 'xml']), + loader=loader, autoescape=jinja2.select_autoescape(["html", "htm", "xml"]) ) if hs.config.form_secret is None: raise ConfigError( "Consent resource is enabled but form_secret is not set in " - "config file. It should be set to an arbitrary secret string.", + "config file. It should be set to an arbitrary secret string." ) self._hmac_secret = hs.config.form_secret.encode("utf-8") @@ -139,7 +140,7 @@ class ConsentResource(Resource): self._check_hash(username, userhmac_bytes) - if username.startswith('@'): + if username.startswith("@"): qualified_user_id = username else: qualified_user_id = UserID(username, self.hs.hostname).to_string() @@ -153,7 +154,8 @@ class ConsentResource(Resource): try: self._render_template( - request, "%s.html" % (version,), + request, + "%s.html" % (version,), user=username, userhmac=userhmac, version=version, @@ -180,7 +182,7 @@ class ConsentResource(Resource): self._check_hash(username, userhmac) - if username.startswith('@'): + if username.startswith("@"): qualified_user_id = username else: qualified_user_id = UserID(username, self.hs.hostname).to_string() @@ -221,11 +223,13 @@ class ConsentResource(Resource): SynapseError if the hash doesn't match """ - want_mac = hmac.new( - key=self._hmac_secret, - msg=userid.encode('utf-8'), - digestmod=sha256, - ).hexdigest().encode('ascii') + want_mac = ( + hmac.new( + key=self._hmac_secret, msg=userid.encode("utf-8"), digestmod=sha256 + ) + .hexdigest() + .encode("ascii") + ) if not compare_digest(want_mac, userhmac): raise SynapseError(http_client.FORBIDDEN, "HMAC incorrect") diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py index ec0ec7b431..8e48ef3564 100644 --- a/synapse/rest/key/v2/local_key_resource.py +++ b/synapse/rest/key/v2/local_key_resource.py @@ -80,9 +80,7 @@ class LocalKey(Resource): for key in self.config.signing_key: verify_key_bytes = key.verify_key.encode() key_id = "%s:%s" % (key.alg, key.version) - verify_keys[key_id] = { - u"key": encode_base64(verify_key_bytes) - } + verify_keys[key_id] = {u"key": encode_base64(verify_key_bytes)} old_verify_keys = {} for key_id, key in self.config.old_signing_keys.items(): @@ -102,11 +100,7 @@ class LocalKey(Resource): u"tls_fingerprints": tls_fingerprints, } for key in self.config.signing_key: - json_object = sign_json( - json_object, - self.config.server_name, - key, - ) + json_object = sign_json(json_object, self.config.server_name, key) return json_object def render_GET(self, request): @@ -114,6 +108,4 @@ class LocalKey(Resource): # Update the expiry time if less than half the interval remains. if time_now + self.config.key_refresh_interval / 2 > self.valid_until_ts: self.update_response_body(time_now) - return respond_with_json_bytes( - request, 200, self.response_body, - ) + return respond_with_json_bytes(request, 200, self.response_body) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 8a730bbc35..ec8b9d7269 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -103,20 +103,16 @@ class RemoteKey(Resource): def async_render_GET(self, request): if len(request.postpath) == 1: server, = request.postpath - query = {server.decode('ascii'): {}} + query = {server.decode("ascii"): {}} elif len(request.postpath) == 2: server, key_id = request.postpath - minimum_valid_until_ts = parse_integer( - request, "minimum_valid_until_ts" - ) + minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts") arguments = {} if minimum_valid_until_ts is not None: arguments["minimum_valid_until_ts"] = minimum_valid_until_ts - query = {server.decode('ascii'): {key_id.decode('ascii'): arguments}} + query = {server.decode("ascii"): {key_id.decode("ascii"): arguments}} else: - raise SynapseError( - 404, "Not found %r" % request.postpath, Codes.NOT_FOUND - ) + raise SynapseError(404, "Not found %r" % request.postpath, Codes.NOT_FOUND) yield self.query_keys(request, query, query_remote_on_cache_miss=True) @@ -140,8 +136,8 @@ class RemoteKey(Resource): store_queries = [] for server_name, key_ids in query.items(): if ( - self.federation_domain_whitelist is not None and - server_name not in self.federation_domain_whitelist + self.federation_domain_whitelist is not None + and server_name not in self.federation_domain_whitelist ): logger.debug("Federation denied with %s", server_name) continue @@ -159,9 +155,7 @@ class RemoteKey(Resource): cache_misses = dict() for (server_name, key_id, from_server), results in cached.items(): - results = [ - (result["ts_added_ms"], result) for result in results - ] + results = [(result["ts_added_ms"], result) for result in results] if not results and key_id is not None: cache_misses.setdefault(server_name, set()).add(key_id) @@ -178,23 +172,30 @@ class RemoteKey(Resource): logger.debug( "Cached response for %r/%r is older than requested" ": valid_until (%r) < minimum_valid_until (%r)", - server_name, key_id, - ts_valid_until_ms, req_valid_until + server_name, + key_id, + ts_valid_until_ms, + req_valid_until, ) miss = True else: logger.debug( "Cached response for %r/%r is newer than requested" ": valid_until (%r) >= minimum_valid_until (%r)", - server_name, key_id, - ts_valid_until_ms, req_valid_until + server_name, + key_id, + ts_valid_until_ms, + req_valid_until, ) elif (ts_added_ms + ts_valid_until_ms) / 2 < time_now_ms: logger.debug( "Cached response for %r/%r is too old" ": (added (%r) + valid_until (%r)) / 2 < now (%r)", - server_name, key_id, - ts_added_ms, ts_valid_until_ms, time_now_ms + server_name, + key_id, + ts_added_ms, + ts_valid_until_ms, + time_now_ms, ) # We more than half way through the lifetime of the # response. We should fetch a fresh copy. @@ -203,8 +204,11 @@ class RemoteKey(Resource): logger.debug( "Cached response for %r/%r is still valid" ": (added (%r) + valid_until (%r)) / 2 < now (%r)", - server_name, key_id, - ts_added_ms, ts_valid_until_ms, time_now_ms + server_name, + key_id, + ts_added_ms, + ts_valid_until_ms, + time_now_ms, ) if miss: @@ -216,12 +220,10 @@ class RemoteKey(Resource): if cache_misses and query_remote_on_cache_miss: yield self.fetcher.get_keys(cache_misses) - yield self.query_keys( - request, query, query_remote_on_cache_miss=False - ) + yield self.query_keys(request, query, query_remote_on_cache_miss=False) else: result_io = BytesIO() - result_io.write(b"{\"server_keys\":") + result_io.write(b'{"server_keys":') sep = b"[" for json_bytes in json_results: result_io.write(sep) @@ -231,6 +233,4 @@ class RemoteKey(Resource): result_io.write(sep) result_io.write(b"]}") - respond_with_json_bytes( - request, 200, result_io.getvalue(), - ) + respond_with_json_bytes(request, 200, result_io.getvalue()) diff --git a/synapse/rest/media/v0/content_repository.py b/synapse/rest/media/v0/content_repository.py index 5a426ff2f6..86884c0ef4 100644 --- a/synapse/rest/media/v0/content_repository.py +++ b/synapse/rest/media/v0/content_repository.py @@ -44,6 +44,7 @@ class ContentRepoResource(resource.Resource): - Content type base64d (so we can return it when clients GET it) """ + isLeaf = True def __init__(self, hs, directory): @@ -56,7 +57,7 @@ class ContentRepoResource(resource.Resource): # servers. # TODO: A little crude here, we could do this better. - filename = request.path.decode('ascii').split('/')[-1] + filename = request.path.decode("ascii").split("/")[-1] # be paranoid filename = re.sub("[^0-9A-z.-_]", "", filename) @@ -69,17 +70,15 @@ class ContentRepoResource(resource.Resource): base64_contentype = filename.split(".")[1] content_type = base64.urlsafe_b64decode(base64_contentype) logger.info("Sending file %s", file_path) - f = open(file_path, 'rb') - request.setHeader('Content-Type', content_type) + f = open(file_path, "rb") + request.setHeader("Content-Type", content_type) # cache for at least a day. # XXX: we might want to turn this off for data we don't want to # recommend caching as it's sensitive or private - or at least # select private. don't bother setting Expires as all our matrix # clients are smart enough to be happy with Cache-Control (right?) - request.setHeader( - b"Cache-Control", b"public,max-age=86400,s-maxage=86400" - ) + request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400") d = FileSender().beginFileTransfer(f, request) @@ -87,13 +86,15 @@ class ContentRepoResource(resource.Resource): def cbFinished(ignored): f.close() finish_request(request) + d.addCallback(cbFinished) else: respond_with_json_bytes( request, 404, json.dumps(cs_error("Not found", code=Codes.NOT_FOUND)), - send_cors=True) + send_cors=True, + ) return server.NOT_DONE_YET diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 2dcc8f74d6..3318638d3e 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -38,8 +38,8 @@ def parse_media_id(request): server_name, media_id = request.postpath[:2] if isinstance(server_name, bytes): - server_name = server_name.decode('utf-8') - media_id = media_id.decode('utf8') + server_name = server_name.decode("utf-8") + media_id = media_id.decode("utf8") file_name = None if len(request.postpath) > 2: @@ -120,11 +120,11 @@ def add_file_headers(request, media_type, file_size, upload_name): # correctly interpret those as of 0.99.2 and (b) they are a bit of a pain and we # may as well just do the filename* version. if _can_encode_filename_as_token(upload_name): - disposition = 'inline; filename=%s' % (upload_name, ) + disposition = "inline; filename=%s" % (upload_name,) else: - disposition = "inline; filename*=utf-8''%s" % (_quote(upload_name), ) + disposition = "inline; filename*=utf-8''%s" % (_quote(upload_name),) - request.setHeader(b"Content-Disposition", disposition.encode('ascii')) + request.setHeader(b"Content-Disposition", disposition.encode("ascii")) # cache for at least a day. # XXX: we might want to turn this off for data we don't want to @@ -137,10 +137,27 @@ def add_file_headers(request, media_type, file_size, upload_name): # separators as defined in RFC2616. SP and HT are handled separately. # see _can_encode_filename_as_token. -_FILENAME_SEPARATOR_CHARS = set(( - "(", ")", "<", ">", "@", ",", ";", ":", "\\", '"', - "/", "[", "]", "?", "=", "{", "}", -)) +_FILENAME_SEPARATOR_CHARS = set( + ( + "(", + ")", + "<", + ">", + "@", + ",", + ";", + ":", + "\\", + '"', + "/", + "[", + "]", + "?", + "=", + "{", + "}", + ) +) def _can_encode_filename_as_token(x): @@ -271,7 +288,7 @@ def get_filename_from_headers(headers): Returns: A Unicode string of the filename, or None. """ - content_disposition = headers.get(b"Content-Disposition", [b'']) + content_disposition = headers.get(b"Content-Disposition", [b""]) # No header, bail out. if not content_disposition[0]: @@ -293,7 +310,7 @@ def get_filename_from_headers(headers): # Once it is decoded, we can then unquote the %-encoded # parts strictly into a unicode string. upload_name = urllib.parse.unquote( - upload_name_utf8.decode('ascii'), errors="strict" + upload_name_utf8.decode("ascii"), errors="strict" ) except UnicodeDecodeError: # Incorrect UTF-8. @@ -302,7 +319,7 @@ def get_filename_from_headers(headers): # On Python 2, we first unquote the %-encoded parts and then # decode it strictly using UTF-8. try: - upload_name = urllib.parse.unquote(upload_name_utf8).decode('utf8') + upload_name = urllib.parse.unquote(upload_name_utf8).decode("utf8") except UnicodeDecodeError: pass @@ -310,7 +327,7 @@ def get_filename_from_headers(headers): if not upload_name: upload_name_ascii = params.get(b"filename", None) if upload_name_ascii and is_ascii(upload_name_ascii): - upload_name = upload_name_ascii.decode('ascii') + upload_name = upload_name_ascii.decode("ascii") # This may be None here, indicating we did not find a matching name. return upload_name @@ -328,19 +345,19 @@ def _parse_header(line): Tuple[bytes, dict[bytes, bytes]]: the main content-type, followed by the parameter dictionary """ - parts = _parseparam(b';' + line) + parts = _parseparam(b";" + line) key = next(parts) pdict = {} for p in parts: - i = p.find(b'=') + i = p.find(b"=") if i >= 0: name = p[:i].strip().lower() - value = p[i + 1:].strip() + value = p[i + 1 :].strip() # strip double-quotes if len(value) >= 2 and value[0:1] == value[-1:] == b'"': value = value[1:-1] - value = value.replace(b'\\\\', b'\\').replace(b'\\"', b'"') + value = value.replace(b"\\\\", b"\\").replace(b'\\"', b'"') pdict[name] = value return key, pdict @@ -357,16 +374,16 @@ def _parseparam(s): Returns: Iterable[bytes]: the split input """ - while s[:1] == b';': + while s[:1] == b";": s = s[1:] # look for the next ; - end = s.find(b';') + end = s.find(b";") # if there is an odd number of " marks between here and the next ;, skip to the # next ; instead while end > 0 and (s.count(b'"', 0, end) - s.count(b'\\"', 0, end)) % 2: - end = s.find(b';', end + 1) + end = s.find(b";", end + 1) if end < 0: end = len(s) diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py index 77316033f7..fa3d6680fc 100644 --- a/synapse/rest/media/v1/config_resource.py +++ b/synapse/rest/media/v1/config_resource.py @@ -29,9 +29,7 @@ class MediaConfigResource(Resource): config = hs.get_config() self.clock = hs.get_clock() self.auth = hs.get_auth() - self.limits_dict = { - "m.upload.size": config.max_upload_size, - } + self.limits_dict = {"m.upload.size": config.max_upload_size} def render_GET(self, request): self._async_render_GET(request) diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py index bdc5daecc1..a21a35f843 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py @@ -54,18 +54,20 @@ class DownloadResource(Resource): b" plugin-types application/pdf;" b" style-src 'unsafe-inline';" b" media-src 'self';" - b" object-src 'self';" + b" object-src 'self';", ) server_name, media_id, name = parse_media_id(request) if server_name == self.server_name: yield self.media_repo.get_local_media(request, media_id, name) else: allow_remote = synapse.http.servlet.parse_boolean( - request, "allow_remote", default=True) + request, "allow_remote", default=True + ) if not allow_remote: logger.info( "Rejecting request for remote media %s/%s due to allow_remote", - server_name, media_id, + server_name, + media_id, ) respond_404(request) return diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index c8586fa280..e25c382c9c 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py @@ -24,6 +24,7 @@ def _wrap_in_base_path(func): """Takes a function that returns a relative path and turns it into an absolute path based on the location of the primary media store """ + @functools.wraps(func) def _wrapped(self, *args, **kwargs): path = func(self, *args, **kwargs) @@ -43,125 +44,102 @@ class MediaFilePaths(object): def __init__(self, primary_base_path): self.base_path = primary_base_path - def default_thumbnail_rel(self, default_top_level, default_sub_type, width, - height, content_type, method): + def default_thumbnail_rel( + self, default_top_level, default_sub_type, width, height, content_type, method + ): top_level_type, sub_type = content_type.split("/") - file_name = "%i-%i-%s-%s-%s" % ( - width, height, top_level_type, sub_type, method - ) + file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) return os.path.join( - "default_thumbnails", default_top_level, - default_sub_type, file_name + "default_thumbnails", default_top_level, default_sub_type, file_name ) default_thumbnail = _wrap_in_base_path(default_thumbnail_rel) def local_media_filepath_rel(self, media_id): - return os.path.join( - "local_content", - media_id[0:2], media_id[2:4], media_id[4:] - ) + return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:]) local_media_filepath = _wrap_in_base_path(local_media_filepath_rel) - def local_media_thumbnail_rel(self, media_id, width, height, content_type, - method): + def local_media_thumbnail_rel(self, media_id, width, height, content_type, method): top_level_type, sub_type = content_type.split("/") - file_name = "%i-%i-%s-%s-%s" % ( - width, height, top_level_type, sub_type, method - ) + file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) return os.path.join( - "local_thumbnails", - media_id[0:2], media_id[2:4], media_id[4:], - file_name + "local_thumbnails", media_id[0:2], media_id[2:4], media_id[4:], file_name ) local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel) def remote_media_filepath_rel(self, server_name, file_id): return os.path.join( - "remote_content", server_name, - file_id[0:2], file_id[2:4], file_id[4:] + "remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:] ) remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel) - def remote_media_thumbnail_rel(self, server_name, file_id, width, height, - content_type, method): + def remote_media_thumbnail_rel( + self, server_name, file_id, width, height, content_type, method + ): top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type) return os.path.join( - "remote_thumbnail", server_name, - file_id[0:2], file_id[2:4], file_id[4:], - file_name + "remote_thumbnail", + server_name, + file_id[0:2], + file_id[2:4], + file_id[4:], + file_name, ) remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel) def remote_media_thumbnail_dir(self, server_name, file_id): return os.path.join( - self.base_path, "remote_thumbnail", server_name, - file_id[0:2], file_id[2:4], file_id[4:], + self.base_path, + "remote_thumbnail", + server_name, + file_id[0:2], + file_id[2:4], + file_id[4:], ) def url_cache_filepath_rel(self, media_id): if NEW_FORMAT_ID_RE.match(media_id): # Media id is of the form # E.g.: 2017-09-28-fsdRDt24DS234dsf - return os.path.join( - "url_cache", - media_id[:10], media_id[11:] - ) + return os.path.join("url_cache", media_id[:10], media_id[11:]) else: - return os.path.join( - "url_cache", - media_id[0:2], media_id[2:4], media_id[4:], - ) + return os.path.join("url_cache", media_id[0:2], media_id[2:4], media_id[4:]) url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel) def url_cache_filepath_dirs_to_delete(self, media_id): "The dirs to try and remove if we delete the media_id file" if NEW_FORMAT_ID_RE.match(media_id): - return [ - os.path.join( - self.base_path, "url_cache", - media_id[:10], - ), - ] + return [os.path.join(self.base_path, "url_cache", media_id[:10])] else: return [ - os.path.join( - self.base_path, "url_cache", - media_id[0:2], media_id[2:4], - ), - os.path.join( - self.base_path, "url_cache", - media_id[0:2], - ), + os.path.join(self.base_path, "url_cache", media_id[0:2], media_id[2:4]), + os.path.join(self.base_path, "url_cache", media_id[0:2]), ] - def url_cache_thumbnail_rel(self, media_id, width, height, content_type, - method): + def url_cache_thumbnail_rel(self, media_id, width, height, content_type, method): # Media id is of the form # E.g.: 2017-09-28-fsdRDt24DS234dsf top_level_type, sub_type = content_type.split("/") - file_name = "%i-%i-%s-%s-%s" % ( - width, height, top_level_type, sub_type, method - ) + file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) if NEW_FORMAT_ID_RE.match(media_id): return os.path.join( - "url_cache_thumbnails", - media_id[:10], media_id[11:], - file_name + "url_cache_thumbnails", media_id[:10], media_id[11:], file_name ) else: return os.path.join( "url_cache_thumbnails", - media_id[0:2], media_id[2:4], media_id[4:], - file_name + media_id[0:2], + media_id[2:4], + media_id[4:], + file_name, ) url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel) @@ -172,13 +150,15 @@ class MediaFilePaths(object): if NEW_FORMAT_ID_RE.match(media_id): return os.path.join( - self.base_path, "url_cache_thumbnails", - media_id[:10], media_id[11:], + self.base_path, "url_cache_thumbnails", media_id[:10], media_id[11:] ) else: return os.path.join( - self.base_path, "url_cache_thumbnails", - media_id[0:2], media_id[2:4], media_id[4:], + self.base_path, + "url_cache_thumbnails", + media_id[0:2], + media_id[2:4], + media_id[4:], ) def url_cache_thumbnail_dirs_to_delete(self, media_id): @@ -188,26 +168,21 @@ class MediaFilePaths(object): if NEW_FORMAT_ID_RE.match(media_id): return [ os.path.join( - self.base_path, "url_cache_thumbnails", - media_id[:10], media_id[11:], - ), - os.path.join( - self.base_path, "url_cache_thumbnails", - media_id[:10], + self.base_path, "url_cache_thumbnails", media_id[:10], media_id[11:] ), + os.path.join(self.base_path, "url_cache_thumbnails", media_id[:10]), ] else: return [ os.path.join( - self.base_path, "url_cache_thumbnails", - media_id[0:2], media_id[2:4], media_id[4:], - ), - os.path.join( - self.base_path, "url_cache_thumbnails", - media_id[0:2], media_id[2:4], + self.base_path, + "url_cache_thumbnails", + media_id[0:2], + media_id[2:4], + media_id[4:], ), os.path.join( - self.base_path, "url_cache_thumbnails", - media_id[0:2], + self.base_path, "url_cache_thumbnails", media_id[0:2], media_id[2:4] ), + os.path.join(self.base_path, "url_cache_thumbnails", media_id[0:2]), ] diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index a4929dd5db..df3d985a38 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -100,17 +100,16 @@ class MediaRepository(object): storage_providers.append(provider) self.media_storage = MediaStorage( - self.hs, self.primary_base_path, self.filepaths, storage_providers, + self.hs, self.primary_base_path, self.filepaths, storage_providers ) self.clock.looping_call( - self._start_update_recently_accessed, - UPDATE_RECENTLY_ACCESSED_TS, + self._start_update_recently_accessed, UPDATE_RECENTLY_ACCESSED_TS ) def _start_update_recently_accessed(self): return run_as_background_process( - "update_recently_accessed_media", self._update_recently_accessed, + "update_recently_accessed_media", self._update_recently_accessed ) @defer.inlineCallbacks @@ -138,8 +137,9 @@ class MediaRepository(object): self.recently_accessed_locals.add(media_id) @defer.inlineCallbacks - def create_content(self, media_type, upload_name, content, content_length, - auth_user): + def create_content( + self, media_type, upload_name, content, content_length, auth_user + ): """Store uploaded content for a local user and return the mxc URL Args: @@ -154,10 +154,7 @@ class MediaRepository(object): """ media_id = random_string(24) - file_info = FileInfo( - server_name=None, - file_id=media_id, - ) + file_info = FileInfo(server_name=None, file_id=media_id) fname = yield self.media_storage.store_file(content, file_info) @@ -172,9 +169,7 @@ class MediaRepository(object): user_id=auth_user, ) - yield self._generate_thumbnails( - None, media_id, media_id, media_type, - ) + yield self._generate_thumbnails(None, media_id, media_id, media_type) defer.returnValue("mxc://%s/%s" % (self.server_name, media_id)) @@ -205,14 +200,11 @@ class MediaRepository(object): upload_name = name if name else media_info["upload_name"] url_cache = media_info["url_cache"] - file_info = FileInfo( - None, media_id, - url_cache=url_cache, - ) + file_info = FileInfo(None, media_id, url_cache=url_cache) responder = yield self.media_storage.fetch_media(file_info) yield respond_with_responder( - request, responder, media_type, media_length, upload_name, + request, responder, media_type, media_length, upload_name ) @defer.inlineCallbacks @@ -232,8 +224,8 @@ class MediaRepository(object): to request """ if ( - self.federation_domain_whitelist is not None and - server_name not in self.federation_domain_whitelist + self.federation_domain_whitelist is not None + and server_name not in self.federation_domain_whitelist ): raise FederationDeniedError(server_name) @@ -244,7 +236,7 @@ class MediaRepository(object): key = (server_name, media_id) with (yield self.remote_media_linearizer.queue(key)): responder, media_info = yield self._get_remote_media_impl( - server_name, media_id, + server_name, media_id ) # We deliberately stream the file outside the lock @@ -253,7 +245,7 @@ class MediaRepository(object): media_length = media_info["media_length"] upload_name = name if name else media_info["upload_name"] yield respond_with_responder( - request, responder, media_type, media_length, upload_name, + request, responder, media_type, media_length, upload_name ) else: respond_404(request) @@ -272,8 +264,8 @@ class MediaRepository(object): Deferred[dict]: The media_info of the file """ if ( - self.federation_domain_whitelist is not None and - server_name not in self.federation_domain_whitelist + self.federation_domain_whitelist is not None + and server_name not in self.federation_domain_whitelist ): raise FederationDeniedError(server_name) @@ -282,7 +274,7 @@ class MediaRepository(object): key = (server_name, media_id) with (yield self.remote_media_linearizer.queue(key)): responder, media_info = yield self._get_remote_media_impl( - server_name, media_id, + server_name, media_id ) # Ensure we actually use the responder so that it releases resources @@ -305,9 +297,7 @@ class MediaRepository(object): Returns: Deferred[(Responder, media_info)] """ - media_info = yield self.store.get_cached_remote_media( - server_name, media_id - ) + media_info = yield self.store.get_cached_remote_media(server_name, media_id) # file_id is the ID we use to track the file locally. If we've already # seen the file then reuse the existing ID, otherwise genereate a new @@ -331,9 +321,7 @@ class MediaRepository(object): # Failed to find the file anywhere, lets download it. - media_info = yield self._download_remote_file( - server_name, media_id, file_id - ) + media_info = yield self._download_remote_file(server_name, media_id, file_id) responder = yield self.media_storage.fetch_media(file_info) defer.returnValue((responder, media_info)) @@ -354,54 +342,60 @@ class MediaRepository(object): Deferred[MediaInfo] """ - file_info = FileInfo( - server_name=server_name, - file_id=file_id, - ) + file_info = FileInfo(server_name=server_name, file_id=file_id) with self.media_storage.store_into_file(file_info) as (f, fname, finish): - request_path = "/".join(( - "/_matrix/media/v1/download", server_name, media_id, - )) + request_path = "/".join( + ("/_matrix/media/v1/download", server_name, media_id) + ) try: length, headers = yield self.client.get_file( - server_name, request_path, output_stream=f, - max_size=self.max_upload_size, args={ + server_name, + request_path, + output_stream=f, + max_size=self.max_upload_size, + args={ # tell the remote server to 404 if it doesn't # recognise the server_name, to make sure we don't # end up with a routing loop. - "allow_remote": "false", - } + "allow_remote": "false" + }, ) except RequestSendFailed as e: - logger.warn("Request failed fetching remote media %s/%s: %r", - server_name, media_id, e) + logger.warn( + "Request failed fetching remote media %s/%s: %r", + server_name, + media_id, + e, + ) raise SynapseError(502, "Failed to fetch remote media") except HttpResponseException as e: - logger.warn("HTTP error fetching remote media %s/%s: %s", - server_name, media_id, e.response) + logger.warn( + "HTTP error fetching remote media %s/%s: %s", + server_name, + media_id, + e.response, + ) if e.code == twisted.web.http.NOT_FOUND: raise e.to_synapse_error() raise SynapseError(502, "Failed to fetch remote media") except SynapseError: - logger.warn( - "Failed to fetch remote media %s/%s", - server_name, media_id, - ) + logger.warn("Failed to fetch remote media %s/%s", server_name, media_id) raise except NotRetryingDestination: logger.warn("Not retrying destination %r", server_name) raise SynapseError(502, "Failed to fetch remote media") except Exception: - logger.exception("Failed to fetch remote media %s/%s", - server_name, media_id) + logger.exception( + "Failed to fetch remote media %s/%s", server_name, media_id + ) raise SynapseError(502, "Failed to fetch remote media") yield finish() - media_type = headers[b"Content-Type"][0].decode('ascii') + media_type = headers[b"Content-Type"][0].decode("ascii") upload_name = get_filename_from_headers(headers) time_now_ms = self.clock.time_msec() @@ -425,24 +419,23 @@ class MediaRepository(object): "filesystem_id": file_id, } - yield self._generate_thumbnails( - server_name, media_id, file_id, media_type, - ) + yield self._generate_thumbnails(server_name, media_id, file_id, media_type) defer.returnValue(media_info) def _get_thumbnail_requirements(self, media_type): return self.thumbnail_requirements.get(media_type, ()) - def _generate_thumbnail(self, thumbnailer, t_width, t_height, - t_method, t_type): + def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type): m_width = thumbnailer.width m_height = thumbnailer.height if m_width * m_height >= self.max_image_pixels: logger.info( "Image too large to thumbnail %r x %r > %r", - m_width, m_height, self.max_image_pixels + m_width, + m_height, + self.max_image_pixels, ) return @@ -462,17 +455,22 @@ class MediaRepository(object): return t_byte_source @defer.inlineCallbacks - def generate_local_exact_thumbnail(self, media_id, t_width, t_height, - t_method, t_type, url_cache): - input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo( - None, media_id, url_cache=url_cache, - )) + def generate_local_exact_thumbnail( + self, media_id, t_width, t_height, t_method, t_type, url_cache + ): + input_path = yield self.media_storage.ensure_media_is_in_local_cache( + FileInfo(None, media_id, url_cache=url_cache) + ) thumbnailer = Thumbnailer(input_path) t_byte_source = yield logcontext.defer_to_thread( self.hs.get_reactor(), self._generate_thumbnail, - thumbnailer, t_width, t_height, t_method, t_type + thumbnailer, + t_width, + t_height, + t_method, + t_type, ) if t_byte_source: @@ -489,7 +487,7 @@ class MediaRepository(object): ) output_path = yield self.media_storage.store_file( - t_byte_source, file_info, + t_byte_source, file_info ) finally: t_byte_source.close() @@ -505,17 +503,22 @@ class MediaRepository(object): defer.returnValue(output_path) @defer.inlineCallbacks - def generate_remote_exact_thumbnail(self, server_name, file_id, media_id, - t_width, t_height, t_method, t_type): - input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo( - server_name, file_id, url_cache=False, - )) + def generate_remote_exact_thumbnail( + self, server_name, file_id, media_id, t_width, t_height, t_method, t_type + ): + input_path = yield self.media_storage.ensure_media_is_in_local_cache( + FileInfo(server_name, file_id, url_cache=False) + ) thumbnailer = Thumbnailer(input_path) t_byte_source = yield logcontext.defer_to_thread( self.hs.get_reactor(), self._generate_thumbnail, - thumbnailer, t_width, t_height, t_method, t_type + thumbnailer, + t_width, + t_height, + t_method, + t_type, ) if t_byte_source: @@ -531,7 +534,7 @@ class MediaRepository(object): ) output_path = yield self.media_storage.store_file( - t_byte_source, file_info, + t_byte_source, file_info ) finally: t_byte_source.close() @@ -541,15 +544,22 @@ class MediaRepository(object): t_len = os.path.getsize(output_path) yield self.store.store_remote_media_thumbnail( - server_name, media_id, file_id, - t_width, t_height, t_type, t_method, t_len + server_name, + media_id, + file_id, + t_width, + t_height, + t_type, + t_method, + t_len, ) defer.returnValue(output_path) @defer.inlineCallbacks - def _generate_thumbnails(self, server_name, media_id, file_id, media_type, - url_cache=False): + def _generate_thumbnails( + self, server_name, media_id, file_id, media_type, url_cache=False + ): """Generate and store thumbnails for an image. Args: @@ -568,9 +578,9 @@ class MediaRepository(object): if not requirements: return - input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo( - server_name, file_id, url_cache=url_cache, - )) + input_path = yield self.media_storage.ensure_media_is_in_local_cache( + FileInfo(server_name, file_id, url_cache=url_cache) + ) thumbnailer = Thumbnailer(input_path) m_width = thumbnailer.width @@ -579,14 +589,15 @@ class MediaRepository(object): if m_width * m_height >= self.max_image_pixels: logger.info( "Image too large to thumbnail %r x %r > %r", - m_width, m_height, self.max_image_pixels + m_width, + m_height, + self.max_image_pixels, ) return if thumbnailer.transpose_method is not None: m_width, m_height = yield logcontext.defer_to_thread( - self.hs.get_reactor(), - thumbnailer.transpose + self.hs.get_reactor(), thumbnailer.transpose ) # We deduplicate the thumbnail sizes by ignoring the cropped versions if @@ -606,15 +617,11 @@ class MediaRepository(object): # Generate the thumbnail if t_method == "crop": t_byte_source = yield logcontext.defer_to_thread( - self.hs.get_reactor(), - thumbnailer.crop, - t_width, t_height, t_type, + self.hs.get_reactor(), thumbnailer.crop, t_width, t_height, t_type ) elif t_method == "scale": t_byte_source = yield logcontext.defer_to_thread( - self.hs.get_reactor(), - thumbnailer.scale, - t_width, t_height, t_type, + self.hs.get_reactor(), thumbnailer.scale, t_width, t_height, t_type ) else: logger.error("Unrecognized method: %r", t_method) @@ -636,7 +643,7 @@ class MediaRepository(object): ) output_path = yield self.media_storage.store_file( - t_byte_source, file_info, + t_byte_source, file_info ) finally: t_byte_source.close() @@ -646,18 +653,21 @@ class MediaRepository(object): # Write to database if server_name: yield self.store.store_remote_media_thumbnail( - server_name, media_id, file_id, - t_width, t_height, t_type, t_method, t_len + server_name, + media_id, + file_id, + t_width, + t_height, + t_type, + t_method, + t_len, ) else: yield self.store.store_local_thumbnail( media_id, t_width, t_height, t_type, t_method, t_len ) - defer.returnValue({ - "width": m_width, - "height": m_height, - }) + defer.returnValue({"width": m_width, "height": m_height}) @defer.inlineCallbacks def delete_old_remote_media(self, before_ts): @@ -749,11 +759,12 @@ class MediaRepositoryResource(Resource): self.putChild(b"upload", UploadResource(hs, media_repo)) self.putChild(b"download", DownloadResource(hs, media_repo)) - self.putChild(b"thumbnail", ThumbnailResource( - hs, media_repo, media_repo.media_storage, - )) + self.putChild( + b"thumbnail", ThumbnailResource(hs, media_repo, media_repo.media_storage) + ) if hs.config.url_preview_enabled: - self.putChild(b"preview_url", PreviewUrlResource( - hs, media_repo, media_repo.media_storage, - )) + self.putChild( + b"preview_url", + PreviewUrlResource(hs, media_repo, media_repo.media_storage), + ) self.putChild(b"config", MediaConfigResource(hs)) diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index 896078fe76..eff86836fb 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -66,8 +66,7 @@ class MediaStorage(object): with self.store_into_file(file_info) as (f, fname, finish_cb): # Write to the main repository yield logcontext.defer_to_thread( - self.hs.get_reactor(), - _write_file_synchronously, source, f, + self.hs.get_reactor(), _write_file_synchronously, source, f ) yield finish_cb() @@ -179,7 +178,8 @@ class MediaStorage(object): if res: with res: consumer = BackgroundFileConsumer( - open(local_path, "wb"), self.hs.get_reactor()) + open(local_path, "wb"), self.hs.get_reactor() + ) yield res.write_to_consumer(consumer) yield consumer.wait() defer.returnValue(local_path) @@ -217,10 +217,10 @@ class MediaStorage(object): width=file_info.thumbnail_width, height=file_info.thumbnail_height, content_type=file_info.thumbnail_type, - method=file_info.thumbnail_method + method=file_info.thumbnail_method, ) return self.filepaths.remote_media_filepath_rel( - file_info.server_name, file_info.file_id, + file_info.server_name, file_info.file_id ) if file_info.thumbnail: @@ -229,11 +229,9 @@ class MediaStorage(object): width=file_info.thumbnail_width, height=file_info.thumbnail_height, content_type=file_info.thumbnail_type, - method=file_info.thumbnail_method + method=file_info.thumbnail_method, ) - return self.filepaths.local_media_filepath_rel( - file_info.file_id, - ) + return self.filepaths.local_media_filepath_rel(file_info.file_id) def _write_file_synchronously(source, dest): @@ -255,6 +253,7 @@ class FileResponder(Responder): open_file (file): A file like object to be streamed ot the client, is closed when finished streaming. """ + def __init__(self, open_file): self.open_file = open_file diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 85a7c61a24..e7c5f5a8e2 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -94,7 +94,7 @@ class PreviewUrlResource(Resource): ) self._cleaner_loop = self.clock.looping_call( - self._start_expire_url_cache_data, 10 * 1000, + self._start_expire_url_cache_data, 10 * 1000 ) def render_OPTIONS(self, request): @@ -123,16 +123,16 @@ class PreviewUrlResource(Resource): for attrib in entry: pattern = entry[attrib] value = getattr(url_tuple, attrib) - logger.debug(( - "Matching attrib '%s' with value '%s' against" - " pattern '%s'" - ) % (attrib, value, pattern)) + logger.debug( + ("Matching attrib '%s' with value '%s' against" " pattern '%s'") + % (attrib, value, pattern) + ) if value is None: match = False continue - if pattern.startswith('^'): + if pattern.startswith("^"): if not re.match(pattern, getattr(url_tuple, attrib)): match = False continue @@ -141,12 +141,9 @@ class PreviewUrlResource(Resource): match = False continue if match: - logger.warn( - "URL %s blocked by url_blacklist entry %s", url, entry - ) + logger.warn("URL %s blocked by url_blacklist entry %s", url, entry) raise SynapseError( - 403, "URL blocked by url pattern blacklist entry", - Codes.UNKNOWN + 403, "URL blocked by url pattern blacklist entry", Codes.UNKNOWN ) # the in-memory cache: @@ -158,14 +155,8 @@ class PreviewUrlResource(Resource): observable = self._cache.get(url) if not observable: - download = run_in_background( - self._do_preview, - url, requester.user, ts, - ) - observable = ObservableDeferred( - download, - consumeErrors=True - ) + download = run_in_background(self._do_preview, url, requester.user, ts) + observable = ObservableDeferred(download, consumeErrors=True) self._cache[url] = observable else: logger.info("Returning cached response") @@ -189,15 +180,15 @@ class PreviewUrlResource(Resource): # historical previews, if we have any) cache_result = yield self.store.get_url_cache(url, ts) if ( - cache_result and - cache_result["expires_ts"] > ts and - cache_result["response_code"] / 100 == 2 + cache_result + and cache_result["expires_ts"] > ts + and cache_result["response_code"] / 100 == 2 ): # It may be stored as text in the database, not as bytes (such as # PostgreSQL). If so, encode it back before handing it on. og = cache_result["og"] if isinstance(og, six.text_type): - og = og.encode('utf8') + og = og.encode("utf8") defer.returnValue(og) return @@ -205,33 +196,31 @@ class PreviewUrlResource(Resource): logger.debug("got media_info of '%s'" % media_info) - if _is_media(media_info['media_type']): - file_id = media_info['filesystem_id'] + if _is_media(media_info["media_type"]): + file_id = media_info["filesystem_id"] dims = yield self.media_repo._generate_thumbnails( - None, file_id, file_id, media_info["media_type"], - url_cache=True, + None, file_id, file_id, media_info["media_type"], url_cache=True ) og = { - "og:description": media_info['download_name'], - "og:image": "mxc://%s/%s" % ( - self.server_name, media_info['filesystem_id'] - ), - "og:image:type": media_info['media_type'], - "matrix:image:size": media_info['media_length'], + "og:description": media_info["download_name"], + "og:image": "mxc://%s/%s" + % (self.server_name, media_info["filesystem_id"]), + "og:image:type": media_info["media_type"], + "matrix:image:size": media_info["media_length"], } if dims: - og["og:image:width"] = dims['width'] - og["og:image:height"] = dims['height'] + og["og:image:width"] = dims["width"] + og["og:image:height"] = dims["height"] else: logger.warn("Couldn't get dims for %s" % url) # define our OG response for this media - elif _is_html(media_info['media_type']): + elif _is_html(media_info["media_type"]): # TODO: somehow stop a big HTML tree from exploding synapse's RAM - with open(media_info['filename'], 'rb') as file: + with open(media_info["filename"], "rb") as file: body = file.read() encoding = None @@ -244,45 +233,43 @@ class PreviewUrlResource(Resource): # If we find a match, it should take precedence over the # Content-Type header, so set it here. if match: - encoding = match.group(1).decode('ascii') + encoding = match.group(1).decode("ascii") # If we don't find a match, we'll look at the HTTP Content-Type, and # if that doesn't exist, we'll fall back to UTF-8. if not encoding: - match = _content_type_match.match( - media_info['media_type'] - ) + match = _content_type_match.match(media_info["media_type"]) encoding = match.group(1) if match else "utf-8" - og = decode_and_calc_og(body, media_info['uri'], encoding) + og = decode_and_calc_og(body, media_info["uri"], encoding) # pre-cache the image for posterity # FIXME: it might be cleaner to use the same flow as the main /preview_url # request itself and benefit from the same caching etc. But for now we # just rely on the caching on the master request to speed things up. - if 'og:image' in og and og['og:image']: + if "og:image" in og and og["og:image"]: image_info = yield self._download_url( - _rebase_url(og['og:image'], media_info['uri']), user + _rebase_url(og["og:image"], media_info["uri"]), user ) - if _is_media(image_info['media_type']): + if _is_media(image_info["media_type"]): # TODO: make sure we don't choke on white-on-transparent images - file_id = image_info['filesystem_id'] + file_id = image_info["filesystem_id"] dims = yield self.media_repo._generate_thumbnails( - None, file_id, file_id, image_info["media_type"], - url_cache=True, + None, file_id, file_id, image_info["media_type"], url_cache=True ) if dims: - og["og:image:width"] = dims['width'] - og["og:image:height"] = dims['height'] + og["og:image:width"] = dims["width"] + og["og:image:height"] = dims["height"] else: logger.warn("Couldn't get dims for %s" % og["og:image"]) og["og:image"] = "mxc://%s/%s" % ( - self.server_name, image_info['filesystem_id'] + self.server_name, + image_info["filesystem_id"], ) - og["og:image:type"] = image_info['media_type'] - og["matrix:image:size"] = image_info['media_length'] + og["og:image:type"] = image_info["media_type"] + og["matrix:image:size"] = image_info["media_length"] else: del og["og:image"] else: @@ -291,7 +278,7 @@ class PreviewUrlResource(Resource): logger.debug("Calculated OG for %s as %s" % (url, og)) - jsonog = json.dumps(og).encode('utf8') + jsonog = json.dumps(og).encode("utf8") # store OG in history-aware DB cache yield self.store.store_url_cache( @@ -312,19 +299,15 @@ class PreviewUrlResource(Resource): # we're most likely being explicitly triggered by a human rather than a # bot, so are we really a robot? - file_id = datetime.date.today().isoformat() + '_' + random_string(16) + file_id = datetime.date.today().isoformat() + "_" + random_string(16) - file_info = FileInfo( - server_name=None, - file_id=file_id, - url_cache=True, - ) + file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True) with self.media_storage.store_into_file(file_info) as (f, fname, finish): try: logger.debug("Trying to get url '%s'" % url) length, headers, uri, code = yield self.client.get_file( - url, output_stream=f, max_size=self.max_spider_size, + url, output_stream=f, max_size=self.max_spider_size ) except SynapseError: # Pass SynapseErrors through directly, so that the servlet @@ -336,24 +319,25 @@ class PreviewUrlResource(Resource): # Note: This will also be the case if one of the resolved IP # addresses is blacklisted raise SynapseError( - 502, "DNS resolution failure during URL preview generation", - Codes.UNKNOWN + 502, + "DNS resolution failure during URL preview generation", + Codes.UNKNOWN, ) except Exception as e: # FIXME: pass through 404s and other error messages nicely logger.warn("Error downloading %s: %r", url, e) raise SynapseError( - 500, "Failed to download content: %s" % ( - traceback.format_exception_only(sys.exc_info()[0], e), - ), + 500, + "Failed to download content: %s" + % (traceback.format_exception_only(sys.exc_info()[0], e),), Codes.UNKNOWN, ) yield finish() try: if b"Content-Type" in headers: - media_type = headers[b"Content-Type"][0].decode('ascii') + media_type = headers[b"Content-Type"][0].decode("ascii") else: media_type = "application/octet-stream" time_now_ms = self.clock.time_msec() @@ -377,24 +361,26 @@ class PreviewUrlResource(Resource): # therefore not expire it. raise - defer.returnValue({ - "media_type": media_type, - "media_length": length, - "download_name": download_name, - "created_ts": time_now_ms, - "filesystem_id": file_id, - "filename": fname, - "uri": uri, - "response_code": code, - # FIXME: we should calculate a proper expiration based on the - # Cache-Control and Expire headers. But for now, assume 1 hour. - "expires": 60 * 60 * 1000, - "etag": headers["ETag"][0] if "ETag" in headers else None, - }) + defer.returnValue( + { + "media_type": media_type, + "media_length": length, + "download_name": download_name, + "created_ts": time_now_ms, + "filesystem_id": file_id, + "filename": fname, + "uri": uri, + "response_code": code, + # FIXME: we should calculate a proper expiration based on the + # Cache-Control and Expire headers. But for now, assume 1 hour. + "expires": 60 * 60 * 1000, + "etag": headers["ETag"][0] if "ETag" in headers else None, + } + ) def _start_expire_url_cache_data(self): return run_as_background_process( - "expire_url_cache_data", self._expire_url_cache_data, + "expire_url_cache_data", self._expire_url_cache_data ) @defer.inlineCallbacks @@ -498,7 +484,7 @@ def decode_and_calc_og(body, media_uri, request_encoding=None): # blindly try decoding the body as utf-8, which seems to fix # the charset mismatches on https://google.com parser = etree.HTMLParser(recover=True, encoding=request_encoding) - tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser) + tree = etree.fromstring(body.decode("utf-8", "ignore"), parser) og = _calc_og(tree, media_uri) return og @@ -525,8 +511,8 @@ def _calc_og(tree, media_uri): og = {} for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"): - if 'content' in tag.attrib: - og[tag.attrib['property']] = tag.attrib['content'] + if "content" in tag.attrib: + og[tag.attrib["property"]] = tag.attrib["content"] # TODO: grab article: meta tags too, e.g.: @@ -537,39 +523,43 @@ def _calc_og(tree, media_uri): # "article:published_time" content="2016-03-31T19:58:24+00:00" /> # "article:modified_time" content="2016-04-01T18:31:53+00:00" /> - if 'og:title' not in og: + if "og:title" not in og: # do some basic spidering of the HTML title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]") if title and title[0].text is not None: - og['og:title'] = title[0].text.strip() + og["og:title"] = title[0].text.strip() else: - og['og:title'] = None + og["og:title"] = None - if 'og:image' not in og: + if "og:image" not in og: # TODO: extract a favicon failing all else meta_image = tree.xpath( "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content" ) if meta_image: - og['og:image'] = _rebase_url(meta_image[0], media_uri) + og["og:image"] = _rebase_url(meta_image[0], media_uri) else: # TODO: consider inlined CSS styles as well as width & height attribs images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]") - images = sorted(images, key=lambda i: ( - -1 * float(i.attrib['width']) * float(i.attrib['height']) - )) + images = sorted( + images, + key=lambda i: ( + -1 * float(i.attrib["width"]) * float(i.attrib["height"]) + ), + ) if not images: images = tree.xpath("//img[@src]") if images: - og['og:image'] = images[0].attrib['src'] + og["og:image"] = images[0].attrib["src"] - if 'og:description' not in og: + if "og:description" not in og: meta_description = tree.xpath( "//*/meta" "[translate(@name, 'DESCRIPTION', 'description')='description']" - "/@content") + "/@content" + ) if meta_description: - og['og:description'] = meta_description[0] + og["og:description"] = meta_description[0] else: # grab any text nodes which are inside the tag... # unless they are within an HTML5 semantic markup tag... @@ -590,18 +580,18 @@ def _calc_og(tree, media_uri): "script", "noscript", "style", - etree.Comment + etree.Comment, ) # Split all the text nodes into paragraphs (by splitting on new # lines) text_nodes = ( - re.sub(r'\s+', '\n', el).strip() + re.sub(r"\s+", "\n", el).strip() for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE) ) - og['og:description'] = summarize_paragraphs(text_nodes) + og["og:description"] = summarize_paragraphs(text_nodes) else: - og['og:description'] = summarize_paragraphs([og['og:description']]) + og["og:description"] = summarize_paragraphs([og["og:description"]]) # TODO: delete the url downloads to stop diskfilling, # as we only ever cared about its OG @@ -638,7 +628,7 @@ def _iterate_over_text(tree, *tags_to_ignore): [child, child.tail] if child.tail else [child] for child in el.iterchildren() ), - elements + elements, ) @@ -649,8 +639,8 @@ def _rebase_url(url, base): url[0] = base[0] or "http" if not url[1]: # fix up hostname url[1] = base[1] - if not url[2].startswith('/'): - url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2] + if not url[2].startswith("/"): + url[2] = re.sub(r"/[^/]+$", "/", base[2]) + url[2] return urlparse.urlunparse(url) @@ -661,9 +651,8 @@ def _is_media(content_type): def _is_html(content_type): content_type = content_type.lower() - if ( - content_type.startswith("text/html") or - content_type.startswith("application/xhtml") + if content_type.startswith("text/html") or content_type.startswith( + "application/xhtml" ): return True @@ -673,19 +662,19 @@ def summarize_paragraphs(text_nodes, min_size=200, max_size=500): # first paragraph and then word boundaries. # TODO: Respect sentences? - description = '' + description = "" # Keep adding paragraphs until we get to the MIN_SIZE. for text_node in text_nodes: if len(description) < min_size: - text_node = re.sub(r'[\t \r\n]+', ' ', text_node) - description += text_node + '\n\n' + text_node = re.sub(r"[\t \r\n]+", " ", text_node) + description += text_node + "\n\n" else: break description = description.strip() - description = re.sub(r'[\t ]+', ' ', description) - description = re.sub(r'[\t \r\n]*[\r\n]+', '\n\n', description) + description = re.sub(r"[\t ]+", " ", description) + description = re.sub(r"[\t \r\n]*[\r\n]+", "\n\n", description) # If the concatenation of paragraphs to get above MIN_SIZE # took us over MAX_SIZE, then we need to truncate mid paragraph diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index d90cbfb56a..359b45ebfc 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -32,6 +32,7 @@ class StorageProvider(object): """A storage provider is a service that can store uploaded media and retrieve them. """ + def store_file(self, path, file_info): """Store the file described by file_info. The actual contents can be retrieved by reading the file in file_info.upload_path. @@ -70,6 +71,7 @@ class StorageProviderWrapper(StorageProvider): uploaded, or todo the upload in the backgroud. store_remote (bool): Whether remote media should be uploaded """ + def __init__(self, backend, store_local, store_synchronous, store_remote): self.backend = backend self.store_local = store_local @@ -92,6 +94,7 @@ class StorageProviderWrapper(StorageProvider): return self.backend.store_file(path, file_info) except Exception: logger.exception("Error storing file") + run_in_background(store) return defer.succeed(None) @@ -123,8 +126,7 @@ class FileStorageProviderBackend(StorageProvider): os.makedirs(dirname) return logcontext.defer_to_thread( - self.hs.get_reactor(), - shutil.copyfile, primary_fname, backup_fname, + self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname ) def fetch(self, path, file_info): diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index 35a750923b..ca84c9f139 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -74,19 +74,18 @@ class ThumbnailResource(Resource): else: if self.dynamic_thumbnails: yield self._select_or_generate_remote_thumbnail( - request, server_name, media_id, - width, height, method, m_type + request, server_name, media_id, width, height, method, m_type ) else: yield self._respond_remote_thumbnail( - request, server_name, media_id, - width, height, method, m_type + request, server_name, media_id, width, height, method, m_type ) self.media_repo.mark_recently_accessed(server_name, media_id) @defer.inlineCallbacks - def _respond_local_thumbnail(self, request, media_id, width, height, - method, m_type): + def _respond_local_thumbnail( + self, request, media_id, width, height, method, m_type + ): media_info = yield self.store.get_local_media(media_id) if not media_info: @@ -105,7 +104,8 @@ class ThumbnailResource(Resource): ) file_info = FileInfo( - server_name=None, file_id=media_id, + server_name=None, + file_id=media_id, url_cache=media_info["url_cache"], thumbnail=True, thumbnail_width=thumbnail_info["thumbnail_width"], @@ -124,9 +124,15 @@ class ThumbnailResource(Resource): respond_404(request) @defer.inlineCallbacks - def _select_or_generate_local_thumbnail(self, request, media_id, desired_width, - desired_height, desired_method, - desired_type): + def _select_or_generate_local_thumbnail( + self, + request, + media_id, + desired_width, + desired_height, + desired_method, + desired_type, + ): media_info = yield self.store.get_local_media(media_id) if not media_info: @@ -146,7 +152,8 @@ class ThumbnailResource(Resource): if t_w and t_h and t_method and t_type: file_info = FileInfo( - server_name=None, file_id=media_id, + server_name=None, + file_id=media_id, url_cache=media_info["url_cache"], thumbnail=True, thumbnail_width=info["thumbnail_width"], @@ -167,7 +174,11 @@ class ThumbnailResource(Resource): # Okay, so we generate one. file_path = yield self.media_repo.generate_local_exact_thumbnail( - media_id, desired_width, desired_height, desired_method, desired_type, + media_id, + desired_width, + desired_height, + desired_method, + desired_type, url_cache=media_info["url_cache"], ) @@ -178,13 +189,20 @@ class ThumbnailResource(Resource): respond_404(request) @defer.inlineCallbacks - def _select_or_generate_remote_thumbnail(self, request, server_name, media_id, - desired_width, desired_height, - desired_method, desired_type): + def _select_or_generate_remote_thumbnail( + self, + request, + server_name, + media_id, + desired_width, + desired_height, + desired_method, + desired_type, + ): media_info = yield self.media_repo.get_remote_media_info(server_name, media_id) thumbnail_infos = yield self.store.get_remote_media_thumbnails( - server_name, media_id, + server_name, media_id ) file_id = media_info["filesystem_id"] @@ -197,7 +215,8 @@ class ThumbnailResource(Resource): if t_w and t_h and t_method and t_type: file_info = FileInfo( - server_name=server_name, file_id=media_info["filesystem_id"], + server_name=server_name, + file_id=media_info["filesystem_id"], thumbnail=True, thumbnail_width=info["thumbnail_width"], thumbnail_height=info["thumbnail_height"], @@ -217,8 +236,13 @@ class ThumbnailResource(Resource): # Okay, so we generate one. file_path = yield self.media_repo.generate_remote_exact_thumbnail( - server_name, file_id, media_id, desired_width, - desired_height, desired_method, desired_type + server_name, + file_id, + media_id, + desired_width, + desired_height, + desired_method, + desired_type, ) if file_path: @@ -228,15 +252,16 @@ class ThumbnailResource(Resource): respond_404(request) @defer.inlineCallbacks - def _respond_remote_thumbnail(self, request, server_name, media_id, width, - height, method, m_type): + def _respond_remote_thumbnail( + self, request, server_name, media_id, width, height, method, m_type + ): # TODO: Don't download the whole remote file # We should proxy the thumbnail from the remote server instead of # downloading the remote file and generating our own thumbnails. media_info = yield self.media_repo.get_remote_media_info(server_name, media_id) thumbnail_infos = yield self.store.get_remote_media_thumbnails( - server_name, media_id, + server_name, media_id ) if thumbnail_infos: @@ -244,7 +269,8 @@ class ThumbnailResource(Resource): width, height, method, m_type, thumbnail_infos ) file_info = FileInfo( - server_name=server_name, file_id=media_info["filesystem_id"], + server_name=server_name, + file_id=media_info["filesystem_id"], thumbnail=True, thumbnail_width=thumbnail_info["thumbnail_width"], thumbnail_height=thumbnail_info["thumbnail_height"], @@ -261,8 +287,14 @@ class ThumbnailResource(Resource): logger.info("Failed to find any generated thumbnails") respond_404(request) - def _select_thumbnail(self, desired_width, desired_height, desired_method, - desired_type, thumbnail_infos): + def _select_thumbnail( + self, + desired_width, + desired_height, + desired_method, + desired_type, + thumbnail_infos, + ): d_w = desired_width d_h = desired_height @@ -280,15 +312,27 @@ class ThumbnailResource(Resource): type_quality = desired_type != info["thumbnail_type"] length_quality = info["thumbnail_length"] if t_w >= d_w or t_h >= d_h: - info_list.append(( - aspect_quality, min_quality, size_quality, type_quality, - length_quality, info - )) + info_list.append( + ( + aspect_quality, + min_quality, + size_quality, + type_quality, + length_quality, + info, + ) + ) else: - info_list2.append(( - aspect_quality, min_quality, size_quality, type_quality, - length_quality, info - )) + info_list2.append( + ( + aspect_quality, + min_quality, + size_quality, + type_quality, + length_quality, + info, + ) + ) if info_list: return min(info_list)[-1] else: @@ -304,13 +348,11 @@ class ThumbnailResource(Resource): type_quality = desired_type != info["thumbnail_type"] length_quality = info["thumbnail_length"] if t_method == "scale" and (t_w >= d_w or t_h >= d_h): - info_list.append(( - size_quality, type_quality, length_quality, info - )) + info_list.append((size_quality, type_quality, length_quality, info)) elif t_method == "scale": - info_list2.append(( - size_quality, type_quality, length_quality, info - )) + info_list2.append( + (size_quality, type_quality, length_quality, info) + ) if info_list: return min(info_list)[-1] else: diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 3efd0d80fc..90d8e6bffe 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -28,16 +28,13 @@ EXIF_TRANSPOSE_MAPPINGS = { 5: Image.TRANSPOSE, 6: Image.ROTATE_270, 7: Image.TRANSVERSE, - 8: Image.ROTATE_90 + 8: Image.ROTATE_90, } class Thumbnailer(object): - FORMATS = { - "image/jpeg": "JPEG", - "image/png": "PNG", - } + FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"} def __init__(self, input_path): self.image = Image.open(input_path) @@ -110,17 +107,13 @@ class Thumbnailer(object): """ if width * self.height > height * self.width: scaled_height = (width * self.height) // self.width - scaled_image = self.image.resize( - (width, scaled_height), Image.ANTIALIAS - ) + scaled_image = self.image.resize((width, scaled_height), Image.ANTIALIAS) crop_top = (scaled_height - height) // 2 crop_bottom = height + crop_top cropped = scaled_image.crop((0, crop_top, width, crop_bottom)) else: scaled_width = (height * self.width) // self.height - scaled_image = self.image.resize( - (scaled_width, height), Image.ANTIALIAS - ) + scaled_image = self.image.resize((scaled_width, height), Image.ANTIALIAS) crop_left = (scaled_width - width) // 2 crop_right = width + crop_left cropped = scaled_image.crop((crop_left, 0, crop_right, height)) diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index c1240e1963..d1d7e959f0 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -55,48 +55,36 @@ class UploadResource(Resource): requester = yield self.auth.get_user_by_req(request) # TODO: The checks here are a bit late. The content will have # already been uploaded to a tmp file at this point - content_length = request.getHeader(b"Content-Length").decode('ascii') + content_length = request.getHeader(b"Content-Length").decode("ascii") if content_length is None: - raise SynapseError( - msg="Request must specify a Content-Length", code=400 - ) + raise SynapseError(msg="Request must specify a Content-Length", code=400) if int(content_length) > self.max_upload_size: - raise SynapseError( - msg="Upload request body is too large", - code=413, - ) + raise SynapseError(msg="Upload request body is too large", code=413) upload_name = parse_string(request, b"filename", encoding=None) if upload_name: try: - upload_name = upload_name.decode('utf8') + upload_name = upload_name.decode("utf8") except UnicodeDecodeError: raise SynapseError( - msg="Invalid UTF-8 filename parameter: %r" % (upload_name), - code=400, + msg="Invalid UTF-8 filename parameter: %r" % (upload_name), code=400 ) headers = request.requestHeaders if headers.hasHeader(b"Content-Type"): - media_type = headers.getRawHeaders(b"Content-Type")[0].decode('ascii') + media_type = headers.getRawHeaders(b"Content-Type")[0].decode("ascii") else: - raise SynapseError( - msg="Upload request missing 'Content-Type'", - code=400, - ) + raise SynapseError(msg="Upload request missing 'Content-Type'", code=400) # if headers.hasHeader(b"Content-Disposition"): # disposition = headers.getRawHeaders(b"Content-Disposition")[0] # TODO(markjh): parse content-dispostion content_uri = yield self.media_repo.create_content( - media_type, upload_name, request.content, - content_length, requester.user + media_type, upload_name, request.content, content_length, requester.user ) logger.info("Uploaded content with URI %r", content_uri) - respond_with_json( - request, 200, {"content_uri": content_uri}, send_cors=True - ) + respond_with_json(request, 200, {"content_uri": content_uri}, send_cors=True) diff --git a/synapse/rest/saml2/metadata_resource.py b/synapse/rest/saml2/metadata_resource.py index e8c680aeb4..1e8526e22e 100644 --- a/synapse/rest/saml2/metadata_resource.py +++ b/synapse/rest/saml2/metadata_resource.py @@ -30,7 +30,7 @@ class SAML2MetadataResource(Resource): def render_GET(self, request): metadata_xml = saml2.metadata.create_metadata_string( - configfile=None, config=self.sp_config, + configfile=None, config=self.sp_config ) request.setHeader(b"Content-Type", b"text/xml; charset=utf-8") return metadata_xml diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/saml2/response_resource.py index 69fb77b322..ab14b70675 100644 --- a/synapse/rest/saml2/response_resource.py +++ b/synapse/rest/saml2/response_resource.py @@ -46,18 +46,16 @@ class SAML2ResponseResource(Resource): @wrap_html_request_handler def _async_render_POST(self, request): - resp_bytes = parse_string(request, 'SAMLResponse', required=True) - relay_state = parse_string(request, 'RelayState', required=True) + resp_bytes = parse_string(request, "SAMLResponse", required=True) + relay_state = parse_string(request, "RelayState", required=True) try: saml2_auth = self._saml_client.parse_authn_request_response( - resp_bytes, saml2.BINDING_HTTP_POST, + resp_bytes, saml2.BINDING_HTTP_POST ) except Exception as e: logger.warning("Exception parsing SAML2 response", exc_info=1) - raise CodeMessageException( - 400, "Unable to parse SAML2 response: %s" % (e,), - ) + raise CodeMessageException(400, "Unable to parse SAML2 response: %s" % (e,)) if saml2_auth.not_signed: raise CodeMessageException(400, "SAML2 response was not signed") @@ -69,6 +67,5 @@ class SAML2ResponseResource(Resource): displayName = saml2_auth.ava.get("displayName", [None])[0] return self._sso_auth_handler.on_successful_auth( - username, request, relay_state, - user_display_name=displayName, + username, request, relay_state, user_display_name=displayName ) diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index a7fa4f39af..5e8fda4b65 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -29,6 +29,7 @@ class WellKnownBuilder(object): Args: hs (synapse.server.HomeServer): """ + def __init__(self, hs): self._config = hs.config @@ -37,15 +38,11 @@ class WellKnownBuilder(object): if self._config.public_baseurl is None: return None - result = { - "m.homeserver": { - "base_url": self._config.public_baseurl, - }, - } + result = {"m.homeserver": {"base_url": self._config.public_baseurl}} if self._config.default_identity_server: result["m.identity_server"] = { - "base_url": self._config.default_identity_server, + "base_url": self._config.default_identity_server } return result @@ -66,7 +63,7 @@ class WellKnownResource(Resource): if not r: request.setResponseCode(404) request.setHeader(b"Content-Type", b"text/plain") - return b'.well-known not available' + return b".well-known not available" logger.debug("returning: %s", r) request.setHeader(b"Content-Type", b"application/json") diff --git a/synapse/rulecheck/domain_rule_checker.py b/synapse/rulecheck/domain_rule_checker.py index 212cc212cc..6f2a1931c5 100644 --- a/synapse/rulecheck/domain_rule_checker.py +++ b/synapse/rulecheck/domain_rule_checker.py @@ -61,19 +61,19 @@ class DomainRuleChecker(object): self.default = config["default"] self.can_only_join_rooms_with_invite = config.get( - "can_only_join_rooms_with_invite", False, + "can_only_join_rooms_with_invite", False ) self.can_only_create_one_to_one_rooms = config.get( - "can_only_create_one_to_one_rooms", False, + "can_only_create_one_to_one_rooms", False ) self.can_only_invite_during_room_creation = config.get( - "can_only_invite_during_room_creation", False, + "can_only_invite_during_room_creation", False ) self.can_invite_by_third_party_id = config.get( - "can_invite_by_third_party_id", True, + "can_invite_by_third_party_id", True ) self.domains_prevented_from_being_invited_to_published_rooms = config.get( - "domains_prevented_from_being_invited_to_published_rooms", [], + "domains_prevented_from_being_invited_to_published_rooms", [] ) def check_event_for_spam(self, event): @@ -81,8 +81,15 @@ class DomainRuleChecker(object): """ return False - def user_may_invite(self, inviter_userid, invitee_userid, third_party_invite, - room_id, new_room, published_room=False): + def user_may_invite( + self, + inviter_userid, + invitee_userid, + third_party_invite, + room_id, + new_room, + published_room=False, + ): """Implements synapse.events.SpamChecker.user_may_invite """ if self.can_only_invite_during_room_creation and not new_room: @@ -103,15 +110,17 @@ class DomainRuleChecker(object): return self.default if ( - published_room and - invitee_domain in self.domains_prevented_from_being_invited_to_published_rooms + published_room + and invitee_domain + in self.domains_prevented_from_being_invited_to_published_rooms ): return False return invitee_domain in self.domain_mapping[inviter_domain] - def user_may_create_room(self, userid, invite_list, third_party_invite_list, - cloning): + def user_may_create_room( + self, userid, invite_list, third_party_invite_list, cloning + ): """Implements synapse.events.SpamChecker.user_may_create_room """ @@ -169,4 +178,4 @@ class DomainRuleChecker(object): idx = mxid.find(":") if idx == -1: raise Exception("Invalid ID: %r" % (mxid,)) - return mxid[idx + 1:] + return mxid[idx + 1 :] diff --git a/synapse/secrets.py b/synapse/secrets.py index f6280f951c..0b327a0f82 100644 --- a/synapse/secrets.py +++ b/synapse/secrets.py @@ -29,6 +29,7 @@ if sys.version_info[0:2] >= (3, 6): def Secrets(): return secrets + else: import os import binascii @@ -38,4 +39,4 @@ else: return os.urandom(nbytes) def token_hex(self, nbytes=32): - return binascii.hexlify(self.token_bytes(nbytes)).decode('ascii') + return binascii.hexlify(self.token_bytes(nbytes)).decode("ascii") diff --git a/synapse/server.py b/synapse/server.py index 3f3c79498a..dfc619504a 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -93,7 +93,9 @@ from synapse.rest.media.v1.media_repository import ( from synapse.secrets import Secrets from synapse.server_notices.server_notices_manager import ServerNoticesManager from synapse.server_notices.server_notices_sender import ServerNoticesSender -from synapse.server_notices.worker_server_notices_sender import WorkerServerNoticesSender +from synapse.server_notices.worker_server_notices_sender import ( + WorkerServerNoticesSender +) from synapse.state import StateHandler, StateResolutionHandler from synapse.streams.events import EventSources from synapse.util import Clock @@ -129,81 +131,78 @@ class HomeServer(object): __metaclass__ = abc.ABCMeta DEPENDENCIES = [ - 'http_client', - 'db_pool', - 'federation_client', - 'federation_server', - 'handlers', - 'auth', - 'room_creation_handler', - 'state_handler', - 'state_resolution_handler', - 'presence_handler', - 'sync_handler', - 'typing_handler', - 'room_list_handler', - 'acme_handler', - 'auth_handler', - 'device_handler', - 'stats_handler', - 'e2e_keys_handler', - 'e2e_room_keys_handler', - 'event_handler', - 'event_stream_handler', - 'initial_sync_handler', - 'application_service_api', - 'application_service_scheduler', - 'application_service_handler', - 'device_message_handler', - 'profile_handler', - 'event_creation_handler', - 'deactivate_account_handler', - 'set_password_handler', - 'notifier', - 'event_sources', - 'keyring', - 'pusherpool', - 'event_builder_factory', - 'filtering', - 'http_client_context_factory', + "http_client", + "db_pool", + "federation_client", + "federation_server", + "handlers", + "auth", + "room_creation_handler", + "state_handler", + "state_resolution_handler", + "presence_handler", + "sync_handler", + "typing_handler", + "room_list_handler", + "acme_handler", + "auth_handler", + "device_handler", + "stats_handler", + "e2e_keys_handler", + "e2e_room_keys_handler", + "event_handler", + "event_stream_handler", + "initial_sync_handler", + "application_service_api", + "application_service_scheduler", + "application_service_handler", + "device_message_handler", + "profile_handler", + "event_creation_handler", + "deactivate_account_handler", + "set_password_handler", + "notifier", + "event_sources", + "keyring", + "pusherpool", + "event_builder_factory", + "filtering", + "http_client_context_factory", "proxied_http_client", - 'simple_http_client', - 'media_repository', - 'media_repository_resource', - 'federation_transport_client', - 'federation_sender', - 'receipts_handler', - 'macaroon_generator', - 'tcp_replication', - 'read_marker_handler', - 'action_generator', - 'user_directory_handler', - 'groups_local_handler', - 'groups_server_handler', - 'groups_attestation_signing', - 'groups_attestation_renewer', - 'secrets', - 'spam_checker', - 'third_party_event_rules', - 'room_member_handler', - 'federation_registry', - 'server_notices_manager', - 'server_notices_sender', - 'message_handler', - 'pagination_handler', - 'room_context_handler', - 'sendmail', - 'registration_handler', - 'account_validity_handler', - 'event_client_serializer', - 'password_policy_handler', - ] - - REQUIRED_ON_MASTER_STARTUP = [ + "simple_http_client", + "media_repository", + "media_repository_resource", + "federation_transport_client", + "federation_sender", + "receipts_handler", + "macaroon_generator", + "tcp_replication", + "read_marker_handler", + "action_generator", "user_directory_handler", - "stats_handler" + "groups_local_handler", + "groups_server_handler", + "groups_attestation_signing", + "groups_attestation_renewer", + "secrets", + "spam_checker", + "third_party_event_rules", + "room_member_handler", + "federation_registry", + "server_notices_manager", + "server_notices_sender", + "message_handler", + "pagination_handler", + "room_context_handler", + "sendmail", + "registration_handler", + "account_validity_handler", + "event_client_serializer", + "password_policy_handler", ] + REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"] + # This is overridden in derived application classes # (such as synapse.app.homeserver.SynapseHomeServer) and gives the class to be # instantiated during setup() for future return by get_datastore() @@ -421,9 +420,7 @@ class HomeServer(object): name = self.db_config["name"] return adbapi.ConnectionPool( - name, - cp_reactor=self.get_reactor(), - **self.db_config.get("args", {}) + name, cp_reactor=self.get_reactor(), **self.db_config.get("args", {}) ) def get_db_conn(self, run_new_connection=True): @@ -435,7 +432,8 @@ class HomeServer(object): # Any param beginning with cp_ is a parameter for adbapi, and should # not be passed to the database engine. db_params = { - k: v for k, v in self.db_config.get("args", {}).items() + k: v + for k, v in self.db_config.get("args", {}).items() if not k.startswith("cp_") } db_conn = self.database_engine.module.connect(**db_params) @@ -569,9 +567,7 @@ def _make_dependency_method(depname): if builder: # Prevent cyclic dependencies from deadlocking if depname in hs._building: - raise ValueError("Cyclic dependency while building %s" % ( - depname, - )) + raise ValueError("Cyclic dependency while building %s" % (depname,)) hs._building[depname] = 1 dep = builder() @@ -582,9 +578,7 @@ def _make_dependency_method(depname): return dep raise NotImplementedError( - "%s has no %s nor a builder for it" % ( - type(hs).__name__, depname, - ) + "%s has no %s nor a builder for it" % (type(hs).__name__, depname) ) setattr(HomeServer, "get_%s" % (depname), _get) diff --git a/synapse/server.pyi b/synapse/server.pyi index 69263458db..56f9cd06e5 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -23,28 +23,20 @@ class HomeServer(object): @property def config(self) -> synapse.config.homeserver.HomeServerConfig: pass - def get_auth(self) -> synapse.api.auth.Auth: pass - def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler: pass - def get_datastore(self) -> synapse.storage.DataStore: pass - def get_device_handler(self) -> synapse.handlers.device.DeviceHandler: pass - def get_e2e_keys_handler(self) -> synapse.handlers.e2e_keys.E2eKeysHandler: pass - def get_handlers(self) -> synapse.handlers.Handlers: pass - def get_state_handler(self) -> synapse.state.StateHandler: pass - def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler: pass def get_simple_http_client(self) -> synapse.http.client.SimpleHttpClient: @@ -55,35 +47,41 @@ class HomeServer(object): """Fetch an HTTP client implementation which doesn't do any blacklisting but does support HTTP_PROXY settings""" pass - def get_deactivate_account_handler(self) -> synapse.handlers.deactivate_account.DeactivateAccountHandler: + def get_deactivate_account_handler( + self + ) -> synapse.handlers.deactivate_account.DeactivateAccountHandler: pass - def get_room_creation_handler(self) -> synapse.handlers.room.RoomCreationHandler: pass - def get_room_member_handler(self) -> synapse.handlers.room_member.RoomMemberHandler: pass - - def get_event_creation_handler(self) -> synapse.handlers.message.EventCreationHandler: + def get_event_creation_handler( + self + ) -> synapse.handlers.message.EventCreationHandler: pass - - def get_set_password_handler(self) -> synapse.handlers.set_password.SetPasswordHandler: + def get_set_password_handler( + self + ) -> synapse.handlers.set_password.SetPasswordHandler: pass - def get_federation_sender(self) -> synapse.federation.sender.FederationSender: pass - - def get_federation_transport_client(self) -> synapse.federation.transport.client.TransportLayerClient: + def get_federation_transport_client( + self + ) -> synapse.federation.transport.client.TransportLayerClient: pass - - def get_media_repository_resource(self) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource: + def get_media_repository_resource( + self + ) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource: pass - - def get_media_repository(self) -> synapse.rest.media.v1.media_repository.MediaRepository: + def get_media_repository( + self + ) -> synapse.rest.media.v1.media_repository.MediaRepository: pass - - def get_server_notices_manager(self) -> synapse.server_notices.server_notices_manager.ServerNoticesManager: + def get_server_notices_manager( + self + ) -> synapse.server_notices.server_notices_manager.ServerNoticesManager: pass - - def get_server_notices_sender(self) -> synapse.server_notices.server_notices_sender.ServerNoticesSender: + def get_server_notices_sender( + self + ) -> synapse.server_notices.server_notices_sender.ServerNoticesSender: pass diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py index 5e3044d164..415e9c17d8 100644 --- a/synapse/server_notices/consent_server_notices.py +++ b/synapse/server_notices/consent_server_notices.py @@ -30,6 +30,7 @@ class ConsentServerNotices(object): """Keeps track of whether we need to send users server_notices about privacy policy consent, and sends one if we do. """ + def __init__(self, hs): """ @@ -49,12 +50,11 @@ class ConsentServerNotices(object): if not self._server_notices_manager.is_enabled(): raise ConfigError( "user_consent configuration requires server notices, but " - "server notices are not enabled.", + "server notices are not enabled." ) - if 'body' not in self._server_notice_content: + if "body" not in self._server_notice_content: raise ConfigError( - "user_consent server_notice_consent must contain a 'body' " - "key.", + "user_consent server_notice_consent must contain a 'body' " "key." ) self._consent_uri_builder = ConsentURIBuilder(hs.config) @@ -95,18 +95,14 @@ class ConsentServerNotices(object): # need to send a message. try: consent_uri = self._consent_uri_builder.build_user_consent_uri( - get_localpart_from_id(user_id), + get_localpart_from_id(user_id) ) content = copy_with_str_subst( - self._server_notice_content, { - 'consent_uri': consent_uri, - }, - ) - yield self._server_notices_manager.send_notice( - user_id, content, + self._server_notice_content, {"consent_uri": consent_uri} ) + yield self._server_notices_manager.send_notice(user_id, content) yield self._store.user_set_consent_server_notice_sent( - user_id, self._current_consent_version, + user_id, self._current_consent_version ) except SynapseError as e: logger.error("Error sending server notice about user consent: %s", e) @@ -128,9 +124,7 @@ def copy_with_str_subst(x, substitutions): if isinstance(x, string_types): return x % substitutions if isinstance(x, dict): - return { - k: copy_with_str_subst(v, substitutions) for (k, v) in iteritems(x) - } + return {k: copy_with_str_subst(v, substitutions) for (k, v) in iteritems(x)} if isinstance(x, (list, tuple)): return [copy_with_str_subst(y) for y in x] diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index af15cba0ee..f183743f31 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -33,6 +33,7 @@ class ResourceLimitsServerNotices(object): """ Keeps track of whether the server has reached it's resource limit and ensures that the client is kept up to date. """ + def __init__(self, hs): """ Args: @@ -104,34 +105,28 @@ class ResourceLimitsServerNotices(object): if currently_blocked and not is_auth_blocking: # Room is notifying of a block, when it ought not to be. # Remove block notification - content = { - "pinned": ref_events - } + content = {"pinned": ref_events} yield self._server_notices_manager.send_notice( - user_id, content, EventTypes.Pinned, '', + user_id, content, EventTypes.Pinned, "" ) elif not currently_blocked and is_auth_blocking: # Room is not notifying of a block, when it ought to be. # Add block notification content = { - 'body': event_content, - 'msgtype': ServerNoticeMsgType, - 'server_notice_type': ServerNoticeLimitReached, - 'admin_contact': self._config.admin_contact, - 'limit_type': event_limit_type + "body": event_content, + "msgtype": ServerNoticeMsgType, + "server_notice_type": ServerNoticeLimitReached, + "admin_contact": self._config.admin_contact, + "limit_type": event_limit_type, } event = yield self._server_notices_manager.send_notice( - user_id, content, EventTypes.Message, + user_id, content, EventTypes.Message ) - content = { - "pinned": [ - event.event_id, - ] - } + content = {"pinned": [event.event_id]} yield self._server_notices_manager.send_notice( - user_id, content, EventTypes.Pinned, '', + user_id, content, EventTypes.Pinned, "" ) except SynapseError as e: @@ -156,9 +151,7 @@ class ResourceLimitsServerNotices(object): max_id = yield self._store.add_tag_to_room( user_id, room_id, SERVER_NOTICE_ROOM_TAG, {} ) - self._notifier.on_new_event( - "account_data_key", max_id, users=[user_id] - ) + self._notifier.on_new_event("account_data_key", max_id, users=[user_id]) @defer.inlineCallbacks def _is_room_currently_blocked(self, room_id): @@ -188,7 +181,7 @@ class ResourceLimitsServerNotices(object): referenced_events = [] if pinned_state_event is not None: - referenced_events = list(pinned_state_event.content.get('pinned', [])) + referenced_events = list(pinned_state_event.content.get("pinned", [])) events = yield self._store.get_events(referenced_events) for event_id, event in iteritems(events): diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index c5cc6d728e..71e7e75320 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -51,8 +51,7 @@ class ServerNoticesManager(object): @defer.inlineCallbacks def send_notice( - self, user_id, event_content, - type=EventTypes.Message, state_key=None + self, user_id, event_content, type=EventTypes.Message, state_key=None ): """Send a notice to the given user @@ -82,10 +81,10 @@ class ServerNoticesManager(object): } if state_key is not None: - event_dict['state_key'] = state_key + event_dict["state_key"] = state_key res = yield self._event_creation_handler.create_and_send_nonmember_event( - requester, event_dict, ratelimit=False, + requester, event_dict, ratelimit=False ) defer.returnValue(res) @@ -104,11 +103,10 @@ class ServerNoticesManager(object): if not self.is_enabled(): raise Exception("Server notices not enabled") - assert self._is_mine_id(user_id), \ - "Cannot send server notices to remote users" + assert self._is_mine_id(user_id), "Cannot send server notices to remote users" rooms = yield self._store.get_rooms_for_user_where_membership_is( - user_id, [Membership.INVITE, Membership.JOIN], + user_id, [Membership.INVITE, Membership.JOIN] ) system_mxid = self._config.server_notices_mxid for room in rooms: @@ -132,8 +130,8 @@ class ServerNoticesManager(object): # avatar, we have to use both. join_profile = None if ( - self._config.server_notices_mxid_display_name is not None or - self._config.server_notices_mxid_avatar_url is not None + self._config.server_notices_mxid_display_name is not None + or self._config.server_notices_mxid_avatar_url is not None ): join_profile = { "displayname": self._config.server_notices_mxid_display_name, @@ -146,22 +144,18 @@ class ServerNoticesManager(object): config={ "preset": RoomCreationPreset.PRIVATE_CHAT, "name": self._config.server_notices_room_name, - "power_level_content_override": { - "users_default": -10, - }, - "invite": (user_id,) + "power_level_content_override": {"users_default": -10}, + "invite": (user_id,), }, ratelimit=False, creator_join_profile=join_profile, ) - room_id = info['room_id'] + room_id = info["room_id"] max_id = yield self._store.add_tag_to_room( - user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}, - ) - self._notifier.on_new_event( - "account_data_key", max_id, users=[user_id] + user_id, room_id, SERVER_NOTICE_ROOM_TAG, {} ) + self._notifier.on_new_event("account_data_key", max_id, users=[user_id]) logger.info("Created server notices room %s for %s", room_id, user_id) defer.returnValue(room_id) diff --git a/synapse/server_notices/server_notices_sender.py b/synapse/server_notices/server_notices_sender.py index 6121b2f267..ceff6c8e02 100644 --- a/synapse/server_notices/server_notices_sender.py +++ b/synapse/server_notices/server_notices_sender.py @@ -16,7 +16,7 @@ from twisted.internet import defer from synapse.server_notices.consent_server_notices import ConsentServerNotices from synapse.server_notices.resource_limits_server_notices import ( - ResourceLimitsServerNotices, + ResourceLimitsServerNotices ) @@ -24,6 +24,7 @@ class ServerNoticesSender(object): """A centralised place which sends server notices automatically when Certain Events take place """ + def __init__(self, hs): """ @@ -32,7 +33,7 @@ class ServerNoticesSender(object): """ self._server_notices = ( ConsentServerNotices(hs), - ResourceLimitsServerNotices(hs) + ResourceLimitsServerNotices(hs), ) @defer.inlineCallbacks @@ -43,9 +44,7 @@ class ServerNoticesSender(object): user_id (str): mxid of user who synced """ for sn in self._server_notices: - yield sn.maybe_send_server_notice_to_user( - user_id, - ) + yield sn.maybe_send_server_notice_to_user(user_id) @defer.inlineCallbacks def on_user_ip(self, user_id): @@ -58,6 +57,4 @@ class ServerNoticesSender(object): # we check for notices to send to the user in on_user_ip as well as # in on_user_syncing for sn in self._server_notices: - yield sn.maybe_send_server_notice_to_user( - user_id, - ) + yield sn.maybe_send_server_notice_to_user(user_id) diff --git a/synapse/server_notices/worker_server_notices_sender.py b/synapse/server_notices/worker_server_notices_sender.py index 4a133026c3..245ec7c64f 100644 --- a/synapse/server_notices/worker_server_notices_sender.py +++ b/synapse/server_notices/worker_server_notices_sender.py @@ -17,6 +17,7 @@ from twisted.internet import defer class WorkerServerNoticesSender(object): """Stub impl of ServerNoticesSender which does nothing""" + def __init__(self, hs): """ Args: diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index b914a5ba03..a72f3b1704 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -40,7 +40,8 @@ logger = logging.getLogger(__name__) # Metrics for number of state groups involved in a resolution. state_groups_histogram = Histogram( - "synapse_state_number_state_groups", "", + "synapse_state_number_state_groups", + "", buckets=(1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), ) @@ -106,8 +107,9 @@ class StateHandler(object): self._state_resolution_handler = hs.get_state_resolution_handler() @defer.inlineCallbacks - def get_current_state(self, room_id, event_type=None, state_key="", - latest_event_ids=None): + def get_current_state( + self, room_id, event_type=None, state_key="", latest_event_ids=None + ): """ Retrieves the current state for the room. This is done by calling `get_latest_events_in_room` to get the leading edges of the event graph and then resolving any of the state conflicts. @@ -136,8 +138,9 @@ class StateHandler(object): defer.returnValue(event) return - state_map = yield self.store.get_events(list(state.values()), - get_prev_content=False) + state_map = yield self.store.get_events( + list(state.values()), get_prev_content=False + ) state = { key: state_map[e_id] for key, e_id in iteritems(state) if e_id in state_map } @@ -219,9 +222,7 @@ class StateHandler(object): # state. Certainly store.get_current_state won't return any, and # persisting the event won't store the state group. if old_state: - prev_state_ids = { - (s.type, s.state_key): s.event_id for s in old_state - } + prev_state_ids = {(s.type, s.state_key): s.event_id for s in old_state} if event.is_state(): current_state_ids = dict(prev_state_ids) key = (event.type, event.state_key) @@ -247,9 +248,7 @@ class StateHandler(object): # Let's just correctly fill out the context and create a # new state group for it. - prev_state_ids = { - (s.type, s.state_key): s.event_id for s in old_state - } + prev_state_ids = {(s.type, s.state_key): s.event_id for s in old_state} if event.is_state(): key = (event.type, event.state_key) @@ -281,7 +280,7 @@ class StateHandler(object): logger.debug("calling resolve_state_groups from compute_event_context") entry = yield self.resolve_state_groups_for_events( - event.room_id, event.prev_event_ids(), + event.room_id, event.prev_event_ids() ) prev_state_ids = entry.state @@ -304,9 +303,7 @@ class StateHandler(object): # If the state at the event has a state group assigned then # we can use that as the prev group prev_group = entry.state_group - delta_ids = { - key: event.event_id - } + delta_ids = {key: event.event_id} elif entry.prev_group: # If the state at the event only has a prev group, then we can # use that as a prev group too. @@ -368,31 +365,31 @@ class StateHandler(object): # map from state group id to the state in that state group (where # 'state' is a map from state key to event id) # dict[int, dict[(str, str), str]] - state_groups_ids = yield self.store.get_state_groups_ids( - room_id, event_ids - ) + state_groups_ids = yield self.store.get_state_groups_ids(room_id, event_ids) if len(state_groups_ids) == 0: - defer.returnValue(_StateCacheEntry( - state={}, - state_group=None, - )) + defer.returnValue(_StateCacheEntry(state={}, state_group=None)) elif len(state_groups_ids) == 1: name, state_list = list(state_groups_ids.items()).pop() prev_group, delta_ids = yield self.store.get_state_group_delta(name) - defer.returnValue(_StateCacheEntry( - state=state_list, - state_group=name, - prev_group=prev_group, - delta_ids=delta_ids, - )) + defer.returnValue( + _StateCacheEntry( + state=state_list, + state_group=name, + prev_group=prev_group, + delta_ids=delta_ids, + ) + ) room_version = yield self.store.get_room_version(room_id) result = yield self._state_resolution_handler.resolve_state_groups( - room_id, room_version, state_groups_ids, None, + room_id, + room_version, + state_groups_ids, + None, state_res_store=StateResolutionStore(self.store), ) defer.returnValue(result) @@ -402,27 +399,21 @@ class StateHandler(object): logger.info( "Resolving state for %s with %d groups", event.room_id, len(state_sets) ) - state_set_ids = [{ - (ev.type, ev.state_key): ev.event_id - for ev in st - } for st in state_sets] - - state_map = { - ev.event_id: ev - for st in state_sets - for ev in st - } + state_set_ids = [ + {(ev.type, ev.state_key): ev.event_id for ev in st} for st in state_sets + ] + + state_map = {ev.event_id: ev for st in state_sets for ev in st} with Measure(self.clock, "state._resolve_events"): new_state = yield resolve_events_with_store( - room_version, state_set_ids, + room_version, + state_set_ids, event_map=state_map, state_res_store=StateResolutionStore(self.store), ) - new_state = { - key: state_map[ev_id] for key, ev_id in iteritems(new_state) - } + new_state = {key: state_map[ev_id] for key, ev_id in iteritems(new_state)} defer.returnValue(new_state) @@ -433,6 +424,7 @@ class StateResolutionHandler(object): Note that the storage layer depends on this handler, so all functions must be storage-independent. """ + def __init__(self, hs): self.clock = hs.get_clock() @@ -452,7 +444,7 @@ class StateResolutionHandler(object): @defer.inlineCallbacks @log_function def resolve_state_groups( - self, room_id, room_version, state_groups_ids, event_map, state_res_store, + self, room_id, room_version, state_groups_ids, event_map, state_res_store ): """Resolves conflicts between a set of state groups @@ -479,10 +471,7 @@ class StateResolutionHandler(object): Returns: Deferred[_StateCacheEntry]: resolved state """ - logger.debug( - "resolve_state_groups state_groups %s", - state_groups_ids.keys() - ) + logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys()) group_names = frozenset(state_groups_ids.keys()) @@ -539,10 +528,7 @@ class StateResolutionHandler(object): defer.returnValue(cache) -def _make_state_cache_entry( - new_state, - state_groups_ids, -): +def _make_state_cache_entry(new_state, state_groups_ids): """Given a resolved state, and a set of input state groups, pick one to base a new state group on (if any), and return an appropriately-constructed _StateCacheEntry. @@ -572,10 +558,7 @@ def _make_state_cache_entry( old_state_event_ids = set(itervalues(state)) if new_state_event_ids == old_state_event_ids: # got an exact match. - return _StateCacheEntry( - state=new_state, - state_group=sg, - ) + return _StateCacheEntry(state=new_state, state_group=sg) # TODO: We want to create a state group for this set of events, to # increase cache hits, but we need to make sure that it doesn't @@ -586,20 +569,13 @@ def _make_state_cache_entry( delta_ids = None for old_group, old_state in iteritems(state_groups_ids): - n_delta_ids = { - k: v - for k, v in iteritems(new_state) - if old_state.get(k) != v - } + n_delta_ids = {k: v for k, v in iteritems(new_state) if old_state.get(k) != v} if not delta_ids or len(n_delta_ids) < len(delta_ids): prev_group = old_group delta_ids = n_delta_ids return _StateCacheEntry( - state=new_state, - state_group=None, - prev_group=prev_group, - delta_ids=delta_ids, + state=new_state, state_group=None, prev_group=prev_group, delta_ids=delta_ids ) @@ -628,11 +604,11 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto v = KNOWN_ROOM_VERSIONS[room_version] if v.state_res == StateResolutionVersions.V1: return v1.resolve_events_with_store( - state_sets, event_map, state_res_store.get_events, + state_sets, event_map, state_res_store.get_events ) else: return v2.resolve_events_with_store( - room_version, state_sets, event_map, state_res_store, + room_version, state_sets, event_map, state_res_store ) diff --git a/synapse/state/v1.py b/synapse/state/v1.py index 29b4e86cfd..88acd4817e 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -57,23 +57,17 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory): if len(state_sets) == 1: defer.returnValue(state_sets[0]) - unconflicted_state, conflicted_state = _seperate( - state_sets, - ) + unconflicted_state, conflicted_state = _seperate(state_sets) needed_events = set( - event_id - for event_ids in itervalues(conflicted_state) - for event_id in event_ids + event_id for event_ids in itervalues(conflicted_state) for event_id in event_ids ) needed_event_count = len(needed_events) if event_map is not None: needed_events -= set(iterkeys(event_map)) logger.info( - "Asking for %d/%d conflicted events", - len(needed_events), - needed_event_count, + "Asking for %d/%d conflicted events", len(needed_events), needed_event_count ) # dict[str, FrozenEvent]: a map from state event id to event. Only includes @@ -97,17 +91,17 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory): new_needed_events -= set(iterkeys(event_map)) logger.info( - "Asking for %d/%d auth events", - len(new_needed_events), - new_needed_event_count, + "Asking for %d/%d auth events", len(new_needed_events), new_needed_event_count ) state_map_new = yield state_map_factory(new_needed_events) state_map.update(state_map_new) - defer.returnValue(_resolve_with_state( - unconflicted_state, conflicted_state, auth_events, state_map - )) + defer.returnValue( + _resolve_with_state( + unconflicted_state, conflicted_state, auth_events, state_map + ) + ) def _seperate(state_sets): @@ -173,8 +167,9 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma return auth_events -def _resolve_with_state(unconflicted_state_ids, conflicted_state_ids, auth_event_ids, - state_map): +def _resolve_with_state( + unconflicted_state_ids, conflicted_state_ids, auth_event_ids, state_map +): conflicted_state = {} for key, event_ids in iteritems(conflicted_state_ids): events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map] @@ -190,9 +185,7 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ids, auth_event } try: - resolved_state = _resolve_state_events( - conflicted_state, auth_events - ) + resolved_state = _resolve_state_events(conflicted_state, auth_events) except Exception: logger.exception("Failed to resolve state") raise @@ -218,37 +211,28 @@ def _resolve_state_events(conflicted_state, auth_events): if POWER_KEY in conflicted_state: events = conflicted_state[POWER_KEY] logger.debug("Resolving conflicted power levels %r", events) - resolved_state[POWER_KEY] = _resolve_auth_events( - events, auth_events) + resolved_state[POWER_KEY] = _resolve_auth_events(events, auth_events) auth_events.update(resolved_state) for key, events in iteritems(conflicted_state): if key[0] == EventTypes.JoinRules: logger.debug("Resolving conflicted join rules %r", events) - resolved_state[key] = _resolve_auth_events( - events, - auth_events - ) + resolved_state[key] = _resolve_auth_events(events, auth_events) auth_events.update(resolved_state) for key, events in iteritems(conflicted_state): if key[0] == EventTypes.Member: logger.debug("Resolving conflicted member lists %r", events) - resolved_state[key] = _resolve_auth_events( - events, - auth_events - ) + resolved_state[key] = _resolve_auth_events(events, auth_events) auth_events.update(resolved_state) for key, events in iteritems(conflicted_state): if key not in resolved_state: logger.debug("Resolving conflicted state %r:%r", key, events) - resolved_state[key] = _resolve_normal_events( - events, auth_events - ) + resolved_state[key] = _resolve_normal_events(events, auth_events) return resolved_state @@ -257,9 +241,7 @@ def _resolve_auth_events(events, auth_events): reverse = [i for i in reversed(_ordered_events(events))] auth_keys = set( - key - for event in events - for key in event_auth.auth_types_for_event(event) + key for event in events for key in event_auth.auth_types_for_event(event) ) new_auth_events = {} @@ -313,6 +295,6 @@ def _ordered_events(events): def key_func(e): # we have to use utf-8 rather than ascii here because it turns out we allow # people to send us events with non-ascii event IDs :/ - return -int(e.depth), hashlib.sha1(e.event_id.encode('utf-8')).hexdigest() + return -int(e.depth), hashlib.sha1(e.event_id.encode("utf-8")).hexdigest() return sorted(events, key=key_func) diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 650995c92c..db969e8997 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -70,19 +70,18 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto # Also fetch all auth events that appear in only some of the state sets' # auth chains. - auth_diff = yield _get_auth_chain_difference( - state_sets, event_map, state_res_store, - ) + auth_diff = yield _get_auth_chain_difference(state_sets, event_map, state_res_store) - full_conflicted_set = set(itertools.chain( - itertools.chain.from_iterable(itervalues(conflicted_state)), - auth_diff, - )) + full_conflicted_set = set( + itertools.chain( + itertools.chain.from_iterable(itervalues(conflicted_state)), auth_diff + ) + ) - events = yield state_res_store.get_events([ - eid for eid in full_conflicted_set - if eid not in event_map - ], allow_rejected=True) + events = yield state_res_store.get_events( + [eid for eid in full_conflicted_set if eid not in event_map], + allow_rejected=True, + ) event_map.update(events) full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map) @@ -91,22 +90,21 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto # Get and sort all the power events (kicks/bans/etc) power_events = ( - eid for eid in full_conflicted_set - if _is_power_event(event_map[eid]) + eid for eid in full_conflicted_set if _is_power_event(event_map[eid]) ) sorted_power_events = yield _reverse_topological_power_sort( - power_events, - event_map, - state_res_store, - full_conflicted_set, + power_events, event_map, state_res_store, full_conflicted_set ) logger.debug("sorted %d power events", len(sorted_power_events)) # Now sequentially auth each one resolved_state = yield _iterative_auth_checks( - room_version, sorted_power_events, unconflicted_state, event_map, + room_version, + sorted_power_events, + unconflicted_state, + event_map, state_res_store, ) @@ -116,23 +114,20 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto # events using the mainline of the resolved power level. leftover_events = [ - ev_id - for ev_id in full_conflicted_set - if ev_id not in sorted_power_events + ev_id for ev_id in full_conflicted_set if ev_id not in sorted_power_events ] logger.debug("sorting %d remaining events", len(leftover_events)) pl = resolved_state.get((EventTypes.PowerLevels, ""), None) leftover_events = yield _mainline_sort( - leftover_events, pl, event_map, state_res_store, + leftover_events, pl, event_map, state_res_store ) logger.debug("resolving remaining events") resolved_state = yield _iterative_auth_checks( - room_version, leftover_events, resolved_state, event_map, - state_res_store, + room_version, leftover_events, resolved_state, event_map, state_res_store ) logger.debug("resolved") @@ -209,14 +204,16 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store): auth_ids = set( eid for key, eid in iteritems(state_set) - if (key[0] in ( - EventTypes.Member, - EventTypes.ThirdPartyInvite, - ) or key in ( - (EventTypes.PowerLevels, ''), - (EventTypes.Create, ''), - (EventTypes.JoinRules, ''), - )) and eid not in common + if ( + key[0] in (EventTypes.Member, EventTypes.ThirdPartyInvite) + or key + in ( + (EventTypes.PowerLevels, ""), + (EventTypes.Create, ""), + (EventTypes.JoinRules, ""), + ) + ) + and eid not in common ) auth_chain = yield state_res_store.get_auth_chain(auth_ids) @@ -274,15 +271,16 @@ def _is_power_event(event): return True if event.type == EventTypes.Member: - if event.membership in ('leave', 'ban'): + if event.membership in ("leave", "ban"): return event.sender != event.state_key return False @defer.inlineCallbacks -def _add_event_and_auth_chain_to_graph(graph, event_id, event_map, - state_res_store, auth_diff): +def _add_event_and_auth_chain_to_graph( + graph, event_id, event_map, state_res_store, auth_diff +): """Helper function for _reverse_topological_power_sort that add the event and its auth chain (that is in the auth diff) to the graph @@ -327,7 +325,7 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_ graph = {} for event_id in event_ids: yield _add_event_and_auth_chain_to_graph( - graph, event_id, event_map, state_res_store, auth_diff, + graph, event_id, event_map, state_res_store, auth_diff ) event_to_pl = {} @@ -342,18 +340,16 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_ return -pl, ev.origin_server_ts, event_id # Note: graph is modified during the sort - it = lexicographical_topological_sort( - graph, - key=_get_power_order, - ) + it = lexicographical_topological_sort(graph, key=_get_power_order) sorted_events = list(it) defer.returnValue(sorted_events) @defer.inlineCallbacks -def _iterative_auth_checks(room_version, event_ids, base_state, event_map, - state_res_store): +def _iterative_auth_checks( + room_version, event_ids, base_state, event_map, state_res_store +): """Sequentially apply auth checks to each event in given list, updating the state as it goes along. @@ -389,9 +385,11 @@ def _iterative_auth_checks(room_version, event_ids, base_state, event_map, try: event_auth.check( - room_version, event, auth_events, + room_version, + event, + auth_events, do_sig_check=False, - do_size_check=False + do_size_check=False, ) resolved_state[(event.type, event.state_key)] = event_id @@ -402,8 +400,7 @@ def _iterative_auth_checks(room_version, event_ids, base_state, event_map, @defer.inlineCallbacks -def _mainline_sort(event_ids, resolved_power_event_id, event_map, - state_res_store): +def _mainline_sort(event_ids, resolved_power_event_id, event_map, state_res_store): """Returns a sorted list of event_ids sorted by mainline ordering based on the given event resolved_power_event_id @@ -436,8 +433,7 @@ def _mainline_sort(event_ids, resolved_power_event_id, event_map, order_map = {} for ev_id in event_ids: depth = yield _get_mainline_depth_for_event( - event_map[ev_id], mainline_map, - event_map, state_res_store, + event_map[ev_id], mainline_map, event_map, state_res_store ) order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 0ca6f6121f..6b0ca80087 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -280,7 +280,7 @@ class DataStore( Counts the number of users who used this homeserver in the last 24 hours. """ yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) - return self.runInteraction("count_daily_users", self._count_users, yesterday,) + return self.runInteraction("count_daily_users", self._count_users, yesterday) def count_monthly_users(self): """ @@ -291,9 +291,7 @@ class DataStore( """ thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) return self.runInteraction( - "count_monthly_users", - self._count_users, - thirty_days_ago, + "count_monthly_users", self._count_users, thirty_days_ago ) def _count_users(self, txn, time_from): @@ -361,7 +359,7 @@ class DataStore( txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) for row in txn: - if row[0] == 'unknown': + if row[0] == "unknown": pass results[row[0]] = row[1] @@ -388,7 +386,7 @@ class DataStore( txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) count, = txn.fetchone() - results['all'] = count + results["all"] = count return results diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 537696547c..43413e762c 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -312,9 +312,7 @@ class SQLBaseStore(object): if res: for user in res: self.set_expiration_date_for_user_txn( - txn, - user["name"], - use_delta=True, + txn, user["name"], use_delta=True ) yield self.runInteraction( @@ -344,8 +342,8 @@ class SQLBaseStore(object): self._simple_upsert_txn( txn, "account_validity", - keyvalues={"user_id": user_id, }, - values={"expiration_ts_ms": expiration_ts, "email_sent": False, }, + keyvalues={"user_id": user_id}, + values={"expiration_ts_ms": expiration_ts, "email_sent": False}, ) def start_profiling(self): @@ -1664,7 +1662,7 @@ def db_to_json(db_content): # Decode it to a Unicode string before feeding it to json.loads, so we # consistenty get a Unicode-containing object out. if isinstance(db_content, (bytes, bytearray)): - db_content = db_content.decode('utf8') + db_content = db_content.decode("utf8") try: return json.loads(db_content) diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index b8b8273f73..50f913a414 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -169,7 +169,7 @@ class BackgroundUpdateStore(SQLBaseStore): in_flight = set(update["update_name"] for update in updates) for update in updates: if update["depends_on"] not in in_flight: - self._background_update_queue.append(update['update_name']) + self._background_update_queue.append(update["update_name"]) if not self._background_update_queue: # no work left to do diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index d102e07372..3413a46675 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -149,9 +149,7 @@ class DeviceWorkerStore(SQLBaseStore): defer.returnValue((stream_id_cutoff, [])) results = yield self._get_device_update_edus_by_remote( - destination, - from_stream_id, - query_map, + destination, from_stream_id, query_map ) defer.returnValue((now_stream_id, results)) @@ -182,9 +180,7 @@ class DeviceWorkerStore(SQLBaseStore): return list(txn) @defer.inlineCallbacks - def _get_device_update_edus_by_remote( - self, destination, from_stream_id, query_map, - ): + def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_map): """Returns a list of device update EDUs as well as E2EE keys Args: @@ -210,7 +206,7 @@ class DeviceWorkerStore(SQLBaseStore): # The prev_id for the first row is always the last row before # `from_stream_id` prev_id = yield self._get_last_device_update_for_remote_user( - destination, user_id, from_stream_id, + destination, user_id, from_stream_id ) for device_id, device in iteritems(user_devices): stream_id = query_map[(user_id, device_id)] @@ -238,7 +234,7 @@ class DeviceWorkerStore(SQLBaseStore): defer.returnValue(results) def _get_last_device_update_for_remote_user( - self, destination, user_id, from_stream_id, + self, destination, user_id, from_stream_id ): def f(txn): prev_sent_id_sql = """ diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py index 521936e3b0..f40ef2ab64 100644 --- a/synapse/storage/e2e_room_keys.py +++ b/synapse/storage/e2e_room_keys.py @@ -87,10 +87,10 @@ class EndToEndRoomKeyStore(SQLBaseStore): }, values={ "version": version, - "first_message_index": room_key['first_message_index'], - "forwarded_count": room_key['forwarded_count'], - "is_verified": room_key['is_verified'], - "session_data": json.dumps(room_key['session_data']), + "first_message_index": room_key["first_message_index"], + "forwarded_count": room_key["forwarded_count"], + "is_verified": room_key["is_verified"], + "session_data": json.dumps(room_key["session_data"]), }, lock=False, ) @@ -118,13 +118,13 @@ class EndToEndRoomKeyStore(SQLBaseStore): try: version = int(version) except ValueError: - defer.returnValue({'rooms': {}}) + defer.returnValue({"rooms": {}}) keyvalues = {"user_id": user_id, "version": version} if room_id: - keyvalues['room_id'] = room_id + keyvalues["room_id"] = room_id if session_id: - keyvalues['session_id'] = session_id + keyvalues["session_id"] = session_id rows = yield self._simple_select_list( table="e2e_room_keys", @@ -141,10 +141,10 @@ class EndToEndRoomKeyStore(SQLBaseStore): desc="get_e2e_room_keys", ) - sessions = {'rooms': {}} + sessions = {"rooms": {}} for row in rows: - room_entry = sessions['rooms'].setdefault(row['room_id'], {"sessions": {}}) - room_entry['sessions'][row['session_id']] = { + room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}}) + room_entry["sessions"][row["session_id"]] = { "first_message_index": row["first_message_index"], "forwarded_count": row["forwarded_count"], "is_verified": row["is_verified"], @@ -174,9 +174,9 @@ class EndToEndRoomKeyStore(SQLBaseStore): keyvalues = {"user_id": user_id, "version": int(version)} if room_id: - keyvalues['room_id'] = room_id + keyvalues["room_id"] = room_id if session_id: - keyvalues['session_id'] = session_id + keyvalues["session_id"] = session_id yield self._simple_delete( table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys" @@ -191,7 +191,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): ) row = txn.fetchone() if not row: - raise StoreError(404, 'No current backup version') + raise StoreError(404, "No current backup version") return row[0] def get_e2e_room_keys_version_info(self, user_id, version=None): @@ -255,7 +255,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): ) current_version = txn.fetchone()[0] if current_version is None: - current_version = '0' + current_version = "0" new_version = str(int(current_version) + 1) diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index 933bcf42c2..e9b9caa49a 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -85,7 +85,7 @@ class Sqlite3Engine(object): def _parse_match_info(buf): bufsize = len(buf) - return [struct.unpack('@I', buf[i : i + 4])[0] for i in range(0, bufsize, 4)] + return [struct.unpack("@I", buf[i : i + 4])[0] for i in range(0, bufsize, 4)] def _rank(raw_match_info): diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index e8d16edbc8..cb4478342f 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -215,8 +215,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return [room_id for room_id, in txn] return self.runInteraction( - "get_rooms_with_many_extremities", - _get_rooms_with_many_extremities_txn, + "get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn ) @cached(max_entries=5000, iterable=True) diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index a729f3e067..eca77069fd 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -277,7 +277,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): # contain results from the first query, correctly ordered, followed # by results from the second query, but we want them all ordered # by stream_ordering, oldest first. - notifs.sort(key=lambda r: r['stream_ordering']) + notifs.sort(key=lambda r: r["stream_ordering"]) # Take only up to the limit. We have to stop at the limit because # one of the subqueries may have hit the limit. @@ -379,7 +379,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): # contain results from the first query, correctly ordered, followed # by results from the second query, but we want them all ordered # by received_ts (most recent first) - notifs.sort(key=lambda r: -(r['received_ts'] or 0)) + notifs.sort(key=lambda r: -(r["received_ts"] or 0)) # Now return the first `limit` defer.returnValue(notifs[:limit]) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index c4bfa5144c..8ebc9783c9 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -76,15 +76,17 @@ state_delta_reuse_delta_counter = Counter( # The number of forward extremities for each new event. forward_extremities_counter = Histogram( - "synapse_storage_events_forward_extremities_persisted", "", - buckets=(1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf") + "synapse_storage_events_forward_extremities_persisted", + "", + buckets=(1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), ) # The number of stale forward extremities for each new event. Stale extremities # are those that were in the previous set of extremities as well as the new. stale_forward_extremities_counter = Histogram( - "synapse_storage_events_stale_forward_extremities_persisted", "", - buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf") + "synapse_storage_events_stale_forward_extremities_persisted", + "", + buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), ) @@ -247,7 +249,7 @@ class EventsStore( BucketCollector( "synapse_forward_extremities", lambda: self._current_forward_extremities_amount, - buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"] + buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"], ) # Read the extrems every 60 minutes diff --git a/synapse/storage/events_bg_updates.py b/synapse/storage/events_bg_updates.py index 75c1935bf3..1ce21d190c 100644 --- a/synapse/storage/events_bg_updates.py +++ b/synapse/storage/events_bg_updates.py @@ -64,8 +64,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): ) self.register_background_update_handler( - self.DELETE_SOFT_FAILED_EXTREMITIES, - self._cleanup_extremities_bg_update, + self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update ) @defer.inlineCallbacks @@ -269,7 +268,8 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): LEFT JOIN events USING (event_id) LEFT JOIN event_json USING (event_id) LEFT JOIN rejections USING (event_id) - """, (batch_size,) + """, + (batch_size,), ) for prev_event_id, event_id, metadata, rejected, outlier in txn: @@ -364,13 +364,12 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): column="event_id", iterable=to_delete, keyvalues={}, - retcols=("room_id",) + retcols=("room_id",), ) room_ids = set(row["room_id"] for row in rows) for room_id in room_ids: txn.call_after( - self.get_latest_event_ids_in_room.invalidate, - (room_id,) + self.get_latest_event_ids_in_room.invalidate, (room_id,) ) self._simple_delete_many_txn( @@ -384,7 +383,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): return len(original_set) num_handled = yield self.runInteraction( - "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn, + "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn ) if not num_handled: @@ -394,8 +393,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): txn.execute("DROP TABLE _extremities_to_check") yield self.runInteraction( - "_cleanup_extremities_bg_update_drop_table", - _drop_table_txn, + "_cleanup_extremities_bg_update_drop_table", _drop_table_txn ) defer.returnValue(num_handled) diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index 5dc49822b5..e6b6b9d233 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -27,6 +27,7 @@ from synapse.api.constants import EventTypes from synapse.api.errors import NotFoundError from synapse.api.room_versions import EventFormatVersions from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401 + # these are only included to make the type annotations work from synapse.events.snapshot import EventContext # noqa: F401 from synapse.events.utils import prune_event @@ -111,8 +112,7 @@ class EventsWorkerStore(SQLBaseStore): return ts return self.runInteraction( - "get_approximate_received_ts", - _get_approximate_received_ts_txn, + "get_approximate_received_ts", _get_approximate_received_ts_txn ) @defer.inlineCallbacks @@ -298,9 +298,7 @@ class EventsWorkerStore(SQLBaseStore): # that expected_domain = get_domain_from_id(entry.event.sender) - if ( - get_domain_from_id(orig_event_info["sender"]) == expected_domain - ): + if get_domain_from_id(orig_event_info["sender"]) == expected_domain: # This redaction event is allowed. Mark as not needing a # recheck. entry.event.internal_metadata.recheck_redaction = False @@ -692,7 +690,8 @@ class EventsWorkerStore(SQLBaseStore): """ return self.runInteraction( "get_total_state_event_counts", - self._get_total_state_event_counts_txn, room_id + self._get_total_state_event_counts_txn, + room_id, ) def _get_current_state_event_counts_txn(self, txn, room_id): @@ -716,7 +715,8 @@ class EventsWorkerStore(SQLBaseStore): """ return self.runInteraction( "get_current_state_event_counts", - self._get_current_state_event_counts_txn, room_id + self._get_current_state_event_counts_txn, + room_id, ) @defer.inlineCallbacks diff --git a/synapse/storage/group_server.py b/synapse/storage/group_server.py index dce6a43ac1..73e6fc6de2 100644 --- a/synapse/storage/group_server.py +++ b/synapse/storage/group_server.py @@ -1179,11 +1179,7 @@ class GroupServerStore(SQLBaseStore): for table in tables: self._simple_delete_txn( - txn, - table=table, - keyvalues={"group_id": group_id}, + txn, table=table, keyvalues={"group_id": group_id} ) - return self.runInteraction( - "delete_group", _delete_group_txn - ) + return self.runInteraction("delete_group", _delete_group_txn) diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index e3655ad8d7..e72f89e446 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -131,7 +131,7 @@ class KeyStore(SQLBaseStore): def _invalidate(res): f = self._get_server_verify_key.invalidate for i in invalidations: - f((i, )) + f((i,)) return res return self.runInteraction( diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py index 3ecf47e7a7..6b1238ce4a 100644 --- a/synapse/storage/media_repository.py +++ b/synapse/storage/media_repository.py @@ -22,11 +22,11 @@ class MediaRepositoryStore(BackgroundUpdateStore): super(MediaRepositoryStore, self).__init__(db_conn, hs) self.register_background_index_update( - update_name='local_media_repository_url_idx', - index_name='local_media_repository_url_idx', - table='local_media_repository', - columns=['created_ts'], - where_clause='url_cache IS NOT NULL', + update_name="local_media_repository_url_idx", + index_name="local_media_repository_url_idx", + table="local_media_repository", + columns=["created_ts"], + where_clause="url_cache IS NOT NULL", ) def get_local_media(self, media_id): @@ -108,12 +108,12 @@ class MediaRepositoryStore(BackgroundUpdateStore): return dict( zip( ( - 'response_code', - 'etag', - 'expires_ts', - 'og', - 'media_id', - 'download_ts', + "response_code", + "etag", + "expires_ts", + "og", + "media_id", + "download_ts", ), row, ) diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 8aa8abc470..081564360f 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -86,11 +86,11 @@ class MonthlyActiveUsersStore(SQLBaseStore): if len(self.reserved_users) > 0: # questionmarks is a hack to overcome sqlite not supporting # tuples in 'WHERE IN %s' - questionmarks = '?' * len(self.reserved_users) + questionmarks = "?" * len(self.reserved_users) query_args.extend(self.reserved_users) sql = base_sql + """ AND user_id NOT IN ({})""".format( - ','.join(questionmarks) + ",".join(questionmarks) ) else: sql = base_sql @@ -124,7 +124,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): if len(self.reserved_users) > 0: query_args.extend(self.reserved_users) sql = base_sql + """ AND user_id NOT IN ({})""".format( - ','.join(questionmarks) + ",".join(questionmarks) ) else: sql = base_sql diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index f2c1bed487..fc10b9534e 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -146,9 +146,10 @@ def _setup_new_database(cur, database_engine): directory_entries = os.listdir(sql_dir) - for filename in sorted(fnmatch.filter(directory_entries, "*.sql") + fnmatch.filter( - directory_entries, "*.sql." + specific - )): + for filename in sorted( + fnmatch.filter(directory_entries, "*.sql") + + fnmatch.filter(directory_entries, "*.sql." + specific) + ): sql_loc = os.path.join(sql_dir, filename) logger.debug("Applying schema %s", sql_loc) executescript(cur, sql_loc) @@ -313,7 +314,7 @@ def _apply_module_schemas(txn, database_engine, config): application config """ for (mod, _config) in config.password_providers: - if not hasattr(mod, 'get_db_schema_files'): + if not hasattr(mod, "get_db_schema_files"): continue modname = ".".join((mod.__module__, mod.__name__)) _apply_module_schema_files( @@ -343,7 +344,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams) continue root_name, ext = os.path.splitext(name) - if ext != '.sql': + if ext != ".sql": raise PrepareDatabaseException( "only .sql files are currently supported for module schemas" ) @@ -407,7 +408,7 @@ def get_statements(f): def executescript(txn, schema_path): - with open(schema_path, 'r') as f: + with open(schema_path, "r") as f: for statement in get_statements(f): txn.execute(statement) diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py index 38524f2545..9bffd27d9a 100644 --- a/synapse/storage/profile.py +++ b/synapse/storage/profile.py @@ -45,7 +45,7 @@ class ProfileWorkerStore(SQLBaseStore): defer.returnValue( ProfileInfo( - avatar_url=profile['avatar_url'], display_name=profile['displayname'] + avatar_url=profile["avatar_url"], display_name=profile["displayname"] ) ) @@ -69,17 +69,14 @@ class ProfileWorkerStore(SQLBaseStore): def f(txn): txn.execute("SELECT MAX(batch) as maxbatch FROM profiles") rows = self.cursor_to_dict(txn) - return rows[0]['maxbatch'] - return self.runInteraction( - "get_latest_profile_replication_batch_number", f, - ) + return rows[0]["maxbatch"] + + return self.runInteraction("get_latest_profile_replication_batch_number", f) def get_profile_batch(self, batchnum): return self._simple_select_list( table="profiles", - keyvalues={ - "batch": batchnum, - }, + keyvalues={"batch": batchnum}, retcols=("user_id", "displayname", "avatar_url", "active"), desc="get_profile_batch", ) @@ -95,22 +92,24 @@ class ProfileWorkerStore(SQLBaseStore): ) txn.execute(sql, (BATCH_SIZE,)) return txn.rowcount + return self.runInteraction("assign_profile_batch", f) def get_replication_hosts(self): def f(txn): - txn.execute("SELECT host, last_synced_batch FROM profile_replication_status") + txn.execute( + "SELECT host, last_synced_batch FROM profile_replication_status" + ) rows = self.cursor_to_dict(txn) - return {r['host']: r['last_synced_batch'] for r in rows} + return {r["host"]: r["last_synced_batch"] for r in rows} + return self.runInteraction("get_replication_hosts", f) def update_replication_batch_for_host(self, host, last_synced_batch): return self._simple_upsert( table="profile_replication_status", keyvalues={"host": host}, - values={ - "last_synced_batch": last_synced_batch, - }, + values={"last_synced_batch": last_synced_batch}, desc="update_replication_batch_for_host", ) @@ -127,31 +126,22 @@ class ProfileWorkerStore(SQLBaseStore): return self._simple_upsert( table="profiles", keyvalues={"user_id": user_localpart}, - values={ - "displayname": new_displayname, - "batch": batchnum, - }, + values={"displayname": new_displayname, "batch": batchnum}, desc="set_profile_displayname", - lock=False # we can do this because user_id has a unique index + lock=False, # we can do this because user_id has a unique index ) def set_profile_avatar_url(self, user_localpart, new_avatar_url, batchnum): return self._simple_upsert( table="profiles", keyvalues={"user_id": user_localpart}, - values={ - "avatar_url": new_avatar_url, - "batch": batchnum, - }, + values={"avatar_url": new_avatar_url, "batch": batchnum}, desc="set_profile_avatar_url", - lock=False # we can do this because user_id has a unique index + lock=False, # we can do this because user_id has a unique index ) def set_profile_active(self, user_localpart, active, hide, batchnum): - values = { - "active": int(active), - "batch": batchnum, - } + values = {"active": int(active), "batch": batchnum} if not active and not hide: # we are deactivating for real (not in hide mode) # so clear the profile. @@ -162,7 +152,7 @@ class ProfileWorkerStore(SQLBaseStore): keyvalues={"user_id": user_localpart}, values=values, desc="set_profile_active", - lock=False # we can do this because user_id has a unique index + lock=False, # we can do this because user_id has a unique index ) diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 9e406baafa..98cec8c82b 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -46,12 +46,12 @@ def _load_rules(rawrules, enabled_map): rules = list(list_with_base_rules(ruleslist)) for i, rule in enumerate(rules): - rule_id = rule['rule_id'] + rule_id = rule["rule_id"] if rule_id in enabled_map: - if rule.get('enabled', True) != bool(enabled_map[rule_id]): + if rule.get("enabled", True) != bool(enabled_map[rule_id]): # Rules are cached across users. rule = dict(rule) - rule['enabled'] = bool(enabled_map[rule_id]) + rule["enabled"] = bool(enabled_map[rule_id]) rules[i] = rule return rules @@ -126,12 +126,12 @@ class PushRulesWorkerStore( def get_push_rules_enabled_for_user(self, user_id): results = yield self._simple_select_list( table="push_rules_enable", - keyvalues={'user_name': user_id}, + keyvalues={"user_name": user_id}, retcols=("user_name", "rule_id", "enabled"), desc="get_push_rules_enabled_for_user", ) defer.returnValue( - {r['rule_id']: False if r['enabled'] == 0 else True for r in results} + {r["rule_id"]: False if r["enabled"] == 0 else True for r in results} ) def have_push_rules_changed_for_user(self, user_id, last_id): @@ -175,7 +175,7 @@ class PushRulesWorkerStore( rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) for row in rows: - results.setdefault(row['user_name'], []).append(row) + results.setdefault(row["user_name"], []).append(row) enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids) @@ -194,7 +194,7 @@ class PushRulesWorkerStore( rule (Dict): A push rule. """ # Create new rule id - rule_id_scope = '/'.join(rule["rule_id"].split('/')[:-1]) + rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1]) new_rule_id = rule_id_scope + "/" + new_room_id # Change room id in each condition @@ -334,8 +334,8 @@ class PushRulesWorkerStore( desc="bulk_get_push_rules_enabled", ) for row in rows: - enabled = bool(row['enabled']) - results.setdefault(row['user_name'], {})[row['rule_id']] = enabled + enabled = bool(row["enabled"]) + results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled defer.returnValue(results) @@ -568,7 +568,7 @@ class PushRuleStore(PushRulesWorkerStore): def delete_push_rule_txn(txn, stream_id, event_stream_ordering): self._simple_delete_one_txn( - txn, "push_rules", {'user_name': user_id, 'rule_id': rule_id} + txn, "push_rules", {"user_name": user_id, "rule_id": rule_id} ) self._insert_push_rules_update_txn( @@ -605,9 +605,9 @@ class PushRuleStore(PushRulesWorkerStore): self._simple_upsert_txn( txn, "push_rules_enable", - {'user_name': user_id, 'rule_id': rule_id}, - {'enabled': 1 if enabled else 0}, - {'id': new_id}, + {"user_name": user_id, "rule_id": rule_id}, + {"enabled": 1 if enabled else 0}, + {"id": new_id}, ) self._insert_push_rules_update_txn( @@ -645,8 +645,8 @@ class PushRuleStore(PushRulesWorkerStore): self._simple_update_one_txn( txn, "push_rules", - {'user_name': user_id, 'rule_id': rule_id}, - {'actions': actions_json}, + {"user_name": user_id, "rule_id": rule_id}, + {"actions": actions_json}, ) self._insert_push_rules_update_txn( diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 1567e1df48..cfe0a94330 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -37,24 +37,24 @@ else: class PusherWorkerStore(SQLBaseStore): def _decode_pushers_rows(self, rows): for r in rows: - dataJson = r['data'] - r['data'] = None + dataJson = r["data"] + r["data"] = None try: if isinstance(dataJson, db_binary_type): dataJson = str(dataJson).decode("UTF8") - r['data'] = json.loads(dataJson) + r["data"] = json.loads(dataJson) except Exception as e: logger.warn( "Invalid JSON in data for pusher %d: %s, %s", - r['id'], + r["id"], dataJson, e.args[0], ) pass - if isinstance(r['pushkey'], db_binary_type): - r['pushkey'] = str(r['pushkey']).decode("UTF8") + if isinstance(r["pushkey"], db_binary_type): + r["pushkey"] = str(r["pushkey"]).decode("UTF8") return rows @@ -195,15 +195,15 @@ class PusherWorkerStore(SQLBaseStore): ) def get_if_users_have_pushers(self, user_ids): rows = yield self._simple_select_many_batch( - table='pushers', - column='user_name', + table="pushers", + column="user_name", iterable=user_ids, - retcols=['user_name'], - desc='get_if_users_have_pushers', + retcols=["user_name"], + desc="get_if_users_have_pushers", ) result = {user_id: False for user_id in user_ids} - result.update({r['user_name']: True for r in rows}) + result.update({r["user_name"]: True for r in rows}) defer.returnValue(result) @@ -299,8 +299,8 @@ class PusherStore(PusherWorkerStore): ): yield self._simple_update_one( "pushers", - {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id}, - {'last_stream_ordering': last_stream_ordering}, + {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, + {"last_stream_ordering": last_stream_ordering}, desc="update_pusher_last_stream_ordering", ) @@ -310,10 +310,10 @@ class PusherStore(PusherWorkerStore): ): yield self._simple_update_one( "pushers", - {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id}, + {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, { - 'last_stream_ordering': last_stream_ordering, - 'last_success': last_success, + "last_stream_ordering": last_stream_ordering, + "last_success": last_success, }, desc="update_pusher_last_stream_ordering_and_success", ) @@ -322,8 +322,8 @@ class PusherStore(PusherWorkerStore): def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since): yield self._simple_update_one( "pushers", - {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id}, - {'failing_since': failing_since}, + {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, + {"failing_since": failing_since}, desc="update_pusher_failing_since", ) diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index a1647e50a1..b477da12b1 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -58,7 +58,7 @@ class ReceiptsWorkerStore(SQLBaseStore): @cachedInlineCallbacks() def get_users_with_read_receipts_in_room(self, room_id): receipts = yield self.get_receipts_for_room(room_id, "m.read") - defer.returnValue(set(r['user_id'] for r in receipts)) + defer.returnValue(set(r["user_id"] for r in receipts)) @cached(num_args=2) def get_receipts_for_room(self, room_id, receipt_type): diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 028848cf89..8ecf7c029c 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -116,8 +116,9 @@ class RegistrationWorkerStore(SQLBaseStore): defer.returnValue(res) @defer.inlineCallbacks - def set_account_validity_for_user(self, user_id, expiration_ts, email_sent, - renewal_token=None): + def set_account_validity_for_user( + self, user_id, expiration_ts, email_sent, renewal_token=None + ): """Updates the account validity properties of the given account, with the given values. @@ -131,6 +132,7 @@ class RegistrationWorkerStore(SQLBaseStore): renewal_token (str): Renewal token the user can use to extend the validity of their account. Defaults to no token. """ + def set_account_validity_for_user_txn(txn): self._simple_update_txn( txn=txn, @@ -143,12 +145,11 @@ class RegistrationWorkerStore(SQLBaseStore): }, ) self._invalidate_cache_and_stream( - txn, self.get_expiration_ts_for_user, (user_id,), + txn, self.get_expiration_ts_for_user, (user_id,) ) yield self.runInteraction( - "set_account_validity_for_user", - set_account_validity_for_user_txn, + "set_account_validity_for_user", set_account_validity_for_user_txn ) @defer.inlineCallbacks @@ -158,6 +159,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: Deferred[list[str]]: List of expired user IDs """ + def get_expired_users_txn(txn, now_ms): sql = """ SELECT user_id from account_validity @@ -168,9 +170,7 @@ class RegistrationWorkerStore(SQLBaseStore): return [row[0] for row in rows] res = yield self.runInteraction( - "get_expired_users", - get_expired_users_txn, - self.clock.time_msec(), + "get_expired_users", get_expired_users_txn, self.clock.time_msec() ) defer.returnValue(res) @@ -240,6 +240,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]] """ + def select_users_txn(txn, now_ms, renew_at): sql = ( "SELECT user_id, expiration_ts_ms FROM account_validity" @@ -252,7 +253,8 @@ class RegistrationWorkerStore(SQLBaseStore): res = yield self.runInteraction( "get_users_expiring_soon", select_users_txn, - self.clock.time_msec(), self.config.account_validity.renew_at, + self.clock.time_msec(), + self.config.account_validity.renew_at, ) defer.returnValue(res) @@ -392,7 +394,7 @@ class RegistrationWorkerStore(SQLBaseStore): WHERE creation_ts > ? ) AS t GROUP BY user_type """ - results = {'native': 0, 'guest': 0, 'bridged': 0} + results = {"native": 0, "guest": 0, "bridged": 0} txn.execute(sql, (yesterday,)) for row in txn: results[row[0]] = row[1] @@ -458,7 +460,7 @@ class RegistrationWorkerStore(SQLBaseStore): {"medium": medium, "address": address}, ["guest_access_token"], True, - 'get_3pid_guest_access_token', + "get_3pid_guest_access_token", ) if ret: defer.returnValue(ret["guest_access_token"]) @@ -495,11 +497,11 @@ class RegistrationWorkerStore(SQLBaseStore): txn, "user_threepids", {"medium": medium, "address": address}, - ['user_id'], + ["user_id"], True, ) if ret: - return ret['user_id'] + return ret["user_id"] return None @defer.inlineCallbacks @@ -515,8 +517,8 @@ class RegistrationWorkerStore(SQLBaseStore): ret = yield self._simple_select_list( "user_threepids", {"user_id": user_id}, - ['medium', 'address', 'validated_at', 'added_at'], - 'user_get_threepids', + ["medium", "address", "validated_at", "added_at"], + "user_get_threepids", ) defer.returnValue(ret) @@ -595,11 +597,7 @@ class RegistrationWorkerStore(SQLBaseStore): """ return self._simple_select_onecol( table="user_threepid_id_server", - keyvalues={ - "user_id": user_id, - "medium": medium, - "address": address, - }, + keyvalues={"user_id": user_id, "medium": medium, "address": address}, retcol="id_server", desc="get_id_servers_user_bound", ) @@ -635,16 +633,16 @@ class RegistrationStore( self.register_noop_background_update("refresh_tokens_device_index") self.register_background_update_handler( - "user_threepids_grandfather", self._bg_user_threepids_grandfather, + "user_threepids_grandfather", self._bg_user_threepids_grandfather ) self.register_background_update_handler( - "users_set_deactivated_flag", self._backgroud_update_set_deactivated_flag, + "users_set_deactivated_flag", self._backgroud_update_set_deactivated_flag ) # Create a background job for culling expired 3PID validity tokens hs.get_clock().looping_call( - self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS, + self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS ) @defer.inlineCallbacks @@ -700,8 +698,7 @@ class RegistrationStore( return False end = yield self.runInteraction( - "users_set_deactivated_flag", - _backgroud_update_set_deactivated_flag_txn, + "users_set_deactivated_flag", _backgroud_update_set_deactivated_flag_txn ) if end: @@ -874,7 +871,7 @@ class RegistrationStore( def user_set_password_hash_txn(txn): self._simple_update_one_txn( - txn, 'users', {'name': user_id}, {'password_hash': password_hash} + txn, "users", {"name": user_id}, {"password_hash": password_hash} ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) @@ -895,9 +892,9 @@ class RegistrationStore( def f(txn): self._simple_update_one_txn( txn, - table='users', - keyvalues={'name': user_id}, - updatevalues={'consent_version': consent_version}, + table="users", + keyvalues={"name": user_id}, + updatevalues={"consent_version": consent_version}, ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) @@ -919,9 +916,9 @@ class RegistrationStore( def f(txn): self._simple_update_one_txn( txn, - table='users', - keyvalues={'name': user_id}, - updatevalues={'consent_server_notice_sent': consent_version}, + table="users", + keyvalues={"name": user_id}, + updatevalues={"consent_server_notice_sent": consent_version}, ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) @@ -1091,7 +1088,7 @@ class RegistrationStore( if id_servers: yield self.runInteraction( - "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn, + "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn ) yield self._end_background_update("user_threepids_grandfather") @@ -1099,12 +1096,7 @@ class RegistrationStore( defer.returnValue(1) def get_threepid_validation_session( - self, - medium, - client_secret, - address=None, - sid=None, - validated=True, + self, medium, client_secret, address=None, sid=None, validated=True ): """Gets a session_id and last_send_attempt (if available) for a client_secret/medium/(address|session_id) combo @@ -1124,23 +1116,22 @@ class RegistrationStore( latest session_id and send_attempt count for this 3PID. Otherwise None if there hasn't been a previous attempt """ - keyvalues = { - "medium": medium, - "client_secret": client_secret, - } + keyvalues = {"medium": medium, "client_secret": client_secret} if address: keyvalues["address"] = address if sid: keyvalues["session_id"] = sid - assert(address or sid) + assert address or sid def get_threepid_validation_session_txn(txn): sql = """ SELECT address, session_id, medium, client_secret, last_send_attempt, validated_at FROM threepid_validation_session WHERE %s - """ % (" AND ".join("%s = ?" % k for k in iterkeys(keyvalues)),) + """ % ( + " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)), + ) if validated is not None: sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL") @@ -1155,17 +1146,10 @@ class RegistrationStore( return rows[0] return self.runInteraction( - "get_threepid_validation_session", - get_threepid_validation_session_txn, + "get_threepid_validation_session", get_threepid_validation_session_txn ) - def validate_threepid_session( - self, - session_id, - client_secret, - token, - current_ts, - ): + def validate_threepid_session(self, session_id, client_secret, token, current_ts): """Attempt to validate a threepid session using a token Args: @@ -1197,7 +1181,7 @@ class RegistrationStore( if retrieved_client_secret != client_secret: raise ThreepidValidationError( - 400, "This client_secret does not match the provided session_id", + 400, "This client_secret does not match the provided session_id" ) row = self._simple_select_one_txn( @@ -1210,7 +1194,7 @@ class RegistrationStore( if not row: raise ThreepidValidationError( - 400, "Validation token not found or has expired", + 400, "Validation token not found or has expired" ) expires = row["expires"] next_link = row["next_link"] @@ -1221,7 +1205,7 @@ class RegistrationStore( if expires <= current_ts: raise ThreepidValidationError( - 400, "This token has expired. Please request a new one", + 400, "This token has expired. Please request a new one" ) # Looks good. Validate the session @@ -1236,8 +1220,7 @@ class RegistrationStore( # Return next_link if it exists return self.runInteraction( - "validate_threepid_session_txn", - validate_threepid_session_txn, + "validate_threepid_session_txn", validate_threepid_session_txn ) def upsert_threepid_validation_session( @@ -1304,6 +1287,7 @@ class RegistrationStore( token_expires (int): The timestamp for which after the token will no longer be valid """ + def start_or_continue_validation_session_txn(txn): # Create or update a validation session self._simple_upsert_txn( @@ -1337,6 +1321,7 @@ class RegistrationStore( def cull_expired_threepid_validation_tokens(self): """Remove threepid validation tokens with expiry dates that have passed""" + def cull_expired_threepid_validation_tokens_txn(txn, ts): sql = """ DELETE FROM threepid_validation_token WHERE @@ -1358,6 +1343,7 @@ class RegistrationStore( Args: session_id (str): The ID of the session to delete """ + def delete_threepid_session_txn(txn): self._simple_delete_txn( txn, @@ -1371,8 +1357,7 @@ class RegistrationStore( ) return self.runInteraction( - "delete_threepid_session", - delete_threepid_session_txn, + "delete_threepid_session", delete_threepid_session_txn ) def set_user_deactivated_status_txn(self, txn, user_id, deactivated): @@ -1383,7 +1368,7 @@ class RegistrationStore( updatevalues={"deactivated": 1 if deactivated else 0}, ) self._invalidate_cache_and_stream( - txn, self.get_user_deactivated_status, (user_id,), + txn, self.get_user_deactivated_status, (user_id,) ) @defer.inlineCallbacks @@ -1398,7 +1383,8 @@ class RegistrationStore( yield self.runInteraction( "set_user_deactivated_status", self.set_user_deactivated_status_txn, - user_id, deactivated, + user_id, + deactivated, ) @cachedInlineCallbacks() diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index 4c83800cca..1b01934c19 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -468,9 +468,5 @@ class RelationsStore(RelationsWorkerStore): """ self._simple_delete_txn( - txn, - table="event_relations", - keyvalues={ - "event_id": redacted_event_id, - } + txn, table="event_relations", keyvalues={"event_id": redacted_event_id} ) diff --git a/synapse/storage/room.py b/synapse/storage/room.py index db3d052d33..49df35e3de 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -242,10 +242,7 @@ class RoomWorkerStore(SQLBaseStore): # maximum, in order not to filter out events we should filter out when sending to # the client. if not self.config.retention_enabled: - defer.returnValue({ - "min_lifetime": None, - "max_lifetime": None, - }) + defer.returnValue({"min_lifetime": None, "max_lifetime": None}) def get_retention_policy_for_room_txn(txn): txn.execute( @@ -254,23 +251,24 @@ class RoomWorkerStore(SQLBaseStore): INNER JOIN current_state_events USING (event_id, room_id) WHERE room_id = ?; """, - (room_id,) + (room_id,), ) return self.cursor_to_dict(txn) ret = yield self.runInteraction( - "get_retention_policy_for_room", - get_retention_policy_for_room_txn, + "get_retention_policy_for_room", get_retention_policy_for_room_txn ) # If we don't know this room ID, ret will be None, in this case return the default # policy. if not ret: - defer.returnValue({ - "min_lifetime": self.config.retention_default_min_lifetime, - "max_lifetime": self.config.retention_default_max_lifetime, - }) + defer.returnValue( + { + "min_lifetime": self.config.retention_default_min_lifetime, + "max_lifetime": self.config.retention_default_max_lifetime, + } + ) row = ret[0] @@ -294,7 +292,7 @@ class RoomStore(RoomWorkerStore, SearchStore): self.config = hs.config self.register_background_update_handler( - "insert_room_retention", self._background_insert_retention, + "insert_room_retention", self._background_insert_retention ) @defer.inlineCallbacks @@ -317,8 +315,9 @@ class RoomStore(RoomWorkerStore, SearchStore): WHERE state.room_id > ? AND state.type = '%s' ORDER BY state.room_id ASC LIMIT ?; - """ % EventTypes.Retention, - (last_room, batch_size) + """ + % EventTypes.Retention, + (last_room, batch_size), ) rows = self.cursor_to_dict(txn) @@ -341,15 +340,13 @@ class RoomStore(RoomWorkerStore, SearchStore): "event_id": row["event_id"], "min_lifetime": retention_policy.get("min_lifetime"), "max_lifetime": retention_policy.get("max_lifetime"), - } + }, ) logger.info("Inserted %d rows into room_retention", len(rows)) self._background_update_progress_txn( - txn, "insert_room_retention", { - "room_id": rows[-1]["room_id"], - } + txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]} ) if batch_size > len(rows): @@ -358,8 +355,7 @@ class RoomStore(RoomWorkerStore, SearchStore): return False end = yield self.runInteraction( - "insert_room_retention", - _background_insert_retention_txn, + "insert_room_retention", _background_insert_retention_txn ) if end: @@ -603,17 +599,15 @@ class RoomStore(RoomWorkerStore, SearchStore): txn.execute(sql, (event.event_id, event.room_id, event.content[key])) def _store_retention_policy_for_room_txn(self, txn, event): - if ( - hasattr(event, "content") - and ("min_lifetime" in event.content or "max_lifetime" in event.content) + if hasattr(event, "content") and ( + "min_lifetime" in event.content or "max_lifetime" in event.content ): if ( - ("min_lifetime" in event.content and not isinstance( - event.content.get("min_lifetime"), integer_types - )) - or ("max_lifetime" in event.content and not isinstance( - event.content.get("max_lifetime"), integer_types - )) + "min_lifetime" in event.content + and not isinstance(event.content.get("min_lifetime"), integer_types) + ) or ( + "max_lifetime" in event.content + and not isinstance(event.content.get("max_lifetime"), integer_types) ): # Ignore the event if one of the value isn't an integer. return @@ -816,7 +810,9 @@ class RoomStore(RoomWorkerStore, SearchStore): return local_media_mxcs, remote_media_mxcs @defer.inlineCallbacks - def get_rooms_for_retention_period_in_range(self, min_ms, max_ms, include_null=False): + def get_rooms_for_retention_period_in_range( + self, min_ms, max_ms, include_null=False + ): """Retrieves all of the rooms within the given retention range. Optionally includes the rooms which don't have a retention policy. diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 7617913326..8004aeb909 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -420,7 +420,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): table="room_memberships", column="event_id", iterable=missing_member_event_ids, - retcols=('user_id', 'display_name', 'avatar_url'), + retcols=("user_id", "display_name", "avatar_url"), keyvalues={"membership": Membership.JOIN}, batch_size=500, desc="_get_joined_users_from_context", @@ -448,7 +448,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): @cachedInlineCallbacks(max_entries=10000) def is_host_joined(self, room_id, host): - if '%' in host or '_' in host: + if "%" in host or "_" in host: raise Exception("Invalid host name") sql = """ @@ -490,7 +490,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): Deferred: Resolves to True if the host is/was in the room, otherwise False. """ - if '%' in host or '_' in host: + if "%" in host or "_" in host: raise Exception("Invalid host name") sql = """ @@ -723,7 +723,7 @@ class RoomMemberStore(RoomMemberWorkerStore): room_id = row["room_id"] try: event_json = json.loads(row["json"]) - content = event_json['content'] + content = event_json["content"] except Exception: continue diff --git a/synapse/storage/schema/delta/20/pushers.py b/synapse/storage/schema/delta/20/pushers.py index 147496a38b..3edfcfd783 100644 --- a/synapse/storage/schema/delta/20/pushers.py +++ b/synapse/storage/schema/delta/20/pushers.py @@ -29,7 +29,8 @@ logger = logging.getLogger(__name__) def run_create(cur, database_engine, *args, **kwargs): logger.info("Porting pushers table...") - cur.execute(""" + cur.execute( + """ CREATE TABLE IF NOT EXISTS pushers2 ( id BIGINT PRIMARY KEY, user_name TEXT NOT NULL, @@ -48,27 +49,34 @@ def run_create(cur, database_engine, *args, **kwargs): failing_since BIGINT, UNIQUE (app_id, pushkey, user_name) ) - """) - cur.execute("""SELECT + """ + ) + cur.execute( + """SELECT id, user_name, access_token, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_token, last_success, failing_since FROM pushers - """) + """ + ) count = 0 for row in cur.fetchall(): row = list(row) row[8] = bytes(row[8]).decode("utf-8") row[11] = bytes(row[11]).decode("utf-8") - cur.execute(database_engine.convert_param_style(""" + cur.execute( + database_engine.convert_param_style( + """ INSERT into pushers2 ( id, user_name, access_token, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_token, last_success, failing_since - ) values (%s)""" % (','.join(['?' for _ in range(len(row))]))), - row + ) values (%s)""" + % (",".join(["?" for _ in range(len(row))])) + ), + row, ) count += 1 cur.execute("DROP TABLE pushers") diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/schema/delta/30/as_users.py index ef7ec34346..9b95411fb6 100644 --- a/synapse/storage/schema/delta/30/as_users.py +++ b/synapse/storage/schema/delta/30/as_users.py @@ -40,9 +40,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): logger.warning("Could not get app_service_config_files from config") pass - appservices = load_appservices( - config.server_name, config_files - ) + appservices = load_appservices(config.server_name, config_files) owned = {} @@ -53,20 +51,19 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): if user_id in owned.keys(): logger.error( "user_id %s was owned by more than one application" - " service (IDs %s and %s); assigning arbitrarily to %s" % - (user_id, owned[user_id], appservice.id, owned[user_id]) + " service (IDs %s and %s); assigning arbitrarily to %s" + % (user_id, owned[user_id], appservice.id, owned[user_id]) ) owned.setdefault(appservice.id, []).append(user_id) for as_id, user_ids in owned.items(): n = 100 - user_chunks = (user_ids[i:i + 100] for i in range(0, len(user_ids), n)) + user_chunks = (user_ids[i : i + 100] for i in range(0, len(user_ids), n)) for chunk in user_chunks: cur.execute( database_engine.convert_param_style( - "UPDATE users SET appservice_id = ? WHERE name IN (%s)" % ( - ",".join("?" for _ in chunk), - ) + "UPDATE users SET appservice_id = ? WHERE name IN (%s)" + % (",".join("?" for _ in chunk),) ), - [as_id] + chunk + [as_id] + chunk, ) diff --git a/synapse/storage/schema/delta/31/pushers.py b/synapse/storage/schema/delta/31/pushers.py index 93367fa09e..9bb504aad5 100644 --- a/synapse/storage/schema/delta/31/pushers.py +++ b/synapse/storage/schema/delta/31/pushers.py @@ -24,12 +24,13 @@ logger = logging.getLogger(__name__) def token_to_stream_ordering(token): - return int(token[1:].split('_')[0]) + return int(token[1:].split("_")[0]) def run_create(cur, database_engine, *args, **kwargs): logger.info("Porting pushers table, delta 31...") - cur.execute(""" + cur.execute( + """ CREATE TABLE IF NOT EXISTS pushers2 ( id BIGINT PRIMARY KEY, user_name TEXT NOT NULL, @@ -48,26 +49,33 @@ def run_create(cur, database_engine, *args, **kwargs): failing_since BIGINT, UNIQUE (app_id, pushkey, user_name) ) - """) - cur.execute("""SELECT + """ + ) + cur.execute( + """SELECT id, user_name, access_token, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_token, last_success, failing_since FROM pushers - """) + """ + ) count = 0 for row in cur.fetchall(): row = list(row) row[12] = token_to_stream_ordering(row[12]) - cur.execute(database_engine.convert_param_style(""" + cur.execute( + database_engine.convert_param_style( + """ INSERT into pushers2 ( id, user_name, access_token, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_stream_ordering, last_success, failing_since - ) values (%s)""" % (','.join(['?' for _ in range(len(row))]))), - row + ) values (%s)""" + % (",".join(["?" for _ in range(len(row))])) + ), + row, ) count += 1 cur.execute("DROP TABLE pushers") diff --git a/synapse/storage/schema/delta/33/remote_media_ts.py b/synapse/storage/schema/delta/33/remote_media_ts.py index 9754d3ccfb..a26057dfb6 100644 --- a/synapse/storage/schema/delta/33/remote_media_ts.py +++ b/synapse/storage/schema/delta/33/remote_media_ts.py @@ -26,5 +26,5 @@ def run_upgrade(cur, database_engine, *args, **kwargs): database_engine.convert_param_style( "UPDATE remote_media_cache SET last_access_ts = ?" ), - (int(time.time() * 1000),) + (int(time.time() * 1000),), ) diff --git a/synapse/storage/schema/delta/47/state_group_seq.py b/synapse/storage/schema/delta/47/state_group_seq.py index f6766501d2..9fd1ccf6f7 100644 --- a/synapse/storage/schema/delta/47/state_group_seq.py +++ b/synapse/storage/schema/delta/47/state_group_seq.py @@ -27,10 +27,7 @@ def run_create(cur, database_engine, *args, **kwargs): else: start_val = row[0] + 1 - cur.execute( - "CREATE SEQUENCE state_group_id_seq START WITH %s", - (start_val, ), - ) + cur.execute("CREATE SEQUENCE state_group_id_seq START WITH %s", (start_val,)) def run_upgrade(*args, **kwargs): diff --git a/synapse/storage/schema/delta/48/group_unique_indexes.py b/synapse/storage/schema/delta/48/group_unique_indexes.py index 2233af87d7..49f5f2c003 100644 --- a/synapse/storage/schema/delta/48/group_unique_indexes.py +++ b/synapse/storage/schema/delta/48/group_unique_indexes.py @@ -38,16 +38,22 @@ def run_create(cur, database_engine, *args, **kwargs): rowid = "ctid" if isinstance(database_engine, PostgresEngine) else "rowid" # remove duplicates from group_users & group_invites tables - cur.execute(""" + cur.execute( + """ DELETE FROM group_users WHERE %s NOT IN ( SELECT min(%s) FROM group_users GROUP BY group_id, user_id ); - """ % (rowid, rowid)) - cur.execute(""" + """ + % (rowid, rowid) + ) + cur.execute( + """ DELETE FROM group_invites WHERE %s NOT IN ( SELECT min(%s) FROM group_invites GROUP BY group_id, user_id ); - """ % (rowid, rowid)) + """ + % (rowid, rowid) + ) for statement in get_statements(FIX_INDEXES.splitlines()): cur.execute(statement) diff --git a/synapse/storage/schema/delta/50/make_event_content_nullable.py b/synapse/storage/schema/delta/50/make_event_content_nullable.py index 6dd467b6c5..b1684a8441 100644 --- a/synapse/storage/schema/delta/50/make_event_content_nullable.py +++ b/synapse/storage/schema/delta/50/make_event_content_nullable.py @@ -65,14 +65,18 @@ def run_create(cur, database_engine, *args, **kwargs): def run_upgrade(cur, database_engine, *args, **kwargs): if isinstance(database_engine, PostgresEngine): - cur.execute(""" + cur.execute( + """ ALTER TABLE events ALTER COLUMN content DROP NOT NULL; - """) + """ + ) return # sqlite is an arse about this. ref: https://www.sqlite.org/lang_altertable.html - cur.execute("SELECT sql FROM sqlite_master WHERE tbl_name='events' AND type='table'") + cur.execute( + "SELECT sql FROM sqlite_master WHERE tbl_name='events' AND type='table'" + ) (oldsql,) = cur.fetchone() sql = oldsql.replace("content TEXT NOT NULL", "content TEXT") @@ -86,7 +90,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs): cur.execute("PRAGMA writable_schema=ON") cur.execute( "UPDATE sqlite_master SET sql=? WHERE tbl_name='events' AND type='table'", - (sql, ), + (sql,), ) cur.execute("PRAGMA schema_version=%i" % (oldver + 1,)) cur.execute("PRAGMA writable_schema=OFF") diff --git a/synapse/storage/search.py b/synapse/storage/search.py index 10a27c207a..f3b1cec933 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -31,8 +31,8 @@ from .background_updates import BackgroundUpdateStore logger = logging.getLogger(__name__) SearchEntry = namedtuple( - 'SearchEntry', - ['key', 'value', 'event_id', 'room_id', 'stream_ordering', 'origin_server_ts'], + "SearchEntry", + ["key", "value", "event_id", "room_id", "stream_ordering", "origin_server_ts"], ) @@ -216,7 +216,7 @@ class SearchStore(BackgroundUpdateStore): target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] rows_inserted = progress.get("rows_inserted", 0) - have_added_index = progress['have_added_indexes'] + have_added_index = progress["have_added_indexes"] if not have_added_index: diff --git a/synapse/storage/stats.py b/synapse/storage/stats.py index ff266b09b0..1cec84ee2e 100644 --- a/synapse/storage/stats.py +++ b/synapse/storage/stats.py @@ -71,7 +71,8 @@ class StatsStore(StateDeltasStore): # Get all the rooms that we want to process. def _make_staging_area(txn): # Create the temporary tables - stmts = get_statements(""" + stmts = get_statements( + """ -- We just recreate the table, we'll be reinserting the -- correct entries again later anyway. DROP TABLE IF EXISTS {temp}_rooms; @@ -85,7 +86,10 @@ class StatsStore(StateDeltasStore): ON {temp}_rooms(events); CREATE INDEX {temp}_rooms_id ON {temp}_rooms(room_id); - """.format(temp=TEMP_TABLE).splitlines()) + """.format( + temp=TEMP_TABLE + ).splitlines() + ) for statement in stmts: txn.execute(statement) @@ -105,7 +109,9 @@ class StatsStore(StateDeltasStore): LEFT JOIN room_stats_earliest_token AS t USING (room_id) WHERE t.room_id IS NULL GROUP BY c.room_id - """ % (TEMP_TABLE,) + """ % ( + TEMP_TABLE, + ) txn.execute(sql) new_pos = yield self.get_max_stream_id_in_current_state_deltas() @@ -184,7 +190,8 @@ class StatsStore(StateDeltasStore): logger.info( "Processing the next %d rooms of %d remaining", - len(rooms_to_work_on), progress["remaining"], + len(rooms_to_work_on), + progress["remaining"], ) # Number of state events we've processed by going through each room @@ -204,10 +211,17 @@ class StatsStore(StateDeltasStore): avatar_id = current_state_ids.get((EventTypes.RoomAvatar, "")) canonical_alias_id = current_state_ids.get((EventTypes.CanonicalAlias, "")) - state_events = yield self.get_events([ - join_rules_id, history_visibility_id, encryption_id, name_id, - topic_id, avatar_id, canonical_alias_id, - ]) + state_events = yield self.get_events( + [ + join_rules_id, + history_visibility_id, + encryption_id, + name_id, + topic_id, + avatar_id, + canonical_alias_id, + ] + ) def _get_or_none(event_id, arg): event = state_events.get(event_id) @@ -271,7 +285,7 @@ class StatsStore(StateDeltasStore): # We've finished a room. Delete it from the table. self._simple_delete_one_txn( - txn, TEMP_TABLE + "_rooms", {"room_id": room_id}, + txn, TEMP_TABLE + "_rooms", {"room_id": room_id} ) yield self.runInteraction("update_room_stats", _fetch_data) @@ -338,7 +352,7 @@ class StatsStore(StateDeltasStore): "name", "topic", "avatar", - "canonical_alias" + "canonical_alias", ): field = fields.get(col) if field and "\0" in field: diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 6f7f65d96b..d9482a3843 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -65,7 +65,7 @@ _EventDictReturn = namedtuple( def generate_pagination_where_clause( - direction, column_names, from_token, to_token, engine, + direction, column_names, from_token, to_token, engine ): """Creates an SQL expression to bound the columns by the pagination tokens. @@ -153,7 +153,7 @@ def _make_generic_sql_bound(bound, column_names, values, engine): str """ - assert(bound in (">", "<", ">=", "<=")) + assert bound in (">", "<", ">=", "<=") name1, name2 = column_names val1, val2 = values @@ -169,11 +169,7 @@ def _make_generic_sql_bound(bound, column_names, values, engine): # Postgres doesn't optimise ``(x < a) OR (x=a AND y" % ( - id(self), self._result, self._deferred, + id(self), + self._result, + self._deferred, ) @@ -150,10 +153,12 @@ def concurrently_execute(func, args, limit): except StopIteration: pass - return logcontext.make_deferred_yieldable(defer.gatherResults([ - run_in_background(_concurrently_execute_inner) - for _ in range(limit) - ], consumeErrors=True)).addErrback(unwrapFirstError) + return logcontext.make_deferred_yieldable( + defer.gatherResults( + [run_in_background(_concurrently_execute_inner) for _ in range(limit)], + consumeErrors=True, + ) + ).addErrback(unwrapFirstError) def yieldable_gather_results(func, iter, *args, **kwargs): @@ -169,10 +174,12 @@ def yieldable_gather_results(func, iter, *args, **kwargs): Deferred[list]: Resolved when all functions have been invoked, or errors if one of the function calls fails. """ - return logcontext.make_deferred_yieldable(defer.gatherResults([ - run_in_background(func, item, *args, **kwargs) - for item in iter - ], consumeErrors=True)).addErrback(unwrapFirstError) + return logcontext.make_deferred_yieldable( + defer.gatherResults( + [run_in_background(func, item, *args, **kwargs) for item in iter], + consumeErrors=True, + ) + ).addErrback(unwrapFirstError) class Linearizer(object): @@ -185,6 +192,7 @@ class Linearizer(object): # do some work. """ + def __init__(self, name=None, max_count=1, clock=None): """ Args: @@ -197,6 +205,7 @@ class Linearizer(object): if not clock: from twisted.internet import reactor + clock = Clock(reactor) self._clock = clock self.max_count = max_count @@ -221,7 +230,7 @@ class Linearizer(object): res = self._await_lock(key) else: logger.debug( - "Acquired uncontended linearizer lock %r for key %r", self.name, key, + "Acquired uncontended linearizer lock %r for key %r", self.name, key ) entry[0] += 1 res = defer.succeed(None) @@ -266,9 +275,7 @@ class Linearizer(object): """ entry = self.key_to_defer[key] - logger.debug( - "Waiting to acquire linearizer lock %r for key %r", self.name, key, - ) + logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key) new_defer = make_deferred_yieldable(defer.Deferred()) entry[1][new_defer] = 1 @@ -293,14 +300,14 @@ class Linearizer(object): logger.info("defer %r got err %r", new_defer, e) if isinstance(e, CancelledError): logger.debug( - "Cancelling wait for linearizer lock %r for key %r", - self.name, key, + "Cancelling wait for linearizer lock %r for key %r", self.name, key ) else: logger.warn( "Unexpected exception waiting for linearizer lock %r for key %r", - self.name, key, + self.name, + key, ) # we just have to take ourselves back out of the queue. @@ -438,7 +445,7 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None): try: deferred.cancel() - except: # noqa: E722, if we throw any exception it'll break time outs + except: # noqa: E722, if we throw any exception it'll break time outs logger.exception("Canceller failed during timeout") if not new_d.called: diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index f37d5bec08..8271229015 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -104,8 +104,8 @@ def register_cache(cache_type, cache_name, cache): KNOWN_KEYS = { - key: key for key in - ( + key: key + for key in ( "auth_events", "content", "depth", @@ -150,7 +150,7 @@ def intern_dict(dictionary): def _intern_known_values(key, value): - intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key",) + intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key") if key in intern_keys: return intern_string(value) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 187510576a..d2f25063aa 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -40,9 +40,7 @@ _CacheSentinel = object() class CacheEntry(object): - __slots__ = [ - "deferred", "callbacks", "invalidated" - ] + __slots__ = ["deferred", "callbacks", "invalidated"] def __init__(self, deferred, callbacks): self.deferred = deferred @@ -73,7 +71,9 @@ class Cache(object): self._pending_deferred_cache = cache_type() self.cache = LruCache( - max_size=max_entries, keylen=keylen, cache_type=cache_type, + max_size=max_entries, + keylen=keylen, + cache_type=cache_type, size_callback=(lambda d: len(d)) if iterable else None, evicted_callback=self._on_evicted, ) @@ -133,10 +133,7 @@ class Cache(object): def set(self, key, value, callback=None): callbacks = [callback] if callback else [] self.check_thread() - entry = CacheEntry( - deferred=value, - callbacks=callbacks, - ) + entry = CacheEntry(deferred=value, callbacks=callbacks) existing_entry = self._pending_deferred_cache.pop(key, None) if existing_entry: @@ -191,9 +188,7 @@ class Cache(object): def invalidate_many(self, key): self.check_thread() if not isinstance(key, tuple): - raise TypeError( - "The cache key must be a tuple not %r" % (type(key),) - ) + raise TypeError("The cache key must be a tuple not %r" % (type(key),)) self.cache.del_multi(key) # if we have a pending lookup for this key, remove it from the @@ -244,29 +239,25 @@ class _CacheDescriptorBase(object): raise Exception( "Not enough explicit positional arguments to key off for %r: " "got %i args, but wanted %i. (@cached cannot key off *args or " - "**kwargs)" - % (orig.__name__, len(all_args), num_args) + "**kwargs)" % (orig.__name__, len(all_args), num_args) ) self.num_args = num_args # list of the names of the args used as the cache key - self.arg_names = all_args[1:num_args + 1] + self.arg_names = all_args[1 : num_args + 1] # self.arg_defaults is a map of arg name to its default value for each # argument that has a default value if arg_spec.defaults: - self.arg_defaults = dict(zip( - all_args[-len(arg_spec.defaults):], - arg_spec.defaults - )) + self.arg_defaults = dict( + zip(all_args[-len(arg_spec.defaults) :], arg_spec.defaults) + ) else: self.arg_defaults = {} if "cache_context" in self.arg_names: - raise Exception( - "cache_context arg cannot be included among the cache keys" - ) + raise Exception("cache_context arg cannot be included among the cache keys") self.add_cache_context = cache_context @@ -304,12 +295,24 @@ class CacheDescriptor(_CacheDescriptorBase): ``cache_context``) to use as cache keys. Defaults to all named args of the function. """ - def __init__(self, orig, max_entries=1000, num_args=None, tree=False, - inlineCallbacks=False, cache_context=False, iterable=False): + + def __init__( + self, + orig, + max_entries=1000, + num_args=None, + tree=False, + inlineCallbacks=False, + cache_context=False, + iterable=False, + ): super(CacheDescriptor, self).__init__( - orig, num_args=num_args, inlineCallbacks=inlineCallbacks, - cache_context=cache_context) + orig, + num_args=num_args, + inlineCallbacks=inlineCallbacks, + cache_context=cache_context, + ) max_entries = int(max_entries * get_cache_factor_for(orig.__name__)) @@ -356,7 +359,9 @@ class CacheDescriptor(_CacheDescriptorBase): return args[0] else: return self.arg_defaults[nm] + else: + def get_cache_key(args, kwargs): return tuple(get_cache_key_gen(args, kwargs)) @@ -383,8 +388,7 @@ class CacheDescriptor(_CacheDescriptorBase): except KeyError: ret = defer.maybeDeferred( - logcontext.preserve_fn(self.function_to_call), - obj, *args, **kwargs + logcontext.preserve_fn(self.function_to_call), obj, *args, **kwargs ) def onErr(f): @@ -437,8 +441,9 @@ class CacheListDescriptor(_CacheDescriptorBase): results. """ - def __init__(self, orig, cached_method_name, list_name, num_args=None, - inlineCallbacks=False): + def __init__( + self, orig, cached_method_name, list_name, num_args=None, inlineCallbacks=False + ): """ Args: orig (function) @@ -451,7 +456,8 @@ class CacheListDescriptor(_CacheDescriptorBase): be wrapped by defer.inlineCallbacks """ super(CacheListDescriptor, self).__init__( - orig, num_args=num_args, inlineCallbacks=inlineCallbacks) + orig, num_args=num_args, inlineCallbacks=inlineCallbacks + ) self.list_name = list_name @@ -463,7 +469,7 @@ class CacheListDescriptor(_CacheDescriptorBase): if self.list_name not in self.arg_names: raise Exception( "Couldn't see arguments %r for %r." - % (self.list_name, cached_method_name,) + % (self.list_name, cached_method_name) ) def __get__(self, obj, objtype=None): @@ -494,8 +500,10 @@ class CacheListDescriptor(_CacheDescriptorBase): # If the cache takes a single arg then that is used as the key, # otherwise a tuple is used. if num_args == 1: + def arg_to_cache_key(arg): return arg + else: keylist = list(keyargs) @@ -505,8 +513,7 @@ class CacheListDescriptor(_CacheDescriptorBase): for arg in list_args: try: - res = cache.get(arg_to_cache_key(arg), - callback=invalidate_callback) + res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback) if not isinstance(res, ObservableDeferred): results[arg] = res elif not res.has_succeeded(): @@ -554,18 +561,15 @@ class CacheListDescriptor(_CacheDescriptorBase): args_to_call = dict(arg_dict) args_to_call[self.list_name] = list(missing) - cached_defers.append(defer.maybeDeferred( - logcontext.preserve_fn(self.function_to_call), - **args_to_call - ).addCallbacks(complete_all, errback)) + cached_defers.append( + defer.maybeDeferred( + logcontext.preserve_fn(self.function_to_call), **args_to_call + ).addCallbacks(complete_all, errback) + ) if cached_defers: - d = defer.gatherResults( - cached_defers, - consumeErrors=True, - ).addCallbacks( - lambda _: results, - unwrapFirstError + d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks( + lambda _: results, unwrapFirstError ) return logcontext.make_deferred_yieldable(d) else: @@ -586,8 +590,9 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))): self.cache.invalidate(self.key) -def cached(max_entries=1000, num_args=None, tree=False, cache_context=False, - iterable=False): +def cached( + max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False +): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, @@ -598,8 +603,9 @@ def cached(max_entries=1000, num_args=None, tree=False, cache_context=False, ) -def cachedInlineCallbacks(max_entries=1000, num_args=None, tree=False, - cache_context=False, iterable=False): +def cachedInlineCallbacks( + max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False +): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index 6c0b5a4094..6834e6f3ae 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -35,6 +35,7 @@ class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "va there. value (dict): The full or partial dict value """ + def __len__(self): return len(self.value) @@ -84,13 +85,15 @@ class DictionaryCache(object): self.metrics.inc_hits() if dict_keys is None: - return DictionaryEntry(entry.full, entry.known_absent, dict(entry.value)) + return DictionaryEntry( + entry.full, entry.known_absent, dict(entry.value) + ) else: - return DictionaryEntry(entry.full, entry.known_absent, { - k: entry.value[k] - for k in dict_keys - if k in entry.value - }) + return DictionaryEntry( + entry.full, + entry.known_absent, + {k: entry.value[k] for k in dict_keys if k in entry.value}, + ) self.metrics.inc_misses() return DictionaryEntry(False, set(), {}) diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index f369780277..cddf1ed515 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -28,8 +28,15 @@ SENTINEL = object() class ExpiringCache(object): - def __init__(self, cache_name, clock, max_len=0, expiry_ms=0, - reset_expiry_on_get=False, iterable=False): + def __init__( + self, + cache_name, + clock, + max_len=0, + expiry_ms=0, + reset_expiry_on_get=False, + iterable=False, + ): """ Args: cache_name (str): Name of this cache, used for logging. @@ -67,8 +74,7 @@ class ExpiringCache(object): def f(): return run_as_background_process( - "prune_cache_%s" % self._cache_name, - self._prune_cache, + "prune_cache_%s" % self._cache_name, self._prune_cache ) self._clock.looping_call(f, self._expiry_ms / 2) @@ -153,7 +159,9 @@ class ExpiringCache(object): logger.debug( "[%s] _prune_cache before: %d, after len: %d", - self._cache_name, begin_length, len(self) + self._cache_name, + begin_length, + len(self), ) def __len__(self): diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index b684f24e7b..1536cb64f3 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -49,8 +49,15 @@ class LruCache(object): Can also set callbacks on objects when getting/setting which are fired when that key gets invalidated/evicted. """ - def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None, - evicted_callback=None): + + def __init__( + self, + max_size, + keylen=1, + cache_type=dict, + size_callback=None, + evicted_callback=None, + ): """ Args: max_size (int): @@ -93,9 +100,12 @@ class LruCache(object): cached_cache_len = [0] if size_callback is not None: + def cache_len(): return cached_cache_len[0] + else: + def cache_len(): return len(cache) diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index afb03b2e1b..e493d07b55 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -38,9 +38,7 @@ class ResponseCache(object): self.timeout_sec = timeout_ms / 1000. self._name = name - self._metrics = register_cache( - "response_cache", name, self - ) + self._metrics = register_cache("response_cache", name, self) def size(self): return len(self.pending_result_cache) @@ -100,8 +98,7 @@ class ResponseCache(object): def remove(r): if self.timeout_sec: self.clock.call_later( - self.timeout_sec, - self.pending_result_cache.pop, key, None, + self.timeout_sec, self.pending_result_cache.pop, key, None ) else: self.pending_result_cache.pop(key, None) @@ -147,14 +144,15 @@ class ResponseCache(object): """ result = self.get(key) if not result: - logger.info("[%s]: no cached result for [%s], calculating new one", - self._name, key) + logger.info( + "[%s]: no cached result for [%s], calculating new one", self._name, key + ) d = run_in_background(callback, *args, **kwargs) result = self.set(key, d) elif not isinstance(result, defer.Deferred) or result.called: - logger.info("[%s]: using completed cached result for [%s]", - self._name, key) + logger.info("[%s]: using completed cached result for [%s]", self._name, key) else: - logger.info("[%s]: using incomplete cached result for [%s]", - self._name, key) + logger.info( + "[%s]: using incomplete cached result for [%s]", self._name, key + ) return make_deferred_yieldable(result) diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 625aedc940..235f64049c 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -77,9 +77,8 @@ class StreamChangeCache(object): if stream_pos >= self._earliest_known_stream_pos: changed_entities = { - self._cache[k] for k in self._cache.islice( - start=self._cache.bisect_right(stream_pos), - ) + self._cache[k] + for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)) } result = changed_entities.intersection(entities) @@ -114,8 +113,10 @@ class StreamChangeCache(object): assert type(stream_pos) is int if stream_pos >= self._earliest_known_stream_pos: - return [self._cache[k] for k in self._cache.islice( - start=self._cache.bisect_right(stream_pos))] + return [ + self._cache[k] + for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)) + ] else: return None @@ -136,7 +137,7 @@ class StreamChangeCache(object): while len(self._cache) > self._max_size: k, r = self._cache.popitem(0) self._earliest_known_stream_pos = max( - k, self._earliest_known_stream_pos, + k, self._earliest_known_stream_pos ) self._entity_to_key.pop(r, None) diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index dd4c9e6067..9a72218d85 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -9,6 +9,7 @@ class TreeCache(object): efficiently. Keys must be tuples. """ + def __init__(self): self.size = 0 self.root = {} diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py index 5ba1862506..2af8ca43b1 100644 --- a/synapse/util/caches/ttlcache.py +++ b/synapse/util/caches/ttlcache.py @@ -155,6 +155,7 @@ class TTLCache(object): @attr.s(frozen=True, slots=True) class _CacheEntry(object): """TTLCache entry""" + # expiry_time is the first attribute, so that entries are sorted by expiry. expiry_time = attr.ib() key = attr.ib() diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index e14c8bdfda..5a79db821c 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -51,9 +51,7 @@ class Distributor(object): if name in self.signals: raise KeyError("%r already has a signal named %s" % (self, name)) - self.signals[name] = Signal( - name, - ) + self.signals[name] = Signal(name) if name in self.pre_registration: signal = self.signals[name] @@ -78,11 +76,7 @@ class Distributor(object): if name not in self.signals: raise KeyError("%r does not have a signal named %s" % (self, name)) - run_as_background_process( - name, - self.signals[name].fire, - *args, **kwargs - ) + run_as_background_process(name, self.signals[name].fire, *args, **kwargs) class Signal(object): @@ -118,22 +112,23 @@ class Signal(object): def eb(failure): logger.warning( "%s signal observer %s failed: %r", - self.name, observer, failure, + self.name, + observer, + failure, exc_info=( failure.type, failure.value, - failure.getTracebackObject())) + failure.getTracebackObject(), + ), + ) return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb) - deferreds = [ - run_in_background(do, o) - for o in self.observers - ] + deferreds = [run_in_background(do, o) for o in self.observers] - return make_deferred_yieldable(defer.gatherResults( - deferreds, consumeErrors=True, - )) + return make_deferred_yieldable( + defer.gatherResults(deferreds, consumeErrors=True) + ) def __repr__(self): return "" % (self.name,) diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py index 014edea971..635b897d6c 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py @@ -60,11 +60,10 @@ def _handle_frozendict(obj): # fishing the protected dict out of the object is a bit nasty, # but we don't really want the overhead of copying the dict. return obj._dict - raise TypeError('Object of type %s is not JSON serializable' % - obj.__class__.__name__) + raise TypeError( + "Object of type %s is not JSON serializable" % obj.__class__.__name__ + ) # A JSONEncoder which is capable of encoding frozendics without barfing -frozendict_json_encoder = json.JSONEncoder( - default=_handle_frozendict, -) +frozendict_json_encoder = json.JSONEncoder(default=_handle_frozendict) diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py index 2d7ddc1cbe..1a20c596bf 100644 --- a/synapse/util/httpresourcetree.py +++ b/synapse/util/httpresourcetree.py @@ -45,7 +45,7 @@ def create_resource_tree(desired_tree, root_resource): logger.info("Attaching %s to path %s", res, full_path) last_resource = root_resource - for path_seg in full_path.split(b'/')[1:-1]: + for path_seg in full_path.split(b"/")[1:-1]: if path_seg not in last_resource.listNames(): # resource doesn't exist, so make a "dummy resource" child_resource = NoResource() @@ -60,7 +60,7 @@ def create_resource_tree(desired_tree, root_resource): # =========================== # now attach the actual desired resource - last_path_seg = full_path.split(b'/')[-1] + last_path_seg = full_path.split(b"/")[-1] # if there is already a resource here, thieve its children and # replace it @@ -70,9 +70,7 @@ def create_resource_tree(desired_tree, root_resource): # to be replaced with the desired resource. existing_dummy_resource = resource_mappings[res_id] for child_name in existing_dummy_resource.listNames(): - child_res_id = _resource_id( - existing_dummy_resource, child_name - ) + child_res_id = _resource_id(existing_dummy_resource, child_name) child_resource = resource_mappings[child_res_id] # steal the children res.putChild(child_name, child_resource) diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py index d668e5a6b8..6dce03dd3a 100644 --- a/synapse/util/jsonobject.py +++ b/synapse/util/jsonobject.py @@ -70,7 +70,8 @@ class JsonEncodedObject(object): dict """ d = { - k: _encode(v) for (k, v) in self.__dict__.items() + k: _encode(v) + for (k, v) in self.__dict__.items() if k in self.valid_keys and k not in self.internal_keys } d.update(self.unrecognized_keys) @@ -78,7 +79,8 @@ class JsonEncodedObject(object): def get_internal_dict(self): d = { - k: _encode(v, internal=True) for (k, v) in self.__dict__.items() + k: _encode(v, internal=True) + for (k, v) in self.__dict__.items() if k in self.valid_keys } d.update(self.unrecognized_keys) diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index fe412355d8..413b4aa060 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -42,6 +42,8 @@ try: def get_thread_resource_usage(): return resource.getrusage(RUSAGE_THREAD) + + except Exception: # If the system doesn't support resource.getrusage(RUSAGE_THREAD) then we # won't track resource usage by returning None. @@ -64,8 +66,11 @@ class ContextResourceUsage(object): """ __slots__ = [ - "ru_stime", "ru_utime", - "db_txn_count", "db_txn_duration_sec", "db_sched_duration_sec", + "ru_stime", + "ru_utime", + "db_txn_count", + "db_txn_duration_sec", + "db_sched_duration_sec", "evt_db_fetch_count", ] @@ -100,15 +105,18 @@ class ContextResourceUsage(object): self.evt_db_fetch_count = 0 def __repr__(self): - return ("") % ( - self.ru_stime, - self.ru_utime, - self.db_txn_count, - self.db_txn_duration_sec, - self.db_sched_duration_sec, - self.evt_db_fetch_count,) + return ( + "" + ) % ( + self.ru_stime, + self.ru_utime, + self.db_txn_count, + self.db_txn_duration_sec, + self.db_sched_duration_sec, + self.evt_db_fetch_count, + ) def __iadd__(self, other): """Add another ContextResourceUsage's stats to this one's. @@ -159,11 +167,15 @@ class LoggingContext(object): """ __slots__ = [ - "previous_context", "name", "parent_context", + "previous_context", + "name", + "parent_context", "_resource_usage", "usage_start", - "main_thread", "alive", - "request", "tag", + "main_thread", + "alive", + "request", + "tag", ] thread_local = threading.local() @@ -196,6 +208,7 @@ class LoggingContext(object): def __nonzero__(self): return False + __bool__ = __nonzero__ # python3 sentinel = Sentinel() @@ -261,7 +274,8 @@ class LoggingContext(object): if self.previous_context != old_context: logger.warn( "Expected previous context %r, found %r", - self.previous_context, old_context + self.previous_context, + old_context, ) self.alive = True @@ -285,9 +299,8 @@ class LoggingContext(object): self.alive = False # if we have a parent, pass our CPU usage stats on - if ( - self.parent_context is not None - and hasattr(self.parent_context, '_resource_usage') + if self.parent_context is not None and hasattr( + self.parent_context, "_resource_usage" ): self.parent_context._resource_usage += self._resource_usage @@ -320,9 +333,7 @@ class LoggingContext(object): # When we stop, let's record the cpu used since we started if not self.usage_start: - logger.warning( - "Called stop on logcontext %s without calling start", self, - ) + logger.warning("Called stop on logcontext %s without calling start", self) return usage_end = get_thread_resource_usage() @@ -381,6 +392,7 @@ class LoggingContextFilter(logging.Filter): **defaults: Default values to avoid formatters complaining about missing fields """ + def __init__(self, **defaults): self.defaults = defaults @@ -416,17 +428,12 @@ class PreserveLoggingContext(object): def __enter__(self): """Captures the current logging context""" - self.current_context = LoggingContext.set_current_context( - self.new_context - ) + self.current_context = LoggingContext.set_current_context(self.new_context) if self.current_context: self.has_parent = self.current_context.previous_context is not None if not self.current_context.alive: - logger.debug( - "Entering dead context: %s", - self.current_context, - ) + logger.debug("Entering dead context: %s", self.current_context) def __exit__(self, type, value, traceback): """Restores the current logging context""" @@ -444,10 +451,7 @@ class PreserveLoggingContext(object): if self.current_context is not LoggingContext.sentinel: if not self.current_context.alive: - logger.debug( - "Restoring dead context: %s", - self.current_context, - ) + logger.debug("Restoring dead context: %s", self.current_context) def nested_logging_context(suffix, parent_context=None): @@ -474,15 +478,16 @@ def nested_logging_context(suffix, parent_context=None): if parent_context is None: parent_context = LoggingContext.current_context() return LoggingContext( - parent_context=parent_context, - request=parent_context.request + "-" + suffix, + parent_context=parent_context, request=parent_context.request + "-" + suffix ) def preserve_fn(f): """Function decorator which wraps the function with run_in_background""" + def g(*args, **kwargs): return run_in_background(f, *args, **kwargs) + return g @@ -502,7 +507,7 @@ def run_in_background(f, *args, **kwargs): current = LoggingContext.current_context() try: res = f(*args, **kwargs) - except: # noqa: E722 + except: # noqa: E722 # the assumption here is that the caller doesn't want to be disturbed # by synchronous exceptions, so let's turn them into Failures. return defer.fail() @@ -639,6 +644,4 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs): with LoggingContext(parent_context=logcontext): return f(*args, **kwargs) - return make_deferred_yieldable( - threads.deferToThreadPool(reactor, threadpool, g) - ) + return make_deferred_yieldable(threads.deferToThreadPool(reactor, threadpool, g)) diff --git a/synapse/util/logformatter.py b/synapse/util/logformatter.py index a46bc47ce3..fbf570c756 100644 --- a/synapse/util/logformatter.py +++ b/synapse/util/logformatter.py @@ -29,6 +29,7 @@ class LogFormatter(logging.Formatter): (Normally only stack frames between the point the exception was raised and where it was caught are logged). """ + def __init__(self, *args, **kwargs): super(LogFormatter, self).__init__(*args, **kwargs) @@ -40,7 +41,7 @@ class LogFormatter(logging.Formatter): # check that we actually have an f_back attribute to work around # https://twistedmatrix.com/trac/ticket/9305 - if tb and hasattr(tb.tb_frame, 'f_back'): + if tb and hasattr(tb.tb_frame, "f_back"): sio.write("Capture point (most recent call last):\n") traceback.print_stack(tb.tb_frame.f_back, None, sio) diff --git a/synapse/util/logutils.py b/synapse/util/logutils.py index ef31458226..7df0fa6087 100644 --- a/synapse/util/logutils.py +++ b/synapse/util/logutils.py @@ -44,7 +44,7 @@ def _log_debug_as_f(f, msg, msg_args): lineno=lineno, msg=msg, args=msg_args, - exc_info=None + exc_info=None, ) logger.handle(record) @@ -70,20 +70,11 @@ def log_function(f): r = r[:50] + "..." return r - func_args = [ - "%s=%s" % (k, format(v)) for k, v in bound_args.items() - ] + func_args = ["%s=%s" % (k, format(v)) for k, v in bound_args.items()] - msg_args = { - "func_name": func_name, - "args": ", ".join(func_args) - } + msg_args = {"func_name": func_name, "args": ", ".join(func_args)} - _log_debug_as_f( - f, - "Invoked '%(func_name)s' with args: %(args)s", - msg_args - ) + _log_debug_as_f(f, "Invoked '%(func_name)s' with args: %(args)s", msg_args) return f(*args, **kwargs) @@ -103,19 +94,13 @@ def time_function(f): start = time.clock() try: - _log_debug_as_f( - f, - "[FUNC START] {%s-%d}", - (func_name, id), - ) + _log_debug_as_f(f, "[FUNC START] {%s-%d}", (func_name, id)) r = f(*args, **kwargs) finally: end = time.clock() _log_debug_as_f( - f, - "[FUNC END] {%s-%d} %.3f sec", - (func_name, id, end - start,), + f, "[FUNC END] {%s-%d} %.3f sec", (func_name, id, end - start) ) return r @@ -137,9 +122,8 @@ def trace_function(f): s = inspect.currentframe().f_back to_print = [ - "\t%s:%s %s. Args: args=%s, kwargs=%s" % ( - pathname, linenum, func_name, args, kwargs - ) + "\t%s:%s %s. Args: args=%s, kwargs=%s" + % (pathname, linenum, func_name, args, kwargs) ] while s: if True or s.f_globals["__name__"].startswith("synapse"): @@ -147,9 +131,7 @@ def trace_function(f): args_string = inspect.formatargvalues(*inspect.getargvalues(s)) to_print.append( - "\t%s:%d %s. Args: %s" % ( - filename, lineno, function, args_string - ) + "\t%s:%d %s. Args: %s" % (filename, lineno, function, args_string) ) s = s.f_back @@ -163,7 +145,7 @@ def trace_function(f): lineno=lineno, msg=msg, args=None, - exc_info=None + exc_info=None, ) logger.handle(record) @@ -182,13 +164,13 @@ def get_previous_frames(): filename, lineno, function, _, _ = inspect.getframeinfo(s) args_string = inspect.formatargvalues(*inspect.getargvalues(s)) - to_return.append("{{ %s:%d %s - Args: %s }}" % ( - filename, lineno, function, args_string - )) + to_return.append( + "{{ %s:%d %s - Args: %s }}" % (filename, lineno, function, args_string) + ) s = s.f_back - return ", ". join(to_return) + return ", ".join(to_return) def get_previous_frame(ignore=[]): @@ -201,7 +183,10 @@ def get_previous_frame(ignore=[]): args_string = inspect.formatargvalues(*inspect.getargvalues(s)) return "{{ %s:%d %s - Args: %s }}" % ( - filename, lineno, function, args_string + filename, + lineno, + function, + args_string, ) s = s.f_back diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py index 628a2962d9..631654f297 100644 --- a/synapse/util/manhole.py +++ b/synapse/util/manhole.py @@ -74,27 +74,25 @@ def manhole(username, password, globals): twisted.internet.protocol.Factory: A factory to pass to ``listenTCP`` """ if not isinstance(password, bytes): - password = password.encode('ascii') + password = password.encode("ascii") - checker = checkers.InMemoryUsernamePasswordDatabaseDontUse( - **{username: password} - ) + checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(**{username: password}) rlm = manhole_ssh.TerminalRealm() rlm.chainedProtocolFactory = lambda: insults.ServerProtocol( - SynapseManhole, - dict(globals, __name__="__console__") + SynapseManhole, dict(globals, __name__="__console__") ) factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker])) - factory.publicKeys[b'ssh-rsa'] = Key.fromString(PUBLIC_KEY) - factory.privateKeys[b'ssh-rsa'] = Key.fromString(PRIVATE_KEY) + factory.publicKeys[b"ssh-rsa"] = Key.fromString(PUBLIC_KEY) + factory.privateKeys[b"ssh-rsa"] = Key.fromString(PRIVATE_KEY) return factory class SynapseManhole(ColoredManhole): """Overrides connectionMade to create our own ManholeInterpreter""" + def connectionMade(self): super(SynapseManhole, self).connectionMade() @@ -127,7 +125,7 @@ class SynapseManholeInterpreter(ManholeInterpreter): value = SyntaxError(msg, (filename, lineno, offset, line)) sys.last_value = value lines = traceback.format_exception_only(type, value) - self.write(''.join(lines)) + self.write("".join(lines)) def showtraceback(self): """Display the exception that just occurred. @@ -140,6 +138,6 @@ class SynapseManholeInterpreter(ManholeInterpreter): try: # We remove the first stack item because it is our own code. lines = traceback.format_exception(ei[0], ei[1], last_tb.tb_next) - self.write(''.join(lines)) + self.write("".join(lines)) finally: last_tb = ei = None diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 4b4ac5f6c7..01284d3cf8 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -30,25 +30,31 @@ block_counter = Counter("synapse_util_metrics_block_count", "", ["block_name"]) block_timer = Counter("synapse_util_metrics_block_time_seconds", "", ["block_name"]) block_ru_utime = Counter( - "synapse_util_metrics_block_ru_utime_seconds", "", ["block_name"]) + "synapse_util_metrics_block_ru_utime_seconds", "", ["block_name"] +) block_ru_stime = Counter( - "synapse_util_metrics_block_ru_stime_seconds", "", ["block_name"]) + "synapse_util_metrics_block_ru_stime_seconds", "", ["block_name"] +) block_db_txn_count = Counter( - "synapse_util_metrics_block_db_txn_count", "", ["block_name"]) + "synapse_util_metrics_block_db_txn_count", "", ["block_name"] +) # seconds spent waiting for db txns, excluding scheduling time, in this block block_db_txn_duration = Counter( - "synapse_util_metrics_block_db_txn_duration_seconds", "", ["block_name"]) + "synapse_util_metrics_block_db_txn_duration_seconds", "", ["block_name"] +) # seconds spent waiting for a db connection, in this block block_db_sched_duration = Counter( - "synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"]) + "synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"] +) # Tracks the number of blocks currently active in_flight = InFlightGauge( - "synapse_util_metrics_block_in_flight", "", + "synapse_util_metrics_block_in_flight", + "", labels=["block_name"], sub_metrics=["real_time_max", "real_time_sum"], ) @@ -62,13 +68,18 @@ def measure_func(name): with Measure(self.clock, name): r = yield func(self, *args, **kwargs) defer.returnValue(r) + return measured_func + return wrapper class Measure(object): __slots__ = [ - "clock", "name", "start_context", "start", + "clock", + "name", + "start_context", + "start", "created_context", "start_usage", ] @@ -108,7 +119,9 @@ class Measure(object): if context != self.start_context: logger.warn( "Context has unexpectedly changed from '%s' to '%s'. (%r)", - self.start_context, context, self.name + self.start_context, + context, + self.name, ) return @@ -126,8 +139,7 @@ class Measure(object): block_db_sched_duration.labels(self.name).inc(usage.db_sched_duration_sec) except ValueError: logger.warn( - "Failed to save metrics! OLD: %r, NEW: %r", - self.start_usage, current + "Failed to save metrics! OLD: %r, NEW: %r", self.start_usage, current ) if self.created_context: diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py index 4288312b8a..522acd5aa8 100644 --- a/synapse/util/module_loader.py +++ b/synapse/util/module_loader.py @@ -28,15 +28,13 @@ def load_module(provider): """ # We need to import the module, and then pick the class out of # that, so we split based on the last dot. - module, clz = provider['module'].rsplit(".", 1) + module, clz = provider["module"].rsplit(".", 1) module = importlib.import_module(module) provider_class = getattr(module, clz) try: provider_config = provider_class.parse_config(provider["config"]) except Exception as e: - raise ConfigError( - "Failed to parse config for %r: %r" % (provider['module'], e) - ) + raise ConfigError("Failed to parse config for %r: %r" % (provider["module"], e)) return provider_class, provider_config diff --git a/synapse/util/msisdn.py b/synapse/util/msisdn.py index a6c30e5265..c8bcbe297a 100644 --- a/synapse/util/msisdn.py +++ b/synapse/util/msisdn.py @@ -36,6 +36,6 @@ def phone_number_to_msisdn(country, number): phoneNumber = phonenumbers.parse(number, country) except phonenumbers.NumberParseException: raise SynapseError(400, "Unable to parse phone number") - return phonenumbers.format_number( - phoneNumber, phonenumbers.PhoneNumberFormat.E164 - )[1:] + return phonenumbers.format_number(phoneNumber, phonenumbers.PhoneNumberFormat.E164)[ + 1: + ] diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index b146d137f4..06defa8199 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -56,11 +56,7 @@ class FederationRateLimiter(object): _PerHostRatelimiter """ return self.ratelimiters.setdefault( - host, - _PerHostRatelimiter( - clock=self.clock, - config=self._config, - ) + host, _PerHostRatelimiter(clock=self.clock, config=self._config) ).ratelimit() @@ -112,8 +108,7 @@ class _PerHostRatelimiter(object): # remove any entries from request_times which aren't within the window self.request_times[:] = [ - r for r in self.request_times - if time_now - r < self.window_size + r for r in self.request_times if time_now - r < self.window_size ] # reject the request if we already have too many queued up (either @@ -121,9 +116,7 @@ class _PerHostRatelimiter(object): queue_size = len(self.ready_request_queue) + len(self.sleeping_requests) if queue_size > self.reject_limit: raise LimitExceededError( - retry_after_ms=int( - self.window_size / self.sleep_limit - ), + retry_after_ms=int(self.window_size / self.sleep_limit) ) self.request_times.append(time_now) @@ -143,22 +136,18 @@ class _PerHostRatelimiter(object): logger.debug( "Ratelimit [%s]: len(self.request_times)=%d", - id(request_id), len(self.request_times), + id(request_id), + len(self.request_times), ) if len(self.request_times) > self.sleep_limit: - logger.debug( - "Ratelimiter: sleeping request for %f sec", self.sleep_sec, - ) + logger.debug("Ratelimiter: sleeping request for %f sec", self.sleep_sec) ret_defer = run_in_background(self.clock.sleep, self.sleep_sec) self.sleeping_requests.add(request_id) def on_wait_finished(_): - logger.debug( - "Ratelimit [%s]: Finished sleeping", - id(request_id), - ) + logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id)) self.sleeping_requests.discard(request_id) queue_defer = queue_request() return queue_defer @@ -168,10 +157,7 @@ class _PerHostRatelimiter(object): ret_defer = queue_request() def on_start(r): - logger.debug( - "Ratelimit [%s]: Processing req", - id(request_id), - ) + logger.debug("Ratelimit [%s]: Processing req", id(request_id)) self.current_processing.add(request_id) return r @@ -193,10 +179,7 @@ class _PerHostRatelimiter(object): return make_deferred_yieldable(ret_defer) def _on_exit(self, request_id): - logger.debug( - "Ratelimit [%s]: Processed req", - id(request_id), - ) + logger.debug("Ratelimit [%s]: Processed req", id(request_id)) self.current_processing.discard(request_id) try: # start processing the next item on the queue. diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 5fb18ee1f8..31762e095d 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -24,9 +24,7 @@ from six.moves import range from synapse.api.errors import Codes, SynapseError -_string_with_symbols = ( - string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" -) +_string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" # random_string and random_string_with_symbols are used for a range of things, # some cryptographically important, some less so. We use SystemRandom to make sure @@ -37,13 +35,11 @@ client_secret_regex = re.compile(r"^[0-9a-zA-Z.=_-]+$") def random_string(length): - return ''.join(rand.choice(string.ascii_letters) for _ in range(length)) + return "".join(rand.choice(string.ascii_letters) for _ in range(length)) def random_string_with_symbols(length): - return ''.join( - rand.choice(_string_with_symbols) for _ in range(length) - ) + return "".join(rand.choice(_string_with_symbols) for _ in range(length)) def is_ascii(s): @@ -51,7 +47,7 @@ def is_ascii(s): if PY3: if isinstance(s, bytes): try: - s.decode('ascii').encode('ascii') + s.decode("ascii").encode("ascii") except UnicodeDecodeError: return False except UnicodeEncodeError: @@ -116,7 +112,7 @@ def exception_to_unicode(e): msg = e.args[0] if isinstance(msg, bytes): - return msg.decode('utf-8', errors='replace') + return msg.decode("utf-8", errors="replace") else: return msg diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py index 4cc7d27ce5..34ce7cac16 100644 --- a/synapse/util/threepids.py +++ b/synapse/util/threepids.py @@ -36,25 +36,26 @@ def check_3pid_allowed(hs, medium, address): if hs.config.check_is_for_allowed_local_3pids: data = yield hs.get_simple_http_client().get_json( - "https://%s%s" % ( + "https://%s%s" + % ( hs.config.check_is_for_allowed_local_3pids, - "/_matrix/identity/api/v1/internal-info" + "/_matrix/identity/api/v1/internal-info", ), - {'medium': medium, 'address': address} + {"medium": medium, "address": address}, ) # Check for invalid response - if 'hs' not in data and 'shadow_hs' not in data: + if "hs" not in data and "shadow_hs" not in data: defer.returnValue(False) # Check if this user is intended to register for this homeserver if ( - data.get('hs') != hs.config.server_name - and data.get('shadow_hs') != hs.config.server_name + data.get("hs") != hs.config.server_name + and data.get("shadow_hs") != hs.config.server_name ): defer.returnValue(False) - if data.get('requires_invite', False) and not data.get('invited', False): + if data.get("requires_invite", False) and not data.get("invited", False): # Requires an invite but hasn't been invited defer.returnValue(False) @@ -64,11 +65,13 @@ def check_3pid_allowed(hs, medium, address): for constraint in hs.config.allowed_local_3pids: logger.debug( "Checking 3PID %s (%s) against %s (%s)", - address, medium, constraint['pattern'], constraint['medium'], + address, + medium, + constraint["pattern"], + constraint["medium"], ) - if ( - medium == constraint['medium'] and - re.match(constraint['pattern'], address) + if medium == constraint["medium"] and re.match( + constraint["pattern"], address ): defer.returnValue(True) else: diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py index 3baba3225a..a4d9a462f7 100644 --- a/synapse/util/versionstring.py +++ b/synapse/util/versionstring.py @@ -23,44 +23,53 @@ logger = logging.getLogger(__name__) def get_version_string(module): try: - null = open(os.devnull, 'w') + null = open(os.devnull, "w") cwd = os.path.dirname(os.path.abspath(module.__file__)) try: - git_branch = subprocess.check_output( - ['git', 'rev-parse', '--abbrev-ref', 'HEAD'], - stderr=null, - cwd=cwd, - ).strip().decode('ascii') + git_branch = ( + subprocess.check_output( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], stderr=null, cwd=cwd + ) + .strip() + .decode("ascii") + ) git_branch = "b=" + git_branch except subprocess.CalledProcessError: git_branch = "" try: - git_tag = subprocess.check_output( - ['git', 'describe', '--exact-match'], - stderr=null, - cwd=cwd, - ).strip().decode('ascii') + git_tag = ( + subprocess.check_output( + ["git", "describe", "--exact-match"], stderr=null, cwd=cwd + ) + .strip() + .decode("ascii") + ) git_tag = "t=" + git_tag except subprocess.CalledProcessError: git_tag = "" try: - git_commit = subprocess.check_output( - ['git', 'rev-parse', '--short', 'HEAD'], - stderr=null, - cwd=cwd, - ).strip().decode('ascii') + git_commit = ( + subprocess.check_output( + ["git", "rev-parse", "--short", "HEAD"], stderr=null, cwd=cwd + ) + .strip() + .decode("ascii") + ) except subprocess.CalledProcessError: git_commit = "" try: dirty_string = "-this_is_a_dirty_checkout" - is_dirty = subprocess.check_output( - ['git', 'describe', '--dirty=' + dirty_string], - stderr=null, - cwd=cwd, - ).strip().decode('ascii').endswith(dirty_string) + is_dirty = ( + subprocess.check_output( + ["git", "describe", "--dirty=" + dirty_string], stderr=null, cwd=cwd + ) + .strip() + .decode("ascii") + .endswith(dirty_string) + ) git_dirty = "dirty" if is_dirty else "" except subprocess.CalledProcessError: @@ -68,16 +77,10 @@ def get_version_string(module): if git_branch or git_tag or git_commit or git_dirty: git_version = ",".join( - s for s in - (git_branch, git_tag, git_commit, git_dirty,) - if s + s for s in (git_branch, git_tag, git_commit, git_dirty) if s ) - return ( - "%s (%s)" % ( - module.__version__, git_version, - ) - ) + return "%s (%s)" % (module.__version__, git_version) except Exception as e: logger.info("Failed to check for git repository: %s", e) diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py index 7a9e45aca9..9bf6a44f75 100644 --- a/synapse/util/wheel_timer.py +++ b/synapse/util/wheel_timer.py @@ -69,9 +69,7 @@ class WheelTimer(object): # Add empty entries between the end of the current list and when we want # to insert. This ensures there are no gaps. - self.entries.extend( - _Entry(key) for key in range(last_key, then_key + 1) - ) + self.entries.extend(_Entry(key) for key in range(last_key, then_key + 1)) self.entries[-1].queue.append(obj) diff --git a/synapse/visibility.py b/synapse/visibility.py index 09d8334b26..e50472081a 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -29,12 +29,7 @@ from synapse.types import get_domain_from_id logger = logging.getLogger(__name__) -VISIBILITY_PRIORITY = ( - "world_readable", - "shared", - "invited", - "joined", -) +VISIBILITY_PRIORITY = ("world_readable", "shared", "invited", "joined") MEMBERSHIP_PRIORITY = ( @@ -47,9 +42,14 @@ MEMBERSHIP_PRIORITY = ( @defer.inlineCallbacks -def filter_events_for_client(store, user_id, events, is_peeking=False, - always_include_ids=frozenset(), - apply_retention_policies=True): +def filter_events_for_client( + store, + user_id, + events, + is_peeking=False, + always_include_ids=frozenset(), + apply_retention_policies=True, +): """ Check which events a user is allowed to see @@ -76,23 +76,21 @@ def filter_events_for_client(store, user_id, events, is_peeking=False, # to clients. events = list(e for e in events if not e.internal_metadata.is_soft_failed()) - types = ( - (EventTypes.RoomHistoryVisibility, ""), - (EventTypes.Member, user_id), - ) + types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id)) event_id_to_state = yield store.get_state_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types(types), ) ignore_dict_content = yield store.get_global_account_data_by_type_for_user( - "m.ignored_user_list", user_id, + "m.ignored_user_list", user_id ) # FIXME: This will explode if people upload something incorrect. ignore_list = frozenset( ignore_dict_content.get("ignored_users", {}).keys() - if ignore_dict_content else [] + if ignore_dict_content + else [] ) erased_senders = yield store.are_users_erased((e.sender for e in events)) @@ -211,9 +209,7 @@ def filter_events_for_client(store, user_id, events, is_peeking=False, elif visibility == "invited": # user can also see the event if they were *invited* at the time # of the event. - return ( - event if membership == Membership.INVITE else None - ) + return event if membership == Membership.INVITE else None elif visibility == "shared" and is_peeking: # if the visibility is shared, users cannot see the event unless @@ -246,8 +242,9 @@ def filter_events_for_client(store, user_id, events, is_peeking=False, @defer.inlineCallbacks -def filter_events_for_server(store, server_name, events, redact=True, - check_history_visibility_only=False): +def filter_events_for_server( + store, server_name, events, redact=True, check_history_visibility_only=False +): """Filter a list of events based on whether given server is allowed to see them. @@ -268,15 +265,12 @@ def filter_events_for_server(store, server_name, events, redact=True, def is_sender_erased(event, erased_senders): if erased_senders and erased_senders[event.sender]: - logger.info( - "Sender of %s has been erased, redacting", - event.event_id, - ) + logger.info("Sender of %s has been erased, redacting", event.event_id) return True return False def check_event_is_visible(event, state): - history = state.get((EventTypes.RoomHistoryVisibility, ''), None) + history = state.get((EventTypes.RoomHistoryVisibility, ""), None) if history: visibility = history.content.get("history_visibility", "shared") if visibility in ["invited", "joined"]: @@ -313,8 +307,8 @@ def filter_events_for_server(store, server_name, events, redact=True, event_to_state_ids = yield store.get_state_ids_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types( - types=((EventTypes.RoomHistoryVisibility, ""),), - ) + types=((EventTypes.RoomHistoryVisibility, ""),) + ), ) visibility_ids = set() @@ -335,9 +329,7 @@ def filter_events_for_server(store, server_name, events, redact=True, ) if not check_history_visibility_only: - erased_senders = yield store.are_users_erased( - (e.sender for e in events), - ) + erased_senders = yield store.are_users_erased((e.sender for e in events)) else: # We don't want to check whether users are erased, which is equivalent # to no users having been erased. @@ -369,11 +361,8 @@ def filter_events_for_server(store, server_name, events, redact=True, event_to_state_ids = yield store.get_state_ids_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types( - types=( - (EventTypes.RoomHistoryVisibility, ""), - (EventTypes.Member, None), - ), - ) + types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None)) + ), ) # We only want to pull out member events that correspond to the @@ -397,13 +386,15 @@ def filter_events_for_server(store, server_name, events, redact=True, idx = state_key.find(":") if idx == -1: return False - return state_key[idx + 1:] == server_name - - event_map = yield store.get_events([ - e_id - for e_id, key in iteritems(event_id_to_state_key) - if include(key[0], key[1]) - ]) + return state_key[idx + 1 :] == server_name + + event_map = yield store.get_events( + [ + e_id + for e_id, key in iteritems(event_id_to_state_key) + if include(key[0], key[1]) + ] + ) event_to_state = { e_id: { diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index d0d36f96fa..d4e75b5b2e 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -172,7 +172,7 @@ class AuthTestCase(unittest.TestCase): request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = yield self.auth.get_user_by_req(request) self.assertEquals( - requester.user.to_string(), masquerading_user_id.decode('utf8') + requester.user.to_string(), masquerading_user_id.decode("utf8") ) def test_get_user_by_req_appservice_valid_token_bad_user_id(self): @@ -264,7 +264,7 @@ class AuthTestCase(unittest.TestCase): # check the token works request = Mock(args={}) - request.args[b"access_token"] = [token.encode('ascii')] + request.args[b"access_token"] = [token.encode("ascii")] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = yield self.auth.get_user_by_req(request, allow_guest=True) self.assertEqual(UserID.from_string(USER_ID), requester.user) @@ -277,7 +277,7 @@ class AuthTestCase(unittest.TestCase): # the token should *not* work now request = Mock(args={}) - request.args[b"access_token"] = [guest_tok.encode('ascii')] + request.args[b"access_token"] = [guest_tok.encode("ascii")] request.requestHeaders.getRawHeaders = mock_getRawHeaders() with self.assertRaises(AuthError) as cm: @@ -321,11 +321,11 @@ class AuthTestCase(unittest.TestCase): self.hs.config.limit_usage_by_mau = True self.hs.config.max_mau_value = 1 self.store.get_monthly_active_count = lambda: defer.succeed(2) - threepid = {'medium': 'email', 'address': 'reserved@server.com'} - unknown_threepid = {'medium': 'email', 'address': 'unreserved@server.com'} + threepid = {"medium": "email", "address": "reserved@server.com"} + unknown_threepid = {"medium": "email", "address": "unreserved@server.com"} self.hs.config.mau_limits_reserved_threepids = [threepid] - yield self.store.register(user_id='user1', token="123", password_hash=None) + yield self.store.register(user_id="user1", token="123", password_hash=None) with self.assertRaises(ResourceLimitError): yield self.auth.check_auth_blocking() diff --git a/tests/config/test_server.py b/tests/config/test_server.py index de64965a60..1ca5ea54ca 100644 --- a/tests/config/test_server.py +++ b/tests/config/test_server.py @@ -20,10 +20,10 @@ from tests import unittest class ServerConfigTestCase(unittest.TestCase): def test_is_threepid_reserved(self): - user1 = {'medium': 'email', 'address': 'user1@example.com'} - user2 = {'medium': 'email', 'address': 'user2@example.com'} - user3 = {'medium': 'email', 'address': 'user3@example.com'} - user1_msisdn = {'medium': 'msisdn', 'address': '447700000000'} + user1 = {"medium": "email", "address": "user1@example.com"} + user2 = {"medium": "email", "address": "user2@example.com"} + user3 = {"medium": "email", "address": "user3@example.com"} + user1_msisdn = {"medium": "msisdn", "address": "447700000000"} config = [user1, user2] self.assertTrue(is_threepid_reserved(config, user1)) diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py index 40ca428778..0cbbf4e885 100644 --- a/tests/config/test_tls.py +++ b/tests/config/test_tls.py @@ -32,7 +32,7 @@ class TLSConfigTests(TestCase): """ config_dir = self.mktemp() os.mkdir(config_dir) - with open(os.path.join(config_dir, "cert.pem"), 'w') as f: + with open(os.path.join(config_dir, "cert.pem"), "w") as f: f.write( """-----BEGIN CERTIFICATE----- MIID6DCCAtACAws9CjANBgkqhkiG9w0BAQUFADCBtzELMAkGA1UEBhMCVFIxDzAN diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py index 71aa731439..126e176004 100644 --- a/tests/crypto/test_event_signing.py +++ b/tests/crypto/test_event_signing.py @@ -41,25 +41,25 @@ class EventSigningTestCase(unittest.TestCase): def test_sign_minimal(self): event_dict = { - 'event_id': "$0:domain", - 'origin': "domain", - 'origin_server_ts': 1000000, - 'signatures': {}, - 'type': "X", - 'unsigned': {'age_ts': 1000000}, + "event_id": "$0:domain", + "origin": "domain", + "origin_server_ts": 1000000, + "signatures": {}, + "type": "X", + "unsigned": {"age_ts": 1000000}, } add_hashes_and_signatures(event_dict, HOSTNAME, self.signing_key) event = FrozenEvent(event_dict) - self.assertTrue(hasattr(event, 'hashes')) - self.assertIn('sha256', event.hashes) + self.assertTrue(hasattr(event, "hashes")) + self.assertIn("sha256", event.hashes) self.assertEquals( - event.hashes['sha256'], "6tJjLpXtggfke8UxFhAKg82QVkJzvKOVOOSjUDK4ZSI" + event.hashes["sha256"], "6tJjLpXtggfke8UxFhAKg82QVkJzvKOVOOSjUDK4ZSI" ) - self.assertTrue(hasattr(event, 'signatures')) + self.assertTrue(hasattr(event, "signatures")) self.assertIn(HOSTNAME, event.signatures) self.assertIn(KEY_NAME, event.signatures["domain"]) self.assertEquals( @@ -70,28 +70,28 @@ class EventSigningTestCase(unittest.TestCase): def test_sign_message(self): event_dict = { - 'content': {'body': "Here is the message content"}, - 'event_id': "$0:domain", - 'origin': "domain", - 'origin_server_ts': 1000000, - 'type': "m.room.message", - 'room_id': "!r:domain", - 'sender': "@u:domain", - 'signatures': {}, - 'unsigned': {'age_ts': 1000000}, + "content": {"body": "Here is the message content"}, + "event_id": "$0:domain", + "origin": "domain", + "origin_server_ts": 1000000, + "type": "m.room.message", + "room_id": "!r:domain", + "sender": "@u:domain", + "signatures": {}, + "unsigned": {"age_ts": 1000000}, } add_hashes_and_signatures(event_dict, HOSTNAME, self.signing_key) event = FrozenEvent(event_dict) - self.assertTrue(hasattr(event, 'hashes')) - self.assertIn('sha256', event.hashes) + self.assertTrue(hasattr(event, "hashes")) + self.assertIn("sha256", event.hashes) self.assertEquals( - event.hashes['sha256'], "onLKD1bGljeBWQhWZ1kaP9SorVmRQNdN5aM2JYU2n/g" + event.hashes["sha256"], "onLKD1bGljeBWQhWZ1kaP9SorVmRQNdN5aM2JYU2n/g" ) - self.assertTrue(hasattr(event, 'signatures')) + self.assertTrue(hasattr(event, "signatures")) self.assertIn(HOSTNAME, event.signatures) self.assertIn(KEY_NAME, event.signatures["domain"]) self.assertEquals( diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index d0cc492deb..9e3d4d0f47 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -37,88 +37,88 @@ class PruneEventTestCase(unittest.TestCase): def test_minimal(self): self.run_test( - {'type': 'A', 'event_id': '$test:domain'}, + {"type": "A", "event_id": "$test:domain"}, { - 'type': 'A', - 'event_id': '$test:domain', - 'content': {}, - 'signatures': {}, - 'unsigned': {}, + "type": "A", + "event_id": "$test:domain", + "content": {}, + "signatures": {}, + "unsigned": {}, }, ) def test_basic_keys(self): self.run_test( { - 'type': 'A', - 'room_id': '!1:domain', - 'sender': '@2:domain', - 'event_id': '$3:domain', - 'origin': 'domain', + "type": "A", + "room_id": "!1:domain", + "sender": "@2:domain", + "event_id": "$3:domain", + "origin": "domain", }, { - 'type': 'A', - 'room_id': '!1:domain', - 'sender': '@2:domain', - 'event_id': '$3:domain', - 'origin': 'domain', - 'content': {}, - 'signatures': {}, - 'unsigned': {}, + "type": "A", + "room_id": "!1:domain", + "sender": "@2:domain", + "event_id": "$3:domain", + "origin": "domain", + "content": {}, + "signatures": {}, + "unsigned": {}, }, ) def test_unsigned_age_ts(self): self.run_test( - {'type': 'B', 'event_id': '$test:domain', 'unsigned': {'age_ts': 20}}, + {"type": "B", "event_id": "$test:domain", "unsigned": {"age_ts": 20}}, { - 'type': 'B', - 'event_id': '$test:domain', - 'content': {}, - 'signatures': {}, - 'unsigned': {'age_ts': 20}, + "type": "B", + "event_id": "$test:domain", + "content": {}, + "signatures": {}, + "unsigned": {"age_ts": 20}, }, ) self.run_test( { - 'type': 'B', - 'event_id': '$test:domain', - 'unsigned': {'other_key': 'here'}, + "type": "B", + "event_id": "$test:domain", + "unsigned": {"other_key": "here"}, }, { - 'type': 'B', - 'event_id': '$test:domain', - 'content': {}, - 'signatures': {}, - 'unsigned': {}, + "type": "B", + "event_id": "$test:domain", + "content": {}, + "signatures": {}, + "unsigned": {}, }, ) def test_content(self): self.run_test( - {'type': 'C', 'event_id': '$test:domain', 'content': {'things': 'here'}}, + {"type": "C", "event_id": "$test:domain", "content": {"things": "here"}}, { - 'type': 'C', - 'event_id': '$test:domain', - 'content': {}, - 'signatures': {}, - 'unsigned': {}, + "type": "C", + "event_id": "$test:domain", + "content": {}, + "signatures": {}, + "unsigned": {}, }, ) self.run_test( { - 'type': 'm.room.create', - 'event_id': '$test:domain', - 'content': {'creator': '@2:domain', 'other_field': 'here'}, + "type": "m.room.create", + "event_id": "$test:domain", + "content": {"creator": "@2:domain", "other_field": "here"}, }, { - 'type': 'm.room.create', - 'event_id': '$test:domain', - 'content': {'creator': '@2:domain'}, - 'signatures': {}, - 'unsigned': {}, + "type": "m.room.create", + "event_id": "$test:domain", + "content": {"creator": "@2:domain"}, + "signatures": {}, + "unsigned": {}, }, ) diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 1e3e5aec66..a5b03005d7 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -32,7 +32,7 @@ class RoomComplexityTests(unittest.HomeserverTestCase): login.register_servlets, ] - def default_config(self, name='test'): + def default_config(self, name="test"): config = super(RoomComplexityTests, self).default_config(name=name) config["limit_large_remote_room_joins"] = True config["limit_large_remote_room_complexity"] = 0.05 diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 7bb106b5f7..cce8d8c6de 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -51,16 +51,16 @@ class FederationSenderTestCases(HomeserverTestCase): json_cb = mock_send_transaction.call_args[0][1] data = json_cb() self.assertEqual( - data['edus'], + data["edus"], [ { - 'edu_type': 'm.receipt', - 'content': { - 'room_id': { - 'm.read': { - 'user_id': { - 'event_ids': ['event_id'], - 'data': {'ts': 1234}, + "edu_type": "m.receipt", + "content": { + "room_id": { + "m.read": { + "user_id": { + "event_ids": ["event_id"], + "data": {"ts": 1234}, } } } @@ -93,16 +93,16 @@ class FederationSenderTestCases(HomeserverTestCase): json_cb = mock_send_transaction.call_args[0][1] data = json_cb() self.assertEqual( - data['edus'], + data["edus"], [ { - 'edu_type': 'm.receipt', - 'content': { - 'room_id': { - 'm.read': { - 'user_id': { - 'event_ids': ['event_id'], - 'data': {'ts': 1234}, + "edu_type": "m.receipt", + "content": { + "room_id": { + "m.read": { + "user_id": { + "event_ids": ["event_id"], + "data": {"ts": 1234}, } } } @@ -128,16 +128,16 @@ class FederationSenderTestCases(HomeserverTestCase): json_cb = mock_send_transaction.call_args[0][1] data = json_cb() self.assertEqual( - data['edus'], + data["edus"], [ { - 'edu_type': 'm.receipt', - 'content': { - 'room_id': { - 'm.read': { - 'user_id': { - 'event_ids': ['other_id'], - 'data': {'ts': 1234}, + "edu_type": "m.receipt", + "content": { + "room_id": { + "m.read": { + "user_id": { + "event_ids": ["other_id"], + "data": {"ts": 1234}, } } } diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 1e39fe0ec2..b204a0700d 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -117,7 +117,7 @@ class AuthTestCase(unittest.TestCase): def test_mau_limits_disabled(self): self.hs.config.limit_usage_by_mau = False # Ensure does not throw exception - yield self.auth_handler.get_access_token_for_user_id('user_a') + yield self.auth_handler.get_access_token_for_user_id("user_a") yield self.auth_handler.validate_short_term_login_token_and_get_user_id( self._get_macaroon().serialize() @@ -131,7 +131,7 @@ class AuthTestCase(unittest.TestCase): ) with self.assertRaises(ResourceLimitError): - yield self.auth_handler.get_access_token_for_user_id('user_a') + yield self.auth_handler.get_access_token_for_user_id("user_a") self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.large_number_of_users) @@ -150,7 +150,7 @@ class AuthTestCase(unittest.TestCase): return_value=defer.succeed(self.hs.config.max_mau_value) ) with self.assertRaises(ResourceLimitError): - yield self.auth_handler.get_access_token_for_user_id('user_a') + yield self.auth_handler.get_access_token_for_user_id("user_a") self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) @@ -166,7 +166,7 @@ class AuthTestCase(unittest.TestCase): self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) - yield self.auth_handler.get_access_token_for_user_id('user_a') + yield self.auth_handler.get_access_token_for_user_id("user_a") self.hs.get_datastore().user_last_seen_monthly_active = Mock( return_value=defer.succeed(self.hs.get_clock().time_msec()) ) @@ -185,7 +185,7 @@ class AuthTestCase(unittest.TestCase): return_value=defer.succeed(self.small_number_of_users) ) # Ensure does not raise exception - yield self.auth_handler.get_access_token_for_user_id('user_a') + yield self.auth_handler.get_access_token_for_user_id("user_a") self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.small_number_of_users) diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 917548bb31..91c7a17070 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -132,7 +132,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): request, channel = self.make_request( "PUT", b"directory/room/%23test%3Atest", - ('{"room_id":"%s"}' % (room_id,)).encode('ascii'), + ('{"room_id":"%s"}' % (room_id,)).encode("ascii"), ) self.render(request) self.assertEquals(403, channel.code, channel.result) @@ -143,7 +143,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): request, channel = self.make_request( "PUT", b"directory/room/%23unofficial_test%3Atest", - ('{"room_id":"%s"}' % (room_id,)).encode('ascii'), + ('{"room_id":"%s"}' % (room_id,)).encode("ascii"), ) self.render(request) self.assertEquals(200, channel.code, channel.result) @@ -158,7 +158,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): room_id = self.helper.create_room_as(self.user_id) request, channel = self.make_request( - "PUT", b"directory/list/room/%s" % (room_id.encode('ascii'),), b'{}' + "PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}" ) self.render(request) self.assertEquals(200, channel.code, channel.result) @@ -190,7 +190,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): # Room list disabled so we shouldn't be allowed to publish rooms room_id = self.helper.create_room_as(self.user_id) request, channel = self.make_request( - "PUT", b"directory/list/room/%s" % (room_id.encode('ascii'),), b'{}' + "PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}" ) self.render(request) self.assertEquals(403, channel.code, channel.result) diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index 2e72a1dd23..c4503c1611 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -395,37 +395,37 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): yield self.handler.upload_room_keys(self.local_user, version, room_keys) new_room_keys = copy.deepcopy(room_keys) - new_room_key = new_room_keys['rooms']['!abc:matrix.org']['sessions']['c0ff33'] + new_room_key = new_room_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"] # test that increasing the message_index doesn't replace the existing session - new_room_key['first_message_index'] = 2 - new_room_key['session_data'] = 'new' + new_room_key["first_message_index"] = 2 + new_room_key["session_data"] = "new" yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) res = yield self.handler.get_room_keys(self.local_user, version) self.assertEqual( - res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], + res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "SSBBTSBBIEZJU0gK", ) # test that marking the session as verified however /does/ replace it - new_room_key['is_verified'] = True + new_room_key["is_verified"] = True yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) res = yield self.handler.get_room_keys(self.local_user, version) self.assertEqual( - res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], "new" + res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) # test that a session with a higher forwarded_count doesn't replace one # with a lower forwarding count - new_room_key['forwarded_count'] = 2 - new_room_key['session_data'] = 'other' + new_room_key["forwarded_count"] = 2 + new_room_key["session_data"] = "other" yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) res = yield self.handler.get_room_keys(self.local_user, version) self.assertEqual( - res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], "new" + res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) # TODO: check edge cases as well as the common variations here diff --git a/tests/handlers/test_identity.py b/tests/handlers/test_identity.py index 99ce94db52..b8ecdc5c9b 100644 --- a/tests/handlers/test_identity.py +++ b/tests/handlers/test_identity.py @@ -38,24 +38,18 @@ class ThreepidISRewrittenURLTestCase(unittest.HomeserverTestCase): self.rewritten_is_url = "int.testis" config = self.default_config() - config["trusted_third_party_id_servers"] = [ - self.is_server_name, - ] + config["trusted_third_party_id_servers"] = [self.is_server_name] config["rewrite_identity_server_urls"] = { - self.is_server_name: self.rewritten_is_url, + self.is_server_name: self.rewritten_is_url } - mock_http_client = Mock(spec=[ - "post_urlencoded_get_json", - ]) - mock_http_client.post_urlencoded_get_json.return_value = defer.succeed({ - "address": self.address, - "medium": "email", - }) + mock_http_client = Mock(spec=["post_urlencoded_get_json"]) + mock_http_client.post_urlencoded_get_json.return_value = defer.succeed( + {"address": self.address, "medium": "email"} + ) self.hs = self.setup_test_homeserver( - config=config, - simple_http_client=mock_http_client, + config=config, simple_http_client=mock_http_client ) return self.hs @@ -74,35 +68,34 @@ class ThreepidISRewrittenURLTestCase(unittest.HomeserverTestCase): post_urlenc_get_json = self.hs.get_simple_http_client().post_urlencoded_get_json store = self.hs.get_datastore() - creds = { - "sid": "123", - "client_secret": "some_secret", - } + creds = {"sid": "123", "client_secret": "some_secret"} # Make sure processing the mocked response goes through. - data = self.get_success(handler.bind_threepid( - { - "id_server": self.is_server_name, - "client_secret": creds["client_secret"], - "sid": creds["sid"], - }, - self.user_id, - )) + data = self.get_success( + handler.bind_threepid( + { + "id_server": self.is_server_name, + "client_secret": creds["client_secret"], + "sid": creds["sid"], + }, + self.user_id, + ) + ) self.assertEqual(data.get("address"), self.address) # Check that the request was done against the rewritten server name. post_urlenc_get_json.assert_called_once_with( "https://%s/_matrix/identity/api/v1/3pid/bind" % self.rewritten_is_url, { - 'sid': creds['sid'], - 'client_secret': creds["client_secret"], - 'mxid': self.user_id, - } + "sid": creds["sid"], + "client_secret": creds["client_secret"], + "mxid": self.user_id, + }, ) # Check that the original server name is saved in the database instead of the # rewritten one. - id_servers = self.get_success(store.get_id_servers_user_bound( - self.user_id, "email", self.address - )) + id_servers = self.get_success( + store.get_id_servers_user_bound(self.user_id, "email", self.address) + ) self.assertEqual(id_servers, [self.is_server_name]) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 45cbfeb9a4..2311040201 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -71,9 +71,7 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_my_name(self): - yield self.store.set_profile_displayname( - self.frank.localpart, "Frank", 1, - ) + yield self.store.set_profile_displayname(self.frank.localpart, "Frank", 1) displayname = yield self.handler.get_displayname(self.frank) @@ -127,7 +125,7 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_my_avatar(self): yield self.store.set_profile_avatar_url( - self.frank.localpart, "http://my.server/me.png", 1, + self.frank.localpart, "http://my.server/me.png", 1 ) avatar_url = yield self.handler.get_avatar_url(self.frank) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index b2aa5c2478..e5a9acd87f 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -53,7 +53,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.mock_distributor.declare("registered_user") self.mock_captcha_client = Mock() self.macaroon_generator = Mock( - generate_access_token=Mock(return_value='secret') + generate_access_token=Mock(return_value="secret") ) self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator) self.handler = self.hs.get_registration_handler() @@ -72,7 +72,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ) self.assertEquals(result_user_id, user_id) self.assertTrue(result_token is not None) - self.assertEquals(result_token, 'secret') + self.assertEquals(result_token, "secret") def test_if_user_exists(self): store = self.hs.get_datastore() @@ -97,7 +97,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.hs.config.limit_usage_by_mau = False # Ensure does not throw exception self.get_success( - self.handler.get_or_create_user(self.requester, 'a', "display_name") + self.handler.get_or_create_user(self.requester, "a", "display_name") ) def test_get_or_create_user_mau_not_blocked(self): @@ -106,7 +106,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): return_value=defer.succeed(self.hs.config.max_mau_value - 1) ) # Ensure does not throw exception - self.get_success(self.handler.get_or_create_user(self.requester, 'c', "User")) + self.get_success(self.handler.get_or_create_user(self.requester, "c", "User")) def test_get_or_create_user_mau_blocked(self): self.hs.config.limit_usage_by_mau = True @@ -114,7 +114,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): return_value=defer.succeed(self.lots_of_users) ) self.get_failure( - self.handler.get_or_create_user(self.requester, 'b', "display_name"), + self.handler.get_or_create_user(self.requester, "b", "display_name"), ResourceLimitError, ) @@ -122,7 +122,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): return_value=defer.succeed(self.hs.config.max_mau_value) ) self.get_failure( - self.handler.get_or_create_user(self.requester, 'b', "display_name"), + self.handler.get_or_create_user(self.requester, "b", "display_name"), ResourceLimitError, ) @@ -145,13 +145,13 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_auto_create_auto_join_rooms(self): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - res = self.get_success(self.handler.register(localpart='jeff')) + res = self.get_success(self.handler.register(localpart="jeff")) rooms = self.get_success(self.store.get_rooms_for_user(res[0])) directory_handler = self.hs.get_handlers().directory_handler room_alias = RoomAlias.from_string(room_alias_str) room_id = self.get_success(directory_handler.get_association(room_alias)) - self.assertTrue(room_id['room_id'] in rooms) + self.assertTrue(room_id["room_id"] in rooms) self.assertEqual(len(rooms), 1) def test_auto_create_auto_join_rooms_with_no_rooms(self): @@ -174,7 +174,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.hs.config.autocreate_auto_join_rooms = False room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - res = self.get_success(self.handler.register(localpart='jeff')) + res = self.get_success(self.handler.register(localpart="jeff")) rooms = self.get_success(self.store.get_rooms_for_user(res[0])) self.assertEqual(len(rooms), 0) @@ -183,7 +183,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.hs.config.auto_join_rooms = [room_alias_str] self.store.is_support_user = Mock(return_value=True) - res = self.get_success(self.handler.register(localpart='support')) + res = self.get_success(self.handler.register(localpart="support")) rooms = self.get_success(self.store.get_rooms_for_user(res[0])) self.assertEqual(len(rooms), 0) directory_handler = self.hs.get_handlers().directory_handler @@ -212,7 +212,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): # When:- # * the user is registered and post consent actions are called - res = self.get_success(self.handler.register(localpart='jeff')) + res = self.get_success(self.handler.register(localpart="jeff")) self.get_success(self.handler.post_consent_actions(res[0])) # Then:- @@ -222,37 +222,27 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_register_support_user(self): res = self.get_success( - self.handler.register(localpart='user', user_type=UserTypes.SUPPORT) + self.handler.register(localpart="user", user_type=UserTypes.SUPPORT) ) self.assertTrue(self.store.is_support_user(res[0])) def test_register_not_support_user(self): - res = self.get_success(self.handler.register(localpart='user')) + res = self.get_success(self.handler.register(localpart="user")) self.assertFalse(self.store.is_support_user(res[0])) def test_invalid_user_id_length(self): invalid_user_id = "x" * 256 - self.get_failure( - self.handler.register(localpart=invalid_user_id), - SynapseError - ) + self.get_failure(self.handler.register(localpart=invalid_user_id), SynapseError) def test_email_to_displayname_mapping(self): """Test that custom emails are mapped to new user displaynames correctly""" self._check_mapping( - "jack-phillips.rivers@big-org.com", - "Jack-Phillips Rivers [Big-Org]", + "jack-phillips.rivers@big-org.com", "Jack-Phillips Rivers [Big-Org]" ) - self._check_mapping( - "bob.jones@matrix.org", - "Bob Jones [Tchap Admin]", - ) + self._check_mapping("bob.jones@matrix.org", "Bob Jones [Tchap Admin]") - self._check_mapping( - "bob-jones.blabla@gouv.fr", - "Bob-Jones Blabla [Gouv]", - ) + self._check_mapping("bob-jones.blabla@gouv.fr", "Bob-Jones Blabla [Gouv]") # Multibyte unicode characters self._check_mapping( diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 2710c991cf..a8b858eb4f 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -265,10 +265,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): while not self.get_success(self.store.has_completed_background_updates()): self.get_success(self.store.do_next_background_update(100), by=0.1) - events = { - "a1": None, - "a2": {"membership": Membership.JOIN}, - } + events = {"a1": None, "a2": {"membership": Membership.JOIN}} def get_event(event_id, allow_none=True): if events.get(event_id): diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index cb8b4d2913..5d5e324df2 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -47,7 +47,7 @@ def _expect_edu_transaction(edu_type, content, origin="test"): def _make_edu_transaction_json(edu_type, content): - return json.dumps(_expect_edu_transaction(edu_type, content)).encode('utf8') + return json.dumps(_expect_edu_transaction(edu_type, content)).encode("utf8") class TypingNotificationsTestCase(unittest.HomeserverTestCase): @@ -151,7 +151,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) ) - self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])]) + self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) self.assertEquals(self.event_source.get_current_key(), 1) events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) @@ -209,12 +209,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): "typing": True, }, ), - federation_auth_origin=b'farm', + federation_auth_origin=b"farm", ) self.render(request) self.assertEqual(channel.code, 200) - self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])]) + self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) self.assertEquals(self.event_source.get_current_key(), 1) events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) @@ -247,7 +247,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) ) - self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])]) + self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) put_json = self.hs.get_http_client().put_json put_json.assert_called_once_with( @@ -285,7 +285,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) ) - self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])]) + self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 1) @@ -303,7 +303,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.reactor.pump([16]) - self.on_new_event.assert_has_calls([call('typing_key', 2, rooms=[ROOM_ID])]) + self.on_new_event.assert_has_calls([call("typing_key", 2, rooms=[ROOM_ID])]) self.assertEquals(self.event_source.get_current_key(), 2) events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1) @@ -320,7 +320,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) ) - self.on_new_event.assert_has_calls([call('typing_key', 3, rooms=[ROOM_ID])]) + self.on_new_event.assert_has_calls([call("typing_key", 3, rooms=[ROOM_ID])]) self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 3) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 9021e647fe..b135486c48 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -60,15 +60,15 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) profile = self.get_success(self.store.get_user_in_directory(support_user_id)) self.assertTrue(profile is None) - display_name = 'display_name' + display_name = "display_name" - profile_info = ProfileInfo(avatar_url='avatar_url', display_name=display_name) - regular_user_id = '@regular:test' + profile_info = ProfileInfo(avatar_url="avatar_url", display_name=display_name) + regular_user_id = "@regular:test" self.get_success( self.handler.handle_local_profile_change(regular_user_id, profile_info) ) profile = self.get_success(self.store.get_user_in_directory(regular_user_id)) - self.assertTrue(profile['display_name'] == display_name) + self.assertTrue(profile["display_name"] == display_name) def test_handle_user_deactivated_support_user(self): s_user_id = "@support:test" diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 6609d6b366..2fe3567dfa 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -53,13 +53,15 @@ def get_connection_factory(): # this needs to happen once, but not until we are ready to run the first test global test_server_connection_factory if test_server_connection_factory is None: - test_server_connection_factory = TestServerTLSConnectionFactory(sanlist=[ - b'DNS:testserv', - b'DNS:target-server', - b'DNS:xn--bcher-kva.com', - b'IP:1.2.3.4', - b'IP:::1', - ]) + test_server_connection_factory = TestServerTLSConnectionFactory( + sanlist=[ + b"DNS:testserv", + b"DNS:target-server", + b"DNS:xn--bcher-kva.com", + b"IP:1.2.3.4", + b"IP:::1", + ] + ) return test_server_connection_factory @@ -138,7 +140,7 @@ class MatrixFederationAgentTests(TestCase): Sends a simple GET request via the agent, and checks its logcontext management """ with LoggingContext("one") as context: - fetch_d = self.agent.request(b'GET', uri) + fetch_d = self.agent.request(b"GET", uri) # Nothing happened yet self.assertNoResult(fetch_d) @@ -182,9 +184,9 @@ class MatrixFederationAgentTests(TestCase): """Check that an incoming request looks like a valid .well-known request, and send back the response. """ - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/.well-known/matrix/server') - self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/.well-known/matrix/server") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"]) # send back a response for k, v in headers.items(): request.setHeader(k, v) @@ -207,7 +209,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8448) # make a test server, and wire up the client @@ -215,20 +217,20 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), [b'testserv:8448'] + request.requestHeaders.getRawHeaders(b"host"), [b"testserv:8448"] ) content = request.content.read() - self.assertEqual(content, b'') + self.assertEqual(content, b"") # Deferred is still without a result self.assertNoResult(test_d) # send the headers - request.responseHeaders.setRawHeaders(b'Content-Type', [b'application/json']) - request.write('') + request.responseHeaders.setRawHeaders(b"Content-Type", [b"application/json"]) + request.write("") self.reactor.pump((0.1,)) @@ -238,7 +240,7 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(response.code, 200) # Send the body - request.write('{ "a": 1 }'.encode('ascii')) + request.write('{ "a": 1 }'.encode("ascii")) request.finish() self.reactor.pump((0.1,)) @@ -263,7 +265,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8448) # make a test server, and wire up the client @@ -271,9 +273,9 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') - self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'1.2.3.4']) + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"1.2.3.4"]) # finish the request request.finish() @@ -298,7 +300,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '::1') + self.assertEqual(host, "::1") self.assertEqual(port, 8448) # make a test server, and wire up the client @@ -306,9 +308,9 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') - self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'[::1]']) + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"[::1]"]) # finish the request request.finish() @@ -333,7 +335,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '::1') + self.assertEqual(host, "::1") self.assertEqual(port, 80) # make a test server, and wire up the client @@ -341,9 +343,9 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') - self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'[::1]:80']) + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"[::1]:80"]) # finish the request request.finish() @@ -369,7 +371,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) # fonx the connection @@ -387,11 +389,11 @@ class MatrixFederationAgentTests(TestCase): # we should fall back to a direct connection self.assertEqual(len(clients), 2) (host, port, client_factory, _timeout, _bindAddress) = clients[1] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8448) # make a test server, and wire up the client - http_server = self._make_connection(client_factory, expected_sni=b'testserv1') + http_server = self._make_connection(client_factory, expected_sni=b"testserv1") # there should be no requests self.assertEqual(len(http_server.requests), 0) @@ -418,7 +420,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.5') + self.assertEqual(host, "1.2.3.5") self.assertEqual(port, 8448) # make a test server, and wire up the client @@ -452,7 +454,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) # fonx the connection @@ -470,17 +472,17 @@ class MatrixFederationAgentTests(TestCase): # we should fall back to a direct connection self.assertEqual(len(clients), 2) (host, port, client_factory, _timeout, _bindAddress) = clients[1] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8448) # make a test server, and wire up the client - http_server = self._make_connection(client_factory, expected_sni=b'testserv') + http_server = self._make_connection(client_factory, expected_sni=b"testserv") self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') - self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"]) # finish the request request.finish() @@ -504,7 +506,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) self._handle_well_known_connection( @@ -521,20 +523,20 @@ class MatrixFederationAgentTests(TestCase): # now we should get a connection to the target server self.assertEqual(len(clients), 2) (host, port, client_factory, _timeout, _bindAddress) = clients[1] - self.assertEqual(host, '1::f') + self.assertEqual(host, "1::f") self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection( - client_factory, expected_sni=b'target-server' + client_factory, expected_sni=b"target-server" ) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), [b'target-server'] + request.requestHeaders.getRawHeaders(b"host"), [b"target-server"] ) # finish the request @@ -566,7 +568,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop() - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) redirect_server = self._make_connection( @@ -576,7 +578,7 @@ class MatrixFederationAgentTests(TestCase): # send a 302 redirect self.assertEqual(len(redirect_server.requests), 1) request = redirect_server.requests[0] - request.redirect(b'https://testserv/even_better_known') + request.redirect(b"https://testserv/even_better_known") request.finish() self.reactor.pump((0.1,)) @@ -585,7 +587,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop() - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) well_known_server = self._make_connection( @@ -594,8 +596,8 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(len(well_known_server.requests), 1, "No request after 302") request = well_known_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/even_better_known') + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/even_better_known") request.write(b'{ "m.server": "target-server" }') request.finish() @@ -609,20 +611,20 @@ class MatrixFederationAgentTests(TestCase): # now we should get a connection to the target server self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1::f') + self.assertEqual(host, "1::f") self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection( - client_factory, expected_sni=b'target-server' + client_factory, expected_sni=b"target-server" ) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), [b'target-server'] + request.requestHeaders.getRawHeaders(b"host"), [b"target-server"] ) # finish the request @@ -657,11 +659,11 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop() - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) self._handle_well_known_connection( - client_factory, expected_sni=b"testserv", content=b'NOT JSON' + client_factory, expected_sni=b"testserv", content=b"NOT JSON" ) # now there should be a SRV lookup @@ -672,17 +674,17 @@ class MatrixFederationAgentTests(TestCase): # we should fall back to a direct connection self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop() - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8448) # make a test server, and wire up the client - http_server = self._make_connection(client_factory, expected_sni=b'testserv') + http_server = self._make_connection(client_factory, expected_sni=b"testserv") self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') - self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"]) # finish the request request.finish() @@ -717,12 +719,10 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) - http_proto = self._make_connection( - client_factory, expected_sni=b"testserv", - ) + http_proto = self._make_connection(client_factory, expected_sni=b"testserv") # there should be no requests self.assertEqual(len(http_proto.requests), 0) @@ -755,17 +755,17 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8443) # make a test server, and wire up the client - http_server = self._make_connection(client_factory, expected_sni=b'testserv') + http_server = self._make_connection(client_factory, expected_sni=b"testserv") self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') - self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"]) # finish the request request.finish() @@ -788,7 +788,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) self.mock_resolver.resolve_service.side_effect = lambda _: [ @@ -809,20 +809,20 @@ class MatrixFederationAgentTests(TestCase): # now we should get a connection to the target of the SRV record self.assertEqual(len(clients), 2) (host, port, client_factory, _timeout, _bindAddress) = clients[1] - self.assertEqual(host, '5.6.7.8') + self.assertEqual(host, "5.6.7.8") self.assertEqual(port, 8443) # make a test server, and wire up the client http_server = self._make_connection( - client_factory, expected_sni=b'target-server' + client_factory, expected_sni=b"target-server" ) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), [b'target-server'] + request.requestHeaders.getRawHeaders(b"host"), [b"target-server"] ) # finish the request @@ -851,7 +851,7 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) # fonx the connection @@ -870,20 +870,20 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 2) (host, port, client_factory, _timeout, _bindAddress) = clients[1] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection( - client_factory, expected_sni=b'xn--bcher-kva.com' + client_factory, expected_sni=b"xn--bcher-kva.com" ) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), [b'xn--bcher-kva.com'] + request.requestHeaders.getRawHeaders(b"host"), [b"xn--bcher-kva.com"] ) # finish the request @@ -912,20 +912,20 @@ class MatrixFederationAgentTests(TestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8443) # make a test server, and wire up the client http_server = self._make_connection( - client_factory, expected_sni=b'xn--bcher-kva.com' + client_factory, expected_sni=b"xn--bcher-kva.com" ) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), [b'xn--bcher-kva.com'] + request.requestHeaders.getRawHeaders(b"host"), [b"xn--bcher-kva.com"] ) # finish the request @@ -946,42 +946,42 @@ class MatrixFederationAgentTests(TestCase): def test_well_known_cache(self): self.reactor.lookups["testserv"] = "1.2.3.4" - fetch_d = self.do_get_well_known(b'testserv') + fetch_d = self.do_get_well_known(b"testserv") # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) well_known_server = self._handle_well_known_connection( client_factory, expected_sni=b"testserv", - response_headers={b'Cache-Control': b'max-age=10'}, + response_headers={b"Cache-Control": b"max-age=10"}, content=b'{ "m.server": "target-server" }', ) r = self.successResultOf(fetch_d) - self.assertEqual(r, b'target-server') + self.assertEqual(r, b"target-server") # close the tcp connection well_known_server.loseConnection() # repeat the request: it should hit the cache - fetch_d = self.do_get_well_known(b'testserv') + fetch_d = self.do_get_well_known(b"testserv") r = self.successResultOf(fetch_d) - self.assertEqual(r, b'target-server') + self.assertEqual(r, b"target-server") # expire the cache self.reactor.pump((10.0,)) # now it should connect again - fetch_d = self.do_get_well_known(b'testserv') + fetch_d = self.do_get_well_known(b"testserv") self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) self._handle_well_known_connection( @@ -991,7 +991,7 @@ class MatrixFederationAgentTests(TestCase): ) r = self.successResultOf(fetch_d) - self.assertEqual(r, b'other-server') + self.assertEqual(r, b"other-server") class TestCachePeriodFromHeaders(TestCase): @@ -999,27 +999,27 @@ class TestCachePeriodFromHeaders(TestCase): # uppercase self.assertEqual( _cache_period_from_headers( - Headers({b'Cache-Control': [b'foo, Max-Age = 100, bar']}) + Headers({b"Cache-Control": [b"foo, Max-Age = 100, bar"]}) ), 100, ) # missing value self.assertIsNone( - _cache_period_from_headers(Headers({b'Cache-Control': [b'max-age=, bar']})) + _cache_period_from_headers(Headers({b"Cache-Control": [b"max-age=, bar"]})) ) # hackernews: bogus due to semicolon self.assertIsNone( _cache_period_from_headers( - Headers({b'Cache-Control': [b'private; max-age=0']}) + Headers({b"Cache-Control": [b"private; max-age=0"]}) ) ) # github self.assertEqual( _cache_period_from_headers( - Headers({b'Cache-Control': [b'max-age=0, private, must-revalidate']}) + Headers({b"Cache-Control": [b"max-age=0, private, must-revalidate"]}) ), 0, ) @@ -1027,7 +1027,7 @@ class TestCachePeriodFromHeaders(TestCase): # google self.assertEqual( _cache_period_from_headers( - Headers({b'cache-control': [b'private, max-age=0']}) + Headers({b"cache-control": [b"private, max-age=0"]}) ), 0, ) @@ -1035,7 +1035,7 @@ class TestCachePeriodFromHeaders(TestCase): def test_expires(self): self.assertEqual( _cache_period_from_headers( - Headers({b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT']}), + Headers({b"Expires": [b"Wed, 30 Jan 2019 07:35:33 GMT"]}), time_now=lambda: 1548833700, ), 33, @@ -1046,8 +1046,8 @@ class TestCachePeriodFromHeaders(TestCase): _cache_period_from_headers( Headers( { - b'cache-control': [b'max-age=10'], - b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT'], + b"cache-control": [b"max-age=10"], + b"Expires": [b"Wed, 30 Jan 2019 07:35:33 GMT"], } ), time_now=lambda: 1548833700, @@ -1056,7 +1056,7 @@ class TestCachePeriodFromHeaders(TestCase): ) # invalid expires means immediate expiry - self.assertEqual(_cache_period_from_headers(Headers({b'Expires': [b'0']})), 0) + self.assertEqual(_cache_period_from_headers(Headers({b"Expires": [b"0"]})), 0) def _check_logcontext(context): diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index 034c0db8d2..cf6c6e95b5 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -100,7 +100,7 @@ class SrvResolverTestCase(unittest.TestCase): def test_from_cache(self): clock = MockClock() - dns_client_mock = Mock(spec_set=['lookupService']) + dns_client_mock = Mock(spec_set=["lookupService"]) dns_client_mock.lookupService = Mock(spec_set=[]) service_name = b"test_service.example.com" diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py index 3b0155ed03..b2e9533b07 100644 --- a/tests/http/test_endpoint.py +++ b/tests/http/test_endpoint.py @@ -20,12 +20,12 @@ from tests import unittest class ServerNameTestCase(unittest.TestCase): def test_parse_server_name(self): test_data = { - 'localhost': ('localhost', None), - 'my-example.com:1234': ('my-example.com', 1234), - '1.2.3.4': ('1.2.3.4', None), - '[0abc:1def::1234]': ('[0abc:1def::1234]', None), - '1.2.3.4:1': ('1.2.3.4', 1), - '[0abc:1def::1234]:8080': ('[0abc:1def::1234]', 8080), + "localhost": ("localhost", None), + "my-example.com:1234": ("my-example.com", 1234), + "1.2.3.4": ("1.2.3.4", None), + "[0abc:1def::1234]": ("[0abc:1def::1234]", None), + "1.2.3.4:1": ("1.2.3.4", 1), + "[0abc:1def::1234]:8080": ("[0abc:1def::1234]", 8080), } for i, o in test_data.items(): diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index ee767f3a5a..c4c0d9b968 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -83,7 +83,7 @@ class FederationClientTests(HomeserverTestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8008) # complete the connection and wire it up to a fake transport @@ -99,7 +99,7 @@ class FederationClientTests(HomeserverTestCase): self.assertNoResult(test_d) # Send it the HTTP response - res_json = '{ "a": 1 }'.encode('ascii') + res_json = '{ "a": 1 }'.encode("ascii") protocol.dataReceived( b"HTTP/1.1 200 OK\r\n" b"Server: Fake\r\n" @@ -138,7 +138,7 @@ class FederationClientTests(HomeserverTestCase): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8008) e = Exception("go away") factory.clientConnectionFailed(None, e) @@ -164,7 +164,7 @@ class FederationClientTests(HomeserverTestCase): # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) - self.assertEqual(clients[0][0], '1.2.3.4') + self.assertEqual(clients[0][0], "1.2.3.4") self.assertEqual(clients[0][1], 8008) # Deferred is still without a result @@ -194,7 +194,7 @@ class FederationClientTests(HomeserverTestCase): # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) - self.assertEqual(clients[0][0], '1.2.3.4') + self.assertEqual(clients[0][0], "1.2.3.4") self.assertEqual(clients[0][1], 8008) conn = Mock() @@ -215,10 +215,9 @@ class FederationClientTests(HomeserverTestCase): """Ensure that Synapse does not try to connect to blacklisted IPs""" # Set up the ip_range blacklist - self.hs.config.federation_ip_range_blacklist = IPSet([ - "127.0.0.0/8", - "fe80::/64", - ]) + self.hs.config.federation_ip_range_blacklist = IPSet( + ["127.0.0.0/8", "fe80::/64"] + ) self.reactor.lookups["internal"] = "127.0.0.1" self.reactor.lookups["internalv6"] = "fe80:0:0:0:0:8a2e:370:7337" self.reactor.lookups["fine"] = "10.20.30.40" @@ -382,7 +381,7 @@ class FederationClientTests(HomeserverTestCase): b"Content-Type: application/json\r\n" b"Content-Length: 2\r\n" b"\r\n" - b'{}' + b"{}" ) # We should get a successful response diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 72760a0733..358b593cd4 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -57,7 +57,7 @@ class EmailPusherTests(HomeserverTestCase): config["email"] = { "enable_notifs": True, "template_dir": os.path.abspath( - pkg_resources.resource_filename('synapse', 'res/templates') + pkg_resources.resource_filename("synapse", "res/templates") ), "expiry_template_html": "notice_expiry.html", "expiry_template_text": "notice_expiry.txt", @@ -120,7 +120,7 @@ class EmailPusherTests(HomeserverTestCase): # Create a simple room with two users room = self.helper.create_room_as(self.user_id, tok=self.access_token) self.helper.invite( - room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id, + room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id ) self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token) @@ -141,7 +141,7 @@ class EmailPusherTests(HomeserverTestCase): for other in self.others: self.helper.invite( - room=room, src=self.user_id, tok=self.access_token, targ=other.id, + room=room, src=self.user_id, tok=self.access_token, targ=other.id ) self.helper.join(room=room, user=other.id, tok=other.token) diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index e5fc2fcd15..b62ab36fda 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -30,7 +30,7 @@ from tests import unittest class VersionTestCase(unittest.HomeserverTestCase): - url = '/_synapse/admin/v1/server_version' + url = "/_synapse/admin/v1/server_version" def create_test_json_resource(self): resource = JsonResource(self.hs) @@ -43,7 +43,7 @@ class VersionTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( - {'server_version', 'python_version'}, set(channel.json_body.keys()) + {"server_version", "python_version"}, set(channel.json_body.keys()) ) @@ -82,12 +82,12 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): """ self.hs.config.registration_shared_secret = None - request, channel = self.make_request("POST", self.url, b'{}') + request, channel = self.make_request("POST", self.url, b"{}") self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( - 'Shared secret registration is not enabled', channel.json_body["error"] + "Shared secret registration is not enabled", channel.json_body["error"] ) def test_get_nonce(self): @@ -118,20 +118,20 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): self.reactor.advance(59) body = json.dumps({"nonce": nonce}) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('username must be specified', channel.json_body["error"]) + self.assertEqual("username must be specified", channel.json_body["error"]) # 61 seconds self.reactor.advance(2) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('unrecognised nonce', channel.json_body["error"]) + self.assertEqual("unrecognised nonce", channel.json_body["error"]) def test_register_incorrect_nonce(self): """ @@ -154,7 +154,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "mac": want_mac, } ) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) @@ -171,7 +171,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update( - nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin\x00support" + nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin\x00support" ) want_mac = want_mac.hexdigest() @@ -185,7 +185,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "mac": want_mac, } ) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -200,7 +200,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) - want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin") + want_mac.update(nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin") want_mac = want_mac.hexdigest() body = json.dumps( @@ -212,18 +212,18 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "mac": want_mac, } ) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["user_id"]) # Now, try and reuse it - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('unrecognised nonce', channel.json_body["error"]) + self.assertEqual("unrecognised nonce", channel.json_body["error"]) def test_missing_parts(self): """ @@ -243,11 +243,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Must be present body = json.dumps({}) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('nonce must be specified', channel.json_body["error"]) + self.assertEqual("nonce must be specified", channel.json_body["error"]) # # Username checks @@ -255,35 +255,35 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Must be present body = json.dumps({"nonce": nonce()}) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('username must be specified', channel.json_body["error"]) + self.assertEqual("username must be specified", channel.json_body["error"]) # Must be a string body = json.dumps({"nonce": nonce(), "username": 1234}) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('Invalid username', channel.json_body["error"]) + self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes body = json.dumps({"nonce": nonce(), "username": u"abcd\u0000"}) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('Invalid username', channel.json_body["error"]) + self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes body = json.dumps({"nonce": nonce(), "username": "a" * 1000}) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('Invalid username', channel.json_body["error"]) + self.assertEqual("Invalid username", channel.json_body["error"]) # # Password checks @@ -291,37 +291,37 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Must be present body = json.dumps({"nonce": nonce(), "username": "a"}) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('password must be specified', channel.json_body["error"]) + self.assertEqual("password must be specified", channel.json_body["error"]) # Must be a string body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234}) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('Invalid password', channel.json_body["error"]) + self.assertEqual("Invalid password", channel.json_body["error"]) # Must not have null bytes body = json.dumps( {"nonce": nonce(), "username": "a", "password": u"abcd\u0000"} ) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('Invalid password', channel.json_body["error"]) + self.assertEqual("Invalid password", channel.json_body["error"]) # Super long body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000}) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('Invalid password', channel.json_body["error"]) + self.assertEqual("Invalid password", channel.json_body["error"]) # # user_type check @@ -336,11 +336,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "user_type": "invalid", } ) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('Invalid user type', channel.json_body["error"]) + self.assertEqual("Invalid user type", channel.json_body["error"]) class ShutdownRoomTestCase(unittest.HomeserverTestCase): @@ -396,7 +396,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): url = "admin/shutdown_room/" + room_id request, channel = self.make_request( "POST", - url.encode('ascii'), + url.encode("ascii"), json.dumps({"new_room_user_id": self.admin_user}), access_token=self.admin_user_tok, ) @@ -421,7 +421,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): url = "rooms/%s/state/m.room.history_visibility" % (room_id,) request, channel = self.make_request( "PUT", - url.encode('ascii'), + url.encode("ascii"), json.dumps({"history_visibility": "world_readable"}), access_token=self.other_user_token, ) @@ -432,7 +432,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): url = "admin/shutdown_room/" + room_id request, channel = self.make_request( "POST", - url.encode('ascii'), + url.encode("ascii"), json.dumps({"new_room_user_id": self.admin_user}), access_token=self.admin_user_tok, ) @@ -449,7 +449,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): url = "rooms/%s/initialSync" % (room_id,) request, channel = self.make_request( - "GET", url.encode('ascii'), access_token=self.admin_user_tok + "GET", url.encode("ascii"), access_token=self.admin_user_tok ) self.render(request) self.assertEqual( @@ -458,7 +458,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): url = "events?timeout=0&room_id=" + room_id request, channel = self.make_request( - "GET", url.encode('ascii'), access_token=self.admin_user_tok + "GET", url.encode("ascii"), access_token=self.admin_user_tok ) self.render(request) self.assertEqual( @@ -486,7 +486,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): # Create a new group request, channel = self.make_request( "POST", - "/create_group".encode('ascii'), + "/create_group".encode("ascii"), access_token=self.admin_user_tok, content={"localpart": "test"}, ) @@ -502,14 +502,14 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): url = "/groups/%s/admin/users/invite/%s" % (group_id, self.other_user) request, channel = self.make_request( - "PUT", url.encode('ascii'), access_token=self.admin_user_tok, content={} + "PUT", url.encode("ascii"), access_token=self.admin_user_tok, content={} ) self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) url = "/groups/%s/self/accept_invite" % (group_id,) request, channel = self.make_request( - "PUT", url.encode('ascii'), access_token=self.other_user_token, content={} + "PUT", url.encode("ascii"), access_token=self.other_user_token, content={} ) self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -522,7 +522,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): url = "/admin/delete_group/" + group_id request, channel = self.make_request( "POST", - url.encode('ascii'), + url.encode("ascii"), access_token=self.admin_user_tok, content={"localpart": "test"}, ) @@ -544,7 +544,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): url = "/groups/%s/profile" % (group_id,) request, channel = self.make_request( - "GET", url.encode('ascii'), access_token=self.admin_user_tok + "GET", url.encode("ascii"), access_token=self.admin_user_tok ) self.render(request) @@ -556,7 +556,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): """Returns the list of groups the user is in (given their access token) """ request, channel = self.make_request( - "GET", "/joined_groups".encode('ascii'), access_token=access_token + "GET", "/joined_groups".encode("ascii"), access_token=access_token ) self.render(request) diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py index efc5a99db3..6803b372ac 100644 --- a/tests/rest/client/test_consent.py +++ b/tests/rest/client/test_consent.py @@ -42,17 +42,17 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): # Make some temporary templates... temp_consent_path = self.mktemp() os.mkdir(temp_consent_path) - os.mkdir(os.path.join(temp_consent_path, 'en')) + os.mkdir(os.path.join(temp_consent_path, "en")) config["user_consent"] = { "version": "1", "template_dir": os.path.abspath(temp_consent_path), } - with open(os.path.join(temp_consent_path, "en/1.html"), 'w') as f: + with open(os.path.join(temp_consent_path, "en/1.html"), "w") as f: f.write("{{version}},{{has_consented}}") - with open(os.path.join(temp_consent_path, "en/success.html"), 'w') as f: + with open(os.path.join(temp_consent_path, "en/success.html"), "w") as f: f.write("yay!") hs = self.setup_test_homeserver(config=config) @@ -88,7 +88,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) # Get the version from the body, and whether we've consented - version, consented = channel.result["body"].decode('ascii').split(",") + version, consented = channel.result["body"].decode("ascii").split(",") self.assertEqual(consented, "False") # POST to the consent page, saying we've agreed @@ -111,6 +111,6 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): # Get the version from the body, and check that it's the version we # agreed to, and that we've consented to it. - version, consented = channel.result["body"].decode('ascii').split(",") + version, consented = channel.result["body"].decode("ascii").split(",") self.assertEqual(consented, "True") self.assertEqual(version, "1") diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py index c9b9eff83e..f81f81602e 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py @@ -39,9 +39,7 @@ class IdentityDisabledTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() - config["trusted_third_party_id_servers"] = [ - "testis", - ] + config["trusted_third_party_id_servers"] = ["testis"] config["enable_3pid_lookup"] = False self.hs = self.setup_test_homeserver(config=config) @@ -53,7 +51,7 @@ class IdentityDisabledTestCase(unittest.HomeserverTestCase): def test_3pid_invite_disabled(self): request, channel = self.make_request( - b"POST", "/createRoom", b"{}", access_token=self.tok, + b"POST", "/createRoom", b"{}", access_token=self.tok ) self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) @@ -65,18 +63,18 @@ class IdentityDisabledTestCase(unittest.HomeserverTestCase): "address": "test@example.com", } request_data = json.dumps(params) - request_url = ( - "/rooms/%s/invite" % (room_id) - ).encode('ascii') + request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii") request, channel = self.make_request( - b"POST", request_url, request_data, access_token=self.tok, + b"POST", request_url, request_data, access_token=self.tok ) self.render(request) self.assertEquals(channel.result["code"], b"403", channel.result) def test_3pid_lookup_disabled(self): - url = ("/_matrix/client/unstable/account/3pid/lookup" - "?id_server=testis&medium=email&address=foo@bar.baz") + url = ( + "/_matrix/client/unstable/account/3pid/lookup" + "?id_server=testis&medium=email&address=foo@bar.baz" + ) request, channel = self.make_request("GET", url, access_token=self.tok) self.render(request) self.assertEqual(channel.result["code"], b"403", channel.result) @@ -85,20 +83,11 @@ class IdentityDisabledTestCase(unittest.HomeserverTestCase): url = "/_matrix/client/unstable/account/3pid/bulk_lookup" data = { "id_server": "testis", - "threepids": [ - [ - "email", - "foo@bar.baz" - ], - [ - "email", - "john.doe@matrix.org" - ] - ] + "threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]], } request_data = json.dumps(data) request, channel = self.make_request( - "POST", url, request_data, access_token=self.tok, + "POST", url, request_data, access_token=self.tok ) self.render(request) self.assertEqual(channel.result["code"], b"403", channel.result) @@ -118,20 +107,14 @@ class IdentityEnabledTestCase(unittest.HomeserverTestCase): config = self.default_config() config["enable_3pid_lookup"] = True - config["trusted_third_party_id_servers"] = [ - "testis", - ] - - mock_http_client = Mock(spec=[ - "get_json", - "post_json_get_json", - ]) + config["trusted_third_party_id_servers"] = ["testis"] + + mock_http_client = Mock(spec=["get_json", "post_json_get_json"]) mock_http_client.get_json.return_value = defer.succeed((200, "{}")) mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}")) self.hs = self.setup_test_homeserver( - config=config, - simple_http_client=mock_http_client, + config=config, simple_http_client=mock_http_client ) return self.hs @@ -142,7 +125,7 @@ class IdentityEnabledTestCase(unittest.HomeserverTestCase): def test_3pid_invite_enabled(self): request, channel = self.make_request( - b"POST", "/createRoom", b"{}", access_token=self.tok, + b"POST", "/createRoom", b"{}", access_token=self.tok ) self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) @@ -154,70 +137,46 @@ class IdentityEnabledTestCase(unittest.HomeserverTestCase): "address": "test@example.com", } request_data = json.dumps(params) - request_url = ("/rooms/%s/invite" % (room_id)).encode('ascii') + request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii") request, channel = self.make_request( - b"POST", request_url, request_data, access_token=self.tok, + b"POST", request_url, request_data, access_token=self.tok ) self.render(request) get_json = self.hs.get_simple_http_client().get_json get_json.assert_called_once_with( "https://testis/_matrix/identity/api/v1/lookup", - { - "address": "test@example.com", - "medium": "email", - }, + {"address": "test@example.com", "medium": "email"}, ) def test_3pid_lookup_enabled(self): - url = ("/_matrix/client/unstable/account/3pid/lookup" - "?id_server=testis&medium=email&address=foo@bar.baz") + url = ( + "/_matrix/client/unstable/account/3pid/lookup" + "?id_server=testis&medium=email&address=foo@bar.baz" + ) request, channel = self.make_request("GET", url, access_token=self.tok) self.render(request) get_json = self.hs.get_simple_http_client().get_json get_json.assert_called_once_with( "https://testis/_matrix/identity/api/v1/lookup", - { - "address": "foo@bar.baz", - "medium": "email", - }, + {"address": "foo@bar.baz", "medium": "email"}, ) def test_3pid_bulk_lookup_enabled(self): url = "/_matrix/client/unstable/account/3pid/bulk_lookup" data = { "id_server": "testis", - "threepids": [ - [ - "email", - "foo@bar.baz" - ], - [ - "email", - "john.doe@matrix.org" - ] - ] + "threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]], } request_data = json.dumps(data) request, channel = self.make_request( - "POST", url, request_data, access_token=self.tok, + "POST", url, request_data, access_token=self.tok ) self.render(request) post_json = self.hs.get_simple_http_client().post_json_get_json post_json.assert_called_once_with( "https://testis/_matrix/identity/api/v1/bulk_lookup", - { - "threepids": [ - [ - "email", - "foo@bar.baz" - ], - [ - "email", - "john.doe@matrix.org" - ] - ], - }, + {"threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]]}, ) diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index d0deff5a3b..4303f95206 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -61,9 +61,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): self.helper.send_state( room_id=room_id, event_type=EventTypes.Retention, - body={ - "max_lifetime": one_day_ms * 4, - }, + body={"max_lifetime": one_day_ms * 4}, tok=self.token, expect_code=400, ) @@ -71,9 +69,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): self.helper.send_state( room_id=room_id, event_type=EventTypes.Retention, - body={ - "max_lifetime": one_hour_ms, - }, + body={"max_lifetime": one_hour_ms}, tok=self.token, expect_code=400, ) @@ -89,9 +85,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): self.helper.send_state( room_id=room_id, event_type=EventTypes.Retention, - body={ - "max_lifetime": lifetime, - }, + body={"max_lifetime": lifetime}, tok=self.token, ) @@ -114,48 +108,32 @@ class RetentionTestCase(unittest.HomeserverTestCase): events = [] # Send a first event, which should be filtered out at the end of the test. - resp = self.helper.send( - room_id=room_id, - body="1", - tok=self.token, - ) + resp = self.helper.send(room_id=room_id, body="1", tok=self.token) # Get the event from the store so that we end up with a FrozenEvent that we can # give to filter_events_for_client. We need to do this now because the event won't # be in the database anymore after it has expired. - events.append(self.get_success( - store.get_event( - resp.get("event_id") - ) - )) + events.append(self.get_success(store.get_event(resp.get("event_id")))) # Advance the time by 2 days. We're using the default retention policy, therefore # after this the first event will still be valid. self.reactor.advance(one_day_ms * 2 / 1000) # Send another event, which shouldn't get filtered out. - resp = self.helper.send( - room_id=room_id, - body="2", - tok=self.token, - ) + resp = self.helper.send(room_id=room_id, body="2", tok=self.token) valid_event_id = resp.get("event_id") - events.append(self.get_success( - store.get_event( - valid_event_id - ) - )) + events.append(self.get_success(store.get_event(valid_event_id))) # Advance the time by anothe 2 days. After this, the first event should be # outdated but not the second one. self.reactor.advance(one_day_ms * 2 / 1000) # Run filter_events_for_client with our list of FrozenEvents. - filtered_events = self.get_success(filter_events_for_client( - store, self.user_id, events - )) + filtered_events = self.get_success( + filter_events_for_client(store, self.user_id, events) + ) # We should only get one event back. self.assertEqual(len(filtered_events), 1, filtered_events) @@ -171,28 +149,22 @@ class RetentionTestCase(unittest.HomeserverTestCase): # Send a first event to the room. This is the event we'll want to be purged at the # end of the test. - resp = self.helper.send( - room_id=room_id, - body="1", - tok=self.token, - ) + resp = self.helper.send(room_id=room_id, body="1", tok=self.token) expired_event_id = resp.get("event_id") # Check that we can retrieve the event. expired_event = self.get_event(room_id, expired_event_id) - self.assertEqual(expired_event.get("content", {}).get("body"), "1", expired_event) + self.assertEqual( + expired_event.get("content", {}).get("body"), "1", expired_event + ) # Advance the time. self.reactor.advance(increment / 1000) # Send another event. We need this because the purge job won't purge the most # recent event in the room. - resp = self.helper.send( - room_id=room_id, - body="2", - tok=self.token, - ) + resp = self.helper.send(room_id=room_id, body="2", tok=self.token) valid_event_id = resp.get("event_id") @@ -232,15 +204,12 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() config["default_room_version"] = "1" - config["retention"] = { - "enabled": True, - } + config["retention"] = {"enabled": True} mock_federation_client = Mock(spec=["backfill"]) self.hs = self.setup_test_homeserver( - config=config, - federation_client=mock_federation_client, + config=config, federation_client=mock_federation_client ) return self.hs @@ -267,9 +236,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): self.helper.send_state( room_id=room_id, event_type=EventTypes.Retention, - body={ - "max_lifetime": one_day_ms * 35, - }, + body={"max_lifetime": one_day_ms * 35}, tok=self.token, ) @@ -278,28 +245,22 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): def _test_retention(self, room_id, expected_code_for_first_event=200): # Send a first event to the room. This is the event we'll want to be purged at the # end of the test. - resp = self.helper.send( - room_id=room_id, - body="1", - tok=self.token, - ) + resp = self.helper.send(room_id=room_id, body="1", tok=self.token) first_event_id = resp.get("event_id") # Check that we can retrieve the event. expired_event = self.get_event(room_id, first_event_id) - self.assertEqual(expired_event.get("content", {}).get("body"), "1", expired_event) + self.assertEqual( + expired_event.get("content", {}).get("body"), "1", expired_event + ) # Advance the time by a month. self.reactor.advance(one_day_ms * 30 / 1000) # Send another event. We need this because the purge job won't purge the most # recent event in the room. - resp = self.helper.send( - room_id=room_id, - body="2", - tok=self.token, - ) + resp = self.helper.send(room_id=room_id, body="2", tok=self.token) second_event_id = resp.get("event_id") @@ -312,7 +273,9 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): ) if expected_code_for_first_event == 200: - self.assertEqual(first_event.get("content", {}).get("body"), "1", first_event) + self.assertEqual( + first_event.get("content", {}).get("body"), "1", first_event + ) # Check that the event that hasn't been purged can still be retrieved. second_event = self.get_event(room_id, second_event_id) diff --git a/tests/rest/client/test_room_access_rules.py b/tests/rest/client/test_room_access_rules.py index 13caea3b01..c17b0c226d 100644 --- a/tests/rest/client/test_room_access_rules.py +++ b/tests/rest/client/test_room_access_rules.py @@ -49,15 +49,11 @@ class RoomAccessTestCase(unittest.HomeserverTestCase): config["third_party_event_rules"] = { "module": "synapse.third_party_rules.access_rules.RoomAccessRules", "config": { - "domains_forbidden_when_restricted": [ - "forbidden_domain" - ], + "domains_forbidden_when_restricted": ["forbidden_domain"], "id_server": "testis", - } + }, } - config["trusted_third_party_id_servers"] = [ - "testis", - ] + config["trusted_third_party_id_servers"] = ["testis"] def send_invite(destination, room_id, event_id, pdu): return defer.succeed(pdu) @@ -67,31 +63,28 @@ class RoomAccessTestCase(unittest.HomeserverTestCase): return defer.succeed({"hs": address_domain}) def post_urlencoded_get_json(uri, args={}, headers=None): - token = ''.join(random.choice(string.ascii_letters) for _ in range(10)) - return defer.succeed({ - "token": token, - "public_keys": [ - { - "public_key": "serverpublickey", - "key_validity_url": "https://testis/pubkey/isvalid", - }, - { - "public_key": "phemeralpublickey", - "key_validity_url": "https://testis/pubkey/ephemeral/isvalid", - }, - ], - "display_name": "f...@b...", - }) - - mock_federation_client = Mock(spec=[ - "send_invite", - ]) + token = "".join(random.choice(string.ascii_letters) for _ in range(10)) + return defer.succeed( + { + "token": token, + "public_keys": [ + { + "public_key": "serverpublickey", + "key_validity_url": "https://testis/pubkey/isvalid", + }, + { + "public_key": "phemeralpublickey", + "key_validity_url": "https://testis/pubkey/ephemeral/isvalid", + }, + ], + "display_name": "f...@b...", + } + ) + + mock_federation_client = Mock(spec=["send_invite"]) mock_federation_client.send_invite.side_effect = send_invite - mock_http_client = Mock(spec=[ - "get_json", - "post_urlencoded_get_json" - ]) + mock_http_client = Mock(spec=["get_json", "post_urlencoded_get_json"]) # Mocking the response for /info on the IS API. mock_http_client.get_json.side_effect = get_json # Mocking the response for /store-invite on the IS API. @@ -164,74 +157,80 @@ class RoomAccessTestCase(unittest.HomeserverTestCase): # rule to restricted. preset_room_id = self.create_room(preset=RoomCreationPreset.PUBLIC_CHAT) self.assertEqual( - self.current_rule_in_room(preset_room_id), ACCESS_RULE_RESTRICTED, + self.current_rule_in_room(preset_room_id), ACCESS_RULE_RESTRICTED ) # Creating a room with the public join rule in its initial state should succeed # and set the access rule to restricted. - init_state_room_id = self.create_room(initial_state=[{ - "type": "m.room.join_rules", - "content": { - "join_rule": JoinRules.PUBLIC, - }, - }]) + init_state_room_id = self.create_room( + initial_state=[ + { + "type": "m.room.join_rules", + "content": {"join_rule": JoinRules.PUBLIC}, + } + ] + ) self.assertEqual( - self.current_rule_in_room(init_state_room_id), ACCESS_RULE_RESTRICTED, + self.current_rule_in_room(init_state_room_id), ACCESS_RULE_RESTRICTED ) # Changing access rule to unrestricted should fail. self.change_rule_in_room( - preset_room_id, ACCESS_RULE_UNRESTRICTED, expected_code=403, + preset_room_id, ACCESS_RULE_UNRESTRICTED, expected_code=403 ) self.change_rule_in_room( - init_state_room_id, ACCESS_RULE_UNRESTRICTED, expected_code=403, + init_state_room_id, ACCESS_RULE_UNRESTRICTED, expected_code=403 ) # Changing access rule to direct should fail. + self.change_rule_in_room(preset_room_id, ACCESS_RULE_DIRECT, expected_code=403) self.change_rule_in_room( - preset_room_id, ACCESS_RULE_DIRECT, expected_code=403, - ) - self.change_rule_in_room( - init_state_room_id, ACCESS_RULE_DIRECT, expected_code=403, + init_state_room_id, ACCESS_RULE_DIRECT, expected_code=403 ) # Changing join rule to public in an unrestricted room should fail. self.change_join_rule_in_room( - self.unrestricted_room, JoinRules.PUBLIC, expected_code=403, + self.unrestricted_room, JoinRules.PUBLIC, expected_code=403 ) # Changing join rule to public in an direct room should fail. self.change_join_rule_in_room( - self.direct_rooms[0], JoinRules.PUBLIC, expected_code=403, + self.direct_rooms[0], JoinRules.PUBLIC, expected_code=403 ) # Creating a new room with the public_chat preset and an access rule that isn't # restricted should fail. self.create_room( - preset=RoomCreationPreset.PUBLIC_CHAT, rule=ACCESS_RULE_UNRESTRICTED, + preset=RoomCreationPreset.PUBLIC_CHAT, + rule=ACCESS_RULE_UNRESTRICTED, expected_code=400, ) self.create_room( - preset=RoomCreationPreset.PUBLIC_CHAT, rule=ACCESS_RULE_DIRECT, + preset=RoomCreationPreset.PUBLIC_CHAT, + rule=ACCESS_RULE_DIRECT, expected_code=400, ) # Creating a room with the public join rule in its initial state and an access # rule that isn't restricted should fail. self.create_room( - initial_state=[{ - "type": "m.room.join_rules", - "content": { - "join_rule": JoinRules.PUBLIC, - }, - }], rule=ACCESS_RULE_UNRESTRICTED, expected_code=400, + initial_state=[ + { + "type": "m.room.join_rules", + "content": {"join_rule": JoinRules.PUBLIC}, + } + ], + rule=ACCESS_RULE_UNRESTRICTED, + expected_code=400, ) self.create_room( - initial_state=[{ - "type": "m.room.join_rules", - "content": { - "join_rule": JoinRules.PUBLIC, - }, - }], rule=ACCESS_RULE_DIRECT, expected_code=400, + initial_state=[ + { + "type": "m.room.join_rules", + "content": {"join_rule": JoinRules.PUBLIC}, + } + ], + rule=ACCESS_RULE_DIRECT, + expected_code=400, ) def test_restricted(self): @@ -405,12 +404,7 @@ class RoomAccessTestCase(unittest.HomeserverTestCase): self.helper.send_state( room_id=self.unrestricted_room, event_type=EventTypes.PowerLevels, - body={ - "users": { - self.user_id: 100, - "@test:not_forbidden_domain": 10, - }, - }, + body={"users": {self.user_id: 100, "@test:not_forbidden_domain": 10}}, tok=self.tok, expect_code=200, ) @@ -421,10 +415,7 @@ class RoomAccessTestCase(unittest.HomeserverTestCase): room_id=self.unrestricted_room, event_type=EventTypes.PowerLevels, body={ - "users": { - self.user_id: 100, - "@test:not_forbidden_domain": 10, - }, + "users": {self.user_id: 100, "@test:not_forbidden_domain": 10}, "users_default": 10, }, tok=self.tok, @@ -436,12 +427,7 @@ class RoomAccessTestCase(unittest.HomeserverTestCase): self.helper.send_state( room_id=self.unrestricted_room, event_type=EventTypes.PowerLevels, - body={ - "users": { - self.user_id: 100, - "@test:forbidden_domain": 10, - }, - }, + body={"users": {self.user_id: 100, "@test:forbidden_domain": 10}}, tok=self.tok, expect_code=403, ) @@ -459,9 +445,7 @@ class RoomAccessTestCase(unittest.HomeserverTestCase): # We can't change the rule from restricted to direct. self.change_rule_in_room( - room_id=self.restricted_room, - new_rule=ACCESS_RULE_DIRECT, - expected_code=403, + room_id=self.restricted_room, new_rule=ACCESS_RULE_DIRECT, expected_code=403 ) # We can't change the rule from unrestricted to restricted. @@ -498,12 +482,7 @@ class RoomAccessTestCase(unittest.HomeserverTestCase): """ avatar_content = { - "info": { - "h": 398, - "mimetype": "image/jpeg", - "size": 31037, - "w": 394 - }, + "info": {"h": 398, "mimetype": "image/jpeg", "size": 31037, "w": 394}, "url": "mxc://example.org/JWEIFJgwEIhweiWJE", } @@ -536,9 +515,7 @@ class RoomAccessTestCase(unittest.HomeserverTestCase): chat, in which case it's forbidden. """ - name_content = { - "name": "My super room", - } + name_content = {"name": "My super room"} self.helper.send_state( room_id=self.restricted_room, @@ -569,9 +546,7 @@ class RoomAccessTestCase(unittest.HomeserverTestCase): direct chat, in which case it's forbidden. """ - topic_content = { - "topic": "Welcome to this room", - } + topic_content = {"topic": "Welcome to this room"} self.helper.send_state( room_id=self.restricted_room, @@ -608,15 +583,15 @@ class RoomAccessTestCase(unittest.HomeserverTestCase): "public_keys": [ { "key_validity_url": "https://validity_url", - "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA" + "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA", }, { "key_validity_url": "https://validity_url", - "public_key": "4_9nzEeDwR5N9s51jPodBiLnqH43A2_g2InVT137t9I" - } + "public_key": "4_9nzEeDwR5N9s51jPodBiLnqH43A2_g2InVT137t9I", + }, ], "key_validity_url": "https://validity_url", - "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA" + "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA", } self.send_state_with_state_key( @@ -646,22 +621,19 @@ class RoomAccessTestCase(unittest.HomeserverTestCase): ) def create_room( - self, direct=False, rule=None, preset=RoomCreationPreset.TRUSTED_PRIVATE_CHAT, - initial_state=None, expected_code=200, + self, + direct=False, + rule=None, + preset=RoomCreationPreset.TRUSTED_PRIVATE_CHAT, + initial_state=None, + expected_code=200, ): - content = { - "is_direct": direct, - "preset": preset, - } + content = {"is_direct": direct, "preset": preset} if rule: - content["initial_state"] = [{ - "type": ACCESS_RULES_TYPE, - "state_key": "", - "content": { - "rule": rule, - } - }] + content["initial_state"] = [ + {"type": ACCESS_RULES_TYPE, "state_key": "", "content": {"rule": rule}} + ] if initial_state: if "initial_state" not in content: @@ -694,9 +666,7 @@ class RoomAccessTestCase(unittest.HomeserverTestCase): return channel.json_body["rule"] def change_rule_in_room(self, room_id, new_rule, expected_code=200): - data = { - "rule": new_rule, - } + data = {"rule": new_rule} request, channel = self.make_request( "PUT", "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, ACCESS_RULES_TYPE), @@ -708,9 +678,7 @@ class RoomAccessTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, expected_code, channel.result) def change_join_rule_in_room(self, room_id, new_join_rule, expected_code=200): - data = { - "join_rule": new_join_rule, - } + data = {"join_rule": new_join_rule} request, channel = self.make_request( "PUT", "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, EventTypes.JoinRules), @@ -722,11 +690,7 @@ class RoomAccessTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, expected_code, channel.result) def send_threepid_invite(self, address, room_id, expected_code=200): - params = { - "id_server": "testis", - "medium": "email", - "address": address, - } + params = {"id_server": "testis", "medium": "email", "address": address} request, channel = self.make_request( "POST", @@ -741,7 +705,9 @@ class RoomAccessTestCase(unittest.HomeserverTestCase): self, room_id, event_type, state_key, body, tok, expect_code=200 ): path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % ( - room_id, event_type, state_key + room_id, + event_type, + state_key, ) request, channel = self.make_request( diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 6958430608..02b4b8f5eb 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -183,7 +183,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): def test_set_displayname(self): request, channel = self.make_request( "PUT", - "/profile/%s/displayname" % (self.owner, ), + "/profile/%s/displayname" % (self.owner,), content=json.dumps({"displayname": "test"}), access_token=self.owner_tok, ) @@ -197,7 +197,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): """Attempts to set a stupid displayname should get a 400""" request, channel = self.make_request( "PUT", - "/profile/%s/displayname" % (self.owner, ), + "/profile/%s/displayname" % (self.owner,), content=json.dumps({"displayname": "test" * 100}), access_token=self.owner_tok, ) @@ -209,8 +209,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): def get_displayname(self): request, channel = self.make_request( - "GET", - "/profile/%s/displayname" % (self.owner, ), + "GET", "/profile/%s/displayname" % (self.owner,) ) self.render(request) self.assertEqual(channel.code, 200, channel.result) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 2d64b338be..fe741637f5 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -79,7 +79,7 @@ class RoomPermissionsTestCase(RoomBase): # send a message in one of the rooms self.created_rmid_msg_path = ( "rooms/%s/send/m.room.message/a1" % (self.created_rmid) - ).encode('ascii') + ).encode("ascii") request, channel = self.make_request( "PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}' ) @@ -89,7 +89,7 @@ class RoomPermissionsTestCase(RoomBase): # set topic for public room request, channel = self.make_request( "PUT", - ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode('ascii'), + ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"), b'{"topic":"Public Room Topic"}', ) self.render(request) @@ -193,7 +193,7 @@ class RoomPermissionsTestCase(RoomBase): request, channel = self.make_request("GET", topic_path) self.render(request) self.assertEquals(200, channel.code, msg=channel.result["body"]) - self.assert_dict(json.loads(topic_content.decode('utf8')), channel.json_body) + self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body) # set/get topic in created PRIVATE room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) @@ -497,7 +497,7 @@ class RoomTopicTestCase(RoomBase): def test_invalid_puts(self): # missing keys or invalid json - request, channel = self.make_request("PUT", self.path, '{}') + request, channel = self.make_request("PUT", self.path, "{}") self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) @@ -515,11 +515,11 @@ class RoomTopicTestCase(RoomBase): self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", self.path, 'text only') + request, channel = self.make_request("PUT", self.path, "text only") self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", self.path, '') + request, channel = self.make_request("PUT", self.path, "") self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) @@ -572,7 +572,7 @@ class RoomMemberStateTestCase(RoomBase): def test_invalid_puts(self): path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) # missing keys or invalid json - request, channel = self.make_request("PUT", path, '{}') + request, channel = self.make_request("PUT", path, "{}") self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) @@ -590,11 +590,11 @@ class RoomMemberStateTestCase(RoomBase): self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, 'text only') + request, channel = self.make_request("PUT", path, "text only") self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, '') + request, channel = self.make_request("PUT", path, "") self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) @@ -604,7 +604,7 @@ class RoomMemberStateTestCase(RoomBase): Membership.JOIN, Membership.LEAVE, ) - request, channel = self.make_request("PUT", path, content.encode('ascii')) + request, channel = self.make_request("PUT", path, content.encode("ascii")) self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) @@ -616,7 +616,7 @@ class RoomMemberStateTestCase(RoomBase): # valid join message (NOOP since we made the room) content = '{"membership":"%s"}' % Membership.JOIN - request, channel = self.make_request("PUT", path, content.encode('ascii')) + request, channel = self.make_request("PUT", path, content.encode("ascii")) self.render(request) self.assertEquals(200, channel.code, msg=channel.result["body"]) @@ -678,7 +678,7 @@ class RoomMessagesTestCase(RoomBase): def test_invalid_puts(self): path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) # missing keys or invalid json - request, channel = self.make_request("PUT", path, b'{}') + request, channel = self.make_request("PUT", path, b"{}") self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) @@ -696,11 +696,11 @@ class RoomMessagesTestCase(RoomBase): self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, b'text only') + request, channel = self.make_request("PUT", path, b"text only") self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, b'') + request, channel = self.make_request("PUT", path, b"") self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) @@ -786,7 +786,7 @@ class RoomMessageListTestCase(RoomBase): self.render(request) self.assertEquals(200, channel.code) self.assertTrue("start" in channel.json_body) - self.assertEquals(token, channel.json_body['start']) + self.assertEquals(token, channel.json_body["start"]) self.assertTrue("chunk" in channel.json_body) self.assertTrue("end" in channel.json_body) @@ -798,7 +798,7 @@ class RoomMessageListTestCase(RoomBase): self.render(request) self.assertEquals(200, channel.code) self.assertTrue("start" in channel.json_body) - self.assertEquals(token, channel.json_body['start']) + self.assertEquals(token, channel.json_body["start"]) self.assertTrue("chunk" in channel.json_body) self.assertTrue("end" in channel.json_body) @@ -961,9 +961,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): # Set a profile for the test user self.displayname = "test user" - data = { - "displayname": self.displayname, - } + data = {"displayname": self.displayname} request_data = json.dumps(data) request, channel = self.make_request( "PUT", @@ -977,16 +975,12 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) def test_per_room_profile_forbidden(self): - data = { - "membership": "join", - "displayname": "other test user" - } + data = {"membership": "join", "displayname": "other test user"} request_data = json.dumps(data) request, channel = self.make_request( "PUT", - "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % ( - self.room_id, self.user_id, - ), + "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" + % (self.room_id, self.user_id), request_data, access_token=self.tok, ) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 449a69183f..cdded88b7f 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -44,7 +44,7 @@ class RestHelper(object): path = path + "?access_token=%s" % tok request, channel = make_request( - self.hs.get_reactor(), "POST", path, json.dumps(content).encode('utf8') + self.hs.get_reactor(), "POST", path, json.dumps(content).encode("utf8") ) render(request, self.resource, self.hs.get_reactor()) @@ -93,7 +93,7 @@ class RestHelper(object): data = {"membership": membership} request, channel = make_request( - self.hs.get_reactor(), "PUT", path, json.dumps(data).encode('utf8') + self.hs.get_reactor(), "PUT", path, json.dumps(data).encode("utf8") ) render(request, self.resource, self.hs.get_reactor()) @@ -117,7 +117,7 @@ class RestHelper(object): path = path + "?access_token=%s" % tok request, channel = make_request( - self.hs.get_reactor(), "PUT", path, json.dumps(content).encode('utf8') + self.hs.get_reactor(), "PUT", path, json.dumps(content).encode("utf8") ) render(request, self.resource, self.hs.get_reactor()) diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 32adf88c35..9fed900f4a 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -135,9 +135,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): self.assertEquals(len(self.email_attempts), 1) # Attempt to reset password without clicking the link - self._reset_password( - new_password, session_id, client_secret, expected_code=401, - ) + self._reset_password(new_password, session_id, client_secret, expected_code=401) # Assert we can log in with the old password self.login("kermit", old_password) @@ -172,9 +170,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): session_id = "weasle" # Attempt to reset password without even requesting an email - self._reset_password( - new_password, session_id, client_secret, expected_code=401, - ) + self._reset_password(new_password, session_id, client_secret, expected_code=401) # Assert we can log in with the old password self.login("kermit", old_password) @@ -259,19 +255,18 @@ class DeactivateTestCase(unittest.HomeserverTestCase): user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test") - request_data = json.dumps({ - "auth": { - "type": "m.login.password", - "user": user_id, - "password": "test", - }, - "erase": False, - }) + request_data = json.dumps( + { + "auth": { + "type": "m.login.password", + "user": user_id, + "password": "test", + }, + "erase": False, + } + ) request, channel = self.make_request( - "POST", - "account/deactivate", - request_data, - access_token=tok, + "POST", "account/deactivate", request_data, access_token=tok ) self.render(request) self.assertEqual(request.code, 200) diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py index bce5b0cf4c..b9e01c9418 100644 --- a/tests/rest/client/v2_alpha/test_capabilities.py +++ b/tests/rest/client/v2_alpha/test_capabilities.py @@ -47,15 +47,15 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): request, channel = self.make_request("GET", self.url, access_token=access_token) self.render(request) - capabilities = channel.json_body['capabilities'] + capabilities = channel.json_body["capabilities"] self.assertEqual(channel.code, 200) - for room_version in capabilities['m.room_versions']['available'].keys(): + for room_version in capabilities["m.room_versions"]["available"].keys(): self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, "" + room_version) self.assertEqual( self.config.default_room_version.identifier, - capabilities['m.room_versions']['default'], + capabilities["m.room_versions"]["default"], ) def test_get_change_password_capabilities(self): @@ -66,16 +66,16 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): request, channel = self.make_request("GET", self.url, access_token=access_token) self.render(request) - capabilities = channel.json_body['capabilities'] + capabilities = channel.json_body["capabilities"] self.assertEqual(channel.code, 200) # Test case where password is handled outside of Synapse - self.assertTrue(capabilities['m.change_password']['enabled']) + self.assertTrue(capabilities["m.change_password"]["enabled"]) self.get_success(self.store.user_set_password_hash(user, None)) request, channel = self.make_request("GET", self.url, access_token=access_token) self.render(request) - capabilities = channel.json_body['capabilities'] + capabilities = channel.json_body["capabilities"] self.assertEqual(channel.code, 200) - self.assertFalse(capabilities['m.change_password']['enabled']) + self.assertFalse(capabilities["m.change_password"]["enabled"]) diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py index 17c22fe751..37f970c6b0 100644 --- a/tests/rest/client/v2_alpha/test_password_policy.py +++ b/tests/rest/client/v2_alpha/test_password_policy.py @@ -60,9 +60,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): } config = self.default_config() - config["password_config"] = { - "policy": self.policy, - } + config["password_config"] = {"policy": self.policy} hs = self.setup_test_homeserver(config=config) return hs @@ -70,17 +68,23 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): def test_get_policy(self): """Tests if the /password_policy endpoint returns the configured policy.""" - request, channel = self.make_request("GET", "/_matrix/client/r0/password_policy") + request, channel = self.make_request( + "GET", "/_matrix/client/r0/password_policy" + ) self.render(request) self.assertEqual(channel.code, 200, channel.result) - self.assertEqual(channel.json_body, { - "m.minimum_length": 10, - "m.require_digit": True, - "m.require_symbol": True, - "m.require_lowercase": True, - "m.require_uppercase": True, - }, channel.result) + self.assertEqual( + channel.json_body, + { + "m.minimum_length": 10, + "m.require_digit": True, + "m.require_symbol": True, + "m.require_lowercase": True, + "m.require_uppercase": True, + }, + channel.result, + ) def test_password_too_short(self): request_data = json.dumps({"username": "kermit", "password": "shorty"}) @@ -89,9 +93,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) self.assertEqual( - channel.json_body["errcode"], - Codes.PASSWORD_TOO_SHORT, - channel.result, + channel.json_body["errcode"], Codes.PASSWORD_TOO_SHORT, channel.result ) def test_password_no_digit(self): @@ -101,9 +103,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) self.assertEqual( - channel.json_body["errcode"], - Codes.PASSWORD_NO_DIGIT, - channel.result, + channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT, channel.result ) def test_password_no_symbol(self): @@ -113,9 +113,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) self.assertEqual( - channel.json_body["errcode"], - Codes.PASSWORD_NO_SYMBOL, - channel.result, + channel.json_body["errcode"], Codes.PASSWORD_NO_SYMBOL, channel.result ) def test_password_no_uppercase(self): @@ -125,9 +123,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) self.assertEqual( - channel.json_body["errcode"], - Codes.PASSWORD_NO_UPPERCASE, - channel.result, + channel.json_body["errcode"], Codes.PASSWORD_NO_UPPERCASE, channel.result ) def test_password_no_lowercase(self): @@ -137,9 +133,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) self.assertEqual( - channel.json_body["errcode"], - Codes.PASSWORD_NO_LOWERCASE, - channel.result, + channel.json_body["errcode"], Codes.PASSWORD_NO_LOWERCASE, channel.result ) def test_password_compliant(self): @@ -161,14 +155,16 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): user_id = self.register_user("kermit", compliant_password) tok = self.login("kermit", compliant_password) - request_data = json.dumps({ - "new_password": not_compliant_password, - "auth": { - "password": compliant_password, - "type": LoginType.PASSWORD, - "user": user_id, + request_data = json.dumps( + { + "new_password": not_compliant_password, + "auth": { + "password": compliant_password, + "type": LoginType.PASSWORD, + "user": user_id, + }, } - }) + ) request, channel = self.make_request( "POST", "/_matrix/client/r0/account/password", diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index a5c7aaa9c0..085978b4a6 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -206,9 +206,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): class RegisterHideProfileTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - ] + servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource] def make_homeserver(self, reactor, clock): @@ -219,15 +217,11 @@ class RegisterHideProfileTestCase(unittest.HomeserverTestCase): config["show_users_in_user_directory"] = False config["replicate_user_profiles_to"] = ["fakeserver"] - mock_http_client = Mock(spec=[ - "get_json", - "post_json_get_json", - ]) + mock_http_client = Mock(spec=["get_json", "post_json_get_json"]) mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}")) self.hs = self.setup_test_homeserver( - config=config, - simple_http_client=mock_http_client, + config=config, simple_http_client=mock_http_client ) return self.hs @@ -376,14 +370,11 @@ class AccountValidityUserDirectoryTestCase(unittest.HomeserverTestCase): config["replicate_user_profiles_to"] = "test.is" # Mock homeserver requests to an identity server - mock_http_client = Mock(spec=[ - "post_json_get_json", - ]) + mock_http_client = Mock(spec=["post_json_get_json"]) mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}")) self.hs = self.setup_test_homeserver( - config=config, - simple_http_client=mock_http_client, + config=config, simple_http_client=mock_http_client ) return self.hs @@ -524,7 +515,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): config["email"] = { "enable_notifs": True, "template_dir": os.path.abspath( - pkg_resources.resource_filename('synapse', 'res/templates') + pkg_resources.resource_filename("synapse", "res/templates") ), "expiry_template_html": "notice_expiry.html", "expiry_template_text": "notice_expiry.txt", @@ -624,19 +615,18 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): (user_id, tok) = self.create_user() - request_data = json.dumps({ - "auth": { - "type": "m.login.password", - "user": user_id, - "password": "monkey", - }, - "erase": False, - }) + request_data = json.dumps( + { + "auth": { + "type": "m.login.password", + "user": user_id, + "password": "monkey", + }, + "erase": False, + } + ) request, channel = self.make_request( - "POST", - "account/deactivate", - request_data, - access_token=tok, + "POST", "account/deactivate", request_data, access_token=tok ) self.render(request) self.assertEqual(request.code, 200, channel.result) @@ -700,9 +690,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - ] + servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource] def make_homeserver(self, reactor, clock): self.validity_period = 10 @@ -711,9 +699,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): config = self.default_config() config["enable_registration"] = True - config["account_validity"] = { - "enabled": False, - } + config["account_validity"] = {"enabled": False} self.hs = self.setup_test_homeserver(config=config) self.hs.config.account_validity.period = self.validity_period diff --git a/tests/rest/media/v1/test_base.py b/tests/rest/media/v1/test_base.py index 00688a7325..aa6a5f880c 100644 --- a/tests/rest/media/v1/test_base.py +++ b/tests/rest/media/v1/test_base.py @@ -24,14 +24,14 @@ class GetFileNameFromHeadersTests(unittest.TestCase): b"inline; filename=abc.txt": u"abc.txt", b'inline; filename="azerty"': u"azerty", b'inline; filename="aze%20rty"': u"aze%20rty", - b'inline; filename="aze\"rty"': u'aze"rty', + b'inline; filename="aze"rty"': u'aze"rty', b'inline; filename="azer;ty"': u"azer;ty", b"inline; filename*=utf-8''foo%C2%A3bar": u"foo£bar", } def tests(self): for hdr, expected in self.TEST_CASES.items(): - res = get_filename_from_headers({b'Content-Disposition': [hdr]}) + res = get_filename_from_headers({b"Content-Disposition": [hdr]}) self.assertEqual( res, expected, diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 1069a44145..3c6480b3b2 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -143,7 +143,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.media_repo = hs.get_media_repository_resource() - self.download_resource = self.media_repo.children[b'download'] + self.download_resource = self.media_repo.children[b"download"] # smol png self.end_content = unhexlify( @@ -171,7 +171,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): headers = { b"Content-Length": [b"%d" % (len(self.end_content))], - b"Content-Type": [b'image/png'], + b"Content-Type": [b"image/png"], } if content_disposition: headers[b"Content-Disposition"] = [content_disposition] @@ -204,7 +204,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): correctly decode it as the UTF-8 string, and use filename* in the response. """ - filename = parse.quote(u"\u2603".encode('utf8')).encode('ascii') + filename = parse.quote(u"\u2603".encode("utf8")).encode("ascii") channel = self._req(b"inline; filename*=utf-8''" + filename + b".png") headers = channel.headers diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 1ab0f7293a..1fcb0dd4df 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -55,10 +55,10 @@ class URLPreviewTests(unittest.HomeserverTestCase): hijack_auth = True user_id = "@test:user" end_content = ( - b'' + b"" b'' b'' - b'' + b"" ) def make_homeserver(self, reactor, clock): @@ -98,7 +98,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.media_repo = hs.get_media_repository_resource() - self.preview_url = self.media_repo.children[b'preview_url'] + self.preview_url = self.media_repo.children[b"preview_url"] self.lookups = {} @@ -109,7 +109,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): hostName, portNumber=0, addressTypes=None, - transportSemantics='TCP', + transportSemantics="TCP", ): resolution = HostResolution(hostName) @@ -118,7 +118,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): raise DNSLookupError("OH NO") for i in self.lookups[hostName]: - resolutionReceiver.addressResolved(i[0]('TCP', i[1], portNumber)) + resolutionReceiver.addressResolved(i[0]("TCP", i[1], portNumber)) resolutionReceiver.resolutionComplete() return resolutionReceiver @@ -184,11 +184,11 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] end_content = ( - b'' + b"" b'' b'' b'' - b'' + b"" ) request, channel = self.make_request( @@ -204,7 +204,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): client.dataReceived( ( b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" - b"Content-Type: text/html; charset=\"utf8\"\r\n\r\n" + b'Content-Type: text/html; charset="utf8"\r\n\r\n' ) % (len(end_content),) + end_content @@ -218,10 +218,10 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] end_content = ( - b'' + b"" b'' b'' - b'' + b"" ) request, channel = self.make_request( @@ -237,7 +237,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): client.dataReceived( ( b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" - b"Content-Type: text/html; charset=\"windows-1251\"\r\n\r\n" + b'Content-Type: text/html; charset="windows-1251"\r\n\r\n' ) % (len(end_content),) + end_content @@ -293,8 +293,8 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual( channel.json_body, { - 'errcode': 'M_UNKNOWN', - 'error': 'DNS resolution failure during URL preview generation', + "errcode": "M_UNKNOWN", + "error": "DNS resolution failure during URL preview generation", }, ) @@ -314,8 +314,8 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual( channel.json_body, { - 'errcode': 'M_UNKNOWN', - 'error': 'DNS resolution failure during URL preview generation', + "errcode": "M_UNKNOWN", + "error": "DNS resolution failure during URL preview generation", }, ) @@ -334,8 +334,8 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual( channel.json_body, { - 'errcode': 'M_UNKNOWN', - 'error': 'IP address blocked by IP blacklist entry', + "errcode": "M_UNKNOWN", + "error": "IP address blocked by IP blacklist entry", }, ) self.assertEqual(channel.code, 403) @@ -354,8 +354,8 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual( channel.json_body, { - 'errcode': 'M_UNKNOWN', - 'error': 'IP address blocked by IP blacklist entry', + "errcode": "M_UNKNOWN", + "error": "IP address blocked by IP blacklist entry", }, ) @@ -410,8 +410,8 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual( channel.json_body, { - 'errcode': 'M_UNKNOWN', - 'error': 'DNS resolution failure during URL preview generation', + "errcode": "M_UNKNOWN", + "error": "DNS resolution failure during URL preview generation", }, ) @@ -435,8 +435,8 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual( channel.json_body, { - 'errcode': 'M_UNKNOWN', - 'error': 'DNS resolution failure during URL preview generation', + "errcode": "M_UNKNOWN", + "error": "DNS resolution failure during URL preview generation", }, ) @@ -456,7 +456,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual( channel.json_body, { - 'errcode': 'M_UNKNOWN', - 'error': 'DNS resolution failure during URL preview generation', + "errcode": "M_UNKNOWN", + "error": "DNS resolution failure during URL preview generation", }, ) diff --git a/tests/rulecheck/test_domainrulecheck.py b/tests/rulecheck/test_domainrulecheck.py index 564fad0d77..1accc70dc9 100644 --- a/tests/rulecheck/test_domainrulecheck.py +++ b/tests/rulecheck/test_domainrulecheck.py @@ -33,7 +33,7 @@ class DomainRuleCheckerTestCase(unittest.TestCase): "source_one": ["target_one", "target_two"], "source_two": ["target_two"], }, - "domains_prevented_from_being_invited_to_published_rooms": ["target_two"] + "domains_prevented_from_being_invited_to_published_rooms": ["target_two"], } check = DomainRuleChecker(config) self.assertTrue( @@ -55,14 +55,14 @@ class DomainRuleCheckerTestCase(unittest.TestCase): # User can invite internal user to a published room self.assertTrue( check.user_may_invite( - "test:source_one", "test1:target_one", None, "room", False, True, + "test:source_one", "test1:target_one", None, "room", False, True ) ) # User can invite external user to a non-published room self.assertTrue( check.user_may_invite( - "test:source_one", "test:target_two", None, "room", False, False, + "test:source_one", "test:target_two", None, "room", False, False ) ) @@ -100,7 +100,7 @@ class DomainRuleCheckerTestCase(unittest.TestCase): # User cannot invite external user to a published room self.assertTrue( check.user_may_invite( - "test:source_one", "test:target_two", None, "room", False, True, + "test:source_one", "test:target_two", None, "room", False, True ) ) @@ -165,9 +165,7 @@ class DomainRuleCheckerRoomTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() - config["trusted_third_party_id_servers"] = [ - "localhost", - ] + config["trusted_third_party_id_servers"] = ["localhost"] config["spam_checker"] = { "module": "synapse.rulecheck.domain_rule_checker.DomainRuleChecker", @@ -302,9 +300,7 @@ class DomainRuleCheckerRoomTestCase(unittest.HomeserverTestCase): ) self.helper.join( - room_id, self.normal_user_id, - tok=self.normal_access_token, - expect_code=200, + room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200 ) self.helper.invite( @@ -318,11 +314,7 @@ class DomainRuleCheckerRoomTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "POST", "rooms/%s/invite" % (room_id), - { - "address": "foo@bar.com", - "medium": "email", - "id_server": "localhost" - }, + {"address": "foo@bar.com", "medium": "email", "id_server": "localhost"}, access_token=self.normal_access_token, ) self.render(request) diff --git a/tests/server.py b/tests/server.py index aa0958f6e9..d10c0603e9 100644 --- a/tests/server.py +++ b/tests/server.py @@ -46,7 +46,7 @@ class FakeChannel(object): def json_body(self): if not self.result: raise Exception("No result yet.") - return json.loads(self.result["body"].decode('utf8')) + return json.loads(self.result["body"].decode("utf8")) @property def code(self): @@ -151,10 +151,10 @@ def make_request( Tuple[synapse.http.site.SynapseRequest, channel] """ if not isinstance(method, bytes): - method = method.encode('ascii') + method = method.encode("ascii") if not isinstance(path, bytes): - path = path.encode('ascii') + path = path.encode("ascii") # Decorate it to be the full path, if we're using shorthand if shorthand and not path.startswith(b"/_matrix"): @@ -165,7 +165,7 @@ def make_request( path = b"/" + path if isinstance(content, text_type): - content = content.encode('utf8') + content = content.encode("utf8") site = FakeSite() channel = FakeChannel(reactor) @@ -173,11 +173,11 @@ def make_request( req = request(site, channel) req.process = lambda: b"" req.content = BytesIO(content) - req.postpath = list(map(unquote, path[1:].split(b'/'))) + req.postpath = list(map(unquote, path[1:].split(b"/"))) if access_token: req.requestHeaders.addRawHeader( - b"Authorization", b"Bearer " + access_token.encode('ascii') + b"Authorization", b"Bearer " + access_token.encode("ascii") ) if federation_auth_origin is not None: @@ -242,7 +242,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): self.nameResolver = SimpleResolverComplexifier(FakeResolver()) super(ThreadedMemoryReactorClock, self).__init__() - def listenUDP(self, port, protocol, interface='', maxPacketSize=8196): + def listenUDP(self, port, protocol, interface="", maxPacketSize=8196): p = udp.Port(port, protocol, interface, maxPacketSize, self) p.startListening() self._udp.append(p) @@ -371,7 +371,7 @@ class FakeTransport(object): disconnecting = False disconnected = False - buffer = attr.ib(default=b'') + buffer = attr.ib(default=b"") producer = attr.ib(default=None) autoflush = attr.ib(default=True) diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 739ee59ce4..e1e0c3184d 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -20,7 +20,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, ServerNoticeMsgType from synapse.api.errors import ResourceLimitError from synapse.server_notices.resource_limits_server_notices import ( - ResourceLimitsServerNotices, + ResourceLimitsServerNotices ) from tests import unittest @@ -109,7 +109,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): Test when user has blocked notice, but notice ought to be there (NOOP) """ self._rlsn._auth.check_auth_blocking = Mock( - side_effect=ResourceLimitError(403, 'foo') + side_effect=ResourceLimitError(403, "foo") ) mock_event = Mock( @@ -128,7 +128,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): """ self._rlsn._auth.check_auth_blocking = Mock( - side_effect=ResourceLimitError(403, 'foo') + side_effect=ResourceLimitError(403, "foo") ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index 9c5311d916..8d3845c870 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -181,7 +181,7 @@ class StateTestCase(unittest.TestCase): id="PB", sender=BOB, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 50}}, ), ] @@ -229,14 +229,14 @@ class StateTestCase(unittest.TestCase): id="PB", sender=BOB, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 50, CHARLIE: 50}}, ), FakeEvent( id="PC", sender=CHARLIE, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 50, CHARLIE: 0}}, ), ] @@ -256,7 +256,7 @@ class StateTestCase(unittest.TestCase): id="PA1", sender=ALICE, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( @@ -266,14 +266,14 @@ class StateTestCase(unittest.TestCase): id="PA2", sender=ALICE, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 0}}, ), FakeEvent( id="PB", sender=BOB, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( @@ -296,7 +296,7 @@ class StateTestCase(unittest.TestCase): id="PA", sender=ALICE, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( @@ -326,7 +326,7 @@ class StateTestCase(unittest.TestCase): id="PA1", sender=ALICE, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( @@ -336,14 +336,14 @@ class StateTestCase(unittest.TestCase): id="PA2", sender=ALICE, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 0}}, ), FakeEvent( id="PB", sender=BOB, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 25a6c89ef5..622b16a071 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -74,7 +74,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): namespaces={}, ) # use the token as the filename - with open(as_token, 'w') as outfile: + with open(as_token, "w") as outfile: outfile.write(yaml.dump(as_yaml)) self.as_yaml_files.append(as_token) @@ -135,7 +135,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): namespaces={}, ) # use the token as the filename - with open(as_token, 'w') as outfile: + with open(as_token, "w") as outfile: outfile.write(yaml.dump(as_yaml)) self.as_yaml_files.append(as_token) diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index b62eae7abc..59c6f8c227 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -94,11 +94,11 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): result, [ { - 'access_token': 'access_token', - 'ip': 'ip', - 'user_agent': 'user_agent', - 'device_id': None, - 'last_seen': 12345678000, + "access_token": "access_token", + "ip": "ip", + "user_agent": "user_agent", + "device_id": None, + "last_seen": 12345678000, } ], ) @@ -125,11 +125,11 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): result, [ { - 'access_token': 'access_token', - 'ip': 'ip', - 'user_agent': 'user_agent', - 'device_id': None, - 'last_seen': 12345878000, + "access_token": "access_token", + "ip": "ip", + "user_agent": "user_agent", + "device_id": None, + "last_seen": 12345878000, } ], ) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 6396ccddb5..3cc18f9f1c 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -77,12 +77,12 @@ class DeviceStoreTestCase(tests.unittest.TestCase): # Add two device updates with a single stream_id yield self.store.add_device_change_to_streams( - "user_id", device_ids, ["somehost"], + "user_id", device_ids, ["somehost"] ) # Get all device updates ever meant for this remote now_stream_id, device_updates = yield self.store.get_devices_by_remote( - "somehost", -1, limit=100, + "somehost", -1, limit=100 ) # Check original device_ids are contained within these updates @@ -95,19 +95,19 @@ class DeviceStoreTestCase(tests.unittest.TestCase): # first add one device device_ids1 = ["device_id0"] yield self.store.add_device_change_to_streams( - "user_id", device_ids1, ["someotherhost"], + "user_id", device_ids1, ["someotherhost"] ) # then add 101 device_ids2 = ["device_id" + str(i + 1) for i in range(101)] yield self.store.add_device_change_to_streams( - "user_id", device_ids2, ["someotherhost"], + "user_id", device_ids2, ["someotherhost"] ) # then one more device_ids3 = ["newdevice"] yield self.store.add_device_change_to_streams( - "user_id", device_ids3, ["someotherhost"], + "user_id", device_ids3, ["someotherhost"] ) # @@ -116,20 +116,20 @@ class DeviceStoreTestCase(tests.unittest.TestCase): # first we should get a single update now_stream_id, device_updates = yield self.store.get_devices_by_remote( - "someotherhost", -1, limit=100, + "someotherhost", -1, limit=100 ) self._check_devices_in_updates(device_ids1, device_updates) # Then we should get an empty list back as the 101 devices broke the limit now_stream_id, device_updates = yield self.store.get_devices_by_remote( - "someotherhost", now_stream_id, limit=100, + "someotherhost", now_stream_id, limit=100 ) self.assertEqual(len(device_updates), 0) # The 101 devices should've been cleared, so we should now just get one device # update now_stream_id, device_updates = yield self.store.get_devices_by_remote( - "someotherhost", now_stream_id, limit=100, + "someotherhost", now_stream_id, limit=100 ) self._check_devices_in_updates(device_ids3, device_updates) diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index cd2bcd4ca3..c8ece15284 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -80,10 +80,10 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): yield self.store.store_device("user2", "device1", None) yield self.store.store_device("user2", "device2", None) - yield self.store.set_e2e_device_keys("user1", "device1", now, 'json11') - yield self.store.set_e2e_device_keys("user1", "device2", now, 'json12') - yield self.store.set_e2e_device_keys("user2", "device1", now, 'json21') - yield self.store.set_e2e_device_keys("user2", "device2", now, 'json22') + yield self.store.set_e2e_device_keys("user1", "device1", now, "json11") + yield self.store.set_e2e_device_keys("user1", "device2", now, "json12") + yield self.store.set_e2e_device_keys("user2", "device1", now, "json21") + yield self.store.set_e2e_device_keys("user2", "device2", now, "json22") res = yield self.store.get_e2e_device_keys( (("user1", "device1"), ("user2", "device2")) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 0d4e74d637..86c7ac350d 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -27,11 +27,11 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_get_prev_events_for_room(self): - room_id = '@ROOM:local' + room_id = "@ROOM:local" # add a bunch of events and hashes to act as forward extremities def insert_event(txn, i): - event_id = '$event_%i:local' % i + event_id = "$event_%i:local" % i txn.execute( ( @@ -45,19 +45,19 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): txn.execute( ( - 'INSERT INTO event_forward_extremities (room_id, event_id) ' - 'VALUES (?, ?)' + "INSERT INTO event_forward_extremities (room_id, event_id) " + "VALUES (?, ?)" ), (room_id, event_id), ) txn.execute( ( - 'INSERT INTO event_reference_hashes ' - '(event_id, algorithm, hash) ' + "INSERT INTO event_reference_hashes " + "(event_id, algorithm, hash) " "VALUES (?, 'sha256', ?)" ), - (event_id, b'ffff'), + (event_id, b"ffff"), ) for i in range(0, 11): diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py index 19f9ccf5e0..d44359ff93 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py @@ -61,22 +61,24 @@ class ExtremStatisticsTestCase(HomeserverTestCase): ) ) - expected = set([ - b'synapse_forward_extremities_bucket{le="1.0"} 0.0', - b'synapse_forward_extremities_bucket{le="2.0"} 2.0', - b'synapse_forward_extremities_bucket{le="3.0"} 2.0', - b'synapse_forward_extremities_bucket{le="5.0"} 2.0', - b'synapse_forward_extremities_bucket{le="7.0"} 3.0', - b'synapse_forward_extremities_bucket{le="10.0"} 3.0', - b'synapse_forward_extremities_bucket{le="15.0"} 3.0', - b'synapse_forward_extremities_bucket{le="20.0"} 3.0', - b'synapse_forward_extremities_bucket{le="50.0"} 3.0', - b'synapse_forward_extremities_bucket{le="100.0"} 3.0', - b'synapse_forward_extremities_bucket{le="200.0"} 3.0', - b'synapse_forward_extremities_bucket{le="500.0"} 3.0', - b'synapse_forward_extremities_bucket{le="+Inf"} 3.0', - b'synapse_forward_extremities_count 3.0', - b'synapse_forward_extremities_sum 10.0', - ]) + expected = set( + [ + b'synapse_forward_extremities_bucket{le="1.0"} 0.0', + b'synapse_forward_extremities_bucket{le="2.0"} 2.0', + b'synapse_forward_extremities_bucket{le="3.0"} 2.0', + b'synapse_forward_extremities_bucket{le="5.0"} 2.0', + b'synapse_forward_extremities_bucket{le="7.0"} 3.0', + b'synapse_forward_extremities_bucket{le="10.0"} 3.0', + b'synapse_forward_extremities_bucket{le="15.0"} 3.0', + b'synapse_forward_extremities_bucket{le="20.0"} 3.0', + b'synapse_forward_extremities_bucket{le="50.0"} 3.0', + b'synapse_forward_extremities_bucket{le="100.0"} 3.0', + b'synapse_forward_extremities_bucket{le="200.0"} 3.0', + b'synapse_forward_extremities_bucket{le="500.0"} 3.0', + b'synapse_forward_extremities_bucket{le="+Inf"} 3.0', + b"synapse_forward_extremities_count 3.0", + b"synapse_forward_extremities_sum 10.0", + ] + ) self.assertEqual(items, expected) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index f458c03054..0ce0b991f9 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -46,9 +46,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): user3_email = "user3@matrix.org" threepids = [ - {'medium': 'email', 'address': user1_email}, - {'medium': 'email', 'address': user2_email}, - {'medium': 'email', 'address': user3_email}, + {"medium": "email", "address": user1_email}, + {"medium": "email", "address": user2_email}, + {"medium": "email", "address": user3_email}, ] # -1 because user3 is a support user and does not count user_num = len(threepids) - 1 @@ -177,7 +177,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(None) ) - self.store.populate_monthly_active_users('user_id') + self.store.populate_monthly_active_users("user_id") self.pump() self.store.upsert_monthly_active_user.assert_called_once() @@ -188,7 +188,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(self.hs.get_clock().time_msec()) ) - self.store.populate_monthly_active_users('user_id') + self.store.populate_monthly_active_users("user_id") self.pump() self.store.upsert_monthly_active_user.assert_not_called() @@ -198,13 +198,13 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.assertEquals(self.get_success(count), 0) # Test reserved users but no registered users - user1 = '@user1:example.com' - user2 = '@user2:example.com' - user1_email = 'user1@example.com' - user2_email = 'user2@example.com' + user1 = "@user1:example.com" + user2 = "@user2:example.com" + user1_email = "user1@example.com" + user2_email = "user2@example.com" threepids = [ - {'medium': 'email', 'address': user1_email}, - {'medium': 'email', 'address': user2_email}, + {"medium": "email", "address": user1_email}, + {"medium": "email", "address": user2_email}, ] self.hs.config.mau_limits_reserved_threepids = threepids self.store.runInteraction( diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index c125a0d797..13e9f8ec09 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -34,19 +34,16 @@ class ProfileStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_displayname(self): - yield self.store.set_profile_displayname( - self.u_frank.localpart, "Frank", 1, - ) + yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank", 1) self.assertEquals( - "Frank", - (yield self.store.get_profile_displayname(self.u_frank.localpart)) + "Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart)) ) @defer.inlineCallbacks def test_avatar_url(self): yield self.store.set_profile_avatar_url( - self.u_frank.localpart, "http://my.site/here", 1, + self.u_frank.localpart, "http://my.site/here", 1 ) self.assertEquals( diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index b6169436de..212a7ae765 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -76,10 +76,10 @@ class StateStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_get_state_groups_ids(self): e1 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Create, '', {} + self.room, self.u_alice, EventTypes.Create, "", {} ) e2 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"} + self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) state_group_map = yield self.store.get_state_groups_ids( @@ -89,16 +89,16 @@ class StateStoreTestCase(tests.unittest.TestCase): state_map = list(state_group_map.values())[0] self.assertDictEqual( state_map, - {(EventTypes.Create, ''): e1.event_id, (EventTypes.Name, ''): e2.event_id}, + {(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id}, ) @defer.inlineCallbacks def test_get_state_groups(self): e1 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Create, '', {} + self.room, self.u_alice, EventTypes.Create, "", {} ) e2 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"} + self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) state_group_map = yield self.store.get_state_groups(self.room, [e2.event_id]) @@ -113,10 +113,10 @@ class StateStoreTestCase(tests.unittest.TestCase): # this defaults to a linear DAG as each new injection defaults to whatever # forward extremities are currently in the DB for this room. e1 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Create, '', {} + self.room, self.u_alice, EventTypes.Create, "", {} ) e2 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"} + self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) e3 = yield self.inject_state_event( self.room, @@ -158,7 +158,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # check we can filter to the m.room.name event (with a '' state key) state = yield self.store.get_state_for_event( - e5.event_id, StateFilter.from_types([(EventTypes.Name, '')]) + e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) diff --git a/tests/test_server.py b/tests/test_server.py index 08fb3fe02f..3a7cb56479 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -69,7 +69,7 @@ class JsonResourceTests(unittest.TestCase): ) render(request, res, self.reactor) - self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]}) + self.assertEqual(request.args, {b"a": [u"\N{SNOWMAN}".encode("utf8")]}) self.assertEqual(got_kwargs, {u"room_id": u"\N{SNOWMAN}"}) def test_callback_direct_exception(self): @@ -87,7 +87,7 @@ class JsonResourceTests(unittest.TestCase): request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") render(request, res, self.reactor) - self.assertEqual(channel.result["code"], b'500') + self.assertEqual(channel.result["code"], b"500") def test_callback_indirect_exception(self): """ @@ -110,7 +110,7 @@ class JsonResourceTests(unittest.TestCase): request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") render(request, res, self.reactor) - self.assertEqual(channel.result["code"], b'500') + self.assertEqual(channel.result["code"], b"500") def test_callback_synapseerror(self): """ @@ -127,7 +127,7 @@ class JsonResourceTests(unittest.TestCase): request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") render(request, res, self.reactor) - self.assertEqual(channel.result["code"], b'403') + self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.json_body["error"], "Forbidden!!one!") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -148,7 +148,7 @@ class JsonResourceTests(unittest.TestCase): request, channel = make_request(self.reactor, b"GET", b"/_matrix/foobar") render(request, res, self.reactor) - self.assertEqual(channel.result["code"], b'400') + self.assertEqual(channel.result["code"], b"400") self.assertEqual(channel.json_body["error"], "Unrecognized request") self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") @@ -180,7 +180,7 @@ class SiteTestCase(unittest.HomeserverTestCase): # Make a resource and a Site, the resource will hang and allow us to # time out the request while it's 'processing' base_resource = Resource() - base_resource.putChild(b'', HangingResource()) + base_resource.putChild(b"", HangingResource()) site = SynapseSite("test", "site_tag", {}, base_resource, "1.0") server = site.buildProtocol(None) diff --git a/tests/test_state.py b/tests/test_state.py index 6491a7105a..6d33566f47 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -366,11 +366,11 @@ class StateTestCase(unittest.TestCase): def _add_depths(self, nodes, edges): def _get_depth(ev): node = nodes[ev] - if 'depth' not in node: + if "depth" not in node: prevs = edges[ev] depth = max(_get_depth(prev) for prev in prevs) + 1 - node['depth'] = depth - return node['depth'] + node["depth"] = depth + return node["depth"] for n in nodes: _get_depth(n) diff --git a/tests/test_types.py b/tests/test_types.py index 73d3b2cda2..1477d6a67d 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -109,9 +109,9 @@ class MapUsernameTestCase(unittest.TestCase): def testNonAscii(self): # this should work with either a unicode or a bytes - self.assertEqual(map_username_to_mxid_localpart(u'têst'), "t=c3=aast") + self.assertEqual(map_username_to_mxid_localpart(u"têst"), "t=c3=aast") self.assertEqual( - map_username_to_mxid_localpart(u'têst'.encode('utf-8')), "t=c3=aast" + map_username_to_mxid_localpart(u"têst".encode("utf-8")), "t=c3=aast" ) diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py index fde0baee8e..813f984199 100644 --- a/tests/test_utils/logging_setup.py +++ b/tests/test_utils/logging_setup.py @@ -27,7 +27,7 @@ class ToTwistedHandler(logging.Handler): def emit(self, record): log_entry = self.format(record) - log_level = record.levelname.lower().replace('warning', 'warn') + log_level = record.levelname.lower().replace("warning", "warn") self.tx_log.emit( twisted.logger.LogLevel.levelWithName(log_level), log_entry.replace("{", r"(").replace("}", r")"), diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 6a180ddc32..118c3bd238 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -265,7 +265,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): pr.disable() with open("filter_events_for_server.profile", "w+") as f: - ps = pstats.Stats(pr, stream=f).sort_stats('cumulative') + ps = pstats.Stats(pr, stream=f).sort_stats("cumulative") ps.print_stats() # the result should be 5 redacted events, and 5 unredacted events. diff --git a/tests/unittest.py b/tests/unittest.py index b6dc7932ce..251b0b850e 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -298,7 +298,7 @@ class HomeserverTestCase(TestCase): Tuple[synapse.http.site.SynapseRequest, channel] """ if isinstance(content, dict): - content = json.dumps(content).encode('utf8') + content = json.dumps(content).encode("utf8") return make_request( self.reactor, @@ -397,13 +397,13 @@ class HomeserverTestCase(TestCase): nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) - nonce_str = b"\x00".join([username.encode('utf8'), password.encode('utf8')]) + nonce_str = b"\x00".join([username.encode("utf8"), password.encode("utf8")]) if admin: nonce_str += b"\x00admin" else: nonce_str += b"\x00notadmin" - want_mac.update(nonce.encode('ascii') + b"\x00" + nonce_str) + want_mac.update(nonce.encode("ascii") + b"\x00" + nonce_str) want_mac = want_mac.hexdigest() body = json.dumps( @@ -416,7 +416,7 @@ class HomeserverTestCase(TestCase): } ) request, channel = self.make_request( - "POST", "/_matrix/client/r0/admin/register", body.encode('utf8') + "POST", "/_matrix/client/r0/admin/register", body.encode("utf8") ) self.render(request) self.assertEqual(channel.code, 200) @@ -435,7 +435,7 @@ class HomeserverTestCase(TestCase): body["device_id"] = device_id request, channel = self.make_request( - "POST", "/_matrix/client/r0/login", json.dumps(body).encode('utf8') + "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8") ) self.render(request) self.assertEqual(channel.code, 200, channel.result) @@ -481,9 +481,7 @@ class HomeserverTestCase(TestCase): if soft_failed: event.internal_metadata.soft_failed = True - self.get_success( - event_creator.send_nonmember_event(requester, event, context) - ) + self.get_success(event_creator.send_nonmember_event(requester, event, context)) return event.event_id @@ -508,7 +506,7 @@ class HomeserverTestCase(TestCase): body = {"type": "m.login.password", "user": username, "password": password} request, channel = self.make_request( - "POST", "/_matrix/client/r0/login", json.dumps(body).encode('utf8') + "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8") ) self.render(request) self.assertEqual(channel.code, 403, channel.result) diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 463a737efa..6f8f52537c 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -88,24 +88,24 @@ class DescriptorTestCase(unittest.TestCase): obj = Cls() - obj.mock.return_value = 'fish' + obj.mock.return_value = "fish" r = yield obj.fn(1, 2) - self.assertEqual(r, 'fish') + self.assertEqual(r, "fish") obj.mock.assert_called_once_with(1, 2) obj.mock.reset_mock() # a call with different params should call the mock again - obj.mock.return_value = 'chips' + obj.mock.return_value = "chips" r = yield obj.fn(1, 3) - self.assertEqual(r, 'chips') + self.assertEqual(r, "chips") obj.mock.assert_called_once_with(1, 3) obj.mock.reset_mock() # the two values should now be cached r = yield obj.fn(1, 2) - self.assertEqual(r, 'fish') + self.assertEqual(r, "fish") r = yield obj.fn(1, 3) - self.assertEqual(r, 'chips') + self.assertEqual(r, "chips") obj.mock.assert_not_called() @defer.inlineCallbacks @@ -121,25 +121,25 @@ class DescriptorTestCase(unittest.TestCase): return self.mock(arg1, arg2) obj = Cls() - obj.mock.return_value = 'fish' + obj.mock.return_value = "fish" r = yield obj.fn(1, 2) - self.assertEqual(r, 'fish') + self.assertEqual(r, "fish") obj.mock.assert_called_once_with(1, 2) obj.mock.reset_mock() # a call with different params should call the mock again - obj.mock.return_value = 'chips' + obj.mock.return_value = "chips" r = yield obj.fn(2, 3) - self.assertEqual(r, 'chips') + self.assertEqual(r, "chips") obj.mock.assert_called_once_with(2, 3) obj.mock.reset_mock() # the two values should now be cached; we should be able to vary # the second argument and still get the cached result. r = yield obj.fn(1, 4) - self.assertEqual(r, 'fish') + self.assertEqual(r, "fish") r = yield obj.fn(2, 5) - self.assertEqual(r, 'chips') + self.assertEqual(r, "chips") obj.mock.assert_not_called() def test_cache_logcontexts(self): @@ -248,30 +248,30 @@ class DescriptorTestCase(unittest.TestCase): obj = Cls() - obj.mock.return_value = 'fish' + obj.mock.return_value = "fish" r = yield obj.fn(1, 2, 3) - self.assertEqual(r, 'fish') + self.assertEqual(r, "fish") obj.mock.assert_called_once_with(1, 2, 3) obj.mock.reset_mock() # a call with same params shouldn't call the mock again r = yield obj.fn(1, 2) - self.assertEqual(r, 'fish') + self.assertEqual(r, "fish") obj.mock.assert_not_called() obj.mock.reset_mock() # a call with different params should call the mock again - obj.mock.return_value = 'chips' + obj.mock.return_value = "chips" r = yield obj.fn(2, 3) - self.assertEqual(r, 'chips') + self.assertEqual(r, "chips") obj.mock.assert_called_once_with(2, 3, 3) obj.mock.reset_mock() # the two values should now be cached r = yield obj.fn(1, 2) - self.assertEqual(r, 'fish') + self.assertEqual(r, "fish") r = yield obj.fn(2, 3) - self.assertEqual(r, 'chips') + self.assertEqual(r, "chips") obj.mock.assert_not_called() @@ -297,7 +297,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): with logcontext.LoggingContext() as c1: c1.request = "c1" obj = Cls() - obj.mock.return_value = {10: 'fish', 20: 'chips'} + obj.mock.return_value = {10: "fish", 20: "chips"} d1 = obj.list_fn([10, 20], 2) self.assertEqual( logcontext.LoggingContext.current_context(), @@ -306,26 +306,26 @@ class CachedListDescriptorTestCase(unittest.TestCase): r = yield d1 self.assertEqual(logcontext.LoggingContext.current_context(), c1) obj.mock.assert_called_once_with([10, 20], 2) - self.assertEqual(r, {10: 'fish', 20: 'chips'}) + self.assertEqual(r, {10: "fish", 20: "chips"}) obj.mock.reset_mock() # a call with different params should call the mock again - obj.mock.return_value = {30: 'peas'} + obj.mock.return_value = {30: "peas"} r = yield obj.list_fn([20, 30], 2) obj.mock.assert_called_once_with([30], 2) - self.assertEqual(r, {20: 'chips', 30: 'peas'}) + self.assertEqual(r, {20: "chips", 30: "peas"}) obj.mock.reset_mock() # all the values should now be cached r = yield obj.fn(10, 2) - self.assertEqual(r, 'fish') + self.assertEqual(r, "fish") r = yield obj.fn(20, 2) - self.assertEqual(r, 'chips') + self.assertEqual(r, "chips") r = yield obj.fn(30, 2) - self.assertEqual(r, 'peas') + self.assertEqual(r, "peas") r = yield obj.list_fn([10, 20, 30], 2) obj.mock.assert_not_called() - self.assertEqual(r, {10: 'fish', 20: 'chips', 30: 'peas'}) + self.assertEqual(r, {10: "fish", 20: "chips", 30: "peas"}) @defer.inlineCallbacks def test_invalidate(self): @@ -350,16 +350,16 @@ class CachedListDescriptorTestCase(unittest.TestCase): invalidate1 = mock.Mock() # cache miss - obj.mock.return_value = {10: 'fish', 20: 'chips'} + obj.mock.return_value = {10: "fish", 20: "chips"} r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0) obj.mock.assert_called_once_with([10, 20], 2) - self.assertEqual(r1, {10: 'fish', 20: 'chips'}) + self.assertEqual(r1, {10: "fish", 20: "chips"}) obj.mock.reset_mock() # cache hit r2 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate1) obj.mock.assert_not_called() - self.assertEqual(r2, {10: 'fish', 20: 'chips'}) + self.assertEqual(r2, {10: "fish", 20: "chips"}) invalidate0.assert_not_called() invalidate1.assert_not_called() diff --git a/tests/util/caches/test_ttlcache.py b/tests/util/caches/test_ttlcache.py index 03b3c15db6..c94cbb662b 100644 --- a/tests/util/caches/test_ttlcache.py +++ b/tests/util/caches/test_ttlcache.py @@ -27,57 +27,57 @@ class CacheTestCase(unittest.TestCase): def test_get(self): """simple set/get tests""" - self.cache.set('one', '1', 10) - self.cache.set('two', '2', 20) - self.cache.set('three', '3', 30) + self.cache.set("one", "1", 10) + self.cache.set("two", "2", 20) + self.cache.set("three", "3", 30) self.assertEqual(len(self.cache), 3) - self.assertTrue('one' in self.cache) - self.assertEqual(self.cache.get('one'), '1') - self.assertEqual(self.cache['one'], '1') - self.assertEqual(self.cache.get_with_expiry('one'), ('1', 110)) + self.assertTrue("one" in self.cache) + self.assertEqual(self.cache.get("one"), "1") + self.assertEqual(self.cache["one"], "1") + self.assertEqual(self.cache.get_with_expiry("one"), ("1", 110)) self.assertEqual(self.cache._metrics.hits, 3) self.assertEqual(self.cache._metrics.misses, 0) - self.cache.set('two', '2.5', 20) - self.assertEqual(self.cache['two'], '2.5') + self.cache.set("two", "2.5", 20) + self.assertEqual(self.cache["two"], "2.5") self.assertEqual(self.cache._metrics.hits, 4) # non-existent-item tests - self.assertEqual(self.cache.get('four', '4'), '4') - self.assertIs(self.cache.get('four', None), None) + self.assertEqual(self.cache.get("four", "4"), "4") + self.assertIs(self.cache.get("four", None), None) with self.assertRaises(KeyError): - self.cache['four'] + self.cache["four"] with self.assertRaises(KeyError): - self.cache.get('four') + self.cache.get("four") with self.assertRaises(KeyError): - self.cache.get_with_expiry('four') + self.cache.get_with_expiry("four") self.assertEqual(self.cache._metrics.hits, 4) self.assertEqual(self.cache._metrics.misses, 5) def test_expiry(self): - self.cache.set('one', '1', 10) - self.cache.set('two', '2', 20) - self.cache.set('three', '3', 30) + self.cache.set("one", "1", 10) + self.cache.set("two", "2", 20) + self.cache.set("three", "3", 30) self.assertEqual(len(self.cache), 3) - self.assertEqual(self.cache['one'], '1') - self.assertEqual(self.cache['two'], '2') + self.assertEqual(self.cache["one"], "1") + self.assertEqual(self.cache["two"], "2") # enough for the first entry to expire, but not the rest self.mock_timer.side_effect = lambda: 110.0 self.assertEqual(len(self.cache), 2) - self.assertFalse('one' in self.cache) - self.assertEqual(self.cache['two'], '2') - self.assertEqual(self.cache['three'], '3') + self.assertFalse("one" in self.cache) + self.assertEqual(self.cache["two"], "2") + self.assertEqual(self.cache["three"], "3") - self.assertEqual(self.cache.get_with_expiry('two'), ('2', 120)) + self.assertEqual(self.cache.get_with_expiry("two"), ("2", 120)) self.assertEqual(self.cache._metrics.hits, 5) self.assertEqual(self.cache._metrics.misses, 0) diff --git a/tests/utils.py b/tests/utils.py index f8c7ad2604..bd2c7c954c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -358,9 +358,9 @@ def setup_test_homeserver( # Need to let the HS build an auth handler and then mess with it # because AuthHandler's constructor requires the HS, so we can't make one # beforehand and pass it in to the HS's constructor (chicken / egg) - hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode('utf8')).hexdigest() + hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode("utf8")).hexdigest() hs.get_auth_handler().validate_hash = ( - lambda p, h: hashlib.md5(p.encode('utf8')).hexdigest() == h + lambda p, h: hashlib.md5(p.encode("utf8")).hexdigest() == h ) fed = kargs.get("resource_for_federation", None) @@ -407,7 +407,7 @@ class MockHttpResource(HttpServer): def trigger_get(self, path): return self.trigger(b"GET", path, None) - @patch('twisted.web.http.Request') + @patch("twisted.web.http.Request") @defer.inlineCallbacks def trigger( self, http_method, path, content, mock_request, federation_auth_origin=None @@ -431,12 +431,12 @@ class MockHttpResource(HttpServer): # annoyingly we return a twisted http request which has chained calls # to get at the http content, hence mock it here. mock_content = Mock() - config = {'read.return_value': content} + config = {"read.return_value": content} mock_content.configure_mock(**config) mock_request.content = mock_content - mock_request.method = http_method.encode('ascii') - mock_request.uri = path.encode('ascii') + mock_request.method = http_method.encode("ascii") + mock_request.uri = path.encode("ascii") mock_request.getClientIP.return_value = "-" @@ -452,14 +452,14 @@ class MockHttpResource(HttpServer): # add in query params to the right place try: - mock_request.args = urlparse.parse_qs(path.split('?')[1]) - mock_request.path = path.split('?')[0] + mock_request.args = urlparse.parse_qs(path.split("?")[1]) + mock_request.path = path.split("?")[0] path = mock_request.path except Exception: pass if isinstance(path, bytes): - path = path.decode('utf8') + path = path.decode("utf8") for (method, pattern, func) in self.callbacks: if http_method != method: -- cgit 1.5.1 From d8994942f28f5028e560f6aba52512fae3ca1a6a Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 12 Feb 2020 18:14:10 +0000 Subject: Return a 404 for admin api user lookup if user not found (#6901) --- changelog.d/6901.misc | 1 + synapse/rest/admin/users.py | 5 ++++- tests/rest/admin/test_user.py | 16 ++++++++++++++++ 3 files changed, 21 insertions(+), 1 deletion(-) create mode 100644 changelog.d/6901.misc (limited to 'tests') diff --git a/changelog.d/6901.misc b/changelog.d/6901.misc new file mode 100644 index 0000000000..b2f12bbe86 --- /dev/null +++ b/changelog.d/6901.misc @@ -0,0 +1 @@ +Return a 404 instead of 200 for querying information of a non-existant user through the admin API. \ No newline at end of file diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index e75c5f1370..2107b5dc56 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -21,7 +21,7 @@ from six import text_type from six.moves import http_client from synapse.api.constants import UserTypes -from synapse.api.errors import Codes, SynapseError +from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -152,6 +152,9 @@ class UserRestServletV2(RestServlet): ret = await self.admin_handler.get_user(target_user) + if not ret: + raise NotFoundError("User not found") + return 200, ret async def on_PUT(self, request, user_id): diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 3b5169b38d..490ce8f55d 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -401,6 +401,22 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("You are not a server admin", channel.json_body["error"]) + def test_user_does_not_exist(self): + """ + Tests that a lookup for a user that does not exist returns a 404 + """ + self.hs.config.registration_shared_secret = None + + request, channel = self.make_request( + "GET", + "/_synapse/admin/v2/users/@unknown_person:test", + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"]) + def test_requester_is_admin(self): """ If the user is a server admin, a new user is created. -- cgit 1.5.1 From c9aab2de34ee36cf10eb6007a657b70dfe860469 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Fri, 14 Feb 2020 12:03:12 +0000 Subject: Fix signedjson deleted method --- tests/storage/test_keys.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py index e07ff01201..cef2112974 100644 --- a/tests/storage/test_keys.py +++ b/tests/storage/test_keys.py @@ -14,6 +14,7 @@ # limitations under the License. import signedjson.key +import unpaddedbase64 from twisted.internet.defer import Deferred @@ -21,11 +22,17 @@ from synapse.storage.keys import FetchKeyResult import tests.unittest -KEY_1 = signedjson.key.decode_verify_key_base64( - "ed25519", "key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw" + +def decode_verify_key_base64(key_id: str, key_base64: str): + key_bytes = unpaddedbase64.decode_base64(key_base64) + return signedjson.key.decode_verify_key_bytes(key_id, key_bytes) + + +KEY_1 = decode_verify_key_base64( + "ed25519:key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw" ) -KEY_2 = signedjson.key.decode_verify_key_base64( - "ed25519", "key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw" +KEY_2 = decode_verify_key_base64( + "ed25519:key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw" ) @@ -121,3 +128,4 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase): res2 = res[("srv1", key_id_2)] self.assertEqual(res2.verify_key, new_key_2) self.assertEqual(res2.valid_until_ts, 300) + -- cgit 1.5.1 From 49f877d32efc79cb40b2766cb052cf35bad31de5 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Feb 2020 07:17:54 -0500 Subject: Filter the results of user directory searching via the spam checker (#6888) Add a method to the spam checker to filter the user directory results. --- changelog.d/6888.feature | 1 + docs/spam_checker.md | 3 ++ synapse/events/spamcheck.py | 27 ++++++++++ synapse/handlers/user_directory.py | 14 +++++- tests/handlers/test_user_directory.py | 92 +++++++++++++++++++++++++++++++++++ 5 files changed, 135 insertions(+), 2 deletions(-) create mode 100644 changelog.d/6888.feature (limited to 'tests') diff --git a/changelog.d/6888.feature b/changelog.d/6888.feature new file mode 100644 index 0000000000..1b7ac0c823 --- /dev/null +++ b/changelog.d/6888.feature @@ -0,0 +1 @@ +The result of a user directory search can now be filtered via the spam checker. diff --git a/docs/spam_checker.md b/docs/spam_checker.md index 97ff17f952..5b5f5000b7 100644 --- a/docs/spam_checker.md +++ b/docs/spam_checker.md @@ -54,6 +54,9 @@ class ExampleSpamChecker: def user_may_publish_room(self, userid, room_id): return True # allow publishing of all rooms + + def check_username_for_spam(self, user_profile): + return False # allow all usernames ``` ## Configuration diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 5a907718d6..0a13fca9a4 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -15,6 +15,7 @@ # limitations under the License. import inspect +from typing import Dict from synapse.spam_checker_api import SpamCheckerApi @@ -125,3 +126,29 @@ class SpamChecker(object): return True return self.spam_checker.user_may_publish_room(userid, room_id) + + def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool: + """Checks if a user ID or display name are considered "spammy" by this server. + + If the server considers a username spammy, then it will not be included in + user directory results. + + Args: + user_profile: The user information to check, it contains the keys: + * user_id + * display_name + * avatar_url + + Returns: + True if the user is spammy. + """ + if self.spam_checker is None: + return False + + # For backwards compatibility, if the method does not exist on the spam checker, fallback to not interfering. + checker = getattr(self.spam_checker, "check_username_for_spam", None) + if not checker: + return False + # Make a copy of the user profile object to ensure the spam checker + # cannot modify it. + return checker(user_profile.copy()) diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 81aa58dc8c..722760c59d 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -52,6 +52,7 @@ class UserDirectoryHandler(StateDeltasHandler): self.is_mine_id = hs.is_mine_id self.update_user_directory = hs.config.update_user_directory self.search_all_users = hs.config.user_directory_search_all_users + self.spam_checker = hs.get_spam_checker() # The current position in the current_state_delta stream self.pos = None @@ -65,7 +66,7 @@ class UserDirectoryHandler(StateDeltasHandler): # we start populating the user directory self.clock.call_later(0, self.notify_new_event) - def search_users(self, user_id, search_term, limit): + async def search_users(self, user_id, search_term, limit): """Searches for users in directory Returns: @@ -82,7 +83,16 @@ class UserDirectoryHandler(StateDeltasHandler): ] } """ - return self.store.search_user_dir(user_id, search_term, limit) + results = await self.store.search_user_dir(user_id, search_term, limit) + + # Remove any spammy users from the results. + results["results"] = [ + user + for user in results["results"] + if not self.spam_checker.check_username_for_spam(user) + ] + + return results def notify_new_event(self): """Called when there may be more deltas to process diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 26071059d2..0a4765fff4 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -147,6 +147,98 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): s = self.get_success(self.handler.search_users(u1, "user3", 10)) self.assertEqual(len(s["results"]), 0) + def test_spam_checker(self): + """ + A user which fails to the spam checks will not appear in search results. + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + + # We do not add users to the directory until they join a room. + s = self.get_success(self.handler.search_users(u1, "user2", 10)) + self.assertEqual(len(s["results"]), 0) + + room = self.helper.create_room_as(u1, is_public=False, tok=u1_token) + self.helper.invite(room, src=u1, targ=u2, tok=u1_token) + self.helper.join(room, user=u2, tok=u2_token) + + # Check we have populated the database correctly. + shares_private = self.get_users_who_share_private_rooms() + public_users = self.get_users_in_public_rooms() + + self.assertEqual( + self._compress_shared(shares_private), set([(u1, u2, room), (u2, u1, room)]) + ) + self.assertEqual(public_users, []) + + # We get one search result when searching for user2 by user1. + s = self.get_success(self.handler.search_users(u1, "user2", 10)) + self.assertEqual(len(s["results"]), 1) + + # Configure a spam checker that does not filter any users. + spam_checker = self.hs.get_spam_checker() + + class AllowAll(object): + def check_username_for_spam(self, user_profile): + # Allow all users. + return False + + spam_checker.spam_checker = AllowAll() + + # The results do not change: + # We get one search result when searching for user2 by user1. + s = self.get_success(self.handler.search_users(u1, "user2", 10)) + self.assertEqual(len(s["results"]), 1) + + # Configure a spam checker that filters all users. + class BlockAll(object): + def check_username_for_spam(self, user_profile): + # All users are spammy. + return True + + spam_checker.spam_checker = BlockAll() + + # User1 now gets no search results for any of the other users. + s = self.get_success(self.handler.search_users(u1, "user2", 10)) + self.assertEqual(len(s["results"]), 0) + + def test_legacy_spam_checker(self): + """ + A spam checker without the expected method should be ignored. + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + + # We do not add users to the directory until they join a room. + s = self.get_success(self.handler.search_users(u1, "user2", 10)) + self.assertEqual(len(s["results"]), 0) + + room = self.helper.create_room_as(u1, is_public=False, tok=u1_token) + self.helper.invite(room, src=u1, targ=u2, tok=u1_token) + self.helper.join(room, user=u2, tok=u2_token) + + # Check we have populated the database correctly. + shares_private = self.get_users_who_share_private_rooms() + public_users = self.get_users_in_public_rooms() + + self.assertEqual( + self._compress_shared(shares_private), set([(u1, u2, room), (u2, u1, room)]) + ) + self.assertEqual(public_users, []) + + # Configure a spam checker. + spam_checker = self.hs.get_spam_checker() + # The spam checker doesn't need any methods, so create a bare object. + spam_checker.spam_checker = object() + + # We get one search result when searching for user2 by user1. + s = self.get_success(self.handler.search_users(u1, "user2", 10)) + self.assertEqual(len(s["results"]), 1) + def _compress_shared(self, shared): """ Compress a list of users who share rooms dicts to a list of tuples. -- cgit 1.5.1 From 02e89021f58f931068ab0337de039181cc7f6569 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Feb 2020 09:05:43 -0500 Subject: Convert the directory handler tests to use HomeserverTestCase (#6919) Convert directory handler tests to use HomeserverTestCase. --- changelog.d/6919.misc | 1 + tests/handlers/test_directory.py | 41 +++++++++++++++++----------------------- 2 files changed, 18 insertions(+), 24 deletions(-) create mode 100644 changelog.d/6919.misc (limited to 'tests') diff --git a/changelog.d/6919.misc b/changelog.d/6919.misc new file mode 100644 index 0000000000..aa2cd89998 --- /dev/null +++ b/changelog.d/6919.misc @@ -0,0 +1 @@ +Convert the directory handler tests to use HomeserverTestCase. diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 91c7a17070..ee88cf5a4b 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -19,24 +19,16 @@ from mock import Mock from twisted.internet import defer from synapse.config.room_directory import RoomDirectoryConfig -from synapse.handlers.directory import DirectoryHandler from synapse.rest.client.v1 import directory, room from synapse.types import RoomAlias from tests import unittest -from tests.utils import setup_test_homeserver -class DirectoryHandlers(object): - def __init__(self, hs): - self.directory_handler = DirectoryHandler(hs) - - -class DirectoryTestCase(unittest.TestCase): +class DirectoryTestCase(unittest.HomeserverTestCase): """ Tests the directory service. """ - @defer.inlineCallbacks - def setUp(self): + def make_homeserver(self, reactor, clock): self.mock_federation = Mock() self.mock_registry = Mock() @@ -47,14 +39,12 @@ class DirectoryTestCase(unittest.TestCase): self.mock_registry.register_query_handler = register_query_handler - hs = yield setup_test_homeserver( - self.addCleanup, + hs = self.setup_test_homeserver( http_client=None, resource_for_federation=Mock(), federation_client=self.mock_federation, federation_registry=self.mock_registry, ) - hs.handlers = DirectoryHandlers(hs) self.handler = hs.get_handlers().directory_handler @@ -64,23 +54,25 @@ class DirectoryTestCase(unittest.TestCase): self.your_room = RoomAlias.from_string("#your-room:test") self.remote_room = RoomAlias.from_string("#another:remote") - @defer.inlineCallbacks + return hs + def test_get_local_association(self): - yield self.store.create_room_alias_association( - self.my_room, "!8765qwer:test", ["test"] + self.get_success( + self.store.create_room_alias_association( + self.my_room, "!8765qwer:test", ["test"] + ) ) - result = yield self.handler.get_association(self.my_room) + result = self.get_success(self.handler.get_association(self.my_room)) self.assertEquals({"room_id": "!8765qwer:test", "servers": ["test"]}, result) - @defer.inlineCallbacks def test_get_remote_association(self): self.mock_federation.make_query.return_value = defer.succeed( {"room_id": "!8765qwer:test", "servers": ["test", "remote"]} ) - result = yield self.handler.get_association(self.remote_room) + result = self.get_success(self.handler.get_association(self.remote_room)) self.assertEquals( {"room_id": "!8765qwer:test", "servers": ["test", "remote"]}, result @@ -93,14 +85,15 @@ class DirectoryTestCase(unittest.TestCase): ignore_backoff=True, ) - @defer.inlineCallbacks def test_incoming_fed_query(self): - yield self.store.create_room_alias_association( - self.your_room, "!8765asdf:test", ["test"] + self.get_success( + self.store.create_room_alias_association( + self.your_room, "!8765asdf:test", ["test"] + ) ) - response = yield self.query_handlers["directory"]( - {"room_alias": "#your-room:test"} + response = self.get_success( + self.handler.on_directory_query({"room_alias": "#your-room:test"}) ) self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response) -- cgit 1.5.1 From 622d0be1b303eaca3df917fe89b7355c559dccec Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Fri, 14 Feb 2020 13:17:02 +0000 Subject: Fix with isort==4.3.21 and black==19.3b0 --- docs/sphinx/conf.py | 8 ++++---- synapse/handlers/pagination.py | 2 +- synapse/handlers/profile.py | 3 +-- synapse/handlers/room_member.py | 6 +++--- synapse/http/client.py | 2 +- synapse/rest/client/v2_alpha/account_validity.py | 7 +++---- synapse/server.py | 2 +- synapse/server_notices/server_notices_sender.py | 2 +- synapse/storage/events_worker.py | 2 -- tests/handlers/test_register.py | 4 ++-- tests/server_notices/test_resource_limits_server_notices.py | 2 +- tests/storage/test_keys.py | 1 - tox.ini | 2 +- 13 files changed, 19 insertions(+), 24 deletions(-) (limited to 'tests') diff --git a/docs/sphinx/conf.py b/docs/sphinx/conf.py index ca4b879526..5c5a115ca9 100644 --- a/docs/sphinx/conf.py +++ b/docs/sphinx/conf.py @@ -12,8 +12,8 @@ # All configuration values have a default; values that are commented out # serve to show the default. -import sys import os +import sys # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the @@ -191,11 +191,11 @@ htmlhelp_basename = "Synapsedoc" latex_elements = { # The paper size ('letterpaper' or 'a4paper'). - #'papersize': 'letterpaper', + # 'papersize': 'letterpaper', # The font size ('10pt', '11pt' or '12pt'). - #'pointsize': '10pt', + # 'pointsize': '10pt', # Additional stuff for the LaTeX preamble. - #'preamble': '', + # 'preamble': '', } # Grouping the document tree into LaTeX files. List of tuples diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index bc9c9a056f..1d1fe8996d 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -22,8 +22,8 @@ from twisted.python.failure import Failure from synapse.api.constants import EventTypes, Membership from synapse.api.errors import SynapseError -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.logging.context import run_in_background +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.state import StateFilter from synapse.types import RoomStreamToken from synapse.util.async_helpers import ReadWriteLock diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index e439294ddf..7dfdce9757 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -17,7 +17,6 @@ import logging from six import raise_from - from six.moves import range from signedjson.sign import sign_json @@ -32,9 +31,9 @@ from synapse.api.errors import ( StoreError, SynapseError, ) +from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import UserID, get_domain_from_id -from synapse.logging.context import run_in_background from ._base import BaseHandler diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 0a2ff75629..2f52f2a8f2 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -599,9 +599,9 @@ class RoomMemberHandler(object): if requester is not None: sender = UserID.from_string(event.sender) - assert sender == requester.user, ( - "Sender (%s) must be same as requester (%s)" % (sender, requester.user) - ) + assert ( + sender == requester.user + ), "Sender (%s) must be same as requester (%s)" % (sender, requester.user) assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) else: requester = synapse.types.create_requester(target_user) diff --git a/synapse/http/client.py b/synapse/http/client.py index e58ecc6786..e498ba0a6c 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -45,8 +45,8 @@ from synapse.http import ( cancelled_to_request_timed_out_error, redact_uri, ) -from synapse.logging.context import make_deferred_yieldable from synapse.http.proxyagent import ProxyAgent +from synapse.logging.context import make_deferred_yieldable from synapse.util.async_helpers import timeout_deferred from synapse.util.caches import CACHE_SIZE_FACTOR diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py index 348d2a103b..77a59461bc 100644 --- a/synapse/rest/client/v2_alpha/account_validity.py +++ b/synapse/rest/client/v2_alpha/account_validity.py @@ -15,9 +15,10 @@ import logging -from twisted.internet import defer from six import ensure_binary +from twisted.internet import defer + from synapse.api.errors import AuthError, SynapseError from synapse.http.server import finish_request from synapse.http.servlet import RestServlet @@ -65,9 +66,7 @@ class AccountValidityRenewServlet(RestServlet): request.setResponseCode(status_code) request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader( - b"Content-Length", b"%d" % (len(response),) - ) + request.setHeader(b"Content-Length", b"%d" % (len(response),)) request.write(ensure_binary(response)) finish_request(request) defer.returnValue(None) diff --git a/synapse/server.py b/synapse/server.py index 8aa572d289..23be3560fa 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -94,7 +94,7 @@ from synapse.secrets import Secrets from synapse.server_notices.server_notices_manager import ServerNoticesManager from synapse.server_notices.server_notices_sender import ServerNoticesSender from synapse.server_notices.worker_server_notices_sender import ( - WorkerServerNoticesSender + WorkerServerNoticesSender, ) from synapse.state import StateHandler, StateResolutionHandler from synapse.streams.events import EventSources diff --git a/synapse/server_notices/server_notices_sender.py b/synapse/server_notices/server_notices_sender.py index ceff6c8e02..652bab58e3 100644 --- a/synapse/server_notices/server_notices_sender.py +++ b/synapse/server_notices/server_notices_sender.py @@ -16,7 +16,7 @@ from twisted.internet import defer from synapse.server_notices.consent_server_notices import ConsentServerNotices from synapse.server_notices.resource_limits_server_notices import ( - ResourceLimitsServerNotices + ResourceLimitsServerNotices, ) diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index 280f2297b9..bd3ebddcd8 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -27,8 +27,6 @@ from synapse.api.constants import EventTypes from synapse.api.errors import NotFoundError from synapse.api.room_versions import EventFormatVersions from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401 - -# these are only included to make the type annotations work from synapse.events.snapshot import EventContext # noqa: F401 from synapse.events.utils import prune_event from synapse.logging.context import ( diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index e5a9acd87f..2a60e81cbf 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -246,8 +246,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): # Multibyte unicode characters self._check_mapping( - u"j\u030a\u0065an-poppy.seed@example.com", - u"J\u030a\u0065an-Poppy Seed [Example]", + "j\u030a\u0065an-poppy.seed@example.com", + "J\u030a\u0065an-Poppy Seed [Example]", ) def _check_mapping(self, i, expected): diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index e1e0c3184d..984feb623f 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -20,7 +20,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, ServerNoticeMsgType from synapse.api.errors import ResourceLimitError from synapse.server_notices.resource_limits_server_notices import ( - ResourceLimitsServerNotices + ResourceLimitsServerNotices, ) from tests import unittest diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py index cef2112974..95f309fbbc 100644 --- a/tests/storage/test_keys.py +++ b/tests/storage/test_keys.py @@ -128,4 +128,3 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase): res2 = res[("srv1", key_id_2)] self.assertEqual(res2.verify_key, new_key_2) self.assertEqual(res2.valid_until_ts, 300) - diff --git a/tox.ini b/tox.ini index 5b8a8b3c8e..6b5190e810 100644 --- a/tox.ini +++ b/tox.ini @@ -117,7 +117,7 @@ skip_install = True basepython = python3.6 deps = flake8 - black + black==18.6b4 commands = python -m black --check --diff . /bin/sh -c "flake8 synapse tests scripts scripts-dev scripts/hash_password scripts/register_new_matrix_user scripts/synapse_port_db synctl {env:PEP8SUFFIX:}" -- cgit 1.5.1 From 43b2be9764bf244122f25a0f2bf8ec8c761f5cad Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 7 Feb 2020 13:08:34 +0000 Subject: Replace _event_dict_property with DictProperty this amounts to the same thing, but replaces `_event_dict` with `_dict`, and removes some of the function layers generated by `property`. --- synapse/events/__init__.py | 144 ++++++++++++++++++++++------------------ tests/storage/test_redaction.py | 2 +- 2 files changed, 80 insertions(+), 66 deletions(-) (limited to 'tests') diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index a842661a90..512254f65d 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -37,6 +37,65 @@ from synapse.util.frozenutils import freeze USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0")) +class DictProperty: + """An object property which delegates to the `_dict` within its parent object.""" + + __slots__ = ["key"] + + def __init__(self, key: str): + self.key = key + + def __get__(self, instance, owner=None): + # if the property is accessed as a class property rather than an instance + # property, return the property itself rather than the value + if instance is None: + return self + try: + return instance._dict[self.key] + except KeyError as e1: + # We want this to look like a regular attribute error (mostly so that + # hasattr() works correctly), so we convert the KeyError into an + # AttributeError. + # + # To exclude the KeyError from the traceback, we explicitly + # 'raise from e1.__context__' (which is better than 'raise from None', + # becuase that would omit any *earlier* exceptions). + # + raise AttributeError( + "'%s' has no '%s' property" % (type(instance), self.key) + ) from e1.__context__ + + def __set__(self, instance, v): + instance._dict[self.key] = v + + def __delete__(self, instance): + try: + del instance._dict[self.key] + except KeyError as e1: + raise AttributeError( + "'%s' has no '%s' property" % (type(instance), self.key) + ) from e1.__context__ + + +class DefaultDictProperty(DictProperty): + """An extension of DictProperty which provides a default if the property is + not present in the parent's _dict. + + Note that this means that hasattr() on the property always returns True. + """ + + __slots__ = ["default"] + + def __init__(self, key, default): + super().__init__(key) + self.default = default + + def __get__(self, instance, owner=None): + if instance is None: + return self + return instance._dict.get(self.key, self.default) + + class _EventInternalMetadata(object): def __init__(self, internal_metadata_dict): self.__dict__ = dict(internal_metadata_dict) @@ -117,51 +176,6 @@ class _EventInternalMetadata(object): return getattr(self, "redacted", False) -_SENTINEL = object() - - -def _event_dict_property(key, default=_SENTINEL): - """Creates a new property for the given key that delegates access to - `self._event_dict`. - - The default is used if the key is missing from the `_event_dict`, if given, - otherwise an AttributeError will be raised. - - Note: If a default is given then `hasattr` will always return true. - """ - - # We want to be able to use hasattr with the event dict properties. - # However, (on python3) hasattr expects AttributeError to be raised. Hence, - # we need to transform the KeyError into an AttributeError - - def getter_raises(self): - try: - return self._event_dict[key] - except KeyError: - raise AttributeError(key) - - def getter_default(self): - return self._event_dict.get(key, default) - - def setter(self, v): - try: - self._event_dict[key] = v - except KeyError: - raise AttributeError(key) - - def delete(self): - try: - del self._event_dict[key] - except KeyError: - raise AttributeError(key) - - if default is _SENTINEL: - # No default given, so use the getter that raises - return property(getter_raises, setter, delete) - else: - return property(getter_default, setter, delete) - - class EventBase(object): def __init__( self, @@ -175,23 +189,23 @@ class EventBase(object): self.unsigned = unsigned self.rejected_reason = rejected_reason - self._event_dict = event_dict + self._dict = event_dict self.internal_metadata = _EventInternalMetadata(internal_metadata_dict) - auth_events = _event_dict_property("auth_events") - depth = _event_dict_property("depth") - content = _event_dict_property("content") - hashes = _event_dict_property("hashes") - origin = _event_dict_property("origin") - origin_server_ts = _event_dict_property("origin_server_ts") - prev_events = _event_dict_property("prev_events") - redacts = _event_dict_property("redacts", None) - room_id = _event_dict_property("room_id") - sender = _event_dict_property("sender") - state_key = _event_dict_property("state_key") - type = _event_dict_property("type") - user_id = _event_dict_property("sender") + auth_events = DictProperty("auth_events") + depth = DictProperty("depth") + content = DictProperty("content") + hashes = DictProperty("hashes") + origin = DictProperty("origin") + origin_server_ts = DictProperty("origin_server_ts") + prev_events = DictProperty("prev_events") + redacts = DefaultDictProperty("redacts", None) + room_id = DictProperty("room_id") + sender = DictProperty("sender") + state_key = DictProperty("state_key") + type = DictProperty("type") + user_id = DictProperty("sender") @property def event_id(self) -> str: @@ -205,13 +219,13 @@ class EventBase(object): return hasattr(self, "state_key") and self.state_key is not None def get_dict(self) -> JsonDict: - d = dict(self._event_dict) + d = dict(self._dict) d.update({"signatures": self.signatures, "unsigned": dict(self.unsigned)}) return d def get(self, key, default=None): - return self._event_dict.get(key, default) + return self._dict.get(key, default) def get_internal_metadata_dict(self): return self.internal_metadata.get_dict() @@ -233,16 +247,16 @@ class EventBase(object): raise AttributeError("Unrecognized attribute %s" % (instance,)) def __getitem__(self, field): - return self._event_dict[field] + return self._dict[field] def __contains__(self, field): - return field in self._event_dict + return field in self._dict def items(self): - return list(self._event_dict.items()) + return list(self._dict.items()) def keys(self): - return six.iterkeys(self._event_dict) + return six.iterkeys(self._dict) def prev_event_ids(self): """Returns the list of prev event IDs. The order matches the order diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index b9ee6ec1ec..db3667dc43 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -240,7 +240,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): built_event = yield self._base_builder.build(prev_event_ids) built_event._event_id = self._event_id - built_event._event_dict["event_id"] = self._event_id + built_event._dict["event_id"] = self._event_id assert built_event.event_id == self._event_id return built_event -- cgit 1.5.1 From 3404ad289b1d2e5bc5c7f277f519b9698dbdaa15 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 17 Feb 2020 13:23:37 +0000 Subject: Raise the default power levels for invites, tombstones and server acls (#6834) --- changelog.d/6834.misc | 1 + synapse/handlers/room.py | 10 +++++++++- tests/rest/client/v1/test_rooms.py | 4 +++- 3 files changed, 13 insertions(+), 2 deletions(-) create mode 100644 changelog.d/6834.misc (limited to 'tests') diff --git a/changelog.d/6834.misc b/changelog.d/6834.misc new file mode 100644 index 0000000000..79acebe516 --- /dev/null +++ b/changelog.d/6834.misc @@ -0,0 +1 @@ +Change the default power levels of invites, tombstones and server ACLs for new rooms. \ No newline at end of file diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index ab07edd2fc..033083acac 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -64,18 +64,21 @@ class RoomCreationHandler(BaseHandler): "history_visibility": "shared", "original_invitees_have_ops": False, "guest_can_join": True, + "power_level_content_override": {"invite": 0}, }, RoomCreationPreset.TRUSTED_PRIVATE_CHAT: { "join_rules": JoinRules.INVITE, "history_visibility": "shared", "original_invitees_have_ops": True, "guest_can_join": True, + "power_level_content_override": {"invite": 0}, }, RoomCreationPreset.PUBLIC_CHAT: { "join_rules": JoinRules.PUBLIC, "history_visibility": "shared", "original_invitees_have_ops": False, "guest_can_join": False, + "power_level_content_override": {}, }, } @@ -829,19 +832,24 @@ class RoomCreationHandler(BaseHandler): # This will be reudundant on pre-MSC2260 rooms, since the # aliases event is special-cased. EventTypes.Aliases: 0, + EventTypes.Tombstone: 100, + EventTypes.ServerACL: 100, }, "events_default": 0, "state_default": 50, "ban": 50, "kick": 50, "redact": 50, - "invite": 0, + "invite": 50, } if config["original_invitees_have_ops"]: for invitee in invite_list: power_level_content["users"][invitee] = 100 + # Power levels overrides are defined per chat preset + power_level_content.update(config["power_level_content_override"]) + if power_level_content_override: power_level_content.update(power_level_content_override) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index e3af280ba6..fb681a1db9 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1612,7 +1612,9 @@ class ContextTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): self.user_id = self.register_user("user", "password") self.tok = self.login("user", "password") - self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + self.room_id = self.helper.create_room_as( + self.user_id, tok=self.tok, is_public=False + ) self.other_user_id = self.register_user("user2", "password") self.other_tok = self.login("user2", "password") -- cgit 1.5.1 From fe3941f6e33a17fa7cdf209a4370f4e805341db4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 18 Feb 2020 07:29:44 -0500 Subject: Stop sending events when creating or deleting aliases (#6904) Stop sending events when creating or deleting associations (room aliases). Send an updated canonical alias event if one of the alt_aliases is deleted. --- changelog.d/6904.removal | 1 + synapse/handlers/directory.py | 75 ++++++++++--------- synapse/handlers/room.py | 6 +- tests/handlers/test_directory.py | 154 ++++++++++++++++++++++++++++++++++++++- 4 files changed, 194 insertions(+), 42 deletions(-) create mode 100644 changelog.d/6904.removal (limited to 'tests') diff --git a/changelog.d/6904.removal b/changelog.d/6904.removal new file mode 100644 index 0000000000..a5cc0c3605 --- /dev/null +++ b/changelog.d/6904.removal @@ -0,0 +1 @@ +Stop sending alias events during adding / removing aliases. Check alt_aliases in the latest canonical aliases event when deleting an alias. diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 8c5980cb0c..f718388884 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -81,13 +81,7 @@ class DirectoryHandler(BaseHandler): @defer.inlineCallbacks def create_association( - self, - requester, - room_alias, - room_id, - servers=None, - send_event=True, - check_membership=True, + self, requester, room_alias, room_id, servers=None, check_membership=True, ): """Attempt to create a new alias @@ -97,7 +91,6 @@ class DirectoryHandler(BaseHandler): room_id (str) servers (list[str]|None): List of servers that others servers should try and join via - send_event (bool): Whether to send an updated m.room.aliases event check_membership (bool): Whether to check if the user is in the room before the alias can be set (if the server's config requires it). @@ -150,16 +143,9 @@ class DirectoryHandler(BaseHandler): ) yield self._create_association(room_alias, room_id, servers, creator=user_id) - if send_event: - try: - yield self.send_room_alias_update_event(requester, room_id) - except AuthError as e: - # sending the aliases event may fail due to the user not having - # permission in the room; this is permitted. - logger.info("Skipping updating aliases event due to auth error %s", e) @defer.inlineCallbacks - def delete_association(self, requester, room_alias, send_event=True): + def delete_association(self, requester, room_alias): """Remove an alias from the directory (this is only meant for human users; AS users should call @@ -168,9 +154,6 @@ class DirectoryHandler(BaseHandler): Args: requester (Requester): room_alias (RoomAlias): - send_event (bool): Whether to send an updated m.room.aliases event. - Note that, if we delete the canonical alias, we will always attempt - to send an m.room.canonical_alias event Returns: Deferred[unicode]: room id that the alias used to point to @@ -206,9 +189,6 @@ class DirectoryHandler(BaseHandler): room_id = yield self._delete_association(room_alias) try: - if send_event: - yield self.send_room_alias_update_event(requester, room_id) - yield self._update_canonical_alias( requester, requester.user.to_string(), room_id, room_alias ) @@ -319,25 +299,50 @@ class DirectoryHandler(BaseHandler): @defer.inlineCallbacks def _update_canonical_alias(self, requester, user_id, room_id, room_alias): + """ + Send an updated canonical alias event if the removed alias was set as + the canonical alias or listed in the alt_aliases field. + """ alias_event = yield self.state.get_current_state( room_id, EventTypes.CanonicalAlias, "" ) - alias_str = room_alias.to_string() - if not alias_event or alias_event.content.get("alias", "") != alias_str: + # There is no canonical alias, nothing to do. + if not alias_event: return - yield self.event_creation_handler.create_and_send_nonmember_event( - requester, - { - "type": EventTypes.CanonicalAlias, - "state_key": "", - "room_id": room_id, - "sender": user_id, - "content": {}, - }, - ratelimit=False, - ) + # Obtain a mutable version of the event content. + content = dict(alias_event.content) + send_update = False + + # Remove the alias property if it matches the removed alias. + alias_str = room_alias.to_string() + if alias_event.content.get("alias", "") == alias_str: + send_update = True + content.pop("alias", "") + + # Filter alt_aliases for the removed alias. + alt_aliases = content.pop("alt_aliases", None) + # If the aliases are not a list (or not found) do not attempt to modify + # the list. + if isinstance(alt_aliases, list): + send_update = True + alt_aliases = [alias for alias in alt_aliases if alias != alias_str] + if alt_aliases: + content["alt_aliases"] = alt_aliases + + if send_update: + yield self.event_creation_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.CanonicalAlias, + "state_key": "", + "room_id": room_id, + "sender": user_id, + "content": content, + }, + ratelimit=False, + ) @defer.inlineCallbacks def get_association_from_room_alias(self, room_alias): diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 033083acac..49ec2f48bc 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -478,9 +478,7 @@ class RoomCreationHandler(BaseHandler): for alias_str in aliases: alias = RoomAlias.from_string(alias_str) try: - yield directory_handler.delete_association( - requester, alias, send_event=False - ) + yield directory_handler.delete_association(requester, alias) removed_aliases.append(alias_str) except SynapseError as e: logger.warning("Unable to remove alias %s from old room: %s", alias, e) @@ -511,7 +509,6 @@ class RoomCreationHandler(BaseHandler): RoomAlias.from_string(alias), new_room_id, servers=(self.hs.hostname,), - send_event=False, check_membership=False, ) logger.info("Moved alias %s to new room", alias) @@ -664,7 +661,6 @@ class RoomCreationHandler(BaseHandler): room_id=room_id, room_alias=room_alias, servers=[self.hs.hostname], - send_event=False, check_membership=False, ) diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index ee88cf5a4b..27b916aed4 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -18,9 +18,11 @@ from mock import Mock from twisted.internet import defer +import synapse.api.errors +from synapse.api.constants import EventTypes from synapse.config.room_directory import RoomDirectoryConfig -from synapse.rest.client.v1 import directory, room -from synapse.types import RoomAlias +from synapse.rest.client.v1 import directory, login, room +from synapse.types import RoomAlias, create_requester from tests import unittest @@ -85,6 +87,38 @@ class DirectoryTestCase(unittest.HomeserverTestCase): ignore_backoff=True, ) + def test_delete_alias_not_allowed(self): + room_id = "!8765qwer:test" + self.get_success( + self.store.create_room_alias_association(self.my_room, room_id, ["test"]) + ) + + self.get_failure( + self.handler.delete_association( + create_requester("@user:test"), self.my_room + ), + synapse.api.errors.AuthError, + ) + + def test_delete_alias(self): + room_id = "!8765qwer:test" + user_id = "@user:test" + self.get_success( + self.store.create_room_alias_association( + self.my_room, room_id, ["test"], user_id + ) + ) + + result = self.get_success( + self.handler.delete_association(create_requester(user_id), self.my_room) + ) + self.assertEquals(room_id, result) + + # The alias should not be found. + self.get_failure( + self.handler.get_association(self.my_room), synapse.api.errors.SynapseError + ) + def test_incoming_fed_query(self): self.get_success( self.store.create_room_alias_association( @@ -99,6 +133,122 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response) +class CanonicalAliasTestCase(unittest.HomeserverTestCase): + """Test modifications of the canonical alias when delete aliases. + """ + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + directory.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.handler = hs.get_handlers().directory_handler + self.state_handler = hs.get_state_handler() + + # Create user + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + # Create a test room + self.room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + + self.test_alias = "#test:test" + self.room_alias = RoomAlias.from_string(self.test_alias) + + # Create a new alias to this room. + self.get_success( + self.store.create_room_alias_association( + self.room_alias, self.room_id, ["test"], self.admin_user + ) + ) + + def test_remove_alias(self): + """Removing an alias that is the canonical alias should remove it there too.""" + # Set this new alias as the canonical alias for this room + self.helper.send_state( + self.room_id, + "m.room.canonical_alias", + {"alias": self.test_alias, "alt_aliases": [self.test_alias]}, + tok=self.admin_user_tok, + ) + + data = self.get_success( + self.state_handler.get_current_state( + self.room_id, EventTypes.CanonicalAlias, "" + ) + ) + self.assertEqual(data["content"]["alias"], self.test_alias) + self.assertEqual(data["content"]["alt_aliases"], [self.test_alias]) + + # Finally, delete the alias. + self.get_success( + self.handler.delete_association( + create_requester(self.admin_user), self.room_alias + ) + ) + + data = self.get_success( + self.state_handler.get_current_state( + self.room_id, EventTypes.CanonicalAlias, "" + ) + ) + self.assertNotIn("alias", data["content"]) + self.assertNotIn("alt_aliases", data["content"]) + + def test_remove_other_alias(self): + """Removing an alias listed as in alt_aliases should remove it there too.""" + # Create a second alias. + other_test_alias = "#test2:test" + other_room_alias = RoomAlias.from_string(other_test_alias) + self.get_success( + self.store.create_room_alias_association( + other_room_alias, self.room_id, ["test"], self.admin_user + ) + ) + + # Set the alias as the canonical alias for this room. + self.helper.send_state( + self.room_id, + "m.room.canonical_alias", + { + "alias": self.test_alias, + "alt_aliases": [self.test_alias, other_test_alias], + }, + tok=self.admin_user_tok, + ) + + data = self.get_success( + self.state_handler.get_current_state( + self.room_id, EventTypes.CanonicalAlias, "" + ) + ) + self.assertEqual(data["content"]["alias"], self.test_alias) + self.assertEqual( + data["content"]["alt_aliases"], [self.test_alias, other_test_alias] + ) + + # Delete the second alias. + self.get_success( + self.handler.delete_association( + create_requester(self.admin_user), other_room_alias + ) + ) + + data = self.get_success( + self.state_handler.get_current_state( + self.room_id, EventTypes.CanonicalAlias, "" + ) + ) + self.assertEqual(data["content"]["alias"], self.test_alias) + self.assertEqual(data["content"]["alt_aliases"], [self.test_alias]) + + class TestCreateAliasACL(unittest.HomeserverTestCase): user_id = "@test:test" -- cgit 1.5.1 From adfaea8c698a38ffe14ac682a946abc9f8152635 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 18 Feb 2020 16:23:25 +0000 Subject: Implement GET /_matrix/client/r0/rooms/{roomId}/aliases (#6939) per matrix-org/matrix-doc#2432 --- changelog.d/6939.feature | 1 + synapse/handlers/directory.py | 17 ++++++++- synapse/rest/client/v1/room.py | 23 +++++++++++++ tests/rest/client/v1/test_rooms.py | 70 +++++++++++++++++++++++++++++++++++++- tests/unittest.py | 28 ++++++++++----- 5 files changed, 128 insertions(+), 11 deletions(-) create mode 100644 changelog.d/6939.feature (limited to 'tests') diff --git a/changelog.d/6939.feature b/changelog.d/6939.feature new file mode 100644 index 0000000000..40fe7fc9a9 --- /dev/null +++ b/changelog.d/6939.feature @@ -0,0 +1 @@ +Implement `GET /_matrix/client/r0/rooms/{roomId}/aliases` endpoint as per [MSC2432](https://github.com/matrix-org/matrix-doc/pull/2432). diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index f718388884..3f8c792149 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -16,6 +16,7 @@ import logging import string +from typing import List from twisted.internet import defer @@ -28,7 +29,7 @@ from synapse.api.errors import ( StoreError, SynapseError, ) -from synapse.types import RoomAlias, UserID, get_domain_from_id +from synapse.types import Requester, RoomAlias, UserID, get_domain_from_id from ._base import BaseHandler @@ -452,3 +453,17 @@ class DirectoryHandler(BaseHandler): yield self.store.set_room_is_public_appservice( room_id, appservice_id, network_id, visibility == "public" ) + + async def get_aliases_for_room( + self, requester: Requester, room_id: str + ) -> List[str]: + """ + Get a list of the aliases that currently point to this room on this server + """ + # allow access to server admins and current members of the room + is_admin = await self.auth.is_server_admin(requester.user) + if not is_admin: + await self.auth.check_joined_room(room_id, requester.user.to_string()) + + aliases = await self.store.get_aliases_for_room(room_id) + return aliases diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 6f31584c51..143dc738c6 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -45,6 +45,10 @@ from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID +MYPY = False +if MYPY: + import synapse.server + logger = logging.getLogger(__name__) @@ -843,6 +847,24 @@ class RoomTypingRestServlet(RestServlet): return 200, {} +class RoomAliasListServlet(RestServlet): + PATTERNS = client_patterns("/rooms/(?P[^/]*)/aliases", unstable=False) + + def __init__(self, hs: "synapse.server.HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.directory_handler = hs.get_handlers().directory_handler + + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request) + + alias_list = await self.directory_handler.get_aliases_for_room( + requester, room_id + ) + + return 200, {"aliases": alias_list} + + class SearchRestServlet(RestServlet): PATTERNS = client_patterns("/search$", v1=True) @@ -931,6 +953,7 @@ def register_servlets(hs, http_server): JoinedRoomsRestServlet(hs).register(http_server) RoomEventServlet(hs).register(http_server) RoomEventContextServlet(hs).register(http_server) + RoomAliasListServlet(hs).register(http_server) def register_deprecated_servlets(hs, http_server): diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index fb681a1db9..fb08a45d27 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -28,8 +28,9 @@ from twisted.internet import defer import synapse.rest.admin from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.handlers.pagination import PurgeStatus -from synapse.rest.client.v1 import login, profile, room +from synapse.rest.client.v1 import directory, login, profile, room from synapse.rest.client.v2_alpha import account +from synapse.types import JsonDict, RoomAlias from synapse.util.stringutils import random_string from tests import unittest @@ -1726,3 +1727,70 @@ class ContextTestCase(unittest.HomeserverTestCase): self.assertEqual(len(events_after), 2, events_after) self.assertDictEqual(events_after[0].get("content"), {}, events_after[0]) self.assertEqual(events_after[1].get("content"), {}, events_after[1]) + + +class DirectoryTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + directory.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.room_owner = self.register_user("room_owner", "test") + self.room_owner_tok = self.login("room_owner", "test") + + self.room_id = self.helper.create_room_as( + self.room_owner, tok=self.room_owner_tok + ) + + def test_no_aliases(self): + res = self._get_aliases(self.room_owner_tok) + self.assertEqual(res["aliases"], []) + + def test_not_in_room(self): + self.register_user("user", "test") + user_tok = self.login("user", "test") + res = self._get_aliases(user_tok, expected_code=403) + self.assertEqual(res["errcode"], "M_FORBIDDEN") + + def test_with_aliases(self): + alias1 = self._random_alias() + alias2 = self._random_alias() + + self._set_alias_via_directory(alias1) + self._set_alias_via_directory(alias2) + + res = self._get_aliases(self.room_owner_tok) + self.assertEqual(set(res["aliases"]), {alias1, alias2}) + + def _get_aliases(self, access_token: str, expected_code: int = 200) -> JsonDict: + """Calls the endpoint under test. returns the json response object.""" + request, channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/%s/aliases" % (self.room_id,), + access_token=access_token, + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + res = channel.json_body + self.assertIsInstance(res, dict) + if expected_code == 200: + self.assertIsInstance(res["aliases"], list) + return res + + def _random_alias(self) -> str: + return RoomAlias(random_string(5), self.hs.hostname).to_string() + + def _set_alias_via_directory(self, alias: str, expected_code: int = 200): + url = "/_matrix/client/r0/directory/room/" + alias + data = {"room_id": self.room_id} + request_data = json.dumps(data) + + request, channel = self.make_request( + "PUT", url, request_data, access_token=self.room_owner_tok + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) diff --git a/tests/unittest.py b/tests/unittest.py index 98bf27d39c..8816a4d152 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -21,6 +21,7 @@ import hmac import inspect import logging import time +from typing import Optional, Tuple, Type, TypeVar, Union from mock import Mock @@ -42,7 +43,13 @@ from synapse.server import HomeServer from synapse.types import Requester, UserID, create_requester from synapse.util.ratelimitutils import FederationRateLimiter -from tests.server import get_clock, make_request, render, setup_test_homeserver +from tests.server import ( + FakeChannel, + get_clock, + make_request, + render, + setup_test_homeserver, +) from tests.test_utils.logging_setup import setup_logging from tests.utils import default_config, setupdb @@ -71,6 +78,9 @@ def around(target): return _around +T = TypeVar("T") + + class TestCase(unittest.TestCase): """A subclass of twisted.trial's TestCase which looks for 'loglevel' attributes on both itself and its individual test methods, to override the @@ -334,14 +344,14 @@ class HomeserverTestCase(TestCase): def make_request( self, - method, - path, - content=b"", - access_token=None, - request=SynapseRequest, - shorthand=True, - federation_auth_origin=None, - ): + method: Union[bytes, str], + path: Union[bytes, str], + content: Union[bytes, dict] = b"", + access_token: Optional[str] = None, + request: Type[T] = SynapseRequest, + shorthand: bool = True, + federation_auth_origin: str = None, + ) -> Tuple[T, FakeChannel]: """ Create a SynapseRequest at the path using the method and containing the given content. -- cgit 1.5.1 From b58d17e44f9d9ff7e70578e0f4e328bb9113ec7e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 18 Feb 2020 23:13:29 +0000 Subject: Refactor the membership check methods in Auth these were getting a bit unwieldy, so let's combine `check_joined_room` and `check_user_was_in_room` into a single `check_user_in_room`. --- synapse/api/auth.py | 80 +++++++++++++++++++--------------------- synapse/handlers/initial_sync.py | 31 +++------------- synapse/handlers/typing.py | 4 +- tests/handlers/test_typing.py | 4 +- 4 files changed, 46 insertions(+), 73 deletions(-) (limited to 'tests') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 8b1277ad02..de7b75ca36 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import Optional from six import itervalues @@ -35,6 +36,7 @@ from synapse.api.errors import ( ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.config.server import is_threepid_reserved +from synapse.events import EventBase from synapse.types import StateMap, UserID from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache from synapse.util.caches.lrucache import LruCache @@ -92,20 +94,34 @@ class Auth(object): ) @defer.inlineCallbacks - def check_joined_room(self, room_id, user_id, current_state=None): - """Check if the user is currently joined in the room + def check_user_in_room( + self, + room_id: str, + user_id: str, + current_state: Optional[StateMap[EventBase]] = None, + allow_departed_users: bool = False, + ): + """Check if the user is in the room, or was at some point. Args: - room_id(str): The room to check. - user_id(str): The user to check. - current_state(dict): Optional map of the current state of the room. + room_id: The room to check. + + user_id: The user to check. + + current_state: Optional map of the current state of the room. If provided then that map is used to check whether they are a member of the room. Otherwise the current membership is loaded from the database. + + allow_departed_users: if True, accept users that were previously + members but have now departed. + Raises: - AuthError if the user is not in the room. + AuthError if the user is/was not in the room. Returns: - A deferred membership event for the user if the user is in - the room. + Deferred[Optional[EventBase]]: + Membership event for the user if the user was in the + room. This will be the join event if they are currently joined to + the room. This will be the leave event if they have left the room. """ if current_state: member = current_state.get((EventTypes.Member, user_id), None) @@ -113,37 +129,19 @@ class Auth(object): member = yield self.state.get_current_state( room_id=room_id, event_type=EventTypes.Member, state_key=user_id ) - - self._check_joined_room(member, user_id, room_id) - return member - - @defer.inlineCallbacks - def check_user_was_in_room(self, room_id, user_id): - """Check if the user was in the room at some point. - Args: - room_id(str): The room to check. - user_id(str): The user to check. - Raises: - AuthError if the user was never in the room. - Returns: - A deferred membership event for the user if the user was in the - room. This will be the join event if they are currently joined to - the room. This will be the leave event if they have left the room. - """ - member = yield self.state.get_current_state( - room_id=room_id, event_type=EventTypes.Member, state_key=user_id - ) membership = member.membership if member else None - if membership not in (Membership.JOIN, Membership.LEAVE): - raise AuthError(403, "User %s not in room %s" % (user_id, room_id)) + if membership == Membership.JOIN: + return member - if membership == Membership.LEAVE: + # XXX this looks totally bogus. Why do we not allow users who have been banned, + # or those who were members previously and have been re-invited? + if allow_departed_users and membership == Membership.LEAVE: forgot = yield self.store.did_forget(user_id, room_id) - if forgot: - raise AuthError(403, "User %s not in room %s" % (user_id, room_id)) + if not forgot: + return member - return member + raise AuthError(403, "User %s not in room %s" % (user_id, room_id)) @defer.inlineCallbacks def check_host_in_room(self, room_id, host): @@ -151,12 +149,6 @@ class Auth(object): latest_event_ids = yield self.store.is_host_joined(room_id, host) return latest_event_ids - def _check_joined_room(self, member, user_id, room_id): - if not member or member.membership != Membership.JOIN: - raise AuthError( - 403, "User %s not in room %s (%s)" % (user_id, room_id, repr(member)) - ) - def can_federate(self, event, auth_events): creation_event = auth_events.get((EventTypes.Create, "")) @@ -560,7 +552,7 @@ class Auth(object): return True user_id = user.to_string() - yield self.check_joined_room(room_id, user_id) + yield self.check_user_in_room(room_id, user_id) # We currently require the user is a "moderator" in the room. We do this # by checking if they would (theoretically) be able to change the @@ -645,12 +637,14 @@ class Auth(object): """ try: - # check_user_was_in_room will return the most recent membership + # check_user_in_room will return the most recent membership # event for the user if: # * The user is a non-guest user, and was ever in the room # * The user is a guest user, and has joined the room # else it will throw. - member_event = yield self.check_user_was_in_room(room_id, user_id) + member_event = yield self.check_user_in_room( + room_id, user_id, allow_departed_users=True + ) return member_event.membership, member_event.event_id except AuthError: visibility = yield self.state.get_current_state( diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 2e6755f19c..b7c6a921d9 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -18,7 +18,7 @@ import logging from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import AuthError, Codes, SynapseError +from synapse.api.errors import SynapseError from synapse.events.validator import EventValidator from synapse.handlers.presence import format_user_presence_state from synapse.logging.context import make_deferred_yieldable, run_in_background @@ -274,9 +274,10 @@ class InitialSyncHandler(BaseHandler): user_id = requester.user.to_string() - membership, member_event_id = await self._check_in_room_or_world_readable( - room_id, user_id - ) + ( + membership, + member_event_id, + ) = await self.auth.check_user_in_room_or_world_readable(room_id, user_id) is_peeking = member_event_id is None if membership == Membership.JOIN: @@ -433,25 +434,3 @@ class InitialSyncHandler(BaseHandler): ret["membership"] = membership return ret - - async def _check_in_room_or_world_readable(self, room_id, user_id): - try: - # check_user_was_in_room will return the most recent membership - # event for the user if: - # * The user is a non-guest user, and was ever in the room - # * The user is a guest user, and has joined the room - # else it will throw. - member_event = await self.auth.check_user_was_in_room(room_id, user_id) - return member_event.membership, member_event.event_id - except AuthError: - visibility = await self.state_handler.get_current_state( - room_id, EventTypes.RoomHistoryVisibility, "" - ) - if ( - visibility - and visibility.content["history_visibility"] == "world_readable" - ): - return Membership.JOIN, None - raise AuthError( - 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN - ) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index d5ca9cb07b..5406618431 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -125,7 +125,7 @@ class TypingHandler(object): if target_user_id != auth_user_id: raise AuthError(400, "Cannot set another user's typing state") - yield self.auth.check_joined_room(room_id, target_user_id) + yield self.auth.check_user_in_room(room_id, target_user_id) logger.debug("%s has started typing in %s", target_user_id, room_id) @@ -155,7 +155,7 @@ class TypingHandler(object): if target_user_id != auth_user_id: raise AuthError(400, "Cannot set another user's typing state") - yield self.auth.check_joined_room(room_id, target_user_id) + yield self.auth.check_user_in_room(room_id, target_user_id) logger.debug("%s has stopped typing in %s", target_user_id, room_id) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 2767b0497a..140cc0a3c2 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -122,11 +122,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.room_members = [] - def check_joined_room(room_id, user_id): + def check_user_in_room(room_id, user_id): if user_id not in [u.to_string() for u in self.room_members]: raise AuthError(401, "User is not in the room") - hs.get_auth().check_joined_room = check_joined_room + hs.get_auth().check_user_in_room = check_user_in_room def get_joined_hosts_for_room(room_id): return set(member.domain for member in self.room_members) -- cgit 1.5.1 From 709e81f5183d8ff67d86f4569234cb4a8be7a8d4 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 18 Feb 2020 23:15:54 +0000 Subject: Make room alias lists peekable As per https://github.com/matrix-org/matrix-doc/pull/2432#pullrequestreview-360566830, make room alias lists accessible to users outside world_readable rooms. --- synapse/handlers/directory.py | 4 +++- tests/rest/client/v1/test_rooms.py | 17 +++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 3f8c792149..db2104c5f6 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -463,7 +463,9 @@ class DirectoryHandler(BaseHandler): # allow access to server admins and current members of the room is_admin = await self.auth.is_server_admin(requester.user) if not is_admin: - await self.auth.check_joined_room(room_id, requester.user.to_string()) + await self.auth.check_user_in_room_or_world_readable( + room_id, requester.user.to_string() + ) aliases = await self.store.get_aliases_for_room(room_id) return aliases diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index fb08a45d27..8e389eb6c9 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1766,6 +1766,23 @@ class DirectoryTestCase(unittest.HomeserverTestCase): res = self._get_aliases(self.room_owner_tok) self.assertEqual(set(res["aliases"]), {alias1, alias2}) + def test_peekable_room(self): + alias1 = self._random_alias() + self._set_alias_via_directory(alias1) + + self.helper.send_state( + self.room_id, + EventTypes.RoomHistoryVisibility, + body={"history_visibility": "world_readable"}, + tok=self.room_owner_tok, + ) + + self.register_user("user", "test") + user_tok = self.login("user", "test") + + res = self._get_aliases(user_tok) + self.assertEqual(res["aliases"], [alias1]) + def _get_aliases(self, access_token: str, expected_code: int = 200) -> JsonDict: """Calls the endpoint under test. returns the json response object.""" request, channel = self.make_request( -- cgit 1.5.1 From 880aaac1d82695b1a89f22f1f86c7f295ca205e0 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 19 Feb 2020 10:40:27 +0000 Subject: Move MSC2432 stuff onto unstable prefix (#6948) it's not in the spec yet, so needs to be unstable. Also add a feature flag for it. Also add a test for admin users. --- changelog.d/6948.feature | 1 + synapse/rest/client/v1/room.py | 8 +++++++- synapse/rest/client/versions.py | 2 ++ tests/rest/client/v1/test_rooms.py | 16 +++++++++++++--- 4 files changed, 23 insertions(+), 4 deletions(-) create mode 100644 changelog.d/6948.feature (limited to 'tests') diff --git a/changelog.d/6948.feature b/changelog.d/6948.feature new file mode 100644 index 0000000000..40fe7fc9a9 --- /dev/null +++ b/changelog.d/6948.feature @@ -0,0 +1 @@ +Implement `GET /_matrix/client/r0/rooms/{roomId}/aliases` endpoint as per [MSC2432](https://github.com/matrix-org/matrix-doc/pull/2432). diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 143dc738c6..64f51406fb 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -16,6 +16,7 @@ """ This module contains REST servlets to do with rooms: /rooms/ """ import logging +import re from typing import List, Optional from six.moves.urllib import parse as urlparse @@ -848,7 +849,12 @@ class RoomTypingRestServlet(RestServlet): class RoomAliasListServlet(RestServlet): - PATTERNS = client_patterns("/rooms/(?P[^/]*)/aliases", unstable=False) + PATTERNS = [ + re.compile( + r"^/_matrix/client/unstable/org\.matrix\.msc2432" + r"/rooms/(?P[^/]*)/aliases" + ), + ] def __init__(self, hs: "synapse.server.HomeServer"): super().__init__() diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 3eeb3607f4..d90a6a890b 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -72,6 +72,8 @@ class VersionsRestServlet(RestServlet): "org.matrix.label_based_filtering": True, # Implements support for cross signing as described in MSC1756 "org.matrix.e2e_cross_signing": True, + # Implements additional endpoints as described in MSC2432 + "org.matrix.msc2432": True, }, }, ) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index fb08a45d27..f82655677c 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1729,8 +1729,7 @@ class ContextTestCase(unittest.HomeserverTestCase): self.assertEqual(events_after[1].get("content"), {}, events_after[1]) -class DirectoryTestCase(unittest.HomeserverTestCase): - +class RoomAliasListTestCase(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, directory.register_servlets, @@ -1756,6 +1755,16 @@ class DirectoryTestCase(unittest.HomeserverTestCase): res = self._get_aliases(user_tok, expected_code=403) self.assertEqual(res["errcode"], "M_FORBIDDEN") + def test_admin_user(self): + alias1 = self._random_alias() + self._set_alias_via_directory(alias1) + + self.register_user("user", "test", admin=True) + user_tok = self.login("user", "test") + + res = self._get_aliases(user_tok) + self.assertEqual(res["aliases"], [alias1]) + def test_with_aliases(self): alias1 = self._random_alias() alias2 = self._random_alias() @@ -1770,7 +1779,8 @@ class DirectoryTestCase(unittest.HomeserverTestCase): """Calls the endpoint under test. returns the json response object.""" request, channel = self.make_request( "GET", - "/_matrix/client/r0/rooms/%s/aliases" % (self.room_id,), + "/_matrix/client/unstable/org.matrix.msc2432/rooms/%s/aliases" + % (self.room_id,), access_token=access_token, ) self.render(request) -- cgit 1.5.1 From 2b37eabca1e9355e2e2ab8f65bbdda12431ecc28 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 19 Feb 2020 15:04:47 +0000 Subject: Reduce auth chains fetched during v2 state res. (#6952) The state res v2 algorithm only cares about the difference between auth chains, so we can pass in the known common state to the `get_auth_chain` storage function so that it can ignore those events. --- changelog.d/6952.misc | 1 + synapse/state/__init__.py | 15 ++++++++---- synapse/state/v2.py | 2 +- .../storage/data_stores/main/event_federation.py | 28 ++++++++++++++++++---- tests/state/test_v2.py | 6 +++-- 5 files changed, 39 insertions(+), 13 deletions(-) create mode 100644 changelog.d/6952.misc (limited to 'tests') diff --git a/changelog.d/6952.misc b/changelog.d/6952.misc new file mode 100644 index 0000000000..e26dc5cab8 --- /dev/null +++ b/changelog.d/6952.misc @@ -0,0 +1 @@ +Improve perf of v2 state res for large rooms. diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index fdd6bef6b4..df7a4f6a89 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -16,7 +16,7 @@ import logging from collections import namedtuple -from typing import Dict, Iterable, List, Optional +from typing import Dict, Iterable, List, Optional, Set from six import iteritems, itervalues @@ -662,7 +662,7 @@ class StateResolutionStore(object): allow_rejected=allow_rejected, ) - def get_auth_chain(self, event_ids): + def get_auth_chain(self, event_ids: List[str], ignore_events: Set[str]): """Gets the full auth chain for a set of events (including rejected events). @@ -674,11 +674,16 @@ class StateResolutionStore(object): presence of rejected events Args: - event_ids (list): The event IDs of the events to fetch the auth - chain for. Must be state events. + event_ids: The event IDs of the events to fetch the auth chain for. + Must be state events. + ignore_events: Set of events to exclude from the returned auth + chain. + Returns: Deferred[list[str]]: List of event IDs of the auth chain. """ - return self.store.get_auth_chain_ids(event_ids, include_given=True) + return self.store.get_auth_chain_ids( + event_ids, include_given=True, ignore_events=ignore_events, + ) diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 531018c6a5..75fe58305a 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -248,7 +248,7 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store): and eid not in common ) - auth_chain = yield state_res_store.get_auth_chain(auth_ids) + auth_chain = yield state_res_store.get_auth_chain(auth_ids, common) auth_ids.update(auth_chain) auth_sets.append(auth_ids) diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 60c67457b4..e16da2577d 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -14,6 +14,7 @@ # limitations under the License. import itertools import logging +from typing import List, Optional, Set from six.moves import range from six.moves.queue import Empty, PriorityQueue @@ -46,21 +47,37 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas event_ids, include_given=include_given ).addCallback(self.get_events_as_list) - def get_auth_chain_ids(self, event_ids, include_given=False): + def get_auth_chain_ids( + self, + event_ids: List[str], + include_given: bool = False, + ignore_events: Optional[Set[str]] = None, + ): """Get auth events for given event_ids. The events *must* be state events. Args: - event_ids (list): state events - include_given (bool): include the given events in result + event_ids: state events + include_given: include the given events in result + ignore_events: Set of events to exclude from the returned auth + chain. This is useful if the caller will just discard the + given events anyway, and saves us from figuring out their auth + chains if not required. Returns: list of event_ids """ return self.db.runInteraction( - "get_auth_chain_ids", self._get_auth_chain_ids_txn, event_ids, include_given + "get_auth_chain_ids", + self._get_auth_chain_ids_txn, + event_ids, + include_given, + ignore_events, ) - def _get_auth_chain_ids_txn(self, txn, event_ids, include_given): + def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events): + if ignore_events is None: + ignore_events = set() + if include_given: results = set(event_ids) else: @@ -80,6 +97,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas txn.execute(base_sql + clause, list(args)) new_front.update([r[0] for r in txn]) + new_front -= ignore_events new_front -= results front = new_front diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index 5bafad9f19..5059ade850 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -603,7 +603,7 @@ class TestStateResolutionStore(object): return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map} - def get_auth_chain(self, event_ids): + def get_auth_chain(self, event_ids, ignore_events): """Gets the full auth chain for a set of events (including rejected events). @@ -617,6 +617,8 @@ class TestStateResolutionStore(object): Args: event_ids (list): The event IDs of the events to fetch the auth chain for. Must be state events. + ignore_events: Set of events to exclude from the returned auth + chain. Returns: Deferred[list[str]]: List of event IDs of the auth chain. @@ -627,7 +629,7 @@ class TestStateResolutionStore(object): stack = list(event_ids) while stack: event_id = stack.pop() - if event_id in result: + if event_id in result or event_id in ignore_events: continue result.add(event_id) -- cgit 1.5.1 From 509e381afa8c656e72f5fef3d651a9819794174a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 21 Feb 2020 07:15:07 -0500 Subject: Clarify list/set/dict/tuple comprehensions and enforce via flake8 (#6957) Ensure good comprehension hygiene using flake8-comprehensions. --- CONTRIBUTING.md | 2 +- changelog.d/6957.misc | 1 + docs/code_style.md | 2 +- scripts-dev/convert_server_keys.py | 2 +- synapse/app/_base.py | 2 +- synapse/app/federation_sender.py | 4 +-- synapse/app/pusher.py | 2 +- synapse/config/server.py | 4 +-- synapse/config/tls.py | 2 +- synapse/crypto/keyring.py | 6 ++-- synapse/federation/send_queue.py | 4 +-- synapse/groups/groups_server.py | 2 +- synapse/handlers/device.py | 2 +- synapse/handlers/directory.py | 4 +-- synapse/handlers/federation.py | 18 +++++----- synapse/handlers/presence.py | 6 ++-- synapse/handlers/receipts.py | 2 +- synapse/handlers/room.py | 2 +- synapse/handlers/search.py | 8 ++--- synapse/handlers/sync.py | 22 ++++++------ synapse/handlers/typing.py | 4 +-- synapse/logging/utils.py | 2 +- synapse/metrics/__init__.py | 2 +- synapse/metrics/background_process_metrics.py | 4 +-- synapse/push/bulk_push_rule_evaluator.py | 8 ++--- synapse/push/emailpusher.py | 2 +- synapse/push/mailer.py | 20 +++++------ synapse/push/pusherpool.py | 2 +- synapse/rest/admin/_base.py | 4 +-- synapse/rest/client/v1/push_rule.py | 6 ++-- synapse/rest/client/v1/pusher.py | 4 +-- synapse/rest/client/v2_alpha/sync.py | 2 +- synapse/rest/key/v2/remote_key_resource.py | 2 +- synapse/rest/media/v1/_base.py | 40 ++++++++++------------ synapse/state/v1.py | 10 +++--- synapse/state/v2.py | 8 ++--- synapse/storage/_base.py | 2 +- synapse/storage/background_updates.py | 2 +- synapse/storage/data_stores/main/appservice.py | 14 ++++---- synapse/storage/data_stores/main/client_ips.py | 4 +-- synapse/storage/data_stores/main/devices.py | 13 ++++--- .../storage/data_stores/main/event_federation.py | 2 +- synapse/storage/data_stores/main/events.py | 8 ++--- .../storage/data_stores/main/events_bg_updates.py | 2 +- synapse/storage/data_stores/main/events_worker.py | 6 ++-- synapse/storage/data_stores/main/push_rule.py | 8 ++--- synapse/storage/data_stores/main/receipts.py | 4 +-- synapse/storage/data_stores/main/roommember.py | 4 +-- synapse/storage/data_stores/main/state.py | 8 ++--- synapse/storage/data_stores/main/stream.py | 8 ++--- .../storage/data_stores/main/user_erasure_store.py | 4 +-- synapse/storage/data_stores/state/store.py | 4 +-- synapse/storage/database.py | 4 +-- synapse/storage/persist_events.py | 8 ++--- synapse/storage/prepare_database.py | 6 ++-- synapse/util/frozenutils.py | 2 +- synapse/visibility.py | 4 +-- tests/config/test_generate.py | 2 +- tests/federation/test_federation_server.py | 2 +- tests/handlers/test_presence.py | 4 +-- tests/handlers/test_typing.py | 6 ++-- tests/handlers/test_user_directory.py | 12 +++---- tests/push/test_email.py | 6 ++-- tests/push/test_http.py | 8 ++--- tests/rest/client/v2_alpha/test_sync.py | 28 ++++++++------- tests/storage/test__base.py | 4 +-- tests/storage/test_appservice.py | 36 +++++++++---------- tests/storage/test_cleanup_extrems.py | 10 +++--- tests/storage/test_event_metrics.py | 36 +++++++++---------- tests/storage/test_state.py | 2 +- tests/test_state.py | 18 +++------- tests/util/test_stream_change_cache.py | 18 +++------- tox.ini | 1 + 73 files changed, 251 insertions(+), 276 deletions(-) create mode 100644 changelog.d/6957.misc (limited to 'tests') diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4b01b6ac8c..253a0ca648 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -60,7 +60,7 @@ python 3.6 and to install each tool: ``` # Install the dependencies -pip install -U black flake8 isort +pip install -U black flake8 flake8-comprehensions isort # Run the linter script ./scripts-dev/lint.sh diff --git a/changelog.d/6957.misc b/changelog.d/6957.misc new file mode 100644 index 0000000000..4f98030110 --- /dev/null +++ b/changelog.d/6957.misc @@ -0,0 +1 @@ +Use flake8-comprehensions to enforce good hygiene of list/set/dict comprehensions. diff --git a/docs/code_style.md b/docs/code_style.md index 71aecd41f7..6ef6f80290 100644 --- a/docs/code_style.md +++ b/docs/code_style.md @@ -30,7 +30,7 @@ The necessary tools are detailed below. Install `flake8` with: - pip install --upgrade flake8 + pip install --upgrade flake8 flake8-comprehensions Check all application and test code with: diff --git a/scripts-dev/convert_server_keys.py b/scripts-dev/convert_server_keys.py index 179be61c30..06b4c1e2ff 100644 --- a/scripts-dev/convert_server_keys.py +++ b/scripts-dev/convert_server_keys.py @@ -103,7 +103,7 @@ def main(): yaml.safe_dump(result, sys.stdout, default_flow_style=False) - rows = list(row for server, json in result.items() for row in rows_v2(server, json)) + rows = [row for server, json in result.items() for row in rows_v2(server, json)] cursor = connection.cursor() cursor.executemany( diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 109b1e2fb5..9ffd23c6df 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -141,7 +141,7 @@ def start_reactor( def quit_with_error(error_string): message_lines = error_string.split("\n") - line_length = max([len(l) for l in message_lines if len(l) < 80]) + 2 + line_length = max(len(l) for l in message_lines if len(l) < 80) + 2 sys.stderr.write("*" * line_length + "\n") for line in message_lines: sys.stderr.write(" %s\n" % (line.rstrip(),)) diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 63a91f1177..b7fcf80ddc 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -262,7 +262,7 @@ class FederationSenderHandler(object): # ... as well as device updates and messages elif stream_name == DeviceListsStream.NAME: - hosts = set(row.destination for row in rows) + hosts = {row.destination for row in rows} for host in hosts: self.federation_sender.send_device_messages(host) @@ -270,7 +270,7 @@ class FederationSenderHandler(object): # The to_device stream includes stuff to be pushed to both local # clients and remote servers, so we ignore entities that start with # '@' (since they'll be local users rather than destinations). - hosts = set(row.entity for row in rows if not row.entity.startswith("@")) + hosts = {row.entity for row in rows if not row.entity.startswith("@")} for host in hosts: self.federation_sender.send_device_messages(host) diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index e46b6ac598..84e9f8d5e2 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -158,7 +158,7 @@ class PusherReplicationHandler(ReplicationClientHandler): yield self.pusher_pool.on_new_notifications(token, token) elif stream_name == "receipts": yield self.pusher_pool.on_new_receipts( - token, token, set(row.room_id for row in rows) + token, token, {row.room_id for row in rows} ) except Exception: logger.exception("Error poking pushers") diff --git a/synapse/config/server.py b/synapse/config/server.py index 0ec1b0fadd..7525765fee 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -1066,12 +1066,12 @@ KNOWN_RESOURCES = ( def _check_resource_config(listeners): - resource_names = set( + resource_names = { res_name for listener in listeners for res in listener.get("resources", []) for res_name in res.get("names", []) - ) + } for resource in resource_names: if resource not in KNOWN_RESOURCES: diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 97a12d51f6..a65538562b 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -260,7 +260,7 @@ class TlsConfig(Config): crypto.FILETYPE_ASN1, self.tls_certificate ) sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest()) - sha256_fingerprints = set(f["sha256"] for f in self.tls_fingerprints) + sha256_fingerprints = {f["sha256"] for f in self.tls_fingerprints} if sha256_fingerprint not in sha256_fingerprints: self.tls_fingerprints.append({"sha256": sha256_fingerprint}) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 6fe5a6a26a..983f0ead8c 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -326,9 +326,7 @@ class Keyring(object): verify_requests (list[VerifyJsonRequest]): list of verify requests """ - remaining_requests = set( - (rq for rq in verify_requests if not rq.key_ready.called) - ) + remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called} @defer.inlineCallbacks def do_iterations(): @@ -396,7 +394,7 @@ class Keyring(object): results = yield fetcher.get_keys(missing_keys) - completed = list() + completed = [] for verify_request in remaining_requests: server_name = verify_request.server_name diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 001bb304ae..876fb0e245 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -129,9 +129,9 @@ class FederationRemoteSendQueue(object): for key in keys[:i]: del self.presence_changed[key] - user_ids = set( + user_ids = { user_id for uids in self.presence_changed.values() for user_id in uids - ) + } keys = self.presence_destinations.keys() i = self.presence_destinations.bisect_left(position_to_delete) diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index c106abae21..4f0dc0a209 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -608,7 +608,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): user_results = yield self.store.get_users_in_group( group_id, include_private=True ) - if user_id in [user_result["user_id"] for user_result in user_results]: + if user_id in (user_result["user_id"] for user_result in user_results): raise SynapseError(400, "User already in group") content = { diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 50cea3f378..a514c30714 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -742,6 +742,6 @@ class DeviceListUpdater(object): # We clobber the seen updates since we've re-synced from a given # point. - self._seen_updates[user_id] = set([stream_id]) + self._seen_updates[user_id] = {stream_id} defer.returnValue(result) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 921d887b24..0b23ca919a 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -72,7 +72,7 @@ class DirectoryHandler(BaseHandler): # TODO(erikj): Check if there is a current association. if not servers: users = yield self.state.get_current_users_in_room(room_id) - servers = set(get_domain_from_id(u) for u in users) + servers = {get_domain_from_id(u) for u in users} if not servers: raise SynapseError(400, "Failed to get server list") @@ -255,7 +255,7 @@ class DirectoryHandler(BaseHandler): ) users = yield self.state.get_current_users_in_room(room_id) - extra_servers = set(get_domain_from_id(u) for u in users) + extra_servers = {get_domain_from_id(u) for u in users} servers = set(extra_servers) | set(servers) # If this server is in the list of servers, return it first. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index eb20ef4aec..a689065f89 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -659,11 +659,11 @@ class FederationHandler(BaseHandler): # this can happen if a remote server claims that the state or # auth_events at an event in room A are actually events in room B - bad_events = list( + bad_events = [ (event_id, event.room_id) for event_id, event in fetched_events.items() if event.room_id != room_id - ) + ] for bad_event_id, bad_room_id in bad_events: # This is a bogus situation, but since we may only discover it a long time @@ -856,7 +856,7 @@ class FederationHandler(BaseHandler): # Don't bother processing events we already have. seen_events = await self.store.have_events_in_timeline( - set(e.event_id for e in events) + {e.event_id for e in events} ) events = [e for e in events if e.event_id not in seen_events] @@ -866,7 +866,7 @@ class FederationHandler(BaseHandler): event_map = {e.event_id: e for e in events} - event_ids = set(e.event_id for e in events) + event_ids = {e.event_id for e in events} # build a list of events whose prev_events weren't in the batch. # (XXX: this will include events whose prev_events we already have; that doesn't @@ -892,13 +892,13 @@ class FederationHandler(BaseHandler): state_events.update({s.event_id: s for s in state}) events_to_state[e_id] = state - required_auth = set( + required_auth = { a_id for event in events + list(state_events.values()) + list(auth_events.values()) for a_id in event.auth_event_ids() - ) + } auth_events.update( {e_id: event_map[e_id] for e_id in required_auth if e_id in event_map} ) @@ -1247,7 +1247,7 @@ class FederationHandler(BaseHandler): async def on_event_auth(self, event_id: str) -> List[EventBase]: event = await self.store.get_event(event_id) auth = await self.store.get_auth_chain( - [auth_id for auth_id in event.auth_event_ids()], include_given=True + list(event.auth_event_ids()), include_given=True ) return list(auth) @@ -2152,7 +2152,7 @@ class FederationHandler(BaseHandler): # Now get the current auth_chain for the event. local_auth_chain = await self.store.get_auth_chain( - [auth_id for auth_id in event.auth_event_ids()], include_given=True + list(event.auth_event_ids()), include_given=True ) # TODO: Check if we would now reject event_id. If so we need to tell @@ -2654,7 +2654,7 @@ class FederationHandler(BaseHandler): member_handler = self.hs.get_room_member_handler() yield member_handler.send_membership_event(None, event, context) else: - destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id)) + destinations = {x.split(":", 1)[-1] for x in (sender_user_id, room_id)} yield self.federation_client.forward_third_party_invite( destinations, room_id, event_dict ) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 202aa9294f..0d6cf2b008 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -313,7 +313,7 @@ class PresenceHandler(object): notified_presence_counter.inc(len(to_notify)) yield self._persist_and_notify(list(to_notify.values())) - self.unpersisted_users_changes |= set(s.user_id for s in new_states) + self.unpersisted_users_changes |= {s.user_id for s in new_states} self.unpersisted_users_changes -= set(to_notify.keys()) to_federation_ping = { @@ -698,7 +698,7 @@ class PresenceHandler(object): updates = yield self.current_state_for_users(target_user_ids) updates = list(updates.values()) - for user_id in set(target_user_ids) - set(u.user_id for u in updates): + for user_id in set(target_user_ids) - {u.user_id for u in updates}: updates.append(UserPresenceState.default(user_id)) now = self.clock.time_msec() @@ -886,7 +886,7 @@ class PresenceHandler(object): hosts = yield self.state.get_current_hosts_in_room(room_id) # Filter out ourselves. - hosts = set(host for host in hosts if host != self.server_name) + hosts = {host for host in hosts if host != self.server_name} self.federation.send_presence_to_destinations( states=[state], destinations=hosts diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 9283c039e3..8bc100db42 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -94,7 +94,7 @@ class ReceiptsHandler(BaseHandler): # no new receipts return False - affected_room_ids = list(set([r.room_id for r in receipts])) + affected_room_ids = list({r.room_id for r in receipts}) self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids) # Note that the min here shouldn't be relied upon to be accurate. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 76e8f61b74..8ee870f0bb 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -355,7 +355,7 @@ class RoomCreationHandler(BaseHandler): # If so, mark the new room as non-federatable as well creation_content["m.federate"] = False - initial_state = dict() + initial_state = {} # Replicate relevant room events types_to_copy = ( diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 110097eab9..ec1542d416 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -184,7 +184,7 @@ class SearchHandler(BaseHandler): membership_list=[Membership.JOIN], # membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban], ) - room_ids = set(r.room_id for r in rooms) + room_ids = {r.room_id for r in rooms} # If doing a subset of all rooms seearch, check if any of the rooms # are from an upgraded room, and search their contents as well @@ -374,12 +374,12 @@ class SearchHandler(BaseHandler): ).to_string() if include_profile: - senders = set( + senders = { ev.sender for ev in itertools.chain( res["events_before"], [event], res["events_after"] ) - ) + } if res["events_after"]: last_event_id = res["events_after"][-1].event_id @@ -421,7 +421,7 @@ class SearchHandler(BaseHandler): state_results = {} if include_state: - rooms = set(e.room_id for e in allowed_events) + rooms = {e.room_id for e in allowed_events} for room_id in rooms: state = yield self.state_handler.get_current_state(room_id) state_results[room_id] = list(state.values()) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 4324bc702e..669dbc8a48 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -682,11 +682,9 @@ class SyncHandler(object): # FIXME: order by stream ordering rather than as returned by SQL if joined_user_ids or invited_user_ids: - summary["m.heroes"] = sorted( - [user_id for user_id in (joined_user_ids + invited_user_ids)] - )[0:5] + summary["m.heroes"] = sorted(joined_user_ids + invited_user_ids)[0:5] else: - summary["m.heroes"] = sorted([user_id for user_id in gone_user_ids])[0:5] + summary["m.heroes"] = sorted(gone_user_ids)[0:5] if not sync_config.filter_collection.lazy_load_members(): return summary @@ -697,9 +695,9 @@ class SyncHandler(object): # track which members the client should already know about via LL: # Ones which are already in state... - existing_members = set( + existing_members = { user_id for (typ, user_id) in state.keys() if typ == EventTypes.Member - ) + } # ...or ones which are in the timeline... for ev in batch.events: @@ -773,10 +771,10 @@ class SyncHandler(object): # We only request state for the members needed to display the # timeline: - members_to_fetch = set( + members_to_fetch = { event.sender # FIXME: we also care about invite targets etc. for event in batch.events - ) + } if full_state: # always make sure we LL ourselves so we know we're in the room @@ -1993,10 +1991,10 @@ def _calculate_state( ) } - c_ids = set(e for e in itervalues(current)) - ts_ids = set(e for e in itervalues(timeline_start)) - p_ids = set(e for e in itervalues(previous)) - tc_ids = set(e for e in itervalues(timeline_contains)) + c_ids = set(itervalues(current)) + ts_ids = set(itervalues(timeline_start)) + p_ids = set(itervalues(previous)) + tc_ids = set(itervalues(timeline_contains)) # If we are lazyloading room members, we explicitly add the membership events # for the senders in the timeline into the state block returned by /sync, diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 5406618431..391bceb0c4 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -198,7 +198,7 @@ class TypingHandler(object): now=now, obj=member, then=now + FEDERATION_PING_INTERVAL ) - for domain in set(get_domain_from_id(u) for u in users): + for domain in {get_domain_from_id(u) for u in users}: if domain != self.server_name: logger.debug("sending typing update to %s", domain) self.federation.build_and_send_edu( @@ -231,7 +231,7 @@ class TypingHandler(object): return users = yield self.state.get_current_users_in_room(room_id) - domains = set(get_domain_from_id(u) for u in users) + domains = {get_domain_from_id(u) for u in users} if self.server_name in domains: logger.info("Got typing update from %s: %r", user_id, content) diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py index 6073fc2725..0c2527bd86 100644 --- a/synapse/logging/utils.py +++ b/synapse/logging/utils.py @@ -148,7 +148,7 @@ def trace_function(f): pathname=pathname, lineno=lineno, msg=msg, - args=tuple(), + args=(), exc_info=None, ) diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index 0b45e1f52a..0dba997a23 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -240,7 +240,7 @@ class BucketCollector(object): res.append(["+Inf", sum(data.values())]) metric = HistogramMetricFamily( - self.name, "", buckets=res, sum_value=sum([x * y for x, y in data.items()]) + self.name, "", buckets=res, sum_value=sum(x * y for x, y in data.items()) ) yield metric diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index c53d2a0d40..b65bcd8806 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -80,13 +80,13 @@ _background_process_db_sched_duration = Counter( # map from description to a counter, so that we can name our logcontexts # incrementally. (It actually duplicates _background_process_start_count, but # it's much simpler to do so than to try to combine them.) -_background_process_counts = dict() # type: dict[str, int] +_background_process_counts = {} # type: dict[str, int] # map from description to the currently running background processes. # # it's kept as a dict of sets rather than a big set so that we can keep track # of process descriptions that no longer have any active processes. -_background_processes = dict() # type: dict[str, set[_BackgroundProcess]] +_background_processes = {} # type: dict[str, set[_BackgroundProcess]] # A lock that covers the above dicts _bg_metrics_lock = threading.Lock() diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 7d9f5a38d9..433ca2f416 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -400,11 +400,11 @@ class RulesForRoom(object): if logger.isEnabledFor(logging.DEBUG): logger.debug("Found members %r: %r", self.room_id, members.values()) - interested_in_user_ids = set( + interested_in_user_ids = { user_id for user_id, membership in itervalues(members) if membership == Membership.JOIN - ) + } logger.debug("Joined: %r", interested_in_user_ids) @@ -412,9 +412,9 @@ class RulesForRoom(object): interested_in_user_ids, on_invalidate=self.invalidate_all_cb ) - user_ids = set( + user_ids = { uid for uid, have_pusher in iteritems(if_users_with_pushers) if have_pusher - ) + } logger.debug("With pushers: %r", user_ids) diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 8c818a86bf..ba4551d619 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -204,7 +204,7 @@ class EmailPusher(object): yield self.send_notification(unprocessed, reason) yield self.save_last_stream_ordering_and_success( - max([ea["stream_ordering"] for ea in unprocessed]) + max(ea["stream_ordering"] for ea in unprocessed) ) # we update the throttle on all the possible unprocessed push actions diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index b13b646bfd..4ccaf178ce 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -526,12 +526,10 @@ class Mailer(object): # If the room doesn't have a name, say who the messages # are from explicitly to avoid, "messages in the Bob room" sender_ids = list( - set( - [ - notif_events[n["event_id"]].sender - for n in notifs_by_room[room_id] - ] - ) + { + notif_events[n["event_id"]].sender + for n in notifs_by_room[room_id] + } ) member_events = yield self.store.get_events( @@ -558,12 +556,10 @@ class Mailer(object): # If the reason room doesn't have a name, say who the messages # are from explicitly to avoid, "messages in the Bob room" sender_ids = list( - set( - [ - notif_events[n["event_id"]].sender - for n in notifs_by_room[reason["room_id"]] - ] - ) + { + notif_events[n["event_id"]].sender + for n in notifs_by_room[reason["room_id"]] + } ) member_events = yield self.store.get_events( diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index b9dca5bc63..01789a9fb4 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -191,7 +191,7 @@ class PusherPool: min_stream_id - 1, max_stream_id ) # This returns a tuple, user_id is at index 3 - users_affected = set([r[3] for r in updated_receipts]) + users_affected = {r[3] for r in updated_receipts} for u in users_affected: if u in self.pushers: diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py index 459482eb6d..a96f75ce26 100644 --- a/synapse/rest/admin/_base.py +++ b/synapse/rest/admin/_base.py @@ -29,7 +29,7 @@ def historical_admin_path_patterns(path_regex): Note that this should only be used for existing endpoints: new ones should just register for the /_synapse/admin path. """ - return list( + return [ re.compile(prefix + path_regex) for prefix in ( "^/_synapse/admin/v1", @@ -37,7 +37,7 @@ def historical_admin_path_patterns(path_regex): "^/_matrix/client/unstable/admin", "^/_matrix/client/r0/admin", ) - ) + ] def admin_patterns(path_regex: str): diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 4f74600239..9fd4908136 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -49,7 +49,7 @@ class PushRuleRestServlet(RestServlet): if self._is_worker: raise Exception("Cannot handle PUT /push_rules on worker") - spec = _rule_spec_from_path([x for x in path.split("/")]) + spec = _rule_spec_from_path(path.split("/")) try: priority_class = _priority_class_from_spec(spec) except InvalidRuleException as e: @@ -110,7 +110,7 @@ class PushRuleRestServlet(RestServlet): if self._is_worker: raise Exception("Cannot handle DELETE /push_rules on worker") - spec = _rule_spec_from_path([x for x in path.split("/")]) + spec = _rule_spec_from_path(path.split("/")) requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() @@ -138,7 +138,7 @@ class PushRuleRestServlet(RestServlet): rules = format_push_rules_for_user(requester.user, rules) - path = [x for x in path.split("/")][1:] + path = path.split("/")[1:] if path == []: # we're a reference impl: pedantry is our job. diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 6f6b7aed6e..550a2f1b44 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -54,9 +54,9 @@ class PushersRestServlet(RestServlet): pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) - filtered_pushers = list( + filtered_pushers = [ {k: v for k, v in p.items() if k in ALLOWED_KEYS} for p in pushers - ) + ] return 200, {"pushers": filtered_pushers} diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index d8292ce29f..8fa68dd37f 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -72,7 +72,7 @@ class SyncRestServlet(RestServlet): """ PATTERNS = client_patterns("/sync$") - ALLOWED_PRESENCE = set(["online", "offline", "unavailable"]) + ALLOWED_PRESENCE = {"online", "offline", "unavailable"} def __init__(self, hs): super(SyncRestServlet, self).__init__() diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 9d6813a047..4b6d030a57 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -149,7 +149,7 @@ class RemoteKey(DirectServeResource): time_now_ms = self.clock.time_msec() - cache_misses = dict() # type: Dict[str, Set[str]] + cache_misses = {} # type: Dict[str, Set[str]] for (server_name, key_id, from_server), results in cached.items(): results = [(result["ts_added_ms"], result) for result in results] diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 65bbf00073..ba28dd089d 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -135,27 +135,25 @@ def add_file_headers(request, media_type, file_size, upload_name): # separators as defined in RFC2616. SP and HT are handled separately. # see _can_encode_filename_as_token. -_FILENAME_SEPARATOR_CHARS = set( - ( - "(", - ")", - "<", - ">", - "@", - ",", - ";", - ":", - "\\", - '"', - "/", - "[", - "]", - "?", - "=", - "{", - "}", - ) -) +_FILENAME_SEPARATOR_CHARS = { + "(", + ")", + "<", + ">", + "@", + ",", + ";", + ":", + "\\", + '"', + "/", + "[", + "]", + "?", + "=", + "{", + "}", +} def _can_encode_filename_as_token(x): diff --git a/synapse/state/v1.py b/synapse/state/v1.py index 24b7c0faef..9bf98d06f2 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -69,9 +69,9 @@ def resolve_events_with_store( unconflicted_state, conflicted_state = _seperate(state_sets) - needed_events = set( + needed_events = { event_id for event_ids in itervalues(conflicted_state) for event_id in event_ids - ) + } needed_event_count = len(needed_events) if event_map is not None: needed_events -= set(iterkeys(event_map)) @@ -261,11 +261,11 @@ def _resolve_state_events(conflicted_state, auth_events): def _resolve_auth_events(events, auth_events): - reverse = [i for i in reversed(_ordered_events(events))] + reverse = list(reversed(_ordered_events(events))) - auth_keys = set( + auth_keys = { key for event in events for key in event_auth.auth_types_for_event(event) - ) + } new_auth_events = {} for key in auth_keys: diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 75fe58305a..0ffe6d8c14 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -105,7 +105,7 @@ def resolve_events_with_store( % (room_id, event.event_id, event.room_id,) ) - full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map) + full_conflicted_set = {eid for eid in full_conflicted_set if eid in event_map} logger.debug("%d full_conflicted_set entries", len(full_conflicted_set)) @@ -233,7 +233,7 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store): auth_sets = [] for state_set in state_sets: - auth_ids = set( + auth_ids = { eid for key, eid in iteritems(state_set) if ( @@ -246,7 +246,7 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store): ) ) and eid not in common - ) + } auth_chain = yield state_res_store.get_auth_chain(auth_ids, common) auth_ids.update(auth_chain) @@ -275,7 +275,7 @@ def _seperate(state_sets): conflicted_state = {} for key in set(itertools.chain.from_iterable(state_sets)): - event_ids = set(state_set.get(key) for state_set in state_sets) + event_ids = {state_set.get(key) for state_set in state_sets} if len(event_ids) == 1: unconflicted_state[key] = event_ids.pop() else: diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index da3b99f93d..13de5f1f62 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -56,7 +56,7 @@ class SQLBaseStore(metaclass=ABCMeta): members_changed (iterable[str]): The user_ids of members that have changed """ - for host in set(get_domain_from_id(u) for u in members_changed): + for host in {get_domain_from_id(u) for u in members_changed}: self._attempt_to_invalidate_cache("is_host_joined", (room_id, host)) self._attempt_to_invalidate_cache("was_host_joined", (room_id, host)) diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index bd547f35cf..eb1a7e5002 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -189,7 +189,7 @@ class BackgroundUpdater(object): keyvalues=None, retcols=("update_name", "depends_on"), ) - in_flight = set(update["update_name"] for update in updates) + in_flight = {update["update_name"] for update in updates} for update in updates: if update["depends_on"] not in in_flight: self._background_update_queue.append(update["update_name"]) diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py index b2f39649fd..efbc06c796 100644 --- a/synapse/storage/data_stores/main/appservice.py +++ b/synapse/storage/data_stores/main/appservice.py @@ -135,7 +135,7 @@ class ApplicationServiceTransactionWorkerStore( may be empty. """ results = yield self.db.simple_select_list( - "application_services_state", dict(state=state), ["as_id"] + "application_services_state", {"state": state}, ["as_id"] ) # NB: This assumes this class is linked with ApplicationServiceStore as_list = self.get_app_services() @@ -158,7 +158,7 @@ class ApplicationServiceTransactionWorkerStore( """ result = yield self.db.simple_select_one( "application_services_state", - dict(as_id=service.id), + {"as_id": service.id}, ["state"], allow_none=True, desc="get_appservice_state", @@ -177,7 +177,7 @@ class ApplicationServiceTransactionWorkerStore( A Deferred which resolves when the state was set successfully. """ return self.db.simple_upsert( - "application_services_state", dict(as_id=service.id), dict(state=state) + "application_services_state", {"as_id": service.id}, {"state": state} ) def create_appservice_txn(self, service, events): @@ -253,13 +253,15 @@ class ApplicationServiceTransactionWorkerStore( self.db.simple_upsert_txn( txn, "application_services_state", - dict(as_id=service.id), - dict(last_txn=txn_id), + {"as_id": service.id}, + {"last_txn": txn_id}, ) # Delete txn self.db.simple_delete_txn( - txn, "application_services_txns", dict(txn_id=txn_id, as_id=service.id) + txn, + "application_services_txns", + {"txn_id": txn_id, "as_id": service.id}, ) return self.db.runInteraction( diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index 13f4c9c72e..e1ccb27142 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -530,7 +530,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): ((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"])) for row in rows ) - return list( + return [ { "access_token": access_token, "ip": ip, @@ -538,7 +538,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): "last_seen": last_seen, } for (access_token, ip), (user_agent, last_seen) in iteritems(results) - ) + ] @wrap_as_background_process("prune_old_user_ips") async def _prune_old_user_ips(self): diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index b7617efb80..d55733a4cd 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -137,7 +137,7 @@ class DeviceWorkerStore(SQLBaseStore): # get the cross-signing keys of the users in the list, so that we can # determine which of the device changes were cross-signing keys - users = set(r[0] for r in updates) + users = {r[0] for r in updates} master_key_by_user = {} self_signing_key_by_user = {} for user in users: @@ -446,7 +446,7 @@ class DeviceWorkerStore(SQLBaseStore): a set of user_ids and results_map is a mapping of user_id -> device_id -> device_info """ - user_ids = set(user_id for user_id, _ in query_list) + user_ids = {user_id for user_id, _ in query_list} user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids)) # We go and check if any of the users need to have their device lists @@ -454,10 +454,9 @@ class DeviceWorkerStore(SQLBaseStore): users_needing_resync = yield self.get_user_ids_requiring_device_list_resync( user_ids ) - user_ids_in_cache = ( - set(user_id for user_id, stream_id in user_map.items() if stream_id) - - users_needing_resync - ) + user_ids_in_cache = { + user_id for user_id, stream_id in user_map.items() if stream_id + } - users_needing_resync user_ids_not_in_cache = user_ids - user_ids_in_cache results = {} @@ -604,7 +603,7 @@ class DeviceWorkerStore(SQLBaseStore): rows = yield self.db.execute( "get_users_whose_signatures_changed", None, sql, user_id, from_key ) - return set(user for row in rows for user in json.loads(row[0])) + return {user for row in rows for user in json.loads(row[0])} else: return set() diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 750ec1b70d..49a7b8b433 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -426,7 +426,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas query, (room_id, event_id, False, limit - len(event_results)) ) - new_results = set(t[0] for t in txn) - seen_events + new_results = {t[0] for t in txn} - seen_events new_front |= new_results seen_events |= new_results diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index c9d0d68c3a..8ae23df00a 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -145,7 +145,7 @@ class EventsStore( return txn.fetchall() res = yield self.db.runInteraction("read_forward_extremities", fetch) - self._current_forward_extremities_amount = c_counter(list(x[0] for x in res)) + self._current_forward_extremities_amount = c_counter([x[0] for x in res]) @_retry_on_integrity_error @defer.inlineCallbacks @@ -598,11 +598,11 @@ class EventsStore( # We find out which membership events we may have deleted # and which we have added, then we invlidate the caches for all # those users. - members_changed = set( + members_changed = { state_key for ev_type, state_key in itertools.chain(to_delete, to_insert) if ev_type == EventTypes.Member - ) + } for member in members_changed: txn.call_after( @@ -1615,7 +1615,7 @@ class EventsStore( """ ) - referenced_state_groups = set(sg for sg, in txn) + referenced_state_groups = {sg for sg, in txn} logger.info( "[purge] found %i referenced state groups", len(referenced_state_groups) ) diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py index 5177b71016..f54c8b1ee0 100644 --- a/synapse/storage/data_stores/main/events_bg_updates.py +++ b/synapse/storage/data_stores/main/events_bg_updates.py @@ -402,7 +402,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): keyvalues={}, retcols=("room_id",), ) - room_ids = set(row["room_id"] for row in rows) + room_ids = {row["room_id"] for row in rows} for room_id in room_ids: txn.call_after( self.get_latest_event_ids_in_room.invalidate, (room_id,) diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 7251e819f5..47a3a26072 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -494,9 +494,9 @@ class EventsWorkerStore(SQLBaseStore): """ with Measure(self._clock, "_fetch_event_list"): try: - events_to_fetch = set( + events_to_fetch = { event_id for events, _ in event_list for event_id in events - ) + } row_dict = self.db.new_transaction( conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch @@ -804,7 +804,7 @@ class EventsWorkerStore(SQLBaseStore): desc="have_events_in_timeline", ) - return set(r["event_id"] for r in rows) + return {r["event_id"] for r in rows} @defer.inlineCallbacks def have_seen_events(self, event_ids): diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index e2673ae073..62ac88d9f2 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -276,21 +276,21 @@ class PushRulesWorkerStore( # We ignore app service users for now. This is so that we don't fill # up the `get_if_users_have_pushers` cache with AS entries that we # know don't have pushers, nor even read receipts. - local_users_in_room = set( + local_users_in_room = { u for u in users_in_room if self.hs.is_mine_id(u) and not self.get_if_app_services_interested_in_user(u) - ) + } # users in the room who have pushers need to get push rules run because # that's how their pushers work if_users_with_pushers = yield self.get_if_users_have_pushers( local_users_in_room, on_invalidate=cache_context.invalidate ) - user_ids = set( + user_ids = { uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher - ) + } users_with_receipts = yield self.get_users_with_read_receipts_in_room( room_id, on_invalidate=cache_context.invalidate diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py index 96e54d145e..0d932a0672 100644 --- a/synapse/storage/data_stores/main/receipts.py +++ b/synapse/storage/data_stores/main/receipts.py @@ -58,7 +58,7 @@ class ReceiptsWorkerStore(SQLBaseStore): @cachedInlineCallbacks() def get_users_with_read_receipts_in_room(self, room_id): receipts = yield self.get_receipts_for_room(room_id, "m.read") - return set(r["user_id"] for r in receipts) + return {r["user_id"] for r in receipts} @cached(num_args=2) def get_receipts_for_room(self, room_id, receipt_type): @@ -283,7 +283,7 @@ class ReceiptsWorkerStore(SQLBaseStore): args.append(limit) txn.execute(sql, args) - return list(r[0:5] + (json.loads(r[5]),) for r in txn) + return [r[0:5] + (json.loads(r[5]),) for r in txn] return self.db.runInteraction( "get_all_updated_receipts", get_all_updated_receipts_txn diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index d5ced05701..d5bd0cb5cf 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -465,7 +465,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): txn.execute(sql % (clause,), args) - return set(row[0] for row in txn) + return {row[0] for row in txn} return await self.db.runInteraction( "get_users_server_still_shares_room_with", @@ -826,7 +826,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): GROUP BY room_id, user_id; """ txn.execute(sql, (user_id,)) - return set(row[0] for row in txn if row[1] == 0) + return {row[0] for row in txn if row[1] == 0} return self.db.runInteraction( "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 3d34103e67..3a3b9a8e72 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -321,7 +321,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): desc="get_referenced_state_groups", ) - return set(row["state_group"] for row in rows) + return {row["state_group"] for row in rows} class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): @@ -367,7 +367,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): """ txn.execute(sql, (last_room_id, batch_size)) - room_ids = list(row[0] for row in txn) + room_ids = [row[0] for row in txn] if not room_ids: return True, set() @@ -384,7 +384,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): txn.execute(sql, (last_room_id, room_ids[-1], "%:" + self.server_name)) - joined_room_ids = set(row[0] for row in txn) + joined_room_ids = {row[0] for row in txn} left_rooms = set(room_ids) - joined_room_ids @@ -404,7 +404,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): retcols=("state_key",), ) - potentially_left_users = set(row["state_key"] for row in rows) + potentially_left_users = {row["state_key"] for row in rows} # Now lets actually delete the rooms from the DB. self.db.simple_delete_many_txn( diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py index 056b25b13a..ada5cce6c2 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -346,11 +346,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): from_key (str): The room_key portion of a StreamToken """ from_key = RoomStreamToken.parse_stream_token(from_key).stream - return set( + return { room_id for room_id in room_ids if self._events_stream_cache.has_entity_changed(room_id, from_key) - ) + } @defer.inlineCallbacks def get_room_events_stream_for_room( @@ -679,11 +679,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) events_before = yield self.get_events_as_list( - [e for e in results["before"]["event_ids"]], get_prev_content=True + list(results["before"]["event_ids"]), get_prev_content=True ) events_after = yield self.get_events_as_list( - [e for e in results["after"]["event_ids"]], get_prev_content=True + list(results["after"]["event_ids"]), get_prev_content=True ) return { diff --git a/synapse/storage/data_stores/main/user_erasure_store.py b/synapse/storage/data_stores/main/user_erasure_store.py index af8025bc17..ec6b8a4ffd 100644 --- a/synapse/storage/data_stores/main/user_erasure_store.py +++ b/synapse/storage/data_stores/main/user_erasure_store.py @@ -63,9 +63,9 @@ class UserErasureWorkerStore(SQLBaseStore): retcols=("user_id",), desc="are_users_erased", ) - erased_users = set(row["user_id"] for row in rows) + erased_users = {row["user_id"] for row in rows} - res = dict((u, u in erased_users) for u in user_ids) + res = {u: u in erased_users for u in user_ids} return res diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py index c4ee9b7ccb..57a5267663 100644 --- a/synapse/storage/data_stores/state/store.py +++ b/synapse/storage/data_stores/state/store.py @@ -520,11 +520,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): retcols=("state_group",), ) - remaining_state_groups = set( + remaining_state_groups = { row["state_group"] for row in rows if row["state_group"] not in state_groups_to_delete - ) + } logger.info( "[purge] de-delta-ing %i remaining state groups", diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 6dcb5c04da..1953614401 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -554,8 +554,8 @@ class Database(object): Returns: A list of dicts where the key is the column header. """ - col_headers = list(intern(str(column[0])) for column in cursor.description) - results = list(dict(zip(col_headers, row)) for row in cursor) + col_headers = [intern(str(column[0])) for column in cursor.description] + results = [dict(zip(col_headers, row)) for row in cursor] return results def execute(self, desc, decoder, query, *args): diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index b950550f23..0f9ac1cf09 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -602,14 +602,14 @@ class EventsPersistenceStorage(object): event_id_to_state_group.update(event_to_groups) # State groups of old_latest_event_ids - old_state_groups = set( + old_state_groups = { event_id_to_state_group[evid] for evid in old_latest_event_ids - ) + } # State groups of new_latest_event_ids - new_state_groups = set( + new_state_groups = { event_id_to_state_group[evid] for evid in new_latest_event_ids - ) + } # If they old and new groups are the same then we don't need to do # anything. diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index c285ef52a0..fc69c32a0a 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -345,9 +345,9 @@ def _upgrade_existing_database( "Could not open delta dir for version %d: %s" % (v, directory) ) - duplicates = set( + duplicates = { file_name for file_name, count in file_name_counter.items() if count > 1 - ) + } if duplicates: # We don't support using the same file name in the same delta version. raise PrepareDatabaseException( @@ -454,7 +454,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams) ), (modname,), ) - applied_deltas = set(d for d, in cur) + applied_deltas = {d for d, in cur} for (name, stream) in names_and_streams: if name in applied_deltas: continue diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py index 635b897d6c..f2ccd5e7c6 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py @@ -30,7 +30,7 @@ def freeze(o): return o try: - return tuple([freeze(i) for i in o]) + return tuple(freeze(i) for i in o) except TypeError: pass diff --git a/synapse/visibility.py b/synapse/visibility.py index d0abd8f04f..e60d9756b7 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -75,7 +75,7 @@ def filter_events_for_client( """ # Filter out events that have been soft failed so that we don't relay them # to clients. - events = list(e for e in events if not e.internal_metadata.is_soft_failed()) + events = [e for e in events if not e.internal_metadata.is_soft_failed()] types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id)) event_id_to_state = yield storage.state.get_state_for_events( @@ -97,7 +97,7 @@ def filter_events_for_client( erased_senders = yield storage.main.are_users_erased((e.sender for e in events)) if apply_retention_policies: - room_ids = set(e.room_id for e in events) + room_ids = {e.room_id for e in events} retention_policies = {} for room_id in room_ids: diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py index 2684e662de..463855ecc8 100644 --- a/tests/config/test_generate.py +++ b/tests/config/test_generate.py @@ -48,7 +48,7 @@ class ConfigGenerationTestCase(unittest.TestCase): ) self.assertSetEqual( - set(["homeserver.yaml", "lemurs.win.log.config", "lemurs.win.signing.key"]), + {"homeserver.yaml", "lemurs.win.log.config", "lemurs.win.signing.key"}, set(os.listdir(self.dir)), ) diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index e7d8699040..296dc887be 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -83,7 +83,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase): ) ) - self.assertEqual(members, set(["@user:other.example.com", u1])) + self.assertEqual(members, {"@user:other.example.com", u1}) self.assertEqual(len(channel.json_body["pdus"]), 6) def test_needs_to_be_in_room(self): diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index c171038df8..64915bafcd 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -338,7 +338,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): ) new_state = handle_timeout( - state, is_mine=True, syncing_user_ids=set([user_id]), now=now + state, is_mine=True, syncing_user_ids={user_id}, now=now ) self.assertIsNotNone(new_state) @@ -579,7 +579,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): ) self.assertEqual(expected_state.state, PresenceState.ONLINE) self.federation_sender.send_presence_to_destinations.assert_called_once_with( - destinations=set(("server2", "server3")), states=[expected_state] + destinations={"server2", "server3"}, states=[expected_state] ) def _add_new_user(self, room_id, user_id): diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 140cc0a3c2..07b204666e 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -129,12 +129,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): hs.get_auth().check_user_in_room = check_user_in_room def get_joined_hosts_for_room(room_id): - return set(member.domain for member in self.room_members) + return {member.domain for member in self.room_members} self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room def get_current_users_in_room(room_id): - return set(str(u) for u in self.room_members) + return {str(u) for u in self.room_members} hs.get_state_handler().get_current_users_in_room = get_current_users_in_room @@ -257,7 +257,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): member = RoomMember(ROOM_ID, U_APPLE.to_string()) self.handler._member_typing_until[member] = 1002000 - self.handler._room_typing[ROOM_ID] = set([U_APPLE.to_string()]) + self.handler._room_typing[ROOM_ID] = {U_APPLE.to_string()} self.assertEquals(self.event_source.get_current_key(), 0) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 0a4765fff4..7b92bdbc47 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -114,7 +114,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): public_users = self.get_users_in_public_rooms() self.assertEqual( - self._compress_shared(shares_private), set([(u1, u2, room), (u2, u1, room)]) + self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)} ) self.assertEqual(public_users, []) @@ -169,7 +169,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): public_users = self.get_users_in_public_rooms() self.assertEqual( - self._compress_shared(shares_private), set([(u1, u2, room), (u2, u1, room)]) + self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)} ) self.assertEqual(public_users, []) @@ -226,7 +226,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): public_users = self.get_users_in_public_rooms() self.assertEqual( - self._compress_shared(shares_private), set([(u1, u2, room), (u2, u1, room)]) + self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)} ) self.assertEqual(public_users, []) @@ -358,12 +358,12 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): public_users = self.get_users_in_public_rooms() # User 1 and User 2 are in the same public room - self.assertEqual(set(public_users), set([(u1, room), (u2, room)])) + self.assertEqual(set(public_users), {(u1, room), (u2, room)}) # User 1 and User 3 share private rooms self.assertEqual( self._compress_shared(shares_private), - set([(u1, u3, private_room), (u3, u1, private_room)]), + {(u1, u3, private_room), (u3, u1, private_room)}, ) def test_initial_share_all_users(self): @@ -398,7 +398,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # No users share rooms self.assertEqual(public_users, []) - self.assertEqual(self._compress_shared(shares_private), set([])) + self.assertEqual(self._compress_shared(shares_private), set()) # Despite not sharing a room, search_all_users means we get a search # result. diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 80187406bc..83032cc9ea 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -163,7 +163,7 @@ class EmailPusherTests(HomeserverTestCase): # Get the stream ordering before it gets sent pushers = self.get_success( - self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id)) + self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -174,7 +174,7 @@ class EmailPusherTests(HomeserverTestCase): # It hasn't succeeded yet, so the stream ordering shouldn't have moved pushers = self.get_success( - self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id)) + self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -192,7 +192,7 @@ class EmailPusherTests(HomeserverTestCase): # The stream ordering has increased pushers = self.get_success( - self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id)) + self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) diff --git a/tests/push/test_http.py b/tests/push/test_http.py index fe3441f081..baf9c785f4 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -102,7 +102,7 @@ class HTTPPusherTests(HomeserverTestCase): # Get the stream ordering before it gets sent pushers = self.get_success( - self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) + self.hs.get_datastore().get_pushers_by({"user_name": user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -113,7 +113,7 @@ class HTTPPusherTests(HomeserverTestCase): # It hasn't succeeded yet, so the stream ordering shouldn't have moved pushers = self.get_success( - self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) + self.hs.get_datastore().get_pushers_by({"user_name": user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -132,7 +132,7 @@ class HTTPPusherTests(HomeserverTestCase): # The stream ordering has increased pushers = self.get_success( - self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) + self.hs.get_datastore().get_pushers_by({"user_name": user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -152,7 +152,7 @@ class HTTPPusherTests(HomeserverTestCase): # The stream ordering has increased, again pushers = self.get_success( - self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) + self.hs.get_datastore().get_pushers_by({"user_name": user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 9c13a13786..fa3a3ec1bd 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -40,16 +40,14 @@ class FilterTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertTrue( - set( - [ - "next_batch", - "rooms", - "presence", - "account_data", - "to_device", - "device_lists", - ] - ).issubset(set(channel.json_body.keys())) + { + "next_batch", + "rooms", + "presence", + "account_data", + "to_device", + "device_lists", + }.issubset(set(channel.json_body.keys())) ) def test_sync_presence_disabled(self): @@ -63,9 +61,13 @@ class FilterTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertTrue( - set( - ["next_batch", "rooms", "account_data", "to_device", "device_lists"] - ).issubset(set(channel.json_body.keys())) + { + "next_batch", + "rooms", + "account_data", + "to_device", + "device_lists", + }.issubset(set(channel.json_body.keys())) ) diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index d491ea2924..e37260a820 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -373,7 +373,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): ) self.assertEqual( set(self._dump_to_tuple(res)), - set([(1, "user1", "hello"), (2, "user2", "there")]), + {(1, "user1", "hello"), (2, "user2", "there")}, ) # Update only user2 @@ -400,5 +400,5 @@ class UpsertManyTests(unittest.HomeserverTestCase): ) self.assertEqual( set(self._dump_to_tuple(res)), - set([(1, "user1", "hello"), (2, "user2", "bleb")]), + {(1, "user1", "hello"), (2, "user2", "bleb")}, ) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index fd52512696..31710949a8 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -69,14 +69,14 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): pass def _add_appservice(self, as_token, id, url, hs_token, sender): - as_yaml = dict( - url=url, - as_token=as_token, - hs_token=hs_token, - id=id, - sender_localpart=sender, - namespaces={}, - ) + as_yaml = { + "url": url, + "as_token": as_token, + "hs_token": hs_token, + "id": id, + "sender_localpart": sender, + "namespaces": {}, + } # use the token as the filename with open(as_token, "w") as outfile: outfile.write(yaml.dump(as_yaml)) @@ -135,14 +135,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): ) def _add_service(self, url, as_token, id): - as_yaml = dict( - url=url, - as_token=as_token, - hs_token="something", - id=id, - sender_localpart="a_sender", - namespaces={}, - ) + as_yaml = { + "url": url, + "as_token": as_token, + "hs_token": "something", + "id": id, + "sender_localpart": "a_sender", + "namespaces": {}, + } # use the token as the filename with open(as_token, "w") as outfile: outfile.write(yaml.dump(as_yaml)) @@ -384,8 +384,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): ) self.assertEquals(2, len(services)) self.assertEquals( - set([self.as_list[2]["id"], self.as_list[0]["id"]]), - set([services[0].id, services[1].id]), + {self.as_list[2]["id"], self.as_list[0]["id"]}, + {services[0].id, services[1].id}, ) diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index 029ac26454..0e04b2cf92 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -134,7 +134,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(set(latest_event_ids), set((event_id_a, event_id_b))) + self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b}) # Run the background update and check it did the right thing self.run_background_update() @@ -172,7 +172,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(set(latest_event_ids), set((event_id_a, event_id_b))) + self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b}) # Run the background update and check it did the right thing self.run_background_update() @@ -227,9 +227,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual( - set(latest_event_ids), set((event_id_a, event_id_b, event_id_c)) - ) + self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b, event_id_c}) # Run the background update and check it did the right thing self.run_background_update() @@ -237,7 +235,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(set(latest_event_ids), set([event_id_b, event_id_c])) + self.assertEqual(set(latest_event_ids), {event_id_b, event_id_c}) class CleanupExtremDummyEventsTestCase(HomeserverTestCase): diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py index f26ff57a18..a7b7fd36d3 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py @@ -59,24 +59,22 @@ class ExtremStatisticsTestCase(HomeserverTestCase): ) ) - expected = set( - [ - b'synapse_forward_extremities_bucket{le="1.0"} 0.0', - b'synapse_forward_extremities_bucket{le="2.0"} 2.0', - b'synapse_forward_extremities_bucket{le="3.0"} 2.0', - b'synapse_forward_extremities_bucket{le="5.0"} 2.0', - b'synapse_forward_extremities_bucket{le="7.0"} 3.0', - b'synapse_forward_extremities_bucket{le="10.0"} 3.0', - b'synapse_forward_extremities_bucket{le="15.0"} 3.0', - b'synapse_forward_extremities_bucket{le="20.0"} 3.0', - b'synapse_forward_extremities_bucket{le="50.0"} 3.0', - b'synapse_forward_extremities_bucket{le="100.0"} 3.0', - b'synapse_forward_extremities_bucket{le="200.0"} 3.0', - b'synapse_forward_extremities_bucket{le="500.0"} 3.0', - b'synapse_forward_extremities_bucket{le="+Inf"} 3.0', - b"synapse_forward_extremities_count 3.0", - b"synapse_forward_extremities_sum 10.0", - ] - ) + expected = { + b'synapse_forward_extremities_bucket{le="1.0"} 0.0', + b'synapse_forward_extremities_bucket{le="2.0"} 2.0', + b'synapse_forward_extremities_bucket{le="3.0"} 2.0', + b'synapse_forward_extremities_bucket{le="5.0"} 2.0', + b'synapse_forward_extremities_bucket{le="7.0"} 3.0', + b'synapse_forward_extremities_bucket{le="10.0"} 3.0', + b'synapse_forward_extremities_bucket{le="15.0"} 3.0', + b'synapse_forward_extremities_bucket{le="20.0"} 3.0', + b'synapse_forward_extremities_bucket{le="50.0"} 3.0', + b'synapse_forward_extremities_bucket{le="100.0"} 3.0', + b'synapse_forward_extremities_bucket{le="200.0"} 3.0', + b'synapse_forward_extremities_bucket{le="500.0"} 3.0', + b'synapse_forward_extremities_bucket{le="+Inf"} 3.0', + b"synapse_forward_extremities_count 3.0", + b"synapse_forward_extremities_sum 10.0", + } self.assertEqual(items, expected) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 04d58fbf24..0b88308ff4 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -394,7 +394,7 @@ class StateStoreTestCase(tests.unittest.TestCase): ) = self.state_datastore._state_group_cache.get(group) self.assertEqual(is_all, False) - self.assertEqual(known_absent, set([(e1.type, e1.state_key)])) + self.assertEqual(known_absent, {(e1.type, e1.state_key)}) self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id}) ############################################ diff --git a/tests/test_state.py b/tests/test_state.py index d1578fe581..66f22f6813 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -254,9 +254,7 @@ class StateTestCase(unittest.TestCase): ctx_d = context_store["D"] prev_state_ids = yield ctx_d.get_prev_state_ids() - self.assertSetEqual( - {"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()} - ) + self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values())) self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event) self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group) @@ -313,9 +311,7 @@ class StateTestCase(unittest.TestCase): ctx_e = context_store["E"] prev_state_ids = yield ctx_e.get_prev_state_ids() - self.assertSetEqual( - {"START", "A", "B", "C"}, {e for e in prev_state_ids.values()} - ) + self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values())) self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event) self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group) @@ -388,9 +384,7 @@ class StateTestCase(unittest.TestCase): ctx_d = context_store["D"] prev_state_ids = yield ctx_d.get_prev_state_ids() - self.assertSetEqual( - {"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()} - ) + self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values())) self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event) self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group) @@ -482,7 +476,7 @@ class StateTestCase(unittest.TestCase): current_state_ids = yield context.get_current_state_ids() self.assertEqual( - set([e.event_id for e in old_state]), set(current_state_ids.values()) + {e.event_id for e in old_state}, set(current_state_ids.values()) ) self.assertEqual(group_name, context.state_group) @@ -513,9 +507,7 @@ class StateTestCase(unittest.TestCase): prev_state_ids = yield context.get_prev_state_ids() - self.assertEqual( - set([e.event_id for e in old_state]), set(prev_state_ids.values()) - ) + self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values())) self.assertIsNotNone(context.state_group) diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index f2be63706b..72a9de5370 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -67,7 +67,7 @@ class StreamChangeCacheTests(unittest.TestCase): # If we update an existing entity, it keeps the two existing entities cache.entity_has_changed("bar@baz.net", 5) self.assertEqual( - set(["bar@baz.net", "user@elsewhere.org"]), set(cache._entity_to_key) + {"bar@baz.net", "user@elsewhere.org"}, set(cache._entity_to_key) ) def test_get_all_entities_changed(self): @@ -137,7 +137,7 @@ class StreamChangeCacheTests(unittest.TestCase): cache.get_entities_changed( ["user@foo.com", "bar@baz.net", "user@elsewhere.org"], stream_pos=2 ), - set(["bar@baz.net", "user@elsewhere.org"]), + {"bar@baz.net", "user@elsewhere.org"}, ) # Query all the entries mid-way through the stream, but include one @@ -153,7 +153,7 @@ class StreamChangeCacheTests(unittest.TestCase): ], stream_pos=2, ), - set(["bar@baz.net", "user@elsewhere.org"]), + {"bar@baz.net", "user@elsewhere.org"}, ) # Query all the entries, but before the first known point. We will get @@ -168,21 +168,13 @@ class StreamChangeCacheTests(unittest.TestCase): ], stream_pos=0, ), - set( - [ - "user@foo.com", - "bar@baz.net", - "user@elsewhere.org", - "not@here.website", - ] - ), + {"user@foo.com", "bar@baz.net", "user@elsewhere.org", "not@here.website"}, ) # Query a subset of the entries mid-way through the stream. We should # only get back the subset. self.assertEqual( - cache.get_entities_changed(["bar@baz.net"], stream_pos=2), - set(["bar@baz.net"]), + cache.get_entities_changed(["bar@baz.net"], stream_pos=2), {"bar@baz.net"}, ) def test_max_pos(self): diff --git a/tox.ini b/tox.ini index b9132a3177..b715ea0bff 100644 --- a/tox.ini +++ b/tox.ini @@ -123,6 +123,7 @@ skip_install = True basepython = python3.6 deps = flake8 + flake8-comprehensions black==19.10b0 # We pin so that our tests don't start failing on new releases of black. commands = python -m black --check --diff . -- cgit 1.5.1 From 3a59bd253ed5a6e4ccb2094492111410a7352d6e Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Mon, 24 Feb 2020 17:23:46 +0000 Subject: Revert "Use the v2 lookup API for 3PID invites (#5897)" This reverts commit 978f263e7c5d1eb440efaf07abc5009408ade25d, reversing changes made to 4f6ee99818d9c338944a10585d0aea4c7349d456. --- changelog.d/5897.feature | 1 - synapse/events/spamcheck.py | 1 + synapse/handlers/identity.py | 182 ++++++++++++++++---- synapse/handlers/room_member.py | 250 ++++------------------------ synapse/rest/client/v2_alpha/account.py | 6 +- synapse/util/hash.py | 33 ---- tests/rest/client/test_identity.py | 18 +- tests/rest/client/test_room_access_rules.py | 6 - tests/rest/client/v2_alpha/test_register.py | 1 + tests/rulecheck/test_domainrulecheck.py | 5 - 10 files changed, 183 insertions(+), 320 deletions(-) delete mode 100644 changelog.d/5897.feature delete mode 100644 synapse/util/hash.py (limited to 'tests') diff --git a/changelog.d/5897.feature b/changelog.d/5897.feature deleted file mode 100644 index 7b10774c96..0000000000 --- a/changelog.d/5897.feature +++ /dev/null @@ -1 +0,0 @@ -Switch to the v2 lookup API for 3PID invites. \ No newline at end of file diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index c194ba53f4..f0de4d961f 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + class SpamChecker(object): def __init__(self, hs): self.spam_checker = None diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 4cea75723b..339e0dd04d 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -20,13 +20,18 @@ import logging from canonicaljson import json +from signedjson.key import decode_verify_key_bytes +from signedjson.sign import verify_signed_json +from unpaddedbase64 import decode_base64 from twisted.internet import defer from synapse.api.errors import ( + AuthError, CodeMessageException, Codes, HttpResponseException, + ProxiedRequestError, SynapseError, ) @@ -43,9 +48,26 @@ class IdentityHandler(BaseHandler): self.federation_http_client = hs.get_http_client() self.trusted_id_servers = set(hs.config.trusted_third_party_id_servers) + self.trust_any_id_server_just_for_testing_do_not_use = ( + hs.config.use_insecure_ssl_client_just_for_testing_do_not_use + ) self.rewrite_identity_server_urls = hs.config.rewrite_identity_server_urls self._enable_lookup = hs.config.enable_3pid_lookup + def _should_trust_id_server(self, id_server): + if id_server not in self.trusted_id_servers: + if self.trust_any_id_server_just_for_testing_do_not_use: + logger.warn( + "Trusting untrustworthy ID server %r even though it isn't" + " in the trusted id list for testing because" + " 'use_insecure_ssl_client_just_for_testing_do_not_use'" + " is set in the config", + id_server, + ) + else: + return False + return True + @defer.inlineCallbacks def threepid_from_creds(self, creds): if "id_server" in creds: @@ -62,17 +84,16 @@ class IdentityHandler(BaseHandler): else: raise SynapseError(400, "No client_secret in creds") - if not should_trust_id_server(self.hs, id_server): + if not self._should_trust_id_server(id_server): logger.warn( "%s is not a trusted ID server: rejecting 3pid " + "credentials", id_server, ) return None - # if we have a rewrite rule set for the identity server, # apply it now. - id_server = self.rewrite_identity_server_urls.get(id_server, id_server) - + if id_server in self.rewrite_identity_server_urls: + id_server = self.rewrite_identity_server_urls[id_server] try: data = yield self.http_client.get_json( "https://%s%s" @@ -109,7 +130,10 @@ class IdentityHandler(BaseHandler): # if we have a rewrite rule set for the identity server, # apply it now, but only for sending the request (not # storing in the database). - id_server_host = self.rewrite_identity_server_urls.get(id_server, id_server) + if id_server in self.rewrite_identity_server_urls: + id_server_host = self.rewrite_identity_server_urls[id_server] + else: + id_server_host = id_server try: data = yield self.http_client.post_json_get_json( @@ -181,6 +205,7 @@ class IdentityHandler(BaseHandler): Deferred[bool]: True on success, otherwise False if the identity server doesn't support unbinding """ + url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,) content = { "mxid": mxid, "threepid": {"medium": threepid["medium"], "address": threepid["address"]}, @@ -203,7 +228,8 @@ class IdentityHandler(BaseHandler): # # Note that destination_is has to be the real id_server, not # the server we connect to. - id_server = self.rewrite_identity_server_urls.get(id_server, id_server) + if id_server in self.rewrite_identity_server_urls: + id_server = self.rewrite_identity_server_urls[id_server] url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,) @@ -232,7 +258,7 @@ class IdentityHandler(BaseHandler): def requestEmailToken( self, id_server, email, client_secret, send_attempt, next_link=None ): - if not should_trust_id_server(self.hs, id_server): + if not self._should_trust_id_server(id_server): raise SynapseError( 400, "Untrusted ID server '%s'" % id_server, Codes.SERVER_NOT_TRUSTED ) @@ -243,8 +269,10 @@ class IdentityHandler(BaseHandler): "send_attempt": send_attempt, } - # Rewrite id_server URL if necessary - id_server = self.rewrite_identity_server_urls.get(id_server, id_server) + # if we have a rewrite rule set for the identity server, + # apply it now. + if id_server in self.rewrite_identity_server_urls: + id_server = self.rewrite_identity_server_urls[id_server] if next_link: params.update({"next_link": next_link}) @@ -264,14 +292,11 @@ class IdentityHandler(BaseHandler): def requestMsisdnToken( self, id_server, country, phone_number, client_secret, send_attempt, **kwargs ): - if not should_trust_id_server(self.hs, id_server): + if not self._should_trust_id_server(id_server): raise SynapseError( 400, "Untrusted ID server '%s'" % id_server, Codes.SERVER_NOT_TRUSTED ) - # Rewrite id_server URL if necessary - id_server = self.rewrite_identity_server_urls.get(id_server, id_server) - params = { "country": country, "phone_number": phone_number, @@ -294,30 +319,119 @@ class IdentityHandler(BaseHandler): logger.info("Proxied requestToken failed: %r", e) raise e.to_synapse_error() + @defer.inlineCallbacks + def lookup_3pid(self, id_server, medium, address): + """Looks up a 3pid in the passed identity server. -def should_trust_id_server(hs, id_server): - if id_server not in hs.config.trusted_third_party_id_servers: - if hs.trust_any_id_server_just_for_testing_do_not_use: - logger.warn( - "Trusting untrustworthy ID server %r even though it isn't" - " in the trusted id list for testing because" - " 'use_insecure_ssl_client_just_for_testing_do_not_use'" - " is set in the config", - id_server, + Args: + id_server (str): The server name (including port, if required) + of the identity server to use. + medium (str): The type of the third party identifier (e.g. "email"). + address (str): The third party identifier (e.g. "foo@example.com"). + + Returns: + Deferred[dict]: The result of the lookup. See + https://matrix.org/docs/spec/identity_service/r0.1.0.html#association-lookup + for details + """ + if not self._should_trust_id_server(id_server): + raise SynapseError( + 400, "Untrusted ID server '%s'" % id_server, Codes.SERVER_NOT_TRUSTED + ) + + if not self._enable_lookup: + raise AuthError( + 403, "Looking up third-party identifiers is denied from this server" ) - else: - return False - return True + target = self.rewrite_identity_server_urls.get(id_server, id_server) -class LookupAlgorithm: - """ - Supported hashing algorithms when performing a 3PID lookup. + try: + data = yield self.http_client.get_json( + "https://%s/_matrix/identity/api/v1/lookup" % (target,), + {"medium": medium, "address": address}, + ) - SHA256 - Hashing an (address, medium, pepper) combo with sha256, then url-safe base64 - encoding - NONE - Not performing any hashing. Simply sending an (address, medium) combo in plaintext - """ + if "mxid" in data: + if "signatures" not in data: + raise AuthError(401, "No signatures on 3pid binding") + yield self._verify_any_signature(data, id_server) + + except HttpResponseException as e: + logger.info("Proxied lookup failed: %r", e) + raise e.to_synapse_error() + except IOError as e: + logger.info("Failed to contact %r: %s", id_server, e) + raise ProxiedRequestError(503, "Failed to contact identity server") + + defer.returnValue(data) + + @defer.inlineCallbacks + def bulk_lookup_3pid(self, id_server, threepids): + """Looks up given 3pids in the passed identity server. + + Args: + id_server (str): The server name (including port, if required) + of the identity server to use. + threepids ([[str, str]]): The third party identifiers to lookup, as + a list of 2-string sized lists ([medium, address]). + + Returns: + Deferred[dict]: The result of the lookup. See + https://matrix.org/docs/spec/identity_service/r0.1.0.html#association-lookup + for details + """ + if not self._should_trust_id_server(id_server): + raise SynapseError( + 400, "Untrusted ID server '%s'" % id_server, Codes.SERVER_NOT_TRUSTED + ) + + if not self._enable_lookup: + raise AuthError( + 403, "Looking up third-party identifiers is denied from this server" + ) + + target = self.rewrite_identity_server_urls.get(id_server, id_server) + + try: + data = yield self.http_client.post_json_get_json( + "https://%s/_matrix/identity/api/v1/bulk_lookup" % (target,), + {"threepids": threepids}, + ) + + except HttpResponseException as e: + logger.info("Proxied lookup failed: %r", e) + raise e.to_synapse_error() + except IOError as e: + logger.info("Failed to contact %r: %s", id_server, e) + raise ProxiedRequestError(503, "Failed to contact identity server") + + defer.returnValue(data) + + @defer.inlineCallbacks + def _verify_any_signature(self, data, server_hostname): + if server_hostname not in data["signatures"]: + raise AuthError(401, "No signature from server %s" % (server_hostname,)) + + for key_name, signature in data["signatures"][server_hostname].items(): + target = self.rewrite_identity_server_urls.get( + server_hostname, server_hostname + ) + + key_data = yield self.http_client.get_json( + "https://%s/_matrix/identity/api/v1/pubkey/%s" % (target, key_name) + ) + if "public_key" not in key_data: + raise AuthError( + 401, "No public key named %s from %s" % (key_name, server_hostname) + ) + verify_signed_json( + data, + server_hostname, + decode_verify_key_bytes( + key_name, decode_base64(key_data["public_key"]) + ), + ) + return - SHA256 = "sha256" - NONE = "none" + raise AuthError(401, "No signature from server %s" % (server_hostname,)) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 398c377216..3c00151042 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -20,14 +20,11 @@ import logging from six.moves import http_client -from signedjson.key import decode_verify_key_bytes -from signedjson.sign import verify_signed_json -from unpaddedbase64 import decode_base64 - from twisted.internet import defer from synapse import types from synapse.api.constants import EventTypes, Membership +from synapse.api.ratelimiting import Ratelimiter from synapse.api.errors import ( AuthError, Codes, @@ -35,13 +32,9 @@ from synapse.api.errors import ( HttpResponseException, SynapseError, ) -from synapse.handlers.identity import LookupAlgorithm, should_trust_id_server from synapse.types import RoomID, UserID from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_joined_room, user_left_room -from synapse.util.hash import sha256_and_url_safe_base64 - -from ._base import BaseHandler logger = logging.getLogger(__name__) @@ -85,11 +78,7 @@ class RoomMemberHandler(object): self.rewrite_identity_server_urls = self.config.rewrite_identity_server_urls self._enable_lookup = hs.config.enable_3pid_lookup self.allow_per_room_profiles = self.config.allow_per_room_profiles - - # This is only used to get at ratelimit function, and - # maybe_kick_guest_users. It's fine there are multiple of these as - # it doesn't store state. - self.base_handler = BaseHandler(hs) + self.ratelimiter = Ratelimiter() @abc.abstractmethod def _remote_join(self, requester, remote_room_hosts, room_id, user, content): @@ -583,7 +572,7 @@ class RoomMemberHandler(object): event (SynapseEvent): The membership event. context: The context of the event. is_guest (bool): Whether the sender is a guest. - remote_room_hosts (list[str]|None): Homeservers which are likely to already be in + room_hosts ([str]): Homeservers which are likely to already be in the room, and could be danced with in order to join this homeserver for the first time. ratelimit (bool): Whether to rate limit this request. @@ -694,7 +683,7 @@ class RoomMemberHandler(object): servers.remove(room_alias.domain) servers.insert(0, room_alias.domain) - return RoomID.from_string(room_id), servers + return (RoomID.from_string(room_id), servers) @defer.inlineCallbacks def _get_inviter(self, user_id, room_id): @@ -725,7 +714,7 @@ class RoomMemberHandler(object): # We need to rate limit *before* we send out any 3PID invites, so we # can't just rely on the standard ratelimiting of events. - self.base_handler.ratelimiter.ratelimit( + self.ratelimiter.ratelimit( requester.user.to_string(), time_now_s=self.hs.clock.time(), rate_hz=self.hs.config.rc_third_party_invite.per_second, @@ -753,7 +742,7 @@ class RoomMemberHandler(object): Codes.FORBIDDEN, ) - invitee = yield self.lookup_3pid(id_server, medium, address) + invitee = yield self._lookup_3pid(id_server, medium, address) is_published = yield self.store.is_room_published(room_id) @@ -777,60 +766,23 @@ class RoomMemberHandler(object): requester, id_server, medium, address, room_id, inviter, txn_id=txn_id ) - @defer.inlineCallbacks - def lookup_3pid(self, id_server, medium, address): - """Looks up a 3pid in the passed identity server. + def _get_id_server_target(self, id_server): + """Looks up an id_server's actual http endpoint Args: - id_server (str): The server name (including port, if required) - of the identity server to use. - medium (str): The type of the third party identifier (e.g. "email"). - address (str): The third party identifier (e.g. "foo@example.com"). + id_server (str): the server name to lookup. Returns: - str: the matrix ID of the 3pid, or None if it is not recognized. + the http endpoint to connect to. """ - if not should_trust_id_server(self.hs, id_server): - raise SynapseError( - 400, "Untrusted ID server '%s'" % id_server, - Codes.SERVER_NOT_TRUSTED - ) - - if not self._enable_lookup: - raise SynapseError( - 403, "Looking up third-party identifiers is denied from this server" - ) + if id_server in self.rewrite_identity_server_urls: + return self.rewrite_identity_server_urls[id_server] - # Rewrite id_server URL if necessary - id_server = self.rewrite_identity_server_urls.get(id_server, id_server) - - # Check what hashing details are supported by this identity server - use_v1 = False - hash_details = None - try: - hash_details = yield self.simple_http_client.get_json( - "%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server) - ) - except (HttpResponseException, ValueError) as e: - # Catch HttpResponseExcept for a non-200 response code - # Catch ValueError for non-JSON response body - - # Check if this identity server does not know about v2 lookups - if e.code == 404: - # This is an old identity server that does not yet support v2 lookups - use_v1 = True - else: - logger.warn("Error when looking up hashing details: %s" % (e,)) - return None - - if use_v1: - return (yield self._lookup_3pid_v1(id_server, medium, address)) - - return (yield self._lookup_3pid_v2(id_server, medium, address, hash_details)) + return id_server @defer.inlineCallbacks - def _lookup_3pid_v1(self, id_server, medium, address): - """Looks up a 3pid in the passed identity server using v1 lookup. + def _lookup_3pid(self, id_server, medium, address): + """Looks up a 3pid in the passed identity server. Args: id_server (str): The server name (including port, if required) @@ -842,164 +794,12 @@ class RoomMemberHandler(object): str: the matrix ID of the 3pid, or None if it is not recognized. """ try: - data = yield self.simple_http_client.get_json( - "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server), - {"medium": medium, "address": address}, - ) - - if "mxid" in data: - if "signatures" not in data: - raise AuthError(401, "No signatures on 3pid binding") - yield self._verify_any_signature(data, id_server) - return data["mxid"] - + data = yield self.identity_handler.lookup_3pid(id_server, medium, address) + return data.get("mxid") except ProxiedRequestError as e: logger.warn("Error from identity server lookup: %s" % (e,)) - - except IOError as e: - logger.warn("Error from identity server lookup: %s" % (e,)) - - return None - - @defer.inlineCallbacks - def _lookup_3pid_v2(self, id_server, medium, address, hash_details): - """Looks up a 3pid in the passed identity server using v2 lookup. - - Args: - id_server (str): The server name (including port, if required) - of the identity server to use. - medium (str): The type of the third party identifier (e.g. "email"). - address (str): The third party identifier (e.g. "foo@example.com"). - hash_details (dict[str, str|list]): A dictionary containing hashing information - provided by an identity server. - - Returns: - Deferred[str|None]: the matrix ID of the 3pid, or None if it is not recognised. - """ - # Extract information from hash_details - supported_lookup_algorithms = hash_details["algorithms"] - lookup_pepper = hash_details["lookup_pepper"] - - # Check if any of the supported lookup algorithms are present - if LookupAlgorithm.SHA256 in supported_lookup_algorithms: - # Perform a hashed lookup - lookup_algorithm = LookupAlgorithm.SHA256 - - # Hash address, medium and the pepper with sha256 - to_hash = "%s %s %s" % (address, medium, lookup_pepper) - lookup_value = sha256_and_url_safe_base64(to_hash) - - elif LookupAlgorithm.NONE in supported_lookup_algorithms: - # Perform a non-hashed lookup - lookup_algorithm = LookupAlgorithm.NONE - - # Combine together plaintext address and medium - lookup_value = "%s %s" % (address, medium) - - else: - logger.warn( - "None of the provided lookup algorithms of %s%s are supported: %s", - id_server_scheme, - id_server, - hash_details["algorithms"], - ) - raise SynapseError( - 400, - "Provided identity server does not support any v2 lookup " - "algorithms that this homeserver supports.", - ) - - try: - lookup_results = yield self.simple_http_client.post_json_get_json( - "%s%s/_matrix/identity/v2/lookup" % (id_server_scheme, id_server), - { - "addresses": [lookup_value], - "algorithm": lookup_algorithm, - "pepper": lookup_pepper, - }, - ) - except (HttpResponseException, ValueError) as e: - # Catch HttpResponseExcept for a non-200 response code - # Catch ValueError for non-JSON response body - logger.warn("Error when performing a 3pid lookup: %s" % (e,)) - return None - - # Check for a mapping from what we looked up to an MXID - if "mappings" not in lookup_results or not isinstance( - lookup_results["mappings"], dict - ): - logger.debug("No results from 3pid lookup") return None - # Return the MXID if it's available, or None otherwise - mxid = lookup_results["mappings"].get(lookup_value) - return mxid - - @defer.inlineCallbacks - def bulk_lookup_3pid(self, id_server, threepids): - """Looks up given 3pids in the passed identity server. - Args: - id_server (str): The server name (including port, if required) - of the identity server to use. - threepids ([[str, str]]): The third party identifiers to lookup, as - a list of 2-string sized lists ([medium, address]). - Returns: - Deferred[dict]: The result of the lookup. See - https://matrix.org/docs/spec/identity_service/r0.1.0.html#association-lookup - for details - """ - if not should_trust_id_server(self.hs, id_server): - raise SynapseError( - 400, "Untrusted ID server '%s'" % id_server, - Codes.SERVER_NOT_TRUSTED - ) - - if not self._enable_lookup: - raise AuthError( - 403, "Looking up third-party identifiers is denied from this server", - ) - - target = self.rewrite_identity_server_urls.get(id_server, id_server) - - try: - data = yield self.simple_http_client.post_json_get_json( - "https://%s/_matrix/identity/api/v1/bulk_lookup" % (target,), - { - "threepids": threepids, - } - ) - - except HttpResponseException as e: - logger.info("Proxied lookup failed: %r", e) - raise e.to_synapse_error() - except IOError as e: - logger.info("Failed to contact %r: %s", id_server, e) - raise ProxiedRequestError(503, "Failed to contact identity server") - - defer.returnValue(data) - - @defer.inlineCallbacks - def _verify_any_signature(self, data, server_hostname): - if server_hostname not in data["signatures"]: - raise AuthError(401, "No signature from server %s" % (server_hostname,)) - for key_name, signature in data["signatures"][server_hostname].items(): - key_data = yield self.simple_http_client.get_json( - "%s%s/_matrix/identity/api/v1/pubkey/%s" - % (id_server_scheme, server_hostname, key_name) - ) - if "public_key" not in key_data: - raise AuthError( - 401, "No public key named %s from %s" % (key_name, server_hostname) - ) - verify_signed_json( - data, - server_hostname, - decode_verify_key_bytes( - key_name, decode_base64(key_data["public_key"]) - ), - ) - return - @defer.inlineCallbacks def _make_and_store_3pid_invite( self, requester, id_server, medium, address, room_id, user, txn_id @@ -1117,10 +917,11 @@ class RoomMemberHandler(object): display_name (str): A user-friendly name to represent the invited user. """ - id_server = self.rewrite_identity_server_urls.get(id_server, id_server) + + target = self._get_id_server_target(id_server) is_url = "%s%s/_matrix/identity/api/v1/store-invite" % ( id_server_scheme, - id_server, + target, ) invite_config = { @@ -1160,7 +961,7 @@ class RoomMemberHandler(object): fallback_public_key = { "public_key": data["public_key"], "key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" - % (id_server_scheme, id_server), + % (id_server_scheme, target), } else: fallback_public_key = public_keys[0] @@ -1227,7 +1028,9 @@ class RoomMemberMasterHandler(RoomMemberHandler): ) if complexity: - return complexity["v1"] > max_complexity + if complexity["v1"] > max_complexity: + return True + return False return None @defer.inlineCallbacks @@ -1243,7 +1046,10 @@ class RoomMemberMasterHandler(RoomMemberHandler): max_complexity = self.hs.config.limit_remote_rooms.complexity complexity = yield self.store.get_room_complexity(room_id) - return complexity["v1"] > max_complexity + if complexity["v1"] > max_complexity: + return True + + return False @defer.inlineCallbacks def _remote_join(self, requester, remote_room_hosts, room_id, user, content): diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 986f56d3d7..b83704fed4 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -706,7 +706,6 @@ class ThreepidLookupRestServlet(RestServlet): super(ThreepidLookupRestServlet, self).__init__() self.auth = hs.get_auth() self.identity_handler = hs.get_handlers().identity_handler - self.room_member_handler = hs.get_room_member_handler() @defer.inlineCallbacks def on_GET(self, request): @@ -726,7 +725,7 @@ class ThreepidLookupRestServlet(RestServlet): # Proxy the request to the identity server. lookup_3pid handles checking # if the lookup is allowed so we don't need to do it here. - ret = yield self.room_member_handler.lookup_3pid(id_server, medium, address) + ret = yield self.identity_handler.lookup_3pid(id_server, medium, address) defer.returnValue((200, ret)) @@ -738,7 +737,6 @@ class ThreepidBulkLookupRestServlet(RestServlet): super(ThreepidBulkLookupRestServlet, self).__init__() self.auth = hs.get_auth() self.identity_handler = hs.get_handlers().identity_handler - self.room_member_handler = hs.get_room_member_handler() @defer.inlineCallbacks def on_POST(self, request): @@ -753,7 +751,7 @@ class ThreepidBulkLookupRestServlet(RestServlet): # Proxy the request to the identity server. lookup_3pid handles checking # if the lookup is allowed so we don't need to do it here. - ret = yield self.room_member_handler.bulk_lookup_3pid( + ret = yield self.identity_handler.bulk_lookup_3pid( body["id_server"], body["threepids"] ) diff --git a/synapse/util/hash.py b/synapse/util/hash.py deleted file mode 100644 index 359168704e..0000000000 --- a/synapse/util/hash.py +++ /dev/null @@ -1,33 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import hashlib - -import unpaddedbase64 - - -def sha256_and_url_safe_base64(input_text): - """SHA256 hash an input string, encode the digest as url-safe base64, and - return - - :param input_text: string to hash - :type input_text: str - - :returns a sha256 hashed and url-safe base64 encoded digest - :rtype: str - """ - digest = hashlib.sha256(input_text.encode()).digest() - return unpaddedbase64.encode_base64(digest, urlsafe=True) diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py index a7f08a3f19..f81f81602e 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py @@ -20,7 +20,6 @@ from mock import Mock from twisted.internet import defer import synapse.rest.admin -from synapse.api.errors import HttpResponseException from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import account @@ -110,15 +109,8 @@ class IdentityEnabledTestCase(unittest.HomeserverTestCase): config["enable_3pid_lookup"] = True config["trusted_third_party_id_servers"] = ["testis"] - def get_json(uri, args={}, headers=None): - # TODO: Mock v2 hash_details endpoints and don't just run the v1 code - if "/hash_details" in uri: - raise HttpResponseException(404, "I am a happy v1 server", b"{}") - - return defer.succeed((200, "{}")) - mock_http_client = Mock(spec=["get_json", "post_json_get_json"]) - mock_http_client.get_json.side_effect = get_json + mock_http_client.get_json.return_value = defer.succeed((200, "{}")) mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}")) self.hs = self.setup_test_homeserver( @@ -151,10 +143,8 @@ class IdentityEnabledTestCase(unittest.HomeserverTestCase): ) self.render(request) - # There will be a call to hash_details, which will 404 - # then /lookup. Thus we don't use `assert_called_with_once` here get_json = self.hs.get_simple_http_client().get_json - get_json.assert_called_with( + get_json.assert_called_once_with( "https://testis/_matrix/identity/api/v1/lookup", {"address": "test@example.com", "medium": "email"}, ) @@ -167,10 +157,8 @@ class IdentityEnabledTestCase(unittest.HomeserverTestCase): request, channel = self.make_request("GET", url, access_token=self.tok) self.render(request) - # There will be a call to hash_details, which will 404 - # then /lookup. Thus we don't use `assert_called_with_once` here get_json = self.hs.get_simple_http_client().get_json - get_json.assert_called_with( + get_json.assert_called_once_with( "https://testis/_matrix/identity/api/v1/lookup", {"address": "foo@bar.baz", "medium": "email"}, ) diff --git a/tests/rest/client/test_room_access_rules.py b/tests/rest/client/test_room_access_rules.py index 69920853e6..d44f5c2c8c 100644 --- a/tests/rest/client/test_room_access_rules.py +++ b/tests/rest/client/test_room_access_rules.py @@ -22,7 +22,6 @@ from mock import Mock from twisted.internet import defer from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset -from synapse.api.errors import HttpResponseException from synapse.rest import admin from synapse.rest.client.v1 import login, room from synapse.third_party_rules.access_rules import ( @@ -59,10 +58,6 @@ class RoomAccessTestCase(unittest.HomeserverTestCase): return defer.succeed(pdu) def get_json(uri, args={}, headers=None): - # TODO: Mock v2 hash_details endpoints and don't just run the v1 code - if "/hash_details" in uri: - raise HttpResponseException(404, "I am a happy v1 server", b"{}") - address_domain = args["address"].split("@")[1] return defer.succeed({"hs": address_domain}) @@ -337,7 +332,6 @@ class RoomAccessTestCase(unittest.HomeserverTestCase): self.hs.config.rc_third_party_invite.burst_count = 10 self.hs.config.rc_third_party_invite.per_second = 0.1 - # Note: These calls trigger the mocked get_json method above # We can't send a 3PID invite to a room that already has two members. self.send_threepid_invite( address="test@allowed_domain", diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 9b7eb1e856..9e72cae62a 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -52,6 +52,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): return self.hs + @unittest.DEBUG def test_POST_appservice_registration_valid(self): user_id = "@as_user_kermit:test" as_token = "i_am_an_app_service" diff --git a/tests/rulecheck/test_domainrulecheck.py b/tests/rulecheck/test_domainrulecheck.py index 145d6cba5b..1accc70dc9 100644 --- a/tests/rulecheck/test_domainrulecheck.py +++ b/tests/rulecheck/test_domainrulecheck.py @@ -15,8 +15,6 @@ import json -from mock import Mock -from twisted.internet import defer import synapse.rest.admin from synapse.config._base import ConfigError @@ -289,9 +287,6 @@ class DomainRuleCheckerRoomTestCase(unittest.HomeserverTestCase): def test_cannot_3pid_invite(self): """Test that unbound 3pid invites get rejected. """ - self.hs.get_room_member_handler().lookup_3pid = Mock() - self.hs.get_room_member_handler().lookup_3pid.return_value = defer.succeed(None) - channel = self._create_room(self.admin_access_token) assert channel.result["code"] == b"200", channel.result -- cgit 1.5.1 From 476932cdef7779a1d68e20074843fe280517b102 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Tue, 25 Feb 2020 11:14:11 +0000 Subject: Increase expected state events in tests for new room by one --- tests/handlers/test_stats.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 7569b6fab5..207bf88c01 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -21,8 +21,12 @@ from tests import unittest # The expected number of state events in a fresh public room. EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM = 5 + # The expected number of state events in a fresh private room. -EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM = 6 +# +# Note: we increase this by 1 on the dinsic branch as we send +# a "im.vector.room.access_rules" state event into new private rooms +EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM = 7 class StatsRoomTests(unittest.HomeserverTestCase): @@ -631,6 +635,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.assertEqual(u1stats["joined_rooms"], 1) + @unittest.DEBUG def test_incomplete_stats(self): """ This tests that we track incomplete statistics. -- cgit 1.5.1 From 3c34ddab01cc17671d5dc10eb41abed4d8b7cc0f Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Tue, 25 Feb 2020 11:36:04 +0000 Subject: Remove extraneous unittest.DEBUG's --- tests/handlers/test_stats.py | 1 - tests/rest/client/v2_alpha/test_register.py | 1 - 2 files changed, 2 deletions(-) (limited to 'tests') diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 207bf88c01..89ec8636a6 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -635,7 +635,6 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.assertEqual(u1stats["joined_rooms"], 1) - @unittest.DEBUG def test_incomplete_stats(self): """ This tests that we track incomplete statistics. diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index b6093a0012..9a5d275d06 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -52,7 +52,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): return self.hs - @unittest.DEBUG def test_POST_appservice_registration_valid(self): user_id = "@as_user_kermit:test" as_token = "i_am_an_app_service" -- cgit 1.5.1 From bbf8886a05be6a929556d6f09a1b6ce053a3c403 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 25 Feb 2020 16:56:55 +0000 Subject: Merge worker apps into one. (#6964) --- changelog.d/6964.misc | 1 + synapse/app/appservice.py | 156 +---- synapse/app/client_reader.py | 190 +----- synapse/app/event_creator.py | 186 +----- synapse/app/federation_reader.py | 172 +----- synapse/app/federation_sender.py | 303 +-------- synapse/app/frontend_proxy.py | 236 +------ synapse/app/generic_worker.py | 917 ++++++++++++++++++++++++++++ synapse/app/media_repository.py | 157 +---- synapse/app/pusher.py | 209 +------ synapse/app/synchrotron.py | 449 +------------- synapse/app/user_dir.py | 211 +------ synapse/replication/slave/storage/events.py | 20 + synapse/storage/data_stores/main/pusher.py | 156 ++--- tests/app/test_frontend_proxy.py | 12 +- tests/app/test_openid_listener.py | 4 +- 16 files changed, 1052 insertions(+), 2327 deletions(-) create mode 100644 changelog.d/6964.misc create mode 100644 synapse/app/generic_worker.py (limited to 'tests') diff --git a/changelog.d/6964.misc b/changelog.d/6964.misc new file mode 100644 index 0000000000..ec5c004bbe --- /dev/null +++ b/changelog.d/6964.misc @@ -0,0 +1 @@ +Merge worker apps together. diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index 2217d4a4fb..add43147b3 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -13,161 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging -import sys - -from twisted.internet import defer, reactor -from twisted.web.resource import NoResource - -import synapse -from synapse import events -from synapse.app import _base -from synapse.config._base import ConfigError -from synapse.config.homeserver import HomeServerConfig -from synapse.config.logger import setup_logging -from synapse.http.site import SynapseSite -from synapse.logging.context import LoggingContext, run_in_background -from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy -from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore -from synapse.replication.slave.storage.directory import DirectoryStore -from synapse.replication.slave.storage.events import SlavedEventStore -from synapse.replication.slave.storage.registration import SlavedRegistrationStore -from synapse.replication.tcp.client import ReplicationClientHandler -from synapse.server import HomeServer -from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.manhole import manhole -from synapse.util.versionstring import get_version_string - -logger = logging.getLogger("synapse.app.appservice") - - -class AppserviceSlaveStore( - DirectoryStore, - SlavedEventStore, - SlavedApplicationServiceStore, - SlavedRegistrationStore, -): - pass - - -class AppserviceServer(HomeServer): - DATASTORE_CLASS = AppserviceSlaveStore - - def _listen_http(self, listener_config): - port = listener_config["port"] - bind_addresses = listener_config["bind_addresses"] - site_tag = listener_config.get("tag", port) - resources = {} - for res in listener_config["resources"]: - for name in res["names"]: - if name == "metrics": - resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) - - root_resource = create_resource_tree(resources, NoResource()) - - _base.listen_tcp( - bind_addresses, - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - self.version_string, - ), - ) - - logger.info("Synapse appservice now listening on port %d", port) - - def start_listening(self, listeners): - for listener in listeners: - if listener["type"] == "http": - self._listen_http(listener) - elif listener["type"] == "manhole": - _base.listen_tcp( - listener["bind_addresses"], - listener["port"], - manhole( - username="matrix", password="rabbithole", globals={"hs": self} - ), - ) - elif listener["type"] == "metrics": - if not self.get_config().enable_metrics: - logger.warning( - ( - "Metrics listener configured, but " - "enable_metrics is not True!" - ) - ) - else: - _base.listen_metrics(listener["bind_addresses"], listener["port"]) - else: - logger.warning("Unrecognized listener type: %s", listener["type"]) - - self.get_tcp_replication().start_replication(self) - - def build_tcp_replication(self): - return ASReplicationHandler(self) +import sys -class ASReplicationHandler(ReplicationClientHandler): - def __init__(self, hs): - super(ASReplicationHandler, self).__init__(hs.get_datastore()) - self.appservice_handler = hs.get_application_service_handler() - - async def on_rdata(self, stream_name, token, rows): - await super(ASReplicationHandler, self).on_rdata(stream_name, token, rows) - - if stream_name == "events": - max_stream_id = self.store.get_room_max_stream_ordering() - run_in_background(self._notify_app_services, max_stream_id) - - @defer.inlineCallbacks - def _notify_app_services(self, room_stream_id): - try: - yield self.appservice_handler.notify_interested_services(room_stream_id) - except Exception: - logger.exception("Error notifying application services of event") - - -def start(config_options): - try: - config = HomeServerConfig.load_config("Synapse appservice", config_options) - except ConfigError as e: - sys.stderr.write("\n" + str(e) + "\n") - sys.exit(1) - - assert config.worker_app == "synapse.app.appservice" - - events.USE_FROZEN_DICTS = config.use_frozen_dicts - - if config.notify_appservices: - sys.stderr.write( - "\nThe appservices must be disabled in the main synapse process" - "\nbefore they can be run in a separate worker." - "\nPlease add ``notify_appservices: false`` to the main config" - "\n" - ) - sys.exit(1) - - # Force the pushers to start since they will be disabled in the main config - config.notify_appservices = True - - ps = AppserviceServer( - config.server_name, - config=config, - version_string="Synapse/" + get_version_string(synapse), - ) - - setup_logging(ps, config, use_worker_options=True) - - ps.setup() - reactor.addSystemEventTrigger( - "before", "startup", _base.start, ps, config.worker_listeners - ) - - _base.start_worker_reactor("synapse-appservice", config) - +from synapse.app.generic_worker import start +from synapse.util.logcontext import LoggingContext if __name__ == "__main__": with LoggingContext("main"): diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index 7fa91a3b11..add43147b3 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -13,195 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging -import sys - -from twisted.internet import reactor -from twisted.web.resource import NoResource - -import synapse -from synapse import events -from synapse.app import _base -from synapse.config._base import ConfigError -from synapse.config.homeserver import HomeServerConfig -from synapse.config.logger import setup_logging -from synapse.http.server import JsonResource -from synapse.http.site import SynapseSite -from synapse.logging.context import LoggingContext -from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy -from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.replication.slave.storage.account_data import SlavedAccountDataStore -from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore -from synapse.replication.slave.storage.client_ips import SlavedClientIpStore -from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore -from synapse.replication.slave.storage.devices import SlavedDeviceStore -from synapse.replication.slave.storage.directory import DirectoryStore -from synapse.replication.slave.storage.events import SlavedEventStore -from synapse.replication.slave.storage.groups import SlavedGroupServerStore -from synapse.replication.slave.storage.keys import SlavedKeyStore -from synapse.replication.slave.storage.profile import SlavedProfileStore -from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore -from synapse.replication.slave.storage.receipts import SlavedReceiptsStore -from synapse.replication.slave.storage.registration import SlavedRegistrationStore -from synapse.replication.slave.storage.room import RoomStore -from synapse.replication.slave.storage.transactions import SlavedTransactionStore -from synapse.replication.tcp.client import ReplicationClientHandler -from synapse.rest.client.v1.login import LoginRestServlet -from synapse.rest.client.v1.push_rule import PushRuleRestServlet -from synapse.rest.client.v1.room import ( - JoinedRoomMemberListRestServlet, - PublicRoomListRestServlet, - RoomEventContextServlet, - RoomMemberListRestServlet, - RoomMessageListRestServlet, - RoomStateRestServlet, -) -from synapse.rest.client.v1.voip import VoipRestServlet -from synapse.rest.client.v2_alpha import groups -from synapse.rest.client.v2_alpha.account import ThreepidRestServlet -from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet -from synapse.rest.client.v2_alpha.register import RegisterRestServlet -from synapse.rest.client.versions import VersionsRestServlet -from synapse.server import HomeServer -from synapse.storage.data_stores.main.monthly_active_users import ( - MonthlyActiveUsersWorkerStore, -) -from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.manhole import manhole -from synapse.util.versionstring import get_version_string - -logger = logging.getLogger("synapse.app.client_reader") - - -class ClientReaderSlavedStore( - SlavedDeviceInboxStore, - SlavedDeviceStore, - SlavedReceiptsStore, - SlavedPushRuleStore, - SlavedGroupServerStore, - SlavedAccountDataStore, - SlavedEventStore, - SlavedKeyStore, - RoomStore, - DirectoryStore, - SlavedApplicationServiceStore, - SlavedRegistrationStore, - SlavedTransactionStore, - SlavedProfileStore, - SlavedClientIpStore, - MonthlyActiveUsersWorkerStore, - BaseSlavedStore, -): - pass - - -class ClientReaderServer(HomeServer): - DATASTORE_CLASS = ClientReaderSlavedStore - - def _listen_http(self, listener_config): - port = listener_config["port"] - bind_addresses = listener_config["bind_addresses"] - site_tag = listener_config.get("tag", port) - resources = {} - for res in listener_config["resources"]: - for name in res["names"]: - if name == "metrics": - resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) - elif name == "client": - resource = JsonResource(self, canonical_json=False) - - PublicRoomListRestServlet(self).register(resource) - RoomMemberListRestServlet(self).register(resource) - JoinedRoomMemberListRestServlet(self).register(resource) - RoomStateRestServlet(self).register(resource) - RoomEventContextServlet(self).register(resource) - RoomMessageListRestServlet(self).register(resource) - RegisterRestServlet(self).register(resource) - LoginRestServlet(self).register(resource) - ThreepidRestServlet(self).register(resource) - KeyQueryServlet(self).register(resource) - KeyChangesServlet(self).register(resource) - VoipRestServlet(self).register(resource) - PushRuleRestServlet(self).register(resource) - VersionsRestServlet(self).register(resource) - - groups.register_servlets(self, resource) - - resources.update({"/_matrix/client": resource}) - - root_resource = create_resource_tree(resources, NoResource()) - _base.listen_tcp( - bind_addresses, - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - self.version_string, - ), - ) - - logger.info("Synapse client reader now listening on port %d", port) - - def start_listening(self, listeners): - for listener in listeners: - if listener["type"] == "http": - self._listen_http(listener) - elif listener["type"] == "manhole": - _base.listen_tcp( - listener["bind_addresses"], - listener["port"], - manhole( - username="matrix", password="rabbithole", globals={"hs": self} - ), - ) - elif listener["type"] == "metrics": - if not self.get_config().enable_metrics: - logger.warning( - ( - "Metrics listener configured, but " - "enable_metrics is not True!" - ) - ) - else: - _base.listen_metrics(listener["bind_addresses"], listener["port"]) - else: - logger.warning("Unrecognized listener type: %s", listener["type"]) - - self.get_tcp_replication().start_replication(self) - - def build_tcp_replication(self): - return ReplicationClientHandler(self.get_datastore()) - - -def start(config_options): - try: - config = HomeServerConfig.load_config("Synapse client reader", config_options) - except ConfigError as e: - sys.stderr.write("\n" + str(e) + "\n") - sys.exit(1) - - assert config.worker_app == "synapse.app.client_reader" - - events.USE_FROZEN_DICTS = config.use_frozen_dicts - - ss = ClientReaderServer( - config.server_name, - config=config, - version_string="Synapse/" + get_version_string(synapse), - ) - - setup_logging(ss, config, use_worker_options=True) - - ss.setup() - reactor.addSystemEventTrigger( - "before", "startup", _base.start, ss, config.worker_listeners - ) - - _base.start_worker_reactor("synapse-client-reader", config) +import sys +from synapse.app.generic_worker import start +from synapse.util.logcontext import LoggingContext if __name__ == "__main__": with LoggingContext("main"): diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py index 58e5b354f6..e9c098c4e7 100644 --- a/synapse/app/event_creator.py +++ b/synapse/app/event_creator.py @@ -13,191 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging -import sys - -from twisted.internet import reactor -from twisted.web.resource import NoResource - -import synapse -from synapse import events -from synapse.app import _base -from synapse.config._base import ConfigError -from synapse.config.homeserver import HomeServerConfig -from synapse.config.logger import setup_logging -from synapse.http.server import JsonResource -from synapse.http.site import SynapseSite -from synapse.logging.context import LoggingContext -from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy -from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.replication.slave.storage.account_data import SlavedAccountDataStore -from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore -from synapse.replication.slave.storage.client_ips import SlavedClientIpStore -from synapse.replication.slave.storage.devices import SlavedDeviceStore -from synapse.replication.slave.storage.directory import DirectoryStore -from synapse.replication.slave.storage.events import SlavedEventStore -from synapse.replication.slave.storage.profile import SlavedProfileStore -from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore -from synapse.replication.slave.storage.pushers import SlavedPusherStore -from synapse.replication.slave.storage.receipts import SlavedReceiptsStore -from synapse.replication.slave.storage.registration import SlavedRegistrationStore -from synapse.replication.slave.storage.room import RoomStore -from synapse.replication.slave.storage.transactions import SlavedTransactionStore -from synapse.replication.tcp.client import ReplicationClientHandler -from synapse.rest.client.v1.profile import ( - ProfileAvatarURLRestServlet, - ProfileDisplaynameRestServlet, - ProfileRestServlet, -) -from synapse.rest.client.v1.room import ( - JoinRoomAliasServlet, - RoomMembershipRestServlet, - RoomSendEventRestServlet, - RoomStateEventRestServlet, -) -from synapse.server import HomeServer -from synapse.storage.data_stores.main.monthly_active_users import ( - MonthlyActiveUsersWorkerStore, -) -from synapse.storage.data_stores.main.user_directory import UserDirectoryStore -from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.manhole import manhole -from synapse.util.versionstring import get_version_string - -logger = logging.getLogger("synapse.app.event_creator") - - -class EventCreatorSlavedStore( - # FIXME(#3714): We need to add UserDirectoryStore as we write directly - # rather than going via the correct worker. - UserDirectoryStore, - DirectoryStore, - SlavedTransactionStore, - SlavedProfileStore, - SlavedAccountDataStore, - SlavedPusherStore, - SlavedReceiptsStore, - SlavedPushRuleStore, - SlavedDeviceStore, - SlavedClientIpStore, - SlavedApplicationServiceStore, - SlavedEventStore, - SlavedRegistrationStore, - RoomStore, - MonthlyActiveUsersWorkerStore, - BaseSlavedStore, -): - pass - - -class EventCreatorServer(HomeServer): - DATASTORE_CLASS = EventCreatorSlavedStore - - def _listen_http(self, listener_config): - port = listener_config["port"] - bind_addresses = listener_config["bind_addresses"] - site_tag = listener_config.get("tag", port) - resources = {} - for res in listener_config["resources"]: - for name in res["names"]: - if name == "metrics": - resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) - elif name == "client": - resource = JsonResource(self, canonical_json=False) - RoomSendEventRestServlet(self).register(resource) - RoomMembershipRestServlet(self).register(resource) - RoomStateEventRestServlet(self).register(resource) - JoinRoomAliasServlet(self).register(resource) - ProfileAvatarURLRestServlet(self).register(resource) - ProfileDisplaynameRestServlet(self).register(resource) - ProfileRestServlet(self).register(resource) - resources.update( - { - "/_matrix/client/r0": resource, - "/_matrix/client/unstable": resource, - "/_matrix/client/v2_alpha": resource, - "/_matrix/client/api/v1": resource, - } - ) - - root_resource = create_resource_tree(resources, NoResource()) - - _base.listen_tcp( - bind_addresses, - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - self.version_string, - ), - ) - - logger.info("Synapse event creator now listening on port %d", port) - - def start_listening(self, listeners): - for listener in listeners: - if listener["type"] == "http": - self._listen_http(listener) - elif listener["type"] == "manhole": - _base.listen_tcp( - listener["bind_addresses"], - listener["port"], - manhole( - username="matrix", password="rabbithole", globals={"hs": self} - ), - ) - elif listener["type"] == "metrics": - if not self.get_config().enable_metrics: - logger.warning( - ( - "Metrics listener configured, but " - "enable_metrics is not True!" - ) - ) - else: - _base.listen_metrics(listener["bind_addresses"], listener["port"]) - else: - logger.warning("Unrecognized listener type: %s", listener["type"]) - self.get_tcp_replication().start_replication(self) - - def build_tcp_replication(self): - return ReplicationClientHandler(self.get_datastore()) - - -def start(config_options): - try: - config = HomeServerConfig.load_config("Synapse event creator", config_options) - except ConfigError as e: - sys.stderr.write("\n" + str(e) + "\n") - sys.exit(1) - - assert config.worker_app == "synapse.app.event_creator" - - assert config.worker_replication_http_port is not None - - # This should only be done on the user directory worker or the master - config.update_user_directory = False - - events.USE_FROZEN_DICTS = config.use_frozen_dicts - - ss = EventCreatorServer( - config.server_name, - config=config, - version_string="Synapse/" + get_version_string(synapse), - ) - - setup_logging(ss, config, use_worker_options=True) - - ss.setup() - reactor.addSystemEventTrigger( - "before", "startup", _base.start, ss, config.worker_listeners - ) - - _base.start_worker_reactor("synapse-event-creator", config) +import sys +from synapse.app.generic_worker import start +from synapse.util.logcontext import LoggingContext if __name__ == "__main__": with LoggingContext("main"): diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index d055d11b23..add43147b3 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -13,177 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging -import sys - -from twisted.internet import reactor -from twisted.web.resource import NoResource - -import synapse -from synapse import events -from synapse.api.urls import FEDERATION_PREFIX, SERVER_KEY_V2_PREFIX -from synapse.app import _base -from synapse.config._base import ConfigError -from synapse.config.homeserver import HomeServerConfig -from synapse.config.logger import setup_logging -from synapse.federation.transport.server import TransportLayerServer -from synapse.http.site import SynapseSite -from synapse.logging.context import LoggingContext -from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy -from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.replication.slave.storage.account_data import SlavedAccountDataStore -from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore -from synapse.replication.slave.storage.devices import SlavedDeviceStore -from synapse.replication.slave.storage.directory import DirectoryStore -from synapse.replication.slave.storage.events import SlavedEventStore -from synapse.replication.slave.storage.groups import SlavedGroupServerStore -from synapse.replication.slave.storage.keys import SlavedKeyStore -from synapse.replication.slave.storage.profile import SlavedProfileStore -from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore -from synapse.replication.slave.storage.pushers import SlavedPusherStore -from synapse.replication.slave.storage.receipts import SlavedReceiptsStore -from synapse.replication.slave.storage.registration import SlavedRegistrationStore -from synapse.replication.slave.storage.room import RoomStore -from synapse.replication.slave.storage.transactions import SlavedTransactionStore -from synapse.replication.tcp.client import ReplicationClientHandler -from synapse.rest.key.v2 import KeyApiV2Resource -from synapse.server import HomeServer -from synapse.storage.data_stores.main.monthly_active_users import ( - MonthlyActiveUsersWorkerStore, -) -from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.manhole import manhole -from synapse.util.versionstring import get_version_string - -logger = logging.getLogger("synapse.app.federation_reader") - - -class FederationReaderSlavedStore( - SlavedAccountDataStore, - SlavedProfileStore, - SlavedApplicationServiceStore, - SlavedPusherStore, - SlavedPushRuleStore, - SlavedReceiptsStore, - SlavedEventStore, - SlavedKeyStore, - SlavedRegistrationStore, - SlavedGroupServerStore, - SlavedDeviceStore, - RoomStore, - DirectoryStore, - SlavedTransactionStore, - MonthlyActiveUsersWorkerStore, - BaseSlavedStore, -): - pass - - -class FederationReaderServer(HomeServer): - DATASTORE_CLASS = FederationReaderSlavedStore - - def _listen_http(self, listener_config): - port = listener_config["port"] - bind_addresses = listener_config["bind_addresses"] - site_tag = listener_config.get("tag", port) - resources = {} - for res in listener_config["resources"]: - for name in res["names"]: - if name == "metrics": - resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) - elif name == "federation": - resources.update({FEDERATION_PREFIX: TransportLayerServer(self)}) - if name == "openid" and "federation" not in res["names"]: - # Only load the openid resource separately if federation resource - # is not specified since federation resource includes openid - # resource. - resources.update( - { - FEDERATION_PREFIX: TransportLayerServer( - self, servlet_groups=["openid"] - ) - } - ) - - if name in ["keys", "federation"]: - resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self) - - root_resource = create_resource_tree(resources, NoResource()) - - _base.listen_tcp( - bind_addresses, - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - self.version_string, - ), - reactor=self.get_reactor(), - ) - logger.info("Synapse federation reader now listening on port %d", port) - - def start_listening(self, listeners): - for listener in listeners: - if listener["type"] == "http": - self._listen_http(listener) - elif listener["type"] == "manhole": - _base.listen_tcp( - listener["bind_addresses"], - listener["port"], - manhole( - username="matrix", password="rabbithole", globals={"hs": self} - ), - ) - elif listener["type"] == "metrics": - if not self.get_config().enable_metrics: - logger.warning( - ( - "Metrics listener configured, but " - "enable_metrics is not True!" - ) - ) - else: - _base.listen_metrics(listener["bind_addresses"], listener["port"]) - else: - logger.warning("Unrecognized listener type: %s", listener["type"]) - - self.get_tcp_replication().start_replication(self) - - def build_tcp_replication(self): - return ReplicationClientHandler(self.get_datastore()) - - -def start(config_options): - try: - config = HomeServerConfig.load_config( - "Synapse federation reader", config_options - ) - except ConfigError as e: - sys.stderr.write("\n" + str(e) + "\n") - sys.exit(1) - - assert config.worker_app == "synapse.app.federation_reader" - - events.USE_FROZEN_DICTS = config.use_frozen_dicts - - ss = FederationReaderServer( - config.server_name, - config=config, - version_string="Synapse/" + get_version_string(synapse), - ) - - setup_logging(ss, config, use_worker_options=True) - - ss.setup() - reactor.addSystemEventTrigger( - "before", "startup", _base.start, ss, config.worker_listeners - ) - - _base.start_worker_reactor("synapse-federation-reader", config) +import sys +from synapse.app.generic_worker import start +from synapse.util.logcontext import LoggingContext if __name__ == "__main__": with LoggingContext("main"): diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index b7fcf80ddc..add43147b3 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -13,308 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging -import sys - -from twisted.internet import defer, reactor -from twisted.web.resource import NoResource - -import synapse -from synapse import events -from synapse.app import _base -from synapse.config._base import ConfigError -from synapse.config.homeserver import HomeServerConfig -from synapse.config.logger import setup_logging -from synapse.federation import send_queue -from synapse.http.site import SynapseSite -from synapse.logging.context import LoggingContext, run_in_background -from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore -from synapse.replication.slave.storage.devices import SlavedDeviceStore -from synapse.replication.slave.storage.events import SlavedEventStore -from synapse.replication.slave.storage.presence import SlavedPresenceStore -from synapse.replication.slave.storage.receipts import SlavedReceiptsStore -from synapse.replication.slave.storage.registration import SlavedRegistrationStore -from synapse.replication.slave.storage.transactions import SlavedTransactionStore -from synapse.replication.tcp.client import ReplicationClientHandler -from synapse.replication.tcp.streams._base import ( - DeviceListsStream, - ReceiptsStream, - ToDeviceStream, -) -from synapse.server import HomeServer -from synapse.storage.database import Database -from synapse.types import ReadReceipt -from synapse.util.async_helpers import Linearizer -from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.manhole import manhole -from synapse.util.versionstring import get_version_string - -logger = logging.getLogger("synapse.app.federation_sender") - - -class FederationSenderSlaveStore( - SlavedDeviceInboxStore, - SlavedTransactionStore, - SlavedReceiptsStore, - SlavedEventStore, - SlavedRegistrationStore, - SlavedDeviceStore, - SlavedPresenceStore, -): - def __init__(self, database: Database, db_conn, hs): - super(FederationSenderSlaveStore, self).__init__(database, db_conn, hs) - - # We pull out the current federation stream position now so that we - # always have a known value for the federation position in memory so - # that we don't have to bounce via a deferred once when we start the - # replication streams. - self.federation_out_pos_startup = self._get_federation_out_pos(db_conn) - - def _get_federation_out_pos(self, db_conn): - sql = "SELECT stream_id FROM federation_stream_position WHERE type = ?" - sql = self.database_engine.convert_param_style(sql) - - txn = db_conn.cursor() - txn.execute(sql, ("federation",)) - rows = txn.fetchall() - txn.close() - - return rows[0][0] if rows else -1 - - -class FederationSenderServer(HomeServer): - DATASTORE_CLASS = FederationSenderSlaveStore - - def _listen_http(self, listener_config): - port = listener_config["port"] - bind_addresses = listener_config["bind_addresses"] - site_tag = listener_config.get("tag", port) - resources = {} - for res in listener_config["resources"]: - for name in res["names"]: - if name == "metrics": - resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) - - root_resource = create_resource_tree(resources, NoResource()) - - _base.listen_tcp( - bind_addresses, - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - self.version_string, - ), - ) - - logger.info("Synapse federation_sender now listening on port %d", port) - - def start_listening(self, listeners): - for listener in listeners: - if listener["type"] == "http": - self._listen_http(listener) - elif listener["type"] == "manhole": - _base.listen_tcp( - listener["bind_addresses"], - listener["port"], - manhole( - username="matrix", password="rabbithole", globals={"hs": self} - ), - ) - elif listener["type"] == "metrics": - if not self.get_config().enable_metrics: - logger.warning( - ( - "Metrics listener configured, but " - "enable_metrics is not True!" - ) - ) - else: - _base.listen_metrics(listener["bind_addresses"], listener["port"]) - else: - logger.warning("Unrecognized listener type: %s", listener["type"]) - - self.get_tcp_replication().start_replication(self) - - def build_tcp_replication(self): - return FederationSenderReplicationHandler(self) - - -class FederationSenderReplicationHandler(ReplicationClientHandler): - def __init__(self, hs): - super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore()) - self.send_handler = FederationSenderHandler(hs, self) - - async def on_rdata(self, stream_name, token, rows): - await super(FederationSenderReplicationHandler, self).on_rdata( - stream_name, token, rows - ) - self.send_handler.process_replication_rows(stream_name, token, rows) - - def get_streams_to_replicate(self): - args = super( - FederationSenderReplicationHandler, self - ).get_streams_to_replicate() - args.update(self.send_handler.stream_positions()) - return args - - def on_remote_server_up(self, server: str): - """Called when get a new REMOTE_SERVER_UP command.""" - - # Let's wake up the transaction queue for the server in case we have - # pending stuff to send to it. - self.send_handler.wake_destination(server) - - -def start(config_options): - try: - config = HomeServerConfig.load_config( - "Synapse federation sender", config_options - ) - except ConfigError as e: - sys.stderr.write("\n" + str(e) + "\n") - sys.exit(1) - assert config.worker_app == "synapse.app.federation_sender" - - events.USE_FROZEN_DICTS = config.use_frozen_dicts - - if config.send_federation: - sys.stderr.write( - "\nThe send_federation must be disabled in the main synapse process" - "\nbefore they can be run in a separate worker." - "\nPlease add ``send_federation: false`` to the main config" - "\n" - ) - sys.exit(1) - - # Force the pushers to start since they will be disabled in the main config - config.send_federation = True - - ss = FederationSenderServer( - config.server_name, - config=config, - version_string="Synapse/" + get_version_string(synapse), - ) - - setup_logging(ss, config, use_worker_options=True) - - ss.setup() - reactor.addSystemEventTrigger( - "before", "startup", _base.start, ss, config.worker_listeners - ) - - _base.start_worker_reactor("synapse-federation-sender", config) - - -class FederationSenderHandler(object): - """Processes the replication stream and forwards the appropriate entries - to the federation sender. - """ - - def __init__(self, hs: FederationSenderServer, replication_client): - self.store = hs.get_datastore() - self._is_mine_id = hs.is_mine_id - self.federation_sender = hs.get_federation_sender() - self.replication_client = replication_client - - self.federation_position = self.store.federation_out_pos_startup - self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer") - - self._last_ack = self.federation_position - - self._room_serials = {} - self._room_typing = {} - - def on_start(self): - # There may be some events that are persisted but haven't been sent, - # so send them now. - self.federation_sender.notify_new_events( - self.store.get_room_max_stream_ordering() - ) - - def wake_destination(self, server: str): - self.federation_sender.wake_destination(server) - - def stream_positions(self): - return {"federation": self.federation_position} - - def process_replication_rows(self, stream_name, token, rows): - # The federation stream contains things that we want to send out, e.g. - # presence, typing, etc. - if stream_name == "federation": - send_queue.process_rows_for_federation(self.federation_sender, rows) - run_in_background(self.update_token, token) - - # We also need to poke the federation sender when new events happen - elif stream_name == "events": - self.federation_sender.notify_new_events(token) - - # ... and when new receipts happen - elif stream_name == ReceiptsStream.NAME: - run_as_background_process( - "process_receipts_for_federation", self._on_new_receipts, rows - ) - - # ... as well as device updates and messages - elif stream_name == DeviceListsStream.NAME: - hosts = {row.destination for row in rows} - for host in hosts: - self.federation_sender.send_device_messages(host) - - elif stream_name == ToDeviceStream.NAME: - # The to_device stream includes stuff to be pushed to both local - # clients and remote servers, so we ignore entities that start with - # '@' (since they'll be local users rather than destinations). - hosts = {row.entity for row in rows if not row.entity.startswith("@")} - for host in hosts: - self.federation_sender.send_device_messages(host) - - @defer.inlineCallbacks - def _on_new_receipts(self, rows): - """ - Args: - rows (iterable[synapse.replication.tcp.streams.ReceiptsStreamRow]): - new receipts to be processed - """ - for receipt in rows: - # we only want to send on receipts for our own users - if not self._is_mine_id(receipt.user_id): - continue - receipt_info = ReadReceipt( - receipt.room_id, - receipt.receipt_type, - receipt.user_id, - [receipt.event_id], - receipt.data, - ) - yield self.federation_sender.send_read_receipt(receipt_info) - - @defer.inlineCallbacks - def update_token(self, token): - try: - self.federation_position = token - - # We linearize here to ensure we don't have races updating the token - with (yield self._fed_position_linearizer.queue(None)): - if self._last_ack < self.federation_position: - yield self.store.update_federation_out_pos( - "federation", self.federation_position - ) - - # We ACK this token over replication so that the master can drop - # its in memory queues - self.replication_client.send_federation_ack( - self.federation_position - ) - self._last_ack = self.federation_position - except Exception: - logger.exception("Error updating federation stream position") +import sys +from synapse.app.generic_worker import start +from synapse.util.logcontext import LoggingContext if __name__ == "__main__": with LoggingContext("main"): diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index 30e435eead..add43147b3 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -13,241 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging -import sys - -from twisted.internet import defer, reactor -from twisted.web.resource import NoResource - -import synapse -from synapse import events -from synapse.api.errors import HttpResponseException, SynapseError -from synapse.app import _base -from synapse.config._base import ConfigError -from synapse.config.homeserver import HomeServerConfig -from synapse.config.logger import setup_logging -from synapse.http.server import JsonResource -from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.http.site import SynapseSite -from synapse.logging.context import LoggingContext -from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy -from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore -from synapse.replication.slave.storage.client_ips import SlavedClientIpStore -from synapse.replication.slave.storage.devices import SlavedDeviceStore -from synapse.replication.slave.storage.registration import SlavedRegistrationStore -from synapse.replication.tcp.client import ReplicationClientHandler -from synapse.rest.client.v2_alpha._base import client_patterns -from synapse.server import HomeServer -from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.manhole import manhole -from synapse.util.versionstring import get_version_string - -logger = logging.getLogger("synapse.app.frontend_proxy") - - -class PresenceStatusStubServlet(RestServlet): - PATTERNS = client_patterns("/presence/(?P[^/]*)/status") - - def __init__(self, hs): - super(PresenceStatusStubServlet, self).__init__() - self.http_client = hs.get_simple_http_client() - self.auth = hs.get_auth() - self.main_uri = hs.config.worker_main_http_uri - - @defer.inlineCallbacks - def on_GET(self, request, user_id): - # Pass through the auth headers, if any, in case the access token - # is there. - auth_headers = request.requestHeaders.getRawHeaders("Authorization", []) - headers = {"Authorization": auth_headers} - - try: - result = yield self.http_client.get_json( - self.main_uri + request.uri.decode("ascii"), headers=headers - ) - except HttpResponseException as e: - raise e.to_synapse_error() - - return 200, result - - @defer.inlineCallbacks - def on_PUT(self, request, user_id): - yield self.auth.get_user_by_req(request) - return 200, {} - - -class KeyUploadServlet(RestServlet): - PATTERNS = client_patterns("/keys/upload(/(?P[^/]+))?$") - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super(KeyUploadServlet, self).__init__() - self.auth = hs.get_auth() - self.store = hs.get_datastore() - self.http_client = hs.get_simple_http_client() - self.main_uri = hs.config.worker_main_http_uri - - @defer.inlineCallbacks - def on_POST(self, request, device_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) - user_id = requester.user.to_string() - body = parse_json_object_from_request(request) - - if device_id is not None: - # passing the device_id here is deprecated; however, we allow it - # for now for compatibility with older clients. - if requester.device_id is not None and device_id != requester.device_id: - logger.warning( - "Client uploading keys for a different device " - "(logged in as %s, uploading for %s)", - requester.device_id, - device_id, - ) - else: - device_id = requester.device_id - - if device_id is None: - raise SynapseError( - 400, "To upload keys, you must pass device_id when authenticating" - ) - - if body: - # They're actually trying to upload something, proxy to main synapse. - # Pass through the auth headers, if any, in case the access token - # is there. - auth_headers = request.requestHeaders.getRawHeaders(b"Authorization", []) - headers = {"Authorization": auth_headers} - result = yield self.http_client.post_json_get_json( - self.main_uri + request.uri.decode("ascii"), body, headers=headers - ) - - return 200, result - else: - # Just interested in counts. - result = yield self.store.count_e2e_one_time_keys(user_id, device_id) - return 200, {"one_time_key_counts": result} - - -class FrontendProxySlavedStore( - SlavedDeviceStore, - SlavedClientIpStore, - SlavedApplicationServiceStore, - SlavedRegistrationStore, - BaseSlavedStore, -): - pass +import sys -class FrontendProxyServer(HomeServer): - DATASTORE_CLASS = FrontendProxySlavedStore - - def _listen_http(self, listener_config): - port = listener_config["port"] - bind_addresses = listener_config["bind_addresses"] - site_tag = listener_config.get("tag", port) - resources = {} - for res in listener_config["resources"]: - for name in res["names"]: - if name == "metrics": - resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) - elif name == "client": - resource = JsonResource(self, canonical_json=False) - KeyUploadServlet(self).register(resource) - - # If presence is disabled, use the stub servlet that does - # not allow sending presence - if not self.config.use_presence: - PresenceStatusStubServlet(self).register(resource) - - resources.update( - { - "/_matrix/client/r0": resource, - "/_matrix/client/unstable": resource, - "/_matrix/client/v2_alpha": resource, - "/_matrix/client/api/v1": resource, - } - ) - - root_resource = create_resource_tree(resources, NoResource()) - - _base.listen_tcp( - bind_addresses, - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - self.version_string, - ), - reactor=self.get_reactor(), - ) - - logger.info("Synapse client reader now listening on port %d", port) - - def start_listening(self, listeners): - for listener in listeners: - if listener["type"] == "http": - self._listen_http(listener) - elif listener["type"] == "manhole": - _base.listen_tcp( - listener["bind_addresses"], - listener["port"], - manhole( - username="matrix", password="rabbithole", globals={"hs": self} - ), - ) - elif listener["type"] == "metrics": - if not self.get_config().enable_metrics: - logger.warning( - ( - "Metrics listener configured, but " - "enable_metrics is not True!" - ) - ) - else: - _base.listen_metrics(listener["bind_addresses"], listener["port"]) - else: - logger.warning("Unrecognized listener type: %s", listener["type"]) - - self.get_tcp_replication().start_replication(self) - - def build_tcp_replication(self): - return ReplicationClientHandler(self.get_datastore()) - - -def start(config_options): - try: - config = HomeServerConfig.load_config("Synapse frontend proxy", config_options) - except ConfigError as e: - sys.stderr.write("\n" + str(e) + "\n") - sys.exit(1) - - assert config.worker_app == "synapse.app.frontend_proxy" - - assert config.worker_main_http_uri is not None - - events.USE_FROZEN_DICTS = config.use_frozen_dicts - - ss = FrontendProxyServer( - config.server_name, - config=config, - version_string="Synapse/" + get_version_string(synapse), - ) - - setup_logging(ss, config, use_worker_options=True) - - ss.setup() - reactor.addSystemEventTrigger( - "before", "startup", _base.start, ss, config.worker_listeners - ) - - _base.start_worker_reactor("synapse-frontend-proxy", config) - +from synapse.app.generic_worker import start +from synapse.util.logcontext import LoggingContext if __name__ == "__main__": with LoggingContext("main"): diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py new file mode 100644 index 0000000000..30efd39092 --- /dev/null +++ b/synapse/app/generic_worker.py @@ -0,0 +1,917 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +import logging +import sys + +from twisted.internet import defer, reactor +from twisted.web.resource import NoResource + +import synapse +import synapse.events +from synapse.api.constants import EventTypes +from synapse.api.errors import HttpResponseException, SynapseError +from synapse.api.urls import ( + CLIENT_API_PREFIX, + FEDERATION_PREFIX, + LEGACY_MEDIA_PREFIX, + MEDIA_PREFIX, + SERVER_KEY_V2_PREFIX, +) +from synapse.app import _base +from synapse.config._base import ConfigError +from synapse.config.homeserver import HomeServerConfig +from synapse.config.logger import setup_logging +from synapse.federation import send_queue +from synapse.federation.transport.server import TransportLayerServer +from synapse.handlers.presence import PresenceHandler, get_interested_parties +from synapse.http.server import JsonResource +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseSite +from synapse.logging.context import LoggingContext, run_in_background +from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.replication.slave.storage._base import BaseSlavedStore, __func__ +from synapse.replication.slave.storage.account_data import SlavedAccountDataStore +from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore +from synapse.replication.slave.storage.client_ips import SlavedClientIpStore +from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore +from synapse.replication.slave.storage.devices import SlavedDeviceStore +from synapse.replication.slave.storage.directory import DirectoryStore +from synapse.replication.slave.storage.events import SlavedEventStore +from synapse.replication.slave.storage.filtering import SlavedFilteringStore +from synapse.replication.slave.storage.groups import SlavedGroupServerStore +from synapse.replication.slave.storage.keys import SlavedKeyStore +from synapse.replication.slave.storage.presence import SlavedPresenceStore +from synapse.replication.slave.storage.profile import SlavedProfileStore +from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore +from synapse.replication.slave.storage.pushers import SlavedPusherStore +from synapse.replication.slave.storage.receipts import SlavedReceiptsStore +from synapse.replication.slave.storage.registration import SlavedRegistrationStore +from synapse.replication.slave.storage.room import RoomStore +from synapse.replication.slave.storage.transactions import SlavedTransactionStore +from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.replication.tcp.streams._base import ( + DeviceListsStream, + ReceiptsStream, + ToDeviceStream, +) +from synapse.replication.tcp.streams.events import EventsStreamEventRow, EventsStreamRow +from synapse.rest.admin import register_servlets_for_media_repo +from synapse.rest.client.v1 import events +from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet +from synapse.rest.client.v1.login import LoginRestServlet +from synapse.rest.client.v1.profile import ( + ProfileAvatarURLRestServlet, + ProfileDisplaynameRestServlet, + ProfileRestServlet, +) +from synapse.rest.client.v1.push_rule import PushRuleRestServlet +from synapse.rest.client.v1.room import ( + JoinedRoomMemberListRestServlet, + JoinRoomAliasServlet, + PublicRoomListRestServlet, + RoomEventContextServlet, + RoomInitialSyncRestServlet, + RoomMemberListRestServlet, + RoomMembershipRestServlet, + RoomMessageListRestServlet, + RoomSendEventRestServlet, + RoomStateEventRestServlet, + RoomStateRestServlet, +) +from synapse.rest.client.v1.voip import VoipRestServlet +from synapse.rest.client.v2_alpha import groups, sync, user_directory +from synapse.rest.client.v2_alpha._base import client_patterns +from synapse.rest.client.v2_alpha.account import ThreepidRestServlet +from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet +from synapse.rest.client.v2_alpha.register import RegisterRestServlet +from synapse.rest.client.versions import VersionsRestServlet +from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.server import HomeServer +from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore +from synapse.storage.data_stores.main.monthly_active_users import ( + MonthlyActiveUsersWorkerStore, +) +from synapse.storage.data_stores.main.presence import UserPresenceState +from synapse.storage.data_stores.main.user_directory import UserDirectoryStore +from synapse.types import ReadReceipt +from synapse.util.async_helpers import Linearizer +from synapse.util.httpresourcetree import create_resource_tree +from synapse.util.manhole import manhole +from synapse.util.stringutils import random_string +from synapse.util.versionstring import get_version_string + +logger = logging.getLogger("synapse.app.generic_worker") + + +class PresenceStatusStubServlet(RestServlet): + """If presence is disabled this servlet can be used to stub out setting + presence status, while proxying the getters to the master instance. + """ + + PATTERNS = client_patterns("/presence/(?P[^/]*)/status") + + def __init__(self, hs): + super(PresenceStatusStubServlet, self).__init__() + self.http_client = hs.get_simple_http_client() + self.auth = hs.get_auth() + self.main_uri = hs.config.worker_main_http_uri + + async def on_GET(self, request, user_id): + # Pass through the auth headers, if any, in case the access token + # is there. + auth_headers = request.requestHeaders.getRawHeaders("Authorization", []) + headers = {"Authorization": auth_headers} + + try: + result = await self.http_client.get_json( + self.main_uri + request.uri.decode("ascii"), headers=headers + ) + except HttpResponseException as e: + raise e.to_synapse_error() + + return 200, result + + async def on_PUT(self, request, user_id): + await self.auth.get_user_by_req(request) + return 200, {} + + +class KeyUploadServlet(RestServlet): + """An implementation of the `KeyUploadServlet` that responds to read only + requests, but otherwise proxies through to the master instance. + """ + + PATTERNS = client_patterns("/keys/upload(/(?P[^/]+))?$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(KeyUploadServlet, self).__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.http_client = hs.get_simple_http_client() + self.main_uri = hs.config.worker_main_http_uri + + async def on_POST(self, request, device_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + user_id = requester.user.to_string() + body = parse_json_object_from_request(request) + + if device_id is not None: + # passing the device_id here is deprecated; however, we allow it + # for now for compatibility with older clients. + if requester.device_id is not None and device_id != requester.device_id: + logger.warning( + "Client uploading keys for a different device " + "(logged in as %s, uploading for %s)", + requester.device_id, + device_id, + ) + else: + device_id = requester.device_id + + if device_id is None: + raise SynapseError( + 400, "To upload keys, you must pass device_id when authenticating" + ) + + if body: + # They're actually trying to upload something, proxy to main synapse. + # Pass through the auth headers, if any, in case the access token + # is there. + auth_headers = request.requestHeaders.getRawHeaders(b"Authorization", []) + headers = {"Authorization": auth_headers} + result = await self.http_client.post_json_get_json( + self.main_uri + request.uri.decode("ascii"), body, headers=headers + ) + + return 200, result + else: + # Just interested in counts. + result = await self.store.count_e2e_one_time_keys(user_id, device_id) + return 200, {"one_time_key_counts": result} + + +UPDATE_SYNCING_USERS_MS = 10 * 1000 + + +class GenericWorkerPresence(object): + def __init__(self, hs): + self.hs = hs + self.is_mine_id = hs.is_mine_id + self.http_client = hs.get_simple_http_client() + self.store = hs.get_datastore() + self.user_to_num_current_syncs = {} + self.clock = hs.get_clock() + self.notifier = hs.get_notifier() + + active_presence = self.store.take_presence_startup_info() + self.user_to_current_state = {state.user_id: state for state in active_presence} + + # user_id -> last_sync_ms. Lists the users that have stopped syncing + # but we haven't notified the master of that yet + self.users_going_offline = {} + + self._send_stop_syncing_loop = self.clock.looping_call( + self.send_stop_syncing, UPDATE_SYNCING_USERS_MS + ) + + self.process_id = random_string(16) + logger.info("Presence process_id is %r", self.process_id) + + def send_user_sync(self, user_id, is_syncing, last_sync_ms): + if self.hs.config.use_presence: + self.hs.get_tcp_replication().send_user_sync( + user_id, is_syncing, last_sync_ms + ) + + def mark_as_coming_online(self, user_id): + """A user has started syncing. Send a UserSync to the master, unless they + had recently stopped syncing. + + Args: + user_id (str) + """ + going_offline = self.users_going_offline.pop(user_id, None) + if not going_offline: + # Safe to skip because we haven't yet told the master they were offline + self.send_user_sync(user_id, True, self.clock.time_msec()) + + def mark_as_going_offline(self, user_id): + """A user has stopped syncing. We wait before notifying the master as + its likely they'll come back soon. This allows us to avoid sending + a stopped syncing immediately followed by a started syncing notification + to the master + + Args: + user_id (str) + """ + self.users_going_offline[user_id] = self.clock.time_msec() + + def send_stop_syncing(self): + """Check if there are any users who have stopped syncing a while ago + and haven't come back yet. If there are poke the master about them. + """ + now = self.clock.time_msec() + for user_id, last_sync_ms in list(self.users_going_offline.items()): + if now - last_sync_ms > UPDATE_SYNCING_USERS_MS: + self.users_going_offline.pop(user_id, None) + self.send_user_sync(user_id, False, last_sync_ms) + + def set_state(self, user, state, ignore_status_msg=False): + # TODO Hows this supposed to work? + return defer.succeed(None) + + get_states = __func__(PresenceHandler.get_states) + get_state = __func__(PresenceHandler.get_state) + current_state_for_users = __func__(PresenceHandler.current_state_for_users) + + def user_syncing(self, user_id, affect_presence): + if affect_presence: + curr_sync = self.user_to_num_current_syncs.get(user_id, 0) + self.user_to_num_current_syncs[user_id] = curr_sync + 1 + + # If we went from no in flight sync to some, notify replication + if self.user_to_num_current_syncs[user_id] == 1: + self.mark_as_coming_online(user_id) + + def _end(): + # We check that the user_id is in user_to_num_current_syncs because + # user_to_num_current_syncs may have been cleared if we are + # shutting down. + if affect_presence and user_id in self.user_to_num_current_syncs: + self.user_to_num_current_syncs[user_id] -= 1 + + # If we went from one in flight sync to non, notify replication + if self.user_to_num_current_syncs[user_id] == 0: + self.mark_as_going_offline(user_id) + + @contextlib.contextmanager + def _user_syncing(): + try: + yield + finally: + _end() + + return defer.succeed(_user_syncing()) + + @defer.inlineCallbacks + def notify_from_replication(self, states, stream_id): + parties = yield get_interested_parties(self.store, states) + room_ids_to_states, users_to_states = parties + + self.notifier.on_new_event( + "presence_key", + stream_id, + rooms=room_ids_to_states.keys(), + users=users_to_states.keys(), + ) + + @defer.inlineCallbacks + def process_replication_rows(self, token, rows): + states = [ + UserPresenceState( + row.user_id, + row.state, + row.last_active_ts, + row.last_federation_update_ts, + row.last_user_sync_ts, + row.status_msg, + row.currently_active, + ) + for row in rows + ] + + for state in states: + self.user_to_current_state[state.user_id] = state + + stream_id = token + yield self.notify_from_replication(states, stream_id) + + def get_currently_syncing_users(self): + if self.hs.config.use_presence: + return [ + user_id + for user_id, count in self.user_to_num_current_syncs.items() + if count > 0 + ] + else: + return set() + + +class GenericWorkerTyping(object): + def __init__(self, hs): + self._latest_room_serial = 0 + self._reset() + + def _reset(self): + """ + Reset the typing handler's data caches. + """ + # map room IDs to serial numbers + self._room_serials = {} + # map room IDs to sets of users currently typing + self._room_typing = {} + + def stream_positions(self): + # We must update this typing token from the response of the previous + # sync. In particular, the stream id may "reset" back to zero/a low + # value which we *must* use for the next replication request. + return {"typing": self._latest_room_serial} + + def process_replication_rows(self, token, rows): + if self._latest_room_serial > token: + # The master has gone backwards. To prevent inconsistent data, just + # clear everything. + self._reset() + + # Set the latest serial token to whatever the server gave us. + self._latest_room_serial = token + + for row in rows: + self._room_serials[row.room_id] = token + self._room_typing[row.room_id] = row.user_ids + + +class GenericWorkerSlavedStore( + # FIXME(#3714): We need to add UserDirectoryStore as we write directly + # rather than going via the correct worker. + UserDirectoryStore, + SlavedDeviceInboxStore, + SlavedDeviceStore, + SlavedReceiptsStore, + SlavedPushRuleStore, + SlavedGroupServerStore, + SlavedAccountDataStore, + SlavedPusherStore, + SlavedEventStore, + SlavedKeyStore, + RoomStore, + DirectoryStore, + SlavedApplicationServiceStore, + SlavedRegistrationStore, + SlavedTransactionStore, + SlavedProfileStore, + SlavedClientIpStore, + SlavedPresenceStore, + SlavedFilteringStore, + MonthlyActiveUsersWorkerStore, + MediaRepositoryStore, + BaseSlavedStore, +): + def __init__(self, database, db_conn, hs): + super(GenericWorkerSlavedStore, self).__init__(database, db_conn, hs) + + # We pull out the current federation stream position now so that we + # always have a known value for the federation position in memory so + # that we don't have to bounce via a deferred once when we start the + # replication streams. + self.federation_out_pos_startup = self._get_federation_out_pos(db_conn) + + def _get_federation_out_pos(self, db_conn): + sql = "SELECT stream_id FROM federation_stream_position WHERE type = ?" + sql = self.database_engine.convert_param_style(sql) + + txn = db_conn.cursor() + txn.execute(sql, ("federation",)) + rows = txn.fetchall() + txn.close() + + return rows[0][0] if rows else -1 + + +class GenericWorkerServer(HomeServer): + DATASTORE_CLASS = GenericWorkerSlavedStore + + def _listen_http(self, listener_config): + port = listener_config["port"] + bind_addresses = listener_config["bind_addresses"] + site_tag = listener_config.get("tag", port) + resources = {} + for res in listener_config["resources"]: + for name in res["names"]: + if name == "metrics": + resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) + elif name == "client": + resource = JsonResource(self, canonical_json=False) + + PublicRoomListRestServlet(self).register(resource) + RoomMemberListRestServlet(self).register(resource) + JoinedRoomMemberListRestServlet(self).register(resource) + RoomStateRestServlet(self).register(resource) + RoomEventContextServlet(self).register(resource) + RoomMessageListRestServlet(self).register(resource) + RegisterRestServlet(self).register(resource) + LoginRestServlet(self).register(resource) + ThreepidRestServlet(self).register(resource) + KeyQueryServlet(self).register(resource) + KeyChangesServlet(self).register(resource) + VoipRestServlet(self).register(resource) + PushRuleRestServlet(self).register(resource) + VersionsRestServlet(self).register(resource) + RoomSendEventRestServlet(self).register(resource) + RoomMembershipRestServlet(self).register(resource) + RoomStateEventRestServlet(self).register(resource) + JoinRoomAliasServlet(self).register(resource) + ProfileAvatarURLRestServlet(self).register(resource) + ProfileDisplaynameRestServlet(self).register(resource) + ProfileRestServlet(self).register(resource) + KeyUploadServlet(self).register(resource) + + sync.register_servlets(self, resource) + events.register_servlets(self, resource) + InitialSyncRestServlet(self).register(resource) + RoomInitialSyncRestServlet(self).register(resource) + + user_directory.register_servlets(self, resource) + + # If presence is disabled, use the stub servlet that does + # not allow sending presence + if not self.config.use_presence: + PresenceStatusStubServlet(self).register(resource) + + groups.register_servlets(self, resource) + + resources.update({CLIENT_API_PREFIX: resource}) + elif name == "federation": + resources.update({FEDERATION_PREFIX: TransportLayerServer(self)}) + elif name == "media": + media_repo = self.get_media_repository_resource() + + # We need to serve the admin servlets for media on the + # worker. + admin_resource = JsonResource(self, canonical_json=False) + register_servlets_for_media_repo(self, admin_resource) + + resources.update( + { + MEDIA_PREFIX: media_repo, + LEGACY_MEDIA_PREFIX: media_repo, + "/_synapse/admin": admin_resource, + } + ) + + if name == "openid" and "federation" not in res["names"]: + # Only load the openid resource separately if federation resource + # is not specified since federation resource includes openid + # resource. + resources.update( + { + FEDERATION_PREFIX: TransportLayerServer( + self, servlet_groups=["openid"] + ) + } + ) + + if name in ["keys", "federation"]: + resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self) + + root_resource = create_resource_tree(resources, NoResource()) + + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + self.version_string, + ), + reactor=self.get_reactor(), + ) + + logger.info("Synapse worker now listening on port %d", port) + + def start_listening(self, listeners): + for listener in listeners: + if listener["type"] == "http": + self._listen_http(listener) + elif listener["type"] == "manhole": + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", password="rabbithole", globals={"hs": self} + ), + ) + elif listener["type"] == "metrics": + if not self.get_config().enable_metrics: + logger.warning( + ( + "Metrics listener configured, but " + "enable_metrics is not True!" + ) + ) + else: + _base.listen_metrics(listener["bind_addresses"], listener["port"]) + else: + logger.warning("Unrecognized listener type: %s", listener["type"]) + + self.get_tcp_replication().start_replication(self) + + def remove_pusher(self, app_id, push_key, user_id): + self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id) + + def build_tcp_replication(self): + return GenericWorkerReplicationHandler(self) + + def build_presence_handler(self): + return GenericWorkerPresence(self) + + def build_typing_handler(self): + return GenericWorkerTyping(self) + + +class GenericWorkerReplicationHandler(ReplicationClientHandler): + def __init__(self, hs): + super(GenericWorkerReplicationHandler, self).__init__(hs.get_datastore()) + + self.store = hs.get_datastore() + self.typing_handler = hs.get_typing_handler() + # NB this is a SynchrotronPresence, not a normal PresenceHandler + self.presence_handler = hs.get_presence_handler() + self.notifier = hs.get_notifier() + + self.notify_pushers = hs.config.start_pushers + self.pusher_pool = hs.get_pusherpool() + + if hs.config.send_federation: + self.send_handler = FederationSenderHandler(hs, self) + else: + self.send_handler = None + + async def on_rdata(self, stream_name, token, rows): + await super(GenericWorkerReplicationHandler, self).on_rdata( + stream_name, token, rows + ) + run_in_background(self.process_and_notify, stream_name, token, rows) + + def get_streams_to_replicate(self): + args = super(GenericWorkerReplicationHandler, self).get_streams_to_replicate() + args.update(self.typing_handler.stream_positions()) + if self.send_handler: + args.update(self.send_handler.stream_positions()) + return args + + def get_currently_syncing_users(self): + return self.presence_handler.get_currently_syncing_users() + + async def process_and_notify(self, stream_name, token, rows): + try: + if self.send_handler: + self.send_handler.process_replication_rows(stream_name, token, rows) + + if stream_name == "events": + # We shouldn't get multiple rows per token for events stream, so + # we don't need to optimise this for multiple rows. + for row in rows: + if row.type != EventsStreamEventRow.TypeId: + continue + assert isinstance(row, EventsStreamRow) + + event = await self.store.get_event( + row.data.event_id, allow_rejected=True + ) + if event.rejected_reason: + continue + + extra_users = () + if event.type == EventTypes.Member: + extra_users = (event.state_key,) + max_token = self.store.get_room_max_stream_ordering() + self.notifier.on_new_room_event( + event, token, max_token, extra_users + ) + + await self.pusher_pool.on_new_notifications(token, token) + elif stream_name == "push_rules": + self.notifier.on_new_event( + "push_rules_key", token, users=[row.user_id for row in rows] + ) + elif stream_name in ("account_data", "tag_account_data"): + self.notifier.on_new_event( + "account_data_key", token, users=[row.user_id for row in rows] + ) + elif stream_name == "receipts": + self.notifier.on_new_event( + "receipt_key", token, rooms=[row.room_id for row in rows] + ) + await self.pusher_pool.on_new_receipts( + token, token, {row.room_id for row in rows} + ) + elif stream_name == "typing": + self.typing_handler.process_replication_rows(token, rows) + self.notifier.on_new_event( + "typing_key", token, rooms=[row.room_id for row in rows] + ) + elif stream_name == "to_device": + entities = [row.entity for row in rows if row.entity.startswith("@")] + if entities: + self.notifier.on_new_event("to_device_key", token, users=entities) + elif stream_name == "device_lists": + all_room_ids = set() + for row in rows: + room_ids = await self.store.get_rooms_for_user(row.user_id) + all_room_ids.update(room_ids) + self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids) + elif stream_name == "presence": + await self.presence_handler.process_replication_rows(token, rows) + elif stream_name == "receipts": + self.notifier.on_new_event( + "groups_key", token, users=[row.user_id for row in rows] + ) + elif stream_name == "pushers": + for row in rows: + if row.deleted: + self.stop_pusher(row.user_id, row.app_id, row.pushkey) + else: + await self.start_pusher(row.user_id, row.app_id, row.pushkey) + except Exception: + logger.exception("Error processing replication") + + def stop_pusher(self, user_id, app_id, pushkey): + if not self.notify_pushers: + return + + key = "%s:%s" % (app_id, pushkey) + pushers_for_user = self.pusher_pool.pushers.get(user_id, {}) + pusher = pushers_for_user.pop(key, None) + if pusher is None: + return + logger.info("Stopping pusher %r / %r", user_id, key) + pusher.on_stop() + + async def start_pusher(self, user_id, app_id, pushkey): + if not self.notify_pushers: + return + + key = "%s:%s" % (app_id, pushkey) + logger.info("Starting pusher %r / %r", user_id, key) + return await self.pusher_pool.start_pusher_by_id(app_id, pushkey, user_id) + + def on_remote_server_up(self, server: str): + """Called when get a new REMOTE_SERVER_UP command.""" + + # Let's wake up the transaction queue for the server in case we have + # pending stuff to send to it. + if self.send_handler: + self.send_handler.wake_destination(server) + + +class FederationSenderHandler(object): + """Processes the replication stream and forwards the appropriate entries + to the federation sender. + """ + + def __init__(self, hs: GenericWorkerServer, replication_client): + self.store = hs.get_datastore() + self._is_mine_id = hs.is_mine_id + self.federation_sender = hs.get_federation_sender() + self.replication_client = replication_client + + self.federation_position = self.store.federation_out_pos_startup + self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer") + + self._last_ack = self.federation_position + + self._room_serials = {} + self._room_typing = {} + + def on_start(self): + # There may be some events that are persisted but haven't been sent, + # so send them now. + self.federation_sender.notify_new_events( + self.store.get_room_max_stream_ordering() + ) + + def wake_destination(self, server: str): + self.federation_sender.wake_destination(server) + + def stream_positions(self): + return {"federation": self.federation_position} + + def process_replication_rows(self, stream_name, token, rows): + # The federation stream contains things that we want to send out, e.g. + # presence, typing, etc. + if stream_name == "federation": + send_queue.process_rows_for_federation(self.federation_sender, rows) + run_in_background(self.update_token, token) + + # We also need to poke the federation sender when new events happen + elif stream_name == "events": + self.federation_sender.notify_new_events(token) + + # ... and when new receipts happen + elif stream_name == ReceiptsStream.NAME: + run_as_background_process( + "process_receipts_for_federation", self._on_new_receipts, rows + ) + + # ... as well as device updates and messages + elif stream_name == DeviceListsStream.NAME: + hosts = {row.destination for row in rows} + for host in hosts: + self.federation_sender.send_device_messages(host) + + elif stream_name == ToDeviceStream.NAME: + # The to_device stream includes stuff to be pushed to both local + # clients and remote servers, so we ignore entities that start with + # '@' (since they'll be local users rather than destinations). + hosts = {row.entity for row in rows if not row.entity.startswith("@")} + for host in hosts: + self.federation_sender.send_device_messages(host) + + async def _on_new_receipts(self, rows): + """ + Args: + rows (iterable[synapse.replication.tcp.streams.ReceiptsStreamRow]): + new receipts to be processed + """ + for receipt in rows: + # we only want to send on receipts for our own users + if not self._is_mine_id(receipt.user_id): + continue + receipt_info = ReadReceipt( + receipt.room_id, + receipt.receipt_type, + receipt.user_id, + [receipt.event_id], + receipt.data, + ) + await self.federation_sender.send_read_receipt(receipt_info) + + async def update_token(self, token): + try: + self.federation_position = token + + # We linearize here to ensure we don't have races updating the token + with (await self._fed_position_linearizer.queue(None)): + if self._last_ack < self.federation_position: + await self.store.update_federation_out_pos( + "federation", self.federation_position + ) + + # We ACK this token over replication so that the master can drop + # its in memory queues + self.replication_client.send_federation_ack( + self.federation_position + ) + self._last_ack = self.federation_position + except Exception: + logger.exception("Error updating federation stream position") + + +def start(config_options): + try: + config = HomeServerConfig.load_config("Synapse worker", config_options) + except ConfigError as e: + sys.stderr.write("\n" + str(e) + "\n") + sys.exit(1) + + # For backwards compatibility let any of the old app names. + assert config.worker_app in ( + "synapse.app.appservice", + "synapse.app.client_reader", + "synapse.app.event_creator", + "synapse.app.federation_reader", + "synapse.app.federation_sender", + "synapse.app.frontend_proxy", + "synapse.app.generic_worker", + "synapse.app.media_repository", + "synapse.app.pusher", + "synapse.app.synchrotron", + "synapse.app.user_dir", + ) + + if config.worker_app == "synapse.app.appservice": + if config.notify_appservices: + sys.stderr.write( + "\nThe appservices must be disabled in the main synapse process" + "\nbefore they can be run in a separate worker." + "\nPlease add ``notify_appservices: false`` to the main config" + "\n" + ) + sys.exit(1) + + # Force the appservice to start since they will be disabled in the main config + config.notify_appservices = True + + if config.worker_app == "synapse.app.pusher": + if config.start_pushers: + sys.stderr.write( + "\nThe pushers must be disabled in the main synapse process" + "\nbefore they can be run in a separate worker." + "\nPlease add ``start_pushers: false`` to the main config" + "\n" + ) + sys.exit(1) + + # Force the pushers to start since they will be disabled in the main config + config.start_pushers = True + + if config.worker_app == "synapse.app.user_dir": + if config.update_user_directory: + sys.stderr.write( + "\nThe update_user_directory must be disabled in the main synapse process" + "\nbefore they can be run in a separate worker." + "\nPlease add ``update_user_directory: false`` to the main config" + "\n" + ) + sys.exit(1) + + # Force the pushers to start since they will be disabled in the main config + config.update_user_directory = True + + if config.worker_app == "synapse.app.federation_sender": + if config.send_federation: + sys.stderr.write( + "\nThe send_federation must be disabled in the main synapse process" + "\nbefore they can be run in a separate worker." + "\nPlease add ``send_federation: false`` to the main config" + "\n" + ) + sys.exit(1) + + # Force the pushers to start since they will be disabled in the main config + config.send_federation = True + + synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts + + ss = GenericWorkerServer( + config.server_name, + config=config, + version_string="Synapse/" + get_version_string(synapse), + ) + + setup_logging(ss, config, use_worker_options=True) + + ss.setup() + reactor.addSystemEventTrigger( + "before", "startup", _base.start, ss, config.worker_listeners + ) + + _base.start_worker_reactor("synapse-generic-worker", config) + + +if __name__ == "__main__": + with LoggingContext("main"): + start(sys.argv[1:]) diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index 5b5832214a..add43147b3 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -13,162 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging -import sys - -from twisted.internet import reactor -from twisted.web.resource import NoResource - -import synapse -from synapse import events -from synapse.api.urls import LEGACY_MEDIA_PREFIX, MEDIA_PREFIX -from synapse.app import _base -from synapse.config._base import ConfigError -from synapse.config.homeserver import HomeServerConfig -from synapse.config.logger import setup_logging -from synapse.http.server import JsonResource -from synapse.http.site import SynapseSite -from synapse.logging.context import LoggingContext -from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy -from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore -from synapse.replication.slave.storage.client_ips import SlavedClientIpStore -from synapse.replication.slave.storage.registration import SlavedRegistrationStore -from synapse.replication.slave.storage.room import RoomStore -from synapse.replication.slave.storage.transactions import SlavedTransactionStore -from synapse.replication.tcp.client import ReplicationClientHandler -from synapse.rest.admin import register_servlets_for_media_repo -from synapse.server import HomeServer -from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore -from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.manhole import manhole -from synapse.util.versionstring import get_version_string - -logger = logging.getLogger("synapse.app.media_repository") - - -class MediaRepositorySlavedStore( - RoomStore, - SlavedApplicationServiceStore, - SlavedRegistrationStore, - SlavedClientIpStore, - SlavedTransactionStore, - BaseSlavedStore, - MediaRepositoryStore, -): - pass - - -class MediaRepositoryServer(HomeServer): - DATASTORE_CLASS = MediaRepositorySlavedStore - - def _listen_http(self, listener_config): - port = listener_config["port"] - bind_addresses = listener_config["bind_addresses"] - site_tag = listener_config.get("tag", port) - resources = {} - for res in listener_config["resources"]: - for name in res["names"]: - if name == "metrics": - resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) - elif name == "media": - media_repo = self.get_media_repository_resource() - - # We need to serve the admin servlets for media on the - # worker. - admin_resource = JsonResource(self, canonical_json=False) - register_servlets_for_media_repo(self, admin_resource) - - resources.update( - { - MEDIA_PREFIX: media_repo, - LEGACY_MEDIA_PREFIX: media_repo, - "/_synapse/admin": admin_resource, - } - ) - - root_resource = create_resource_tree(resources, NoResource()) - - _base.listen_tcp( - bind_addresses, - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - self.version_string, - ), - ) - logger.info("Synapse media repository now listening on port %d", port) - - def start_listening(self, listeners): - for listener in listeners: - if listener["type"] == "http": - self._listen_http(listener) - elif listener["type"] == "manhole": - _base.listen_tcp( - listener["bind_addresses"], - listener["port"], - manhole( - username="matrix", password="rabbithole", globals={"hs": self} - ), - ) - elif listener["type"] == "metrics": - if not self.get_config().enable_metrics: - logger.warning( - ( - "Metrics listener configured, but " - "enable_metrics is not True!" - ) - ) - else: - _base.listen_metrics(listener["bind_addresses"], listener["port"]) - else: - logger.warning("Unrecognized listener type: %s", listener["type"]) - - self.get_tcp_replication().start_replication(self) - - def build_tcp_replication(self): - return ReplicationClientHandler(self.get_datastore()) - - -def start(config_options): - try: - config = HomeServerConfig.load_config( - "Synapse media repository", config_options - ) - except ConfigError as e: - sys.stderr.write("\n" + str(e) + "\n") - sys.exit(1) - - assert config.worker_app == "synapse.app.media_repository" - - if config.enable_media_repo: - _base.quit_with_error( - "enable_media_repo must be disabled in the main synapse process\n" - "before the media repo can be run in a separate worker.\n" - "Please add ``enable_media_repo: false`` to the main config\n" - ) - - events.USE_FROZEN_DICTS = config.use_frozen_dicts - - ss = MediaRepositoryServer( - config.server_name, - config=config, - version_string="Synapse/" + get_version_string(synapse), - ) - - setup_logging(ss, config, use_worker_options=True) - - ss.setup() - reactor.addSystemEventTrigger( - "before", "startup", _base.start, ss, config.worker_listeners - ) - - _base.start_worker_reactor("synapse-media-repository", config) +import sys +from synapse.app.generic_worker import start +from synapse.util.logcontext import LoggingContext if __name__ == "__main__": with LoggingContext("main"): diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index 84e9f8d5e2..add43147b3 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -13,213 +13,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging -import sys - -from twisted.internet import defer, reactor -from twisted.web.resource import NoResource - -import synapse -from synapse import events -from synapse.app import _base -from synapse.config._base import ConfigError -from synapse.config.homeserver import HomeServerConfig -from synapse.config.logger import setup_logging -from synapse.http.site import SynapseSite -from synapse.logging.context import LoggingContext, run_in_background -from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy -from synapse.replication.slave.storage._base import __func__ -from synapse.replication.slave.storage.account_data import SlavedAccountDataStore -from synapse.replication.slave.storage.events import SlavedEventStore -from synapse.replication.slave.storage.pushers import SlavedPusherStore -from synapse.replication.slave.storage.receipts import SlavedReceiptsStore -from synapse.replication.slave.storage.room import RoomStore -from synapse.replication.tcp.client import ReplicationClientHandler -from synapse.server import HomeServer -from synapse.storage import DataStore -from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.manhole import manhole -from synapse.util.versionstring import get_version_string - -logger = logging.getLogger("synapse.app.pusher") - - -class PusherSlaveStore( - SlavedEventStore, - SlavedPusherStore, - SlavedReceiptsStore, - SlavedAccountDataStore, - RoomStore, -): - update_pusher_last_stream_ordering_and_success = __func__( - DataStore.update_pusher_last_stream_ordering_and_success - ) - - update_pusher_failing_since = __func__(DataStore.update_pusher_failing_since) - - update_pusher_last_stream_ordering = __func__( - DataStore.update_pusher_last_stream_ordering - ) - - get_throttle_params_by_room = __func__(DataStore.get_throttle_params_by_room) - - set_throttle_params = __func__(DataStore.set_throttle_params) - - get_time_of_last_push_action_before = __func__( - DataStore.get_time_of_last_push_action_before - ) - - get_profile_displayname = __func__(DataStore.get_profile_displayname) - - -class PusherServer(HomeServer): - DATASTORE_CLASS = PusherSlaveStore - - def remove_pusher(self, app_id, push_key, user_id): - self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id) - - def _listen_http(self, listener_config): - port = listener_config["port"] - bind_addresses = listener_config["bind_addresses"] - site_tag = listener_config.get("tag", port) - resources = {} - for res in listener_config["resources"]: - for name in res["names"]: - if name == "metrics": - resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) - - root_resource = create_resource_tree(resources, NoResource()) - - _base.listen_tcp( - bind_addresses, - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - self.version_string, - ), - ) - - logger.info("Synapse pusher now listening on port %d", port) - - def start_listening(self, listeners): - for listener in listeners: - if listener["type"] == "http": - self._listen_http(listener) - elif listener["type"] == "manhole": - _base.listen_tcp( - listener["bind_addresses"], - listener["port"], - manhole( - username="matrix", password="rabbithole", globals={"hs": self} - ), - ) - elif listener["type"] == "metrics": - if not self.get_config().enable_metrics: - logger.warning( - ( - "Metrics listener configured, but " - "enable_metrics is not True!" - ) - ) - else: - _base.listen_metrics(listener["bind_addresses"], listener["port"]) - else: - logger.warning("Unrecognized listener type: %s", listener["type"]) - - self.get_tcp_replication().start_replication(self) - def build_tcp_replication(self): - return PusherReplicationHandler(self) - - -class PusherReplicationHandler(ReplicationClientHandler): - def __init__(self, hs): - super(PusherReplicationHandler, self).__init__(hs.get_datastore()) - - self.pusher_pool = hs.get_pusherpool() - - async def on_rdata(self, stream_name, token, rows): - await super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows) - run_in_background(self.poke_pushers, stream_name, token, rows) - - @defer.inlineCallbacks - def poke_pushers(self, stream_name, token, rows): - try: - if stream_name == "pushers": - for row in rows: - if row.deleted: - yield self.stop_pusher(row.user_id, row.app_id, row.pushkey) - else: - yield self.start_pusher(row.user_id, row.app_id, row.pushkey) - elif stream_name == "events": - yield self.pusher_pool.on_new_notifications(token, token) - elif stream_name == "receipts": - yield self.pusher_pool.on_new_receipts( - token, token, {row.room_id for row in rows} - ) - except Exception: - logger.exception("Error poking pushers") - - def stop_pusher(self, user_id, app_id, pushkey): - key = "%s:%s" % (app_id, pushkey) - pushers_for_user = self.pusher_pool.pushers.get(user_id, {}) - pusher = pushers_for_user.pop(key, None) - if pusher is None: - return - logger.info("Stopping pusher %r / %r", user_id, key) - pusher.on_stop() - - def start_pusher(self, user_id, app_id, pushkey): - key = "%s:%s" % (app_id, pushkey) - logger.info("Starting pusher %r / %r", user_id, key) - return self.pusher_pool.start_pusher_by_id(app_id, pushkey, user_id) - - -def start(config_options): - try: - config = HomeServerConfig.load_config("Synapse pusher", config_options) - except ConfigError as e: - sys.stderr.write("\n" + str(e) + "\n") - sys.exit(1) - - assert config.worker_app == "synapse.app.pusher" - - events.USE_FROZEN_DICTS = config.use_frozen_dicts - - if config.start_pushers: - sys.stderr.write( - "\nThe pushers must be disabled in the main synapse process" - "\nbefore they can be run in a separate worker." - "\nPlease add ``start_pushers: false`` to the main config" - "\n" - ) - sys.exit(1) - - # Force the pushers to start since they will be disabled in the main config - config.start_pushers = True - - ps = PusherServer( - config.server_name, - config=config, - version_string="Synapse/" + get_version_string(synapse), - ) - - setup_logging(ps, config, use_worker_options=True) - - ps.setup() - - def start(): - _base.start(ps, config.worker_listeners) - ps.get_pusherpool().start() - - reactor.addSystemEventTrigger("before", "startup", start) - - _base.start_worker_reactor("synapse-pusher", config) +import sys +from synapse.app.generic_worker import start +from synapse.util.logcontext import LoggingContext if __name__ == "__main__": with LoggingContext("main"): - ps = start(sys.argv[1:]) + start(sys.argv[1:]) diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 8982c0676e..add43147b3 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -13,454 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import contextlib -import logging -import sys - -from six import iteritems - -from twisted.internet import defer, reactor -from twisted.web.resource import NoResource - -import synapse -from synapse.api.constants import EventTypes -from synapse.app import _base -from synapse.config._base import ConfigError -from synapse.config.homeserver import HomeServerConfig -from synapse.config.logger import setup_logging -from synapse.handlers.presence import PresenceHandler, get_interested_parties -from synapse.http.server import JsonResource -from synapse.http.site import SynapseSite -from synapse.logging.context import LoggingContext, run_in_background -from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy -from synapse.replication.slave.storage._base import BaseSlavedStore, __func__ -from synapse.replication.slave.storage.account_data import SlavedAccountDataStore -from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore -from synapse.replication.slave.storage.client_ips import SlavedClientIpStore -from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore -from synapse.replication.slave.storage.devices import SlavedDeviceStore -from synapse.replication.slave.storage.events import SlavedEventStore -from synapse.replication.slave.storage.filtering import SlavedFilteringStore -from synapse.replication.slave.storage.groups import SlavedGroupServerStore -from synapse.replication.slave.storage.presence import SlavedPresenceStore -from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore -from synapse.replication.slave.storage.receipts import SlavedReceiptsStore -from synapse.replication.slave.storage.registration import SlavedRegistrationStore -from synapse.replication.slave.storage.room import RoomStore -from synapse.replication.tcp.client import ReplicationClientHandler -from synapse.replication.tcp.streams.events import EventsStreamEventRow, EventsStreamRow -from synapse.rest.client.v1 import events -from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet -from synapse.rest.client.v1.room import RoomInitialSyncRestServlet -from synapse.rest.client.v2_alpha import sync -from synapse.server import HomeServer -from synapse.storage.data_stores.main.monthly_active_users import ( - MonthlyActiveUsersWorkerStore, -) -from synapse.storage.data_stores.main.presence import UserPresenceState -from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.manhole import manhole -from synapse.util.stringutils import random_string -from synapse.util.versionstring import get_version_string - -logger = logging.getLogger("synapse.app.synchrotron") - - -class SynchrotronSlavedStore( - SlavedReceiptsStore, - SlavedAccountDataStore, - SlavedApplicationServiceStore, - SlavedRegistrationStore, - SlavedFilteringStore, - SlavedPresenceStore, - SlavedGroupServerStore, - SlavedDeviceInboxStore, - SlavedDeviceStore, - SlavedPushRuleStore, - SlavedEventStore, - SlavedClientIpStore, - RoomStore, - MonthlyActiveUsersWorkerStore, - BaseSlavedStore, -): - pass - - -UPDATE_SYNCING_USERS_MS = 10 * 1000 - - -class SynchrotronPresence(object): - def __init__(self, hs): - self.hs = hs - self.is_mine_id = hs.is_mine_id - self.http_client = hs.get_simple_http_client() - self.store = hs.get_datastore() - self.user_to_num_current_syncs = {} - self.clock = hs.get_clock() - self.notifier = hs.get_notifier() - - active_presence = self.store.take_presence_startup_info() - self.user_to_current_state = {state.user_id: state for state in active_presence} - - # user_id -> last_sync_ms. Lists the users that have stopped syncing - # but we haven't notified the master of that yet - self.users_going_offline = {} - - self._send_stop_syncing_loop = self.clock.looping_call( - self.send_stop_syncing, 10 * 1000 - ) - - self.process_id = random_string(16) - logger.info("Presence process_id is %r", self.process_id) - - def send_user_sync(self, user_id, is_syncing, last_sync_ms): - if self.hs.config.use_presence: - self.hs.get_tcp_replication().send_user_sync( - user_id, is_syncing, last_sync_ms - ) - - def mark_as_coming_online(self, user_id): - """A user has started syncing. Send a UserSync to the master, unless they - had recently stopped syncing. - - Args: - user_id (str) - """ - going_offline = self.users_going_offline.pop(user_id, None) - if not going_offline: - # Safe to skip because we haven't yet told the master they were offline - self.send_user_sync(user_id, True, self.clock.time_msec()) - - def mark_as_going_offline(self, user_id): - """A user has stopped syncing. We wait before notifying the master as - its likely they'll come back soon. This allows us to avoid sending - a stopped syncing immediately followed by a started syncing notification - to the master - - Args: - user_id (str) - """ - self.users_going_offline[user_id] = self.clock.time_msec() - - def send_stop_syncing(self): - """Check if there are any users who have stopped syncing a while ago - and haven't come back yet. If there are poke the master about them. - """ - now = self.clock.time_msec() - for user_id, last_sync_ms in list(self.users_going_offline.items()): - if now - last_sync_ms > 10 * 1000: - self.users_going_offline.pop(user_id, None) - self.send_user_sync(user_id, False, last_sync_ms) - - def set_state(self, user, state, ignore_status_msg=False): - # TODO Hows this supposed to work? - return defer.succeed(None) - - get_states = __func__(PresenceHandler.get_states) - get_state = __func__(PresenceHandler.get_state) - current_state_for_users = __func__(PresenceHandler.current_state_for_users) - - def user_syncing(self, user_id, affect_presence): - if affect_presence: - curr_sync = self.user_to_num_current_syncs.get(user_id, 0) - self.user_to_num_current_syncs[user_id] = curr_sync + 1 - - # If we went from no in flight sync to some, notify replication - if self.user_to_num_current_syncs[user_id] == 1: - self.mark_as_coming_online(user_id) - - def _end(): - # We check that the user_id is in user_to_num_current_syncs because - # user_to_num_current_syncs may have been cleared if we are - # shutting down. - if affect_presence and user_id in self.user_to_num_current_syncs: - self.user_to_num_current_syncs[user_id] -= 1 - - # If we went from one in flight sync to non, notify replication - if self.user_to_num_current_syncs[user_id] == 0: - self.mark_as_going_offline(user_id) - - @contextlib.contextmanager - def _user_syncing(): - try: - yield - finally: - _end() - - return defer.succeed(_user_syncing()) - - @defer.inlineCallbacks - def notify_from_replication(self, states, stream_id): - parties = yield get_interested_parties(self.store, states) - room_ids_to_states, users_to_states = parties - - self.notifier.on_new_event( - "presence_key", - stream_id, - rooms=room_ids_to_states.keys(), - users=users_to_states.keys(), - ) - - @defer.inlineCallbacks - def process_replication_rows(self, token, rows): - states = [ - UserPresenceState( - row.user_id, - row.state, - row.last_active_ts, - row.last_federation_update_ts, - row.last_user_sync_ts, - row.status_msg, - row.currently_active, - ) - for row in rows - ] - - for state in states: - self.user_to_current_state[state.user_id] = state - - stream_id = token - yield self.notify_from_replication(states, stream_id) - - def get_currently_syncing_users(self): - if self.hs.config.use_presence: - return [ - user_id - for user_id, count in iteritems(self.user_to_num_current_syncs) - if count > 0 - ] - else: - return set() - -class SynchrotronTyping(object): - def __init__(self, hs): - self._latest_room_serial = 0 - self._reset() - - def _reset(self): - """ - Reset the typing handler's data caches. - """ - # map room IDs to serial numbers - self._room_serials = {} - # map room IDs to sets of users currently typing - self._room_typing = {} - - def stream_positions(self): - # We must update this typing token from the response of the previous - # sync. In particular, the stream id may "reset" back to zero/a low - # value which we *must* use for the next replication request. - return {"typing": self._latest_room_serial} - - def process_replication_rows(self, token, rows): - if self._latest_room_serial > token: - # The master has gone backwards. To prevent inconsistent data, just - # clear everything. - self._reset() - - # Set the latest serial token to whatever the server gave us. - self._latest_room_serial = token - - for row in rows: - self._room_serials[row.room_id] = token - self._room_typing[row.room_id] = row.user_ids - - -class SynchrotronApplicationService(object): - def notify_interested_services(self, event): - pass - - -class SynchrotronServer(HomeServer): - DATASTORE_CLASS = SynchrotronSlavedStore - - def _listen_http(self, listener_config): - port = listener_config["port"] - bind_addresses = listener_config["bind_addresses"] - site_tag = listener_config.get("tag", port) - resources = {} - for res in listener_config["resources"]: - for name in res["names"]: - if name == "metrics": - resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) - elif name == "client": - resource = JsonResource(self, canonical_json=False) - sync.register_servlets(self, resource) - events.register_servlets(self, resource) - InitialSyncRestServlet(self).register(resource) - RoomInitialSyncRestServlet(self).register(resource) - resources.update( - { - "/_matrix/client/r0": resource, - "/_matrix/client/unstable": resource, - "/_matrix/client/v2_alpha": resource, - "/_matrix/client/api/v1": resource, - } - ) - - root_resource = create_resource_tree(resources, NoResource()) - - _base.listen_tcp( - bind_addresses, - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - self.version_string, - ), - ) - - logger.info("Synapse synchrotron now listening on port %d", port) - - def start_listening(self, listeners): - for listener in listeners: - if listener["type"] == "http": - self._listen_http(listener) - elif listener["type"] == "manhole": - _base.listen_tcp( - listener["bind_addresses"], - listener["port"], - manhole( - username="matrix", password="rabbithole", globals={"hs": self} - ), - ) - elif listener["type"] == "metrics": - if not self.get_config().enable_metrics: - logger.warning( - ( - "Metrics listener configured, but " - "enable_metrics is not True!" - ) - ) - else: - _base.listen_metrics(listener["bind_addresses"], listener["port"]) - else: - logger.warning("Unrecognized listener type: %s", listener["type"]) - - self.get_tcp_replication().start_replication(self) - - def build_tcp_replication(self): - return SyncReplicationHandler(self) - - def build_presence_handler(self): - return SynchrotronPresence(self) - - def build_typing_handler(self): - return SynchrotronTyping(self) - - -class SyncReplicationHandler(ReplicationClientHandler): - def __init__(self, hs): - super(SyncReplicationHandler, self).__init__(hs.get_datastore()) - - self.store = hs.get_datastore() - self.typing_handler = hs.get_typing_handler() - # NB this is a SynchrotronPresence, not a normal PresenceHandler - self.presence_handler = hs.get_presence_handler() - self.notifier = hs.get_notifier() - - async def on_rdata(self, stream_name, token, rows): - await super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows) - run_in_background(self.process_and_notify, stream_name, token, rows) - - def get_streams_to_replicate(self): - args = super(SyncReplicationHandler, self).get_streams_to_replicate() - args.update(self.typing_handler.stream_positions()) - return args - - def get_currently_syncing_users(self): - return self.presence_handler.get_currently_syncing_users() - - async def process_and_notify(self, stream_name, token, rows): - try: - if stream_name == "events": - # We shouldn't get multiple rows per token for events stream, so - # we don't need to optimise this for multiple rows. - for row in rows: - if row.type != EventsStreamEventRow.TypeId: - continue - assert isinstance(row, EventsStreamRow) - - event = await self.store.get_event( - row.data.event_id, allow_rejected=True - ) - if event.rejected_reason: - continue - - extra_users = () - if event.type == EventTypes.Member: - extra_users = (event.state_key,) - max_token = self.store.get_room_max_stream_ordering() - self.notifier.on_new_room_event( - event, token, max_token, extra_users - ) - elif stream_name == "push_rules": - self.notifier.on_new_event( - "push_rules_key", token, users=[row.user_id for row in rows] - ) - elif stream_name in ("account_data", "tag_account_data"): - self.notifier.on_new_event( - "account_data_key", token, users=[row.user_id for row in rows] - ) - elif stream_name == "receipts": - self.notifier.on_new_event( - "receipt_key", token, rooms=[row.room_id for row in rows] - ) - elif stream_name == "typing": - self.typing_handler.process_replication_rows(token, rows) - self.notifier.on_new_event( - "typing_key", token, rooms=[row.room_id for row in rows] - ) - elif stream_name == "to_device": - entities = [row.entity for row in rows if row.entity.startswith("@")] - if entities: - self.notifier.on_new_event("to_device_key", token, users=entities) - elif stream_name == "device_lists": - all_room_ids = set() - for row in rows: - room_ids = await self.store.get_rooms_for_user(row.user_id) - all_room_ids.update(room_ids) - self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids) - elif stream_name == "presence": - await self.presence_handler.process_replication_rows(token, rows) - elif stream_name == "receipts": - self.notifier.on_new_event( - "groups_key", token, users=[row.user_id for row in rows] - ) - except Exception: - logger.exception("Error processing replication") - - -def start(config_options): - try: - config = HomeServerConfig.load_config("Synapse synchrotron", config_options) - except ConfigError as e: - sys.stderr.write("\n" + str(e) + "\n") - sys.exit(1) - - assert config.worker_app == "synapse.app.synchrotron" - - synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts - - ss = SynchrotronServer( - config.server_name, - config=config, - version_string="Synapse/" + get_version_string(synapse), - application_service_handler=SynchrotronApplicationService(), - ) - - setup_logging(ss, config, use_worker_options=True) - - ss.setup() - reactor.addSystemEventTrigger( - "before", "startup", _base.start, ss, config.worker_listeners - ) - - _base.start_worker_reactor("synapse-synchrotron", config) +import sys +from synapse.app.generic_worker import start +from synapse.util.logcontext import LoggingContext if __name__ == "__main__": with LoggingContext("main"): diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index ba536d6f04..503d44f687 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -14,217 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging import sys -from twisted.internet import defer, reactor -from twisted.web.resource import NoResource - -import synapse -from synapse import events -from synapse.app import _base -from synapse.config._base import ConfigError -from synapse.config.homeserver import HomeServerConfig -from synapse.config.logger import setup_logging -from synapse.http.server import JsonResource -from synapse.http.site import SynapseSite -from synapse.logging.context import LoggingContext, run_in_background -from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy -from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore -from synapse.replication.slave.storage.client_ips import SlavedClientIpStore -from synapse.replication.slave.storage.events import SlavedEventStore -from synapse.replication.slave.storage.registration import SlavedRegistrationStore -from synapse.replication.tcp.client import ReplicationClientHandler -from synapse.replication.tcp.streams.events import ( - EventsStream, - EventsStreamCurrentStateRow, -) -from synapse.rest.client.v2_alpha import user_directory -from synapse.server import HomeServer -from synapse.storage.data_stores.main.user_directory import UserDirectoryStore -from synapse.storage.database import Database -from synapse.util.caches.stream_change_cache import StreamChangeCache -from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.manhole import manhole -from synapse.util.versionstring import get_version_string - -logger = logging.getLogger("synapse.app.user_dir") - - -class UserDirectorySlaveStore( - SlavedEventStore, - SlavedApplicationServiceStore, - SlavedRegistrationStore, - SlavedClientIpStore, - UserDirectoryStore, - BaseSlavedStore, -): - def __init__(self, database: Database, db_conn, hs): - super(UserDirectorySlaveStore, self).__init__(database, db_conn, hs) - - events_max = self._stream_id_gen.get_current_token() - curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict( - db_conn, - "current_state_delta_stream", - entity_column="room_id", - stream_column="stream_id", - max_value=events_max, # As we share the stream id with events token - limit=1000, - ) - self._curr_state_delta_stream_cache = StreamChangeCache( - "_curr_state_delta_stream_cache", - min_curr_state_delta_id, - prefilled_cache=curr_state_delta_prefill, - ) - - def stream_positions(self): - result = super(UserDirectorySlaveStore, self).stream_positions() - return result - - def process_replication_rows(self, stream_name, token, rows): - if stream_name == EventsStream.NAME: - self._stream_id_gen.advance(token) - for row in rows: - if row.type != EventsStreamCurrentStateRow.TypeId: - continue - self._curr_state_delta_stream_cache.entity_has_changed( - row.data.room_id, token - ) - return super(UserDirectorySlaveStore, self).process_replication_rows( - stream_name, token, rows - ) - - -class UserDirectoryServer(HomeServer): - DATASTORE_CLASS = UserDirectorySlaveStore - - def _listen_http(self, listener_config): - port = listener_config["port"] - bind_addresses = listener_config["bind_addresses"] - site_tag = listener_config.get("tag", port) - resources = {} - for res in listener_config["resources"]: - for name in res["names"]: - if name == "metrics": - resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) - elif name == "client": - resource = JsonResource(self, canonical_json=False) - user_directory.register_servlets(self, resource) - resources.update( - { - "/_matrix/client/r0": resource, - "/_matrix/client/unstable": resource, - "/_matrix/client/v2_alpha": resource, - "/_matrix/client/api/v1": resource, - } - ) - - root_resource = create_resource_tree(resources, NoResource()) - - _base.listen_tcp( - bind_addresses, - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - self.version_string, - ), - ) - - logger.info("Synapse user_dir now listening on port %d", port) - - def start_listening(self, listeners): - for listener in listeners: - if listener["type"] == "http": - self._listen_http(listener) - elif listener["type"] == "manhole": - _base.listen_tcp( - listener["bind_addresses"], - listener["port"], - manhole( - username="matrix", password="rabbithole", globals={"hs": self} - ), - ) - elif listener["type"] == "metrics": - if not self.get_config().enable_metrics: - logger.warning( - ( - "Metrics listener configured, but " - "enable_metrics is not True!" - ) - ) - else: - _base.listen_metrics(listener["bind_addresses"], listener["port"]) - else: - logger.warning("Unrecognized listener type: %s", listener["type"]) - - self.get_tcp_replication().start_replication(self) - - def build_tcp_replication(self): - return UserDirectoryReplicationHandler(self) - - -class UserDirectoryReplicationHandler(ReplicationClientHandler): - def __init__(self, hs): - super(UserDirectoryReplicationHandler, self).__init__(hs.get_datastore()) - self.user_directory = hs.get_user_directory_handler() - - async def on_rdata(self, stream_name, token, rows): - await super(UserDirectoryReplicationHandler, self).on_rdata( - stream_name, token, rows - ) - if stream_name == EventsStream.NAME: - run_in_background(self._notify_directory) - - @defer.inlineCallbacks - def _notify_directory(self): - try: - yield self.user_directory.notify_new_event() - except Exception: - logger.exception("Error notifiying user directory of state update") - - -def start(config_options): - try: - config = HomeServerConfig.load_config("Synapse user directory", config_options) - except ConfigError as e: - sys.stderr.write("\n" + str(e) + "\n") - sys.exit(1) - - assert config.worker_app == "synapse.app.user_dir" - - events.USE_FROZEN_DICTS = config.use_frozen_dicts - - if config.update_user_directory: - sys.stderr.write( - "\nThe update_user_directory must be disabled in the main synapse process" - "\nbefore they can be run in a separate worker." - "\nPlease add ``update_user_directory: false`` to the main config" - "\n" - ) - sys.exit(1) - - # Force the pushers to start since they will be disabled in the main config - config.update_user_directory = True - - ss = UserDirectoryServer( - config.server_name, - config=config, - version_string="Synapse/" + get_version_string(synapse), - ) - - setup_logging(ss, config, use_worker_options=True) - - ss.setup() - reactor.addSystemEventTrigger( - "before", "startup", _base.start, ss, config.worker_listeners - ) - - _base.start_worker_reactor("synapse-user-dir", config) - +from synapse.app.generic_worker import start +from synapse.util.logcontext import LoggingContext if __name__ == "__main__": with LoggingContext("main"): diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 3aa6cb8b96..e73342c657 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -32,6 +32,7 @@ from synapse.storage.data_stores.main.state import StateGroupWorkerStore from synapse.storage.data_stores.main.stream import StreamWorkerStore from synapse.storage.data_stores.main.user_erasure_store import UserErasureWorkerStore from synapse.storage.database import Database +from synapse.util.caches.stream_change_cache import StreamChangeCache from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker @@ -68,6 +69,21 @@ class SlavedEventStore( super(SlavedEventStore, self).__init__(database, db_conn, hs) + events_max = self._stream_id_gen.get_current_token() + curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict( + db_conn, + "current_state_delta_stream", + entity_column="room_id", + stream_column="stream_id", + max_value=events_max, # As we share the stream id with events token + limit=1000, + ) + self._curr_state_delta_stream_cache = StreamChangeCache( + "_curr_state_delta_stream_cache", + min_curr_state_delta_id, + prefilled_cache=curr_state_delta_prefill, + ) + # Cached functions can't be accessed through a class instance so we need # to reach inside the __dict__ to extract them. @@ -120,6 +136,10 @@ class SlavedEventStore( backfilled=False, ) elif row.type == EventsStreamCurrentStateRow.TypeId: + self._curr_state_delta_stream_cache.entity_has_changed( + row.data.room_id, token + ) + if data.type == EventTypes.Member: self.get_rooms_for_user_with_stream_ordering.invalidate( (data.state_key,) diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py index 6b03233262..547b9d69cb 100644 --- a/synapse/storage/data_stores/main/pusher.py +++ b/synapse/storage/data_stores/main/pusher.py @@ -197,6 +197,84 @@ class PusherWorkerStore(SQLBaseStore): return result + @defer.inlineCallbacks + def update_pusher_last_stream_ordering( + self, app_id, pushkey, user_id, last_stream_ordering + ): + yield self.db.simple_update_one( + "pushers", + {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, + {"last_stream_ordering": last_stream_ordering}, + desc="update_pusher_last_stream_ordering", + ) + + @defer.inlineCallbacks + def update_pusher_last_stream_ordering_and_success( + self, app_id, pushkey, user_id, last_stream_ordering, last_success + ): + """Update the last stream ordering position we've processed up to for + the given pusher. + + Args: + app_id (str) + pushkey (str) + last_stream_ordering (int) + last_success (int) + + Returns: + Deferred[bool]: True if the pusher still exists; False if it has been deleted. + """ + updated = yield self.db.simple_update( + table="pushers", + keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, + updatevalues={ + "last_stream_ordering": last_stream_ordering, + "last_success": last_success, + }, + desc="update_pusher_last_stream_ordering_and_success", + ) + + return bool(updated) + + @defer.inlineCallbacks + def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since): + yield self.db.simple_update( + table="pushers", + keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, + updatevalues={"failing_since": failing_since}, + desc="update_pusher_failing_since", + ) + + @defer.inlineCallbacks + def get_throttle_params_by_room(self, pusher_id): + res = yield self.db.simple_select_list( + "pusher_throttle", + {"pusher": pusher_id}, + ["room_id", "last_sent_ts", "throttle_ms"], + desc="get_throttle_params_by_room", + ) + + params_by_room = {} + for row in res: + params_by_room[row["room_id"]] = { + "last_sent_ts": row["last_sent_ts"], + "throttle_ms": row["throttle_ms"], + } + + return params_by_room + + @defer.inlineCallbacks + def set_throttle_params(self, pusher_id, room_id, params): + # no need to lock because `pusher_throttle` has a primary key on + # (pusher, room_id) so simple_upsert will retry + yield self.db.simple_upsert( + "pusher_throttle", + {"pusher": pusher_id, "room_id": room_id}, + params, + desc="set_throttle_params", + lock=False, + ) + class PusherStore(PusherWorkerStore): def get_pushers_stream_token(self): @@ -282,81 +360,3 @@ class PusherStore(PusherWorkerStore): with self._pushers_id_gen.get_next() as stream_id: yield self.db.runInteraction("delete_pusher", delete_pusher_txn, stream_id) - - @defer.inlineCallbacks - def update_pusher_last_stream_ordering( - self, app_id, pushkey, user_id, last_stream_ordering - ): - yield self.db.simple_update_one( - "pushers", - {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, - {"last_stream_ordering": last_stream_ordering}, - desc="update_pusher_last_stream_ordering", - ) - - @defer.inlineCallbacks - def update_pusher_last_stream_ordering_and_success( - self, app_id, pushkey, user_id, last_stream_ordering, last_success - ): - """Update the last stream ordering position we've processed up to for - the given pusher. - - Args: - app_id (str) - pushkey (str) - last_stream_ordering (int) - last_success (int) - - Returns: - Deferred[bool]: True if the pusher still exists; False if it has been deleted. - """ - updated = yield self.db.simple_update( - table="pushers", - keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, - updatevalues={ - "last_stream_ordering": last_stream_ordering, - "last_success": last_success, - }, - desc="update_pusher_last_stream_ordering_and_success", - ) - - return bool(updated) - - @defer.inlineCallbacks - def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since): - yield self.db.simple_update( - table="pushers", - keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, - updatevalues={"failing_since": failing_since}, - desc="update_pusher_failing_since", - ) - - @defer.inlineCallbacks - def get_throttle_params_by_room(self, pusher_id): - res = yield self.db.simple_select_list( - "pusher_throttle", - {"pusher": pusher_id}, - ["room_id", "last_sent_ts", "throttle_ms"], - desc="get_throttle_params_by_room", - ) - - params_by_room = {} - for row in res: - params_by_room[row["room_id"]] = { - "last_sent_ts": row["last_sent_ts"], - "throttle_ms": row["throttle_ms"], - } - - return params_by_room - - @defer.inlineCallbacks - def set_throttle_params(self, pusher_id, room_id, params): - # no need to lock because `pusher_throttle` has a primary key on - # (pusher, room_id) so simple_upsert will retry - yield self.db.simple_upsert( - "pusher_throttle", - {"pusher": pusher_id, "room_id": room_id}, - params, - desc="set_throttle_params", - lock=False, - ) diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py index 8bdbc608a9..160e55aca9 100644 --- a/tests/app/test_frontend_proxy.py +++ b/tests/app/test_frontend_proxy.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.app.frontend_proxy import FrontendProxyServer +from synapse.app.generic_worker import GenericWorkerServer from tests.unittest import HomeserverTestCase @@ -22,7 +22,7 @@ class FrontendProxyTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - http_client=None, homeserverToUse=FrontendProxyServer + http_client=None, homeserverToUse=GenericWorkerServer ) return hs @@ -46,9 +46,7 @@ class FrontendProxyTests(HomeserverTestCase): # Grab the resource from the site that was told to listen self.assertEqual(len(self.reactor.tcpServers), 1) site = self.reactor.tcpServers[0][1] - self.resource = ( - site.resource.children[b"_matrix"].children[b"client"].children[b"r0"] - ) + self.resource = site.resource.children[b"_matrix"].children[b"client"] request, channel = self.make_request("PUT", "presence/a/status") self.render(request) @@ -76,9 +74,7 @@ class FrontendProxyTests(HomeserverTestCase): # Grab the resource from the site that was told to listen self.assertEqual(len(self.reactor.tcpServers), 1) site = self.reactor.tcpServers[0][1] - self.resource = ( - site.resource.children[b"_matrix"].children[b"client"].children[b"r0"] - ) + self.resource = site.resource.children[b"_matrix"].children[b"client"] request, channel = self.make_request("PUT", "presence/a/status") self.render(request) diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py index 48792d1480..1fe048048b 100644 --- a/tests/app/test_openid_listener.py +++ b/tests/app/test_openid_listener.py @@ -16,7 +16,7 @@ from mock import Mock, patch from parameterized import parameterized -from synapse.app.federation_reader import FederationReaderServer +from synapse.app.generic_worker import GenericWorkerServer from synapse.app.homeserver import SynapseHomeServer from tests.unittest import HomeserverTestCase @@ -25,7 +25,7 @@ from tests.unittest import HomeserverTestCase class FederationReaderOpenIDListenerTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - http_client=None, homeserverToUse=FederationReaderServer + http_client=None, homeserverToUse=GenericWorkerServer ) return hs -- cgit 1.5.1 From 8c75b621bfe03725cc8da071516ebc66d3872760 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 26 Feb 2020 12:22:55 +0000 Subject: Ensure 'deactivated' parameter is a boolean on user admin API, Fix error handling of call to deactivate user (#6990) --- changelog.d/6990.bugfix | 1 + synapse/rest/admin/users.py | 11 +++++--- synapse/rest/client/v1/login.py | 1 + tests/rest/admin/test_user.py | 59 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 68 insertions(+), 4 deletions(-) create mode 100644 changelog.d/6990.bugfix (limited to 'tests') diff --git a/changelog.d/6990.bugfix b/changelog.d/6990.bugfix new file mode 100644 index 0000000000..8c1c48f4d4 --- /dev/null +++ b/changelog.d/6990.bugfix @@ -0,0 +1 @@ +Prevent user from setting 'deactivated' to anything other than a bool on the v2 PUT /users Admin API. \ No newline at end of file diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 2107b5dc56..c5b461a236 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -228,13 +228,16 @@ class UserRestServletV2(RestServlet): ) if "deactivated" in body: - deactivate = bool(body["deactivated"]) + deactivate = body["deactivated"] + if not isinstance(deactivate, bool): + raise SynapseError( + 400, "'deactivated' parameter is not of type boolean" + ) + if deactivate and not user["deactivated"]: - result = await self.deactivate_account_handler.deactivate_account( + await self.deactivate_account_handler.deactivate_account( target_user.to_string(), False ) - if not result: - raise SynapseError(500, "Could not deactivate user") user = await self.admin_handler.get_user(target_user) return 200, user diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 1294e080dc..2c99536678 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -599,6 +599,7 @@ class SSOAuthHandler(object): redirect_url = self._add_login_token_to_redirect_url( client_redirect_url, login_token ) + # Load page request.redirect(redirect_url) finish_request(request) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 490ce8f55d..cbe4a6a51f 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -507,3 +507,62 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(1, channel.json_body["admin"]) self.assertEqual(0, channel.json_body["is_guest"]) self.assertEqual(1, channel.json_body["deactivated"]) + + def test_accidental_deactivation_prevention(self): + """ + Ensure an account can't accidentally be deactivated by using a str value + for the deactivated body parameter + """ + self.hs.config.registration_shared_secret = None + + # Create user + body = json.dumps({"password": "abc123"}) + + request, channel = self.make_request( + "PUT", + self.url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual("bob", channel.json_body["displayname"]) + + # Get user + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual("bob", channel.json_body["displayname"]) + self.assertEqual(0, channel.json_body["deactivated"]) + + # Change password (and use a str for deactivate instead of a bool) + body = json.dumps({"password": "abc123", "deactivated": "false"}) # oops! + + request, channel = self.make_request( + "PUT", + self.url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + + # Check user is not deactivated + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual("bob", channel.json_body["displayname"]) + + # Ensure they're still alive + self.assertEqual(0, channel.json_body["deactivated"]) -- cgit 1.5.1 From 1f773eec912e4908ab60f7823f5c0a024261af4d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 26 Feb 2020 15:33:26 +0000 Subject: Port PresenceHandler to async/await (#6991) --- changelog.d/6991.misc | 1 + synapse/handlers/message.py | 5 +- synapse/handlers/presence.py | 192 ++++++++++++++++-------------------- synapse/replication/tcp/resource.py | 6 +- synapse/server.pyi | 5 + tests/handlers/test_presence.py | 18 ++-- tox.ini | 1 + 7 files changed, 113 insertions(+), 115 deletions(-) create mode 100644 changelog.d/6991.misc (limited to 'tests') diff --git a/changelog.d/6991.misc b/changelog.d/6991.misc new file mode 100644 index 0000000000..5130f4e8af --- /dev/null +++ b/changelog.d/6991.misc @@ -0,0 +1 @@ +Port `synapse.handlers.presence` to async/await. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index d6be280952..a0103addd3 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1016,11 +1016,10 @@ class EventCreationHandler(object): # matters as sometimes presence code can take a while. run_in_background(self._bump_active_time, requester.user) - @defer.inlineCallbacks - def _bump_active_time(self, user): + async def _bump_active_time(self, user): try: presence = self.hs.get_presence_handler() - yield presence.bump_presence_active_time(user) + await presence.bump_presence_active_time(user) except Exception: logger.exception("Error bumping presence active time") diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 0d6cf2b008..5526015ddb 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -24,11 +24,12 @@ The methods that define policy are: import logging from contextlib import contextmanager -from typing import Dict, Set +from typing import Dict, List, Set from six import iteritems, itervalues from prometheus_client import Counter +from typing_extensions import ContextManager from twisted.internet import defer @@ -42,10 +43,14 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.presence import UserPresenceState from synapse.types import UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer -from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.util.caches.descriptors import cached from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer +MYPY = False +if MYPY: + import synapse.server + logger = logging.getLogger(__name__) @@ -97,7 +102,6 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER class PresenceHandler(object): def __init__(self, hs: "synapse.server.HomeServer"): self.hs = hs - self.is_mine = hs.is_mine self.is_mine_id = hs.is_mine_id self.server_name = hs.hostname self.clock = hs.get_clock() @@ -150,7 +154,7 @@ class PresenceHandler(object): # Set of users who have presence in the `user_to_current_state` that # have not yet been persisted - self.unpersisted_users_changes = set() + self.unpersisted_users_changes = set() # type: Set[str] hs.get_reactor().addSystemEventTrigger( "before", @@ -160,12 +164,11 @@ class PresenceHandler(object): self._on_shutdown, ) - self.serial_to_user = {} self._next_serial = 1 # Keeps track of the number of *ongoing* syncs on this process. While # this is non zero a user will never go offline. - self.user_to_num_current_syncs = {} + self.user_to_num_current_syncs = {} # type: Dict[str, int] # Keeps track of the number of *ongoing* syncs on other processes. # While any sync is ongoing on another process the user will never @@ -213,8 +216,7 @@ class PresenceHandler(object): self._event_pos = self.store.get_current_events_token() self._event_processing = False - @defer.inlineCallbacks - def _on_shutdown(self): + async def _on_shutdown(self): """Gets called when shutting down. This lets us persist any updates that we haven't yet persisted, e.g. updates that only changes some internal timers. This allows changes to persist across startup without having to @@ -235,7 +237,7 @@ class PresenceHandler(object): if self.unpersisted_users_changes: - yield self.store.update_presence( + await self.store.update_presence( [ self.user_to_current_state[user_id] for user_id in self.unpersisted_users_changes @@ -243,8 +245,7 @@ class PresenceHandler(object): ) logger.info("Finished _on_shutdown") - @defer.inlineCallbacks - def _persist_unpersisted_changes(self): + async def _persist_unpersisted_changes(self): """We periodically persist the unpersisted changes, as otherwise they may stack up and slow down shutdown times. """ @@ -253,12 +254,11 @@ class PresenceHandler(object): if unpersisted: logger.info("Persisting %d unpersisted presence updates", len(unpersisted)) - yield self.store.update_presence( + await self.store.update_presence( [self.user_to_current_state[user_id] for user_id in unpersisted] ) - @defer.inlineCallbacks - def _update_states(self, new_states): + async def _update_states(self, new_states): """Updates presence of users. Sets the appropriate timeouts. Pokes the notifier and federation if and only if the changed presence state should be sent to clients/servers. @@ -267,7 +267,7 @@ class PresenceHandler(object): with Measure(self.clock, "presence_update_states"): - # NOTE: We purposefully don't yield between now and when we've + # NOTE: We purposefully don't await between now and when we've # calculated what we want to do with the new states, to avoid races. to_notify = {} # Changes we want to notify everyone about @@ -311,7 +311,7 @@ class PresenceHandler(object): if to_notify: notified_presence_counter.inc(len(to_notify)) - yield self._persist_and_notify(list(to_notify.values())) + await self._persist_and_notify(list(to_notify.values())) self.unpersisted_users_changes |= {s.user_id for s in new_states} self.unpersisted_users_changes -= set(to_notify.keys()) @@ -326,7 +326,7 @@ class PresenceHandler(object): self._push_to_remotes(to_federation_ping.values()) - def _handle_timeouts(self): + async def _handle_timeouts(self): """Checks the presence of users that have timed out and updates as appropriate. """ @@ -368,10 +368,9 @@ class PresenceHandler(object): now=now, ) - return self._update_states(changes) + return await self._update_states(changes) - @defer.inlineCallbacks - def bump_presence_active_time(self, user): + async def bump_presence_active_time(self, user): """We've seen the user do something that indicates they're interacting with the app. """ @@ -383,16 +382,17 @@ class PresenceHandler(object): bump_active_time_counter.inc() - prev_state = yield self.current_state_for_user(user_id) + prev_state = await self.current_state_for_user(user_id) new_fields = {"last_active_ts": self.clock.time_msec()} if prev_state.state == PresenceState.UNAVAILABLE: new_fields["state"] = PresenceState.ONLINE - yield self._update_states([prev_state.copy_and_replace(**new_fields)]) + await self._update_states([prev_state.copy_and_replace(**new_fields)]) - @defer.inlineCallbacks - def user_syncing(self, user_id, affect_presence=True): + async def user_syncing( + self, user_id: str, affect_presence: bool = True + ) -> ContextManager[None]: """Returns a context manager that should surround any stream requests from the user. @@ -415,11 +415,11 @@ class PresenceHandler(object): curr_sync = self.user_to_num_current_syncs.get(user_id, 0) self.user_to_num_current_syncs[user_id] = curr_sync + 1 - prev_state = yield self.current_state_for_user(user_id) + prev_state = await self.current_state_for_user(user_id) if prev_state.state == PresenceState.OFFLINE: # If they're currently offline then bring them online, otherwise # just update the last sync times. - yield self._update_states( + await self._update_states( [ prev_state.copy_and_replace( state=PresenceState.ONLINE, @@ -429,7 +429,7 @@ class PresenceHandler(object): ] ) else: - yield self._update_states( + await self._update_states( [ prev_state.copy_and_replace( last_user_sync_ts=self.clock.time_msec() @@ -437,13 +437,12 @@ class PresenceHandler(object): ] ) - @defer.inlineCallbacks - def _end(): + async def _end(): try: self.user_to_num_current_syncs[user_id] -= 1 - prev_state = yield self.current_state_for_user(user_id) - yield self._update_states( + prev_state = await self.current_state_for_user(user_id) + await self._update_states( [ prev_state.copy_and_replace( last_user_sync_ts=self.clock.time_msec() @@ -480,8 +479,7 @@ class PresenceHandler(object): else: return set() - @defer.inlineCallbacks - def update_external_syncs_row( + async def update_external_syncs_row( self, process_id, user_id, is_syncing, sync_time_msec ): """Update the syncing users for an external process as a delta. @@ -494,8 +492,8 @@ class PresenceHandler(object): is_syncing (bool): Whether or not the user is now syncing sync_time_msec(int): Time in ms when the user was last syncing """ - with (yield self.external_sync_linearizer.queue(process_id)): - prev_state = yield self.current_state_for_user(user_id) + with (await self.external_sync_linearizer.queue(process_id)): + prev_state = await self.current_state_for_user(user_id) process_presence = self.external_process_to_current_syncs.setdefault( process_id, set() @@ -525,25 +523,24 @@ class PresenceHandler(object): process_presence.discard(user_id) if updates: - yield self._update_states(updates) + await self._update_states(updates) self.external_process_last_updated_ms[process_id] = self.clock.time_msec() - @defer.inlineCallbacks - def update_external_syncs_clear(self, process_id): + async def update_external_syncs_clear(self, process_id): """Marks all users that had been marked as syncing by a given process as offline. Used when the process has stopped/disappeared. """ - with (yield self.external_sync_linearizer.queue(process_id)): + with (await self.external_sync_linearizer.queue(process_id)): process_presence = self.external_process_to_current_syncs.pop( process_id, set() ) - prev_states = yield self.current_state_for_users(process_presence) + prev_states = await self.current_state_for_users(process_presence) time_now_ms = self.clock.time_msec() - yield self._update_states( + await self._update_states( [ prev_state.copy_and_replace(last_user_sync_ts=time_now_ms) for prev_state in itervalues(prev_states) @@ -551,15 +548,13 @@ class PresenceHandler(object): ) self.external_process_last_updated_ms.pop(process_id, None) - @defer.inlineCallbacks - def current_state_for_user(self, user_id): + async def current_state_for_user(self, user_id): """Get the current presence state for a user. """ - res = yield self.current_state_for_users([user_id]) + res = await self.current_state_for_users([user_id]) return res[user_id] - @defer.inlineCallbacks - def current_state_for_users(self, user_ids): + async def current_state_for_users(self, user_ids): """Get the current presence state for multiple users. Returns: @@ -574,7 +569,7 @@ class PresenceHandler(object): if missing: # There are things not in our in memory cache. Lets pull them out of # the database. - res = yield self.store.get_presence_for_users(missing) + res = await self.store.get_presence_for_users(missing) states.update(res) missing = [user_id for user_id, state in iteritems(states) if not state] @@ -587,14 +582,13 @@ class PresenceHandler(object): return states - @defer.inlineCallbacks - def _persist_and_notify(self, states): + async def _persist_and_notify(self, states): """Persist states in the database, poke the notifier and send to interested remote servers """ - stream_id, max_token = yield self.store.update_presence(states) + stream_id, max_token = await self.store.update_presence(states) - parties = yield get_interested_parties(self.store, states) + parties = await get_interested_parties(self.store, states) room_ids_to_states, users_to_states = parties self.notifier.on_new_event( @@ -606,9 +600,8 @@ class PresenceHandler(object): self._push_to_remotes(states) - @defer.inlineCallbacks - def notify_for_states(self, state, stream_id): - parties = yield get_interested_parties(self.store, [state]) + async def notify_for_states(self, state, stream_id): + parties = await get_interested_parties(self.store, [state]) room_ids_to_states, users_to_states = parties self.notifier.on_new_event( @@ -626,8 +619,7 @@ class PresenceHandler(object): """ self.federation.send_presence(states) - @defer.inlineCallbacks - def incoming_presence(self, origin, content): + async def incoming_presence(self, origin, content): """Called when we receive a `m.presence` EDU from a remote server. """ now = self.clock.time_msec() @@ -670,21 +662,19 @@ class PresenceHandler(object): new_fields["status_msg"] = push.get("status_msg", None) new_fields["currently_active"] = push.get("currently_active", False) - prev_state = yield self.current_state_for_user(user_id) + prev_state = await self.current_state_for_user(user_id) updates.append(prev_state.copy_and_replace(**new_fields)) if updates: federation_presence_counter.inc(len(updates)) - yield self._update_states(updates) + await self._update_states(updates) - @defer.inlineCallbacks - def get_state(self, target_user, as_event=False): - results = yield self.get_states([target_user.to_string()], as_event=as_event) + async def get_state(self, target_user, as_event=False): + results = await self.get_states([target_user.to_string()], as_event=as_event) return results[0] - @defer.inlineCallbacks - def get_states(self, target_user_ids, as_event=False): + async def get_states(self, target_user_ids, as_event=False): """Get the presence state for users. Args: @@ -695,7 +685,7 @@ class PresenceHandler(object): list """ - updates = yield self.current_state_for_users(target_user_ids) + updates = await self.current_state_for_users(target_user_ids) updates = list(updates.values()) for user_id in set(target_user_ids) - {u.user_id for u in updates}: @@ -713,8 +703,7 @@ class PresenceHandler(object): else: return updates - @defer.inlineCallbacks - def set_state(self, target_user, state, ignore_status_msg=False): + async def set_state(self, target_user, state, ignore_status_msg=False): """Set the presence state of the user. """ status_msg = state.get("status_msg", None) @@ -730,7 +719,7 @@ class PresenceHandler(object): user_id = target_user.to_string() - prev_state = yield self.current_state_for_user(user_id) + prev_state = await self.current_state_for_user(user_id) new_fields = {"state": presence} @@ -741,16 +730,15 @@ class PresenceHandler(object): if presence == PresenceState.ONLINE: new_fields["last_active_ts"] = self.clock.time_msec() - yield self._update_states([prev_state.copy_and_replace(**new_fields)]) + await self._update_states([prev_state.copy_and_replace(**new_fields)]) - @defer.inlineCallbacks - def is_visible(self, observed_user, observer_user): + async def is_visible(self, observed_user, observer_user): """Returns whether a user can see another user's presence. """ - observer_room_ids = yield self.store.get_rooms_for_user( + observer_room_ids = await self.store.get_rooms_for_user( observer_user.to_string() ) - observed_room_ids = yield self.store.get_rooms_for_user( + observed_room_ids = await self.store.get_rooms_for_user( observed_user.to_string() ) @@ -759,8 +747,7 @@ class PresenceHandler(object): return False - @defer.inlineCallbacks - def get_all_presence_updates(self, last_id, current_id): + async def get_all_presence_updates(self, last_id, current_id): """ Gets a list of presence update rows from between the given stream ids. Each row has: @@ -775,7 +762,7 @@ class PresenceHandler(object): """ # TODO(markjh): replicate the unpersisted changes. # This could use the in-memory stores for recent changes. - rows = yield self.store.get_all_presence_updates(last_id, current_id) + rows = await self.store.get_all_presence_updates(last_id, current_id) return rows def notify_new_event(self): @@ -786,20 +773,18 @@ class PresenceHandler(object): if self._event_processing: return - @defer.inlineCallbacks - def _process_presence(): + async def _process_presence(): assert not self._event_processing self._event_processing = True try: - yield self._unsafe_process() + await self._unsafe_process() finally: self._event_processing = False run_as_background_process("presence.notify_new_event", _process_presence) - @defer.inlineCallbacks - def _unsafe_process(self): + async def _unsafe_process(self): # Loop round handling deltas until we're up to date while True: with Measure(self.clock, "presence_delta"): @@ -812,10 +797,10 @@ class PresenceHandler(object): self._event_pos, room_max_stream_ordering, ) - max_pos, deltas = yield self.store.get_current_state_deltas( + max_pos, deltas = await self.store.get_current_state_deltas( self._event_pos, room_max_stream_ordering ) - yield self._handle_state_delta(deltas) + await self._handle_state_delta(deltas) self._event_pos = max_pos @@ -824,8 +809,7 @@ class PresenceHandler(object): max_pos ) - @defer.inlineCallbacks - def _handle_state_delta(self, deltas): + async def _handle_state_delta(self, deltas): """Process current state deltas to find new joins that need to be handled. """ @@ -846,13 +830,13 @@ class PresenceHandler(object): # joins. continue - event = yield self.store.get_event(event_id, allow_none=True) + event = await self.store.get_event(event_id, allow_none=True) if not event or event.content.get("membership") != Membership.JOIN: # We only care about joins continue if prev_event_id: - prev_event = yield self.store.get_event(prev_event_id, allow_none=True) + prev_event = await self.store.get_event(prev_event_id, allow_none=True) if ( prev_event and prev_event.content.get("membership") == Membership.JOIN @@ -860,10 +844,9 @@ class PresenceHandler(object): # Ignore changes to join events. continue - yield self._on_user_joined_room(room_id, state_key) + await self._on_user_joined_room(room_id, state_key) - @defer.inlineCallbacks - def _on_user_joined_room(self, room_id, user_id): + async def _on_user_joined_room(self, room_id, user_id): """Called when we detect a user joining the room via the current state delta stream. @@ -882,8 +865,8 @@ class PresenceHandler(object): # TODO: We should be able to filter the hosts down to those that # haven't previously seen the user - state = yield self.current_state_for_user(user_id) - hosts = yield self.state.get_current_hosts_in_room(room_id) + state = await self.current_state_for_user(user_id) + hosts = await self.state.get_current_hosts_in_room(room_id) # Filter out ourselves. hosts = {host for host in hosts if host != self.server_name} @@ -903,10 +886,10 @@ class PresenceHandler(object): # TODO: Check that this is actually a new server joining the # room. - user_ids = yield self.state.get_current_users_in_room(room_id) + user_ids = await self.state.get_current_users_in_room(room_id) user_ids = list(filter(self.is_mine_id, user_ids)) - states = yield self.current_state_for_users(user_ids) + states = await self.current_state_for_users(user_ids) # Filter out old presence, i.e. offline presence states where # the user hasn't been active for a week. We can change this @@ -996,9 +979,8 @@ class PresenceEventSource(object): self.store = hs.get_datastore() self.state = hs.get_state_handler() - @defer.inlineCallbacks @log_function - def get_new_events( + async def get_new_events( self, user, from_key, @@ -1045,7 +1027,7 @@ class PresenceEventSource(object): presence = self.get_presence_handler() stream_change_cache = self.store.presence_stream_cache - users_interested_in = yield self._get_interested_in(user, explicit_room_id) + users_interested_in = await self._get_interested_in(user, explicit_room_id) user_ids_changed = set() changed = None @@ -1071,7 +1053,7 @@ class PresenceEventSource(object): else: user_ids_changed = users_interested_in - updates = yield presence.current_state_for_users(user_ids_changed) + updates = await presence.current_state_for_users(user_ids_changed) if include_offline: return (list(updates.values()), max_token) @@ -1084,11 +1066,11 @@ class PresenceEventSource(object): def get_current_key(self): return self.store.get_current_presence_token() - def get_pagination_rows(self, user, pagination_config, key): - return self.get_new_events(user, from_key=None, include_offline=False) + async def get_pagination_rows(self, user, pagination_config, key): + return await self.get_new_events(user, from_key=None, include_offline=False) - @cachedInlineCallbacks(num_args=2, cache_context=True) - def _get_interested_in(self, user, explicit_room_id, cache_context): + @cached(num_args=2, cache_context=True) + async def _get_interested_in(self, user, explicit_room_id, cache_context): """Returns the set of users that the given user should see presence updates for """ @@ -1096,13 +1078,13 @@ class PresenceEventSource(object): users_interested_in = set() users_interested_in.add(user_id) # So that we receive our own presence - users_who_share_room = yield self.store.get_users_who_share_room_with_user( + users_who_share_room = await self.store.get_users_who_share_room_with_user( user_id, on_invalidate=cache_context.invalidate ) users_interested_in.update(users_who_share_room) if explicit_room_id: - user_ids = yield self.store.get_users_in_room( + user_ids = await self.store.get_users_in_room( explicit_room_id, on_invalidate=cache_context.invalidate ) users_interested_in.update(user_ids) @@ -1277,8 +1259,8 @@ def get_interested_parties(store, states): 2-tuple: `(room_ids_to_states, users_to_states)`, with each item being a dict of `entity_name` -> `[UserPresenceState]` """ - room_ids_to_states = {} - users_to_states = {} + room_ids_to_states = {} # type: Dict[str, List[UserPresenceState]] + users_to_states = {} # type: Dict[str, List[UserPresenceState]] for state in states: room_ids = yield store.get_rooms_for_user(state.user_id) for room_id in room_ids: diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index ce60ae2e07..ce9d1fae12 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -323,7 +323,11 @@ class ReplicationStreamer(object): # We need to tell the presence handler that the connection has been # lost so that it can handle any ongoing syncs on that connection. - self.presence_handler.update_external_syncs_clear(connection.conn_id) + run_as_background_process( + "update_external_syncs_clear", + self.presence_handler.update_external_syncs_clear, + connection.conn_id, + ) def _batch_updates(updates): diff --git a/synapse/server.pyi b/synapse/server.pyi index 40eabfe5d9..3844f0e12f 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -3,6 +3,7 @@ import twisted.internet import synapse.api.auth import synapse.config.homeserver import synapse.crypto.keyring +import synapse.federation.federation_server import synapse.federation.sender import synapse.federation.transport.client import synapse.handlers @@ -107,5 +108,9 @@ class HomeServer(object): self, ) -> synapse.replication.tcp.client.ReplicationClientHandler: pass + def get_federation_registry( + self, + ) -> synapse.federation.federation_server.FederationHandlerRegistry: + pass def is_mine_id(self, domain_id: str) -> bool: pass diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 64915bafcd..05ea40a7de 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -494,8 +494,10 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): self.helper.join(room_id, "@test2:server") # Mark test2 as online, test will be offline with a last_active of 0 - self.presence_handler.set_state( - UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} + self.get_success( + self.presence_handler.set_state( + UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} + ) ) self.reactor.pump([0]) # Wait for presence updates to be handled @@ -543,14 +545,18 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): room_id = self.helper.create_room_as(self.user_id) # Mark test as online - self.presence_handler.set_state( - UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE} + self.get_success( + self.presence_handler.set_state( + UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE} + ) ) # Mark test2 as online, test will be offline with a last_active of 0. # Note we don't join them to the room yet - self.presence_handler.set_state( - UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} + self.get_success( + self.presence_handler.set_state( + UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} + ) ) # Add servers to the room diff --git a/tox.ini b/tox.ini index b715ea0bff..4ccfde01b5 100644 --- a/tox.ini +++ b/tox.ini @@ -183,6 +183,7 @@ commands = mypy \ synapse/events/spamcheck.py \ synapse/federation/sender \ synapse/federation/transport \ + synapse/handlers/presence.py \ synapse/handlers/sync.py \ synapse/handlers/ui_auth \ synapse/logging/ \ -- cgit 1.5.1 From 3e99528f2bfaa686c4708fb8efcddce935b2397d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 26 Feb 2020 16:58:33 +0000 Subject: Store room version on invite (#6983) When we get an invite over federation, store the room version in the rooms table. The general idea here is that, when we pull the invite out again, we'll want to know what room_version it belongs to (so that we can later redact it if need be). So we need to store it somewhere... --- changelog.d/6983.misc | 1 + synapse/handlers/federation.py | 12 +++++++++++ synapse/replication/http/_base.py | 2 +- synapse/replication/http/federation.py | 36 +++++++++++++++++++++++++++++++- synapse/storage/data_stores/main/room.py | 20 ++++++++++++++++++ tests/app/test_openid_listener.py | 8 +++++++ tests/handlers/test_typing.py | 1 + 7 files changed, 78 insertions(+), 2 deletions(-) create mode 100644 changelog.d/6983.misc (limited to 'tests') diff --git a/changelog.d/6983.misc b/changelog.d/6983.misc new file mode 100644 index 0000000000..08aa80bcd9 --- /dev/null +++ b/changelog.d/6983.misc @@ -0,0 +1 @@ +Refactoring work in preparation for changing the event redaction algorithm. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index c2e6ee266d..38ab6a8fc3 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -60,6 +60,7 @@ from synapse.replication.http.devices import ReplicationUserDevicesResyncRestSer from synapse.replication.http.federation import ( ReplicationCleanRoomRestServlet, ReplicationFederationSendEventsRestServlet, + ReplicationStoreRoomOnInviteRestServlet, ) from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet from synapse.state import StateResolutionStore, resolve_events_with_store @@ -160,8 +161,12 @@ class FederationHandler(BaseHandler): self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client( hs ) + self._maybe_store_room_on_invite = ReplicationStoreRoomOnInviteRestServlet.make_client( + hs + ) else: self._device_list_updater = hs.get_device_handler().device_list_updater + self._maybe_store_room_on_invite = self.store.maybe_store_room_on_invite # When joining a room we need to queue any events for that room up self.room_queues = {} @@ -1537,6 +1542,13 @@ class FederationHandler(BaseHandler): if event.state_key == self._server_notices_mxid: raise SynapseError(http_client.FORBIDDEN, "Cannot invite this user") + # keep a record of the room version, if we don't yet know it. + # (this may get overwritten if we later get a different room version in a + # join dance). + await self._maybe_store_room_on_invite( + room_id=event.room_id, room_version=room_version + ) + event.internal_metadata.outlier = True event.internal_metadata.out_of_band_membership = True diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 444eb7b7f4..1be1ccbdf3 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -44,7 +44,7 @@ class ReplicationEndpoint(object): """Helper base class for defining new replication HTTP endpoints. This creates an endpoint under `/_synapse/replication/:NAME/:PATH_ARGS..` - (with an `/:txn_id` prefix for cached requests.), where NAME is a name, + (with a `/:txn_id` suffix for cached requests), where NAME is a name, PATH_ARGS are a tuple of parameters to be encoded in the URL. For example, if `NAME` is "send_event" and `PATH_ARGS` is `("event_id",)`, diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 49a3251372..8794720101 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -17,6 +17,7 @@ import logging from twisted.internet import defer +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import event_type_from_format_version from synapse.events.snapshot import EventContext from synapse.http.servlet import parse_json_object_from_request @@ -211,7 +212,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint): Request format: - POST /_synapse/replication/fed_query/:fed_cleanup_room/:txn_id + POST /_synapse/replication/fed_cleanup_room/:room_id/:txn_id {} """ @@ -238,8 +239,41 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint): return 200, {} +class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint): + """Called to clean up any data in DB for a given room, ready for the + server to join the room. + + Request format: + + POST /_synapse/replication/store_room_on_invite/:room_id/:txn_id + + { + "room_version": "1", + } + """ + + NAME = "store_room_on_invite" + PATH_ARGS = ("room_id",) + + def __init__(self, hs): + super().__init__(hs) + + self.store = hs.get_datastore() + + @staticmethod + def _serialize_payload(room_id, room_version): + return {"room_version": room_version.identifier} + + async def _handle_request(self, request, room_id): + content = parse_json_object_from_request(request) + room_version = KNOWN_ROOM_VERSIONS[content["room_version"]] + await self.store.maybe_store_room_on_invite(room_id, room_version) + return 200, {} + + def register_servlets(hs, http_server): ReplicationFederationSendEventsRestServlet(hs).register(http_server) ReplicationFederationSendEduRestServlet(hs).register(http_server) ReplicationGetQueryRestServlet(hs).register(http_server) ReplicationCleanRoomRestServlet(hs).register(http_server) + ReplicationStoreRoomOnInviteRestServlet(hs).register(http_server) diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index 70137dfbe4..e6c10c6316 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -1020,6 +1020,26 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") + async def maybe_store_room_on_invite(self, room_id: str, room_version: RoomVersion): + """ + When we receive an invite over federation, store the version of the room if we + don't already know the room version. + """ + await self.db.simple_upsert( + desc="maybe_store_room_on_invite", + table="rooms", + keyvalues={"room_id": room_id}, + values={}, + insertion_values={ + "room_version": room_version.identifier, + "is_public": False, + "creator": "", + }, + # rooms has a unique constraint on room_id, so no need to lock when doing an + # emulated upsert. + lock=False, + ) + @defer.inlineCallbacks def set_room_is_public(self, room_id, is_public): def set_room_is_public_txn(txn, next_id): diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py index 1fe048048b..89fcc3889a 100644 --- a/tests/app/test_openid_listener.py +++ b/tests/app/test_openid_listener.py @@ -29,6 +29,14 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase): ) return hs + def default_config(self, name="test"): + conf = super().default_config(name) + # we're using FederationReaderServer, which uses a SlavedStore, so we + # have to tell the FederationHandler not to try to access stuff that is only + # in the primary store. + conf["worker_app"] = "yes" + return conf + @parameterized.expand( [ (["federation"], "auth_fail"), diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 07b204666e..51e2b37218 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -74,6 +74,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): "set_received_txn_response", "get_destination_retry_timings", "get_devices_by_remote", + "maybe_store_room_on_invite", # Bits that user_directory needs "get_user_directory_stream_pos", "get_current_state_deltas", -- cgit 1.5.1 From cab4a52535097c5836fe67c5e09e8350d7ccf03c Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 27 Feb 2020 13:08:43 +0000 Subject: set worker_app for frontend proxy test (#7003) to stop the federationhandler trying to do master stuff --- changelog.d/7003.misc | 1 + tests/app/test_frontend_proxy.py | 5 +++++ 2 files changed, 6 insertions(+) create mode 100644 changelog.d/7003.misc (limited to 'tests') diff --git a/changelog.d/7003.misc b/changelog.d/7003.misc new file mode 100644 index 0000000000..08aa80bcd9 --- /dev/null +++ b/changelog.d/7003.misc @@ -0,0 +1 @@ +Refactoring work in preparation for changing the event redaction algorithm. diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py index 160e55aca9..d3feafa1b7 100644 --- a/tests/app/test_frontend_proxy.py +++ b/tests/app/test_frontend_proxy.py @@ -27,6 +27,11 @@ class FrontendProxyTests(HomeserverTestCase): return hs + def default_config(self, name="test"): + c = super().default_config(name) + c["worker_app"] = "synapse.app.frontend_proxy" + return c + def test_listen_http_with_presence_enabled(self): """ When presence is on, the stub servlet will not register. -- cgit 1.5.1 From 9b06d8f8a62dc5c423aa9a694e0759eaf1c3c77e Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 28 Feb 2020 10:58:05 +0100 Subject: Fixed set a user as an admin with the new API (#6928) Fix #6910 --- changelog.d/6910.bugfix | 1 + synapse/rest/admin/users.py | 6 +- synapse/storage/data_stores/main/registration.py | 16 +- tests/rest/admin/test_user.py | 218 +++++++++++++++++++---- 4 files changed, 199 insertions(+), 42 deletions(-) create mode 100644 changelog.d/6910.bugfix (limited to 'tests') diff --git a/changelog.d/6910.bugfix b/changelog.d/6910.bugfix new file mode 100644 index 0000000000..707f1ff7b5 --- /dev/null +++ b/changelog.d/6910.bugfix @@ -0,0 +1 @@ +Fixed set a user as an admin with the admin API `PUT /_synapse/admin/v2/users/`. Contributed by @dklimpel. diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index c5b461a236..80f959248d 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -211,9 +211,7 @@ class UserRestServletV2(RestServlet): if target_user == auth_user and not set_admin_to: raise SynapseError(400, "You may not demote yourself.") - await self.admin_handler.set_user_server_admin( - target_user, set_admin_to - ) + await self.store.set_server_admin(target_user, set_admin_to) if "password" in body: if ( @@ -651,6 +649,6 @@ class UserAdminServlet(RestServlet): if target_user == auth_user and not set_admin_to: raise SynapseError(400, "You may not demote yourself.") - await self.store.set_user_server_admin(target_user, set_admin_to) + await self.store.set_server_admin(target_user, set_admin_to) return 200, {} diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py index 49306642ed..3e53c8568a 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/data_stores/main/registration.py @@ -301,12 +301,16 @@ class RegistrationWorkerStore(SQLBaseStore): admin (bool): true iff the user is to be a server admin, false otherwise. """ - return self.db.simple_update_one( - table="users", - keyvalues={"name": user.to_string()}, - updatevalues={"admin": 1 if admin else 0}, - desc="set_server_admin", - ) + + def set_server_admin_txn(txn): + self.db.simple_update_one_txn( + txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0} + ) + self._invalidate_cache_and_stream( + txn, self.get_user_by_id, (user.to_string(),) + ) + + return self.db.runInteraction("set_server_admin", set_server_admin_txn) def _query_for_auth(self, txn, token): sql = ( diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index cbe4a6a51f..6416fb5d2a 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -16,6 +16,7 @@ import hashlib import hmac import json +import urllib.parse from mock import Mock @@ -371,22 +372,24 @@ class UserRestTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - self.url = "/_synapse/admin/v2/users/@bob:test" - self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") self.other_user = self.register_user("user", "pass") self.other_user_token = self.login("user", "pass") + self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote( + self.other_user + ) def test_requester_is_no_admin(self): """ If the user is not a server admin, an error is returned. """ self.hs.config.registration_shared_secret = None + url = "/_synapse/admin/v2/users/@bob:test" request, channel = self.make_request( - "GET", self.url, access_token=self.other_user_token, + "GET", url, access_token=self.other_user_token, ) self.render(request) @@ -394,7 +397,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("You are not a server admin", channel.json_body["error"]) request, channel = self.make_request( - "PUT", self.url, access_token=self.other_user_token, content=b"{}", + "PUT", url, access_token=self.other_user_token, content=b"{}", ) self.render(request) @@ -417,24 +420,73 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"]) - def test_requester_is_admin(self): + def test_create_server_admin(self): """ - If the user is a server admin, a new user is created. + Check that a new admin user is created successfully. """ self.hs.config.registration_shared_secret = None + url = "/_synapse/admin/v2/users/@bob:test" + # Create user (server admin) body = json.dumps( { "password": "abc123", "admin": True, + "displayname": "Bob's name", "threepids": [{"medium": "email", "address": "bob@bob.bob"}], } ) + request, channel = self.make_request( + "PUT", + url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual("Bob's name", channel.json_body["displayname"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual(True, channel.json_body["admin"]) + + # Get user + request, channel = self.make_request( + "GET", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual("Bob's name", channel.json_body["displayname"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual(True, channel.json_body["admin"]) + self.assertEqual(False, channel.json_body["is_guest"]) + self.assertEqual(False, channel.json_body["deactivated"]) + + def test_create_user(self): + """ + Check that a new regular user is created successfully. + """ + self.hs.config.registration_shared_secret = None + url = "/_synapse/admin/v2/users/@bob:test" + # Create user + body = json.dumps( + { + "password": "abc123", + "admin": False, + "displayname": "Bob's name", + "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + } + ) + request, channel = self.make_request( "PUT", - self.url, + url, access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) @@ -442,29 +494,38 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) - self.assertEqual("bob", channel.json_body["displayname"]) + self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual(False, channel.json_body["admin"]) # Get user request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, + "GET", url, access_token=self.admin_user_tok, ) self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) - self.assertEqual("bob", channel.json_body["displayname"]) - self.assertEqual(1, channel.json_body["admin"]) - self.assertEqual(0, channel.json_body["is_guest"]) - self.assertEqual(0, channel.json_body["deactivated"]) + self.assertEqual("Bob's name", channel.json_body["displayname"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual(False, channel.json_body["admin"]) + self.assertEqual(False, channel.json_body["is_guest"]) + self.assertEqual(False, channel.json_body["deactivated"]) + + def test_set_password(self): + """ + Test setting a new password for another user. + """ + self.hs.config.registration_shared_secret = None # Change password body = json.dumps({"password": "hahaha"}) request, channel = self.make_request( "PUT", - self.url, + self.url_other_user, access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) @@ -472,41 +533,133 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + def test_set_displayname(self): + """ + Test setting the displayname of another user. + """ + self.hs.config.registration_shared_secret = None + # Modify user + body = json.dumps({"displayname": "foobar"}) + + request, channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual("foobar", channel.json_body["displayname"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual("foobar", channel.json_body["displayname"]) + + def test_set_threepid(self): + """ + Test setting threepid for an other user. + """ + self.hs.config.registration_shared_secret = None + + # Delete old and add new threepid to user body = json.dumps( - { - "displayname": "foobar", - "deactivated": True, - "threepids": [{"medium": "email", "address": "bob2@bob.bob"}], - } + {"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]} ) request, channel = self.make_request( "PUT", - self.url, + self.url_other_user, access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("@bob:test", channel.json_body["name"]) - self.assertEqual("foobar", channel.json_body["displayname"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"]) + + def test_deactivate_user(self): + """ + Test deactivating another user. + """ + + # Deactivate user + body = json.dumps({"deactivated": True}) + + request, channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) # the user is deactivated, the threepid will be deleted # Get user request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, + "GET", self.url_other_user, access_token=self.admin_user_tok, ) self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("@bob:test", channel.json_body["name"]) - self.assertEqual("foobar", channel.json_body["displayname"]) - self.assertEqual(1, channel.json_body["admin"]) - self.assertEqual(0, channel.json_body["is_guest"]) - self.assertEqual(1, channel.json_body["deactivated"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(True, channel.json_body["deactivated"]) + + def test_set_user_as_admin(self): + """ + Test setting the admin flag on a user. + """ + self.hs.config.registration_shared_secret = None + + # Set a user as an admin + body = json.dumps({"admin": True}) + + request, channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(True, channel.json_body["admin"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(True, channel.json_body["admin"]) def test_accidental_deactivation_prevention(self): """ @@ -514,13 +667,14 @@ class UserRestTestCase(unittest.HomeserverTestCase): for the deactivated body parameter """ self.hs.config.registration_shared_secret = None + url = "/_synapse/admin/v2/users/@bob:test" # Create user body = json.dumps({"password": "abc123"}) request, channel = self.make_request( "PUT", - self.url, + url, access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) @@ -532,7 +686,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Get user request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, + "GET", url, access_token=self.admin_user_tok, ) self.render(request) @@ -546,7 +700,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "PUT", - self.url, + url, access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) @@ -556,7 +710,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Check user is not deactivated request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, + "GET", url, access_token=self.admin_user_tok, ) self.render(request) -- cgit 1.5.1 From bbeee33d63c43cb80118c0dccf8abd9d4ac1b8f3 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 28 Feb 2020 10:58:05 +0100 Subject: Fixed set a user as an admin with the new API (#6928) Fix #6910 --- changelog.d/6910.bugfix | 1 + synapse/rest/admin/users.py | 6 +- synapse/storage/data_stores/main/registration.py | 16 +- tests/rest/admin/test_user.py | 209 ++++++++++++++++++++--- 4 files changed, 194 insertions(+), 38 deletions(-) create mode 100644 changelog.d/6910.bugfix (limited to 'tests') diff --git a/changelog.d/6910.bugfix b/changelog.d/6910.bugfix new file mode 100644 index 0000000000..707f1ff7b5 --- /dev/null +++ b/changelog.d/6910.bugfix @@ -0,0 +1 @@ +Fixed set a user as an admin with the admin API `PUT /_synapse/admin/v2/users/`. Contributed by @dklimpel. diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 2107b5dc56..064908fbb0 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -211,9 +211,7 @@ class UserRestServletV2(RestServlet): if target_user == auth_user and not set_admin_to: raise SynapseError(400, "You may not demote yourself.") - await self.admin_handler.set_user_server_admin( - target_user, set_admin_to - ) + await self.store.set_server_admin(target_user, set_admin_to) if "password" in body: if ( @@ -648,6 +646,6 @@ class UserAdminServlet(RestServlet): if target_user == auth_user and not set_admin_to: raise SynapseError(400, "You may not demote yourself.") - await self.store.set_user_server_admin(target_user, set_admin_to) + await self.store.set_server_admin(target_user, set_admin_to) return 200, {} diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py index 49306642ed..3e53c8568a 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/data_stores/main/registration.py @@ -301,12 +301,16 @@ class RegistrationWorkerStore(SQLBaseStore): admin (bool): true iff the user is to be a server admin, false otherwise. """ - return self.db.simple_update_one( - table="users", - keyvalues={"name": user.to_string()}, - updatevalues={"admin": 1 if admin else 0}, - desc="set_server_admin", - ) + + def set_server_admin_txn(txn): + self.db.simple_update_one_txn( + txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0} + ) + self._invalidate_cache_and_stream( + txn, self.get_user_by_id, (user.to_string(),) + ) + + return self.db.runInteraction("set_server_admin", set_server_admin_txn) def _query_for_auth(self, txn, token): sql = ( diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 490ce8f55d..70688c2494 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -16,6 +16,7 @@ import hashlib import hmac import json +import urllib.parse from mock import Mock @@ -371,22 +372,24 @@ class UserRestTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - self.url = "/_synapse/admin/v2/users/@bob:test" - self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") self.other_user = self.register_user("user", "pass") self.other_user_token = self.login("user", "pass") + self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote( + self.other_user + ) def test_requester_is_no_admin(self): """ If the user is not a server admin, an error is returned. """ self.hs.config.registration_shared_secret = None + url = "/_synapse/admin/v2/users/@bob:test" request, channel = self.make_request( - "GET", self.url, access_token=self.other_user_token, + "GET", url, access_token=self.other_user_token, ) self.render(request) @@ -394,7 +397,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("You are not a server admin", channel.json_body["error"]) request, channel = self.make_request( - "PUT", self.url, access_token=self.other_user_token, content=b"{}", + "PUT", url, access_token=self.other_user_token, content=b"{}", ) self.render(request) @@ -417,24 +420,73 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"]) - def test_requester_is_admin(self): + def test_create_server_admin(self): """ - If the user is a server admin, a new user is created. + Check that a new admin user is created successfully. """ self.hs.config.registration_shared_secret = None + url = "/_synapse/admin/v2/users/@bob:test" + # Create user (server admin) body = json.dumps( { "password": "abc123", "admin": True, + "displayname": "Bob's name", "threepids": [{"medium": "email", "address": "bob@bob.bob"}], } ) + request, channel = self.make_request( + "PUT", + url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual("Bob's name", channel.json_body["displayname"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual(True, channel.json_body["admin"]) + + # Get user + request, channel = self.make_request( + "GET", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual("Bob's name", channel.json_body["displayname"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual(True, channel.json_body["admin"]) + self.assertEqual(False, channel.json_body["is_guest"]) + self.assertEqual(False, channel.json_body["deactivated"]) + + def test_create_user(self): + """ + Check that a new regular user is created successfully. + """ + self.hs.config.registration_shared_secret = None + url = "/_synapse/admin/v2/users/@bob:test" + # Create user + body = json.dumps( + { + "password": "abc123", + "admin": False, + "displayname": "Bob's name", + "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + } + ) + request, channel = self.make_request( "PUT", - self.url, + url, access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) @@ -442,29 +494,38 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) - self.assertEqual("bob", channel.json_body["displayname"]) + self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual(False, channel.json_body["admin"]) # Get user request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, + "GET", url, access_token=self.admin_user_tok, ) self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) - self.assertEqual("bob", channel.json_body["displayname"]) - self.assertEqual(1, channel.json_body["admin"]) - self.assertEqual(0, channel.json_body["is_guest"]) - self.assertEqual(0, channel.json_body["deactivated"]) + self.assertEqual("Bob's name", channel.json_body["displayname"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual(False, channel.json_body["admin"]) + self.assertEqual(False, channel.json_body["is_guest"]) + self.assertEqual(False, channel.json_body["deactivated"]) + + def test_set_password(self): + """ + Test setting a new password for another user. + """ + self.hs.config.registration_shared_secret = None # Change password body = json.dumps({"password": "hahaha"}) request, channel = self.make_request( "PUT", - self.url, + self.url_other_user, access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) @@ -472,38 +533,130 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + def test_set_displayname(self): + """ + Test setting the displayname of another user. + """ + self.hs.config.registration_shared_secret = None + # Modify user + body = json.dumps({"displayname": "foobar"}) + + request, channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual("foobar", channel.json_body["displayname"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual("foobar", channel.json_body["displayname"]) + + def test_set_threepid(self): + """ + Test setting threepid for an other user. + """ + self.hs.config.registration_shared_secret = None + + # Delete old and add new threepid to user body = json.dumps( - { - "displayname": "foobar", - "deactivated": True, - "threepids": [{"medium": "email", "address": "bob2@bob.bob"}], - } + {"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]} ) request, channel = self.make_request( "PUT", - self.url, + self.url_other_user, access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("@bob:test", channel.json_body["name"]) - self.assertEqual("foobar", channel.json_body["displayname"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"]) + + def test_deactivate_user(self): + """ + Test deactivating another user. + """ + + # Deactivate user + body = json.dumps({"deactivated": True}) + + request, channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) # the user is deactivated, the threepid will be deleted # Get user request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, + "GET", self.url_other_user, access_token=self.admin_user_tok, ) self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("@bob:test", channel.json_body["name"]) - self.assertEqual("foobar", channel.json_body["displayname"]) - self.assertEqual(1, channel.json_body["admin"]) - self.assertEqual(0, channel.json_body["is_guest"]) - self.assertEqual(1, channel.json_body["deactivated"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(True, channel.json_body["deactivated"]) + + def test_set_user_as_admin(self): + """ + Test setting the admin flag on a user. + """ + self.hs.config.registration_shared_secret = None + + # Set a user as an admin + body = json.dumps({"admin": True}) + + request, channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(True, channel.json_body["admin"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(True, channel.json_body["admin"]) -- cgit 1.5.1 From b2bd54a2e31d9a248f73fadb184ae9b4cbdb49f9 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 2 Mar 2020 16:36:32 +0000 Subject: Add a confirmation step to the SSO login flow --- docs/sample_config.yaml | 34 ++++++++++ synapse/config/_base.pyi | 2 + synapse/config/homeserver.py | 2 + synapse/config/sso.py | 74 +++++++++++++++++++++ synapse/res/templates/sso_redirect_confirm.html | 14 ++++ synapse/rest/client/v1/login.py | 40 ++++++++++-- tests/rest/client/v1/test_login.py | 85 +++++++++++++++++++++++++ 7 files changed, 245 insertions(+), 6 deletions(-) create mode 100644 synapse/config/sso.py create mode 100644 synapse/res/templates/sso_redirect_confirm.html (limited to 'tests') diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 8a036071e1..bbb8a4d934 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1360,6 +1360,40 @@ saml2_config: # # name: value +# Additional settings to use with single-sign on systems such as SAML2 and CAS. +# +sso: + # Directory in which Synapse will try to find the template files below. + # If not set, default templates from within the Synapse package will be used. + # + # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates. + # If you *do* uncomment it, you will need to make sure that all the templates + # below are in the directory. + # + # Synapse will look for the following templates in this directory: + # + # * HTML page for confirmation of redirect during authentication: + # 'sso_redirect_confirm.html'. + # + # When rendering, this template is given three variables: + # * redirect_url: the URL the user is about to be redirected to. Needs + # manual escaping (see + # https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping). + # + # * display_url: the same as `redirect_url`, but with the query + # parameters stripped. The intention is to have a + # human-readable URL to show to users, not to use it as + # the final address to redirect to. Needs manual escaping + # (see https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping). + # + # * server_name: the homeserver's name. + # + # You can see the default templates at: + # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates + # + #template_dir: "res/templates" + + # The JWT needs to contain a globally unique "sub" (subject) claim. # #jwt_config: diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 86bc965ee4..3053fc9d27 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -24,6 +24,7 @@ from synapse.config import ( server, server_notices_config, spam_checker, + sso, stats, third_party_event_rules, tls, @@ -57,6 +58,7 @@ class RootConfig: key: key.KeyConfig saml2: saml2_config.SAML2Config cas: cas.CasConfig + sso: sso.SSOConfig jwt: jwt_config.JWTConfig password: password.PasswordConfig email: emailconfig.EmailConfig diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 6e348671c7..b4bca08b20 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -38,6 +38,7 @@ from .saml2_config import SAML2Config from .server import ServerConfig from .server_notices_config import ServerNoticesConfig from .spam_checker import SpamCheckerConfig +from .sso import SSOConfig from .stats import StatsConfig from .third_party_event_rules import ThirdPartyRulesConfig from .tls import TlsConfig @@ -65,6 +66,7 @@ class HomeServerConfig(RootConfig): KeyConfig, SAML2Config, CasConfig, + SSOConfig, JWTConfig, PasswordConfig, EmailConfig, diff --git a/synapse/config/sso.py b/synapse/config/sso.py new file mode 100644 index 0000000000..f426b65b4f --- /dev/null +++ b/synapse/config/sso.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict + +import pkg_resources + +from ._base import Config, ConfigError + + +class SSOConfig(Config): + """SSO Configuration + """ + + section = "sso" + + def read_config(self, config, **kwargs): + sso_config = config.get("sso") or {} # type: Dict[str, Any] + + # Pick a template directory in order of: + # * The sso-specific template_dir + # * /path/to/synapse/install/res/templates + template_dir = sso_config.get("template_dir") + if not template_dir: + template_dir = pkg_resources.resource_filename("synapse", "res/templates",) + + self.sso_redirect_confirm_template_dir = template_dir + + def generate_config_section(self, **kwargs): + return """\ + # Additional settings to use with single-sign on systems such as SAML2 and CAS. + # + sso: + # Directory in which Synapse will try to find the template files below. + # If not set, default templates from within the Synapse package will be used. + # + # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates. + # If you *do* uncomment it, you will need to make sure that all the templates + # below are in the directory. + # + # Synapse will look for the following templates in this directory: + # + # * HTML page for a confirmation step before redirecting back to the client + # with the login token: 'sso_redirect_confirm.html'. + # + # When rendering, this template is given three variables: + # * redirect_url: the URL the user is about to be redirected to. Needs + # manual escaping (see + # https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping). + # + # * display_url: the same as `redirect_url`, but with the query + # parameters stripped. The intention is to have a + # human-readable URL to show to users, not to use it as + # the final address to redirect to. Needs manual escaping + # (see https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping). + # + # * server_name: the homeserver's name. + # + # You can see the default templates at: + # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates + # + #template_dir: "res/templates" + """ diff --git a/synapse/res/templates/sso_redirect_confirm.html b/synapse/res/templates/sso_redirect_confirm.html new file mode 100644 index 0000000000..20a15e1e74 --- /dev/null +++ b/synapse/res/templates/sso_redirect_confirm.html @@ -0,0 +1,14 @@ + + + + + SSO redirect confirmation + + +

The application at {{ display_url | e }} is requesting full access to your {{ server_name }} Matrix account.

+

If you don't recognise this address, you should ignore this and close this tab.

+

+ I trust this address +

+ + \ No newline at end of file diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 1294e080dc..1acfd01d8e 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -29,6 +29,7 @@ from synapse.http.servlet import ( parse_string, ) from synapse.http.site import SynapseRequest +from synapse.push.mailer import load_jinja2_templates from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.well_known import WellKnownBuilder from synapse.types import UserID, map_username_to_mxid_localpart @@ -548,6 +549,13 @@ class SSOAuthHandler(object): self._registration_handler = hs.get_registration_handler() self._macaroon_gen = hs.get_macaroon_generator() + # Load the redirect page HTML template + self._template = load_jinja2_templates( + hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"], + )[0] + + self._server_name = hs.config.server_name + async def on_successful_auth( self, username, request, client_redirect_url, user_display_name=None ): @@ -592,21 +600,41 @@ class SSOAuthHandler(object): request: client_redirect_url: """ - + # Create a login token login_token = self._macaroon_gen.generate_short_term_login_token( registered_user_id ) - redirect_url = self._add_login_token_to_redirect_url( - client_redirect_url, login_token + + # Remove the query parameters from the redirect URL to get a shorter version of + # it. This is only to display a human-readable URL in the template, but not the + # URL we redirect users to. + redirect_url_no_params = client_redirect_url.split("?")[0] + + # Append the login token to the original redirect URL (i.e. with its query + # parameters kept intact) to build the URL to which the template needs to + # redirect the users once they have clicked on the confirmation link. + redirect_url = self._add_query_param_to_url( + client_redirect_url, "loginToken", login_token + ) + + # Serve the redirect confirmation page + html = self._template.render( + display_url=redirect_url_no_params, + redirect_url=redirect_url, + server_name=self._server_name, ) - request.redirect(redirect_url) + + request.setResponseCode(200) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Content-Length", b"%d" % (len(html),)) + request.write(html.encode("utf8")) finish_request(request) @staticmethod - def _add_login_token_to_redirect_url(url, token): + def _add_query_param_to_url(url, param_name, param): url_parts = list(urllib.parse.urlparse(url)) query = dict(urllib.parse.parse_qsl(url_parts[4])) - query.update({"loginToken": token}) + query.update({param_name: param}) url_parts[4] = urllib.parse.urlencode(query) return urllib.parse.urlunparse(url_parts) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index eae5411325..2b8ad5c753 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -1,4 +1,7 @@ import json +import urllib.parse + +from mock import Mock import synapse.rest.admin from synapse.rest.client.v1 import login @@ -252,3 +255,85 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): ) self.render(request) self.assertEquals(channel.code, 200, channel.result) + + +class CASRedirectConfirmTestCase(unittest.HomeserverTestCase): + + servlets = [ + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.base_url = "https://matrix.goodserver.com/" + self.redirect_path = "_synapse/client/login/sso/redirect/confirm" + + config = self.default_config() + config["enable_registration"] = True + config["cas_config"] = { + "enabled": True, + "server_url": "https://fake.test", + "service_url": "https://matrix.goodserver.com:8448", + } + config["public_baseurl"] = self.base_url + + async def get_raw(uri, args): + """Return an example response payload from a call to the `/proxyValidate` + endpoint of a CAS server, copied from + https://apereo.github.io/cas/5.0.x/protocol/CAS-Protocol-V2-Specification.html#26-proxyvalidate-cas-20 + + This needs to be returned by an async function (as opposed to set as the + mock's return value) because the corresponding Synapse code awaits on it. + """ + return """ + + + username + PGTIOU-84678-8a9d... + + https://proxy2/pgtUrl + https://proxy1/pgtUrl + + + + """ + + mocked_http_client = Mock(spec=["get_raw"]) + mocked_http_client.get_raw.side_effect = get_raw + + self.hs = self.setup_test_homeserver( + config=config, proxied_http_client=mocked_http_client, + ) + + return self.hs + + def test_cas_redirect_confirm(self): + """Tests that the SSO login flow serves a confirmation page before redirecting a + user to the redirect URL. + """ + base_url = "/login/cas/ticket?redirectUrl" + redirect_url = "https://dodgy-site.com/" + + url_parts = list(urllib.parse.urlparse(base_url)) + query = dict(urllib.parse.parse_qsl(url_parts[4])) + query.update({"redirectUrl": redirect_url}) + query.update({"ticket": "ticket"}) + url_parts[4] = urllib.parse.urlencode(query) + cas_ticket_url = urllib.parse.urlunparse(url_parts) + + # Get Synapse to call the fake CAS and serve the template. + request, channel = self.make_request("GET", cas_ticket_url) + self.render(request) + + # Test that the response is HTML. + content_type_header_value = "" + for header in channel.result.get("headers", []): + if header[0] == b"Content-Type": + content_type_header_value = header[1].decode("utf8") + + self.assertTrue(content_type_header_value.startswith("text/html")) + + # Test that the body isn't empty. + self.assertTrue(len(channel.result["body"]) > 0) + + # And that it contains our redirect link + self.assertIn(redirect_url, channel.result["body"].decode("UTF-8")) -- cgit 1.5.1 From b68041df3dcbcf3ca04c500d1712aa22a3c2580c Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 2 Mar 2020 17:05:09 +0000 Subject: Add a whitelist for the SSO confirmation step. --- docs/sample_config.yaml | 22 +++++++++++++++++++--- synapse/config/sso.py | 18 ++++++++++++++++++ synapse/rest/client/v1/login.py | 26 ++++++++++++++++++-------- tests/rest/client/v1/test_login.py | 32 +++++++++++++++++++++++++++++--- 4 files changed, 84 insertions(+), 14 deletions(-) (limited to 'tests') diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index bbb8a4d934..f719ec696f 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1363,6 +1363,22 @@ saml2_config: # Additional settings to use with single-sign on systems such as SAML2 and CAS. # sso: + # A list of client URLs which are whitelisted so that the user does not + # have to confirm giving access to their account to the URL. Any client + # whose URL starts with an entry in the following list will not be subject + # to an additional confirmation step after the SSO login is completed. + # + # WARNING: An entry such as "https://my.client" is insecure, because it + # will also match "https://my.client.evil.site", exposing your users to + # phishing attacks from evil.site. To avoid this, include a slash after the + # hostname: "https://my.client/". + # + # By default, this list is empty. + # + #client_whitelist: + # - https://riot.im/develop + # - https://my.custom.client/ + # Directory in which Synapse will try to find the template files below. # If not set, default templates from within the Synapse package will be used. # @@ -1372,8 +1388,8 @@ sso: # # Synapse will look for the following templates in this directory: # - # * HTML page for confirmation of redirect during authentication: - # 'sso_redirect_confirm.html'. + # * HTML page for a confirmation step before redirecting back to the client + # with the login token: 'sso_redirect_confirm.html'. # # When rendering, this template is given three variables: # * redirect_url: the URL the user is about to be redirected to. Needs @@ -1381,7 +1397,7 @@ sso: # https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping). # # * display_url: the same as `redirect_url`, but with the query - # parameters stripped. The intention is to have a + # parameters stripped. The intention is to have a # human-readable URL to show to users, not to use it as # the final address to redirect to. Needs manual escaping # (see https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping). diff --git a/synapse/config/sso.py b/synapse/config/sso.py index f426b65b4f..56299bd4e4 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -37,11 +37,29 @@ class SSOConfig(Config): self.sso_redirect_confirm_template_dir = template_dir + self.sso_client_whitelist = sso_config.get("client_whitelist") or [] + def generate_config_section(self, **kwargs): return """\ # Additional settings to use with single-sign on systems such as SAML2 and CAS. # sso: + # A list of client URLs which are whitelisted so that the user does not + # have to confirm giving access to their account to the URL. Any client + # whose URL starts with an entry in the following list will not be subject + # to an additional confirmation step after the SSO login is completed. + # + # WARNING: An entry such as "https://my.client" is insecure, because it + # will also match "https://my.client.evil.site", exposing your users to + # phishing attacks from evil.site. To avoid this, include a slash after the + # hostname: "https://my.client/". + # + # By default, this list is empty. + # + #client_whitelist: + # - https://riot.im/develop + # - https://my.custom.client/ + # Directory in which Synapse will try to find the template files below. # If not set, default templates from within the Synapse package will be used. # diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 1acfd01d8e..b2bc7537db 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -556,6 +556,9 @@ class SSOAuthHandler(object): self._server_name = hs.config.server_name + # cast to tuple for use with str.startswith + self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist) + async def on_successful_auth( self, username, request, client_redirect_url, user_display_name=None ): @@ -605,11 +608,6 @@ class SSOAuthHandler(object): registered_user_id ) - # Remove the query parameters from the redirect URL to get a shorter version of - # it. This is only to display a human-readable URL in the template, but not the - # URL we redirect users to. - redirect_url_no_params = client_redirect_url.split("?")[0] - # Append the login token to the original redirect URL (i.e. with its query # parameters kept intact) to build the URL to which the template needs to # redirect the users once they have clicked on the confirmation link. @@ -617,17 +615,29 @@ class SSOAuthHandler(object): client_redirect_url, "loginToken", login_token ) - # Serve the redirect confirmation page + # if the client is whitelisted, we can redirect straight to it + if client_redirect_url.startswith(self._whitelisted_sso_clients): + request.redirect(redirect_url) + finish_request(request) + return + + # Otherwise, serve the redirect confirmation page. + + # Remove the query parameters from the redirect URL to get a shorter version of + # it. This is only to display a human-readable URL in the template, but not the + # URL we redirect users to. + redirect_url_no_params = client_redirect_url.split("?")[0] + html = self._template.render( display_url=redirect_url_no_params, redirect_url=redirect_url, server_name=self._server_name, - ) + ).encode("utf-8") request.setResponseCode(200) request.setHeader(b"Content-Type", b"text/html; charset=utf-8") request.setHeader(b"Content-Length", b"%d" % (len(html),)) - request.write(html.encode("utf8")) + request.write(html) finish_request(request) @staticmethod diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 2b8ad5c753..da2c9bfa1e 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -268,13 +268,11 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase): self.redirect_path = "_synapse/client/login/sso/redirect/confirm" config = self.default_config() - config["enable_registration"] = True config["cas_config"] = { "enabled": True, "server_url": "https://fake.test", "service_url": "https://matrix.goodserver.com:8448", } - config["public_baseurl"] = self.base_url async def get_raw(uri, args): """Return an example response payload from a call to the `/proxyValidate` @@ -310,7 +308,7 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase): """Tests that the SSO login flow serves a confirmation page before redirecting a user to the redirect URL. """ - base_url = "/login/cas/ticket?redirectUrl" + base_url = "/_matrix/client/r0/login/cas/ticket?redirectUrl" redirect_url = "https://dodgy-site.com/" url_parts = list(urllib.parse.urlparse(base_url)) @@ -325,6 +323,7 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase): self.render(request) # Test that the response is HTML. + self.assertEqual(channel.code, 200) content_type_header_value = "" for header in channel.result.get("headers", []): if header[0] == b"Content-Type": @@ -337,3 +336,30 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase): # And that it contains our redirect link self.assertIn(redirect_url, channel.result["body"].decode("UTF-8")) + + @override_config( + { + "sso": { + "client_whitelist": [ + "https://legit-site.com/", + "https://other-site.com/", + ] + } + } + ) + def test_cas_redirect_whitelisted(self): + """Tests that the SSO login flow serves a redirect to a whitelisted url + """ + redirect_url = "https://legit-site.com/" + cas_ticket_url = ( + "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" + % (urllib.parse.quote(redirect_url)) + ) + + # Get Synapse to call the fake CAS and serve the template. + request, channel = self.make_request("GET", cas_ticket_url) + self.render(request) + + self.assertEqual(channel.code, 302) + location_headers = channel.headers.getRawHeaders("Location") + self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url) -- cgit 1.5.1 From 7dcbc33a1be04c46b930699c03c15bc759f4b22c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 3 Mar 2020 07:12:45 -0500 Subject: Validate the alt_aliases property of canonical alias events (#6971) --- changelog.d/6971.feature | 1 + synapse/api/errors.py | 1 + synapse/handlers/directory.py | 14 ++-- synapse/handlers/message.py | 47 ++++++++++- synapse/types.py | 15 ++-- tests/handlers/test_directory.py | 66 +++++++-------- tests/rest/client/v1/test_rooms.py | 160 +++++++++++++++++++++++++++++++++++++ tests/test_types.py | 2 +- 8 files changed, 254 insertions(+), 52 deletions(-) create mode 100644 changelog.d/6971.feature (limited to 'tests') diff --git a/changelog.d/6971.feature b/changelog.d/6971.feature new file mode 100644 index 0000000000..ccf02a61df --- /dev/null +++ b/changelog.d/6971.feature @@ -0,0 +1 @@ +Validate the alt_aliases property of canonical alias events. diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 0c20601600..616942b057 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -66,6 +66,7 @@ class Codes(object): EXPIRED_ACCOUNT = "ORG_MATRIX_EXPIRED_ACCOUNT" INVALID_SIGNATURE = "M_INVALID_SIGNATURE" USER_DEACTIVATED = "M_USER_DEACTIVATED" + BAD_ALIAS = "M_BAD_ALIAS" class CodeMessageException(RuntimeError): diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 0b23ca919a..61eb49059b 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import collections import logging import string from typing import List @@ -307,15 +305,17 @@ class DirectoryHandler(BaseHandler): send_update = True content.pop("alias", "") - # Filter alt_aliases for the removed alias. - alt_aliases = content.pop("alt_aliases", None) - # If the aliases are not a list (or not found) do not attempt to modify - # the list. - if isinstance(alt_aliases, collections.Sequence): + # Filter the alt_aliases property for the removed alias. Note that the + # value is not modified if alt_aliases is of an unexpected form. + alt_aliases = content.get("alt_aliases") + if isinstance(alt_aliases, (list, tuple)) and alias_str in alt_aliases: send_update = True alt_aliases = [alias for alias in alt_aliases if alias != alias_str] + if alt_aliases: content["alt_aliases"] = alt_aliases + else: + del content["alt_aliases"] if send_update: yield self.event_creation_handler.create_and_send_nonmember_event( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index a0103addd3..0c84c6cec4 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -888,19 +888,60 @@ class EventCreationHandler(object): yield self.base_handler.maybe_kick_guest_users(event, context) if event.type == EventTypes.CanonicalAlias: - # Check the alias is acually valid (at this time at least) + # Validate a newly added alias or newly added alt_aliases. + + original_alias = None + original_alt_aliases = set() + + original_event_id = event.unsigned.get("replaces_state") + if original_event_id: + original_event = yield self.store.get_event(original_event_id) + + if original_event: + original_alias = original_event.content.get("alias", None) + original_alt_aliases = original_event.content.get("alt_aliases", []) + + # Check the alias is currently valid (if it has changed). room_alias_str = event.content.get("alias", None) - if room_alias_str: + directory_handler = self.hs.get_handlers().directory_handler + if room_alias_str and room_alias_str != original_alias: room_alias = RoomAlias.from_string(room_alias_str) - directory_handler = self.hs.get_handlers().directory_handler mapping = yield directory_handler.get_association(room_alias) if mapping["room_id"] != event.room_id: raise SynapseError( 400, "Room alias %s does not point to the room" % (room_alias_str,), + Codes.BAD_ALIAS, ) + # Check that alt_aliases is the proper form. + alt_aliases = event.content.get("alt_aliases", []) + if not isinstance(alt_aliases, (list, tuple)): + raise SynapseError( + 400, "The alt_aliases property must be a list.", Codes.INVALID_PARAM + ) + + # If the old version of alt_aliases is of an unknown form, + # completely replace it. + if not isinstance(original_alt_aliases, (list, tuple)): + original_alt_aliases = [] + + # Check that each alias is currently valid. + new_alt_aliases = set(alt_aliases) - set(original_alt_aliases) + if new_alt_aliases: + for alias_str in new_alt_aliases: + room_alias = RoomAlias.from_string(alias_str) + mapping = yield directory_handler.get_association(room_alias) + + if mapping["room_id"] != event.room_id: + raise SynapseError( + 400, + "Room alias %s does not point to the room" + % (room_alias_str,), + Codes.BAD_ALIAS, + ) + federation_handler = self.hs.get_handlers().federation_handler if event.type == EventTypes.Member: diff --git a/synapse/types.py b/synapse/types.py index f3cd465735..acf60baddc 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -23,7 +23,7 @@ import attr from signedjson.key import decode_verify_key_bytes from unpaddedbase64 import decode_base64 -from synapse.api.errors import SynapseError +from synapse.api.errors import Codes, SynapseError # define a version of typing.Collection that works on python 3.5 if sys.version_info[:3] >= (3, 6, 0): @@ -166,11 +166,13 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom return self @classmethod - def from_string(cls, s): + def from_string(cls, s: str): """Parse the string given by 's' into a structure object.""" if len(s) < 1 or s[0:1] != cls.SIGIL: raise SynapseError( - 400, "Expected %s string to start with '%s'" % (cls.__name__, cls.SIGIL) + 400, + "Expected %s string to start with '%s'" % (cls.__name__, cls.SIGIL), + Codes.INVALID_PARAM, ) parts = s[1:].split(":", 1) @@ -179,6 +181,7 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom 400, "Expected %s of the form '%slocalname:domain'" % (cls.__name__, cls.SIGIL), + Codes.INVALID_PARAM, ) domain = parts[1] @@ -235,11 +238,13 @@ class GroupID(DomainSpecificString): def from_string(cls, s): group_id = super(GroupID, cls).from_string(s) if not group_id.localpart: - raise SynapseError(400, "Group ID cannot be empty") + raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM) if contains_invalid_mxid_characters(group_id.localpart): raise SynapseError( - 400, "Group ID can only contain characters a-z, 0-9, or '=_-./'" + 400, + "Group ID can only contain characters a-z, 0-9, or '=_-./'", + Codes.INVALID_PARAM, ) return group_id diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 27b916aed4..3397cfa485 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -88,6 +88,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): ) def test_delete_alias_not_allowed(self): + """Removing an alias should be denied if a user does not have the proper permissions.""" room_id = "!8765qwer:test" self.get_success( self.store.create_room_alias_association(self.my_room, room_id, ["test"]) @@ -101,6 +102,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): ) def test_delete_alias(self): + """Removing an alias should work when a user does has the proper permissions.""" room_id = "!8765qwer:test" user_id = "@user:test" self.get_success( @@ -159,30 +161,42 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) self.test_alias = "#test:test" - self.room_alias = RoomAlias.from_string(self.test_alias) + self.room_alias = self._add_alias(self.test_alias) + + def _add_alias(self, alias: str) -> RoomAlias: + """Add an alias to the test room.""" + room_alias = RoomAlias.from_string(alias) # Create a new alias to this room. self.get_success( self.store.create_room_alias_association( - self.room_alias, self.room_id, ["test"], self.admin_user + room_alias, self.room_id, ["test"], self.admin_user ) ) + return room_alias - def test_remove_alias(self): - """Removing an alias that is the canonical alias should remove it there too.""" - # Set this new alias as the canonical alias for this room + def _set_canonical_alias(self, content): + """Configure the canonical alias state on the room.""" self.helper.send_state( - self.room_id, - "m.room.canonical_alias", - {"alias": self.test_alias, "alt_aliases": [self.test_alias]}, - tok=self.admin_user_tok, + self.room_id, "m.room.canonical_alias", content, tok=self.admin_user_tok, ) - data = self.get_success( + def _get_canonical_alias(self): + """Get the canonical alias state of the room.""" + return self.get_success( self.state_handler.get_current_state( self.room_id, EventTypes.CanonicalAlias, "" ) ) + + def test_remove_alias(self): + """Removing an alias that is the canonical alias should remove it there too.""" + # Set this new alias as the canonical alias for this room + self._set_canonical_alias( + {"alias": self.test_alias, "alt_aliases": [self.test_alias]} + ) + + data = self._get_canonical_alias() self.assertEqual(data["content"]["alias"], self.test_alias) self.assertEqual(data["content"]["alt_aliases"], [self.test_alias]) @@ -193,11 +207,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) ) - data = self.get_success( - self.state_handler.get_current_state( - self.room_id, EventTypes.CanonicalAlias, "" - ) - ) + data = self._get_canonical_alias() self.assertNotIn("alias", data["content"]) self.assertNotIn("alt_aliases", data["content"]) @@ -205,29 +215,17 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): """Removing an alias listed as in alt_aliases should remove it there too.""" # Create a second alias. other_test_alias = "#test2:test" - other_room_alias = RoomAlias.from_string(other_test_alias) - self.get_success( - self.store.create_room_alias_association( - other_room_alias, self.room_id, ["test"], self.admin_user - ) - ) + other_room_alias = self._add_alias(other_test_alias) # Set the alias as the canonical alias for this room. - self.helper.send_state( - self.room_id, - "m.room.canonical_alias", + self._set_canonical_alias( { "alias": self.test_alias, "alt_aliases": [self.test_alias, other_test_alias], - }, - tok=self.admin_user_tok, + } ) - data = self.get_success( - self.state_handler.get_current_state( - self.room_id, EventTypes.CanonicalAlias, "" - ) - ) + data = self._get_canonical_alias() self.assertEqual(data["content"]["alias"], self.test_alias) self.assertEqual( data["content"]["alt_aliases"], [self.test_alias, other_test_alias] @@ -240,11 +238,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) ) - data = self.get_success( - self.state_handler.get_current_state( - self.room_id, EventTypes.CanonicalAlias, "" - ) - ) + data = self._get_canonical_alias() self.assertEqual(data["content"]["alias"], self.test_alias) self.assertEqual(data["content"]["alt_aliases"], [self.test_alias]) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 2f3df5f88f..7dd86d0c27 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1821,3 +1821,163 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase): ) self.render(request) self.assertEqual(channel.code, expected_code, channel.result) + + +class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + directory.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.room_owner = self.register_user("room_owner", "test") + self.room_owner_tok = self.login("room_owner", "test") + + self.room_id = self.helper.create_room_as( + self.room_owner, tok=self.room_owner_tok + ) + + self.alias = "#alias:test" + self._set_alias_via_directory(self.alias) + + def _set_alias_via_directory(self, alias: str, expected_code: int = 200): + url = "/_matrix/client/r0/directory/room/" + alias + data = {"room_id": self.room_id} + request_data = json.dumps(data) + + request, channel = self.make_request( + "PUT", url, request_data, access_token=self.room_owner_tok + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + + def _get_canonical_alias(self, expected_code: int = 200) -> JsonDict: + """Calls the endpoint under test. returns the json response object.""" + request, channel = self.make_request( + "GET", + "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), + access_token=self.room_owner_tok, + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + res = channel.json_body + self.assertIsInstance(res, dict) + return res + + def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict: + """Calls the endpoint under test. returns the json response object.""" + request, channel = self.make_request( + "PUT", + "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), + json.dumps(content), + access_token=self.room_owner_tok, + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + res = channel.json_body + self.assertIsInstance(res, dict) + return res + + def test_canonical_alias(self): + """Test a basic alias message.""" + # There is no canonical alias to start with. + self._get_canonical_alias(expected_code=404) + + # Create an alias. + self._set_canonical_alias({"alias": self.alias}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias}) + + # Now remove the alias. + self._set_canonical_alias({}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {}) + + def test_alt_aliases(self): + """Test a canonical alias message with alt_aliases.""" + # Create an alias. + self._set_canonical_alias({"alt_aliases": [self.alias]}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alt_aliases": [self.alias]}) + + # Now remove the alt_aliases. + self._set_canonical_alias({}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {}) + + def test_alias_alt_aliases(self): + """Test a canonical alias message with an alias and alt_aliases.""" + # Create an alias. + self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]}) + + # Now remove the alias and alt_aliases. + self._set_canonical_alias({}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {}) + + def test_partial_modify(self): + """Test removing only the alt_aliases.""" + # Create an alias. + self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]}) + + # Now remove the alt_aliases. + self._set_canonical_alias({"alias": self.alias}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias}) + + def test_add_alias(self): + """Test removing only the alt_aliases.""" + # Create an additional alias. + second_alias = "#second:test" + self._set_alias_via_directory(second_alias) + + # Add the canonical alias. + self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) + + # Then add the second alias. + self._set_canonical_alias( + {"alias": self.alias, "alt_aliases": [self.alias, second_alias]} + ) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual( + res, {"alias": self.alias, "alt_aliases": [self.alias, second_alias]} + ) + + def test_bad_data(self): + """Invalid data for alt_aliases should cause errors.""" + self._set_canonical_alias({"alt_aliases": "@bad:test"}, expected_code=400) + self._set_canonical_alias({"alt_aliases": None}, expected_code=400) + self._set_canonical_alias({"alt_aliases": 0}, expected_code=400) + self._set_canonical_alias({"alt_aliases": 1}, expected_code=400) + self._set_canonical_alias({"alt_aliases": False}, expected_code=400) + self._set_canonical_alias({"alt_aliases": True}, expected_code=400) + self._set_canonical_alias({"alt_aliases": {}}, expected_code=400) + + def test_bad_alias(self): + """An alias which does not point to the room raises a SynapseError.""" + self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400) + self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400) diff --git a/tests/test_types.py b/tests/test_types.py index 8d97c751ea..480bea1bdc 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -75,7 +75,7 @@ class GroupIDTestCase(unittest.TestCase): self.fail("Parsing '%s' should raise exception" % id_string) except SynapseError as exc: self.assertEqual(400, exc.code) - self.assertEqual("M_UNKNOWN", exc.errcode) + self.assertEqual("M_INVALID_PARAM", exc.errcode) class MapUsernameTestCase(unittest.TestCase): -- cgit 1.5.1 From 8ef8fb2c1c7c4aeb80fce4deea477b37754ce539 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 4 Mar 2020 13:11:04 +0000 Subject: Read the room version from database when fetching events (#6874) This is a precursor to giving EventBase objects the knowledge of which room version they belong to. --- changelog.d/6874.misc | 1 + synapse/storage/data_stores/main/events_worker.py | 84 ++++++++++++++++++----- tests/replication/slave/storage/test_events.py | 10 +++ 3 files changed, 79 insertions(+), 16 deletions(-) create mode 100644 changelog.d/6874.misc (limited to 'tests') diff --git a/changelog.d/6874.misc b/changelog.d/6874.misc new file mode 100644 index 0000000000..08aa80bcd9 --- /dev/null +++ b/changelog.d/6874.misc @@ -0,0 +1 @@ +Refactoring work in preparation for changing the event redaction algorithm. diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 47a3a26072..ca237c6f12 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -28,9 +28,12 @@ from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.api.errors import NotFoundError -from synapse.api.room_versions import EventFormatVersions -from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401 -from synapse.events.snapshot import EventContext # noqa: F401 +from synapse.api.room_versions import ( + KNOWN_ROOM_VERSIONS, + EventFormatVersions, + RoomVersions, +) +from synapse.events import make_event_from_dict from synapse.events.utils import prune_event from synapse.logging.context import LoggingContext, PreserveLoggingContext from synapse.metrics.background_process_metrics import run_as_background_process @@ -580,8 +583,49 @@ class EventsWorkerStore(SQLBaseStore): # of a event format version, so it must be a V1 event. format_version = EventFormatVersions.V1 - original_ev = event_type_from_format_version(format_version)( + room_version_id = row["room_version_id"] + + if not room_version_id: + # this should only happen for out-of-band membership events + if not internal_metadata.get("out_of_band_membership"): + logger.warning( + "Room %s for event %s is unknown", d["room_id"], event_id + ) + continue + + # take a wild stab at the room version based on the event format + if format_version == EventFormatVersions.V1: + room_version = RoomVersions.V1 + elif format_version == EventFormatVersions.V2: + room_version = RoomVersions.V3 + else: + room_version = RoomVersions.V5 + else: + room_version = KNOWN_ROOM_VERSIONS.get(room_version_id) + if not room_version: + logger.error( + "Event %s in room %s has unknown room version %s", + event_id, + d["room_id"], + room_version_id, + ) + continue + + if room_version.event_format != format_version: + logger.error( + "Event %s in room %s with version %s has wrong format: " + "expected %s, was %s", + event_id, + d["room_id"], + room_version_id, + room_version.event_format, + format_version, + ) + continue + + original_ev = make_event_from_dict( event_dict=d, + room_version=room_version, internal_metadata_dict=internal_metadata, rejected_reason=rejected_reason, ) @@ -661,6 +705,12 @@ class EventsWorkerStore(SQLBaseStore): of EventFormatVersions. 'None' means the event predates EventFormatVersions (so the event is format V1). + * room_version_id (str|None): The version of the room which contains the event. + Hopefully one of RoomVersions. + + Due to historical reasons, there may be a few events in the database which + do not have an associated room; in this case None will be returned here. + * rejected_reason (str|None): if the event was rejected, the reason why. @@ -676,17 +726,18 @@ class EventsWorkerStore(SQLBaseStore): """ event_dict = {} for evs in batch_iter(event_ids, 200): - sql = ( - "SELECT " - " e.event_id, " - " e.internal_metadata," - " e.json," - " e.format_version, " - " rej.reason " - " FROM event_json as e" - " LEFT JOIN rejections as rej USING (event_id)" - " WHERE " - ) + sql = """\ + SELECT + e.event_id, + e.internal_metadata, + e.json, + e.format_version, + r.room_version, + rej.reason + FROM event_json as e + LEFT JOIN rooms r USING (room_id) + LEFT JOIN rejections as rej USING (event_id) + WHERE """ clause, args = make_in_list_sql_clause( txn.database_engine, "e.event_id", evs @@ -701,7 +752,8 @@ class EventsWorkerStore(SQLBaseStore): "internal_metadata": row[1], "json": row[2], "format_version": row[3], - "rejected_reason": row[4], + "room_version_id": row[4], + "rejected_reason": row[5], "redactions": [], } diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index d31210fbe4..f0561b30e3 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -15,6 +15,7 @@ import logging from canonicaljson import encode_canonical_json +from synapse.api.room_versions import RoomVersions from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict from synapse.events.snapshot import EventContext from synapse.handlers.room import RoomEventSource @@ -58,6 +59,15 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)] return super(SlavedEventStoreTestCase, self).setUp() + def prepare(self, *args, **kwargs): + super().prepare(*args, **kwargs) + + self.get_success( + self.master_store.store_room( + ROOM_ID, USER_ID, is_public=False, room_version=RoomVersions.V1, + ) + ) + def tearDown(self): [unpatch() for unpatch in self.unpatches] -- cgit 1.5.1 From 13892776ef7e0b1af2f82c9ca53f7bbd1c60d66f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 4 Mar 2020 11:30:46 -0500 Subject: Allow deleting an alias if the user has sufficient power level (#6986) --- changelog.d/6986.feature | 1 + synapse/api/auth.py | 9 +-- synapse/handlers/directory.py | 107 ++++++++++++++++++++++---------- tests/handlers/test_directory.py | 128 +++++++++++++++++++++++++++++++-------- tox.ini | 1 + 5 files changed, 182 insertions(+), 64 deletions(-) create mode 100644 changelog.d/6986.feature (limited to 'tests') diff --git a/changelog.d/6986.feature b/changelog.d/6986.feature new file mode 100644 index 0000000000..16dea8bd7f --- /dev/null +++ b/changelog.d/6986.feature @@ -0,0 +1 @@ +Users with a power level sufficient to modify the canonical alias of a room can now delete room aliases. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 5ca18b4301..c1ade1333b 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -539,7 +539,7 @@ class Auth(object): @defer.inlineCallbacks def check_can_change_room_list(self, room_id: str, user: UserID): - """Check if the user is allowed to edit the room's entry in the + """Determine whether the user is allowed to edit the room's entry in the published room list. Args: @@ -570,12 +570,7 @@ class Auth(object): ) user_level = event_auth.get_user_power_level(user_id, auth_events) - if user_level < send_level: - raise AuthError( - 403, - "This server requires you to be a moderator in the room to" - " edit its room list entry", - ) + return user_level >= send_level @staticmethod def has_access_token(request): diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 61eb49059b..1d842c369b 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -15,7 +15,7 @@ import logging import string -from typing import List +from typing import Iterable, List, Optional from twisted.internet import defer @@ -28,6 +28,7 @@ from synapse.api.errors import ( StoreError, SynapseError, ) +from synapse.appservice import ApplicationService from synapse.types import Requester, RoomAlias, UserID, get_domain_from_id from ._base import BaseHandler @@ -55,7 +56,13 @@ class DirectoryHandler(BaseHandler): self.spam_checker = hs.get_spam_checker() @defer.inlineCallbacks - def _create_association(self, room_alias, room_id, servers=None, creator=None): + def _create_association( + self, + room_alias: RoomAlias, + room_id: str, + servers: Optional[Iterable[str]] = None, + creator: Optional[str] = None, + ): # general association creation for both human users and app services for wchar in string.whitespace: @@ -81,17 +88,21 @@ class DirectoryHandler(BaseHandler): @defer.inlineCallbacks def create_association( - self, requester, room_alias, room_id, servers=None, check_membership=True, + self, + requester: Requester, + room_alias: RoomAlias, + room_id: str, + servers: Optional[List[str]] = None, + check_membership: bool = True, ): """Attempt to create a new alias Args: - requester (Requester) - room_alias (RoomAlias) - room_id (str) - servers (list[str]|None): List of servers that others servers - should try and join via - check_membership (bool): Whether to check if the user is in the room + requester + room_alias + room_id + servers: Iterable of servers that others servers should try and join via + check_membership: Whether to check if the user is in the room before the alias can be set (if the server's config requires it). Returns: @@ -145,15 +156,15 @@ class DirectoryHandler(BaseHandler): yield self._create_association(room_alias, room_id, servers, creator=user_id) @defer.inlineCallbacks - def delete_association(self, requester, room_alias): + def delete_association(self, requester: Requester, room_alias: RoomAlias): """Remove an alias from the directory (this is only meant for human users; AS users should call delete_appservice_association) Args: - requester (Requester): - room_alias (RoomAlias): + requester + room_alias Returns: Deferred[unicode]: room id that the alias used to point to @@ -189,16 +200,16 @@ class DirectoryHandler(BaseHandler): room_id = yield self._delete_association(room_alias) try: - yield self._update_canonical_alias( - requester, requester.user.to_string(), room_id, room_alias - ) + yield self._update_canonical_alias(requester, user_id, room_id, room_alias) except AuthError as e: logger.info("Failed to update alias events: %s", e) return room_id @defer.inlineCallbacks - def delete_appservice_association(self, service, room_alias): + def delete_appservice_association( + self, service: ApplicationService, room_alias: RoomAlias + ): if not service.is_interested_in_alias(room_alias.to_string()): raise SynapseError( 400, @@ -208,7 +219,7 @@ class DirectoryHandler(BaseHandler): yield self._delete_association(room_alias) @defer.inlineCallbacks - def _delete_association(self, room_alias): + def _delete_association(self, room_alias: RoomAlias): if not self.hs.is_mine(room_alias): raise SynapseError(400, "Room alias must be local") @@ -217,7 +228,7 @@ class DirectoryHandler(BaseHandler): return room_id @defer.inlineCallbacks - def get_association(self, room_alias): + def get_association(self, room_alias: RoomAlias): room_id = None if self.hs.is_mine(room_alias): result = yield self.get_association_from_room_alias(room_alias) @@ -282,7 +293,9 @@ class DirectoryHandler(BaseHandler): ) @defer.inlineCallbacks - def _update_canonical_alias(self, requester, user_id, room_id, room_alias): + def _update_canonical_alias( + self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias + ): """ Send an updated canonical alias event if the removed alias was set as the canonical alias or listed in the alt_aliases field. @@ -331,7 +344,7 @@ class DirectoryHandler(BaseHandler): ) @defer.inlineCallbacks - def get_association_from_room_alias(self, room_alias): + def get_association_from_room_alias(self, room_alias: RoomAlias): result = yield self.store.get_association_from_room_alias(room_alias) if not result: # Query AS to see if it exists @@ -339,7 +352,7 @@ class DirectoryHandler(BaseHandler): result = yield as_handler.query_room_alias_exists(room_alias) return result - def can_modify_alias(self, alias, user_id=None): + def can_modify_alias(self, alias: RoomAlias, user_id: Optional[str] = None): # Any application service "interested" in an alias they are regexing on # can modify the alias. # Users can only modify the alias if ALL the interested services have @@ -360,22 +373,42 @@ class DirectoryHandler(BaseHandler): return defer.succeed(True) @defer.inlineCallbacks - def _user_can_delete_alias(self, alias, user_id): + def _user_can_delete_alias(self, alias: RoomAlias, user_id: str): + """Determine whether a user can delete an alias. + + One of the following must be true: + + 1. The user created the alias. + 2. The user is a server administrator. + 3. The user has a power-level sufficient to send a canonical alias event + for the current room. + + """ creator = yield self.store.get_room_alias_creator(alias.to_string()) if creator is not None and creator == user_id: return True - is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id)) - return is_admin + # Resolve the alias to the corresponding room. + room_mapping = yield self.get_association(alias) + room_id = room_mapping["room_id"] + if not room_id: + return False + + res = yield self.auth.check_can_change_room_list( + room_id, UserID.from_string(user_id) + ) + return res @defer.inlineCallbacks - def edit_published_room_list(self, requester, room_id, visibility): + def edit_published_room_list( + self, requester: Requester, room_id: str, visibility: str + ): """Edit the entry of the room in the published room list. requester - room_id (str) - visibility (str): "public" or "private" + room_id + visibility: "public" or "private" """ user_id = requester.user.to_string() @@ -400,7 +433,15 @@ class DirectoryHandler(BaseHandler): if room is None: raise SynapseError(400, "Unknown room") - yield self.auth.check_can_change_room_list(room_id, requester.user) + can_change_room_list = yield self.auth.check_can_change_room_list( + room_id, requester.user + ) + if not can_change_room_list: + raise AuthError( + 403, + "This server requires you to be a moderator in the room to" + " edit its room list entry", + ) making_public = visibility == "public" if making_public: @@ -421,16 +462,16 @@ class DirectoryHandler(BaseHandler): @defer.inlineCallbacks def edit_published_appservice_room_list( - self, appservice_id, network_id, room_id, visibility + self, appservice_id: str, network_id: str, room_id: str, visibility: str ): """Add or remove a room from the appservice/network specific public room list. Args: - appservice_id (str): ID of the appservice that owns the list - network_id (str): The ID of the network the list is associated with - room_id (str) - visibility (str): either "public" or "private" + appservice_id: ID of the appservice that owns the list + network_id: The ID of the network the list is associated with + room_id + visibility: either "public" or "private" """ if visibility not in ["public", "private"]: raise SynapseError(400, "Invalid visibility setting") diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 3397cfa485..5e40adba52 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -18,6 +18,7 @@ from mock import Mock from twisted.internet import defer +import synapse import synapse.api.errors from synapse.api.constants import EventTypes from synapse.config.room_directory import RoomDirectoryConfig @@ -87,52 +88,131 @@ class DirectoryTestCase(unittest.HomeserverTestCase): ignore_backoff=True, ) - def test_delete_alias_not_allowed(self): - """Removing an alias should be denied if a user does not have the proper permissions.""" - room_id = "!8765qwer:test" + def test_incoming_fed_query(self): + self.get_success( + self.store.create_room_alias_association( + self.your_room, "!8765asdf:test", ["test"] + ) + ) + + response = self.get_success( + self.handler.on_directory_query({"room_alias": "#your-room:test"}) + ) + + self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response) + + +class TestDeleteAlias(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + directory.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.handler = hs.get_handlers().directory_handler + self.state_handler = hs.get_state_handler() + + # Create user + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + # Create a test room + self.room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + + self.test_alias = "#test:test" + self.room_alias = RoomAlias.from_string(self.test_alias) + + # Create a test user. + self.test_user = self.register_user("user", "pass", admin=False) + self.test_user_tok = self.login("user", "pass") + self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok) + + def _create_alias(self, user): + # Create a new alias to this room. self.get_success( - self.store.create_room_alias_association(self.my_room, room_id, ["test"]) + self.store.create_room_alias_association( + self.room_alias, self.room_id, ["test"], user + ) ) + def test_delete_alias_not_allowed(self): + """A user that doesn't meet the expected guidelines cannot delete an alias.""" + self._create_alias(self.admin_user) self.get_failure( self.handler.delete_association( - create_requester("@user:test"), self.my_room + create_requester(self.test_user), self.room_alias ), synapse.api.errors.AuthError, ) - def test_delete_alias(self): - """Removing an alias should work when a user does has the proper permissions.""" - room_id = "!8765qwer:test" - user_id = "@user:test" - self.get_success( - self.store.create_room_alias_association( - self.my_room, room_id, ["test"], user_id + def test_delete_alias_creator(self): + """An alias creator can delete their own alias.""" + # Create an alias from a different user. + self._create_alias(self.test_user) + + # Delete the user's alias. + result = self.get_success( + self.handler.delete_association( + create_requester(self.test_user), self.room_alias ) ) + self.assertEquals(self.room_id, result) + # Confirm the alias is gone. + self.get_failure( + self.handler.get_association(self.room_alias), + synapse.api.errors.SynapseError, + ) + + def test_delete_alias_admin(self): + """A server admin can delete an alias created by another user.""" + # Create an alias from a different user. + self._create_alias(self.test_user) + + # Delete the user's alias as the admin. result = self.get_success( - self.handler.delete_association(create_requester(user_id), self.my_room) + self.handler.delete_association( + create_requester(self.admin_user), self.room_alias + ) ) - self.assertEquals(room_id, result) + self.assertEquals(self.room_id, result) - # The alias should not be found. + # Confirm the alias is gone. self.get_failure( - self.handler.get_association(self.my_room), synapse.api.errors.SynapseError + self.handler.get_association(self.room_alias), + synapse.api.errors.SynapseError, ) - def test_incoming_fed_query(self): - self.get_success( - self.store.create_room_alias_association( - self.your_room, "!8765asdf:test", ["test"] - ) + def test_delete_alias_sufficient_power(self): + """A user with a sufficient power level should be able to delete an alias.""" + self._create_alias(self.admin_user) + + # Increase the user's power level. + self.helper.send_state( + self.room_id, + "m.room.power_levels", + {"users": {self.test_user: 100}}, + tok=self.admin_user_tok, ) - response = self.get_success( - self.handler.on_directory_query({"room_alias": "#your-room:test"}) + # They can now delete the alias. + result = self.get_success( + self.handler.delete_association( + create_requester(self.test_user), self.room_alias + ) ) + self.assertEquals(self.room_id, result) - self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response) + # Confirm the alias is gone. + self.get_failure( + self.handler.get_association(self.room_alias), + synapse.api.errors.SynapseError, + ) class CanonicalAliasTestCase(unittest.HomeserverTestCase): diff --git a/tox.ini b/tox.ini index 097ebb8774..7622aa19f1 100644 --- a/tox.ini +++ b/tox.ini @@ -185,6 +185,7 @@ commands = mypy \ synapse/federation/federation_client.py \ synapse/federation/sender \ synapse/federation/transport \ + synapse/handlers/directory.py \ synapse/handlers/presence.py \ synapse/handlers/sync.py \ synapse/handlers/ui_auth \ -- cgit 1.5.1 From 1d66dce83e58827aae12080552edeaeb357b1997 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Fri, 6 Mar 2020 18:14:19 +0000 Subject: Break down monthly active users by appservice_id (#7030) * Break down monthly active users by appservice_id and emit via prometheus. Co-authored-by: Brendan Abolivier --- changelog.d/7030.feature | 1 + synapse/app/homeserver.py | 13 +++++++ .../data_stores/main/monthly_active_users.py | 32 ++++++++++++++++- tests/storage/test_monthly_active_users.py | 42 ++++++++++++++++++++++ 4 files changed, 87 insertions(+), 1 deletion(-) create mode 100644 changelog.d/7030.feature (limited to 'tests') diff --git a/changelog.d/7030.feature b/changelog.d/7030.feature new file mode 100644 index 0000000000..fcfdb8d8a1 --- /dev/null +++ b/changelog.d/7030.feature @@ -0,0 +1 @@ +Break down monthly active users by `appservice_id` and emit via Prometheus. diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index c2a334a2b0..e0fdddfdc9 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -298,6 +298,11 @@ class SynapseHomeServer(HomeServer): # Gauges to expose monthly active user control metrics current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU") +current_mau_by_service_gauge = Gauge( + "synapse_admin_mau_current_mau_by_service", + "Current MAU by service", + ["app_service"], +) max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit") registered_reserved_users_mau_gauge = Gauge( "synapse_admin_mau:registered_reserved_users", @@ -585,12 +590,20 @@ def run(hs): @defer.inlineCallbacks def generate_monthly_active_users(): current_mau_count = 0 + current_mau_count_by_service = {} reserved_users = () store = hs.get_datastore() if hs.config.limit_usage_by_mau or hs.config.mau_stats_only: current_mau_count = yield store.get_monthly_active_count() + current_mau_count_by_service = ( + yield store.get_monthly_active_count_by_service() + ) reserved_users = yield store.get_registered_reserved_users() current_mau_gauge.set(float(current_mau_count)) + + for app_service, count in current_mau_count_by_service.items(): + current_mau_by_service_gauge.labels(app_service).set(float(count)) + registered_reserved_users_mau_gauge.set(float(len(reserved_users))) max_mau_gauge.set(float(hs.config.max_mau_value)) diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py index 1507a14e09..925bc5691b 100644 --- a/synapse/storage/data_stores/main/monthly_active_users.py +++ b/synapse/storage/data_stores/main/monthly_active_users.py @@ -43,13 +43,40 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): def _count_users(txn): sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users" - txn.execute(sql) (count,) = txn.fetchone() return count return self.db.runInteraction("count_users", _count_users) + @cached(num_args=0) + def get_monthly_active_count_by_service(self): + """Generates current count of monthly active users broken down by service. + A service is typically an appservice but also includes native matrix users. + Since the `monthly_active_users` table is populated from the `user_ips` table + `config.track_appservice_user_ips` must be set to `true` for this + method to return anything other than native matrix users. + + Returns: + Deferred[dict]: dict that includes a mapping between app_service_id + and the number of occurrences. + + """ + + def _count_users_by_service(txn): + sql = """ + SELECT COALESCE(appservice_id, 'native'), COALESCE(count(*), 0) + FROM monthly_active_users + LEFT JOIN users ON monthly_active_users.user_id=users.name + GROUP BY appservice_id; + """ + + txn.execute(sql) + result = txn.fetchall() + return dict(result) + + return self.db.runInteraction("count_users_by_service", _count_users_by_service) + @defer.inlineCallbacks def get_registered_reserved_users(self): """Of the reserved threepids defined in config, which are associated @@ -291,6 +318,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): ) self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) + self._invalidate_cache_and_stream( + txn, self.get_monthly_active_count_by_service, () + ) self._invalidate_cache_and_stream( txn, self.user_last_seen_monthly_active, (user_id,) ) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 3c78faab45..bc53bf0951 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -303,3 +303,45 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.pump() self.store.upsert_monthly_active_user.assert_not_called() + + def test_get_monthly_active_count_by_service(self): + appservice1_user1 = "@appservice1_user1:example.com" + appservice1_user2 = "@appservice1_user2:example.com" + + appservice2_user1 = "@appservice2_user1:example.com" + native_user1 = "@native_user1:example.com" + + service1 = "service1" + service2 = "service2" + native = "native" + + self.store.register_user( + user_id=appservice1_user1, password_hash=None, appservice_id=service1 + ) + self.store.register_user( + user_id=appservice1_user2, password_hash=None, appservice_id=service1 + ) + self.store.register_user( + user_id=appservice2_user1, password_hash=None, appservice_id=service2 + ) + self.store.register_user(user_id=native_user1, password_hash=None) + self.pump() + + count = self.store.get_monthly_active_count_by_service() + self.assertEqual({}, self.get_success(count)) + + self.store.upsert_monthly_active_user(native_user1) + self.store.upsert_monthly_active_user(appservice1_user1) + self.store.upsert_monthly_active_user(appservice1_user2) + self.store.upsert_monthly_active_user(appservice2_user1) + self.pump() + + count = self.store.get_monthly_active_count() + self.assertEqual(4, self.get_success(count)) + + count = self.store.get_monthly_active_count_by_service() + result = self.get_success(count) + + self.assertEqual(2, result[service1]) + self.assertEqual(1, result[service2]) + self.assertEqual(1, result[native]) -- cgit 1.5.1 From 1f5f3ae8b1c5db96d36ac7c104f13553bc4283da Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Sun, 8 Mar 2020 14:49:33 +0100 Subject: Add options to disable setting profile info for prevent changes. --- synapse/config/registration.py | 11 +++++++++++ synapse/handlers/profile.py | 10 ++++++++++ tests/handlers/test_profile.py | 33 ++++++++++++++++++++++++++++++++- 3 files changed, 53 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 9bb3beedbc..d9f452dcea 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -129,6 +129,9 @@ class RegistrationConfig(Config): raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,)) self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True) + self.disable_set_displayname = config.get("disable_set_displayname", False) + self.disable_set_avatar_url = config.get("disable_set_avatar_url", False) + self.disable_msisdn_registration = config.get( "disable_msisdn_registration", False ) @@ -330,6 +333,14 @@ class RegistrationConfig(Config): #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process + # If enabled, don't let users set their own display names/avatars + # other than for the very first time (unless they are a server admin). + # Useful when provisioning users based on the contents of a 3rd party + # directory and to avoid ambiguities. + # + # disable_set_displayname: False + # disable_set_avatar_url: False + # Users who register on this homeserver will automatically be joined # to these rooms # diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 50ce0c585b..fb7e84f3b8 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -157,6 +157,11 @@ class BaseProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's displayname") + if not by_admin and self.hs.config.disable_set_displayname: + profile = yield self.store.get_profileinfo(target_user.localpart) + if profile.display_name: + raise SynapseError(400, "Changing displayname is disabled on this server") + if len(new_displayname) > MAX_DISPLAYNAME_LEN: raise SynapseError( 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,) @@ -218,6 +223,11 @@ class BaseProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's avatar_url") + if not by_admin and self.hs.config.disable_set_avatar_url: + profile = yield self.store.get_profileinfo(target_user.localpart) + if profile.avatar_url: + raise SynapseError(400, "Changing avatar url is disabled on this server") + if len(new_avatar_url) > MAX_AVATAR_URL_LEN: raise SynapseError( 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index d60c124eec..b85520c688 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -19,7 +19,7 @@ from mock import Mock, NonCallableMock from twisted.internet import defer import synapse.types -from synapse.api.errors import AuthError +from synapse.api.errors import AuthError, SynapseError from synapse.handlers.profile import MasterProfileHandler from synapse.types import UserID @@ -70,6 +70,7 @@ class ProfileTestCase(unittest.TestCase): yield self.store.create_profile(self.frank.localpart) self.handler = hs.get_profile_handler() + self.config = hs.config @defer.inlineCallbacks def test_get_my_name(self): @@ -90,6 +91,19 @@ class ProfileTestCase(unittest.TestCase): "Frank Jr.", ) + @defer.inlineCallbacks + def test_set_my_name_if_disabled(self): + self.config.disable_set_displayname = True + + # Set first displayname is allowed, if displayname is null + self.store.set_profile_displayname(self.frank.localpart, "Frank") + + d = self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.frank), "Frank Jr." + ) + + yield self.assertFailure(d, SynapseError) + @defer.inlineCallbacks def test_set_my_name_noauth(self): d = self.handler.set_displayname( @@ -147,3 +161,20 @@ class ProfileTestCase(unittest.TestCase): (yield self.store.get_profile_avatar_url(self.frank.localpart)), "http://my.server/pic.gif", ) + + @defer.inlineCallbacks + def test_set_my_avatar_if_disabled(self): + self.config.disable_set_avatar_url = True + + # Set first time avatar is allowed, if displayname is null + self.store.set_profile_avatar_url( + self.frank.localpart, "http://my.server/me.png" + ) + + d = self.handler.set_avatar_url( + self.frank, + synapse.types.create_requester(self.frank), + "http://my.server/pic.gif", + ) + + yield self.assertFailure(d, SynapseError) -- cgit 1.5.1 From 06eb5cae08272c401a586991fc81f788825f910b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 9 Mar 2020 08:58:25 -0400 Subject: Remove special auth and redaction rules for aliases events in experimental room ver. (#7037) --- changelog.d/7037.feature | 1 + synapse/api/room_versions.py | 9 +-- synapse/crypto/event_signing.py | 2 +- synapse/event_auth.py | 8 +-- synapse/events/utils.py | 12 ++-- synapse/storage/data_stores/main/events.py | 10 +++- tests/events/test_utils.py | 35 ++++++++++- tests/test_event_auth.py | 93 +++++++++++++++++++++++++++++- 8 files changed, 148 insertions(+), 22 deletions(-) create mode 100644 changelog.d/7037.feature (limited to 'tests') diff --git a/changelog.d/7037.feature b/changelog.d/7037.feature new file mode 100644 index 0000000000..4bc1b3b19f --- /dev/null +++ b/changelog.d/7037.feature @@ -0,0 +1 @@ +Implement updated authorization rules and redaction rules for aliases events, from [MSC2261](https://github.com/matrix-org/matrix-doc/pull/2261) and [MSC2432](https://github.com/matrix-org/matrix-doc/pull/2432). diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index cf7ee60d3a..871179749a 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -57,7 +57,7 @@ class RoomVersion(object): state_res = attr.ib() # int; one of the StateResolutionVersions enforce_key_validity = attr.ib() # bool - # bool: before MSC2260, anyone was allowed to send an aliases event + # bool: before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules special_case_aliases_auth = attr.ib(type=bool, default=False) @@ -102,12 +102,13 @@ class RoomVersions(object): enforce_key_validity=True, special_case_aliases_auth=True, ) - MSC2260_DEV = RoomVersion( - "org.matrix.msc2260", + MSC2432_DEV = RoomVersion( + "org.matrix.msc2432", RoomDisposition.UNSTABLE, EventFormatVersions.V3, StateResolutionVersions.V2, enforce_key_validity=True, + special_case_aliases_auth=False, ) @@ -119,6 +120,6 @@ KNOWN_ROOM_VERSIONS = { RoomVersions.V3, RoomVersions.V4, RoomVersions.V5, - RoomVersions.MSC2260_DEV, + RoomVersions.MSC2432_DEV, ) } # type: Dict[str, RoomVersion] diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py index 5f733c1cf5..0422c43fab 100644 --- a/synapse/crypto/event_signing.py +++ b/synapse/crypto/event_signing.py @@ -140,7 +140,7 @@ def compute_event_signature( Returns: a dictionary in the same format of an event's signatures field. """ - redact_json = prune_event_dict(event_dict) + redact_json = prune_event_dict(room_version, event_dict) redact_json.pop("age_ts", None) redact_json.pop("unsigned", None) if logger.isEnabledFor(logging.DEBUG): diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 472f165044..46beb5334f 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -137,7 +137,7 @@ def check( raise AuthError(403, "This room has been marked as unfederatable.") # 4. If type is m.room.aliases - if event.type == EventTypes.Aliases: + if event.type == EventTypes.Aliases and room_version_obj.special_case_aliases_auth: # 4a. If event has no state_key, reject if not event.is_state(): raise AuthError(403, "Alias event must be a state event") @@ -152,10 +152,8 @@ def check( ) # 4c. Otherwise, allow. - # This is removed by https://github.com/matrix-org/matrix-doc/pull/2260 - if room_version_obj.special_case_aliases_auth: - logger.debug("Allowing! %s", event) - return + logger.debug("Allowing! %s", event) + return if logger.isEnabledFor(logging.DEBUG): logger.debug("Auth events: %s", [a.event_id for a in auth_events.values()]) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index bc6f98ae3b..b75b097e5e 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -23,6 +23,7 @@ from frozendict import frozendict from twisted.internet import defer from synapse.api.constants import EventTypes, RelationTypes +from synapse.api.room_versions import RoomVersion from synapse.util.async_helpers import yieldable_gather_results from . import EventBase @@ -43,7 +44,7 @@ def prune_event(event: EventBase) -> EventBase: the user has specified, but we do want to keep necessary information like type, state_key etc. """ - pruned_event_dict = prune_event_dict(event.get_dict()) + pruned_event_dict = prune_event_dict(event.room_version, event.get_dict()) from . import make_event_from_dict @@ -57,15 +58,12 @@ def prune_event(event: EventBase) -> EventBase: return pruned_event -def prune_event_dict(event_dict): +def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict: """Redacts the event_dict in the same way as `prune_event`, except it operates on dicts rather than event objects - Args: - event_dict (dict) - Returns: - dict: A copy of the pruned event dict + A copy of the pruned event dict """ allowed_keys = [ @@ -112,7 +110,7 @@ def prune_event_dict(event_dict): "kick", "redact", ) - elif event_type == EventTypes.Aliases: + elif event_type == EventTypes.Aliases and room_version.special_case_aliases_auth: add_fields("aliases") elif event_type == EventTypes.RoomHistoryVisibility: add_fields("history_visibility") diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 8ae23df00a..d593ef47b8 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -1168,7 +1168,11 @@ class EventsStore( and original_event.internal_metadata.is_redacted() ): # Redaction was allowed - pruned_json = encode_json(prune_event_dict(original_event.get_dict())) + pruned_json = encode_json( + prune_event_dict( + original_event.room_version, original_event.get_dict() + ) + ) else: # Redaction wasn't allowed pruned_json = None @@ -1929,7 +1933,9 @@ class EventsStore( return # Prune the event's dict then convert it to JSON. - pruned_json = encode_json(prune_event_dict(event.get_dict())) + pruned_json = encode_json( + prune_event_dict(event.room_version, event.get_dict()) + ) # Update the event_json table to replace the event's JSON with the pruned # JSON. diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index 45d55b9e94..ab5f5ac549 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.api.room_versions import RoomVersions from synapse.events import make_event_from_dict from synapse.events.utils import ( copy_power_levels_contents, @@ -36,9 +37,9 @@ class PruneEventTestCase(unittest.TestCase): """ Asserts that a new event constructed with `evdict` will look like `matchdict` when it is redacted. """ - def run_test(self, evdict, matchdict): + def run_test(self, evdict, matchdict, **kwargs): self.assertEquals( - prune_event(make_event_from_dict(evdict)).get_dict(), matchdict + prune_event(make_event_from_dict(evdict, **kwargs)).get_dict(), matchdict ) def test_minimal(self): @@ -128,6 +129,36 @@ class PruneEventTestCase(unittest.TestCase): }, ) + def test_alias_event(self): + """Alias events have special behavior up through room version 6.""" + self.run_test( + { + "type": "m.room.aliases", + "event_id": "$test:domain", + "content": {"aliases": ["test"]}, + }, + { + "type": "m.room.aliases", + "event_id": "$test:domain", + "content": {"aliases": ["test"]}, + "signatures": {}, + "unsigned": {}, + }, + ) + + def test_msc2432_alias_event(self): + """After MSC2432, alias events have no special behavior.""" + self.run_test( + {"type": "m.room.aliases", "content": {"aliases": ["test"]}}, + { + "type": "m.room.aliases", + "content": {}, + "signatures": {}, + "unsigned": {}, + }, + room_version=RoomVersions.MSC2432_DEV, + ) + class SerializeEventTestCase(unittest.TestCase): def serialize(self, ev, fields): diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index bfa5d6f510..6c2351cf55 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -19,6 +19,7 @@ from synapse import event_auth from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersions from synapse.events import make_event_from_dict +from synapse.types import get_domain_from_id class EventAuthTestCase(unittest.TestCase): @@ -51,7 +52,7 @@ class EventAuthTestCase(unittest.TestCase): _random_state_event(joiner), auth_events, do_sig_check=False, - ), + ) def test_state_default_level(self): """ @@ -87,6 +88,83 @@ class EventAuthTestCase(unittest.TestCase): RoomVersions.V1, _random_state_event(king), auth_events, do_sig_check=False, ) + def test_alias_event(self): + """Alias events have special behavior up through room version 6.""" + creator = "@creator:example.com" + other = "@other:example.com" + auth_events = { + ("m.room.create", ""): _create_event(creator), + ("m.room.member", creator): _join_event(creator), + } + + # creator should be able to send aliases + event_auth.check( + RoomVersions.V1, _alias_event(creator), auth_events, do_sig_check=False, + ) + + # Reject an event with no state key. + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.V1, + _alias_event(creator, state_key=""), + auth_events, + do_sig_check=False, + ) + + # If the domain of the sender does not match the state key, reject. + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.V1, + _alias_event(creator, state_key="test.com"), + auth_events, + do_sig_check=False, + ) + + # Note that the member does *not* need to be in the room. + event_auth.check( + RoomVersions.V1, _alias_event(other), auth_events, do_sig_check=False, + ) + + def test_msc2432_alias_event(self): + """After MSC2432, alias events have no special behavior.""" + creator = "@creator:example.com" + other = "@other:example.com" + auth_events = { + ("m.room.create", ""): _create_event(creator), + ("m.room.member", creator): _join_event(creator), + } + + # creator should be able to send aliases + event_auth.check( + RoomVersions.MSC2432_DEV, + _alias_event(creator), + auth_events, + do_sig_check=False, + ) + + # No particular checks are done on the state key. + event_auth.check( + RoomVersions.MSC2432_DEV, + _alias_event(creator, state_key=""), + auth_events, + do_sig_check=False, + ) + event_auth.check( + RoomVersions.MSC2432_DEV, + _alias_event(creator, state_key="test.com"), + auth_events, + do_sig_check=False, + ) + + # Per standard auth rules, the member must be in the room. + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.MSC2432_DEV, + _alias_event(other), + auth_events, + do_sig_check=False, + ) + # helpers for making events @@ -131,6 +209,19 @@ def _power_levels_event(sender, content): ) +def _alias_event(sender, **kwargs): + data = { + "room_id": TEST_ROOM_ID, + "event_id": _get_event_id(), + "type": "m.room.aliases", + "sender": sender, + "state_key": get_domain_from_id(sender), + "content": {"aliases": []}, + } + data.update(**kwargs) + return make_event_from_dict(data) + + def _random_state_event(sender): return make_event_from_dict( { -- cgit 1.5.1 From 04f4b5f6f87fbba0b2f1a4f011c496de3021c81a Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Mon, 9 Mar 2020 19:51:31 +0100 Subject: add tests --- tests/handlers/test_profile.py | 6 +- tests/rest/client/v2_alpha/test_account.py | 308 +++++++++++++++++++++++++++++ 2 files changed, 311 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index b85520c688..98b508c3d4 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -70,7 +70,7 @@ class ProfileTestCase(unittest.TestCase): yield self.store.create_profile(self.frank.localpart) self.handler = hs.get_profile_handler() - self.config = hs.config + self.hs = hs @defer.inlineCallbacks def test_get_my_name(self): @@ -93,7 +93,7 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_name_if_disabled(self): - self.config.disable_set_displayname = True + self.hs.config.disable_set_displayname = True # Set first displayname is allowed, if displayname is null self.store.set_profile_displayname(self.frank.localpart, "Frank") @@ -164,7 +164,7 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_avatar_if_disabled(self): - self.config.disable_set_avatar_url = True + self.hs.config.disable_set_avatar_url = True # Set first time avatar is allowed, if displayname is null self.store.set_profile_avatar_url( diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index c3facc00eb..ac9f200de3 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -325,3 +325,311 @@ class DeactivateTestCase(unittest.HomeserverTestCase): ) self.render(request) self.assertEqual(request.code, 200) + + +class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + account.register_servlets, + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + # Email config. + self.email_attempts = [] + + def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs): + self.email_attempts.append(msg) + return + + config["email"] = { + "enable_notifs": False, + "template_dir": os.path.abspath( + pkg_resources.resource_filename("synapse", "res/templates") + ), + "smtp_host": "127.0.0.1", + "smtp_port": 20, + "require_transport_security": False, + "smtp_user": None, + "smtp_pass": None, + "notif_from": "test@example.com", + } + config["public_baseurl"] = "https://example.com" + + self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + return self.hs + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.user_id = self.register_user("kermit", "test") + self.user_id_tok = self.login("kermit", "test") + self.email = "test@example.com" + self.url_3pid = b"account/3pid" + + def test_add_email(self): + """Test add mail to profile + """ + client_secret = "foobar" + session_id = self._request_token(self.email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) + + def test_add_email_if_disabled(self): + """Test add mail to profile if disabled + """ + self.hs.config.disable_3pid_changes = True + + client_secret = "foobar" + session_id = self._request_token(self.email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("3PID changes disabled on this server", channel.json_body["error"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_delete_email(self): + """Test delete mail from profile + """ + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=self.user_id, + medium="email", + address=self.email, + validated_at=0, + added_at=0, + ) + ) + + request, channel = self.make_request( + "POST", + b"account/3pid/delete", + { + "medium": "email", + "address": self.email + }, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_delete_email_if_disabled(self): + """Test delete mail from profile if disabled + """ + self.hs.config.disable_3pid_changes = True + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=self.user_id, + medium="email", + address=self.email, + validated_at=0, + added_at=0, + ) + ) + + request, channel = self.make_request( + "POST", + b"account/3pid/delete", + { + "medium": "email", + "address": self.email + }, + access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("3PID changes disabled on this server", channel.json_body["error"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) + + def test_cant_add_email_without_clicking_link(self): + """Test that we do actually need to click the link in the email + """ + client_secret = "foobar" + session_id = self._request_token(self.email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + + # Attempt to add email without clicking the link + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("No validated 3pid session found", channel.json_body["error"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_no_valid_token(self): + """Test that we do actually need to request a token and can't just + make a session up. + """ + client_secret = "foobar" + session_id = "weasle" + + # Attempt to add email without even requesting an email + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("No validated 3pid session found", channel.json_body["error"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def _request_token(self, email, client_secret): + request, channel = self.make_request( + "POST", + b"account/3pid/email/requestToken", + {"client_secret": client_secret, "email": email, "send_attempt": 1}, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + return channel.json_body["sid"] + + def _validate_token(self, link): + # Remove the host + path = link.replace("https://example.com", "") + + request, channel = self.make_request("GET", path, shorthand=False) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + def _get_link_from_email(self): + assert self.email_attempts, "No emails have been sent" + + raw_msg = self.email_attempts[-1].decode("UTF-8") + mail = Parser().parsestr(raw_msg) + + text = None + for part in mail.walk(): + if part.get_content_type() == "text/plain": + text = part.get_payload(decode=True).decode("UTF-8") + break + + if not text: + self.fail("Could not find text portion of email to parse") + + match = re.search(r"https://example.com\S+", text) + assert match, "Could not find link in email" + + return match.group(0) -- cgit 1.5.1 From 50ea178c201588b5e6b3f93e1af56aef0b4e8368 Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Mon, 9 Mar 2020 19:57:04 +0100 Subject: lint --- tests/rest/client/v2_alpha/test_account.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) (limited to 'tests') diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index ac9f200de3..e178a53335 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -438,7 +438,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): ) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("3PID changes disabled on this server", channel.json_body["error"]) + self.assertEqual( + "3PID changes disabled on this server", channel.json_body["error"] + ) # Get user request, channel = self.make_request( @@ -466,10 +468,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "POST", b"account/3pid/delete", - { - "medium": "email", - "address": self.email - }, + {"medium": "email", "address": self.email}, access_token=self.user_id_tok, ) self.render(request) @@ -503,16 +502,15 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "POST", b"account/3pid/delete", - { - "medium": "email", - "address": self.email - }, + {"medium": "email", "address": self.email}, access_token=self.user_id_tok, ) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("3PID changes disabled on this server", channel.json_body["error"]) + self.assertEqual( + "3PID changes disabled on this server", channel.json_body["error"] + ) # Get user request, channel = self.make_request( -- cgit 1.5.1 From 7e5f40e7716813f0d32e2efcb32df3c263fbfc63 Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Mon, 9 Mar 2020 21:00:36 +0100 Subject: fix tests --- tests/handlers/test_profile.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 98b508c3d4..f8c0da5ced 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -96,7 +96,7 @@ class ProfileTestCase(unittest.TestCase): self.hs.config.disable_set_displayname = True # Set first displayname is allowed, if displayname is null - self.store.set_profile_displayname(self.frank.localpart, "Frank") + yield self.store.set_profile_displayname(self.frank.localpart, "Frank") d = self.handler.set_displayname( self.frank, synapse.types.create_requester(self.frank), "Frank Jr." @@ -167,7 +167,7 @@ class ProfileTestCase(unittest.TestCase): self.hs.config.disable_set_avatar_url = True # Set first time avatar is allowed, if displayname is null - self.store.set_profile_avatar_url( + yield self.store.set_profile_avatar_url( self.frank.localpart, "http://my.server/me.png" ) -- cgit 1.5.1 From 885134529ffd95dd118d3228e69f0e3553f5a6a7 Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Mon, 9 Mar 2020 22:09:29 +0100 Subject: updates after review --- changelog.d/7053.feature | 2 +- docs/sample_config.yaml | 10 +++++----- synapse/config/registration.py | 16 ++++++++-------- synapse/handlers/profile.py | 8 ++++---- synapse/rest/client/v2_alpha/account.py | 18 ++++++++++++------ tests/handlers/test_profile.py | 6 +++--- tests/rest/client/v2_alpha/test_account.py | 17 +++++++---------- 7 files changed, 40 insertions(+), 37 deletions(-) (limited to 'tests') diff --git a/changelog.d/7053.feature b/changelog.d/7053.feature index 79955b9780..00f47b2a14 100644 --- a/changelog.d/7053.feature +++ b/changelog.d/7053.feature @@ -1 +1 @@ -Add options to disable setting profile info for prevent changes. \ No newline at end of file +Add options to prevent users from changing their profile or associated 3PIDs. \ No newline at end of file diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index d3ecffac7d..8333800a10 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1057,18 +1057,18 @@ account_threepid_delegates: #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process -# If enabled, don't let users set their own display names/avatars +# If disabled, don't let users set their own display names/avatars # other than for the very first time (unless they are a server admin). # Useful when provisioning users based on the contents of a 3rd party # directory and to avoid ambiguities. # -#disable_set_displayname: false -#disable_set_avatar_url: false +#enable_set_displayname: true +#enable_set_avatar_url: true -# If true, stop users from trying to change the 3PIDs associated with +# If false, stop users from trying to change the 3PIDs associated with # their accounts. # -#disable_3pid_changes: false +#enable_3pid_changes: true # Users who register on this homeserver will automatically be joined # to these rooms diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 1abc0a79af..d4897ec9b6 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -129,9 +129,9 @@ class RegistrationConfig(Config): raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,)) self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True) - self.disable_set_displayname = config.get("disable_set_displayname", False) - self.disable_set_avatar_url = config.get("disable_set_avatar_url", False) - self.disable_3pid_changes = config.get("disable_3pid_changes", False) + self.enable_set_displayname = config.get("enable_set_displayname", True) + self.enable_set_avatar_url = config.get("enable_set_avatar_url", True) + self.enable_3pid_changes = config.get("enable_3pid_changes", True) self.disable_msisdn_registration = config.get( "disable_msisdn_registration", False @@ -334,18 +334,18 @@ class RegistrationConfig(Config): #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process - # If enabled, don't let users set their own display names/avatars + # If disabled, don't let users set their own display names/avatars # other than for the very first time (unless they are a server admin). # Useful when provisioning users based on the contents of a 3rd party # directory and to avoid ambiguities. # - #disable_set_displayname: false - #disable_set_avatar_url: false + #enable_set_displayname: true + #enable_set_avatar_url: true - # If true, stop users from trying to change the 3PIDs associated with + # If false, stop users from trying to change the 3PIDs associated with # their accounts. # - #disable_3pid_changes: false + #enable_3pid_changes: true # Users who register on this homeserver will automatically be joined # to these rooms diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index b049dd8e26..eb85dba015 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -157,11 +157,11 @@ class BaseProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's displayname") - if not by_admin and self.hs.config.disable_set_displayname: + if not by_admin and not self.hs.config.enable_set_displayname: profile = yield self.store.get_profileinfo(target_user.localpart) if profile.display_name: raise SynapseError( - 400, "Changing displayname is disabled on this server" + 400, "Changing display name is disabled on this server", Codes.FORBIDDEN ) if len(new_displayname) > MAX_DISPLAYNAME_LEN: @@ -225,11 +225,11 @@ class BaseProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's avatar_url") - if not by_admin and self.hs.config.disable_set_avatar_url: + if not by_admin and not self.hs.config.enable_set_avatar_url: profile = yield self.store.get_profileinfo(target_user.localpart) if profile.avatar_url: raise SynapseError( - 400, "Changing avatar url is disabled on this server" + 400, "Changing avatar is disabled on this server", Codes.FORBIDDEN ) if len(new_avatar_url) > MAX_AVATAR_URL_LEN: diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 97bddf36d9..e40136f2f3 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -599,8 +599,10 @@ class ThreepidRestServlet(RestServlet): return 200, {"threepids": threepids} async def on_POST(self, request): - if self.hs.config.disable_3pid_changes: - raise SynapseError(400, "3PID changes disabled on this server") + if not self.hs.config.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() @@ -646,8 +648,10 @@ class ThreepidAddRestServlet(RestServlet): @interactive_auth_handler async def on_POST(self, request): - if self.hs.config.disable_3pid_changes: - raise SynapseError(400, "3PID changes disabled on this server") + if not self.hs.config.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() @@ -749,8 +753,10 @@ class ThreepidDeleteRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() async def on_POST(self, request): - if self.hs.config.disable_3pid_changes: - raise SynapseError(400, "3PID changes disabled on this server") + if not self.hs.config.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) body = parse_json_object_from_request(request) assert_params_in_dict(body, ["medium", "address"]) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index f8c0da5ced..e600b9777b 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -93,7 +93,7 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_name_if_disabled(self): - self.hs.config.disable_set_displayname = True + self.hs.config.enable_set_displayname = False # Set first displayname is allowed, if displayname is null yield self.store.set_profile_displayname(self.frank.localpart, "Frank") @@ -164,9 +164,9 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_avatar_if_disabled(self): - self.hs.config.disable_set_avatar_url = True + self.hs.config.enable_set_avatar_url = False - # Set first time avatar is allowed, if displayname is null + # Set first time avatar is allowed, if avatar is null yield self.store.set_profile_avatar_url( self.frank.localpart, "http://my.server/me.png" ) diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index e178a53335..34e40a36d0 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -24,6 +24,7 @@ import pkg_resources import synapse.rest.admin from synapse.api.constants import LoginType, Membership +from synapse.api.errors import Codes from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import account, register @@ -412,7 +413,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def test_add_email_if_disabled(self): """Test add mail to profile if disabled """ - self.hs.config.disable_3pid_changes = True + self.hs.config.enable_3pid_changes = True client_secret = "foobar" session_id = self._request_token(self.email, client_secret) @@ -438,9 +439,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): ) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual( - "3PID changes disabled on this server", channel.json_body["error"] - ) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Get user request, channel = self.make_request( @@ -486,7 +485,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def test_delete_email_if_disabled(self): """Test delete mail from profile if disabled """ - self.hs.config.disable_3pid_changes = True + self.hs.config.enable_3pid_changes = True # Add a threepid self.get_success( @@ -508,9 +507,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual( - "3PID changes disabled on this server", channel.json_body["error"] - ) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Get user request, channel = self.make_request( @@ -547,7 +544,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): ) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("No validated 3pid session found", channel.json_body["error"]) + self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) # Get user request, channel = self.make_request( @@ -582,7 +579,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): ) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("No validated 3pid session found", channel.json_body["error"]) + self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) # Get user request, channel = self.make_request( -- cgit 1.5.1 From 39f6595b4ab108cb451072ae251a91117002191c Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Mon, 9 Mar 2020 22:13:20 +0100 Subject: lint, fix tests --- synapse/handlers/profile.py | 4 +++- tests/rest/client/v2_alpha/test_account.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index eb85dba015..6aa1c0f5e0 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -161,7 +161,9 @@ class BaseProfileHandler(BaseHandler): profile = yield self.store.get_profileinfo(target_user.localpart) if profile.display_name: raise SynapseError( - 400, "Changing display name is disabled on this server", Codes.FORBIDDEN + 400, + "Changing display name is disabled on this server", + Codes.FORBIDDEN, ) if len(new_displayname) > MAX_DISPLAYNAME_LEN: diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 34e40a36d0..99cc9163f3 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -413,7 +413,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def test_add_email_if_disabled(self): """Test add mail to profile if disabled """ - self.hs.config.enable_3pid_changes = True + self.hs.config.enable_3pid_changes = False client_secret = "foobar" session_id = self._request_token(self.email, client_secret) @@ -485,7 +485,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def test_delete_email_if_disabled(self): """Test delete mail from profile if disabled """ - self.hs.config.enable_3pid_changes = True + self.hs.config.enable_3pid_changes = False # Add a threepid self.get_success( -- cgit 1.5.1 From 6a35046363a6f5d41199256c80eef4ea7e385986 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 17 Mar 2020 11:25:01 +0000 Subject: Revert "Add options to disable setting profile info for prevent changes. (#7053)" This reverts commit 54dd28621b070ca67de9f773fe9a89e1f4dc19da, reversing changes made to 6640460d054e8f4444046a34bdf638921b31c01e. --- changelog.d/7053.feature | 1 - docs/sample_config.yaml | 13 -- synapse/config/registration.py | 17 -- synapse/handlers/profile.py | 16 -- synapse/rest/client/v2_alpha/account.py | 16 -- tests/handlers/test_profile.py | 33 +--- tests/rest/client/v2_alpha/test_account.py | 303 ----------------------------- 7 files changed, 1 insertion(+), 398 deletions(-) delete mode 100644 changelog.d/7053.feature (limited to 'tests') diff --git a/changelog.d/7053.feature b/changelog.d/7053.feature deleted file mode 100644 index 00f47b2a14..0000000000 --- a/changelog.d/7053.feature +++ /dev/null @@ -1 +0,0 @@ -Add options to prevent users from changing their profile or associated 3PIDs. \ No newline at end of file diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 91eff4c8ad..2ff0dd05a2 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1057,19 +1057,6 @@ account_threepid_delegates: #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process -# If disabled, don't let users set their own display names/avatars -# (unless they are a server admin) other than for the very first time. -# Useful when provisioning users based on the contents of a 3rd party -# directory and to avoid ambiguities. -# -#enable_set_displayname: true -#enable_set_avatar_url: true - -# If false, stop users from trying to change the 3PIDs associated with -# their accounts. -# -#enable_3pid_changes: true - # Users who register on this homeserver will automatically be joined # to these rooms # diff --git a/synapse/config/registration.py b/synapse/config/registration.py index ee737eb40d..9bb3beedbc 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -129,10 +129,6 @@ class RegistrationConfig(Config): raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,)) self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True) - self.enable_set_displayname = config.get("enable_set_displayname", True) - self.enable_set_avatar_url = config.get("enable_set_avatar_url", True) - self.enable_3pid_changes = config.get("enable_3pid_changes", True) - self.disable_msisdn_registration = config.get( "disable_msisdn_registration", False ) @@ -334,19 +330,6 @@ class RegistrationConfig(Config): #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process - # If disabled, don't let users set their own display names/avatars - # (unless they are a server admin) other than for the very first time. - # Useful when provisioning users based on the contents of a 3rd party - # directory and to avoid ambiguities. - # - #enable_set_displayname: true - #enable_set_avatar_url: true - - # If false, stop users from trying to change the 3PIDs associated with - # their accounts. - # - #enable_3pid_changes: true - # Users who register on this homeserver will automatically be joined # to these rooms # diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 6aa1c0f5e0..50ce0c585b 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -157,15 +157,6 @@ class BaseProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's displayname") - if not by_admin and not self.hs.config.enable_set_displayname: - profile = yield self.store.get_profileinfo(target_user.localpart) - if profile.display_name: - raise SynapseError( - 400, - "Changing display name is disabled on this server", - Codes.FORBIDDEN, - ) - if len(new_displayname) > MAX_DISPLAYNAME_LEN: raise SynapseError( 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,) @@ -227,13 +218,6 @@ class BaseProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's avatar_url") - if not by_admin and not self.hs.config.enable_set_avatar_url: - profile = yield self.store.get_profileinfo(target_user.localpart) - if profile.avatar_url: - raise SynapseError( - 400, "Changing avatar is disabled on this server", Codes.FORBIDDEN - ) - if len(new_avatar_url) > MAX_AVATAR_URL_LEN: raise SynapseError( 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index e40136f2f3..dc837d6c75 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -599,11 +599,6 @@ class ThreepidRestServlet(RestServlet): return 200, {"threepids": threepids} async def on_POST(self, request): - if not self.hs.config.enable_3pid_changes: - raise SynapseError( - 400, "3PID changes are disabled on this server", Codes.FORBIDDEN - ) - requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -648,11 +643,6 @@ class ThreepidAddRestServlet(RestServlet): @interactive_auth_handler async def on_POST(self, request): - if not self.hs.config.enable_3pid_changes: - raise SynapseError( - 400, "3PID changes are disabled on this server", Codes.FORBIDDEN - ) - requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -748,16 +738,10 @@ class ThreepidDeleteRestServlet(RestServlet): def __init__(self, hs): super(ThreepidDeleteRestServlet, self).__init__() - self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() async def on_POST(self, request): - if not self.hs.config.enable_3pid_changes: - raise SynapseError( - 400, "3PID changes are disabled on this server", Codes.FORBIDDEN - ) - body = parse_json_object_from_request(request) assert_params_in_dict(body, ["medium", "address"]) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index e600b9777b..d60c124eec 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -19,7 +19,7 @@ from mock import Mock, NonCallableMock from twisted.internet import defer import synapse.types -from synapse.api.errors import AuthError, SynapseError +from synapse.api.errors import AuthError from synapse.handlers.profile import MasterProfileHandler from synapse.types import UserID @@ -70,7 +70,6 @@ class ProfileTestCase(unittest.TestCase): yield self.store.create_profile(self.frank.localpart) self.handler = hs.get_profile_handler() - self.hs = hs @defer.inlineCallbacks def test_get_my_name(self): @@ -91,19 +90,6 @@ class ProfileTestCase(unittest.TestCase): "Frank Jr.", ) - @defer.inlineCallbacks - def test_set_my_name_if_disabled(self): - self.hs.config.enable_set_displayname = False - - # Set first displayname is allowed, if displayname is null - yield self.store.set_profile_displayname(self.frank.localpart, "Frank") - - d = self.handler.set_displayname( - self.frank, synapse.types.create_requester(self.frank), "Frank Jr." - ) - - yield self.assertFailure(d, SynapseError) - @defer.inlineCallbacks def test_set_my_name_noauth(self): d = self.handler.set_displayname( @@ -161,20 +147,3 @@ class ProfileTestCase(unittest.TestCase): (yield self.store.get_profile_avatar_url(self.frank.localpart)), "http://my.server/pic.gif", ) - - @defer.inlineCallbacks - def test_set_my_avatar_if_disabled(self): - self.hs.config.enable_set_avatar_url = False - - # Set first time avatar is allowed, if avatar is null - yield self.store.set_profile_avatar_url( - self.frank.localpart, "http://my.server/me.png" - ) - - d = self.handler.set_avatar_url( - self.frank, - synapse.types.create_requester(self.frank), - "http://my.server/pic.gif", - ) - - yield self.assertFailure(d, SynapseError) diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 99cc9163f3..c3facc00eb 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -24,7 +24,6 @@ import pkg_resources import synapse.rest.admin from synapse.api.constants import LoginType, Membership -from synapse.api.errors import Codes from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import account, register @@ -326,305 +325,3 @@ class DeactivateTestCase(unittest.HomeserverTestCase): ) self.render(request) self.assertEqual(request.code, 200) - - -class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): - - servlets = [ - account.register_servlets, - login.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, - ] - - def make_homeserver(self, reactor, clock): - config = self.default_config() - - # Email config. - self.email_attempts = [] - - def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs): - self.email_attempts.append(msg) - return - - config["email"] = { - "enable_notifs": False, - "template_dir": os.path.abspath( - pkg_resources.resource_filename("synapse", "res/templates") - ), - "smtp_host": "127.0.0.1", - "smtp_port": 20, - "require_transport_security": False, - "smtp_user": None, - "smtp_pass": None, - "notif_from": "test@example.com", - } - config["public_baseurl"] = "https://example.com" - - self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail) - return self.hs - - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - - self.user_id = self.register_user("kermit", "test") - self.user_id_tok = self.login("kermit", "test") - self.email = "test@example.com" - self.url_3pid = b"account/3pid" - - def test_add_email(self): - """Test add mail to profile - """ - client_secret = "foobar" - session_id = self._request_token(self.email, client_secret) - - self.assertEquals(len(self.email_attempts), 1) - link = self._get_link_from_email() - - self._validate_token(link) - - request, channel = self.make_request( - "POST", - b"/_matrix/client/unstable/account/3pid/add", - { - "client_secret": client_secret, - "sid": session_id, - "auth": { - "type": "m.login.password", - "user": self.user_id, - "password": "test", - }, - }, - access_token=self.user_id_tok, - ) - - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Get user - request, channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) - - def test_add_email_if_disabled(self): - """Test add mail to profile if disabled - """ - self.hs.config.enable_3pid_changes = False - - client_secret = "foobar" - session_id = self._request_token(self.email, client_secret) - - self.assertEquals(len(self.email_attempts), 1) - link = self._get_link_from_email() - - self._validate_token(link) - - request, channel = self.make_request( - "POST", - b"/_matrix/client/unstable/account/3pid/add", - { - "client_secret": client_secret, - "sid": session_id, - "auth": { - "type": "m.login.password", - "user": self.user_id, - "password": "test", - }, - }, - access_token=self.user_id_tok, - ) - self.render(request) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - # Get user - request, channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertFalse(channel.json_body["threepids"]) - - def test_delete_email(self): - """Test delete mail from profile - """ - # Add a threepid - self.get_success( - self.store.user_add_threepid( - user_id=self.user_id, - medium="email", - address=self.email, - validated_at=0, - added_at=0, - ) - ) - - request, channel = self.make_request( - "POST", - b"account/3pid/delete", - {"medium": "email", "address": self.email}, - access_token=self.user_id_tok, - ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Get user - request, channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertFalse(channel.json_body["threepids"]) - - def test_delete_email_if_disabled(self): - """Test delete mail from profile if disabled - """ - self.hs.config.enable_3pid_changes = False - - # Add a threepid - self.get_success( - self.store.user_add_threepid( - user_id=self.user_id, - medium="email", - address=self.email, - validated_at=0, - added_at=0, - ) - ) - - request, channel = self.make_request( - "POST", - b"account/3pid/delete", - {"medium": "email", "address": self.email}, - access_token=self.user_id_tok, - ) - self.render(request) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - # Get user - request, channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) - - def test_cant_add_email_without_clicking_link(self): - """Test that we do actually need to click the link in the email - """ - client_secret = "foobar" - session_id = self._request_token(self.email, client_secret) - - self.assertEquals(len(self.email_attempts), 1) - - # Attempt to add email without clicking the link - request, channel = self.make_request( - "POST", - b"/_matrix/client/unstable/account/3pid/add", - { - "client_secret": client_secret, - "sid": session_id, - "auth": { - "type": "m.login.password", - "user": self.user_id, - "password": "test", - }, - }, - access_token=self.user_id_tok, - ) - self.render(request) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) - - # Get user - request, channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertFalse(channel.json_body["threepids"]) - - def test_no_valid_token(self): - """Test that we do actually need to request a token and can't just - make a session up. - """ - client_secret = "foobar" - session_id = "weasle" - - # Attempt to add email without even requesting an email - request, channel = self.make_request( - "POST", - b"/_matrix/client/unstable/account/3pid/add", - { - "client_secret": client_secret, - "sid": session_id, - "auth": { - "type": "m.login.password", - "user": self.user_id, - "password": "test", - }, - }, - access_token=self.user_id_tok, - ) - self.render(request) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) - - # Get user - request, channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertFalse(channel.json_body["threepids"]) - - def _request_token(self, email, client_secret): - request, channel = self.make_request( - "POST", - b"account/3pid/email/requestToken", - {"client_secret": client_secret, "email": email, "send_attempt": 1}, - ) - self.render(request) - self.assertEquals(200, channel.code, channel.result) - - return channel.json_body["sid"] - - def _validate_token(self, link): - # Remove the host - path = link.replace("https://example.com", "") - - request, channel = self.make_request("GET", path, shorthand=False) - self.render(request) - self.assertEquals(200, channel.code, channel.result) - - def _get_link_from_email(self): - assert self.email_attempts, "No emails have been sent" - - raw_msg = self.email_attempts[-1].decode("UTF-8") - mail = Parser().parsestr(raw_msg) - - text = None - for part in mail.walk(): - if part.get_content_type() == "text/plain": - text = part.get_payload(decode=True).decode("UTF-8") - break - - if not text: - self.fail("Could not find text portion of email to parse") - - match = re.search(r"https://example.com\S+", text) - assert match, "Could not find link in email" - - return match.group(0) -- cgit 1.5.1 From 60724c46b7dc5300243fd97d5a485564b3e00afe Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 17 Mar 2020 07:37:04 -0400 Subject: Remove special casing of `m.room.aliases` events (#7034) --- changelog.d/7034.removal | 1 + synapse/handlers/room.py | 16 +------------ synapse/rest/client/v1/room.py | 12 ---------- tests/rest/admin/test_admin.py | 7 ++++++ tests/rest/client/v1/test_directory.py | 41 +++++++++++++++++++++------------- 5 files changed, 35 insertions(+), 42 deletions(-) create mode 100644 changelog.d/7034.removal (limited to 'tests') diff --git a/changelog.d/7034.removal b/changelog.d/7034.removal new file mode 100644 index 0000000000..be8d20e14f --- /dev/null +++ b/changelog.d/7034.removal @@ -0,0 +1 @@ +Remove special handling of aliases events from [MSC2260](https://github.com/matrix-org/matrix-doc/pull/2260) added in v1.10.0rc1. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 8ee870f0bb..f580ab2e9f 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -292,16 +292,6 @@ class RoomCreationHandler(BaseHandler): except AuthError as e: logger.warning("Unable to update PLs in old room: %s", e) - new_pl_content = copy_power_levels_contents(old_room_pl_state.content) - - # pre-msc2260 rooms may not have the right setting for aliases. If no other - # value is set, set it now. - events_default = new_pl_content.get("events_default", 0) - new_pl_content.setdefault("events", {}).setdefault( - EventTypes.Aliases, events_default - ) - - logger.debug("Setting correct PLs in new room to %s", new_pl_content) yield self.event_creation_handler.create_and_send_nonmember_event( requester, { @@ -309,7 +299,7 @@ class RoomCreationHandler(BaseHandler): "state_key": "", "room_id": new_room_id, "sender": requester.user.to_string(), - "content": new_pl_content, + "content": old_room_pl_state.content, }, ratelimit=False, ) @@ -814,10 +804,6 @@ class RoomCreationHandler(BaseHandler): EventTypes.RoomHistoryVisibility: 100, EventTypes.CanonicalAlias: 50, EventTypes.RoomAvatar: 50, - # MSC2260: Allow everybody to send alias events by default - # This will be reudundant on pre-MSC2260 rooms, since the - # aliases event is special-cased. - EventTypes.Aliases: 0, EventTypes.Tombstone: 100, EventTypes.ServerACL: 100, }, diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 64f51406fb..bffd43de5f 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -189,12 +189,6 @@ class RoomStateEventRestServlet(TransactionRestServlet): content = parse_json_object_from_request(request) - if event_type == EventTypes.Aliases: - # MSC2260 - raise SynapseError( - 400, "Cannot send m.room.aliases events via /rooms/{room_id}/state" - ) - event_dict = { "type": event_type, "content": content, @@ -242,12 +236,6 @@ class RoomSendEventRestServlet(TransactionRestServlet): requester = await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) - if event_type == EventTypes.Aliases: - # MSC2260 - raise SynapseError( - 400, "Cannot send m.room.aliases events via /rooms/{room_id}/send" - ) - event_dict = { "type": event_type, "content": content, diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index e5984aaad8..0342aed416 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -868,6 +868,13 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Set this new alias as the canonical alias for this room + self.helper.send_state( + room_id, + "m.room.aliases", + {"aliases": [test_alias]}, + tok=self.admin_user_tok, + state_key="test", + ) self.helper.send_state( room_id, "m.room.canonical_alias", diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/v1/test_directory.py index 914cf54927..633b7dbda0 100644 --- a/tests/rest/client/v1/test_directory.py +++ b/tests/rest/client/v1/test_directory.py @@ -51,30 +51,26 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.user = self.register_user("user", "test") self.user_tok = self.login("user", "test") - def test_cannot_set_alias_via_state_event(self): - self.ensure_user_joined_room() - url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % ( - self.room_id, - self.hs.hostname, - ) - - data = {"aliases": [self.random_alias(5)]} - request_data = json.dumps(data) - - request, channel = self.make_request( - "PUT", url, request_data, access_token=self.user_tok - ) - self.render(request) - self.assertEqual(channel.code, 400, channel.result) + def test_state_event_not_in_room(self): + self.ensure_user_left_room() + self.set_alias_via_state_event(403) def test_directory_endpoint_not_in_room(self): self.ensure_user_left_room() self.set_alias_via_directory(403) + def test_state_event_in_room_too_long(self): + self.ensure_user_joined_room() + self.set_alias_via_state_event(400, alias_length=256) + def test_directory_in_room_too_long(self): self.ensure_user_joined_room() self.set_alias_via_directory(400, alias_length=256) + def test_state_event_in_room(self): + self.ensure_user_joined_room() + self.set_alias_via_state_event(200) + def test_directory_in_room(self): self.ensure_user_joined_room() self.set_alias_via_directory(200) @@ -106,6 +102,21 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEqual(channel.code, 200, channel.result) + def set_alias_via_state_event(self, expected_code, alias_length=5): + url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % ( + self.room_id, + self.hs.hostname, + ) + + data = {"aliases": [self.random_alias(alias_length)]} + request_data = json.dumps(data) + + request, channel = self.make_request( + "PUT", url, request_data, access_token=self.user_tok + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + def set_alias_via_directory(self, expected_code, alias_length=5): url = "/_matrix/client/r0/directory/room/%s" % self.random_alias(alias_length) data = {"room_id": self.room_id} -- cgit 1.5.1 From c37db0211e36cd298426ff8811e547b0acd10bf4 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 17 Mar 2020 22:32:25 +0100 Subject: Share SSL contexts for non-federation requests (#7094) Extends #5794 etc to the SimpleHttpClient so that it also applies to non-federation requests. Fixes #7092. --- changelog.d/7094.misc | 1 + synapse/crypto/context_factory.py | 68 ++++++++++++++-------- synapse/http/client.py | 3 - synapse/http/federation/matrix_federation_agent.py | 2 +- synapse/server.py | 6 +- tests/config/test_tls.py | 29 +++++---- .../federation/test_matrix_federation_agent.py | 6 +- 7 files changed, 71 insertions(+), 44 deletions(-) create mode 100644 changelog.d/7094.misc (limited to 'tests') diff --git a/changelog.d/7094.misc b/changelog.d/7094.misc new file mode 100644 index 0000000000..aa093ee3c0 --- /dev/null +++ b/changelog.d/7094.misc @@ -0,0 +1 @@ +Improve performance when making HTTPS requests to sygnal, sydent, etc, by sharing the SSL context object between connections. diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index e93f0b3705..a5a2a7815d 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -75,7 +75,7 @@ class ServerContextFactory(ContextFactory): @implementer(IPolicyForHTTPS) -class ClientTLSOptionsFactory(object): +class FederationPolicyForHTTPS(object): """Factory for Twisted SSLClientConnectionCreators that are used to make connections to remote servers for federation. @@ -103,15 +103,15 @@ class ClientTLSOptionsFactory(object): # let us do). minTLS = _TLS_VERSION_MAP[config.federation_client_minimum_tls_version] - self._verify_ssl = CertificateOptions( + _verify_ssl = CertificateOptions( trustRoot=trust_root, insecurelyLowerMinimumTo=minTLS ) - self._verify_ssl_context = self._verify_ssl.getContext() - self._verify_ssl_context.set_info_callback(self._context_info_cb) + self._verify_ssl_context = _verify_ssl.getContext() + self._verify_ssl_context.set_info_callback(_context_info_cb) - self._no_verify_ssl = CertificateOptions(insecurelyLowerMinimumTo=minTLS) - self._no_verify_ssl_context = self._no_verify_ssl.getContext() - self._no_verify_ssl_context.set_info_callback(self._context_info_cb) + _no_verify_ssl = CertificateOptions(insecurelyLowerMinimumTo=minTLS) + self._no_verify_ssl_context = _no_verify_ssl.getContext() + self._no_verify_ssl_context.set_info_callback(_context_info_cb) def get_options(self, host: bytes): @@ -136,23 +136,6 @@ class ClientTLSOptionsFactory(object): return SSLClientConnectionCreator(host, ssl_context, should_verify) - @staticmethod - def _context_info_cb(ssl_connection, where, ret): - """The 'information callback' for our openssl context object.""" - # we assume that the app_data on the connection object has been set to - # a TLSMemoryBIOProtocol object. (This is done by SSLClientConnectionCreator) - tls_protocol = ssl_connection.get_app_data() - try: - # ... we further assume that SSLClientConnectionCreator has set the - # '_synapse_tls_verifier' attribute to a ConnectionVerifier object. - tls_protocol._synapse_tls_verifier.verify_context_info_cb( - ssl_connection, where - ) - except: # noqa: E722, taken from the twisted implementation - logger.exception("Error during info_callback") - f = Failure() - tls_protocol.failVerification(f) - def creatorForNetloc(self, hostname, port): """Implements the IPolicyForHTTPS interace so that this can be passed directly to agents. @@ -160,6 +143,43 @@ class ClientTLSOptionsFactory(object): return self.get_options(hostname) +@implementer(IPolicyForHTTPS) +class RegularPolicyForHTTPS(object): + """Factory for Twisted SSLClientConnectionCreators that are used to make connections + to remote servers, for other than federation. + + Always uses the same OpenSSL context object, which uses the default OpenSSL CA + trust root. + """ + + def __init__(self): + trust_root = platformTrust() + self._ssl_context = CertificateOptions(trustRoot=trust_root).getContext() + self._ssl_context.set_info_callback(_context_info_cb) + + def creatorForNetloc(self, hostname, port): + return SSLClientConnectionCreator(hostname, self._ssl_context, True) + + +def _context_info_cb(ssl_connection, where, ret): + """The 'information callback' for our openssl context objects. + + Note: Once this is set as the info callback on a Context object, the Context should + only be used with the SSLClientConnectionCreator. + """ + # we assume that the app_data on the connection object has been set to + # a TLSMemoryBIOProtocol object. (This is done by SSLClientConnectionCreator) + tls_protocol = ssl_connection.get_app_data() + try: + # ... we further assume that SSLClientConnectionCreator has set the + # '_synapse_tls_verifier' attribute to a ConnectionVerifier object. + tls_protocol._synapse_tls_verifier.verify_context_info_cb(ssl_connection, where) + except: # noqa: E722, taken from the twisted implementation + logger.exception("Error during info_callback") + f = Failure() + tls_protocol.failVerification(f) + + @implementer(IOpenSSLClientConnectionCreator) class SSLClientConnectionCreator(object): """Creates openssl connection objects for client connections. diff --git a/synapse/http/client.py b/synapse/http/client.py index d4c285445e..3797545824 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -244,9 +244,6 @@ class SimpleHttpClient(object): pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5)) pool.cachedConnectionTimeout = 2 * 60 - # The default context factory in Twisted 14.0.0 (which we require) is - # BrowserLikePolicyForHTTPS which will do regular cert validation - # 'like a browser' self.agent = ProxyAgent( self.reactor, connectTimeout=15, diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 647d26dc56..f5f917f5ae 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -45,7 +45,7 @@ class MatrixFederationAgent(object): Args: reactor (IReactor): twisted reactor to use for underlying requests - tls_client_options_factory (ClientTLSOptionsFactory|None): + tls_client_options_factory (FederationPolicyForHTTPS|None): factory to use for fetching client tls options, or none to disable TLS. _srv_resolver (SrvResolver|None): diff --git a/synapse/server.py b/synapse/server.py index fd2f69e928..1b980371de 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -26,7 +26,6 @@ import logging import os from twisted.mail.smtp import sendmail -from twisted.web.client import BrowserLikePolicyForHTTPS from synapse.api.auth import Auth from synapse.api.filtering import Filtering @@ -35,6 +34,7 @@ from synapse.appservice.api import ApplicationServiceApi from synapse.appservice.scheduler import ApplicationServiceScheduler from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory +from synapse.crypto.context_factory import RegularPolicyForHTTPS from synapse.crypto.keyring import Keyring from synapse.events.builder import EventBuilderFactory from synapse.events.spamcheck import SpamChecker @@ -310,7 +310,7 @@ class HomeServer(object): return ( InsecureInterceptableContextFactory() if self.config.use_insecure_ssl_client_just_for_testing_do_not_use - else BrowserLikePolicyForHTTPS() + else RegularPolicyForHTTPS() ) def build_simple_http_client(self): @@ -420,7 +420,7 @@ class HomeServer(object): return PusherPool(self) def build_http_client(self): - tls_client_options_factory = context_factory.ClientTLSOptionsFactory( + tls_client_options_factory = context_factory.FederationPolicyForHTTPS( self.config ) return MatrixFederationHttpClient(self, tls_client_options_factory) diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py index 1be6ff563b..ec32d4b1ca 100644 --- a/tests/config/test_tls.py +++ b/tests/config/test_tls.py @@ -23,7 +23,7 @@ from OpenSSL import SSL from synapse.config._base import Config, RootConfig from synapse.config.tls import ConfigError, TlsConfig -from synapse.crypto.context_factory import ClientTLSOptionsFactory +from synapse.crypto.context_factory import FederationPolicyForHTTPS from tests.unittest import TestCase @@ -180,12 +180,13 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg= t = TestConfig() t.read_config(config, config_dir_path="", data_dir_path="") - cf = ClientTLSOptionsFactory(t) + cf = FederationPolicyForHTTPS(t) + options = _get_ssl_context_options(cf._verify_ssl_context) # The context has had NO_TLSv1_1 and NO_TLSv1_0 set, but not NO_TLSv1_2 - self.assertNotEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1, 0) - self.assertNotEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_1, 0) - self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_2, 0) + self.assertNotEqual(options & SSL.OP_NO_TLSv1, 0) + self.assertNotEqual(options & SSL.OP_NO_TLSv1_1, 0) + self.assertEqual(options & SSL.OP_NO_TLSv1_2, 0) def test_tls_client_minimum_set_passed_through_1_0(self): """ @@ -195,12 +196,13 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg= t = TestConfig() t.read_config(config, config_dir_path="", data_dir_path="") - cf = ClientTLSOptionsFactory(t) + cf = FederationPolicyForHTTPS(t) + options = _get_ssl_context_options(cf._verify_ssl_context) # The context has not had any of the NO_TLS set. - self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1, 0) - self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_1, 0) - self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_2, 0) + self.assertEqual(options & SSL.OP_NO_TLSv1, 0) + self.assertEqual(options & SSL.OP_NO_TLSv1_1, 0) + self.assertEqual(options & SSL.OP_NO_TLSv1_2, 0) def test_acme_disabled_in_generated_config_no_acme_domain_provied(self): """ @@ -273,7 +275,7 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg= t = TestConfig() t.read_config(config, config_dir_path="", data_dir_path="") - cf = ClientTLSOptionsFactory(t) + cf = FederationPolicyForHTTPS(t) # Not in the whitelist opts = cf.get_options(b"notexample.com") @@ -282,3 +284,10 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg= # Caught by the wildcard opts = cf.get_options(idna.encode("テスト.ドメイン.テスト")) self.assertFalse(opts._verifier._verify_certs) + + +def _get_ssl_context_options(ssl_context: SSL.Context) -> int: + """get the options bits from an openssl context object""" + # the OpenSSL.SSL.Context wrapper doesn't expose get_options, so we have to + # use the low-level interface + return SSL._lib.SSL_CTX_get_options(ssl_context._context) diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index cfcd98ff7d..fdc1d918ff 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -31,7 +31,7 @@ from twisted.web.http_headers import Headers from twisted.web.iweb import IPolicyForHTTPS from synapse.config.homeserver import HomeServerConfig -from synapse.crypto.context_factory import ClientTLSOptionsFactory +from synapse.crypto.context_factory import FederationPolicyForHTTPS from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent from synapse.http.federation.srv_resolver import Server from synapse.http.federation.well_known_resolver import ( @@ -79,7 +79,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self._config = config = HomeServerConfig() config.parse_config_dict(config_dict, "", "") - self.tls_factory = ClientTLSOptionsFactory(config) + self.tls_factory = FederationPolicyForHTTPS(config) self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds) self.had_well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds) @@ -715,7 +715,7 @@ class MatrixFederationAgentTests(unittest.TestCase): config = default_config("test", parse=True) # Build a new agent and WellKnownResolver with a different tls factory - tls_factory = ClientTLSOptionsFactory(config) + tls_factory = FederationPolicyForHTTPS(config) agent = MatrixFederationAgent( reactor=self.reactor, tls_client_options_factory=tls_factory, -- cgit 1.5.1 From 4a17a647a9508b70de35130fd82e3e21474270a9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 18 Mar 2020 16:46:41 +0000 Subject: Improve get auth chain difference algorithm. (#7095) It was originally implemented by pulling the full auth chain of all state sets out of the database and doing set comparison. However, that can take a lot work if the state and auth chains are large. Instead, lets try and fetch the auth chains at the same time and calculate the difference on the fly, allowing us to bail early if all the auth chains converge. Assuming that the auth chains do converge more often than not, this should improve performance. Hopefully. --- changelog.d/7095.misc | 1 + synapse/state/__init__.py | 28 ++-- synapse/state/v2.py | 32 +---- .../storage/data_stores/main/event_federation.py | 150 +++++++++++++++++++- tests/state/test_v2.py | 13 +- tests/storage/test_event_federation.py | 157 ++++++++++++++++++--- 6 files changed, 310 insertions(+), 71 deletions(-) create mode 100644 changelog.d/7095.misc (limited to 'tests') diff --git a/changelog.d/7095.misc b/changelog.d/7095.misc new file mode 100644 index 0000000000..44fc9f616f --- /dev/null +++ b/changelog.d/7095.misc @@ -0,0 +1 @@ +Attempt to improve performance of state res v2 algorithm. diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index df7a4f6a89..4afefc6b1d 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -662,28 +662,16 @@ class StateResolutionStore(object): allow_rejected=allow_rejected, ) - def get_auth_chain(self, event_ids: List[str], ignore_events: Set[str]): - """Gets the full auth chain for a set of events (including rejected - events). - - Includes the given event IDs in the result. - - Note that: - 1. All events must be state events. - 2. For v1 rooms this may not have the full auth chain in the - presence of rejected events - - Args: - event_ids: The event IDs of the events to fetch the auth chain for. - Must be state events. - ignore_events: Set of events to exclude from the returned auth - chain. + def get_auth_chain_difference(self, state_sets: List[Set[str]]): + """Given sets of state events figure out the auth chain difference (as + per state res v2 algorithm). + This equivalent to fetching the full auth chain for each set of state + and returning the events that don't appear in each and every auth + chain. Returns: - Deferred[list[str]]: List of event IDs of the auth chain. + Deferred[Set[str]]: Set of event IDs. """ - return self.store.get_auth_chain_ids( - event_ids, include_given=True, ignore_events=ignore_events, - ) + return self.store.get_auth_chain_difference(state_sets) diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 0ffe6d8c14..18484e2fa6 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -227,36 +227,12 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store): Returns: Deferred[set[str]]: Set of event IDs """ - common = set(itervalues(state_sets[0])).intersection( - *(itervalues(s) for s in state_sets[1:]) - ) - - auth_sets = [] - for state_set in state_sets: - auth_ids = { - eid - for key, eid in iteritems(state_set) - if ( - key[0] in (EventTypes.Member, EventTypes.ThirdPartyInvite) - or key - in ( - (EventTypes.PowerLevels, ""), - (EventTypes.Create, ""), - (EventTypes.JoinRules, ""), - ) - ) - and eid not in common - } - auth_chain = yield state_res_store.get_auth_chain(auth_ids, common) - auth_ids.update(auth_chain) - - auth_sets.append(auth_ids) - - intersection = set(auth_sets[0]).intersection(*auth_sets[1:]) - union = set().union(*auth_sets) + difference = yield state_res_store.get_auth_chain_difference( + [set(state_set.values()) for state_set in state_sets] + ) - return union - intersection + return difference def _seperate(state_sets): diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 49a7b8b433..62d4e9f599 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -14,7 +14,7 @@ # limitations under the License. import itertools import logging -from typing import List, Optional, Set +from typing import Dict, List, Optional, Set, Tuple from six.moves.queue import Empty, PriorityQueue @@ -103,6 +103,154 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return list(results) + def get_auth_chain_difference(self, state_sets: List[Set[str]]): + """Given sets of state events figure out the auth chain difference (as + per state res v2 algorithm). + + This equivalent to fetching the full auth chain for each set of state + and returning the events that don't appear in each and every auth + chain. + + Returns: + Deferred[Set[str]] + """ + + return self.db.runInteraction( + "get_auth_chain_difference", + self._get_auth_chain_difference_txn, + state_sets, + ) + + def _get_auth_chain_difference_txn( + self, txn, state_sets: List[Set[str]] + ) -> Set[str]: + + # Algorithm Description + # ~~~~~~~~~~~~~~~~~~~~~ + # + # The idea here is to basically walk the auth graph of each state set in + # tandem, keeping track of which auth events are reachable by each state + # set. If we reach an auth event we've already visited (via a different + # state set) then we mark that auth event and all ancestors as reachable + # by the state set. This requires that we keep track of the auth chains + # in memory. + # + # Doing it in a such a way means that we can stop early if all auth + # events we're currently walking are reachable by all state sets. + # + # *Note*: We can't stop walking an event's auth chain if it is reachable + # by all state sets. This is because other auth chains we're walking + # might be reachable only via the original auth chain. For example, + # given the following auth chain: + # + # A -> C -> D -> E + # / / + # B -´---------´ + # + # and state sets {A} and {B} then walking the auth chains of A and B + # would immediately show that C is reachable by both. However, if we + # stopped at C then we'd only reach E via the auth chain of B and so E + # would errornously get included in the returned difference. + # + # The other thing that we do is limit the number of auth chains we walk + # at once, due to practical limits (i.e. we can only query the database + # with a limited set of parameters). We pick the auth chains we walk + # each iteration based on their depth, in the hope that events with a + # lower depth are likely reachable by those with higher depths. + # + # We could use any ordering that we believe would give a rough + # topological ordering, e.g. origin server timestamp. If the ordering + # chosen is not topological then the algorithm still produces the right + # result, but perhaps a bit more inefficiently. This is why it is safe + # to use "depth" here. + + initial_events = set(state_sets[0]).union(*state_sets[1:]) + + # Dict from events in auth chains to which sets *cannot* reach them. + # I.e. if the set is empty then all sets can reach the event. + event_to_missing_sets = { + event_id: {i for i, a in enumerate(state_sets) if event_id not in a} + for event_id in initial_events + } + + # We need to get the depth of the initial events for sorting purposes. + sql = """ + SELECT depth, event_id FROM events + WHERE %s + ORDER BY depth ASC + """ + clause, args = make_in_list_sql_clause( + txn.database_engine, "event_id", initial_events + ) + txn.execute(sql % (clause,), args) + + # The sorted list of events whose auth chains we should walk. + search = txn.fetchall() # type: List[Tuple[int, str]] + + # Map from event to its auth events + event_to_auth_events = {} # type: Dict[str, Set[str]] + + base_sql = """ + SELECT a.event_id, auth_id, depth + FROM event_auth AS a + INNER JOIN events AS e ON (e.event_id = a.auth_id) + WHERE + """ + + while search: + # Check whether all our current walks are reachable by all state + # sets. If so we can bail. + if all(not event_to_missing_sets[eid] for _, eid in search): + break + + # Fetch the auth events and their depths of the N last events we're + # currently walking + search, chunk = search[:-100], search[-100:] + clause, args = make_in_list_sql_clause( + txn.database_engine, "a.event_id", [e_id for _, e_id in chunk] + ) + txn.execute(base_sql + clause, args) + + for event_id, auth_event_id, auth_event_depth in txn: + event_to_auth_events.setdefault(event_id, set()).add(auth_event_id) + + sets = event_to_missing_sets.get(auth_event_id) + if sets is None: + # First time we're seeing this event, so we add it to the + # queue of things to fetch. + search.append((auth_event_depth, auth_event_id)) + + # Assume that this event is unreachable from any of the + # state sets until proven otherwise + sets = event_to_missing_sets[auth_event_id] = set( + range(len(state_sets)) + ) + else: + # We've previously seen this event, so look up its auth + # events and recursively mark all ancestors as reachable + # by the current event's state set. + a_ids = event_to_auth_events.get(auth_event_id) + while a_ids: + new_aids = set() + for a_id in a_ids: + event_to_missing_sets[a_id].intersection_update( + event_to_missing_sets[event_id] + ) + + b = event_to_auth_events.get(a_id) + if b: + new_aids.update(b) + + a_ids = new_aids + + # Mark that the auth event is reachable by the approriate sets. + sets.intersection_update(event_to_missing_sets[event_id]) + + search.sort() + + # Return all events where not all sets can reach them. + return {eid for eid, n in event_to_missing_sets.items() if n} + def get_oldest_events_in_room(self, room_id): return self.db.runInteraction( "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index 5059ade850..a44960203e 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -603,7 +603,7 @@ class TestStateResolutionStore(object): return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map} - def get_auth_chain(self, event_ids, ignore_events): + def _get_auth_chain(self, event_ids): """Gets the full auth chain for a set of events (including rejected events). @@ -617,9 +617,6 @@ class TestStateResolutionStore(object): Args: event_ids (list): The event IDs of the events to fetch the auth chain for. Must be state events. - ignore_events: Set of events to exclude from the returned auth - chain. - Returns: Deferred[list[str]]: List of event IDs of the auth chain. """ @@ -629,7 +626,7 @@ class TestStateResolutionStore(object): stack = list(event_ids) while stack: event_id = stack.pop() - if event_id in result or event_id in ignore_events: + if event_id in result: continue result.add(event_id) @@ -639,3 +636,9 @@ class TestStateResolutionStore(object): stack.append(aid) return list(result) + + def get_auth_chain_difference(self, auth_sets): + chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets] + + common = set(chains[0]).intersection(*chains[1:]) + return set(chains[0]).union(*chains[1:]) - common diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index a331517f4d..3aeec0dc0f 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -13,19 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - import tests.unittest import tests.utils -class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield tests.utils.setup_test_homeserver(self.addCleanup) +class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): + def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - @defer.inlineCallbacks def test_get_prev_events_for_room(self): room_id = "@ROOM:local" @@ -61,15 +56,14 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): ) for i in range(0, 20): - yield self.store.db.runInteraction("insert", insert_event, i) + self.get_success(self.store.db.runInteraction("insert", insert_event, i)) # this should get the last ten - r = yield self.store.get_prev_events_for_room(room_id) + r = self.get_success(self.store.get_prev_events_for_room(room_id)) self.assertEqual(10, len(r)) for i in range(0, 10): self.assertEqual("$event_%i:local" % (19 - i), r[i]) - @defer.inlineCallbacks def test_get_rooms_with_many_extremities(self): room1 = "#room1" room2 = "#room2" @@ -86,25 +80,154 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): ) for i in range(0, 20): - yield self.store.db.runInteraction("insert", insert_event, i, room1) - yield self.store.db.runInteraction("insert", insert_event, i, room2) - yield self.store.db.runInteraction("insert", insert_event, i, room3) + self.get_success( + self.store.db.runInteraction("insert", insert_event, i, room1) + ) + self.get_success( + self.store.db.runInteraction("insert", insert_event, i, room2) + ) + self.get_success( + self.store.db.runInteraction("insert", insert_event, i, room3) + ) # Test simple case - r = yield self.store.get_rooms_with_many_extremities(5, 5, []) + r = self.get_success(self.store.get_rooms_with_many_extremities(5, 5, [])) self.assertEqual(len(r), 3) # Does filter work? - r = yield self.store.get_rooms_with_many_extremities(5, 5, [room1]) + r = self.get_success(self.store.get_rooms_with_many_extremities(5, 5, [room1])) self.assertTrue(room2 in r) self.assertTrue(room3 in r) self.assertEqual(len(r), 2) - r = yield self.store.get_rooms_with_many_extremities(5, 5, [room1, room2]) + r = self.get_success( + self.store.get_rooms_with_many_extremities(5, 5, [room1, room2]) + ) self.assertEqual(r, [room3]) # Does filter and limit work? - r = yield self.store.get_rooms_with_many_extremities(5, 1, [room1]) + r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1])) self.assertTrue(r == [room2] or r == [room3]) + + def test_auth_difference(self): + room_id = "@ROOM:local" + + # The silly auth graph we use to test the auth difference algorithm, + # where the top are the most recent events. + # + # A B + # \ / + # D E + # \ | + # ` F C + # | /| + # G ´ | + # | \ | + # H I + # | | + # K J + + auth_graph = { + "a": ["e"], + "b": ["e"], + "c": ["g", "i"], + "d": ["f"], + "e": ["f"], + "f": ["g"], + "g": ["h", "i"], + "h": ["k"], + "i": ["j"], + "k": [], + "j": [], + } + + depth_map = { + "a": 7, + "b": 7, + "c": 4, + "d": 6, + "e": 6, + "f": 5, + "g": 3, + "h": 2, + "i": 2, + "k": 1, + "j": 1, + } + + # We rudely fiddle with the appropriate tables directly, as that's much + # easier than constructing events properly. + + def insert_event(txn, event_id, stream_ordering): + + depth = depth_map[event_id] + + self.store.db.simple_insert_txn( + txn, + table="events", + values={ + "event_id": event_id, + "room_id": room_id, + "depth": depth, + "topological_ordering": depth, + "type": "m.test", + "processed": True, + "outlier": False, + "stream_ordering": stream_ordering, + }, + ) + + self.store.db.simple_insert_many_txn( + txn, + table="event_auth", + values=[ + {"event_id": event_id, "room_id": room_id, "auth_id": a} + for a in auth_graph[event_id] + ], + ) + + next_stream_ordering = 0 + for event_id in auth_graph: + next_stream_ordering += 1 + self.get_success( + self.store.db.runInteraction( + "insert", insert_event, event_id, next_stream_ordering + ) + ) + + # Now actually test that various combinations give the right result: + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a"}, {"b"}]) + ) + self.assertSetEqual(difference, {"a", "b"}) + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}]) + ) + self.assertSetEqual(difference, {"a", "b", "c", "e", "f"}) + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a", "c"}, {"b"}]) + ) + self.assertSetEqual(difference, {"a", "b", "c"}) + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}]) + ) + self.assertSetEqual(difference, {"a", "b", "d", "e"}) + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}, {"d"}]) + ) + self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"}) + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a"}, {"b"}, {"e"}]) + ) + self.assertSetEqual(difference, {"a", "b"}) + + difference = self.get_success(self.store.get_auth_chain_difference([{"a"}])) + self.assertSetEqual(difference, set()) -- cgit 1.5.1 From e66bbf7a9db5cbd94ecd59b8375ab398d6951344 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 22 Apr 2020 14:23:10 +0100 Subject: Fix and refactor rewritten IS url feature. Add sample config docs (#40) --- docs/sample_config.yaml | 7 ++ synapse/config/registration.py | 11 ++- synapse/handlers/identity.py | 155 +++++++++++++++++++--------------------- tests/handlers/test_identity.py | 7 +- 4 files changed, 92 insertions(+), 88 deletions(-) (limited to 'tests') diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 34bdc3759f..274658bfb2 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1221,6 +1221,13 @@ account_threepid_delegates: # #autocreate_auto_join_rooms: true +# Rewrite identity server URLs with a map from one URL to another. Applies to URLs +# provided by clients (which have https:// prepended) and those specified +# in `account_threepid_delegates`. URLs should not feature a trailing slash. +# +#rewrite_identity_server_urls: +# "https://somewhere.example.com": "https://somewhereelse.example.com" + ## Metrics ### diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 426c0fecef..4bf5bbaa17 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -163,8 +163,8 @@ class RegistrationConfig(Config): self.shadow_server = config.get("shadow_server", None) self.rewrite_identity_server_urls = config.get( - "rewrite_identity_server_urls", {} - ) + "rewrite_identity_server_urls" + ) or {} self.disable_msisdn_registration = config.get( "disable_msisdn_registration", False @@ -432,6 +432,13 @@ class RegistrationConfig(Config): # users cannot be auto-joined since they do not exist. # #autocreate_auto_join_rooms: true + + # Rewrite identity server URLs with a map from one URL to another. Applies to URLs + # provided by clients (which have https:// prepended) and those specified + # in `account_threepid_delegates`. URLs should not feature a trailing slash. + # + #rewrite_identity_server_urls: + # "https://somewhere.example.com": "https://somewhereelse.example.com" """ % locals() ) diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 6c369b78aa..e8a6cf7788 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -45,8 +45,6 @@ from ._base import BaseHandler logger = logging.getLogger(__name__) -id_server_scheme = "https://" - class IdentityHandler(BaseHandler): def __init__(self, hs): @@ -69,13 +67,13 @@ class IdentityHandler(BaseHandler): self._enable_lookup = hs.config.enable_3pid_lookup @defer.inlineCallbacks - def threepid_from_creds(self, id_server, creds): + def threepid_from_creds(self, id_server_url, creds): """ Retrieve and validate a threepid identifier from a "credentials" dictionary against a given identity server Args: - id_server (str): The identity server to validate 3PIDs against. Must be a + id_server_url (str): The identity server to validate 3PIDs against. Must be a complete URL including the protocol (http(s)://) creds (dict[str, str]): Dictionary containing the following keys: @@ -104,10 +102,10 @@ class IdentityHandler(BaseHandler): # if we have a rewrite rule set for the identity server, # apply it now. - id_server = self.rewrite_id_server_url(id_server) + id_server_url = self.rewrite_id_server_url(id_server_url) - url = "https://%s%s" % ( - id_server, + url = "%s%s" % ( + id_server_url, "/_matrix/identity/api/v1/3pid/getValidated3pid", ) @@ -118,7 +116,7 @@ class IdentityHandler(BaseHandler): except HttpResponseException as e: logger.info( "%s returned %i for threepid validation for: %s", - id_server, + id_server_url, e.code, creds, ) @@ -132,7 +130,7 @@ class IdentityHandler(BaseHandler): if "medium" in data: return data - logger.info("%s reported non-validated threepid: %s", id_server, creds) + logger.info("%s reported non-validated threepid: %s", id_server_url, creds) return None @defer.inlineCallbacks @@ -167,18 +165,18 @@ class IdentityHandler(BaseHandler): # if we have a rewrite rule set for the identity server, # apply it now, but only for sending the request (not # storing in the database). - id_server_host = self.rewrite_id_server_url(id_server) + id_server_url = self.rewrite_id_server_url(id_server, add_https=True) # Decide which API endpoint URLs to use headers = {} bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid} if use_v2: - bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server_host,) + bind_url = "%s/_matrix/identity/v2/3pid/bind" % (id_server_url,) headers["Authorization"] = create_id_access_token_header( id_access_token ) else: - bind_url = "https://%s/_matrix/identity/api/v1/3pid/bind" % (id_server_host,) + bind_url = "%s/_matrix/identity/api/v1/3pid/bind" % (id_server_url,) try: # Use the blacklisting http client as this call is only to identity servers @@ -265,9 +263,6 @@ class IdentityHandler(BaseHandler): Deferred[bool]: True on success, otherwise False if the identity server doesn't support unbinding """ - url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,) - url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii") - content = { "mxid": mxid, "threepid": {"medium": threepid["medium"], "address": threepid["address"]}, @@ -276,6 +271,7 @@ class IdentityHandler(BaseHandler): # we abuse the federation http client to sign the request, but we have to send it # using the normal http client since we don't want the SRV lookup and want normal # 'browser-like' HTTPS. + url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii") auth_headers = self.federation_http_client.build_auth_headers( destination=None, method="POST", @@ -290,9 +286,9 @@ class IdentityHandler(BaseHandler): # # Note that destination_is has to be the real id_server, not # the server we connect to. - id_server = self.rewrite_id_server_url(id_server) + id_server_url = self.rewrite_id_server_url(id_server, add_https=True) - url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,) + url = "%s/_matrix/identity/api/v1/3pid/unbind" % (id_server_url,) try: # Use the blacklisting http client as this call is only to identity servers @@ -408,33 +404,33 @@ class IdentityHandler(BaseHandler): return session_id - def rewrite_id_server_url(self, url: str) -> str: - """Given an identity server URL, rewrite it according to the - rewrite_identity_server_urls config option + def rewrite_id_server_url(self, url: str, add_https=False) -> str: + """Given an identity server URL, optionally add a protocol scheme + before rewriting it according to the rewrite_identity_server_urls + config option - First removes the protocol scheme from the URL if provided. - Then checks for a rewritten URL. If found, returns the new URL. - Otherwise, returns the original URL. + Adds https:// to the URL if specified, then tries to rewrite the + url. Returns either the rewritten URL or the URL with optional + protocol scheme additions. """ rewritten_url = url - if url.startswith("http://"): - rewritten_url = url[7:] - elif url.startswith("https://"): - rewritten_url = url[8:] - - return self.rewrite_identity_server_urls.get(rewritten_url, url) + if add_https: + rewritten_url = "https://" + rewritten_url + rewritten_url = self.rewrite_identity_server_urls.get(rewritten_url, rewritten_url) + logger.debug("Rewriting identity server rule from %s to %s", url, rewritten_url) + return rewritten_url @defer.inlineCallbacks def requestEmailToken( - self, id_server, email, client_secret, send_attempt, next_link=None + self, id_server_url, email, client_secret, send_attempt, next_link=None ): """ Request an external server send an email on our behalf for the purposes of threepid validation. Args: - id_server (str): The identity server to proxy to + id_server_url (str): The identity server to proxy to email (str): The email to send the message to client_secret (str): The unique client_secret sends by the user send_attempt (int): Which attempt this is @@ -451,7 +447,7 @@ class IdentityHandler(BaseHandler): # if we have a rewrite rule set for the identity server, # apply it now. - id_server = self.rewrite_id_server_url(id_server) + id_server_url = self.rewrite_id_server_url(id_server_url) if next_link: params["next_link"] = next_link @@ -467,7 +463,7 @@ class IdentityHandler(BaseHandler): try: data = yield self.http_client.post_json_get_json( - id_server + "/_matrix/identity/api/v1/validate/email/requestToken", + "%s/_matrix/identity/api/v1/validate/email/requestToken" % (id_server_url,), params, ) return data @@ -480,7 +476,7 @@ class IdentityHandler(BaseHandler): @defer.inlineCallbacks def requestMsisdnToken( self, - id_server, + id_server_url, country, phone_number, client_secret, @@ -491,7 +487,7 @@ class IdentityHandler(BaseHandler): Request an external server send an SMS message on our behalf for the purposes of threepid validation. Args: - id_server (str): The identity server to proxy to + id_server_url (str): The identity server to proxy to country (str): The country code of the phone number phone_number (str): The number to send the message to client_secret (str): The unique client_secret sends by the user @@ -521,10 +517,10 @@ class IdentityHandler(BaseHandler): # if we have a rewrite rule set for the identity server, # apply it now. - id_server = self.rewrite_id_server_url(id_server) + id_server_url = self.rewrite_id_server_url(id_server_url) try: data = yield self.http_client.post_json_get_json( - id_server + "/_matrix/identity/api/v1/validate/msisdn/requestToken", + "%s/_matrix/identity/api/v1/validate/msisdn/requestToken" % (id_server_url,), params, ) except HttpResponseException as e: @@ -646,11 +642,11 @@ class IdentityHandler(BaseHandler): 403, "Looking up third-party identifiers is denied from this server" ) - target = self.rewrite_id_server_url(id_server) + id_server_url = self.rewrite_id_server_url(id_server, add_https=True) try: data = yield self.http_client.get_json( - "https://%s/_matrix/identity/api/v1/lookup" % (target,), + "%s/_matrix/identity/api/v1/lookup" % (id_server_url,), {"medium": medium, "address": address}, ) @@ -663,7 +659,7 @@ class IdentityHandler(BaseHandler): logger.info("Proxied lookup failed: %r", e) raise e.to_synapse_error() except IOError as e: - logger.info("Failed to contact %r: %s", id_server, e) + logger.info("Failed to contact %s: %s", id_server, e) raise ProxiedRequestError(503, "Failed to contact identity server") defer.returnValue(data) @@ -688,11 +684,11 @@ class IdentityHandler(BaseHandler): 403, "Looking up third-party identifiers is denied from this server" ) - target = self.rewrite_id_server_url(id_server) + id_server_url = self.rewrite_id_server_url(id_server, add_https=True) try: data = yield self.http_client.post_json_get_json( - "https://%s/_matrix/identity/api/v1/bulk_lookup" % (target,), + "%s/_matrix/identity/api/v1/bulk_lookup" % (id_server_url,), {"threepids": threepids}, ) @@ -700,7 +696,7 @@ class IdentityHandler(BaseHandler): logger.info("Proxied lookup failed: %r", e) raise e.to_synapse_error() except IOError as e: - logger.info("Failed to contact %r: %s", id_server, e) + logger.info("Failed to contact %s: %s", id_server, e) raise ProxiedRequestError(503, "Failed to contact identity server") defer.returnValue(data) @@ -721,12 +717,12 @@ class IdentityHandler(BaseHandler): str|None: the matrix ID of the 3pid, or None if it is not recognized. """ # Rewrite id_server URL if necessary - id_server_url = self.rewrite_id_server_url(id_server) + id_server_url = self.rewrite_id_server_url(id_server, add_https=True) if id_access_token is not None: try: results = yield self._lookup_3pid_v2( - id_server, id_server_url, id_access_token, medium, address + id_server_url, id_access_token, medium, address ) return results @@ -762,7 +758,7 @@ class IdentityHandler(BaseHandler): """ try: data = yield self.http_client.get_json( - "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server_url), + "%s/_matrix/identity/api/v1/lookup" % (id_server_url,), {"medium": medium, "address": address}, ) @@ -779,13 +775,11 @@ class IdentityHandler(BaseHandler): return None @defer.inlineCallbacks - def _lookup_3pid_v2(self, id_server, id_server_url, id_access_token, medium, address): + def _lookup_3pid_v2(self, id_server_url, id_access_token, medium, address): """Looks up a 3pid in the passed identity server using v2 lookup. Args: - id_server (str): The server name (including port, if required) - of the identity server to use. - id_server_url (str): The actual, reachable domain of the id server + id_server_url (str): The protocol scheme and domain of the id server id_access_token (str): The access token to authenticate to the identity server with medium (str): The type of the third party identifier (e.g. "email"). address (str): The third party identifier (e.g. "foo@example.com"). @@ -796,7 +790,7 @@ class IdentityHandler(BaseHandler): # Check what hashing details are supported by this identity server try: hash_details = yield self.http_client.get_json( - "%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server_url), + "%s/_matrix/identity/v2/hash_details" % (id_server_url,), {"access_token": id_access_token}, ) except TimeoutError: @@ -804,15 +798,14 @@ class IdentityHandler(BaseHandler): if not isinstance(hash_details, dict): logger.warning( - "Got non-dict object when checking hash details of %s%s: %s", - id_server_scheme, - id_server, + "Got non-dict object when checking hash details of %s: %s", + id_server_url, hash_details, ) raise SynapseError( 400, - "Non-dict object from %s%s during v2 hash_details request: %s" - % (id_server_scheme, id_server, hash_details), + "Non-dict object from %s during v2 hash_details request: %s" + % (id_server_url, hash_details), ) # Extract information from hash_details @@ -826,8 +819,8 @@ class IdentityHandler(BaseHandler): ): raise SynapseError( 400, - "Invalid hash details received from identity server %s%s: %s" - % (id_server_scheme, id_server, hash_details), + "Invalid hash details received from identity server %s: %s" + % (id_server_url, hash_details), ) # Check if any of the supported lookup algorithms are present @@ -849,7 +842,7 @@ class IdentityHandler(BaseHandler): else: logger.warning( "None of the provided lookup algorithms of %s are supported: %s", - id_server, + id_server_url, supported_lookup_algorithms, ) raise SynapseError( @@ -863,7 +856,7 @@ class IdentityHandler(BaseHandler): try: lookup_results = yield self.http_client.post_json_get_json( - "%s%s/_matrix/identity/v2/lookup" % (id_server_scheme, id_server_url), + "%s/_matrix/identity/v2/lookup" % (id_server_url,), { "addresses": [lookup_value], "algorithm": lookup_algorithm, @@ -891,30 +884,30 @@ class IdentityHandler(BaseHandler): return mxid @defer.inlineCallbacks - def _verify_any_signature(self, data, server_hostname): - if server_hostname not in data["signatures"]: - raise AuthError(401, "No signature from server %s" % (server_hostname,)) + def _verify_any_signature(self, data, id_server): + if id_server not in data["signatures"]: + raise AuthError(401, "No signature from server %s" % (id_server,)) - for key_name, signature in data["signatures"][server_hostname].items(): - target = self.rewrite_id_server_url(server_hostname) + for key_name, signature in data["signatures"][id_server].items(): + id_server_url = self.rewrite_id_server_url(id_server, add_https=True) key_data = yield self.http_client.get_json( - "https://%s/_matrix/identity/api/v1/pubkey/%s" % (target, key_name) + "%s/_matrix/identity/api/v1/pubkey/%s" % (id_server_url, key_name) ) if "public_key" not in key_data: raise AuthError( - 401, "No public key named %s from %s" % (key_name, server_hostname) + 401, "No public key named %s from %s" % (key_name, id_server) ) verify_signed_json( data, - server_hostname, + id_server, decode_verify_key_bytes( key_name, decode_base64(key_data["public_key"]) ), ) return - raise AuthError(401, "No signature from server %s" % (server_hostname,)) + raise AuthError(401, "No signature from server %s" % (id_server,)) @defer.inlineCallbacks def ask_id_server_for_third_party_invite( @@ -977,17 +970,16 @@ class IdentityHandler(BaseHandler): } # Rewrite the identity server URL if necessary - id_server = self.rewrite_id_server_url(id_server) + id_server_url = self.rewrite_id_server_url(id_server, add_https=True) # Add the identity service access token to the JSON body and use the v2 # Identity Service endpoints if id_access_token is present data = None - base_url = "%s%s/_matrix/identity" % (id_server_scheme, id_server) + base_url = "%s/_matrix/identity" % (id_server_url,) if id_access_token: - key_validity_url = "%s%s/_matrix/identity/v2/pubkey/isvalid" % ( - id_server_scheme, - id_server, + key_validity_url = "%s/_matrix/identity/v2/pubkey/isvalid" % ( + id_server_url, ) # Attempt a v2 lookup @@ -1006,9 +998,8 @@ class IdentityHandler(BaseHandler): raise e if data is None: - key_validity_url = "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % ( - id_server_scheme, - id_server, + key_validity_url = "%s/_matrix/identity/api/v1/pubkey/isvalid" % ( + id_server_url, ) url = base_url + "/api/v1/store-invite" @@ -1020,9 +1011,8 @@ class IdentityHandler(BaseHandler): raise SynapseError(500, "Timed out contacting identity server") except HttpResponseException as e: logger.warning( - "Error trying to call /store-invite on %s%s: %s", - id_server_scheme, - id_server, + "Error trying to call /store-invite on %s: %s", + id_server_url, e, ) @@ -1036,10 +1026,9 @@ class IdentityHandler(BaseHandler): ) except HttpResponseException as e: logger.warning( - "Error calling /store-invite on %s%s with fallback " + "Error calling /store-invite on %s with fallback " "encoding: %s", - id_server_scheme, - id_server, + id_server_url, e, ) raise e diff --git a/tests/handlers/test_identity.py b/tests/handlers/test_identity.py index 34f6bfb422..0ab0356109 100644 --- a/tests/handlers/test_identity.py +++ b/tests/handlers/test_identity.py @@ -35,12 +35,13 @@ class ThreepidISRewrittenURLTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.address = "test@test" self.is_server_name = "testis" - self.rewritten_is_url = "int.testis" + self.is_server_url = "https://testis" + self.rewritten_is_url = "https://int.testis" config = self.default_config() config["trusted_third_party_id_servers"] = [self.is_server_name] config["rewrite_identity_server_urls"] = { - self.is_server_name: self.rewritten_is_url + self.is_server_url: self.rewritten_is_url } mock_http_client = Mock(spec=["get_json", "post_json_get_json"]) @@ -98,7 +99,7 @@ class ThreepidISRewrittenURLTestCase(unittest.HomeserverTestCase): # Check that the request was done against the rewritten server name. post_json_get_json.assert_called_once_with( - "https://%s/_matrix/identity/api/v1/3pid/bind" % self.rewritten_is_url, + "%s/_matrix/identity/api/v1/3pid/bind" % (self.rewritten_is_url,), { "sid": creds["sid"], "client_secret": creds["client_secret"], -- cgit 1.5.1 From 3fb2ea99329ad2c4f054deff83ad19af571241f8 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Thu, 14 May 2020 11:51:16 +0100 Subject: apply linting --- synapse/config/_base.py | 2 +- synapse/config/registration.py | 6 +++--- synapse/events/spamcheck.py | 2 +- synapse/handlers/account_validity.py | 4 +++- synapse/handlers/federation.py | 3 +-- synapse/handlers/identity.py | 24 +++++++++++----------- synapse/handlers/register.py | 12 +++++------ synapse/handlers/room_member.py | 10 +-------- synapse/rest/client/v2_alpha/account.py | 11 ++++++---- synapse/rest/client/v2_alpha/account_data.py | 2 -- synapse/rest/client/versions.py | 8 ++++---- synapse/storage/data_stores/main/profile.py | 2 -- .../federation/test_matrix_federation_agent.py | 4 +--- tests/rest/client/test_identity.py | 8 +++----- tests/rest/client/test_room_access_rules.py | 9 ++++---- 15 files changed, 48 insertions(+), 59 deletions(-) (limited to 'tests') diff --git a/synapse/config/_base.py b/synapse/config/_base.py index eda2b65ef7..132e48447c 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -18,8 +18,8 @@ import argparse import errno import os -from io import open as io_open from collections import OrderedDict +from io import open as io_open from textwrap import dedent from typing import Any, MutableMapping, Optional diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 4bf5bbaa17..1a44b17589 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -162,9 +162,9 @@ class RegistrationConfig(Config): self.replicate_user_profiles_to = [self.replicate_user_profiles_to] self.shadow_server = config.get("shadow_server", None) - self.rewrite_identity_server_urls = config.get( - "rewrite_identity_server_urls" - ) or {} + self.rewrite_identity_server_urls = ( + config.get("rewrite_identity_server_urls") or {} + ) self.disable_msisdn_registration = config.get( "disable_msisdn_registration", False diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 6e6d75bdcb..ae33b3f65d 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -15,7 +15,7 @@ # limitations under the License. import inspect -from typing import Dict, Optional, List +from typing import Dict, List, Optional from synapse.spam_checker_api import SpamCheckerApi diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 72baee44c3..a6c907b9c9 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -93,7 +93,9 @@ class AccountValidityHandler(object): # If account_validity is enabled,check every hour to remove expired users from # the user directory if self._account_validity.enabled: - self.clock.looping_call(self._mark_expired_users_as_inactive, 60 * 60 * 1000) + self.clock.looping_call( + self._mark_expired_users_as_inactive, 60 * 60 * 1000 + ) async def _send_renewal_emails(self): """Gets the list of users whose account is expiring in the amount of time diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 5d686ab2c9..ebdc239fff 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -412,8 +412,7 @@ class FederationHandler(BaseHandler): # First though we need to fetch all the events that are in # state_map, so we can build up the state below. evs = await self.store.get_events( - list(state_map.values()), - get_prev_content=False, + list(state_map.values()), get_prev_content=False, ) event_map.update(evs) diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index e8a6cf7788..1d07361661 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -172,9 +172,7 @@ class IdentityHandler(BaseHandler): bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid} if use_v2: bind_url = "%s/_matrix/identity/v2/3pid/bind" % (id_server_url,) - headers["Authorization"] = create_id_access_token_header( - id_access_token - ) + headers["Authorization"] = create_id_access_token_header(id_access_token) else: bind_url = "%s/_matrix/identity/api/v1/3pid/bind" % (id_server_url,) @@ -417,7 +415,9 @@ class IdentityHandler(BaseHandler): if add_https: rewritten_url = "https://" + rewritten_url - rewritten_url = self.rewrite_identity_server_urls.get(rewritten_url, rewritten_url) + rewritten_url = self.rewrite_identity_server_urls.get( + rewritten_url, rewritten_url + ) logger.debug("Rewriting identity server rule from %s to %s", url, rewritten_url) return rewritten_url @@ -463,7 +463,8 @@ class IdentityHandler(BaseHandler): try: data = yield self.http_client.post_json_get_json( - "%s/_matrix/identity/api/v1/validate/email/requestToken" % (id_server_url,), + "%s/_matrix/identity/api/v1/validate/email/requestToken" + % (id_server_url,), params, ) return data @@ -520,7 +521,8 @@ class IdentityHandler(BaseHandler): id_server_url = self.rewrite_id_server_url(id_server_url) try: data = yield self.http_client.post_json_get_json( - "%s/_matrix/identity/api/v1/validate/msisdn/requestToken" % (id_server_url,), + "%s/_matrix/identity/api/v1/validate/msisdn/requestToken" + % (id_server_url,), params, ) except HttpResponseException as e: @@ -806,7 +808,7 @@ class IdentityHandler(BaseHandler): 400, "Non-dict object from %s during v2 hash_details request: %s" % (id_server_url, hash_details), - ) + ) # Extract information from hash_details supported_lookup_algorithms = hash_details.get("algorithms") @@ -821,7 +823,7 @@ class IdentityHandler(BaseHandler): 400, "Invalid hash details received from identity server %s: %s" % (id_server_url, hash_details), - ) + ) # Check if any of the supported lookup algorithms are present if LookupAlgorithm.SHA256 in supported_lookup_algorithms: @@ -863,7 +865,7 @@ class IdentityHandler(BaseHandler): "pepper": lookup_pepper, }, headers=headers, - ) + ) except TimeoutError: raise SynapseError(500, "Timed out contacting identity server") except Exception as e: @@ -1011,9 +1013,7 @@ class IdentityHandler(BaseHandler): raise SynapseError(500, "Timed out contacting identity server") except HttpResponseException as e: logger.warning( - "Error trying to call /store-invite on %s: %s", - id_server_url, - e, + "Error trying to call /store-invite on %s: %s", id_server_url, e, ) if data is None: diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 85635f1a0f..34b876b6af 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -619,10 +619,7 @@ class RegistrationHandler(BaseHandler): @defer.inlineCallbacks def post_registration_actions( - self, - user_id, - auth_result, - access_token, + self, user_id, auth_result, access_token, ): """A user has completed registration @@ -654,7 +651,8 @@ class RegistrationHandler(BaseHandler): # Bind the 3PID to the identity server logger.debug( "Binding email to %s on id_server %s", - user_id, self.hs.config.account_threepid_delegate_email, + user_id, + self.hs.config.account_threepid_delegate_email, ) threepid_creds = threepid["threepid_creds"] @@ -662,7 +660,9 @@ class RegistrationHandler(BaseHandler): # `bind_threepid` will add https:// to it, so this restricts # account_threepid_delegate.email to https:// addresses only # We assume this is always the case for dinsic however. - if self.hs.config.account_threepid_delegate_email.startswith("https://"): + if self.hs.config.account_threepid_delegate_email.startswith( + "https://" + ): id_server = self.hs.config.account_threepid_delegate_email[8:] else: # Must start with http:// instead diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 2d9ff97324..cddc95413a 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -24,16 +24,8 @@ from twisted.internet import defer from synapse import types from synapse.api.constants import EventTypes, Membership -from synapse.api.ratelimiting import Ratelimiter -from synapse.api.errors import ( - AuthError, - Codes, - HttpResponseException, - SynapseError, -) -from synapse.handlers.identity import LookupAlgorithm, create_id_access_token_header -from synapse.http.client import SimpleHttpClient from synapse.api.errors import AuthError, Codes, SynapseError +from synapse.api.ratelimiting import Ratelimiter from synapse.types import Collection, RoomID, UserID from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_joined_room, user_left_room diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index fbc5aa13f3..7d2cd29a60 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -31,10 +31,10 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) -from synapse.types import UserID from synapse.push.mailer import Mailer, load_jinja2_templates +from synapse.types import UserID from synapse.util.msisdn import phone_number_to_msisdn -from synapse.util.stringutils import assert_valid_client_secret, random_string +from synapse.util.stringutils import assert_valid_client_secret from synapse.util.threepids import check_3pid_allowed from ._base import client_patterns, interactive_auth_handler @@ -650,7 +650,10 @@ class ThreepidRestServlet(RestServlet): threepid = body.get("threepid") await self.auth_handler.add_threepid( - user_id, threepid["medium"], threepid["address"], threepid["validated_at"] + user_id, + threepid["medium"], + threepid["address"], + threepid["validated_at"], ) if self.hs.config.shadow_server: @@ -710,7 +713,7 @@ class ThreepidRestServlet(RestServlet): "%s/_matrix/client/r0/account/3pid?access_token=%s&user_id=%s" % (shadow_hs_url, as_token, user_id), body, - ) + ) class ThreepidAddRestServlet(RestServlet): diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py index 17495f020b..11d7406040 100644 --- a/synapse/rest/client/v2_alpha/account_data.py +++ b/synapse/rest/client/v2_alpha/account_data.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import AuthError, NotFoundError, SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.types import UserID diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index c1e5560565..c99250c2ee 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -54,7 +54,7 @@ class VersionsRestServlet(RestServlet): "unstable_features": { # as per MSC2190, as amended by MSC2264 # to be removed in r0.6.0 - #"m.id_access_token": True, + # "m.id_access_token": True, # Advertise to clients that they need not include an `id_server` # parameter during registration or password reset, as Synapse now decides # itself which identity server to use (or none at all). @@ -64,14 +64,14 @@ class VersionsRestServlet(RestServlet): # also requires `id_server`. If the homeserver is handling 3PID # verification itself, there is no need to ask the user for `id_server` to # be supplied. - #"m.require_identity_server": False, + # "m.require_identity_server": False, # as per MSC2290 - #"m.separate_add_and_bind": True, + # "m.separate_add_and_bind": True, # Implements support for label-based filtering as described in # MSC2326. "org.matrix.label_based_filtering": True, # Implements support for cross signing as described in MSC1756 - #"org.matrix.e2e_cross_signing": True, + # "org.matrix.e2e_cross_signing": True, # Implements additional endpoints as described in MSC2432 "org.matrix.msc2432": True, # Tchap does not currently assume this rule for r0.5.0 diff --git a/synapse/storage/data_stores/main/profile.py b/synapse/storage/data_stores/main/profile.py index 2903de478e..f438dd38be 100644 --- a/synapse/storage/data_stores/main/profile.py +++ b/synapse/storage/data_stores/main/profile.py @@ -17,8 +17,6 @@ from twisted.internet import defer from synapse.api.errors import StoreError - -from synapse.storage import background_updates from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.roommember import ProfileInfo from synapse.util.caches.descriptors import cached diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index ea52330c8e..b1482db39a 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -32,9 +32,7 @@ from twisted.web.iweb import IPolicyForHTTPS from synapse.config.homeserver import HomeServerConfig from synapse.crypto.context_factory import FederationPolicyForHTTPS -from synapse.http.federation.matrix_federation_agent import ( - MatrixFederationAgent, -) +from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent from synapse.http.federation.srv_resolver import Server from synapse.http.federation.well_known_resolver import ( WellKnownResolver, diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py index e163a46f6b..4224b0a92e 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py @@ -120,9 +120,7 @@ class IdentityEnabledTestCase(unittest.HomeserverTestCase): # TODO: This class does not use a singleton to get it's http client # This should be fixed for easier testing # https://github.com/matrix-org/synapse-dinsic/issues/26 - self.hs.get_handlers().identity_handler.http_client = ( - mock_http_client - ) + self.hs.get_handlers().identity_handler.http_client = mock_http_client return self.hs @@ -142,8 +140,8 @@ class IdentityEnabledTestCase(unittest.HomeserverTestCase): self.hs.get_room_member_handler().simple_http_client = Mock( spec=["get_json", "post_json_get_json"] ) - self.hs.get_room_member_handler().simple_http_client.get_json.return_value = ( - defer.succeed((200, "{}")) + self.hs.get_room_member_handler().simple_http_client.get_json.return_value = defer.succeed( + (200, "{}") ) params = { diff --git a/tests/rest/client/test_room_access_rules.py b/tests/rest/client/test_room_access_rules.py index f10ae0adeb..7da0ef4e18 100644 --- a/tests/rest/client/test_room_access_rules.py +++ b/tests/rest/client/test_room_access_rules.py @@ -21,6 +21,7 @@ import string from mock import Mock from twisted.internet import defer + from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset from synapse.rest import admin from synapse.rest.client.v1 import login, room @@ -83,9 +84,7 @@ class RoomAccessTestCase(unittest.HomeserverTestCase): mock_federation_client = Mock(spec=["send_invite"]) mock_federation_client.send_invite.side_effect = send_invite - mock_http_client = Mock( - spec=["get_json", "post_json_get_json"], - ) + mock_http_client = Mock(spec=["get_json", "post_json_get_json"],) # Mocking the response for /info on the IS API. mock_http_client.get_json.side_effect = get_json # Mocking the response for /store-invite on the IS API. @@ -99,7 +98,9 @@ class RoomAccessTestCase(unittest.HomeserverTestCase): # TODO: This class does not use a singleton to get it's http client # This should be fixed for easier testing # https://github.com/matrix-org/synapse-dinsic/issues/26 - self.hs.get_handlers().identity_handler.blacklisting_http_client = mock_http_client + self.hs.get_handlers().identity_handler.blacklisting_http_client = ( + mock_http_client + ) return self.hs -- cgit 1.5.1 From 8f259e883b2ec2cab40b4e31b59330b08e5bb13a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 16 Apr 2020 10:52:55 -0400 Subject: Do not treat display names as globs for push rules. (#7271) --- changelog.d/7271.bugfix | 1 + synapse/push/push_rule_evaluator.py | 69 +++++++++++++++++++--------------- tests/push/test_push_rule_evaluator.py | 65 ++++++++++++++++++++++++++++++++ tox.ini | 1 + 4 files changed, 106 insertions(+), 30 deletions(-) create mode 100644 changelog.d/7271.bugfix create mode 100644 tests/push/test_push_rule_evaluator.py (limited to 'tests') diff --git a/changelog.d/7271.bugfix b/changelog.d/7271.bugfix new file mode 100644 index 0000000000..e8315e4ce4 --- /dev/null +++ b/changelog.d/7271.bugfix @@ -0,0 +1 @@ +Do not treat display names as globs in push rules. diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index b1587183a8..4cd702b5fa 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -16,9 +16,11 @@ import logging import re +from typing import Pattern from six import string_types +from synapse.events import EventBase from synapse.types import UserID from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache from synapse.util.caches.lrucache import LruCache @@ -56,18 +58,18 @@ def _test_ineq_condition(condition, number): rhs = m.group(2) if not rhs.isdigit(): return False - rhs = int(rhs) + rhs_int = int(rhs) if ineq == "" or ineq == "==": - return number == rhs + return number == rhs_int elif ineq == "<": - return number < rhs + return number < rhs_int elif ineq == ">": - return number > rhs + return number > rhs_int elif ineq == ">=": - return number >= rhs + return number >= rhs_int elif ineq == "<=": - return number <= rhs + return number <= rhs_int else: return False @@ -83,7 +85,13 @@ def tweaks_for_actions(actions): class PushRuleEvaluatorForEvent(object): - def __init__(self, event, room_member_count, sender_power_level, power_levels): + def __init__( + self, + event: EventBase, + room_member_count: int, + sender_power_level: int, + power_levels: dict, + ): self._event = event self._room_member_count = room_member_count self._sender_power_level = sender_power_level @@ -92,7 +100,7 @@ class PushRuleEvaluatorForEvent(object): # Maps strings of e.g. 'content.body' -> event["content"]["body"] self._value_cache = _flatten_dict(event) - def matches(self, condition, user_id, display_name): + def matches(self, condition: dict, user_id: str, display_name: str) -> bool: if condition["kind"] == "event_match": return self._event_match(condition, user_id) elif condition["kind"] == "contains_display_name": @@ -106,7 +114,7 @@ class PushRuleEvaluatorForEvent(object): else: return True - def _event_match(self, condition, user_id): + def _event_match(self, condition: dict, user_id: str) -> bool: pattern = condition.get("pattern", None) if not pattern: @@ -134,7 +142,7 @@ class PushRuleEvaluatorForEvent(object): return _glob_matches(pattern, haystack) - def _contains_display_name(self, display_name): + def _contains_display_name(self, display_name: str) -> bool: if not display_name: return False @@ -142,51 +150,52 @@ class PushRuleEvaluatorForEvent(object): if not body: return False - return _glob_matches(display_name, body, word_boundary=True) + # Similar to _glob_matches, but do not treat display_name as a glob. + r = regex_cache.get((display_name, False, True), None) + if not r: + r = re.escape(display_name) + r = _re_word_boundary(r) + r = re.compile(r, flags=re.IGNORECASE) + regex_cache[(display_name, False, True)] = r + + return r.search(body) - def _get_value(self, dotted_key): + def _get_value(self, dotted_key: str) -> str: return self._value_cache.get(dotted_key, None) -# Caches (glob, word_boundary) -> regex for push. See _glob_matches +# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR) register_cache("cache", "regex_push_cache", regex_cache) -def _glob_matches(glob, value, word_boundary=False): +def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: """Tests if value matches glob. Args: - glob (string) - value (string): String to test against glob. - word_boundary (bool): Whether to match against word boundaries or entire + glob + value: String to test against glob. + word_boundary: Whether to match against word boundaries or entire string. Defaults to False. - - Returns: - bool """ try: - r = regex_cache.get((glob, word_boundary), None) + r = regex_cache.get((glob, True, word_boundary), None) if not r: r = _glob_to_re(glob, word_boundary) - regex_cache[(glob, word_boundary)] = r + regex_cache[(glob, True, word_boundary)] = r return r.search(value) except re.error: logger.warning("Failed to parse glob to regex: %r", glob) return False -def _glob_to_re(glob, word_boundary): +def _glob_to_re(glob: str, word_boundary: bool) -> Pattern: """Generates regex for a given glob. Args: - glob (string) - word_boundary (bool): Whether to match against word boundaries or entire - string. Defaults to False. - - Returns: - regex object + glob + word_boundary: Whether to match against word boundaries or entire string. """ if IS_GLOB.search(glob): r = re.escape(glob) @@ -219,7 +228,7 @@ def _glob_to_re(glob, word_boundary): return re.compile(r, flags=re.IGNORECASE) -def _re_word_boundary(r): +def _re_word_boundary(r: str) -> str: """ Adds word boundary characters to the start and end of an expression to require that the match occur as a whole word, diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py new file mode 100644 index 0000000000..9ae6a87d7b --- /dev/null +++ b/tests/push/test_push_rule_evaluator.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.api.room_versions import RoomVersions +from synapse.events import FrozenEvent +from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent + +from tests import unittest + + +class PushRuleEvaluatorTestCase(unittest.TestCase): + def setUp(self): + event = FrozenEvent( + { + "event_id": "$event_id", + "type": "m.room.history_visibility", + "sender": "@user:test", + "state_key": "", + "room_id": "@room:test", + "content": {"body": "foo bar baz"}, + }, + RoomVersions.V1, + ) + room_member_count = 0 + sender_power_level = 0 + power_levels = {} + self.evaluator = PushRuleEvaluatorForEvent( + event, room_member_count, sender_power_level, power_levels + ) + + def test_display_name(self): + """Check for a matching display name in the body of the event.""" + condition = { + "kind": "contains_display_name", + } + + # Blank names are skipped. + self.assertFalse(self.evaluator.matches(condition, "@user:test", "")) + + # Check a display name that doesn't match. + self.assertFalse(self.evaluator.matches(condition, "@user:test", "not found")) + + # Check a display name which matches. + self.assertTrue(self.evaluator.matches(condition, "@user:test", "foo")) + + # A display name that matches, but not a full word does not result in a match. + self.assertFalse(self.evaluator.matches(condition, "@user:test", "ba")) + + # A display name should not be interpreted as a regular expression. + self.assertFalse(self.evaluator.matches(condition, "@user:test", "ba[rz]")) + + # A display name with spaces should work fine. + self.assertTrue(self.evaluator.matches(condition, "@user:test", "foo bar")) diff --git a/tox.ini b/tox.ini index bfe3c69014..f9d62ba83c 100644 --- a/tox.ini +++ b/tox.ini @@ -195,6 +195,7 @@ commands = mypy \ synapse/metrics \ synapse/module_api \ synapse/push/pusherpool.py \ + synapse/push/push_rule_evaluator.py \ synapse/replication \ synapse/rest \ synapse/spam_checker_api \ -- cgit 1.5.1