summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/storage/registration.py12
-rw-r--r--synapse/storage/stats.py54
2 files changed, 52 insertions, 14 deletions
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 3f50324253..2d3c7e2dc9 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -869,6 +869,17 @@ class RegistrationStore(
                 (user_id_obj.localpart, create_profile_with_displayname),
             )
 
+        if self.hs.config.stats_enabled:
+            # we create a new completed user statistics row
+
+            # we don't strictly need current_token since this user really can't
+            # have any state deltas before now (as it is a new user), but still,
+            # we include it for completeness.
+            current_token = self._get_max_stream_id_in_current_state_deltas_txn(txn)
+            self._update_stats_delta_txn(
+                txn, now, "user", user_id, {}, complete_with_stream_id=current_token
+            )
+
         self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
         txn.call_after(self.is_guest.invalidate, (user_id,))
 
@@ -1140,6 +1151,7 @@ class RegistrationStore(
             deferred str|None: A str representing a link to redirect the user
             to if there is one.
         """
+
         # Insert everything into a transaction in order to run atomically
         def validate_threepid_session_txn(txn):
             row = self._simple_select_one_txn(
diff --git a/synapse/storage/stats.py b/synapse/storage/stats.py
index 9d6c3027d5..63c8c2840a 100644
--- a/synapse/storage/stats.py
+++ b/synapse/storage/stats.py
@@ -260,6 +260,10 @@ class StatsStore(StateDeltasStore):
                 (i.e. not deltas) of absolute fields.
                 Does not work with per-slice fields.
         """
+
+        if absolute_field_overrides is None:
+            absolute_field_overrides = {}
+
         table, id_col = TYPE_TO_TABLE[stats_type]
 
         quantised_ts = self.quantise_stats_time(int(ts))
@@ -290,9 +294,6 @@ class StatsStore(StateDeltasStore):
             if key not in absolute_field_overrides
         }
 
-        if absolute_field_overrides is None:
-            absolute_field_overrides = {}
-
         if complete_with_stream_id is not None:
             absolute_field_overrides = absolute_field_overrides.copy()
             absolute_field_overrides[
@@ -321,10 +322,8 @@ class StatsStore(StateDeltasStore):
                 txn=txn,
                 into_table=table + "_historical",
                 keyvalues={id_col: stats_id},
-                extra_dst_keyvalues={
-                    "end_ts": end_ts,
-                    "bucket_size": self.stats_bucket_size,
-                },
+                extra_dst_insvalues={"bucket_size": self.stats_bucket_size},
+                extra_dst_keyvalues={"end_ts": end_ts},
                 additive_relatives=per_slice_additive_relatives,
                 src_table=table + "_current",
                 copy_columns=abs_field_names,
@@ -357,7 +356,7 @@ class StatsStore(StateDeltasStore):
             ]
 
             insert_cols = []
-            qargs = [table]
+            qargs = []
 
             for (key, val) in chain(
                 keyvalues.items(), absolutes.items(), additive_relatives.items()
@@ -368,13 +367,14 @@ class StatsStore(StateDeltasStore):
             sql = """
                 INSERT INTO %(table)s (%(insert_cols_cs)s)
                 VALUES (%(insert_vals_qs)s)
-                ON CONFLICT DO UPDATE SET %(updates)s
+                ON CONFLICT (%(key_columns)s) DO UPDATE SET %(updates)s
             """ % {
                 "table": table,
                 "insert_cols_cs": ", ".join(insert_cols),
                 "insert_vals_qs": ", ".join(
                     ["?"] * (len(keyvalues) + len(absolutes) + len(additive_relatives))
                 ),
+                "key_columns": ", ".join(keyvalues),
                 "updates": ", ".join(chain(absolute_updates, relative_updates)),
             }
 
@@ -400,6 +400,7 @@ class StatsStore(StateDeltasStore):
         into_table,
         keyvalues,
         extra_dst_keyvalues,
+        extra_dst_insvalues,
         additive_relatives,
         src_table,
         copy_columns,
@@ -412,6 +413,8 @@ class StatsStore(StateDeltasStore):
              keyvalues (dict[str, any]): Row-identifying key values
              extra_dst_keyvalues (dict[str, any]): Additional keyvalues
                 for `into_table`.
+             extra_dst_insvalues (dict[str, any]): Additional values to insert
+                on new row creation for `into_table`.
              additive_relatives (dict[str, any]): Fields that will be added onto
                 if existing row present. (Must be disjoint from copy_columns.)
              src_table (str): The source table to copy from
@@ -421,18 +424,28 @@ class StatsStore(StateDeltasStore):
         """
         if self.database_engine.can_native_upsert:
             ins_columns = chain(
-                keyvalues, copy_columns, additive_relatives, extra_dst_keyvalues
+                keyvalues,
+                copy_columns,
+                additive_relatives,
+                extra_dst_keyvalues,
+                extra_dst_insvalues,
             )
             sel_exprs = chain(
                 keyvalues,
                 copy_columns,
-                ("?" for _ in chain(additive_relatives, extra_dst_keyvalues)),
+                (
+                    "?"
+                    for _ in chain(
+                        additive_relatives, extra_dst_keyvalues, extra_dst_insvalues
+                    )
+                ),
             )
             keyvalues_where = ("%s = ?" % f for f in keyvalues)
 
             sets_cc = ("%s = EXCLUDED.%s" % (f, f) for f in copy_columns)
             sets_ar = (
-                "%s = EXCLUDED.%s + %s.%s" % (f, f, into_table, f) for f in copy_columns
+                "%s = EXCLUDED.%s + %s.%s" % (f, f, into_table, f)
+                for f in additive_relatives
             )
 
             sql = """
@@ -455,7 +468,14 @@ class StatsStore(StateDeltasStore):
                 "additional_where": additional_where,
             }
 
-            qargs = chain(additive_relatives.values(), keyvalues.values())
+            qargs = list(
+                chain(
+                    additive_relatives.values(),
+                    extra_dst_keyvalues.values(),
+                    extra_dst_insvalues.values(),
+                    keyvalues.values(),
+                )
+            )
             txn.execute(sql, qargs)
         else:
             self.database_engine.lock_table(txn, into_table)
@@ -471,7 +491,13 @@ class StatsStore(StateDeltasStore):
             )
 
             if dest_current_row is None:
-                merged_dict = {**keyvalues, **src_row, **additive_relatives}
+                merged_dict = {
+                    **keyvalues,
+                    **extra_dst_keyvalues,
+                    **extra_dst_insvalues,
+                    **src_row,
+                    **additive_relatives,
+                }
                 self._simple_insert_txn(txn, into_table, merged_dict)
             else:
                 for (key, val) in additive_relatives.items():