summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/push/__init__.py12
-rw-r--r--synapse/push/httppusher.py25
-rw-r--r--synapse/push/pusherpool.py47
-rw-r--r--synapse/rest/pusher.py13
-rw-r--r--synapse/storage/_base.py45
-rw-r--r--synapse/storage/pusher.py83
-rw-r--r--synapse/storage/schema/delta/v7.sql3
-rw-r--r--synapse/storage/schema/pusher.sql3
8 files changed, 158 insertions, 73 deletions
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index 5fca3bd772..5fe8719fe7 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -30,7 +30,7 @@ class Pusher(object):
     MAX_BACKOFF = 60 * 60 * 1000
     GIVE_UP_AFTER = 24 * 60 * 60 * 1000
 
-    def __init__(self, _hs, user_name, app_id, app_instance_id,
+    def __init__(self, _hs, user_name, app_id,
                  app_display_name, device_display_name, pushkey, data,
                  last_token, last_success, failing_since):
         self.hs = _hs
@@ -39,7 +39,6 @@ class Pusher(object):
         self.clock = self.hs.get_clock()
         self.user_name = user_name
         self.app_id = app_id
-        self.app_instance_id = app_instance_id
         self.app_display_name = app_display_name
         self.device_display_name = device_display_name
         self.pushkey = pushkey
@@ -48,6 +47,7 @@ class Pusher(object):
         self.last_success = last_success  # not actually used
         self.backoff_delay = Pusher.INITIAL_BACKOFF
         self.failing_since = failing_since
+        self.alive = True
 
     @defer.inlineCallbacks
     def start(self):
@@ -65,7 +65,7 @@ class Pusher(object):
             logger.info("Pusher %s for user %s starting from token %s",
                         self.pushkey, self.user_name, self.last_token)
 
-        while True:
+        while self.alive:
             from_tok = StreamToken.from_string(self.last_token)
             config = PaginationConfig(from_token=from_tok, limit='1')
             chunk = yield self.evStreamHandler.get_stream(
@@ -81,6 +81,9 @@ class Pusher(object):
             if not single_event:
                 continue
 
+            if not self.alive:
+                continue
+
             ret = yield self.dispatch_push(single_event)
             if ret:
                 self.backoff_delay = Pusher.INITIAL_BACKOFF
@@ -142,6 +145,9 @@ class Pusher(object):
                     if self.backoff_delay > Pusher.MAX_BACKOFF:
                         self.backoff_delay = Pusher.MAX_BACKOFF
 
+    def stop(self):
+        self.alive = False
+
     def dispatch_push(self, p):
         pass
 
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index fd7fe4e39c..f94f673391 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -24,14 +24,13 @@ logger = logging.getLogger(__name__)
 
 
 class HttpPusher(Pusher):
-    def __init__(self, _hs, user_name, app_id, app_instance_id,
+    def __init__(self, _hs, user_name, app_id,
                  app_display_name, device_display_name, pushkey, data,
                  last_token, last_success, failing_since):
         super(HttpPusher, self).__init__(
             _hs,
             user_name,
             app_id,
-            app_instance_id,
             app_display_name,
             device_display_name,
             pushkey,
@@ -69,16 +68,18 @@ class HttpPusher(Pusher):
                 # we may have to fetch this over federation and we
                 # can't trust it anyway: is it worth it?
                 #'fromDisplayName': 'Steve Stevington'
-            },
-            #'counts': { -- we don't mark messages as read yet so
-            # we have no way of knowing
-            #    'unread': 1,
-            #    'missedCalls': 2
-            # },
-            'devices': {
-                self.pushkey: {
-                    'data': self.data_minus_url
-                }
+                #'counts': { -- we don't mark messages as read yet so
+                # we have no way of knowing
+                #    'unread': 1,
+                #    'missedCalls': 2
+                # },
+                'devices': [
+                    {
+                        'app_id': self.app_id,
+                        'pushkey': self.pushkey,
+                        'data': self.data_minus_url
+                    }
+                ]
             }
         }
 
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 045c36f3b7..d34ef3f6cf 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -24,17 +24,23 @@ import json
 
 logger = logging.getLogger(__name__)
 
+
 class PusherPool:
     def __init__(self, _hs):
         self.hs = _hs
         self.store = self.hs.get_datastore()
-        self.pushers = []
+        self.pushers = {}
         self.last_pusher_started = -1
 
+    @defer.inlineCallbacks
     def start(self):
-        self._pushers_added()
+        pushers = yield self.store.get_all_pushers()
+        for p in pushers:
+            p['data'] = json.loads(p['data'])
+        self._start_pushers(pushers)
 
-    def add_pusher(self, user_name, kind, app_id, app_instance_id,
+    @defer.inlineCallbacks
+    def add_pusher(self, user_name, kind, app_id,
                    app_display_name, device_display_name, pushkey, data):
         # we try to create the pusher just to validate the config: it
         # will then get pulled out of the database,
@@ -44,7 +50,6 @@ class PusherPool:
             "user_name": user_name,
             "kind": kind,
             "app_id": app_id,
-            "app_instance_id": app_instance_id,
             "app_display_name": app_display_name,
             "device_display_name": device_display_name,
             "pushkey": pushkey,
@@ -53,25 +58,26 @@ class PusherPool:
             "last_success": None,
             "failing_since": None
         })
-        self._add_pusher_to_store(user_name, kind, app_id, app_instance_id,
-                                  app_display_name, device_display_name,
-                                  pushkey, data)
+        yield self._add_pusher_to_store(
+            user_name, kind, app_id,
+            app_display_name, device_display_name,
+            pushkey, data
+        )
 
     @defer.inlineCallbacks
-    def _add_pusher_to_store(self, user_name, kind, app_id, app_instance_id,
+    def _add_pusher_to_store(self, user_name, kind, app_id,
                              app_display_name, device_display_name,
                              pushkey, data):
         yield self.store.add_pusher(
             user_name=user_name,
             kind=kind,
             app_id=app_id,
-            app_instance_id=app_instance_id,
             app_display_name=app_display_name,
             device_display_name=device_display_name,
             pushkey=pushkey,
             data=json.dumps(data)
         )
-        self._pushers_added()
+        self._refresh_pusher((app_id, pushkey))
 
     def _create_pusher(self, pusherdict):
         if pusherdict['kind'] == 'http':
@@ -79,7 +85,6 @@ class PusherPool:
                 self.hs,
                 user_name=pusherdict['user_name'],
                 app_id=pusherdict['app_id'],
-                app_instance_id=pusherdict['app_instance_id'],
                 app_display_name=pusherdict['app_display_name'],
                 device_display_name=pusherdict['device_display_name'],
                 pushkey=pusherdict['pushkey'],
@@ -95,21 +100,21 @@ class PusherPool:
             )
 
     @defer.inlineCallbacks
-    def _pushers_added(self):
-        pushers = yield self.store.get_all_pushers_after_id(
-            self.last_pusher_started
+    def _refresh_pusher(self, app_id_pushkey):
+        p = yield self.store.get_pushers_by_app_id_and_pushkey(
+            app_id_pushkey
         )
-        for p in pushers:
-            p['data'] = json.loads(p['data'])
-        if len(pushers):
-            self.last_pusher_started = pushers[-1]['id']
+        p['data'] = json.loads(p['data'])
 
-        self._start_pushers(pushers)
+        self._start_pushers([p])
 
     def _start_pushers(self, pushers):
-        logger.info("Starting %d pushers", (len(pushers)))
+        logger.info("Starting %d pushers", len(pushers))
         for pusherdict in pushers:
             p = self._create_pusher(pusherdict)
             if p:
-                self.pushers.append(p)
+                fullid = "%s:%s" % (pusherdict['app_id'], pusherdict['pushkey'])
+                if fullid in self.pushers:
+                    self.pushers[fullid].stop()
+                self.pushers[fullid] = p
                 p.start()
diff --git a/synapse/rest/pusher.py b/synapse/rest/pusher.py
index a39341cd8b..5b371318d0 100644
--- a/synapse/rest/pusher.py
+++ b/synapse/rest/pusher.py
@@ -23,16 +23,16 @@ import json
 
 
 class PusherRestServlet(RestServlet):
-    PATTERN = client_path_pattern("/pushers/(?P<pushkey>[\w]*)$")
+    PATTERN = client_path_pattern("/pushers/set$")
 
     @defer.inlineCallbacks
-    def on_PUT(self, request, pushkey):
+    def on_POST(self, request):
         user = yield self.auth.get_user_by_req(request)
 
         content = _parse_json(request)
 
-        reqd = ['kind', 'app_id', 'app_instance_id', 'app_display_name',
-                'device_display_name', 'data']
+        reqd = ['kind', 'app_id', 'app_display_name',
+                'device_display_name', 'pushkey', 'data']
         missing = []
         for i in reqd:
             if i not in content:
@@ -43,14 +43,13 @@ class PusherRestServlet(RestServlet):
 
         pusher_pool = self.hs.get_pusherpool()
         try:
-            pusher_pool.add_pusher(
+            yield pusher_pool.add_pusher(
                 user_name=user.to_string(),
                 kind=content['kind'],
                 app_id=content['app_id'],
-                app_instance_id=content['app_instance_id'],
                 app_display_name=content['app_display_name'],
                 device_display_name=content['device_display_name'],
-                pushkey=pushkey,
+                pushkey=content['pushkey'],
                 data=content['data']
             )
         except PusherConfigException as pce:
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 4881f03368..eb8cc4a9f3 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -195,6 +195,51 @@ class SQLBaseStore(object):
         txn.execute(sql, values.values())
         return txn.lastrowid
 
+    def _simple_upsert(self, table, keyvalues, values):
+        """
+        :param table: The table to upsert into
+        :param keyvalues: Dict of the unique key tables and their new values
+        :param values: Dict of all the nonunique columns and their new values
+        :return: A deferred
+        """
+        return self.runInteraction(
+            "_simple_upsert",
+            self._simple_upsert_txn, table, keyvalues, values
+        )
+
+    def _simple_upsert_txn(self, txn, table, keyvalues, values):
+        # Try to update
+        sql = "UPDATE %s SET %s WHERE %s" % (
+            table,
+            ", ".join("%s = ?" % (k) for k in values),
+            " AND ".join("%s = ?" % (k) for k in keyvalues)
+        )
+        sqlargs = values.values() + keyvalues.values()
+        logger.debug(
+            "[SQL] %s Args=%s",
+            sql, sqlargs,
+        )
+
+        txn.execute(sql, sqlargs)
+        if txn.rowcount == 0:
+            # We didn't update and rows so insert a new one
+            allvalues = {}
+            allvalues.update(keyvalues)
+            allvalues.update(values)
+
+            sql = "INSERT INTO %s (%s) VALUES (%s)" % (
+                table,
+                ", ".join(k for k in allvalues),
+                ", ".join("?" for _ in allvalues)
+            )
+            logger.debug(
+                "[SQL] %s Args=%s",
+                sql, keyvalues.values(),
+            )
+            txn.execute(sql, allvalues.values())
+
+
+
     def _simple_select_one(self, table, keyvalues, retcols,
                            allow_none=False):
         """Executes a SELECT query on the named table, which is expected to
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index a858e46f3b..deabd9cd2e 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -28,16 +28,48 @@ logger = logging.getLogger(__name__)
 
 class PusherStore(SQLBaseStore):
     @defer.inlineCallbacks
-    def get_all_pushers_after_id(self, min_id):
+    def get_pushers_by_app_id_and_pushkey(self, app_id_and_pushkey):
         sql = (
-            "SELECT id, user_name, kind, app_id, app_instance_id,"
+            "SELECT id, user_name, kind, app_id,"
             "app_display_name, device_display_name, pushkey, data, "
             "last_token, last_success, failing_since "
             "FROM pushers "
-            "WHERE id > ?"
+            "WHERE app_id = ? AND pushkey = ?"
         )
 
-        rows = yield self._execute(None, sql, min_id)
+        rows = yield self._execute(
+            None, sql, app_id_and_pushkey[0], app_id_and_pushkey[1]
+        )
+
+        ret = [
+            {
+                "id": r[0],
+                "user_name": r[1],
+                "kind": r[2],
+                "app_id": r[3],
+                "app_display_name": r[4],
+                "device_display_name": r[5],
+                "pushkey": r[6],
+                "data": r[7],
+                "last_token": r[8],
+                "last_success": r[9],
+                "failing_since": r[10]
+            }
+            for r in rows
+        ]
+
+        defer.returnValue(ret[0])
+
+    @defer.inlineCallbacks
+    def get_all_pushers(self):
+        sql = (
+            "SELECT id, user_name, kind, app_id,"
+            "app_display_name, device_display_name, pushkey, data, "
+            "last_token, last_success, failing_since "
+            "FROM pushers"
+        )
+
+        rows = yield self._execute(None, sql)
 
         ret = [
             {
@@ -45,14 +77,13 @@ class PusherStore(SQLBaseStore):
                 "user_name": r[1],
                 "kind": r[2],
                 "app_id": r[3],
-                "app_instance_id": r[4],
-                "app_display_name": r[5],
-                "device_display_name": r[6],
-                "pushkey": r[7],
-                "data": r[8],
-                "last_token": r[9],
-                "last_success": r[10],
-                "failing_since": r[11]
+                "app_display_name": r[4],
+                "device_display_name": r[5],
+                "pushkey": r[6],
+                "data": r[7],
+                "last_token": r[8],
+                "last_success": r[9],
+                "failing_since": r[10]
             }
             for r in rows
         ]
@@ -60,21 +91,22 @@ class PusherStore(SQLBaseStore):
         defer.returnValue(ret)
 
     @defer.inlineCallbacks
-    def add_pusher(self, user_name, kind, app_id, app_instance_id,
+    def add_pusher(self, user_name, kind, app_id,
                    app_display_name, device_display_name, pushkey, data):
         try:
-            yield self._simple_insert(PushersTable.table_name, dict(
-                user_name=user_name,
-                kind=kind,
-                app_id=app_id,
-                app_instance_id=app_instance_id,
-                app_display_name=app_display_name,
-                device_display_name=device_display_name,
-                pushkey=pushkey,
-                data=data
-            ))
-        except IntegrityError:
-            raise StoreError(409, "Pushkey in use.")
+            yield self._simple_upsert(
+                PushersTable.table_name,
+                dict(
+                    app_id=app_id,
+                    pushkey=pushkey,
+                ),
+                dict(
+                    user_name=user_name,
+                    kind=kind,
+                    app_display_name=app_display_name,
+                    device_display_name=device_display_name,
+                    data=data
+                ))
         except Exception as e:
             logger.error("create_pusher with failed: %s", e)
             raise StoreError(500, "Problem creating pusher.")
@@ -113,7 +145,6 @@ class PushersTable(Table):
         "user_name",
         "kind",
         "app_id",
-        "app_instance_id",
         "app_display_name",
         "device_display_name",
         "pushkey",
diff --git a/synapse/storage/schema/delta/v7.sql b/synapse/storage/schema/delta/v7.sql
index b60aeda756..799e48d780 100644
--- a/synapse/storage/schema/delta/v7.sql
+++ b/synapse/storage/schema/delta/v7.sql
@@ -18,7 +18,6 @@ CREATE TABLE IF NOT EXISTS pushers (
   user_name TEXT NOT NULL,
   kind varchar(8) NOT NULL,
   app_id varchar(64) NOT NULL,
-  app_instance_id varchar(64) NOT NULL,
   app_display_name varchar(64) NOT NULL,
   device_display_name varchar(128) NOT NULL,
   pushkey blob NOT NULL,
@@ -27,5 +26,5 @@ CREATE TABLE IF NOT EXISTS pushers (
   last_success BIGINT,
   failing_since BIGINT,
   FOREIGN KEY(user_name) REFERENCES users(name),
-  UNIQUE (user_name, pushkey)
+  UNIQUE (app_id, pushkey)
 );
diff --git a/synapse/storage/schema/pusher.sql b/synapse/storage/schema/pusher.sql
index b60aeda756..799e48d780 100644
--- a/synapse/storage/schema/pusher.sql
+++ b/synapse/storage/schema/pusher.sql
@@ -18,7 +18,6 @@ CREATE TABLE IF NOT EXISTS pushers (
   user_name TEXT NOT NULL,
   kind varchar(8) NOT NULL,
   app_id varchar(64) NOT NULL,
-  app_instance_id varchar(64) NOT NULL,
   app_display_name varchar(64) NOT NULL,
   device_display_name varchar(128) NOT NULL,
   pushkey blob NOT NULL,
@@ -27,5 +26,5 @@ CREATE TABLE IF NOT EXISTS pushers (
   last_success BIGINT,
   failing_since BIGINT,
   FOREIGN KEY(user_name) REFERENCES users(name),
-  UNIQUE (user_name, pushkey)
+  UNIQUE (app_id, pushkey)
 );