Abort download if file too large in Scala - scala

I am trying to download a file using the Scala APIs, but I would like it to abort if that file is too large (50MB).
I have managed to put together a very inefficient way that works for small files (< 10KB) but runs my CPU through the roof for large files:
var size = 0
val bytes = scala.io.Source.fromURL(url)(scala.io.Codec.ISO8859).toStream.map {
c =>
size = size + 1
if (size > (maxMbSize*1024*1024)) {
throw new Exception(s"File size is greater than the maximum allowed size of $maxMbSize MB")
}
c.toByte
}.toArray
I would like to be able to do this check more efficiently and also avoid the use of the var for the size. Is this possible?
Also I am using the play framework in case anyone knows of an API within that framework that may do what I am looking for.

You don't need to load the data into a byte array at all - you can generate the hash on-the-fly with a DigestInputStream using existing Java libraries. In this example I am loading the data from a String, but you can adapt to load from a URL. We use a tail-recursive function to eliminate the var, and return an Option so that we can indicate a over-size file by returning None.
import java.io._
import java.security._
import scala.annotation.tailrec
def calculateHash(algorithm: MessageDigest, in: String, limit: Int): Option[Array[Byte]] = {
val input = new ByteArrayInputStream(in.getBytes())
val dis = new DigestInputStream(input, algorithm)
#tailrec
def read(total: Int): Option[Array[Byte]] = {
if (total > limit) None
else {
val byte = dis.read()
if (byte == -1) Some(algorithm.digest())
else read(total + 1)
}
}
read(0)
}
Example usage:
val sha1 = MessageDigest.getInstance("SHA1")
calculateHash(sha1, "Hello", 5).get
//> res0: Array[Byte] = Array(-9, -1, -98, -117, 123, -78, -32, -101, 112, -109, 90, 93, 120, 94, 12, -59, -39, -48, -85, -16)
calculateHash(sha1, "Too long!!!", 5)
//> res1: Option[Array[Byte]] = None
You may also be able to get better performance by using the variant of DigestInputStream.read() that uses a buffer:
...
val buffer = new Array[Byte](1024)
#tailrec
def read(total: Int): Option[Array[Byte]] = {
if (total > limit) None
else {
val count = dis.read(buffer, 0, buffer.length)
if (count == -1) Some(algorithm.digest())
else read(total + count)
}
}
....

Since you are materializing all the data into memory, you don't really need to buffer it in small chunks (which is what the scala io library is doing). Also since you want the bytes anyway, you don't need to be decoding the bytes into Char's only to reverse the process.
To lose your size var, you can make use of the zipWithIndex function, which pairs each element with it's index. Do note that it starts at 0, so you need the + 1.
def readMyURL(url: String): Array[Byte] = {
val is = new java.net.URL(url).openStream()
val byteArray = Iterator.continually(is.read).zipWithIndex.takeWhile{
zb =>
if (zb._2 > (maxMbSize*1024*1024) + 1) {
throw new Exception(s"File size is greater than the maximum allowed size of $maxMbSize MB")
}
-1 != zb._1 // -1 is the end of stream
}.map(_._1.toByte).toArray
is.close()
byteArray
}
This is lazy, so the iterator will not traverse until you call toArray.
You may be able to get away with not closing that URL InputStream (it appears the scala io library does not do it).

Related

Using variable second times in function not return the same value?

I start learning Scala, and i wrote that code. And I have question, why val which is constant? When i pass it second time to the same function return other value? How write pure function in scala?
Or any comment if that counting is right?
import java.io.FileNotFoundException
import java.io.IOException
import scala.io.BufferedSource
import scala.io.Source.fromFile
object Main{
def main(args: Array[String]): Unit = {
val fileName: String = if(args.length == 1) args(0) else ""
try {
val file = fromFile(fileName)
/* In file tekst.txt is 4 lines */
println(s"In file $fileName is ${countLines(file)} lines")
/* In file tekst.txt is 0 lines */
println(s"In file $fileName is ${countLines(file)} lines")
file.close
}
catch{
case e: FileNotFoundException => println(s"File $fileName not found")
case _: Throwable => println("Other error")
}
}
def countLines(file: BufferedSource): Long = {
file.getLines.count(_ => true)
}
}
val means that you cannot assign new value to it. If this is something immutable - a number, immutable collection, tuple or case class of other immutable things - then your value will not change over its lifetime - if this is val inside a function, when you assign value to it, it will stay the same until you leave that function. If this is value in class, it will stay the same between all calls to this class. If this is object it will stay the same over whole program life.
But, if you are talking about object which are mutable on their own, then the only immutable part is the reference to object. If you have a val of mutable.MutableList, then you can swap it with another mutable.MutableList, but you can modify the content of the list. Here:
val file = fromFile(fileName)
/* In file tekst.txt is 4 lines */
println(s"In file $fileName is ${countLines(file)} lines")
/* In file tekst.txt is 0 lines */
println(s"In file $fileName is ${countLines(file)} lines")
file.close
file is immutable reference to BufferedSource. You cannot replace it with another BufferedSource - but this class has internal state, it counts how many lines from file it already read, so the first time you operate on it you receive total number of lines in file, and then (since file is already read) 0.
If you wanted that code to be purer, you should contain mutability so that it won't be observable to the user e.g.
def countFileLines(fileName: String): Either[String, Long] = try {
val file = fromFile(fileName)
try {
Right(file.getLines.count(_ => true))
} finally {
file.close()
}
} catch {
case e: FileNotFoundException => Left(s"File $fileName not found")
case _: Throwable => Left("Other error")
}
println(s"In file $fileName is ${countLines(fileName)} lines")
println(s"In file $fileName is ${countLines(fileName)} lines")
Still, you are having side effects there, so ideally it should be something written using IO monad, but for now remember that you should aim for referential transparency - if you could replace each call to countLines(file) with a value from val counted = countLines(file) it would be RT. As you checked, it isn't. So replace it with something that wouldn't change behavior if it was called twice. A way to do it is to call whole computations twice without any global state preserved between them (e.g. internal counter in BufferedSource). IO monads make that easier, so go after them once you feel comfortable with syntax itself (to avoid learning too many things at once).
file.getLines returns Iterator[String] and iterator is consumable meaning we can iterate over it only once, for example, consider
val it = Iterator("a", "b", "c")
it.count(_ => true)
// val res0: Int = 3
it.count(_ => true)
// val res1: Int = 0
Looking at the implementation of count
def count(p: A => Boolean): Int = {
var res = 0
val it = iterator
while (it.hasNext) if (p(it.next())) res += 1
res
}
notice the call to it.next(). This call advances the state of the iterator and if it happens then we cannot go back to previous state.
As an alternative you could try length instead of count
val it = Iterator("a", "b", "c")
it.length
// val res0: Int = 3
it.length
// val res0: Int = 3
Looking at the definition of length which just delegates to size
def size: Int = {
if (knownSize >= 0) knownSize
else {
val it = iterator
var len = 0
while (it.hasNext) { len += 1; it.next() }
len
}
}
notice the guard
if (knownSize >= 0) knownSize
Some collections know their size without having to compute it by iterating over them. For example,
Array(1,2,3).knownSize // 3: I know my size in advance
List(1,2,3).knownSize // -1: I do not know my size in advance so I have to traverse the whole collection to find it
So if the underlying concrete collection of the Iterator knows its size, then call to length will short-cuircuit and it.next() will never execute, which means the iterator will not be consumed. This is the case for default concrete collection used by Iterator factory which is Array
val it = Iterator("a", "b", "c")
it.getClass.getSimpleName
// res6: Class[_ <: Iterator[String]] = class scala.collection.ArrayOps$ArrayIterator
however it is not true for BufferedSource. To workaround the issue consider creating an new iterator each time countLines is called
def countLines(fileName: String): Long = {
fromFile(fileName).getLines().length
}
println(s"In file $fileName is ${countLines(fileName)} lines")
println(s"In file $fileName is ${countLines(fileName)} lines")
// In file build.sbt is 22 lines
// In file build.sbt is 22 lines
Final point regarding value definitions and immutability. Consider
object Foo { var x = 42 } // object contains mutable state
val foo = Foo // value definition
foo.x
// val res0: Int = 42
Foo.x = -11 // mutation happening here
foo.x
// val res1: Int = -11
Here identifier foo is an immutable reference to mutable object.

How to write binary data to a file to implement Huffman compression

I'm trying to implement Huffman compression. After encoding the char into 0 and 1, how do I write it to a file so it'll be compressed? Obviously simply writing the characters 0,1 will only make the file larger.
Let's say I have a string "0101010" which represents bits.
I wish to write the string into file but as in binary, i.e. not the char '0' but the bit 0.
I tried using the getBytesArray() on the string but it seems to make no difference rather than simply writing the string.
Although Sarvesh Kumar Singh code will probably work, it looks so Javish to me that I think this question needs one more answer. The way I imaging Huffman coding to be implemented in Scala is something like this:
import scala.collection._
type Bit = Boolean
type BitStream = Iterator[Bit]
type BitArray = Array[Bit]
type ByteStream = Iterator[Byte]
type CharStream = Iterator[Char]
case class EncodingMapping(charMapping: Map[Char, BitArray], eofCharMapping: BitArray)
def buildMapping(src: CharStream): EncodingMapping = {
def accumulateStats(src: CharStream): Map[Char, Int] = ???
def buildMappingImpl(stats: Map[Char, Int]): EncodingMapping = ???
val stats = accumulateStats(src)
buildMappingImpl(stats)
}
def convertBitsToBytes(bits: BitStream): ByteStream = {
bits.grouped(8).map(bits => {
val res = bits.foldLeft((0.toByte, 0))((acc, bit) => ((acc._1 * 2 + (if (bit) 1 else 0)).toByte, acc._2 + 1))
// last byte might be less than 8 bits
if (res._2 == 8)
res._1
else
(res._1 << (8 - res._2)).toByte
})
}
def encodeImpl(src: CharStream, mapping: EncodingMapping): ByteStream = {
val mainData = src.flatMap(ch => mapping.charMapping(ch))
val fullData = mainData ++ mapping.eofCharMapping
convertBitsToBytes(fullData)
}
// can be used to encode String as src. Thanks to StringLike/StringOps extension
def encode(src: Iterable[Char]): (EncodingMapping, ByteStream) = {
val mapping = buildMapping(src.iterator)
val encoded = encode(src.iterator, mapping)
(mapping, encoded)
}
def wrapClose[A <: java.io.Closeable, B](resource: A)(op: A => B): B = {
try {
op(resource)
}
finally {
resource.close()
}
}
def encodeFile(fileName: String): (EncodingMapping, ByteStream) = {
// note in real life you probably want to specify file encoding as well
val mapping = wrapClose(Source.fromFile(fileName))(buildMapping)
val encoded = wrapClose(Source.fromFile(fileName))(file => encode(file, mapping))
(mapping, encoded)
}
where in accumulateStats you find out how often each Char is present in the src and in buildMappingImpl (which is the main part of the whole Huffman encoding) you first build a tree from that stats and then use the create a fixed EncodingMapping. eofCharMapping is a mapping for the pseudo-EOF char as mentioned in one of the comments. Note that high-level encode methods return both EncodingMapping and ByteStream because in any real life scenario you want to save both.
The piece of logic specifically being asked is located in convertBitsToBytes method. Note that I use Boolean to represent single bit rather than Char and thus Iterator[Bit] (effectively Iterator[Boolean]) rather than String to represent a sequence of bits. The idea of the implementation is based on the grouped method that converts a BitStream into a stream of Bits grouped in a byte-sized groups (except possible for the last one).
IMHO the main advantage of this stream-oriented approach comparing to Sarvesh Kumar Singh answer is that you don't need to load the whole file into memory at once or store the whole encoded file in the memory. Note however that in such case you'll have to read the file twice: first time to build the EncodingMapping and second to apply it. Obviously if the file is small enough you can load it into the memory first and then convert ByteStream to Array[Byte] using .toArray call. But if your file is big, you can use just stream-based approach and easily save the ByteStream into a file using something like .foreach(b => out.write(b))
I don't think this will help you achieve your Huffman compression goal, but in answer to your question:
string-to-value
Converting a String to the value it represents is pretty easy.
val x: Int = "10101010".foldLeft(0)(_*2 + _.asDigit)
Note: You'll have to check for formatting (only ones and zeros) and overflow (strings too long) before conversion.
value-to-file
There are a number of ways to write data to a file. Here's a simple one.
import java.io.{FileOutputStream, File}
val fos = new FileOutputStream(new File("output.dat"))
fos.write(x)
fos.flush()
fos.close()
Note: You'll want to catch any errors thrown.
I will specify all the required imports first,
import java.io.{ File, FileInputStream, FileOutputStream}
import java.nio.file.Paths
import scala.collection.mutable.ArrayBuffer
Now, We are going to need following smaller units to achieve this whole thing,
1 - We need to be able to convert our binary string (eg. "01010") to Array[Byte],
def binaryStringToByteArray(binaryString: String) = {
val byteBuffer = ArrayBuffer.empty[Byte]
var byteStr = ""
for (binaryChar <- binaryString) {
if (byteStr.length < 7) {
byteStr = byteStr + binaryChar
}
else {
try{
val byte = java.lang.Byte.parseByte(byteStr + binaryChar, 2)
byteBuffer += byte
byteStr = ""
}
catch {
case ex: java.lang.NumberFormatException =>
val byte = java.lang.Byte.parseByte(byteStr, 2)
byteBuffer += byte
byteStr = "" + binaryChar
}
}
}
if (!byteStr.isEmpty) {
val byte = java.lang.Byte.parseByte(byteStr, 2)
byteBuffer += byte
byteStr = ""
}
byteBuffer.toArray
}
2 - We need to be able to open the file to serve in our little play,
def openFile(filePath: String): File = {
val path = Paths.get(filePath)
val file = path.toFile
if (file.exists()) file.delete()
if (!file.exists()) file.createNewFile()
file
}
3 - We need to be able to write bytes to a file,
def writeBytesToFile(bytes: Array[Byte], file: File): Unit = {
val fos = new FileOutputStream(file)
fos.write(bytes)
fos.close()
}
4 - We need to be able to read bytes back from the file,
def readBytesFromFile(file: File): Array[Byte] = {
val fis = new FileInputStream(file)
val bytes = new Array[Byte](file.length().toInt)
fis.read(bytes)
fis.close()
bytes
}
5 - We need to be able convert bytes back to our binaryString,
def byteArrayToBinaryString(byteArray: Array[Byte]): String = {
byteArray.map(b => b.toBinaryString).mkString("")
}
Now, we are ready to do every thing we want,
// lets say we had this binary string,
scala> val binaryString = "00101110011010101010101010101"
// binaryString: String = 00101110011010101010101010101
// Now, we need to "pad" this with a leading "1" to avoid byte related issues
scala> val paddedBinaryString = "1" + binaryString
// paddedBinaryString: String = 100101110011010101010101010101
// The file which we will use for this,
scala> val file = openFile("/tmp/a_bit")
// file: java.io.File = /tmp/a_bit
// convert our padded binary string to bytes
scala> val bytes = binaryStringToByteArray(paddedBinaryString)
// bytes: Array[Byte] = Array(75, 77, 85, 85)
// write the bytes to our file,
scala> writeBytesToFile(bytes, file)
// read bytes back from file,
scala> val bytesFromFile = readBytesFromFile(file)
// bytesFromFile: Array[Byte] = Array(75, 77, 85, 85)
// so now, we have our padded string back,
scala> val paddedBinaryStringFromFile = byteArrayToBinaryString(bytes)
// paddedBinaryStringFromFile: String = 1001011100110110101011010101
// remove that "1" from the front and we have our binaryString back,
scala> val binaryStringFromFile = paddedBinaryString.tail
// binaryStringFromFile: String = 00101110011010101010101010101
NOTE :: you may have to make few changes if you want to deal with very large "binary strings" (more than few millions of characters long) to improve performance or even be usable. For example - You will need to start using Streams or Iterators instead of Array[Byte].

Surprisingly slow of mutable.array.drop

I am a newbie in Scala, and when I am trying to profile my Scala code with YourKit, I have some surprising finding regarding the usage of array.drop.
Here is what I write:
...
val items = s.split(" +") // s is a string
...
val s1 = items.drop(2).mkString(" ")
...
In a 1 mins run of my code, YourKit told me that function call items.drop(2) takes around 11% of the total execution time..
Lexer.scala:33 scala.collection.mutable.ArrayOps$ofRef.drop(int) 1054 11%
This is really surprising to me, is there any internal memory copy that slow down the processing? If so, what is the best practice to optimize my simple code snippet? Thank you.
This is really surprising to me, is there any internal memory copy
that slow down the processing?
ArrayOps.drop internally calls IterableLike.slice, which allocates a builder that produces a new Array for each call:
override def slice(from: Int, until: Int): Repr = {
val lo = math.max(from, 0)
val hi = math.min(math.max(until, 0), length)
val elems = math.max(hi - lo, 0)
val b = newBuilder
b.sizeHint(elems)
var i = lo
while (i < hi) {
b += self(i)
i += 1
}
b.result()
}
You're seeing the cost of the iteration + allocation. You didn't specify how many times this happens and what's the size of the collection, but if it's large this could be time consuming.
One way of optimizing this is to generate a List[String] instead which simply iterates the collection and drops it's head element. Note this will occur an additional traversal of the Array[T] to create the list, so make sure to benchmark this to see you actually gain anything:
val items = s.split(" +").toList
val afterDrop = items.drop(2).mkString(" ")
Another possibility is to enrich Array[T] to include your own version of mkString which manually populates a StringBuilder:
object RichOps {
implicit class RichArray[T](val arr: Array[T]) extends AnyVal {
def mkStringWithIndex(start: Int, end: Int, separator: String): String = {
var idx = start
val stringBuilder = new StringBuilder(end - start)
while (idx < end) {
stringBuilder.append(arr(idx))
if (idx != end - 1) {
stringBuilder.append(separator)
}
idx += 1
}
stringBuilder.toString()
}
}
}
And now we have:
object Test {
def main(args: Array[String]): Unit = {
import RichOps._
val items = "hello everyone and welcome".split(" ")
println(items.mkStringWithIndex(2, items.length, " "))
}
Yields:
and welcome

Storing a sequence of Long efficiently in a file in Scala

So I have an association that associates a pair of Ints with a Vector[Long] that can be up to size 10000, and I have anywhere from several hundred thousand to a million of such data. I would like to store this in a single file for later processing in Scala.
Clearly storing this in a plain-text format would take way too much space, so I've been trying to figure out how to do it by writing a Byte stream. However I'm not too sure if this will work since it seems to me that the byteValue() of a Long returns the Byte representation which is still 4 bytes long, and hence I won't save any space? I do not have much experience working with binary formats.
It seems the Scala standard library had a BytePickle that might have been what I was looking for, but has since been deprecated?
An arbitrary Long is about 19.5 ASCII digits long, but only 8 bytes long, so you'll gain a savings of a factor of ~2 if you write it in binary. Now, it may be that most of the values are not actually taking all 8 bytes, in which case you could define some compression scheme yourself.
In any case, you are probably best off writing block data using java.nio.ByteBuffer and friends. Binary data is most efficiently read in blocks, and you might want your file to be randomly accessible, in which case you want your data to look something like so:
<some unique binary header that lets you check the file type>
<int saying how many records you have>
<offset of the first record>
<offset of the second record>
...
<offset of the last record>
<int><int><length of vector><long><long>...<long>
<int><int><length of vector><long><long>...<long>
...
<int><int><length of vector><long><long>...<long>
This is a particularly convenient format for reading and writing using ByteBuffer because you know in advance how big everything is going to be. So you can
val fos = new FileOutputStream(myFileName)
val fc = fos.getChannel // java.nio.channel.FileChannel
val header = ByteBuffer.allocate(28)
header.put("This is my cool header!!".getBytes)
header.putInt(data.length)
fc.write(header)
val offsets = ByteBuffer.allocate(8*data.length)
data.foldLeft(28L+8*data.length){ (n,d) =>
offsets.putLong(n)
n = n + 12 + d.vector.length*8
}
fc.write(offsets)
...
and on the way back in
val fis = new FileInputStream(myFileName)
val fc = fis.getChannel
val header = ByteBuffer.allocate(28)
fc.read(header)
val hbytes = new Array[Byte](24)
header.get(hbytes)
if (new String(hbytes) != "This is my cool header!!") ???
val nrec = header.getInt
val offsets = ByteBuffer.allocate(8*nrec)
fc.read(offsets)
val offsetArray = offsets.getLongs(nrec) // See below!
...
There are some handy methods on ByteBuffer that are absent, but you can add them on with implicits (here for Scala 2.10; with 2.9 make it a plain class, drop the extends AnyVal, and supply an implicit conversion from ByteBuffer to RichByteBuffer):
implicit class RichByteBuffer(val b: java.nio.ByteBuffer) extends AnyVal {
def getBytes(n: Int) = { val a = new Array[Byte](n); b.get(a); a }
def getShorts(n: Int) = { val a = new Array[Short](n); var i=0; while (i<n) { a(i)=b.getShort(); i+=1 } ; a }
def getInts(n: Int) = { val a = new Array[Int](n); var i=0; while (i<n) { a(i)=b.getInt(); i+=1 } ; a }
def getLongs(n: Int) = { val a = new Array[Long](n); var i=0; while (i<n) { a(i)=b.getLong(); i+=1 } ; a }
def getFloats(n: Int) = { val a = new Array[Float](n); var i=0; while (i<n) { a(i)=b.getFloat(); i+=1 } ; a }
def getDoubles(n: Int) = { val a = new Array[Double](n); var i=0; while (i<n) { a(i)=b.getDouble(); i+=1 } ; a }
}
Anyway, the reason to do things this way is that you'll end up with decent performance, which is also a consideration when you have tens of gigabytes of data (which it sounds like you have given hundreds of thousands of vectors of length up to ten thousand).
If your problem is actually much smaller, then don't worry so much about it--pack it into XML or use JSON or some custom text solution (or use DataOutputStream and DataInputStream, which don't perform as well and won't give you random access).
If your problem is actually bigger, you can define two lists of longs; first, the ones that will fit in an Int, say, and then the ones that actually need a full Long (with indices so you know where they are). Data compression is a very case-specific task--assuming you don't just want to use java.util.zip--so without a lot more knowledge about what the data looks like, it's hard to know what to recommend beyond just storing it as a weakly hierarchical binary file as I've described above.
See Java's DataOutputStream. It allows easy and efficient writing of primitive types and Strings to byte streams. In particular, you want something like:
val stream = new DataOutputStream(new FileOutputStream("your_file.bin"))
You can then use the equivalent DataInputStream methods to read from that file to variables again.
I used scala-io, scala-arm to write a binary stream of Long-s. The libraries itself are supposed to be a Scala-way to do things, but these are not in Scala master branch - maybe someone knows why? I use them from time to time.
1) Clone scala-io:
git clone https://github.com/scala-incubator/scala-io.git
Go to scala-io/package and change in Build.scala, val scalaVersion to yours
sbt package
2) Clone scala-arm:
git clone https://github.com/jsuereth/scala-arm.git
Go to scala-arm/package and change in build.scala, scalaVersion := to yours
sbt package
3) Copy somewhere not too far:
scala-io/core/target/scala-xxx/scala-io-core_xxx-0.5.0-SNAPSHOT.jar
scala-io/file/target/scala-xxx/scala-io-file_xxx-0.5.0-SNAPSHOT.jar
scala-arm/target/scala-xxx/scala-arm_xxx-1.3-SNAPSHOT.jar
4) Start REPL:
scala -classpath "/opt/scala-io/scala-io-core_2.10-0.5.0-SNAPSHOT.jar:
/opt/scala-io/scala-io-file_2.10-0.5.0-SNAPSHOT.jar:
/opt/scala-arm/scala-arm_2.10-1.3-SNAPSHOT.jar"
5) :paste actual code:
import scalax.io._
// create data stream
val EOData = Vector(0xffffffffffffffffL)
val data = List(
(0, Vector(0L,1L,2L,3L))
,(1, Vector(4L,5L))
,(2, Vector(6L,7L,8L))
,(3, Vector(9L))
)
var it = Iterator[Long]()
for (rec <- data) {
it = it ++ Vector(rec._1).iterator.map(_.toLong)
it = it ++ rec._2.iterator
it = it ++ EOData.iterator
}
// write data at once
val out: Output = Resource.fromFile("/tmp/data")
out.write(it)(OutputConverter.TraversableLongConverter)

Efficient way to fold list in scala, while avoiding allocations and vars

I have a bunch of items in a list, and I need to analyze the content to find out how many of them are "complete". I started out with partition, but then realized that I didn't need to two lists back, so I switched to a fold:
val counts = groupRows.foldLeft( (0,0) )( (pair, row) =>
if(row.time == 0) (pair._1+1,pair._2)
else (pair._1, pair._2+1)
)
but I have a lot of rows to go through for a lot of parallel users, and it is causing a lot of GC activity (assumption on my part...the GC could be from other things, but I suspect this since I understand it will allocate a new tuple on every item folded).
for the time being, I've rewritten this as
var complete = 0
var incomplete = 0
list.foreach(row => if(row.time != 0) complete += 1 else incomplete += 1)
which fixes the GC, but introduces vars.
I was wondering if there was a way of doing this without using vars while also not abusing the GC?
EDIT:
Hard call on the answers I've received. A var implementation seems to be considerably faster on large lists (like by 40%) than even a tail-recursive optimized version that is more functional but should be equivalent.
The first answer from dhg seems to be on-par with the performance of the tail-recursive one, implying that the size pass is super-efficient...in fact, when optimized it runs very slightly faster than the tail-recursive one on my hardware.
The cleanest two-pass solution is probably to just use the built-in count method:
val complete = groupRows.count(_.time == 0)
val counts = (complete, groupRows.size - complete)
But you can do it in one pass if you use partition on an iterator:
val (complete, incomplete) = groupRows.iterator.partition(_.time == 0)
val counts = (complete.size, incomplete.size)
This works because the new returned iterators are linked behind the scenes and calling next on one will cause it to move the original iterator forward until it finds a matching element, but it remembers the non-matching elements for the other iterator so that they don't need to be recomputed.
Example of the one-pass solution:
scala> val groupRows = List(Row(0), Row(1), Row(1), Row(0), Row(0)).view.map{x => println(x); x}
scala> val (complete, incomplete) = groupRows.iterator.partition(_.time == 0)
Row(0)
Row(1)
complete: Iterator[Row] = non-empty iterator
incomplete: Iterator[Row] = non-empty iterator
scala> val counts = (complete.size, incomplete.size)
Row(1)
Row(0)
Row(0)
counts: (Int, Int) = (3,2)
I see you've already accepted an answer, but you rightly mention that that solution will traverse the list twice. The way to do it efficiently is with recursion.
def counts(xs: List[...], complete: Int = 0, incomplete: Int = 0): (Int,Int) =
xs match {
case Nil => (complete, incomplete)
case row :: tail =>
if (row.time == 0) counts(tail, complete + 1, incomplete)
else counts(tail, complete, incomplete + 1)
}
This is effectively just a customized fold, except we use 2 accumulators which are just Ints (primitives) instead of tuples (reference types). It should also be just as efficient a while-loop with vars - in fact, the bytecode should be identical.
Maybe it's just me, but I prefer using the various specialized folds (.size, .exists, .sum, .product) if they are available. I find it clearer and less error-prone than the heavy-duty power of general folds.
val complete = groupRows.view.filter(_.time==0).size
(complete, groupRows.length - complete)
How about this one? No import tax.
import scala.collection.generic.CanBuildFrom
import scala.collection.Traversable
import scala.collection.mutable.Builder
case class Count(n: Int, total: Int) {
def not = total - n
}
object Count {
implicit def cbf[A]: CanBuildFrom[Traversable[A], Boolean, Count] = new CanBuildFrom[Traversable[A], Boolean, Count] {
def apply(): Builder[Boolean, Count] = new Counter
def apply(from: Traversable[A]): Builder[Boolean, Count] = apply()
}
}
class Counter extends Builder[Boolean, Count] {
var n = 0
var ttl = 0
override def +=(b: Boolean) = { if (b) n += 1; ttl += 1; this }
override def clear() { n = 0 ; ttl = 0 }
override def result = Count(n, ttl)
}
object Counting extends App {
val vs = List(4, 17, 12, 21, 9, 24, 11)
val res: Count = vs map (_ % 2 == 0)
Console println s"${vs} have ${res.n} evens out of ${res.total}; ${res.not} were odd."
val res2: Count = vs collect { case i if i % 2 == 0 => i > 10 }
Console println s"${vs} have ${res2.n} evens over 10 out of ${res2.total}; ${res2.not} were smaller."
}
OK, inspired by the answers above, but really wanting to only pass over the list once and avoid GC, I decided that, in the face of a lack of direct API support, I would add this to my central library code:
class RichList[T](private val theList: List[T]) {
def partitionCount(f: T => Boolean): (Int, Int) = {
var matched = 0
var unmatched = 0
theList.foreach(r => { if (f(r)) matched += 1 else unmatched += 1 })
(matched, unmatched)
}
}
object RichList {
implicit def apply[T](list: List[T]): RichList[T] = new RichList(list)
}
Then in my application code (if I've imported the implicit), I can write var-free expressions:
val (complete, incomplete) = groupRows.partitionCount(_.time != 0)
and get what I want: an optimized GC-friendly routine that prevents me from polluting the rest of the program with vars.
However, I then saw Luigi's benchmark, and updated it to:
Use a longer list so that multiple passes on the list were more obvious in the numbers
Use a boolean function in all cases, so that we are comparing things fairly
http://pastebin.com/2XmrnrrB
The var implementation is definitely considerably faster, even though Luigi's routine should be identical (as one would expect with optimized tail recursion). Surprisingly, dhg's dual-pass original is just as fast (slightly faster if compiler optimization is on) as the tail-recursive one. I do not understand why.
It is slightly tidier to use a mutable accumulator pattern, like so, especially if you can re-use your accumulator:
case class Accum(var complete = 0, var incomplete = 0) {
def inc(compl: Boolean): this.type = {
if (compl) complete += 1 else incomplete += 1
this
}
}
val counts = groupRows.foldLeft( Accum() ){ (a, row) => a.inc( row.time == 0 ) }
If you really want to, you can hide your vars as private; if not, you still are a lot more self-contained than the pattern with vars.
You could just calculate it using the difference like so:
def counts(groupRows: List[Row]) = {
val complete = groupRows.foldLeft(0){ (pair, row) =>
if(row.time == 0) pair + 1 else pair
}
(complete, groupRows.length - complete)
}