Batch inserts and LAST_INSERT_ID with Slick and MariaDB - scala

I'm trying to insert some data into a MariaDB database. I got two tables and I have to insert the rows (using a batch insert) into the first table and use the IDs of the newly-inserted rows to perform a second batch insert into the second table.
I'm doing so in Scala using Alpakka Slick. For the purpose of this question, let's call tests the main table and dependent the second one.
At the moment, my algorithm is as follows:
Insert the rows into tests
Fetch the ID of the first row in the batch using SELECT LAST_INSERT_ID();
Knowing the ID of the first row and the number of rows in the batch, compute by hand the other IDs and use them for the insertion in the second table
This works pretty well with only one connection at a time. However, I'm trying to simulate a scenario with multiple attempts to write simultaneously. To do that, I'm using Scala parallel collections and Akka Stream Source as follows:
// three sources of 10 random Strings each
val sources = Seq.fill(3)(Source(Seq.fill(10)(Random.alphanumeric.take(3).mkString))).zipWithIndex
val parallelSources: ParSeq[(Source[String, NotUsed], Int)] = sources.par
parallelSources.map { case (source, i) =>
source
.grouped(ChunkSize) // performs batch inserts of a given size
.via(insert(i))
.zipWithIndex
.runWith(Sink.foreach { case (_, chunkIndex) => println(s"Chunk $chunkIndex of source $i done") })
}
I'm adding an index to each Source just to use it a prefix in the data I write in the DB.
Here's the code of the insert Flow I've written so far:
def insert(srcIndex: Int): Flow[Seq[String], Unit, NotUsed] = {
implicit val insertSession: SlickSession = slickSession
system.registerOnTermination(() => insertSession.close())
Flow[Seq[String]]
.via(Slick.flowWithPassThrough { chunk =>
(for {
// insert data into `tests`
_ <- InsTests ++= chunk.map(v => TestProj(s"source$srcIndex-$v"))
// fetch last insert ID and connection ID
queryResult <- sql"SELECT CONNECTION_ID(), LAST_INSERT_ID();".as[(Long, Long)].headOption
_ <- queryResult match {
case Some((connId, firstIdInChunk)) =>
println(s"Source $srcIndex, last insert ID $firstIdInChunk, connection $connId")
// compute IDs by hand and write to `dependent`
val depValues = Seq.fill(ChunkSize)(s"source$srcIndex-${Random.alphanumeric.take(6).mkString}")
val depRows =
(firstIdInChunk to (firstIdInChunk + ChunkSize))
.zip(depValues)
.map { case (index, value) => DependentProj(index, value) }
InsDependent ++= depRows
case None => DBIO.failed(new Exception("..."))
}
} yield ()).transactionally
})
}
Where InsTests and InsDependent are Slick's TableQuery objects. slickSession creates a new session for each different insert and is defined as follows:
private def slickSession = {
val db = Database.forURL(
url = "jdbc:mariadb://localhost:3306/test",
user = "root",
password = "password",
executor = AsyncExecutor(
name = "executor",
minThreads = 20,
maxThreads = 20,
queueSize = 1000,
maxConnections = 20
)
)
val profile = slick.jdbc.MySQLProfile
SlickSession.forDbAndProfile(db, profile)
}
The problem is that the last insert IDs returned by the second step of the algorithm overlap. Every run of this app would print something like:
Source 2, last insert ID 6, connection 66
Source 1, last insert ID 5, connection 68
Source 0, last insert ID 7, connection 67
Chunk 0 of source 0 done
Chunk 0 of source 2 done
Chunk 0 of source 1 done
Source 2, last insert ID 40, connection 70
Source 0, last insert ID 26, connection 69
Source 1, last insert ID 27, connection 71
Chunk 1 of source 2 done
Chunk 1 of source 1 done
Chunk 1 of source 0 done
Where it looks like the connection is a different one for each Source, but the IDs overlap (Source 0 sees 7, source 1 sees 5, source 2 sees 2). It is correct that IDs start from 5, as I'm adding 4 dummy rows right after creating the tables (not shown in this question's code). Obviously, I see multiple rows in dependent with the same tests.id, which shouldn't happen.
It's my understanding that last insert IDs refer to a single connection. How is it possible that three different connections see overlapping IDs, considering that the entire flow is wrapped in a transaction (via Slick's transactionally)?
This happens with innodb_autoinc_lock_mode=1. As far as I've seen so far, it doesn't with innodb_autoinc_lock_mode=0, which makes sense, since InnoDB would lock tests until the whole batch insert terminates.
UPDATE after Georg's answer: For some other constraints in the project, I'd like the solution to be compatible with MariaDB 10.4, which, as far as I understand, doesn't feature INSERT...RETURNING. Additionally, Slick's ++= operator's support for returning is quite bad, as also reported here. I tested it on both MariaDB 10.4 and 10.5, and, according to the query logs, Slick does execute single INSERT INTO statements instead of a batch one. In my case, this is not quite acceptable, as I'm planning on writing several chunks of rows in a streaming fashion.
While I also understand that making assumptions about the auto-increment value being 1 is not ideal, we do have control over the Production setup and do not have multi-master replication.

You cannot generate subsequent values based on LAST_INSERT_ID():
There might be a second transaction which was rolled back running at the same time, so there will be a gap in your auto_incremented ID's.
Iterating over the number of rows by incrementing LAST_INSERT_ID value will not work, since it depends of value of session variable ##auto_increment_increment (which is especially in multi master replication not 1).
Instead, you should use RETURNING to get the ID's of inserted rows:
MariaDB [test]> create table t1 (a int not null auto_increment primary key);
Query OK, 0 rows affected (0,022 sec)
MariaDB [test]> insert into t1 (a) values (1),(3),(NULL), (NULL) returning a;
+---+
| a |
+---+
| 1 |
| 3 |
| 4 |
| 5 |
+---+
4 rows in set (0,006 sec)

Related

TopologyTestDriver sending incorrect message on KTable aggregations

I have a topology that aggregates on a KTable.
This is a generic method I created to build this topology on different topics I have.
public static <A, B, C> KTable<C, Set<B>> groupTable(KTable<A, B> table, Function<B, C> getKeyFunction,
Serde<C> keySerde, Serde<B> valueSerde, Serde<Set<B>> aggregatedSerde) {
return table
.groupBy((key, value) -> KeyValue.pair(getKeyFunction.apply(value), value),
Serialized.with(keySerde, valueSerde))
.aggregate(() -> new HashSet<>(), (key, newValue, agg) -> {
agg.remove(newValue);
agg.add(newValue);
return agg;
}, (key, oldValue, agg) -> {
agg.remove(oldValue);
return agg;
}, Materialized.with(keySerde, aggregatedSerde));
}
This works pretty well when using Kafka, but not when testing via `TopologyTestDriver`.
In both scenarios, when I get an update, the subtractor is called first, and then the adder is called. The problem is that when using the TopologyTestDriver, two messages are sent out for updates: one after the subtractor call, and another one after the adder call. Not to mention that the message that is sent after the subrtractor and before the adder is in an incorrect stage.
Any one else could confirm this is a bug? I've tested this for both Kafka versions 2.0.1 and 2.1.0.
EDIT:
I created a testcase in github to illustrate the issue: https://github.com/mulho/topology-testcase
It is expected behavior that there are two output records (one "minus" record, and one "plus" record). It's a little tricky to understand how it works, so let me try to explain.
Assume you have the following input table:
key | value
-----+---------
A | <10,2>
B | <10,3>
C | <11,4>
On KTable#groupBy() you extract the first part of the value as new key (ie, 10 or 11) and later sum the second part (ie, 2, 3, 4) in the aggregation. Because A and B record both have 10 as new key, you would sum 2+3 and you would also sum 4 for new key 11. The result table would be:
key | value
-----+---------
10 | 5
11 | 4
Now assume that an update record <B,<11,5>> change the original input KTable to:
key | value
-----+---------
A | <10,2>
B | <11,5>
C | <11,4>
Thus, the new result table should sum up 5+4 for 11 and 2 for 10:
key | value
-----+---------
10 | 2
11 | 9
If you compare the first result table with the second, you might notice that both rows got update. The old B|<10,3> record is subtracted from 10|5 resulting in 10|2 and the new B|<11,5> record is added to 11|4 resulting in 11|9.
This is exactly the two output records you see. The first output record (after subtract is executed), updates the first row (it subtracts the old value that is not part of the aggregation result any longer), while the second record adds the new value to the aggregation result. In our example, the subtract record would be <10,<null,<10,3>>> and the add record would be <11,<<11,5>,null>> (the format of those record is <key, <plus,minus>> (note that the subtract record only set the minus part while the add record only set the plus part).
Final remark: it is not possible to put plus and minus records together, because the key of the plus and minus record can be different (in our example 11 and 10), and thus might go into different partitions. This implies that the plus and minus operation might be executed by different machines and thus it's not possible to only emit one record that contains both plus and minus part.

How to tune mapping/filtering on big datasets (cross joined from two datasets)?

Spark 2.2.0
I have the following code converted from SQL script. It has been running for two hours and it's still running. Even slower than SQL Server. Is anything not done correctly?
The following is the plan,
Push table2 to all executors
Partition table1 and distribute the partitions to executors.
And each row in table2/t2 joins (cross join) each partition of table1.
So the calculation on the result of the cross-join can be run distributed/parallelly. (I wanted to, for example suppose​ I have 16 executors, keep a copy of t2 on all the 16 executors. Then divide table 1 into 16 partitions, one for each executor. Then each executor do the calculation on one partition of table 1 and t2.)
case class Cols (Id: Int, F2: String, F3: BigDecimal, F4: Date, F5: String,
F6: String, F7: BigDecimal, F8: String, F9: String, F10: String )
case class Result (Id1: Int, ID2: Int, Point: Int)
def getDataFromDB(source: String) = {
import sqlContext.sparkSession.implicits._
sqlContext.read.format("jdbc").options(Map(
"driver" -> "com.microsoft.sqlserver.jdbc.SQLServerDriver",
"url" -> jdbcSqlConn,
"dbtable" -> s"$source"
)).load()
.select("Id", "F2", "F3", "F4", "F5", "F6", "F7", "F8", "F9", "F10")
.as[Cols]
}
val sc = new SparkContext(conf)
val table1:DataSet[Cols] = getDataFromDB("table1").repartition(32).cache()
println(table1.count()) // about 300K rows
val table2:DataSet[Cols] = getDataFromDB("table2") // ~20K rows
table2.take(1)
println(table2.count())
val t2 = sc.broadcast(table2)
import org.apache.spark.sql.{functions => func}
val j = table1.joinWith(t2.value, func.lit(true))
j.map(x => {
val (l, r) = x
Result(l.Id, r.Id,
(if (l.F1!= null && r.F1!= null && l.F1== r.F1) 3 else 0)
+(if (l.F2!= null && r.F2!= null && l.F2== r.F2) 2 else 0)
+ ..... // All kind of the similiar expression
+(if (l.F8!= null && r.F8!= null && l.F8== r.F8) 1 else 0)
)
}).filter(x => x.Value >= 10)
println("Total count %d", j.count()) // This takes forever, the count will be about 100
How to rewrite it with Spark idiomatic way?
Ref: https://forums.databricks.com/questions/6747/how-do-i-get-a-cartesian-product-of-a-huge-dataset.html
(Somehow I feel as if I have seen the code already)
The code is slow because you use just a single task to load the entire dataset from the database using JDBC and despite cache it does not benefit from it.
Start by checking out the physical plan and Executors tab in web UI to find out about the single executor and the single task to do the work.
You should use one of the following to fine-tune the number of tasks for loading:
Use partitionColumn, lowerBound, upperBound options for the JDBC data source
Use predicates option
See JDBC To Other Databases in Spark's official documentation.
After you're fine with the loading, you should work on improving the last count action and add...another count action right after the following line:
val table1: DataSet[Cols] = getDataFromDB("table1").repartition(32).cache()
// trigger caching as it's lazy in Dataset API
table1.count
The reason why the entire query is slow is that you only mark table1 to be cached when an action gets executed which is exactly at the end (!) In other words, cache does nothing useful and more importantly makes the query performance even worse.
Performance will increase after you table2.cache.count too.
If you want to do cross join, use crossJoin operator.
crossJoin(right: Dataset[_]): DataFrame Explicit cartesian join with another DataFrame.
Please note the note from the scaladoc of crossJoin (no pun intended).
Cartesian joins are very expensive without an extra filter that can be pushed down.
The following requirement is already handled by Spark given all the optimizations available.
So the calculation on the result of the cross-join can be run distributed/parallelly.
That's Spark's job (again, no pun intended).
The following requirement begs for broadcast.
I wanted to, for example suppose​ I have 16 executors, keep a copy of t2 on all the 16 executors. Then divide table 1 into 16 partitions, one for each executor. Then each executor do the calculation on one partition of table 1 and t2.)
Use broadcast function to hint Spark SQL's engine to use table2 in broadcast mode.
broadcast[T](df: Dataset[T]): Dataset[T] Marks a DataFrame as small enough for use in broadcast joins.

Replace rows based on a modified timestamp

I am looking for an efficient method (which I can reuse for similar situations) to drop rows which have been updated.
My table has many columns, but the important ones are:
creation_timestamp, id, last_modified_timestamp
My primary key is the creation_timestamp and the id. However, after and id has been created, it can be modified by other users which is indicated by the last_modified_timestamp.
1) Read a daily file and add any new rows (based on creation_timestamp and id)
2) Remove old rows which have a different last_modified_timestamp and replace them with the latest versions.
I typically do most of my operations with Pandas (python library) and pyscopg2, so I am not extremely familiar with PostgreSQL 9.6 which is the database I am using. My initial approach is to just add the last_modified_timestamp to the primary key, and then just use a view to SELECT DISTINCT based on the latest changes. However, it seems like that is 'cheating' and I will be wasting space since I do not need to retain previous versions.
EDIT:
def create_update_query(df, table=FACT_TABLE):
columns = ', '.join([f'{col}' for col in DATABASE_COLUMNS])
constraint = ', '.join([f'{col}' for col in PRIMARY_KEY])
placeholder = ', '.join([f'%({col})s' for col in DATABASE_COLUMNS])
updates = ', '.join([f'{col} = EXCLUDED.{col}' for col in DATABASE_COLUMNS])
query = f"""
INSERT INTO {table} ({columns})
VALUES ({placeholder})
ON CONFLICT ({constraint})
DO UPDATE SET {updates};"""
query.split()
query = ' '.join(query.split())
return query
def load_updates(df, connection=DATABASE):
conn = connection.get_conn()
cursor = conn.cursor()
df1 = df.where((pd.notnull(df)), None)
insert_values = df1.to_dict(orient='records')
for row in insert_values:
cursor.execute(create_update_query(df), row)
conn.commit()
cursor.close()
del cursor
conn.close()
This appears to work. I was running into some issues, so right now i am looping through each row of the DataFrame as a dictionary, then inserting that row. Also, I had to figure out a way to fill in the nan columns with None, because I was getting errors with Timestamp dtypes with blank values, etc.

Insert row based on existing row and update the existing row after (Postgres)

I have a table "tablea" with column "a" and "b"
In the table is a row with a = 1 and b = 3
Now I want to produce a Postgres SQL query which inserts a new row with a=x and b = 3, where x is some provided value and 3 is gotten from the existing row.
In the same go, I want to update the existing row to a = 1 and b = x.
In pseudo code, this would be something similar to
first = select(a=1) limit 1
second = (x, first.b)
first = (first.a, x)
I cannot imagine how to do this in SQL. Currently, I do a select in my application and build the insert and update and send those to the db together.
This means 2 roundtrips to the db, and have ~60.000 of these queries I need to send. Takes quite a while.
If I could implement this in a single query, I could just build the 60.000 queries and send them to the db all at once = much faster

What is the right way to work with slick's 3.0.0 streaming results and Postgresql?

I am trying to figure out how to work with slick streaming. I use slick 3.0.0 with postgres driver
The situation is following: server have to give client sequences of data split into chunks limited by size(in bytes). So, I wrote following slick query:
val sequences = TableQuery[Sequences]
def find(userId: Long, timestamp: Long) = sequences.filter(s ⇒ s.userId === userId && s.timestamp > timestamp).sortBy(_.timestamp.asc).result
val seq = db.stream(find(0L, 0L))
I combined seq with akka-streams Source, wrote custom PushPullStage, that limits size of data(in bytes) and finishes upstream when it reaches size limit. It works just fine. The problem is - when I look into postgres logs, I see query like that
select * from sequences where user_id = 0 and timestamp > 0 order by timestamp;
So, at first glance it appears to be much (and unnecessary) database querying going on, only to use a few bytes in each query. What is the right way to do streaming with Slick so as to minimize database querying and to make best use of the data transferred in each query?
The "right way" to do streaming with Slick and Postgres includes three things:
Must use db.stream()
Must disable autoCommit in JDBC-driver. One way is to make the query run in a transaction by suffixing .transactionally.
Must set fetchSize to be something else than 0 or else postgres will push the whole resultSet to the client in one go.
Ex:
DB.stream(
find(0L, 0L)
.transactionally
.withStatementParameters(fetchSize = 1000)
).foreach(println)
Useful links:
https://github.com/slick/slick/issues/1038
https://github.com/slick/slick/issues/809
The correct way to stream in Slick is as provided in documentation is
val q = for (c <- coffees) yield c.image
val a = q.result
val p1: DatabasePublisher[Blob] = db.stream(a.withStatementParameters(
rsType = ResultSetType.ForwardOnly,
rsConcurrency = ResultSetConcurrency.ReadOnly,
fetchSize = 1000 /*your fetching size*/
).transactionally)