summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorMark Haines <mark.haines@matrix.org>2016-09-12 11:14:56 +0100
committerMark Haines <mark.haines@matrix.org>2016-09-12 11:14:56 +0100
commit4a32d25d4c97b10f74e636d46b312ba6b660987b (patch)
tree53c35423e6469adc87cc5b213d5b960fd143277c /synapse
parentFix unit tests (diff)
parentMerge pull request #1102 from matrix-org/revert-1099-dbkr/make_notif_highligh... (diff)
downloadsynapse-4a32d25d4c97b10f74e636d46b312ba6b660987b.tar.xz
Merge branch 'develop' into markjh/bearer_token
Diffstat (limited to 'synapse')
-rw-r--r--synapse/app/synchrotron.py3
-rw-r--r--synapse/federation/federation_client.py8
-rw-r--r--synapse/federation/transaction_queue.py50
-rw-r--r--synapse/handlers/presence.py20
-rw-r--r--synapse/handlers/receipts.py1
-rw-r--r--synapse/handlers/typing.py1
-rw-r--r--synapse/replication/resource.py9
-rw-r--r--synapse/rest/client/v1/login.py7
-rw-r--r--synapse/storage/state.py54
9 files changed, 108 insertions, 45 deletions
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 07d3d047c6..dbaa48035d 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -242,6 +242,9 @@ class SynchrotronTyping(object):
         self._room_typing = {}
 
     def stream_positions(self):
+        # We must update this typing token from the response of the previous
+        # sync. In particular, the stream id may "reset" back to zero/a low
+        # value which we *must* use for the next replication request.
         return {"typing": self._latest_room_serial}
 
     def process_replication(self, result):
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 78719eed25..3395c9e41e 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -122,8 +122,12 @@ class FederationClient(FederationBase):
             pdu.event_id
         )
 
+    def send_presence(self, destination, states):
+        if destination != self.server_name:
+            self._transaction_queue.enqueue_presence(destination, states)
+
     @log_function
-    def send_edu(self, destination, edu_type, content):
+    def send_edu(self, destination, edu_type, content, key=None):
         edu = Edu(
             origin=self.server_name,
             destination=destination,
@@ -134,7 +138,7 @@ class FederationClient(FederationBase):
         sent_edus_counter.inc()
 
         # TODO, add errback, etc.
-        self._transaction_queue.enqueue_edu(edu)
+        self._transaction_queue.enqueue_edu(edu, key=key)
         return defer.succeed(None)
 
     @log_function
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 1ac569b305..f8ca93e4c3 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -26,6 +26,7 @@ from synapse.util.retryutils import (
     get_retry_limiter, NotRetryingDestination,
 )
 from synapse.util.metrics import measure_func
+from synapse.handlers.presence import format_user_presence_state
 import synapse.metrics
 
 import logging
@@ -69,13 +70,21 @@ class TransactionQueue(object):
         # destination -> list of tuple(edu, deferred)
         self.pending_edus_by_dest = edus = {}
 
+        # Presence needs to be separate as we send single aggragate EDUs
+        self.pending_presence_by_dest = presence = {}
+        self.pending_edus_keyed_by_dest = edus_keyed = {}
+
         metrics.register_callback(
             "pending_pdus",
             lambda: sum(map(len, pdus.values())),
         )
         metrics.register_callback(
             "pending_edus",
-            lambda: sum(map(len, edus.values())),
+            lambda: (
+                sum(map(len, edus.values()))
+                + sum(map(len, presence.values()))
+                + sum(map(len, edus_keyed.values()))
+            ),
         )
 
         # destination -> list of tuple(failure, deferred)
@@ -130,13 +139,27 @@ class TransactionQueue(object):
                 self._attempt_new_transaction, destination
             )
 
-    def enqueue_edu(self, edu):
+    def enqueue_presence(self, destination, states):
+        self.pending_presence_by_dest.setdefault(destination, {}).update({
+            state.user_id: state for state in states
+        })
+
+        preserve_context_over_fn(
+            self._attempt_new_transaction, destination
+        )
+
+    def enqueue_edu(self, edu, key=None):
         destination = edu.destination
 
         if not self.can_send_to(destination):
             return
 
-        self.pending_edus_by_dest.setdefault(destination, []).append(edu)
+        if key:
+            self.pending_edus_keyed_by_dest.setdefault(
+                destination, {}
+            )[(edu.edu_type, key)] = edu
+        else:
+            self.pending_edus_by_dest.setdefault(destination, []).append(edu)
 
         preserve_context_over_fn(
             self._attempt_new_transaction, destination
@@ -190,8 +213,13 @@ class TransactionQueue(object):
             while True:
                     pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
                     pending_edus = self.pending_edus_by_dest.pop(destination, [])
+                    pending_presence = self.pending_presence_by_dest.pop(destination, {})
                     pending_failures = self.pending_failures_by_dest.pop(destination, [])
 
+                    pending_edus.extend(
+                        self.pending_edus_keyed_by_dest.pop(destination, {}).values()
+                    )
+
                     limiter = yield get_retry_limiter(
                         destination,
                         self.clock,
@@ -203,6 +231,22 @@ class TransactionQueue(object):
                     )
 
                     pending_edus.extend(device_message_edus)
+                    if pending_presence:
+                        pending_edus.append(
+                            Edu(
+                                origin=self.server_name,
+                                destination=destination,
+                                edu_type="m.presence",
+                                content={
+                                    "push": [
+                                        format_user_presence_state(
+                                            presence, self.clock.time_msec()
+                                        )
+                                        for presence in pending_presence.values()
+                                    ]
+                                },
+                            )
+                        )
 
                     if pending_pdus:
                         logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 16dbddee03..a949e39bda 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -625,18 +625,8 @@ class PresenceHandler(object):
         Args:
             hosts_to_states (dict): Mapping `server_name` -> `[UserPresenceState]`
         """
-        now = self.clock.time_msec()
         for host, states in hosts_to_states.items():
-            self.federation.send_edu(
-                destination=host,
-                edu_type="m.presence",
-                content={
-                    "push": [
-                        _format_user_presence_state(state, now)
-                        for state in states
-                    ]
-                }
-            )
+            self.federation.send_presence(host, states)
 
     @defer.inlineCallbacks
     def incoming_presence(self, origin, content):
@@ -723,13 +713,13 @@ class PresenceHandler(object):
             defer.returnValue([
                 {
                     "type": "m.presence",
-                    "content": _format_user_presence_state(state, now),
+                    "content": format_user_presence_state(state, now),
                 }
                 for state in updates
             ])
         else:
             defer.returnValue([
-                _format_user_presence_state(state, now) for state in updates
+                format_user_presence_state(state, now) for state in updates
             ])
 
     @defer.inlineCallbacks
@@ -988,7 +978,7 @@ def should_notify(old_state, new_state):
     return False
 
 
-def _format_user_presence_state(state, now):
+def format_user_presence_state(state, now):
     """Convert UserPresenceState to a format that can be sent down to clients
     and to other servers.
     """
@@ -1101,7 +1091,7 @@ class PresenceEventSource(object):
         defer.returnValue(([
             {
                 "type": "m.presence",
-                "content": _format_user_presence_state(s, now),
+                "content": format_user_presence_state(s, now),
             }
             for s in updates.values()
             if include_offline or s.state != PresenceState.OFFLINE
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 726f7308d2..e536a909d0 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -156,6 +156,7 @@ class ReceiptsHandler(BaseHandler):
                             }
                         },
                     },
+                    key=(room_id, receipt_type, user_id),
                 )
 
     @defer.inlineCallbacks
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 3b687957dd..0548b81c34 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -187,6 +187,7 @@ class TypingHandler(object):
                         "user_id": user_id,
                         "typing": typing,
                     },
+                    key=(room_id, user_id),
                 ))
 
         yield preserve_context_over_deferred(
diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py
index 857bc9795c..299e9419a4 100644
--- a/synapse/replication/resource.py
+++ b/synapse/replication/resource.py
@@ -274,11 +274,18 @@ class ReplicationResource(Resource):
 
     @defer.inlineCallbacks
     def typing(self, writer, current_token, request_streams):
-        current_position = current_token.presence
+        current_position = current_token.typing
 
         request_typing = request_streams.get("typing")
 
         if request_typing is not None:
+            # If they have a higher token than current max, we can assume that
+            # they had been talking to a previous instance of the master. Since
+            # we reset the token on restart, the best (but hacky) thing we can
+            # do is to simply resend down all the typing notifications.
+            if request_typing > current_position:
+                request_typing = 0
+
             typing_rows = yield self.typing_handler.get_all_typing_updates(
                 request_typing, current_position
             )
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 6c0eec8fb3..345018a8fc 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -318,7 +318,7 @@ class CasRedirectServlet(ClientV1RestServlet):
         service_param = urllib.urlencode({
             "service": "%s?%s" % (hs_redirect_url, client_redirect_url_param)
         })
-        request.redirect("%s?%s" % (self.cas_server_url, service_param))
+        request.redirect("%s/login?%s" % (self.cas_server_url, service_param))
         finish_request(request)
 
 
@@ -385,7 +385,7 @@ class CasTicketServlet(ClientV1RestServlet):
 
     def parse_cas_response(self, cas_response_body):
         user = None
-        attributes = None
+        attributes = {}
         try:
             root = ET.fromstring(cas_response_body)
             if not root.tag.endswith("serviceResponse"):
@@ -395,7 +395,6 @@ class CasTicketServlet(ClientV1RestServlet):
                 if child.tag.endswith("user"):
                     user = child.text
                 if child.tag.endswith("attributes"):
-                    attributes = {}
                     for attribute in child:
                         # ElementTree library expands the namespace in
                         # attribute tags to the full URL of the namespace.
@@ -407,8 +406,6 @@ class CasTicketServlet(ClientV1RestServlet):
                         attributes[tag] = attribute.text
             if user is None:
                 raise Exception("CAS response does not contain user")
-            if attributes is None:
-                raise Exception("CAS response does not contain attributes")
         except Exception:
             logger.error("Error parsing CAS response", exc_info=1)
             raise LoginError(401, "Invalid CAS response",
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 0cff0a0cda..f98d5d53ee 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -306,13 +306,6 @@ class StateStore(SQLBaseStore):
         defer.returnValue(results)
 
     def _get_state_groups_from_groups_txn(self, txn, groups, types=None):
-        if types is not None:
-            where_clause = "AND (%s)" % (
-                " OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
-            )
-        else:
-            where_clause = ""
-
         results = {group: {} for group in groups}
         if isinstance(self.database_engine, PostgresEngine):
             # Temporarily disable sequential scans in this transaction. This is
@@ -342,20 +335,43 @@ class StateStore(SQLBaseStore):
                 WHERE state_group IN (
                     SELECT state_group FROM state
                 )
-                %s;
-            """) % (where_clause,)
-
-            for group in groups:
-                args = [group]
-                if types is not None:
-                    args.extend([i for typ in types for i in typ])
+                %s
+            """)
 
-                txn.execute(sql, args)
-                rows = self.cursor_to_dict(txn)
-                for row in rows:
-                    key = (row["type"], row["state_key"])
-                    results[group][key] = row["event_id"]
+            # Turns out that postgres doesn't like doing a list of OR's and
+            # is about 1000x slower, so we just issue a query for each specific
+            # type seperately.
+            if types:
+                clause_to_args = [
+                    (
+                        "AND type = ? AND state_key = ?",
+                        (etype, state_key)
+                    )
+                    for etype, state_key in types
+                ]
+            else:
+                # If types is None we fetch all the state, and so just use an
+                # empty where clause with no extra args.
+                clause_to_args = [("", [])]
+
+            for where_clause, where_args in clause_to_args:
+                for group in groups:
+                    args = [group]
+                    args.extend(where_args)
+
+                    txn.execute(sql % (where_clause,), args)
+                    rows = self.cursor_to_dict(txn)
+                    for row in rows:
+                        key = (row["type"], row["state_key"])
+                        results[group][key] = row["event_id"]
         else:
+            if types is not None:
+                where_clause = "AND (%s)" % (
+                    " OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
+                )
+            else:
+                where_clause = ""
+
             # We don't use WITH RECURSIVE on sqlite3 as there are distributions
             # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
             for group in groups: