diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index 5f9af8529b..f8f3b1a31e 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -13,20 +13,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Union
+
from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import EventFormatVersions
+from synapse.config.homeserver import HomeServerConfig
+from synapse.events import EventBase
+from synapse.events.builder import EventBuilder
from synapse.events.utils import validate_canonicaljson
+from synapse.federation.federation_server import server_matches_acl_event
from synapse.types import EventID, RoomID, UserID
class EventValidator:
- def validate_new(self, event, config):
+ def validate_new(self, event: EventBase, config: HomeServerConfig):
"""Validates the event has roughly the right format
Args:
- event (FrozenEvent): The event to validate.
- config (Config): The homeserver's configuration.
+ event: The event to validate.
+ config: The homeserver's configuration.
"""
self.validate_builder(event)
@@ -76,12 +82,18 @@ class EventValidator:
if event.type == EventTypes.Retention:
self._validate_retention(event)
- def _validate_retention(self, event):
+ if event.type == EventTypes.ServerACL:
+ if not server_matches_acl_event(config.server_name, event):
+ raise SynapseError(
+ 400, "Can't create an ACL event that denies the local server"
+ )
+
+ def _validate_retention(self, event: EventBase):
"""Checks that an event that defines the retention policy for a room respects the
format enforced by the spec.
Args:
- event (FrozenEvent): The event to validate.
+ event: The event to validate.
"""
if not event.is_state():
raise SynapseError(code=400, msg="must be a state event")
@@ -116,13 +128,10 @@ class EventValidator:
errcode=Codes.BAD_JSON,
)
- def validate_builder(self, event):
+ def validate_builder(self, event: Union[EventBase, EventBuilder]):
"""Validates that the builder/event has roughly the right format. Only
checks values that we expect a proto event to have, rather than all the
fields an event would have
-
- Args:
- event (EventBuilder|FrozenEvent)
"""
strings = ["room_id", "sender", "type"]
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index ff103cbb92..213baea2e3 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -181,10 +181,15 @@ class AuthHandler(BaseHandler):
# better way to break the loop
account_handler = ModuleApi(hs, self)
- self.password_providers = [
- module(config=config, account_handler=account_handler)
- for module, config in hs.config.password_providers
- ]
+ self.password_providers = []
+ for module, config in hs.config.password_providers:
+ try:
+ self.password_providers.append(
+ module(config=config, account_handler=account_handler)
+ )
+ except Exception as e:
+ logger.error("Error while initializing %r: %s", module, e)
+ raise
logger.info("Extra password_providers: %r", self.password_providers)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 31f91e0a1a..2f3f3a7ef5 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1138,6 +1138,9 @@ class EventCreationHandler:
if original_event.room_id != event.room_id:
raise SynapseError(400, "Cannot redact event from a different room")
+ if original_event.type == EventTypes.ServerACL:
+ raise AuthError(403, "Redacting server ACL events is not permitted")
+
prev_state_ids = await context.get_prev_state_ids()
auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 14348faaf3..74a1ddd780 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -189,7 +189,9 @@ class ProfileHandler(BaseHandler):
)
if not isinstance(new_displayname, str):
- raise SynapseError(400, "Invalid displayname")
+ raise SynapseError(
+ 400, "'displayname' must be a string", errcode=Codes.INVALID_PARAM
+ )
if len(new_displayname) > MAX_DISPLAYNAME_LEN:
raise SynapseError(
@@ -273,7 +275,9 @@ class ProfileHandler(BaseHandler):
)
if not isinstance(new_avatar_url, str):
- raise SynapseError(400, "Invalid displayname")
+ raise SynapseError(
+ 400, "'avatar_url' must be a string", errcode=Codes.INVALID_PARAM
+ )
if len(new_avatar_url) > MAX_AVATAR_URL_LEN:
raise SynapseError(
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 04766ca965..7e17cdb73e 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -1063,13 +1063,19 @@ def check_content_type_is_json(headers):
"""
c_type = headers.getRawHeaders(b"Content-Type")
if c_type is None:
- raise RequestSendFailed(RuntimeError("No Content-Type header"), can_retry=False)
+ raise RequestSendFailed(
+ RuntimeError("No Content-Type header received from remote server"),
+ can_retry=False,
+ )
c_type = c_type[0].decode("ascii") # only the first header
val, options = cgi.parse_header(c_type)
if val != "application/json":
raise RequestSendFailed(
- RuntimeError("Content-Type not application/json: was '%s'" % c_type),
+ RuntimeError(
+ "Remote server sent Content-Type header of '%s', not 'application/json'"
+ % c_type,
+ ),
can_retry=False,
)
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index b8d2a8e8a9..cbf0dbb871 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -502,6 +502,16 @@ build_info.labels(
last_ticked = time.time()
+# 3PID send info
+threepid_send_requests = Histogram(
+ "synapse_threepid_send_requests_with_tries",
+ documentation="Number of requests for a 3pid token by try count. Note if"
+ " there is a request with try count of 4, then there would have been one"
+ " each for 1, 2 and 3",
+ buckets=(1, 2, 3, 4, 5, 10),
+ labelnames=("type", "reason"),
+)
+
class ReactorLastSeenMetric:
def collect(self):
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index 2858b61fb1..f5788c1de7 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -498,6 +498,30 @@ BASE_APPEND_UNDERRIDE_RULES = [
],
"actions": ["notify", {"set_tweak": "highlight", "value": False}],
},
+ {
+ "rule_id": "global/underride/.im.vector.jitsi",
+ "conditions": [
+ {
+ "kind": "event_match",
+ "key": "type",
+ "pattern": "im.vector.modular.widgets",
+ "_id": "_type_modular_widgets",
+ },
+ {
+ "kind": "event_match",
+ "key": "content.type",
+ "pattern": "jitsi",
+ "_id": "_content_type_jitsi",
+ },
+ {
+ "kind": "event_match",
+ "key": "state_key",
+ "pattern": "*",
+ "_id": "_is_state_event",
+ },
+ ],
+ "actions": ["notify", {"set_tweak": "highlight", "value": False}],
+ },
]
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index fa7e9e4043..2a4f7a1740 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -47,6 +47,7 @@ from synapse.rest.admin.rooms import (
ShutdownRoomRestServlet,
)
from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
+from synapse.rest.admin.statistics import UserMediaStatisticsRestServlet
from synapse.rest.admin.users import (
AccountValidityRenewServlet,
DeactivateAccountRestServlet,
@@ -227,6 +228,7 @@ def register_servlets(hs, http_server):
DeviceRestServlet(hs).register(http_server)
DevicesRestServlet(hs).register(http_server)
DeleteDevicesRestServlet(hs).register(http_server)
+ UserMediaStatisticsRestServlet(hs).register(http_server)
EventReportDetailRestServlet(hs).register(http_server)
EventReportsRestServlet(hs).register(http_server)
PushersRestServlet(hs).register(http_server)
diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py
new file mode 100644
index 0000000000..f2490e382d
--- /dev/null
+++ b/synapse/rest/admin/statistics.py
@@ -0,0 +1,122 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING, Tuple
+
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.servlet import RestServlet, parse_integer, parse_string
+from synapse.http.site import SynapseRequest
+from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
+from synapse.storage.databases.main.stats import UserSortOrder
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class UserMediaStatisticsRestServlet(RestServlet):
+ """
+ Get statistics about uploaded media by users.
+ """
+
+ PATTERNS = admin_patterns("/statistics/users/media$")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ await assert_requester_is_admin(self.auth, request)
+
+ order_by = parse_string(
+ request, "order_by", default=UserSortOrder.USER_ID.value
+ )
+ if order_by not in (
+ UserSortOrder.MEDIA_LENGTH.value,
+ UserSortOrder.MEDIA_COUNT.value,
+ UserSortOrder.USER_ID.value,
+ UserSortOrder.DISPLAYNAME.value,
+ ):
+ raise SynapseError(
+ 400,
+ "Unknown value for order_by: %s" % (order_by,),
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ start = parse_integer(request, "from", default=0)
+ if start < 0:
+ raise SynapseError(
+ 400,
+ "Query parameter from must be a string representing a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ limit = parse_integer(request, "limit", default=100)
+ if limit < 0:
+ raise SynapseError(
+ 400,
+ "Query parameter limit must be a string representing a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ from_ts = parse_integer(request, "from_ts", default=0)
+ if from_ts < 0:
+ raise SynapseError(
+ 400,
+ "Query parameter from_ts must be a string representing a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ until_ts = parse_integer(request, "until_ts")
+ if until_ts is not None:
+ if until_ts < 0:
+ raise SynapseError(
+ 400,
+ "Query parameter until_ts must be a string representing a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+ if until_ts <= from_ts:
+ raise SynapseError(
+ 400,
+ "Query parameter until_ts must be greater than from_ts.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ search_term = parse_string(request, "search_term")
+ if search_term == "":
+ raise SynapseError(
+ 400,
+ "Query parameter search_term cannot be an empty string.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ direction = parse_string(request, "dir", default="f")
+ if direction not in ("f", "b"):
+ raise SynapseError(
+ 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM
+ )
+
+ users_media, total = await self.store.get_users_media_usage_paginate(
+ start, limit, from_ts, until_ts, order_by, direction, search_term
+ )
+ ret = {"users": users_media, "total": total}
+ if (start + limit) < total:
+ ret["next_token"] = start + len(users_media)
+
+ return 200, ret
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index b337311a37..3638e219f2 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -412,6 +412,7 @@ class UserRegisterServlet(RestServlet):
admin = body.get("admin", None)
user_type = body.get("user_type", None)
+ displayname = body.get("displayname", None)
if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
raise SynapseError(400, "Invalid user type")
@@ -448,6 +449,7 @@ class UserRegisterServlet(RestServlet):
password_hash=password_hash,
admin=bool(admin),
user_type=user_type,
+ default_display_name=displayname,
by_admin=True,
)
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 51effc4d8e..a54e1011f7 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -38,6 +38,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
+from synapse.metrics import threepid_send_requests
from synapse.push.mailer import Mailer
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.stringutils import assert_valid_client_secret, random_string
@@ -143,6 +144,10 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
# Wrap the session id in a JSON object
ret = {"sid": sid}
+ threepid_send_requests.labels(type="email", reason="password_reset").observe(
+ send_attempt
+ )
+
return 200, ret
@@ -411,6 +416,10 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
# Wrap the session id in a JSON object
ret = {"sid": sid}
+ threepid_send_requests.labels(type="email", reason="add_threepid").observe(
+ send_attempt
+ )
+
return 200, ret
@@ -481,6 +490,10 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
next_link,
)
+ threepid_send_requests.labels(type="msisdn", reason="add_threepid").observe(
+ send_attempt
+ )
+
return 200, ret
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 8f2c8cd991..ea68114026 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -45,6 +45,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
+from synapse.metrics import threepid_send_requests
from synapse.push.mailer import Mailer
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.ratelimitutils import FederationRateLimiter
@@ -163,6 +164,10 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
# Wrap the session id in a JSON object
ret = {"sid": sid}
+ threepid_send_requests.labels(type="email", reason="register").observe(
+ send_attempt
+ )
+
return 200, ret
@@ -234,6 +239,10 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
next_link,
)
+ threepid_send_requests.labels(type="msisdn", reason="register").observe(
+ send_attempt
+ )
+
return 200, ret
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index 0422d4c7ce..d464c75c03 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -119,7 +119,7 @@ class ServerNoticesManager:
# manages to invite the system user to a room, that doesn't make it
# the server notices room.
user_ids = await self._store.get_users_in_room(room.room_id)
- if self.server_notices_mxid in user_ids:
+ if len(user_ids) <= 2 and self.server_notices_mxid in user_ids:
# we found a room which our user shares with the system notice
# user
logger.info(
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index a0572b2952..d1b5760c2c 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -88,13 +88,18 @@ def make_pool(
"""Get the connection pool for the database.
"""
+ # By default enable `cp_reconnect`. We need to fiddle with db_args in case
+ # someone has explicitly set `cp_reconnect`.
+ db_args = dict(db_config.config.get("args", {}))
+ db_args.setdefault("cp_reconnect", True)
+
return adbapi.ConnectionPool(
db_config.config["name"],
cp_reactor=reactor,
cp_openfun=lambda conn: engine.on_new_connection(
LoggingDatabaseConnection(conn, engine, "on_new_connection")
),
- **db_config.config.get("args", {}),
+ **db_args,
)
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 4415909414..4d1b92d1aa 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -24,7 +24,7 @@ from twisted.enterprise.adbapi import Connection
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import make_in_list_sql_clause
+from synapse.storage.database import DatabasePool, make_in_list_sql_clause
from synapse.storage.types import Cursor
from synapse.types import JsonDict
from synapse.util import json_encoder
@@ -33,6 +33,7 @@ from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
from synapse.handlers.e2e_keys import SignatureListItem
+ from synapse.server import HomeServer
@attr.s(slots=True)
@@ -47,7 +48,20 @@ class DeviceKeyLookupResult:
keys = attr.ib(type=Optional[JsonDict])
-class EndToEndKeyWorkerStore(SQLBaseStore):
+class EndToEndKeyBackgroundStore(SQLBaseStore):
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ super().__init__(database, db_conn, hs)
+
+ self.db_pool.updates.register_background_index_update(
+ "e2e_cross_signing_keys_idx",
+ index_name="e2e_cross_signing_keys_stream_idx",
+ table="e2e_cross_signing_keys",
+ columns=["stream_id"],
+ unique=True,
+ )
+
+
+class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
async def get_e2e_device_keys_for_federation_query(
self, user_id: str
) -> Tuple[int, List[JsonDict]]:
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index a6279a6c13..2e07c37340 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -26,6 +26,7 @@ from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.types import Collection
from synapse.util.caches.descriptors import cached
+from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
logger = logging.getLogger(__name__)
@@ -40,6 +41,11 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
self._delete_old_forward_extrem_cache, 60 * 60 * 1000
)
+ # Cache of event ID to list of auth event IDs and their depths.
+ self._event_auth_cache = LruCache(
+ 500000, "_event_auth_cache", size_callback=len
+ ) # type: LruCache[str, List[Tuple[str, int]]]
+
async def get_auth_chain(
self, event_ids: Collection[str], include_given: bool = False
) -> List[EventBase]:
@@ -84,17 +90,45 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
else:
results = set()
- base_sql = "SELECT DISTINCT auth_id FROM event_auth WHERE "
+ # We pull out the depth simply so that we can populate the
+ # `_event_auth_cache` cache.
+ base_sql = """
+ SELECT a.event_id, auth_id, depth
+ FROM event_auth AS a
+ INNER JOIN events AS e ON (e.event_id = a.auth_id)
+ WHERE
+ """
front = set(event_ids)
while front:
new_front = set()
for chunk in batch_iter(front, 100):
- clause, args = make_in_list_sql_clause(
- txn.database_engine, "event_id", chunk
- )
- txn.execute(base_sql + clause, args)
- new_front.update(r[0] for r in txn)
+ # Pull the auth events either from the cache or DB.
+ to_fetch = [] # Event IDs to fetch from DB # type: List[str]
+ for event_id in chunk:
+ res = self._event_auth_cache.get(event_id)
+ if res is None:
+ to_fetch.append(event_id)
+ else:
+ new_front.update(auth_id for auth_id, depth in res)
+
+ if to_fetch:
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "a.event_id", to_fetch
+ )
+ txn.execute(base_sql + clause, args)
+
+ # Note we need to batch up the results by event ID before
+ # adding to the cache.
+ to_cache = {}
+ for event_id, auth_event_id, auth_event_depth in txn:
+ to_cache.setdefault(event_id, []).append(
+ (auth_event_id, auth_event_depth)
+ )
+ new_front.add(auth_event_id)
+
+ for event_id, auth_events in to_cache.items():
+ self._event_auth_cache.set(event_id, auth_events)
new_front -= results
@@ -213,14 +247,38 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
break
# Fetch the auth events and their depths of the N last events we're
- # currently walking
+ # currently walking, either from cache or DB.
search, chunk = search[:-100], search[-100:]
- clause, args = make_in_list_sql_clause(
- txn.database_engine, "a.event_id", [e_id for _, e_id in chunk]
- )
- txn.execute(base_sql + clause, args)
- for event_id, auth_event_id, auth_event_depth in txn:
+ found = [] # Results found # type: List[Tuple[str, str, int]]
+ to_fetch = [] # Event IDs to fetch from DB # type: List[str]
+ for _, event_id in chunk:
+ res = self._event_auth_cache.get(event_id)
+ if res is None:
+ to_fetch.append(event_id)
+ else:
+ found.extend((event_id, auth_id, depth) for auth_id, depth in res)
+
+ if to_fetch:
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "a.event_id", to_fetch
+ )
+ txn.execute(base_sql + clause, args)
+
+ # We parse the results and add the to the `found` set and the
+ # cache (note we need to batch up the results by event ID before
+ # adding to the cache).
+ to_cache = {}
+ for event_id, auth_event_id, auth_event_depth in txn:
+ to_cache.setdefault(event_id, []).append(
+ (auth_event_id, auth_event_depth)
+ )
+ found.append((event_id, auth_event_id, auth_event_depth))
+
+ for event_id, auth_events in to_cache.items():
+ self._event_auth_cache.set(event_id, auth_events)
+
+ for event_id, auth_event_id, auth_event_depth in found:
event_to_auth_events.setdefault(event_id, set()).add(auth_event_id)
sets = event_to_missing_sets.get(auth_event_id)
diff --git a/synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql b/synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql
new file mode 100644
index 0000000000..61c558db77
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('e2e_cross_signing_keys_idx', '{}');
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 5beb302be3..0cdb3ec1f7 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -16,15 +16,18 @@
import logging
from collections import Counter
+from enum import Enum
from itertools import chain
from typing import Any, Dict, List, Optional, Tuple
from twisted.internet.defer import DeferredLock
from synapse.api.constants import EventTypes, Membership
+from synapse.api.errors import StoreError
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.storage.engines import PostgresEngine
+from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -59,6 +62,23 @@ TYPE_TO_TABLE = {"room": ("room_stats", "room_id"), "user": ("user_stats", "user
TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")}
+class UserSortOrder(Enum):
+ """
+ Enum to define the sorting method used when returning users
+ with get_users_media_usage_paginate
+
+ MEDIA_LENGTH = ordered by size of uploaded media. Smallest to largest.
+ MEDIA_COUNT = ordered by number of uploaded media. Smallest to largest.
+ USER_ID = ordered alphabetically by `user_id`.
+ DISPLAYNAME = ordered alphabetically by `displayname`
+ """
+
+ MEDIA_LENGTH = "media_length"
+ MEDIA_COUNT = "media_count"
+ USER_ID = "user_id"
+ DISPLAYNAME = "displayname"
+
+
class StatsStore(StateDeltasStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
@@ -882,3 +902,110 @@ class StatsStore(StateDeltasStore):
complete_with_stream_id=pos,
absolute_field_overrides={"joined_rooms": joined_rooms},
)
+
+ async def get_users_media_usage_paginate(
+ self,
+ start: int,
+ limit: int,
+ from_ts: Optional[int] = None,
+ until_ts: Optional[int] = None,
+ order_by: Optional[UserSortOrder] = UserSortOrder.USER_ID.value,
+ direction: Optional[str] = "f",
+ search_term: Optional[str] = None,
+ ) -> Tuple[List[JsonDict], Dict[str, int]]:
+ """Function to retrieve a paginated list of users and their uploaded local media
+ (size and number). This will return a json list of users and the
+ total number of users matching the filter criteria.
+
+ Args:
+ start: offset to begin the query from
+ limit: number of rows to retrieve
+ from_ts: request only media that are created later than this timestamp (ms)
+ until_ts: request only media that are created earlier than this timestamp (ms)
+ order_by: the sort order of the returned list
+ direction: sort ascending or descending
+ search_term: a string to filter user names by
+ Returns:
+ A list of user dicts and an integer representing the total number of
+ users that exist given this query
+ """
+
+ def get_users_media_usage_paginate_txn(txn):
+ filters = []
+ args = [self.hs.config.server_name]
+
+ if search_term:
+ filters.append("(lmr.user_id LIKE ? OR displayname LIKE ?)")
+ args.extend(["@%" + search_term + "%:%", "%" + search_term + "%"])
+
+ if from_ts:
+ filters.append("created_ts >= ?")
+ args.extend([from_ts])
+ if until_ts:
+ filters.append("created_ts <= ?")
+ args.extend([until_ts])
+
+ # Set ordering
+ if UserSortOrder(order_by) == UserSortOrder.MEDIA_LENGTH:
+ order_by_column = "media_length"
+ elif UserSortOrder(order_by) == UserSortOrder.MEDIA_COUNT:
+ order_by_column = "media_count"
+ elif UserSortOrder(order_by) == UserSortOrder.USER_ID:
+ order_by_column = "lmr.user_id"
+ elif UserSortOrder(order_by) == UserSortOrder.DISPLAYNAME:
+ order_by_column = "displayname"
+ else:
+ raise StoreError(
+ 500, "Incorrect value for order_by provided: %s" % order_by
+ )
+
+ if direction == "b":
+ order = "DESC"
+ else:
+ order = "ASC"
+
+ where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
+
+ sql_base = """
+ FROM local_media_repository as lmr
+ LEFT JOIN profiles AS p ON lmr.user_id = '@' || p.user_id || ':' || ?
+ {}
+ GROUP BY lmr.user_id, displayname
+ """.format(
+ where_clause
+ )
+
+ # SQLite does not support SELECT COUNT(*) OVER()
+ sql = """
+ SELECT COUNT(*) FROM (
+ SELECT lmr.user_id
+ {sql_base}
+ ) AS count_user_ids
+ """.format(
+ sql_base=sql_base,
+ )
+ txn.execute(sql, args)
+ count = txn.fetchone()[0]
+
+ sql = """
+ SELECT
+ lmr.user_id,
+ displayname,
+ COUNT(lmr.user_id) as media_count,
+ SUM(media_length) as media_length
+ {sql_base}
+ ORDER BY {order_by_column} {order}
+ LIMIT ? OFFSET ?
+ """.format(
+ sql_base=sql_base, order_by_column=order_by_column, order=order,
+ )
+
+ args += [limit, start]
+ txn.execute(sql, args)
+ users = self.db_pool.cursor_to_dict(txn)
+
+ return users, count
+
+ return await self.db_pool.runInteraction(
+ "get_users_media_usage_paginate_txn", get_users_media_usage_paginate_txn
+ )
|