Unable to mock method using pytest-mock - pytest

I have a class 'MyClass' with code as below.
class MyClass:
def __init__(update_type_id='1')
self.update_type_id = update_type_id
self._cursor = <database_connection).cursor()
def update_start_dt(self):
self._update_job_ctrl_start_dt()
def _update_job_ctrl_start_dt(self):
update_sql, v_1 = self._get_update_sql(self._update_type_id)
logger.debug(f'Update sql: {update_sql} and v_1: {v_1}')
def _get_update_sql(self, update_type_id: int) -> Tuple:
sql = f"SELECT start_sql, end_sql FROM <database.table> where update_type_key = {self._update_type_id}"
self._run_sql(sql)
record = self._cursor.fetchone()
if record:
return record
else:
logger.error(f'Record Not Found. Update type key ({update_type_id}) not found in the table in the database')
raise Exception
def _run_sql(self, sql_statement: str):
try:
self._cursor.execute(sql_statement)
except (Exception, Error) as e:
logger.error(f'Error {e} encountered when reading from table')
raise e
I am trying to write a test function using pytest-mock which will test the update_start_dt method. The method internally invokes a series of private methods and I am having difficulty in mocking the code which runs through all the private methods. Can anyone help me to understand in what all ways we can mock?
I tried to refer multiple online websites but couldn't get a complete picture.
class TestMyClass:
def test_update_start_dt(mocker,mock_get_connection,mock_run_sql):
mock_manager = mocker.Mock()
mock_get_update_sql = mock_manager.patch('MyClass._get_update_sql')
mock_get_update_sql.return_value = ('123','234')
myclass = MyClass(update_type_id='1')
myclass.update_start_dt()
I am getting error as below for above test code
update_sql, v_1 = self._get_update_sql(self._update_type_id)
ValueError: not enough values to unpack (expected 2, got 0)

The issue here is that you are patching on a Mock object that you are creating, for the purposes of your test you do not need to explicitly create a Mock object. Shown below is how you would test it instead, where we patch directly on the class.
class MyClass:
def __init__(self, update_type_id='1'):
self._update_type_id = update_type_id
self._cursor = None
def update_start_dt(self):
self._update_job_ctrl_start_dt()
def _update_job_ctrl_start_dt(self):
update_sql, v_1 = self._get_update_sql(self._update_type_id)
logger.debug(f'Update sql: {update_sql} and v_1: {v_1}')
def _get_update_sql(self, update_type_id: int):
sql = f"SELECT start_sql, end_sql FROM <database.table> where update_type_key = {self._update_type_id}"
self._run_sql(sql)
record = self._cursor.fetchone()
if record:
return record
else:
logger.error(f'Record Not Found. Update type key ({update_type_id}) not found in the table in the database')
raise Exception
def _run_sql(self, sql_statement: str):
try:
self._cursor.execute(sql_statement)
except (Exception, Error) as e:
logger.error(f'Error {e} encountered when reading from table')
raise e
def test_update_start_dt(mocker):
mock_get_update_sql = mocker.patch.object(MyClass, "_get_update_sql")
mock_get_update_sql.return_value = ("123", "234")
myclass = MyClass(update_type_id='1')
myclass.update_start_dt()
=========================================== test session starts ============================================
platform darwin -- Python 3.8.9, pytest-7.0.1, pluggy-1.0.0
rootdir: ***
plugins: mock-3.7.0
collected 1 item
test_script.py . [100%]
======================================= 1 passed, 1 warning in 0.01s =======================================
Your code would work if you called the Mock object you created instead of the class. That is shown below.
def test_update_start_dt(mocker):
mock_manager = mocker.Mock()
mock_get_update_sql = mock_manager.patch('MyClass._get_update_sql')
mock_get_update_sql.return_value = ('123','234')
# Notice how we use `mock_manager` instead of MyClass
# tests will now pass
myclass = mock_manager(update_type_id='1')
myclass.update_start_dt()
Hopefully you see what the issue is now.

Related

Scala 3 compiler plugin generating expected code but failing at runtime

I'm trying to get started with writing a compiler plugin for scala 3. At this stage, it's primarily based on https://github.com/liufengyun/scala3-plugin-example/blob/master/plugin/src/main/scala/Phases.scala (and the accompanying youtube video explaining how it works).
It's been an interesting process so far, and I'm getting a bit of a feel for some aspects of the compiler.
As a first step, I'm simply trying to wrap a method body into a block, print whatever the returned object was going to be, and then return the object.
This differs from the original plugin mainly in that there was a single side-effecting method call added to each method - this is also assigning a local variable, (which I think is probably the cause of the problem), and moving the method body into a block.
I've produced as minimal of a working example as I could in a fork here: https://github.com/robmwalsh/scala3-plugin-example
The plugin compiles fine, seems to run as part of compilation as expected, and then blows up at runtime. I'm not entirely sure if this is me doing something wrong (not unlikely) or a bug in the compiler (less likely, but a distinct possibility!).
Can anybody please shed some light on why this isn't working? I don't know what flags should be set when creating a new Symbol, so that's one possibility, but there's heaps of stuff that sorta seemed to work so I rolled with it.
Here's where I'm at (the interesting bits):
...
override def prepareForUnit(tree: Tree)(using ctx: Context): Context =
//find the printLn method
val predef = requiredModule("scala.Predef")
printlnSym = predef.requiredMethod("println", List(defn.AnyType))
ctx
override def transformDefDef(tree: DefDef)(using ctx: Context): Tree =
val sym = tree.symbol
// ignore abstract and synthetic methods
if tree.rhs.isEmpty|| sym.isOneOf(Synthetic | Deferred | Private | Accessor)
then return tree
try {
println("\n\n\n\n")
println("========================== tree ==========================")
println(tree.show)
// val body = {tree.rhs}
val body = ValDef(
newSymbol(
tree.symbol, termName("body"), tree.symbol.flags, tree.rhs.tpe),
Block(Nil, tree.rhs)
)
// println(body)
val bodyRef = ref(body.symbol)
val printRes = ref(printlnSym).appliedTo(bodyRef)
// shove it all together in a block
val rhs1 = tpd.Block(body :: printRes :: Nil, bodyRef)
//replace RHS with new
val newDefDef = cpy.DefDef(tree)(rhs = rhs1)
println("====================== transformed ======================")
println(newDefDef.show)
newDefDef
} catch {
case e =>
println("====================== error ===========================")
println(e)
println(e.printStackTrace)
tree
}
...
test program for compiler plugin
object Test extends App:
def foo: String = "forty two"
def bar(x: String): Int = x.length
def baz(x: String, y: Int): String = x + y
baz(foo, bar(foo))
output during compile using plugin (exactly what I wanted! I got very excited at this point)
========================== tree ==========================
def foo: String = "forty two"
====================== transformed ======================
def foo: String =
{
val body: ("forty two" : String) =
{
"forty two"
}
println(body)
body
}
========================== tree ==========================
def bar(x: String): Int = x.length()
====================== transformed ======================
def bar(x: String): Int =
{
val body: Int =
{
x.length()
}
println(body)
body
}
========================== tree ==========================
def baz(x: String, y: Int): String = x.+(y)
====================== transformed ======================
def baz(x: String, y: Int): String =
{
val body: String =
{
x.+(y)
}
println(body)
body
}
output during runtime :'( (this changes depending on the code it's running on, but always the same theme)
Exception in thread "main" java.lang.VerifyError: Bad local variable type
Exception Details:
Location:
testing/Test$.body$2()I #0: aload_1
Reason:
Type top (current frame, locals[1]) is not assignable to reference type
Current Frame:
bci: #0
flags: { }
locals: { 'testing/Test$' }
stack: { }
Bytecode:
0000000: 2bb6 007d ac
at testing.Test.main(Example.scala)
Edit: I'm using scala 3.1.2
I was using the existing flags when creating my new symbol. Instead, I needed to use the Local flag, which I suppose makes sense.

RDD constructed from parsed case class: serialization fails

I'm trying to understand how serialization in the case of a self constructed case class and a parser in a separate object works -- and I fail.
I tried to boil down the problem to:
parsing a string into case classes
constructing an RDD from those
taking the first element in order to print it
case class article(title: String, text: String) extends Serializable {
override def toString = title + s"/" + text
}
object parser {
def parse(line: String): article = {
val subs = "</end>"
val i = line.indexOf(subs)
val title = line.substring(6, i)
val text = line.substring(i + subs.length, line.length)
article(title, text)
}
}
val text = """"<beg>Title1</end>Text 1"
"<beg>Title2</end>Text 2"
"""
val lines = text.split('\n')
val res = lines.map( line => parser.parse(line) )
val rdd = sc.parallelize(res)
rdd.take(1).map( println )
I get a
Job aborted due to stage failure: Failed to serialize task, not attempting to retry it. Exception during serialization: java.io.NotSerializableException
Can a gifted Scala expert please help me -- just that I understand the interaction of serialization in workers and master -- how to fix the parser / article interaction such that serialization works?
Thank you very much.
In your map function from lines.map( line => parser.parse(line) ) you call parser.parse and parser it's your object which is not serializable. Spark internally uses partitions which are spread across the cluster. The map functions will be called on each partitions. Because the partitions are not on the same JVM process, the function that is called on each partition needs to be serializable, that is why your object parser has to obey the rule.

How to call groovy inner class

This is my code, and I try to call the method in the inner class as shown below (the last line, ic = new oc.Inner()). But I get error.
I am using groovy console, and according to groovy documentation I expect that the Inner class can be called from outer class. I am not sure about the syntax.
class Outer {
private String privateStr = 'some string'
def callInnerMethod() {
new Inner().methodA()
}
class Inner {
def methodA() {
println "${privateStr}."
}
}
}
Outer oc = new Outer()
ic = new oc.Inner()
This is what I get as result:
startup failed:
Script1.groovy: 14: unable to resolve class oc.Inner
# line 14, column 6.
ic = new oc.Inner()
^
1 error
How about this:
def ic = new Outer.Inner()
This will likely only work if your inner class is static.
def oc = new Outer()
def ic = new Outer.Inner(oc)
https://groovy-lang.org/differences.html#_creating_instances_of_non_static_inner_classes

Implicits over function closures in Scala

I've been trying to understand implicits for Scala and trying to use them at work - one particular place im stuck at is trying to pass implicits in the following manner
object DBUtils {
case class DB(val jdbcConnection: Connection) {
def execute[A](op: =>Unit): Any = {
implicit val con = jdbcConnection
op
}
}
object DB {
def SQL(query: String)(implicit jdbcConnection: Connection): PreparedStatement = {
jdbcConnection.prepareStatement(query)
}
}
val someDB1 = DB(jdbcConnection)
val someDB2 = DB(jdbcConnection2)
val someSQL = SQL("SOME SQL HERE")
someDB1.execute{someSQL}
someDB2.execute{someSQL}
Currently i get an execption saying that the SQL() function cannot find the implicit jdbcConnection.What gives and what do i do to make it work in the format i need?
Ps-:Im on a slightly older version of Scala(2.10.4) and cannot upgrade
Edit: Changed the problem statement to be more clear - I cannot use a single implicit connection in scope since i can have multiple DBs with different Connections
At the point where SQL is invoked there is no implicit value of type Connection in scope.
In your code snippet the declaration of jdbcConnection is missing, but if you change it from
val jdbcConnection = //...
to
implicit val jdbcConnection = // ...
then you will have an implicit instance of Connection in scope and the compiler should be happy.
Try this:
implicit val con = jdbcConnection // define implicit here
val someDB = DB(jdbcConnection)
val someSQL = SQL("SOME SQL HERE") // need implicit here
someDB.execute{someSQL}
The implicit must be defined in the scope where you need it. (In reality, it's more complicated, because there are rules for looking elsewhere, as you can find in the documentation. But the simplest thing is to make sure the implicit is available in the scope where you need it.)
Make the following changes
1) execute method take a function from Connection to Unit
2) Instead of this val someDB1 = DB(jdbcConnection) use this someDB1.execute{implicit con => someSQL}
object DBUtils {
case class DB(val jdbcConnection: Connection) {
def execute[A](op: Connection =>Unit): Any = {
val con = jdbcConnection
op(con)
}
}
Here is the complete code.
object DB {
def SQL(query: String)(implicit jdbcConnection: Connection): PreparedStatement = {
jdbcConnection.prepareStatement(query)
}
}
val someDB1 = DB(jdbcConnection)
val someDB2 = DB(jdbcConnection2)
val someSQL = SQL("SOME SQL HERE")
someDB1.execute{implicit con => someSQL}
someDB2.execute{implicit con => someSQL}

How to return full row using Slick's insertOrUpdate

I am currently learning Play2, Scala and Slick 3.1, and am pretty stuck with the syntax for using insertOrUpdate and wonder if anyone can please help me.
What I want to do is to return the full row when using insertOrUpdate including the auto inc primary key, but I have only managed to return the number of updated/inserted rows.
Here is my table definition:
package models
final case class Report(session_id: Option[Long], session_name: String, tester_name: String, date: String, jira_ref: String,
duration: String, environment: String, notes: Option[String])
trait ReportDBTableDefinitions {
import slick.driver.PostgresDriver.api._
class Reports(tag: Tag) extends Table[Report](tag, "REPORTS") {
def session_id = column[Long]("SESSION_ID", O.PrimaryKey, O.AutoInc)
def session_name = column[String]("SESSION_NAME")
def tester_name = column[String]("TESTER_NAME")
def date = column[String]("DATE")
def jira_ref = column[String]("JIRA_REF")
def duration = column[String]("DURATION")
def environment = column[String]("ENVIRONMENT")
def notes = column[Option[String]]("NOTES")
def * = (session_id.?, session_name, tester_name, date, jira_ref, duration, environment, notes) <> (Report.tupled, Report.unapply)
}
lazy val reportsTable = TableQuery[Reports]
}
Here is the section of my DAO that relates to insertOrUpdate, and it works just fine, but only returns the number of updated/inserted rows:
package models
import com.google.inject.Inject
import play.api.db.slick.DatabaseConfigProvider
import scala.concurrent.Future
class ReportsDAO #Inject()(protected val dbConfigProvider: DatabaseConfigProvider) extends DAOSlick {
import driver.api._
def save_report(report: Report): Future[Int] = {
dbConfig.db.run(reportsTable.insertOrUpdate(report).transactionally)
}
}
I have tried playing with "returning" but I can't get the syntax I need and keep getting type mismatches e.g. the below doesn't compile (because it's probably completely wrong!)
def save_report(report: Report): Future[Report] = {
dbConfig.db.run(reportsTable.returning(reportsTable).insertOrUpdate(report))
}
Any help appreciated - I'm new to Scala and Slick so apologies if I'm missing something really obvious.
Solved - posting it incase it helps anyone else trying to do something similar:
//will return the new session_id on insert, and None on update
def save_report(report: Report): Future[Option[Long]] = {
val insertQuery = (reportsTable returning reportsTable.map(_.session_id)).insertOrUpdate(report)
dbConfig.db.run(insertQuery)
}
Works well - insertOrUpdate doesn't returning anything it seems on update, so if I need to get the updated data after the update operation I can then run a subsequent query to get the information using the session id.
You cannot return whole Report, first return Id (returning(reportsTable.map(_.session_id))) and then get whole object
Check if report exists in the database if it exists update it, if not go ahead inserting the report into the database.
Note do above operations in all or none fashion by using Transactions
def getReportDBIO(id: Long): DBIO[Report] = reportsTable.filter(_.session_id === id).result.head
def save_report(report: Report): Future[Report] = {
val query = reportsTable.filter(_.session_id === report.session_id)
val existsAction = query.exists.result
val insertOrUpdateAction =
(for {
exists <- existsAction
result <- exists match {
case true =>
query.update(report).flatMap {_ => getReportDBIO(report.session_id)}.transactionally
case false => {
val insertAction = reportsTable.returning(reportsTable.map(_.session_id)) += report
val finalAction = insertAction.flatMap( id => getReportDBIO(id)).transactionally //transactionally is important
finalAction
}
}
} yield result).transactionally
dbConfig.db.run(insertOrUpdateAction)
}
Update your insertOrUpdate function accordingly
You can return the full row, but it is an Option, as the documentation states, it will be empty on an update and will be a Some(...) representing the inserted row on an insert.
So the correct code would be
def save_report(report: Report): Future[Option[Report]] = {dbConfig.db.run(reportsTable.returning(reportsTable).insertOrUpdate(report))}