summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
authorAmber Brown <hawkowl@atleastfornow.net>2018-05-28 18:57:23 +1000
committerAmber Brown <hawkowl@atleastfornow.net>2018-05-28 18:57:23 +1000
commit754826a8305b3f32a45367cb6bd8bb26bd489a6b (patch)
treec5aba5e1e720474b3f94d15bd9202bda0b120e20 /synapse/storage
parentpepeightttt (diff)
parentMerge pull request #3288 from matrix-org/rav/no_spam_guests (diff)
downloadsynapse-754826a8305b3f32a45367cb6bd8bb26bd489a6b.tar.xz
Merge remote-tracking branch 'origin/develop' into 3218-official-prom
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/_base.py47
-rw-r--r--synapse/storage/client_ips.py6
-rw-r--r--synapse/storage/devices.py9
-rw-r--r--synapse/storage/end_to_end_keys.py6
-rw-r--r--synapse/storage/event_push_actions.py4
-rw-r--r--synapse/storage/events_worker.py2
-rw-r--r--synapse/storage/filtering.py2
-rw-r--r--synapse/storage/keys.py2
-rw-r--r--synapse/storage/roommember.py10
-rw-r--r--synapse/storage/stream.py15
10 files changed, 60 insertions, 43 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index d1b625dc30..d963af5c89 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -27,9 +27,17 @@ import sys
 import time
 import threading
 
+from six import itervalues, iterkeys, iteritems
+from six.moves import intern, range
 
 logger = logging.getLogger(__name__)
 
+try:
+    MAX_TXN_ID = sys.maxint - 1
+except AttributeError:
+    # python 3 does not have a maximum int value
+    MAX_TXN_ID = 2**63 - 1
+
 sql_logger = logging.getLogger("synapse.storage.SQL")
 transaction_logger = logging.getLogger("synapse.storage.txn")
 perf_logger = logging.getLogger("synapse.storage.TIME")
@@ -134,7 +142,7 @@ class PerformanceCounters(object):
 
     def interval(self, interval_duration, limit=3):
         counters = []
-        for name, (count, cum_time) in self.current_counters.iteritems():
+        for name, (count, cum_time) in iteritems(self.current_counters):
             prev_count, prev_time = self.previous_counters.get(name, (0, 0))
             counters.append((
                 (cum_time - prev_time) / interval_duration,
@@ -219,7 +227,7 @@ class SQLBaseStore(object):
 
         # We don't really need these to be unique, so lets stop it from
         # growing really large.
-        self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
+        self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID)
 
         name = "%s-%x" % (desc, txn_id, )
 
@@ -540,7 +548,7 @@ class SQLBaseStore(object):
             ", ".join("%s = ?" % (k,) for k in values),
             " AND ".join("%s = ?" % (k,) for k in keyvalues)
         )
-        sqlargs = values.values() + keyvalues.values()
+        sqlargs = list(values.values()) + list(keyvalues.values())
 
         txn.execute(sql, sqlargs)
         if txn.rowcount > 0:
@@ -558,7 +566,7 @@ class SQLBaseStore(object):
             ", ".join(k for k in allvalues),
             ", ".join("?" for _ in allvalues)
         )
-        txn.execute(sql, allvalues.values())
+        txn.execute(sql, list(allvalues.values()))
         # successfully inserted
         return True
 
@@ -626,8 +634,8 @@ class SQLBaseStore(object):
         }
 
         if keyvalues:
-            sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
-            txn.execute(sql, keyvalues.values())
+            sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
+            txn.execute(sql, list(keyvalues.values()))
         else:
             txn.execute(sql)
 
@@ -691,7 +699,7 @@ class SQLBaseStore(object):
                 table,
                 " AND ".join("%s = ?" % (k, ) for k in keyvalues)
             )
-            txn.execute(sql, keyvalues.values())
+            txn.execute(sql, list(keyvalues.values()))
         else:
             sql = "SELECT %s FROM %s" % (
                 ", ".join(retcols),
@@ -722,9 +730,12 @@ class SQLBaseStore(object):
         if not iterable:
             defer.returnValue(results)
 
+        # iterables can not be sliced, so convert it to a list first
+        it_list = list(iterable)
+
         chunks = [
-            iterable[i:i + batch_size]
-            for i in xrange(0, len(iterable), batch_size)
+            it_list[i:i + batch_size]
+            for i in range(0, len(it_list), batch_size)
         ]
         for chunk in chunks:
             rows = yield self.runInteraction(
@@ -764,7 +775,7 @@ class SQLBaseStore(object):
         )
         values.extend(iterable)
 
-        for key, value in keyvalues.iteritems():
+        for key, value in iteritems(keyvalues):
             clauses.append("%s = ?" % (key,))
             values.append(value)
 
@@ -787,7 +798,7 @@ class SQLBaseStore(object):
     @staticmethod
     def _simple_update_txn(txn, table, keyvalues, updatevalues):
         if keyvalues:
-            where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
+            where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
         else:
             where = ""
 
@@ -799,7 +810,7 @@ class SQLBaseStore(object):
 
         txn.execute(
             update_sql,
-            updatevalues.values() + keyvalues.values()
+            list(updatevalues.values()) + list(keyvalues.values())
         )
 
         return txn.rowcount
@@ -847,7 +858,7 @@ class SQLBaseStore(object):
             " AND ".join("%s = ?" % (k,) for k in keyvalues)
         )
 
-        txn.execute(select_sql, keyvalues.values())
+        txn.execute(select_sql, list(keyvalues.values()))
 
         row = txn.fetchone()
         if not row:
@@ -885,7 +896,7 @@ class SQLBaseStore(object):
             " AND ".join("%s = ?" % (k, ) for k in keyvalues)
         )
 
-        txn.execute(sql, keyvalues.values())
+        txn.execute(sql, list(keyvalues.values()))
         if txn.rowcount == 0:
             raise StoreError(404, "No row found")
         if txn.rowcount > 1:
@@ -903,7 +914,7 @@ class SQLBaseStore(object):
             " AND ".join("%s = ?" % (k, ) for k in keyvalues)
         )
 
-        return txn.execute(sql, keyvalues.values())
+        return txn.execute(sql, list(keyvalues.values()))
 
     def _simple_delete_many(self, table, column, iterable, keyvalues, desc):
         return self.runInteraction(
@@ -935,7 +946,7 @@ class SQLBaseStore(object):
         )
         values.extend(iterable)
 
-        for key, value in keyvalues.iteritems():
+        for key, value in iteritems(keyvalues):
             clauses.append("%s = ?" % (key,))
             values.append(value)
 
@@ -975,7 +986,7 @@ class SQLBaseStore(object):
         txn.close()
 
         if cache:
-            min_val = min(cache.itervalues())
+            min_val = min(itervalues(cache))
         else:
             min_val = max_value
 
@@ -1090,7 +1101,7 @@ class SQLBaseStore(object):
                 " AND ".join("%s = ?" % (k,) for k in keyvalues),
                 " ? ASC LIMIT ? OFFSET ?"
             )
-            txn.execute(sql, keyvalues.values() + pagevalues)
+            txn.execute(sql, list(keyvalues.values()) + list(pagevalues))
         else:
             sql = "SELECT %s FROM %s ORDER BY %s" % (
                 ", ".join(retcols),
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index ba46907737..ce338514e8 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -22,6 +22,8 @@ from . import background_updates
 
 from synapse.util.caches import CACHE_SIZE_FACTOR
 
+from six import iteritems
+
 
 logger = logging.getLogger(__name__)
 
@@ -99,7 +101,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
     def _update_client_ips_batch_txn(self, txn, to_update):
         self.database_engine.lock_table(txn, "user_ips")
 
-        for entry in to_update.iteritems():
+        for entry in iteritems(to_update):
             (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
 
             self._simple_upsert_txn(
@@ -231,5 +233,5 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
                 "user_agent": user_agent,
                 "last_seen": last_seen,
             }
-            for (access_token, ip), (user_agent, last_seen) in results.iteritems()
+            for (access_token, ip), (user_agent, last_seen) in iteritems(results)
         ))
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index 712106b83a..d149d8392e 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -21,6 +21,7 @@ from synapse.api.errors import StoreError
 from ._base import SQLBaseStore, Cache
 from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
 
+from six import itervalues, iteritems
 
 logger = logging.getLogger(__name__)
 
@@ -360,7 +361,7 @@ class DeviceStore(SQLBaseStore):
             return (now_stream_id, [])
 
         if len(query_map) >= 20:
-            now_stream_id = max(stream_id for stream_id in query_map.itervalues())
+            now_stream_id = max(stream_id for stream_id in itervalues(query_map))
 
         devices = self._get_e2e_device_keys_txn(
             txn, query_map.keys(), include_all_devices=True
@@ -373,13 +374,13 @@ class DeviceStore(SQLBaseStore):
         """
 
         results = []
-        for user_id, user_devices in devices.iteritems():
+        for user_id, user_devices in iteritems(devices):
             # The prev_id for the first row is always the last row before
             # `from_stream_id`
             txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
             rows = txn.fetchall()
             prev_id = rows[0][0]
-            for device_id, device in user_devices.iteritems():
+            for device_id, device in iteritems(user_devices):
                 stream_id = query_map[(user_id, device_id)]
                 result = {
                     "user_id": user_id,
@@ -483,7 +484,7 @@ class DeviceStore(SQLBaseStore):
         if devices:
             user_devices = devices[user_id]
             results = []
-            for device_id, device in user_devices.iteritems():
+            for device_id, device in iteritems(user_devices):
                 result = {
                     "device_id": device_id,
                 }
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index ff8538ddf8..b146487943 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -21,6 +21,8 @@ import simplejson as json
 
 from ._base import SQLBaseStore
 
+from six import iteritems
+
 
 class EndToEndKeyStore(SQLBaseStore):
     def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
@@ -81,8 +83,8 @@ class EndToEndKeyStore(SQLBaseStore):
             query_list, include_all_devices,
         )
 
-        for user_id, device_keys in results.iteritems():
-            for device_id, device_info in device_keys.iteritems():
+        for user_id, device_keys in iteritems(results):
+            for device_id, device_info in iteritems(device_keys):
                 device_info["keys"] = json.loads(device_info.pop("key_json"))
 
         defer.returnValue(results)
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index f084a5f54b..d0350ee5fe 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -22,6 +22,8 @@ from synapse.util.caches.descriptors import cachedInlineCallbacks
 import logging
 import simplejson as json
 
+from six import iteritems
+
 logger = logging.getLogger(__name__)
 
 
@@ -420,7 +422,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 
             txn.executemany(sql, (
                 _gen_entry(user_id, actions)
-                for user_id, actions in user_id_actions.iteritems()
+                for user_id, actions in iteritems(user_id_actions)
             ))
 
         return self.runInteraction(
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index ba834854e1..32d9d00ffb 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -337,7 +337,7 @@ class EventsWorkerStore(SQLBaseStore):
     def _fetch_event_rows(self, txn, events):
         rows = []
         N = 200
-        for i in range(1 + len(events) / N):
+        for i in range(1 + len(events) // N):
             evs = events[i * N:(i + 1) * N]
             if not evs:
                 break
diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py
index 78b1e30945..2e2763126d 100644
--- a/synapse/storage/filtering.py
+++ b/synapse/storage/filtering.py
@@ -44,7 +44,7 @@ class FilteringStore(SQLBaseStore):
             desc="get_user_filter",
         )
 
-        defer.returnValue(json.loads(str(def_json).decode("utf-8")))
+        defer.returnValue(json.loads(bytes(def_json).decode("utf-8")))
 
     def add_user_filter(self, user_localpart, user_filter):
         def_json = encode_canonical_json(user_filter)
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 87aeaf71d6..0540c2b0b1 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -92,7 +92,7 @@ class KeyStore(SQLBaseStore):
 
         if verify_key_bytes:
             defer.returnValue(decode_verify_key_bytes(
-                key_id, str(verify_key_bytes)
+                key_id, bytes(verify_key_bytes)
             ))
 
     @defer.inlineCallbacks
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 6a861943a2..7bfc3d91b5 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -30,6 +30,8 @@ from synapse.types import get_domain_from_id
 import logging
 import simplejson as json
 
+from six import itervalues, iteritems
+
 logger = logging.getLogger(__name__)
 
 
@@ -272,7 +274,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         users_in_room = {}
         member_event_ids = [
             e_id
-            for key, e_id in current_state_ids.iteritems()
+            for key, e_id in iteritems(current_state_ids)
             if key[0] == EventTypes.Member
         ]
 
@@ -289,7 +291,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                     users_in_room = dict(prev_res)
                     member_event_ids = [
                         e_id
-                        for key, e_id in context.delta_ids.iteritems()
+                        for key, e_id in iteritems(context.delta_ids)
                         if key[0] == EventTypes.Member
                     ]
                     for etype, state_key in context.delta_ids:
@@ -741,7 +743,7 @@ class _JoinedHostsCache(object):
             if state_entry.state_group == self.state_group:
                 pass
             elif state_entry.prev_group == self.state_group:
-                for (typ, state_key), event_id in state_entry.delta_ids.iteritems():
+                for (typ, state_key), event_id in iteritems(state_entry.delta_ids):
                     if typ != EventTypes.Member:
                         continue
 
@@ -771,7 +773,7 @@ class _JoinedHostsCache(object):
                 self.state_group = state_entry.state_group
             else:
                 self.state_group = object()
-            self._len = sum(len(v) for v in self.hosts_to_joined_users.itervalues())
+            self._len = sum(len(v) for v in itervalues(self.hosts_to_joined_users))
         defer.returnValue(frozenset(self.hosts_to_joined_users))
 
     def __len__(self):
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index ea24710ad8..fb463c525a 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -684,8 +684,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
                 results to only those before
             direction(char): Either 'b' or 'f' to indicate whether we are
                 paginating forwards or backwards from `from_key`.
-            limit (int): The maximum number of events to return. Zero or less
-                means no limit.
+            limit (int): The maximum number of events to return.
             event_filter (Filter|None): If provided filters the events to
                 those that match the filter.
 
@@ -694,6 +693,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             as a list of _EventDictReturn and a token that points to the end
             of the result set.
         """
+
+        assert int(limit) >= 0
+
         # Tokens really represent positions between elements, but we use
         # the convention of pointing to the event before the gap. Hence
         # we have a bit of asymmetry when it comes to equalities.
@@ -723,22 +725,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             bounds += " AND " + filter_clause
             args.extend(filter_args)
 
-        if int(limit) > 0:
-            args.append(int(limit))
-            limit_str = " LIMIT ?"
-        else:
-            limit_str = ""
+        args.append(int(limit))
 
         sql = (
             "SELECT event_id, topological_ordering, stream_ordering"
             " FROM events"
             " WHERE outlier = ? AND room_id = ? AND %(bounds)s"
             " ORDER BY topological_ordering %(order)s,"
-            " stream_ordering %(order)s %(limit)s"
+            " stream_ordering %(order)s LIMIT ?"
         ) % {
             "bounds": bounds,
             "order": order,
-            "limit": limit_str
         }
 
         txn.execute(sql, args)