summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/http/client.py14
-rw-r--r--synapse/http/endpoint.py3
-rw-r--r--synapse/push/httppusher.py37
-rw-r--r--synapse/rest/client/v1/admin.py22
-rw-r--r--synapse/state.py20
-rw-r--r--synapse/storage/events.py6
-rw-r--r--synapse/storage/room.py153
7 files changed, 180 insertions, 75 deletions
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 4abb479ae3..f3e4973c2e 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -18,6 +18,7 @@ from OpenSSL.SSL import VERIFY_NONE
 from synapse.api.errors import (
     CodeMessageException, MatrixCodeMessageException, SynapseError, Codes,
 )
+from synapse.util.caches import CACHE_SIZE_FACTOR
 from synapse.util.logcontext import make_deferred_yieldable
 from synapse.util import logcontext
 import synapse.metrics
@@ -30,6 +31,7 @@ from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
 from twisted.web.client import (
     BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent,
     readBody, PartialDownloadError,
+    HTTPConnectionPool,
 )
 from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer
 from twisted.web.http import PotentialDataLoss
@@ -64,13 +66,23 @@ class SimpleHttpClient(object):
     """
     def __init__(self, hs):
         self.hs = hs
+
+        pool = HTTPConnectionPool(reactor)
+
+        # the pusher makes lots of concurrent SSL connections to sygnal, and
+        # tends to do so in batches, so we need to allow the pool to keep lots
+        # of idle connections around.
+        pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5))
+        pool.cachedConnectionTimeout = 2 * 60
+
         # The default context factory in Twisted 14.0.0 (which we require) is
         # BrowserLikePolicyForHTTPS which will do regular cert validation
         # 'like a browser'
         self.agent = Agent(
             reactor,
             connectTimeout=15,
-            contextFactory=hs.get_http_client_context_factory()
+            contextFactory=hs.get_http_client_context_factory(),
+            pool=pool,
         )
         self.user_agent = hs.version_string
         self.clock = hs.get_clock()
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index e2b99ef3bd..87639b9151 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -357,8 +357,7 @@ def _get_hosts_for_srv_record(dns_client, host):
     def eb(res, record_type):
         if res.check(DNSNameError):
             return []
-        logger.warn("Error looking up %s for %s: %s",
-                    record_type, host, res, res.value)
+        logger.warn("Error looking up %s for %s: %s", record_type, host, res)
         return res
 
     # no logcontexts here, so we can safely fire these off and gatherResults
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index c16f61452c..2cbac571b8 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -13,21 +13,30 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-from synapse.push import PusherConfigException
+import logging
 
 from twisted.internet import defer, reactor
 from twisted.internet.error import AlreadyCalled, AlreadyCancelled
 
-import logging
 import push_rule_evaluator
 import push_tools
-
+import synapse
+from synapse.push import PusherConfigException
 from synapse.util.logcontext import LoggingContext
 from synapse.util.metrics import Measure
 
 logger = logging.getLogger(__name__)
 
+metrics = synapse.metrics.get_metrics_for(__name__)
+
+http_push_processed_counter = metrics.register_counter(
+    "http_pushes_processed",
+)
+
+http_push_failed_counter = metrics.register_counter(
+    "http_pushes_failed",
+)
+
 
 class HttpPusher(object):
     INITIAL_BACKOFF_SEC = 1  # in seconds because that's what Twisted takes
@@ -152,9 +161,16 @@ class HttpPusher(object):
             self.user_id, self.last_stream_ordering, self.max_stream_ordering
         )
 
+        logger.info(
+            "Processing %i unprocessed push actions for %s starting at "
+            "stream_ordering %s",
+            len(unprocessed), self.name, self.last_stream_ordering,
+        )
+
         for push_action in unprocessed:
             processed = yield self._process_one(push_action)
             if processed:
+                http_push_processed_counter.inc()
                 self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
                 self.last_stream_ordering = push_action['stream_ordering']
                 yield self.store.update_pusher_last_stream_ordering_and_success(
@@ -169,6 +185,7 @@ class HttpPusher(object):
                         self.failing_since
                     )
             else:
+                http_push_failed_counter.inc()
                 if not self.failing_since:
                     self.failing_since = self.clock.time_msec()
                     yield self.store.update_pusher_failing_since(
@@ -316,7 +333,10 @@ class HttpPusher(object):
         try:
             resp = yield self.http_client.post_json_get_json(self.url, notification_dict)
         except Exception:
-            logger.warn("Failed to push %s ", self.url)
+            logger.warn(
+                "Failed to push event %s to %s",
+                event.event_id, self.name, exc_info=True,
+            )
             defer.returnValue(False)
         rejected = []
         if 'rejected' in resp:
@@ -325,7 +345,7 @@ class HttpPusher(object):
 
     @defer.inlineCallbacks
     def _send_badge(self, badge):
-        logger.info("Sending updated badge count %d to %r", badge, self.user_id)
+        logger.info("Sending updated badge count %d to %s", badge, self.name)
         d = {
             'notification': {
                 'id': '',
@@ -347,7 +367,10 @@ class HttpPusher(object):
         try:
             resp = yield self.http_client.post_json_get_json(self.url, d)
         except Exception:
-            logger.exception("Failed to push %s ", self.url)
+            logger.warn(
+                "Failed to send badge count to %s",
+                self.name, exc_info=True,
+            )
             defer.returnValue(False)
         rejected = []
         if 'rejected' in resp:
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index 5022808ea9..0615e5d807 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -289,6 +289,27 @@ class QuarantineMediaInRoom(ClientV1RestServlet):
         defer.returnValue((200, {"num_quarantined": num_quarantined}))
 
 
+class ListMediaInRoom(ClientV1RestServlet):
+    """Lists all of the media in a given room.
+    """
+    PATTERNS = client_path_patterns("/admin/room/(?P<room_id>[^/]+)/media")
+
+    def __init__(self, hs):
+        super(ListMediaInRoom, self).__init__(hs)
+        self.store = hs.get_datastore()
+
+    @defer.inlineCallbacks
+    def on_GET(self, request, room_id):
+        requester = yield self.auth.get_user_by_req(request)
+        is_admin = yield self.auth.is_server_admin(requester.user)
+        if not is_admin:
+            raise AuthError(403, "You are not a server admin")
+
+        local_mxcs, remote_mxcs = yield self.store.get_media_mxcs_in_room(room_id)
+
+        defer.returnValue((200, {"local": local_mxcs, "remote": remote_mxcs}))
+
+
 class ResetPasswordRestServlet(ClientV1RestServlet):
     """Post request to allow an administrator reset password for a user.
     This needs user to have administrator access in Synapse.
@@ -487,3 +508,4 @@ def register_servlets(hs, http_server):
     SearchUsersRestServlet(hs).register(http_server)
     ShutdownRoomRestServlet(hs).register(http_server)
     QuarantineMediaInRoom(hs).register(http_server)
+    ListMediaInRoom(hs).register(http_server)
diff --git a/synapse/state.py b/synapse/state.py
index 30b16e201a..037f24dd79 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -150,8 +150,20 @@ class StateHandler(object):
         defer.returnValue(state)
 
     @defer.inlineCallbacks
-    def get_current_state_ids(self, room_id, event_type=None, state_key="",
-                              latest_event_ids=None):
+    def get_current_state_ids(self, room_id, latest_event_ids=None):
+        """Get the current state, or the state at a set of events, for a room
+
+        Args:
+            room_id (str):
+
+            latest_event_ids (iterable[str]|None): if given, the forward
+                extremities to resolve. If None, we look them up from the
+                database (via a cache)
+
+        Returns:
+            Deferred[dict[(str, str), str)]]: the state dict, mapping from
+                (event_type, state_key) -> event_id
+        """
         if not latest_event_ids:
             latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
 
@@ -159,10 +171,6 @@ class StateHandler(object):
         ret = yield self.resolve_state_groups(room_id, latest_event_ids)
         state = ret.state
 
-        if event_type:
-            defer.returnValue(state.get((event_type, state_key)))
-            return
-
         defer.returnValue(state)
 
     @defer.inlineCallbacks
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 7a9cd3ec90..33fccfa7a8 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -110,7 +110,7 @@ class _EventPeristenceQueue(object):
                 end_item.events_and_contexts.extend(events_and_contexts)
                 return end_item.deferred.observe()
 
-        deferred = ObservableDeferred(defer.Deferred())
+        deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
 
         queue.append(self._EventPersistQueueItem(
             events_and_contexts=events_and_contexts,
@@ -152,8 +152,8 @@ class _EventPeristenceQueue(object):
                     try:
                         ret = yield per_item_callback(item)
                         item.deferred.callback(ret)
-                    except Exception as e:
-                        item.deferred.errback(e)
+                    except Exception:
+                        item.deferred.errback()
             finally:
                 queue = self._event_persist_queues.pop(room_id, None)
                 if queue:
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 23688430b7..cf2c4dae39 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -533,73 +533,114 @@ class RoomStore(SQLBaseStore):
         )
         self.is_room_blocked.invalidate((room_id,))
 
+    def get_media_mxcs_in_room(self, room_id):
+        """Retrieves all the local and remote media MXC URIs in a given room
+
+        Args:
+            room_id (str)
+
+        Returns:
+            The local and remote media as a lists of tuples where the key is
+            the hostname and the value is the media ID.
+        """
+        def _get_media_mxcs_in_room_txn(txn):
+            local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
+            local_media_mxcs = []
+            remote_media_mxcs = []
+
+            # Convert the IDs to MXC URIs
+            for media_id in local_mxcs:
+                local_media_mxcs.append("mxc://%s/%s" % (self.hostname, media_id))
+            for hostname, media_id in remote_mxcs:
+                remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
+
+            return local_media_mxcs, remote_media_mxcs
+        return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn)
+
     def quarantine_media_ids_in_room(self, room_id, quarantined_by):
         """For a room loops through all events with media and quarantines
         the associated media
         """
-        def _get_media_ids_in_room(txn):
-            mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
+        def _quarantine_media_in_room_txn(txn):
+            local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
+            total_media_quarantined = 0
 
-            next_token = self.get_current_events_token() + 1
+            # Now update all the tables to set the quarantined_by flag
 
-            total_media_quarantined = 0
+            txn.executemany("""
+                UPDATE local_media_repository
+                SET quarantined_by = ?
+                WHERE media_id = ?
+            """, ((quarantined_by, media_id) for media_id in local_mxcs))
 
-            while next_token:
-                sql = """
-                    SELECT stream_ordering, content FROM events
-                    WHERE room_id = ?
-                        AND stream_ordering < ?
-                        AND contains_url = ? AND outlier = ?
-                    ORDER BY stream_ordering DESC
-                    LIMIT ?
+            txn.executemany(
                 """
-                txn.execute(sql, (room_id, next_token, True, False, 100))
-
-                next_token = None
-                local_media_mxcs = []
-                remote_media_mxcs = []
-                for stream_ordering, content_json in txn:
-                    next_token = stream_ordering
-                    content = json.loads(content_json)
-
-                    content_url = content.get("url")
-                    thumbnail_url = content.get("info", {}).get("thumbnail_url")
-
-                    for url in (content_url, thumbnail_url):
-                        if not url:
-                            continue
-                        matches = mxc_re.match(url)
-                        if matches:
-                            hostname = matches.group(1)
-                            media_id = matches.group(2)
-                            if hostname == self.hostname:
-                                local_media_mxcs.append(media_id)
-                            else:
-                                remote_media_mxcs.append((hostname, media_id))
-
-                # Now update all the tables to set the quarantined_by flag
-
-                txn.executemany("""
-                    UPDATE local_media_repository
+                    UPDATE remote_media_cache
                     SET quarantined_by = ?
-                    WHERE media_id = ?
-                """, ((quarantined_by, media_id) for media_id in local_media_mxcs))
-
-                txn.executemany(
-                    """
-                        UPDATE remote_media_cache
-                        SET quarantined_by = ?
-                        WHERE media_origin AND media_id = ?
-                    """,
-                    (
-                        (quarantined_by, origin, media_id)
-                        for origin, media_id in remote_media_mxcs
-                    )
+                    WHERE media_origin = ? AND media_id = ?
+                """,
+                (
+                    (quarantined_by, origin, media_id)
+                    for origin, media_id in remote_mxcs
                 )
+            )
 
-                total_media_quarantined += len(local_media_mxcs)
-                total_media_quarantined += len(remote_media_mxcs)
+            total_media_quarantined += len(local_mxcs)
+            total_media_quarantined += len(remote_mxcs)
 
             return total_media_quarantined
 
-        return self.runInteraction("get_media_ids_in_room", _get_media_ids_in_room)
+        return self.runInteraction(
+            "quarantine_media_in_room",
+            _quarantine_media_in_room_txn,
+        )
+
+    def _get_media_mxcs_in_room_txn(self, txn, room_id):
+        """Retrieves all the local and remote media MXC URIs in a given room
+
+        Args:
+            txn (cursor)
+            room_id (str)
+
+        Returns:
+            The local and remote media as a lists of tuples where the key is
+            the hostname and the value is the media ID.
+        """
+        mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
+
+        next_token = self.get_current_events_token() + 1
+        local_media_mxcs = []
+        remote_media_mxcs = []
+
+        while next_token:
+            sql = """
+                SELECT stream_ordering, content FROM events
+                WHERE room_id = ?
+                    AND stream_ordering < ?
+                    AND contains_url = ? AND outlier = ?
+                ORDER BY stream_ordering DESC
+                LIMIT ?
+            """
+            txn.execute(sql, (room_id, next_token, True, False, 100))
+
+            next_token = None
+            for stream_ordering, content_json in txn:
+                next_token = stream_ordering
+                content = json.loads(content_json)
+
+                content_url = content.get("url")
+                thumbnail_url = content.get("info", {}).get("thumbnail_url")
+
+                for url in (content_url, thumbnail_url):
+                    if not url:
+                        continue
+                    matches = mxc_re.match(url)
+                    if matches:
+                        hostname = matches.group(1)
+                        media_id = matches.group(2)
+                        if hostname == self.hostname:
+                            local_media_mxcs.append(media_id)
+                        else:
+                            remote_media_mxcs.append((hostname, media_id))
+
+        return local_media_mxcs, remote_media_mxcs