diff options
Diffstat (limited to 'synapse')
36 files changed, 402 insertions, 186 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index e395a0a9ee..65a2b894cc 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -17,4 +17,4 @@ """ This is a reference implementation of a Matrix home server. """ -__version__ = "0.33.3.1" +__version__ = "0.33.4" diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 2fd9c48abf..b8d5690f2b 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -21,7 +21,7 @@ from .consent_config import ConsentConfig from .database import DatabaseConfig from .emailconfig import EmailConfig from .groups import GroupsConfig -from .jwt import JWTConfig +from .jwt_config import JWTConfig from .key import KeyConfig from .logger import LoggingConfig from .metrics import MetricsConfig diff --git a/synapse/config/jwt.py b/synapse/config/jwt_config.py index 51e7f7e003..51e7f7e003 100644 --- a/synapse/config/jwt.py +++ b/synapse/config/jwt_config.py diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index 1a391adec1..02b76dfcfb 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -123,6 +123,6 @@ class ClientTLSOptionsFactory(object): def get_options(self, host): return ClientTLSOptions( - host.decode('utf-8'), + host, CertificateOptions(verify=False).getContext() ) diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py index e94400b8e2..57d4665e84 100644 --- a/synapse/crypto/keyclient.py +++ b/synapse/crypto/keyclient.py @@ -50,7 +50,7 @@ def fetch_server_key(server_name, tls_client_options_factory, path=KEY_API_V1): defer.returnValue((server_response, server_certificate)) except SynapseKeyClientError as e: logger.warn("Error getting key for %r: %s", server_name, e) - if e.status.startswith("4"): + if e.status.startswith(b"4"): # Don't retry for 4xx responses. raise IOError("Cannot get key for %r" % server_name) except (ConnectError, DomainError) as e: @@ -82,6 +82,12 @@ class SynapseKeyClientProtocol(HTTPClient): self._peer = self.transport.getPeer() logger.debug("Connected to %s", self._peer) + if not isinstance(self.path, bytes): + self.path = self.path.encode('ascii') + + if not isinstance(self.host, bytes): + self.host = self.host.encode('ascii') + self.sendCommand(b"GET", self.path) if self.host: self.sendHeader(b"Host", self.host) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index a6b31e8b5b..d89f94c219 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -16,9 +16,10 @@ import hashlib import logging -import urllib from collections import namedtuple +from six.moves import urllib + from signedjson.key import ( decode_verify_key_bytes, encode_verify_key_base64, @@ -444,7 +445,7 @@ class Keyring(object): # an incoming request. query_response = yield self.client.post_json( destination=perspective_name, - path=b"/_matrix/key/v2/query", + path="/_matrix/key/v2/query", data={ u"server_keys": { server_name: { @@ -525,8 +526,8 @@ class Keyring(object): (response, tls_certificate) = yield fetch_server_key( server_name, self.hs.tls_client_options_factory, - path=(b"/_matrix/key/v2/server/%s" % ( - urllib.quote(requested_key_id), + path=("/_matrix/key/v2/server/%s" % ( + urllib.parse.quote(requested_key_id), )).encode("ascii"), ) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 94d7423d01..8cbf8c4f7f 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -463,7 +463,19 @@ class TransactionQueue(object): # pending_transactions flag. pending_pdus = self.pending_pdus_by_dest.pop(destination, []) + + # We can only include at most 50 PDUs per transactions + pending_pdus, leftover_pdus = pending_pdus[:50], pending_pdus[50:] + if leftover_pdus: + self.pending_pdus_by_dest[destination] = leftover_pdus + pending_edus = self.pending_edus_by_dest.pop(destination, []) + + # We can only include at most 100 EDUs per transactions + pending_edus, leftover_edus = pending_edus[:100], pending_edus[100:] + if leftover_edus: + self.pending_edus_by_dest[destination] = leftover_edus + pending_presence = self.pending_presence_by_dest.pop(destination, {}) pending_edus.extend( diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 4a81bd2ba9..2a5eab124f 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -895,22 +895,24 @@ class AuthHandler(BaseHandler): Args: password (unicode): Password to hash. - stored_hash (unicode): Expected hash value. + stored_hash (bytes): Expected hash value. Returns: Deferred(bool): Whether self.hash(password) == stored_hash. """ - def _do_validate_hash(): # Normalise the Unicode in the password pw = unicodedata.normalize("NFKC", password) return bcrypt.checkpw( pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"), - stored_hash.encode('utf8') + stored_hash ) if stored_hash: + if not isinstance(stored_hash, bytes): + stored_hash = stored_hash.encode('ascii') + return make_deferred_yieldable( threads.deferToThreadPool( self.hs.get_reactor(), diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 5816bf8b4f..578e9250fb 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -330,7 +330,8 @@ class E2eKeysHandler(object): (algorithm, key_id, ex_json, key) ) else: - new_keys.append((algorithm, key_id, encode_canonical_json(key))) + new_keys.append(( + algorithm, key_id, encode_canonical_json(key).decode('ascii'))) yield self.store.add_e2e_one_time_keys( user_id, device_id, time_now, new_keys @@ -358,7 +359,7 @@ def _exception_to_failure(e): # Note that some Exceptions (notably twisted's ResponseFailed etc) don't # give a string for e.message, which json then fails to serialize. return { - "status": 503, "message": str(e.message), + "status": 503, "message": str(e), } diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 3fa7a98445..0c68e8a472 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -594,7 +594,7 @@ class FederationHandler(BaseHandler): required_auth = set( a_id - for event in events + state_events.values() + auth_events.values() + for event in events + list(state_events.values()) + list(auth_events.values()) for a_id, _ in event.auth_events ) auth_events.update({ @@ -802,7 +802,7 @@ class FederationHandler(BaseHandler): ) continue except NotRetryingDestination as e: - logger.info(e.message) + logger.info(str(e)) continue except FederationDeniedError as e: logger.info(e) @@ -1358,7 +1358,7 @@ class FederationHandler(BaseHandler): ) if state_groups: - _, state = state_groups.items().pop() + _, state = list(state_groups.items()).pop() results = state if event.is_state(): diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 37e41afd61..38e1737ec9 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -162,7 +162,7 @@ class RoomListHandler(BaseHandler): # Filter out rooms that we don't want to return rooms_to_scan = [ r for r in sorted_rooms - if r not in newly_unpublished and rooms_to_num_joined[room_id] > 0 + if r not in newly_unpublished and rooms_to_num_joined[r] > 0 ] total_room_count = len(rooms_to_scan) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index c464adbd0b..0c1d52fd11 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -54,7 +54,7 @@ class SearchHandler(BaseHandler): batch_token = None if batch: try: - b = decode_base64(batch) + b = decode_base64(batch).decode('ascii') batch_group, batch_group_key, batch_token = b.split("\n") assert batch_group is not None @@ -258,18 +258,18 @@ class SearchHandler(BaseHandler): # it returns more from the same group (if applicable) rather # than reverting to searching all results again. if batch_group and batch_group_key: - global_next_batch = encode_base64("%s\n%s\n%s" % ( + global_next_batch = encode_base64(("%s\n%s\n%s" % ( batch_group, batch_group_key, pagination_token - )) + )).encode('ascii')) else: - global_next_batch = encode_base64("%s\n%s\n%s" % ( + global_next_batch = encode_base64(("%s\n%s\n%s" % ( "all", "", pagination_token - )) + )).encode('ascii')) for room_id, group in room_groups.items(): - group["next_batch"] = encode_base64("%s\n%s\n%s" % ( + group["next_batch"] = encode_base64(("%s\n%s\n%s" % ( "room_id", room_id, pagination_token - )) + )).encode('ascii')) allowed_events.extend(room_events) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index ef20c2296c..23983a51ab 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -24,6 +24,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.push.clientformat import format_push_rules_for_user +from synapse.storage.roommember import MemberSummary from synapse.types import RoomStreamToken from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache @@ -525,6 +526,8 @@ class SyncHandler(object): A deferred dict describing the room summary """ + # FIXME: we could/should get this from room_stats when matthew/stats lands + # FIXME: this promulgates https://github.com/matrix-org/synapse/issues/3305 last_events, _ = yield self.store.get_recent_event_ids_for_room( room_id, end_token=now_token.room_key, limit=1, @@ -537,44 +540,54 @@ class SyncHandler(object): last_event = last_events[-1] state_ids = yield self.store.get_state_ids_for_event( last_event.event_id, [ - (EventTypes.Member, None), (EventTypes.Name, ''), (EventTypes.CanonicalAlias, ''), ] ) - member_ids = { - state_key: event_id - for (t, state_key), event_id in state_ids.iteritems() - if t == EventTypes.Member - } + # this is heavily cached, thus: fast. + details = yield self.store.get_room_summary(room_id) + name_id = state_ids.get((EventTypes.Name, '')) canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, '')) summary = {} - - # FIXME: it feels very heavy to load up every single membership event - # just to calculate the counts. - member_events = yield self.store.get_events(member_ids.values()) - - joined_user_ids = [] - invited_user_ids = [] - - for ev in member_events.values(): - if ev.content.get("membership") == Membership.JOIN: - joined_user_ids.append(ev.state_key) - elif ev.content.get("membership") == Membership.INVITE: - invited_user_ids.append(ev.state_key) + empty_ms = MemberSummary([], 0) # TODO: only send these when they change. - summary["m.joined_member_count"] = len(joined_user_ids) - summary["m.invited_member_count"] = len(invited_user_ids) + summary["m.joined_member_count"] = ( + details.get(Membership.JOIN, empty_ms).count + ) + summary["m.invited_member_count"] = ( + details.get(Membership.INVITE, empty_ms).count + ) if name_id or canonical_alias_id: defer.returnValue(summary) - # FIXME: order by stream ordering, not alphabetic + joined_user_ids = [ + r[0] for r in details.get(Membership.JOIN, empty_ms).members + ] + invited_user_ids = [ + r[0] for r in details.get(Membership.INVITE, empty_ms).members + ] + gone_user_ids = ( + [r[0] for r in details.get(Membership.LEAVE, empty_ms).members] + + [r[0] for r in details.get(Membership.BAN, empty_ms).members] + ) + + # FIXME: only build up a member_ids list for our heroes + member_ids = {} + for membership in ( + Membership.JOIN, + Membership.INVITE, + Membership.LEAVE, + Membership.BAN + ): + for user_id, event_id in details.get(membership, empty_ms).members: + member_ids[user_id] = event_id + # FIXME: order by stream ordering rather than as returned by SQL me = sync_config.user.to_string() if (joined_user_ids or invited_user_ids): summary['m.heroes'] = sorted( @@ -586,7 +599,11 @@ class SyncHandler(object): )[0:5] else: summary['m.heroes'] = sorted( - [user_id for user_id in member_ids.keys() if user_id != me] + [ + user_id + for user_id in gone_user_ids + if user_id != me + ] )[0:5] if not sync_config.filter_collection.lazy_load_members(): @@ -719,6 +736,26 @@ class SyncHandler(object): lazy_load_members=lazy_load_members, ) elif batch.limited: + state_at_timeline_start = yield self.store.get_state_ids_for_event( + batch.events[0].event_id, types=types, + filtered_types=filtered_types, + ) + + # for now, we disable LL for gappy syncs - see + # https://github.com/vector-im/riot-web/issues/7211#issuecomment-419976346 + # N.B. this slows down incr syncs as we are now processing way + # more state in the server than if we were LLing. + # + # We still have to filter timeline_start to LL entries (above) in order + # for _calculate_state's LL logic to work, as we have to include LL + # members for timeline senders in case they weren't loaded in the initial + # sync. We do this by (counterintuitively) by filtering timeline_start + # members to just be ones which were timeline senders, which then ensures + # all of the rest get included in the state block (if we need to know + # about them). + types = None + filtered_types = None + state_at_previous_sync = yield self.get_state_at( room_id, stream_position=since_token, types=types, filtered_types=filtered_types, @@ -729,24 +766,21 @@ class SyncHandler(object): filtered_types=filtered_types, ) - state_at_timeline_start = yield self.store.get_state_ids_for_event( - batch.events[0].event_id, types=types, - filtered_types=filtered_types, - ) - state_ids = _calculate_state( timeline_contains=timeline_state, timeline_start=state_at_timeline_start, previous=state_at_previous_sync, current=current_state_ids, + # we have to include LL members in case LL initial sync missed them lazy_load_members=lazy_load_members, ) else: state_ids = {} if lazy_load_members: if types: - # We're returning an incremental sync, with no "gap" since - # the previous sync, so normally there would be no state to return + # We're returning an incremental sync, with no + # "gap" since the previous sync, so normally there would be + # no state to return. # But we're lazy-loading, so the client might need some more # member events to understand the events in this timeline. # So we fish out all the member events corresponding to the @@ -774,7 +808,7 @@ class SyncHandler(object): logger.debug("filtering state from %r...", state_ids) state_ids = { t: event_id - for t, event_id in state_ids.iteritems() + for t, event_id in iteritems(state_ids) if cache.get(t[1]) != event_id } logger.debug("...to %r", state_ids) @@ -1575,6 +1609,19 @@ class SyncHandler(object): newly_joined_room=newly_joined, ) + # When we join the room (or the client requests full_state), we should + # send down any existing tags. Usually the user won't have tags in a + # newly joined room, unless either a) they've joined before or b) the + # tag was added by synapse e.g. for server notice rooms. + if full_state: + user_id = sync_result_builder.sync_config.user.to_string() + tags = yield self.store.get_tags_for_room(user_id, room_id) + + # If there aren't any tags, don't send the empty tags list down + # sync + if not tags: + tags = None + account_data_events = [] if tags is not None: account_data_events.append({ @@ -1603,10 +1650,24 @@ class SyncHandler(object): ) summary = {} + + # we include a summary in room responses when we're lazy loading + # members (as the client otherwise doesn't have enough info to form + # the name itself). if ( sync_config.filter_collection.lazy_load_members() and ( + # we recalulate the summary: + # if there are membership changes in the timeline, or + # if membership has changed during a gappy sync, or + # if this is an initial sync. any(ev.type == EventTypes.Member for ev in batch.events) or + ( + # XXX: this may include false positives in the form of LL + # members which have snuck into state + batch.limited and + any(t == EventTypes.Member for (t, k) in state) + ) or since_token is None ) ): @@ -1636,6 +1697,16 @@ class SyncHandler(object): unread_notifications["highlight_count"] = notifs["highlight_count"] sync_result_builder.joined.append(room_sync) + + if batch.limited: + user_id = sync_result_builder.sync_config.user.to_string() + logger.info( + "Incremental syncing room %s for user %s with %d state events" % ( + room_id, + user_id, + len(state), + ) + ) elif room_builder.rtype == "archived": room_sync = ArchivedSyncResult( room_id=room_id, @@ -1729,17 +1800,17 @@ def _calculate_state( event_id_to_key = { e: key for key, e in itertools.chain( - timeline_contains.items(), - previous.items(), - timeline_start.items(), - current.items(), + iteritems(timeline_contains), + iteritems(previous), + iteritems(timeline_start), + iteritems(current), ) } - c_ids = set(e for e in current.values()) - ts_ids = set(e for e in timeline_start.values()) - p_ids = set(e for e in previous.values()) - tc_ids = set(e for e in timeline_contains.values()) + c_ids = set(e for e in itervalues(current)) + ts_ids = set(e for e in itervalues(timeline_start)) + p_ids = set(e for e in itervalues(previous)) + tc_ids = set(e for e in itervalues(timeline_contains)) # If we are lazyloading room members, we explicitly add the membership events # for the senders in the timeline into the state block returned by /sync, @@ -1753,7 +1824,7 @@ def _calculate_state( if lazy_load_members: p_ids.difference_update( - e for t, e in timeline_start.iteritems() + e for t, e in iteritems(timeline_start) if t[0] == EventTypes.Member ) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 6a1fc8ca55..f9a1fbf95d 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -280,7 +280,10 @@ class MatrixFederationHttpClient(object): # :'( # Update transactions table? with logcontext.PreserveLoggingContext(): - body = yield treq.content(response) + body = yield self._timeout_deferred( + treq.content(response), + timeout, + ) raise HttpResponseException( response.code, response.phrase, body ) @@ -394,7 +397,10 @@ class MatrixFederationHttpClient(object): check_content_type_is_json(response.headers) with logcontext.PreserveLoggingContext(): - body = yield treq.json_content(response) + body = yield self._timeout_deferred( + treq.json_content(response), + timeout, + ) defer.returnValue(body) @defer.inlineCallbacks @@ -444,7 +450,10 @@ class MatrixFederationHttpClient(object): check_content_type_is_json(response.headers) with logcontext.PreserveLoggingContext(): - body = yield treq.json_content(response) + body = yield self._timeout_deferred( + treq.json_content(response), + timeout, + ) defer.returnValue(body) @@ -496,7 +505,10 @@ class MatrixFederationHttpClient(object): check_content_type_is_json(response.headers) with logcontext.PreserveLoggingContext(): - body = yield treq.json_content(response) + body = yield self._timeout_deferred( + treq.json_content(response), + timeout, + ) defer.returnValue(body) @@ -543,7 +555,10 @@ class MatrixFederationHttpClient(object): check_content_type_is_json(response.headers) with logcontext.PreserveLoggingContext(): - body = yield treq.json_content(response) + body = yield self._timeout_deferred( + treq.json_content(response), + timeout, + ) defer.returnValue(body) @@ -585,8 +600,10 @@ class MatrixFederationHttpClient(object): try: with logcontext.PreserveLoggingContext(): - length = yield _readBodyToFile( - response, output_stream, max_size + length = yield self._timeout_deferred( + _readBodyToFile( + response, output_stream, max_size + ), ) except Exception: logger.exception("Failed to download body") @@ -594,6 +611,27 @@ class MatrixFederationHttpClient(object): defer.returnValue((length, headers)) + def _timeout_deferred(self, deferred, timeout_ms=None): + """Times the deferred out after `timeout_ms` ms + + Args: + deferred (Deferred) + timeout_ms (int|None): Timeout in milliseconds. If None defaults + to 60 seconds. + + Returns: + Deferred + """ + + add_timeout_to_deferred( + deferred, + timeout_ms / 1000. if timeout_ms else 60, + self.hs.get_reactor(), + cancelled_to_request_timed_out_error, + ) + + return deferred + class _ReadBodyToFileProtocol(protocol.Protocol): def __init__(self, stream, deferred, max_size): diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 6dd5179320..0d8de600cf 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -42,8 +42,8 @@ REQUIREMENTS = { "Twisted>=17.1.0": ["twisted>=17.1.0"], "treq>=15.1": ["treq>=15.1"], - # We use crypto.get_elliptic_curve which is only supported in >=0.15 - "pyopenssl>=0.15": ["OpenSSL>=0.15"], + # Twisted has required pyopenssl 16.0 since about Twisted 16.6. + "pyopenssl>=16.0.0": ["OpenSSL>=16.0.0"], "pyyaml": ["yaml"], "pyasn1": ["pyasn1"], diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 74e892c104..5dc7b3fffc 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -590,9 +590,9 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): pending_commands = LaterGauge( "synapse_replication_tcp_protocol_pending_commands", "", - ["name", "conn_id"], + ["name"], lambda: { - (p.name, p.conn_id): len(p.pending_commands) for p in connected_connections + (p.name,): len(p.pending_commands) for p in connected_connections }, ) @@ -607,9 +607,9 @@ def transport_buffer_size(protocol): transport_send_buffer = LaterGauge( "synapse_replication_tcp_protocol_transport_send_buffer", "", - ["name", "conn_id"], + ["name"], lambda: { - (p.name, p.conn_id): transport_buffer_size(p) for p in connected_connections + (p.name,): transport_buffer_size(p) for p in connected_connections }, ) @@ -632,9 +632,9 @@ def transport_kernel_read_buffer_size(protocol, read=True): tcp_transport_kernel_send_buffer = LaterGauge( "synapse_replication_tcp_protocol_transport_kernel_send_buffer", "", - ["name", "conn_id"], + ["name"], lambda: { - (p.name, p.conn_id): transport_kernel_read_buffer_size(p, False) + (p.name,): transport_kernel_read_buffer_size(p, False) for p in connected_connections }, ) @@ -643,9 +643,9 @@ tcp_transport_kernel_send_buffer = LaterGauge( tcp_transport_kernel_read_buffer = LaterGauge( "synapse_replication_tcp_protocol_transport_kernel_read_buffer", "", - ["name", "conn_id"], + ["name"], lambda: { - (p.name, p.conn_id): transport_kernel_read_buffer_size(p, True) + (p.name,): transport_kernel_read_buffer_size(p, True) for p in connected_connections }, ) @@ -654,9 +654,9 @@ tcp_transport_kernel_read_buffer = LaterGauge( tcp_inbound_commands = LaterGauge( "synapse_replication_tcp_protocol_inbound_commands", "", - ["command", "name", "conn_id"], + ["command", "name"], lambda: { - (k[0], p.name, p.conn_id): count + (k[0], p.name,): count for p in connected_connections for k, count in iteritems(p.inbound_commands_counter) }, @@ -665,9 +665,9 @@ tcp_inbound_commands = LaterGauge( tcp_outbound_commands = LaterGauge( "synapse_replication_tcp_protocol_outbound_commands", "", - ["command", "name", "conn_id"], + ["command", "name"], lambda: { - (k[0], p.name, p.conn_id): count + (k[0], p.name,): count for p in connected_connections for k, count in iteritems(p.outbound_commands_counter) }, diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index ad536ab570..41534b8c2a 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -101,7 +101,7 @@ class UserRegisterServlet(ClientV1RestServlet): nonce = self.hs.get_secrets().token_hex(64) self.nonces[nonce] = int(self.reactor.seconds()) - return (200, {"nonce": nonce.encode('ascii')}) + return (200, {"nonce": nonce}) @defer.inlineCallbacks def on_POST(self, request): @@ -164,7 +164,7 @@ class UserRegisterServlet(ClientV1RestServlet): key=self.hs.config.registration_shared_secret.encode(), digestmod=hashlib.sha1, ) - want_mac.update(nonce) + want_mac.update(nonce.encode('utf8')) want_mac.update(b"\x00") want_mac.update(username) want_mac.update(b"\x00") @@ -173,7 +173,10 @@ class UserRegisterServlet(ClientV1RestServlet): want_mac.update(b"admin" if admin else b"notadmin") want_mac = want_mac.hexdigest() - if not hmac.compare_digest(want_mac, got_mac.encode('ascii')): + if not hmac.compare_digest( + want_mac.encode('ascii'), + got_mac.encode('ascii') + ): raise SynapseError(403, "HMAC incorrect") # Reuse the parts of RegisterRestServlet to reduce code duplication diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 0f3a2e8b51..cd9b3bdbd1 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -45,20 +45,20 @@ class EventStreamRestServlet(ClientV1RestServlet): is_guest = requester.is_guest room_id = None if is_guest: - if "room_id" not in request.args: + if b"room_id" not in request.args: raise SynapseError(400, "Guest users must specify room_id param") - if "room_id" in request.args: - room_id = request.args["room_id"][0] + if b"room_id" in request.args: + room_id = request.args[b"room_id"][0].decode('ascii') pagin_config = PaginationConfig.from_request(request) timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS - if "timeout" in request.args: + if b"timeout" in request.args: try: - timeout = int(request.args["timeout"][0]) + timeout = int(request.args[b"timeout"][0]) except ValueError: raise SynapseError(400, "timeout must be in milliseconds.") - as_client_event = "raw" not in request.args + as_client_event = b"raw" not in request.args chunk = yield self.event_stream_handler.get_stream( requester.user.to_string(), diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index fd5f85b53e..3ead75cb77 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -32,7 +32,7 @@ class InitialSyncRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): requester = yield self.auth.get_user_by_req(request) - as_client_event = "raw" not in request.args + as_client_event = b"raw" not in request.args pagination_config = PaginationConfig.from_request(request) include_archived = parse_boolean(request, "archived", default=False) content = yield self.initial_sync_handler.snapshot_all_rooms( diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index cb85fa1436..0010699d31 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -14,10 +14,9 @@ # limitations under the License. import logging -import urllib import xml.etree.ElementTree as ET -from six.moves.urllib import parse as urlparse +from six.moves import urllib from canonicaljson import json from saml2 import BINDING_HTTP_POST, config @@ -134,7 +133,7 @@ class LoginRestServlet(ClientV1RestServlet): LoginRestServlet.SAML2_TYPE): relay_state = "" if "relay_state" in login_submission: - relay_state = "&RelayState=" + urllib.quote( + relay_state = "&RelayState=" + urllib.parse.quote( login_submission["relay_state"]) result = { "uri": "%s%s" % (self.idp_redirect_url, relay_state) @@ -366,7 +365,7 @@ class SAML2RestServlet(ClientV1RestServlet): (user_id, token) = yield handler.register_saml2(username) # Forward to the RelayState callback along with ava if 'RelayState' in request.args: - request.redirect(urllib.unquote( + request.redirect(urllib.parse.unquote( request.args['RelayState'][0]) + '?status=authenticated&access_token=' + token + '&user_id=' + user_id + '&ava=' + @@ -377,7 +376,7 @@ class SAML2RestServlet(ClientV1RestServlet): "user_id": user_id, "token": token, "ava": saml2_auth.ava})) elif 'RelayState' in request.args: - request.redirect(urllib.unquote( + request.redirect(urllib.parse.unquote( request.args['RelayState'][0]) + '?status=not_authenticated') finish_request(request) @@ -390,21 +389,22 @@ class CasRedirectServlet(ClientV1RestServlet): def __init__(self, hs): super(CasRedirectServlet, self).__init__(hs) - self.cas_server_url = hs.config.cas_server_url - self.cas_service_url = hs.config.cas_service_url + self.cas_server_url = hs.config.cas_server_url.encode('ascii') + self.cas_service_url = hs.config.cas_service_url.encode('ascii') def on_GET(self, request): args = request.args - if "redirectUrl" not in args: + if b"redirectUrl" not in args: return (400, "Redirect URL not specified for CAS auth") - client_redirect_url_param = urllib.urlencode({ - "redirectUrl": args["redirectUrl"][0] - }) - hs_redirect_url = self.cas_service_url + "/_matrix/client/api/v1/login/cas/ticket" - service_param = urllib.urlencode({ - "service": "%s?%s" % (hs_redirect_url, client_redirect_url_param) - }) - request.redirect("%s/login?%s" % (self.cas_server_url, service_param)) + client_redirect_url_param = urllib.parse.urlencode({ + b"redirectUrl": args[b"redirectUrl"][0] + }).encode('ascii') + hs_redirect_url = (self.cas_service_url + + b"/_matrix/client/api/v1/login/cas/ticket") + service_param = urllib.parse.urlencode({ + b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param) + }).encode('ascii') + request.redirect(b"%s/login?%s" % (self.cas_server_url, service_param)) finish_request(request) @@ -422,11 +422,11 @@ class CasTicketServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - client_redirect_url = request.args["redirectUrl"][0] + client_redirect_url = request.args[b"redirectUrl"][0] http_client = self.hs.get_simple_http_client() uri = self.cas_server_url + "/proxyValidate" args = { - "ticket": request.args["ticket"], + "ticket": request.args[b"ticket"][0].decode('ascii'), "service": self.cas_service_url } try: @@ -471,11 +471,11 @@ class CasTicketServlet(ClientV1RestServlet): finish_request(request) def add_login_token_to_redirect_url(self, url, token): - url_parts = list(urlparse.urlparse(url)) - query = dict(urlparse.parse_qsl(url_parts[4])) + url_parts = list(urllib.parse.urlparse(url)) + query = dict(urllib.parse.parse_qsl(url_parts[4])) query.update({"loginToken": token}) - url_parts[4] = urllib.urlencode(query) - return urlparse.urlunparse(url_parts) + url_parts[4] = urllib.parse.urlencode(query).encode('ascii') + return urllib.parse.urlunparse(url_parts) def parse_cas_response(self, cas_response_body): user = None diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 6e95d9bec2..9382b1f124 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -46,7 +46,7 @@ class PushRuleRestServlet(ClientV1RestServlet): try: priority_class = _priority_class_from_spec(spec) except InvalidRuleException as e: - raise SynapseError(400, e.message) + raise SynapseError(400, str(e)) requester = yield self.auth.get_user_by_req(request) @@ -73,7 +73,7 @@ class PushRuleRestServlet(ClientV1RestServlet): content, ) except InvalidRuleException as e: - raise SynapseError(400, e.message) + raise SynapseError(400, str(e)) before = parse_string(request, "before") if before: @@ -95,9 +95,9 @@ class PushRuleRestServlet(ClientV1RestServlet): ) self.notify_user(user_id) except InconsistentRuleException as e: - raise SynapseError(400, e.message) + raise SynapseError(400, str(e)) except RuleNotFoundException as e: - raise SynapseError(400, e.message) + raise SynapseError(400, str(e)) defer.returnValue((200, {})) @@ -142,10 +142,10 @@ class PushRuleRestServlet(ClientV1RestServlet): PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR ) - if path[0] == '': + if path[0] == b'': defer.returnValue((200, rules)) - elif path[0] == 'global': - path = path[1:] + elif path[0] == b'global': + path = [x.decode('ascii') for x in path[1:]] result = _filter_ruleset_with_path(rules['global'], path) defer.returnValue((200, result)) else: @@ -192,10 +192,10 @@ class PushRuleRestServlet(ClientV1RestServlet): def _rule_spec_from_path(path): if len(path) < 2: raise UnrecognizedRequestError() - if path[0] != 'pushrules': + if path[0] != b'pushrules': raise UnrecognizedRequestError() - scope = path[1] + scope = path[1].decode('ascii') path = path[2:] if scope != 'global': raise UnrecognizedRequestError() @@ -203,13 +203,13 @@ def _rule_spec_from_path(path): if len(path) == 0: raise UnrecognizedRequestError() - template = path[0] + template = path[0].decode('ascii') path = path[1:] if len(path) == 0 or len(path[0]) == 0: raise UnrecognizedRequestError() - rule_id = path[0] + rule_id = path[0].decode('ascii') spec = { 'scope': scope, @@ -220,7 +220,7 @@ def _rule_spec_from_path(path): path = path[1:] if len(path) > 0 and len(path[0]) > 0: - spec['attr'] = path[0] + spec['attr'] = path[0].decode('ascii') return spec diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 182a68b1e2..b84f0260f2 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -59,7 +59,7 @@ class PushersRestServlet(ClientV1RestServlet): ] for p in pushers: - for k, v in p.items(): + for k, v in list(p.items()): if k not in allowed_keys: del p[k] @@ -126,7 +126,7 @@ class PushersSetRestServlet(ClientV1RestServlet): profile_tag=content.get('profile_tag', ""), ) except PusherConfigException as pce: - raise SynapseError(400, "Config Error: " + pce.message, + raise SynapseError(400, "Config Error: " + str(pce), errcode=Codes.MISSING_PARAM) self.notifier.on_new_replication_data() diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 976d98387d..663934efd0 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -207,7 +207,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet): "sender": requester.user.to_string(), } - if 'ts' in request.args and requester.app_service: + if b'ts' in request.args and requester.app_service: event_dict['origin_server_ts'] = parse_integer(request, "ts", 0) event = yield self.event_creation_hander.create_and_send_nonmember_event( @@ -255,7 +255,9 @@ class JoinRoomAliasServlet(ClientV1RestServlet): if RoomID.is_valid(room_identifier): room_id = room_identifier try: - remote_room_hosts = request.args["server_name"] + remote_room_hosts = [ + x.decode('ascii') for x in request.args[b"server_name"] + ] except Exception: remote_room_hosts = None elif RoomAlias.is_valid(room_identifier): @@ -461,10 +463,10 @@ class RoomMessageListRestServlet(ClientV1RestServlet): pagination_config = PaginationConfig.from_request( request, default_limit=10, ) - as_client_event = "raw" not in request.args - filter_bytes = parse_string(request, "filter") + as_client_event = b"raw" not in request.args + filter_bytes = parse_string(request, b"filter", encoding=None) if filter_bytes: - filter_json = urlparse.unquote(filter_bytes).decode("UTF-8") + filter_json = urlparse.unquote(filter_bytes.decode("UTF-8")) event_filter = Filter(json.loads(filter_json)) else: event_filter = None @@ -560,7 +562,7 @@ class RoomEventContextServlet(ClientV1RestServlet): # picking the API shape for symmetry with /messages filter_bytes = parse_string(request, "filter") if filter_bytes: - filter_json = urlparse.unquote(filter_bytes).decode("UTF-8") + filter_json = urlparse.unquote(filter_bytes) event_filter = Filter(json.loads(filter_json)) else: event_filter = None diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 263d8eb73e..0251146722 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -89,7 +89,7 @@ class SyncRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request): - if "from" in request.args: + if b"from" in request.args: # /events used to use 'from', but /sync uses 'since'. # Lets be helpful and whine if we see a 'from'. raise SynapseError( diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index d9d379182e..b9b5d07677 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -79,7 +79,7 @@ class ThirdPartyUserServlet(RestServlet): yield self.auth.get_user_by_req(request, allow_guest=True) fields = request.args - fields.pop("access_token", None) + fields.pop(b"access_token", None) results = yield self.appservice_handler.query_3pe( ThirdPartyEntityKind.USER, protocol, fields @@ -102,7 +102,7 @@ class ThirdPartyLocationServlet(RestServlet): yield self.auth.get_user_by_req(request, allow_guest=True) fields = request.args - fields.pop("access_token", None) + fields.pop(b"access_token", None) results = yield self.appservice_handler.query_3pe( ThirdPartyEntityKind.LOCATION, protocol, fields diff --git a/synapse/rest/key/v1/server_key_resource.py b/synapse/rest/key/v1/server_key_resource.py index b9ee6e1c13..38eb2ee23f 100644 --- a/synapse/rest/key/v1/server_key_resource.py +++ b/synapse/rest/key/v1/server_key_resource.py @@ -88,5 +88,5 @@ class LocalKey(Resource): ) def getChild(self, name, request): - if name == '': + if name == b'': return self diff --git a/synapse/rest/key/v2/__init__.py b/synapse/rest/key/v2/__init__.py index 3491fd2118..cb5abcf826 100644 --- a/synapse/rest/key/v2/__init__.py +++ b/synapse/rest/key/v2/__init__.py @@ -22,5 +22,5 @@ from .remote_key_resource import RemoteKey class KeyApiV2Resource(Resource): def __init__(self, hs): Resource.__init__(self) - self.putChild("server", LocalKey(hs)) - self.putChild("query", RemoteKey(hs)) + self.putChild(b"server", LocalKey(hs)) + self.putChild(b"query", RemoteKey(hs)) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 7d67e4b064..eb8782aa6e 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -103,7 +103,7 @@ class RemoteKey(Resource): def async_render_GET(self, request): if len(request.postpath) == 1: server, = request.postpath - query = {server: {}} + query = {server.decode('ascii'): {}} elif len(request.postpath) == 2: server, key_id = request.postpath minimum_valid_until_ts = parse_integer( @@ -112,11 +112,12 @@ class RemoteKey(Resource): arguments = {} if minimum_valid_until_ts is not None: arguments["minimum_valid_until_ts"] = minimum_valid_until_ts - query = {server: {key_id: arguments}} + query = {server.decode('ascii'): {key_id.decode('ascii'): arguments}} else: raise SynapseError( 404, "Not found %r" % request.postpath, Codes.NOT_FOUND ) + yield self.query_keys(request, query, query_remote_on_cache_miss=True) def render_POST(self, request): @@ -135,6 +136,7 @@ class RemoteKey(Resource): @defer.inlineCallbacks def query_keys(self, request, query, query_remote_on_cache_miss=False): logger.info("Handling query for keys %r", query) + store_queries = [] for server_name, key_ids in query.items(): if ( diff --git a/synapse/rest/media/v0/content_repository.py b/synapse/rest/media/v0/content_repository.py index f255f2883f..5a426ff2f6 100644 --- a/synapse/rest/media/v0/content_repository.py +++ b/synapse/rest/media/v0/content_repository.py @@ -56,7 +56,7 @@ class ContentRepoResource(resource.Resource): # servers. # TODO: A little crude here, we could do this better. - filename = request.path.split('/')[-1] + filename = request.path.decode('ascii').split('/')[-1] # be paranoid filename = re.sub("[^0-9A-z.-_]", "", filename) @@ -78,7 +78,7 @@ class ContentRepoResource(resource.Resource): # select private. don't bother setting Expires as all our matrix # clients are smart enough to be happy with Cache-Control (right?) request.setHeader( - "Cache-Control", "public,max-age=86400,s-maxage=86400" + b"Cache-Control", b"public,max-age=86400,s-maxage=86400" ) d = FileSender().beginFileTransfer(f, request) diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 65f4bd2910..76e479afa3 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -15,9 +15,8 @@ import logging import os -import urllib -from six.moves.urllib import parse as urlparse +from six.moves import urllib from twisted.internet import defer from twisted.protocols.basic import FileSender @@ -35,10 +34,15 @@ def parse_media_id(request): # This allows users to append e.g. /test.png to the URL. Useful for # clients that parse the URL to see content type. server_name, media_id = request.postpath[:2] + + if isinstance(server_name, bytes): + server_name = server_name.decode('utf-8') + media_id = media_id.decode('utf8') + file_name = None if len(request.postpath) > 2: try: - file_name = urlparse.unquote(request.postpath[-1]).decode("utf-8") + file_name = urllib.parse.unquote(request.postpath[-1].decode("utf-8")) except UnicodeDecodeError: pass return server_name, media_id, file_name @@ -93,22 +97,18 @@ def add_file_headers(request, media_type, file_size, upload_name): file_size (int): Size in bytes of the media, if known. upload_name (str): The name of the requested file, if any. """ + def _quote(x): + return urllib.parse.quote(x.encode("utf-8")) + request.setHeader(b"Content-Type", media_type.encode("UTF-8")) if upload_name: if is_ascii(upload_name): - request.setHeader( - b"Content-Disposition", - b"inline; filename=%s" % ( - urllib.quote(upload_name.encode("utf-8")), - ), - ) + disposition = ("inline; filename=%s" % (_quote(upload_name),)).encode("ascii") else: - request.setHeader( - b"Content-Disposition", - b"inline; filename*=utf-8''%s" % ( - urllib.quote(upload_name.encode("utf-8")), - ), - ) + disposition = ( + "inline; filename*=utf-8''%s" % (_quote(upload_name),)).encode("ascii") + + request.setHeader(b"Content-Disposition", disposition) # cache for at least a day. # XXX: we might want to turn this off for data we don't want to diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py index fbfa85f74f..ca90964d1d 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py @@ -47,12 +47,12 @@ class DownloadResource(Resource): def _async_render_GET(self, request): set_cors_headers(request) request.setHeader( - "Content-Security-Policy", - "default-src 'none';" - " script-src 'none';" - " plugin-types application/pdf;" - " style-src 'unsafe-inline';" - " object-src 'self';" + b"Content-Security-Policy", + b"default-src 'none';" + b" script-src 'none';" + b" plugin-types application/pdf;" + b" style-src 'unsafe-inline';" + b" object-src 'self';" ) server_name, media_id, name = parse_media_id(request) if server_name == self.server_name: diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 241c972070..a828ff4438 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -20,7 +20,7 @@ import logging import os import shutil -from six import iteritems +from six import PY3, iteritems from six.moves.urllib import parse as urlparse import twisted.internet.error @@ -397,13 +397,13 @@ class MediaRepository(object): yield finish() - media_type = headers["Content-Type"][0] + media_type = headers[b"Content-Type"][0].decode('ascii') time_now_ms = self.clock.time_msec() - content_disposition = headers.get("Content-Disposition", None) + content_disposition = headers.get(b"Content-Disposition", None) if content_disposition: - _, params = cgi.parse_header(content_disposition[0],) + _, params = cgi.parse_header(content_disposition[0].decode('ascii'),) upload_name = None # First check if there is a valid UTF-8 filename @@ -419,9 +419,13 @@ class MediaRepository(object): upload_name = upload_name_ascii if upload_name: - upload_name = urlparse.unquote(upload_name) + if PY3: + upload_name = urlparse.unquote(upload_name) + else: + upload_name = urlparse.unquote(upload_name.encode('ascii')) try: - upload_name = upload_name.decode("utf-8") + if isinstance(upload_name, bytes): + upload_name = upload_name.decode("utf-8") except UnicodeDecodeError: upload_name = None else: @@ -755,14 +759,15 @@ class MediaRepositoryResource(Resource): Resource.__init__(self) media_repo = hs.get_media_repository() - self.putChild("upload", UploadResource(hs, media_repo)) - self.putChild("download", DownloadResource(hs, media_repo)) - self.putChild("thumbnail", ThumbnailResource( + + self.putChild(b"upload", UploadResource(hs, media_repo)) + self.putChild(b"download", DownloadResource(hs, media_repo)) + self.putChild(b"thumbnail", ThumbnailResource( hs, media_repo, media_repo.media_storage, )) - self.putChild("identicon", IdenticonResource()) + self.putChild(b"identicon", IdenticonResource()) if hs.config.url_preview_enabled: - self.putChild("preview_url", PreviewUrlResource( + self.putChild(b"preview_url", PreviewUrlResource( hs, media_repo, media_repo.media_storage, )) - self.putChild("config", MediaConfigResource(hs)) + self.putChild(b"config", MediaConfigResource(hs)) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 778ef97337..cad2dec33a 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -261,7 +261,7 @@ class PreviewUrlResource(Resource): logger.debug("Calculated OG for %s as %s" % (url, og)) - jsonog = json.dumps(og) + jsonog = json.dumps(og).encode('utf8') # store OG in history-aware DB cache yield self.store.store_url_cache( @@ -301,20 +301,20 @@ class PreviewUrlResource(Resource): logger.warn("Error downloading %s: %r", url, e) raise SynapseError( 500, "Failed to download content: %s" % ( - traceback.format_exception_only(sys.exc_type, e), + traceback.format_exception_only(sys.exc_info()[0], e), ), Codes.UNKNOWN, ) yield finish() try: - if "Content-Type" in headers: - media_type = headers["Content-Type"][0] + if b"Content-Type" in headers: + media_type = headers[b"Content-Type"][0].decode('ascii') else: media_type = "application/octet-stream" time_now_ms = self.clock.time_msec() - content_disposition = headers.get("Content-Disposition", None) + content_disposition = headers.get(b"Content-Disposition", None) if content_disposition: _, params = cgi.parse_header(content_disposition[0],) download_name = None diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 8bf87f38f7..30ff87a4c4 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -930,6 +930,10 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore ) self._invalidate_cache_and_stream( + txn, self.get_room_summary, (room_id,) + ) + + self._invalidate_cache_and_stream( txn, self.get_current_state_ids, (room_id,) ) diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index c7899d7fd2..b890c152db 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -199,10 +199,14 @@ class MonthlyActiveUsersStore(SQLBaseStore): Args: user_id(str): the user_id to query """ + if self.hs.config.limit_usage_by_mau: + # Trial users and guests should not be included as part of MAU group + is_guest = yield self.is_guest(user_id) + if is_guest: + return is_trial = yield self.is_trial_user(user_id) if is_trial: - # we don't track trial users in the MAU table. return last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 9b4e6d6aa8..0707f9a86a 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -51,6 +51,12 @@ ProfileInfo = namedtuple( "ProfileInfo", ("avatar_url", "display_name") ) +# "members" points to a truncated list of (user_id, event_id) tuples for users of +# a given membership type, suitable for use in calculating heroes for a room. +# "count" points to the total numberr of users of a given membership type. +MemberSummary = namedtuple( + "MemberSummary", ("members", "count") +) _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" @@ -82,6 +88,65 @@ class RoomMemberWorkerStore(EventsWorkerStore): return [to_ascii(r[0]) for r in txn] return self.runInteraction("get_users_in_room", f) + @cached(max_entries=100000) + def get_room_summary(self, room_id): + """ Get the details of a room roughly suitable for use by the room + summary extension to /sync. Useful when lazy loading room members. + Args: + room_id (str): The room ID to query + Returns: + Deferred[dict[str, MemberSummary]: + dict of membership states, pointing to a MemberSummary named tuple. + """ + + def _get_room_summary_txn(txn): + # first get counts. + # We do this all in one transaction to keep the cache small. + # FIXME: get rid of this when we have room_stats + sql = """ + SELECT count(*), m.membership FROM room_memberships as m + INNER JOIN current_state_events as c + ON m.event_id = c.event_id + AND m.room_id = c.room_id + AND m.user_id = c.state_key + WHERE c.type = 'm.room.member' AND c.room_id = ? + GROUP BY m.membership + """ + + txn.execute(sql, (room_id,)) + res = {} + for count, membership in txn: + summary = res.setdefault(to_ascii(membership), MemberSummary([], count)) + + # we order by membership and then fairly arbitrarily by event_id so + # heroes are consistent + sql = """ + SELECT m.user_id, m.membership, m.event_id + FROM room_memberships as m + INNER JOIN current_state_events as c + ON m.event_id = c.event_id + AND m.room_id = c.room_id + AND m.user_id = c.state_key + WHERE c.type = 'm.room.member' AND c.room_id = ? + ORDER BY + CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC, + m.event_id ASC + LIMIT ? + """ + + # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user. + txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6)) + for user_id, membership, event_id in txn: + summary = res[to_ascii(membership)] + # we will always have a summary for this membership type at this + # point given the summary currently contains the counts. + members = summary.members + members.append((to_ascii(user_id), to_ascii(event_id))) + + return res + + return self.runInteraction("get_room_summary", _get_room_summary_txn) + @cached() def get_invited_rooms_for_user(self, user_id): """ Get all the rooms the user is invited to |