Understanding actor and making it thread safe - swift

I have an actor that is processing values and is then publishing the values with a Combine Publisher.
I have problems understanding actors, I thought when using actors in an async context, it would automatically be serialised. However, the numbers get processed in different orders and not in the expected order (see class tests for comparison).
I understand that if I would wrap Task around the for loop that then this would be returned serialised, but my understanding is, that I could call a function of an actor and this would then be automatically serialised.
How can I make my actor thread safe so it publishes the values in the expected order even if it is called from a different thread?
import XCTest
import Combine
import CryptoKit
actor AddNumbersActor {
private let _numberPublisher: PassthroughSubject<(Int,String), Never> = .init()
nonisolated lazy var numberPublisher = _numberPublisher.eraseToAnyPublisher()
func process(_ number: Int) {
let string = SHA512.hash(data: Data(String(number).utf8))
.description
_numberPublisher.send((number, string))
}
}
class AddNumbersClass {
private let _numberPublisher: PassthroughSubject<(Int,String), Never> = .init()
lazy var numberPublisher = _numberPublisher.eraseToAnyPublisher()
func process(_ number: Int) {
let string = SHA512.hash(data: Data(String(number).utf8))
.description
_numberPublisher.send((number, string))
}
}
final class TestActorWithPublisher: XCTestCase {
var subscription: AnyCancellable?
override func tearDownWithError() throws {
subscription = nil
}
func testActor() throws {
let addNumbers = AddNumbersActor()
var numbersResults = [(int: Int, string: String)]()
let expectation = expectation(description: "numberOfExpectedResults")
let numberCount = 1000
subscription = addNumbers.numberPublisher
.sink { results in
print(results)
numbersResults.append(results)
if numberCount == numbersResults.count {
expectation.fulfill()
}
}
for number in 1...numberCount {
Task {
await addNumbers.process(number)
}
}
wait(for: [expectation], timeout: 5)
print(numbersResults.count)
XCTAssertEqual(numbersResults[10].0, 11)
XCTAssertEqual(numbersResults[100].0, 101)
XCTAssertEqual(numbersResults[500].0, 501)
}
func testClass() throws {
let addNumbers = AddNumbersClass()
var numbersResults = [(int: Int, string: String)]()
let expectation = expectation(description: "numberOfExpectedResults")
let numberCount = 1000
subscription = addNumbers.numberPublisher
.sink { results in
print(results)
numbersResults.append(results)
if numberCount == numbersResults.count {
expectation.fulfill()
}
}
for number in 1...numberCount {
addNumbers.process(number)
}
wait(for: [expectation], timeout: 5)
print(numbersResults.count)
XCTAssertEqual(numbersResults[10].0, 11)
XCTAssertEqual(numbersResults[100].0, 101)
XCTAssertEqual(numbersResults[500].0, 501)
}
}
``

Using actor does indeed serialize access.
The issue you're running into is that the tests aren't testing whether calls to process() are serialized, they are testing the execution order of the calls. And the execution order of the Task calls is not guaranteed.
Try changing your AddNumbers objects so that instead of the output order reflecting the order in which the calls were made, they will succeed if calls are serialized but will fail if concurrent calls are made. You can do this by keeping a count variable, incrementing it, sleeping a bit, then publishing the count. Concurrent calls will fail, since count will be incremented multiple times before its returned.
If you make that change, the test using an Actor will pass. The test using a class will fail if it calls process() concurrently:
DispatchQueue.global(qos: .default).async {
addNumbers.process()
}
It will also help to understand that Task's scheduling depends on a bunch of stuff. GCD will spin up tons of threads, whereas Swift concurrency will only use 1 worker thread per available core (I think!). So in some execution environments, just wrapping your work in Task { } might be enough to serialize it for you. I've been finding that iOS simulators act as if they have a single core, so task execution ends up being serialized. Also, otherwise unsafe code will work if you ensure the task runs on the main actor, since it guarantees serial execution:
Task { #MainActor in
// ...
}
Here are modified tests showing all this:
class TestActorWithPublisher: XCTestCase {
actor AddNumbersActor {
private let _numberPublisher: PassthroughSubject<Int, Never> = .init()
nonisolated lazy var numberPublisher = _numberPublisher.eraseToAnyPublisher()
var count = 0
func process() {
// Increment the count here
count += 1
// Wait a bit...
Thread.sleep(forTimeInterval: TimeInterval.random(in: 0...0.010))
// Send it back. If other calls to process() were made concurrently, count may have been incremented again before being sent:
_numberPublisher.send(count)
}
}
class AddNumbersClass {
private let _numberPublisher: PassthroughSubject<Int, Never> = .init()
lazy var numberPublisher = _numberPublisher.eraseToAnyPublisher()
var count = 0
func process() {
count += 1
Thread.sleep(forTimeInterval: TimeInterval.random(in: 0...0.010))
_numberPublisher.send(count)
}
}
var subscription: AnyCancellable?
override func tearDownWithError() throws {
subscription = nil
}
func testActor() throws {
let addNumbers = AddNumbersActor()
var numbersResults = [Int]()
let expectation = expectation(description: "numberOfExpectedResults")
let numberCount = 1000
subscription = addNumbers.numberPublisher
.sink { results in
numbersResults.append(results)
if numberCount == numbersResults.count {
expectation.fulfill()
}
}
for _ in 1...numberCount {
Task.detached(priority: .high) {
await addNumbers.process()
}
}
wait(for: [expectation], timeout: 10)
XCTAssertEqual(numbersResults, Array(1...numberCount))
}
func testClass() throws {
let addNumbers = AddNumbersClass()
var numbersResults = [Int]()
let expectation = expectation(description: "numberOfExpectedResults")
let numberCount = 1000
subscription = addNumbers.numberPublisher
.sink { results in
numbersResults.append(results)
if numberCount == numbersResults.count {
expectation.fulfill()
}
}
for _ in 1...numberCount {
DispatchQueue.global(qos: .default).async {
addNumbers.process()
}
}
wait(for: [expectation], timeout: 5)
XCTAssertEqual(numbersResults, Array(1...numberCount))
}
}

Related

Swift 5.5 test async Task in init

I would like to test if my init function works as expected. There is an async call in the init within a Task {} block. How can I make my test wait for the result of the Task block?
class ViewModel: ObservableObject {
#Published private(set) var result: [Item]
init(fetching: RemoteFetching) {
self.result = []
Task {
do {
let result = try await fetching.fetch()
self.result = result // <- need to do something with #MainActor?
} catch {
print(error)
}
}
}
}
Test:
func testFetching() async {
let items = [Item(), Item()]
let fakeFetching = FakeFetching(returnValue: items)
let vm = ViewModel(fetching: FakeFetching())
XCTAssertEqual(vm.result, [])
// wait for fetching, but how?
XCTAssertEqual(vm.result, items])
}
I tried this, but setting the items, only happens after the XCTWaiter. The compiler warns that XCTWaiter cannot be called with await, because it isn't async.
func testFetching() async {
let items = [Item(), Item()]
let fakeFetching = FakeFetching(returnValue: items)
let expectation = XCTestExpectation()
let vm = ViewModel(fetching: FakeFetching())
XCTAssertEqual(vm.result, [])
vm.$items
.dropFirst()
.sink { value in
XCTAssertEqual(value, items)
expectation.fulfill()
}
.store(in: &cancellables)
let result = await XCTWaiter.wait(for: [expectation], timeout: 1)
XCTAssertEqual(result, .completed)
}
Expectation-and-wait is correct. You're just using it wrong.
You are way overthinking this. You don't need an async test method. You don't need to call fulfill yourself. You don't need a Combine chain. Simply use a predicate expectation to wait until vm.result is set.
Basically the rule is this: Testing an async method requires an async test method. But testing the asynchronous "result" of a method that happens to make an asynchronous call, like your init method, simply requires good old-fashioned expectation-and-wait test.
I'll give an example. Here's a reduced version of your code; the structure is essentially the same as what you're doing:
protocol Fetching {
func fetch() async -> String
}
class MyClass {
var result = ""
init(fetcher: Fetching) {
Task {
self.result = await fetcher.fetch()
}
}
}
Okay then, here's how to test it:
final class MockFetcher: Fetching {
func fetch() async -> String { "howdy" }
}
final class MyLibraryTests: XCTestCase {
let fetcher = MockFetcher()
func testMyClassInit() {
let subject = MyClass(fetcher: fetcher)
let expectation = XCTNSPredicateExpectation(
predicate: NSPredicate(block: { _, _ in
subject.result == "howdy"
}), object: nil
)
wait(for: [expectation], timeout: 2)
}
}
Extra for experts: A Bool predicate expectation is such a common thing to use, that it will be found useful to have on hand a convenience method that combines the expectation, the predicate, and the wait into a single package:
extension XCTestCase {
func wait(
_ condition: #escaping #autoclosure () -> (Bool),
timeout: TimeInterval = 10)
{
wait(for: [XCTNSPredicateExpectation(
predicate: NSPredicate(block: { _, _ in condition() }), object: nil
)], timeout: timeout)
}
}
The outcome is that, for example, the above test code can be reduced to this:
func testMyClassInit() {
let subject = MyClass(fetcher: fetcher)
wait(subject.result == "howdy")
}
Convenient indeed. In my own code, I often add an explicit assert, even when it is completely redundant, just to make it perfectly clear what I'm claiming my code does:
func testMyClassInit() {
let subject = MyClass(fetcher: fetcher)
wait(subject.result == "howdy")
XCTAssertEqual(subject.result, "howdy") // redundant but nice
}
Tnx to matt this is the correct way. No need for async in the test function and just using a predicate did the job.
func testFetching() {
let items = [Item(), Item()]
let fakeFetching = FakeFetching(returnValue: items)
let expectation = XCTestExpectation()
let vm = ViewModel(fetching: FakeFetching())
let pred = NSPredicate { _, _ in
vm.items == items
}
let expectation = XCTNSPredicateExpectation(predicate: pred, object: vm)
wait(for: [expectation], timeout: 1)
}
Slight variation on Matt's excellent answer. In my case, I've broken out his extension method into even more granular extensions for additional convenience.
Helper Framework
public typealias Predicate = () -> Bool
public extension NSPredicate {
convenience init(predicate: #escaping #autoclosure Predicate) {
self.init{ _, _ in predicate() }
}
}
public extension XCTNSPredicateExpectation {
convenience init(predicate: #escaping #autoclosure Predicate, object: Any) {
self.init(predicate: NSPredicate(predicate: predicate()), object: object)
}
convenience init(predicate: #escaping #autoclosure Predicate) {
self.init(predicate: NSPredicate(predicate: predicate()))
}
convenience init(predicate: NSPredicate) {
self.init(predicate: predicate, object: nil)
}
}
public extension XCTestCase {
func XCTWait(for condition: #escaping #autoclosure Predicate, timeout: TimeInterval = 10) {
let expectation = XCTNSPredicateExpectation(predicate: condition())
wait(for: [expectation], timeout: timeout)
}
}
With the above in place, the OP's code can be reduced to this...
Unit Test
func testFetching() {
let items = [Item(), Item()]
let fakeFetching = FakeFetching(returnValue: items)
let vm = ViewModel(fetching: FakeFetching())
XCTWait(for: vm.items == items, timeout: 1)
}
Notes on Naming
Above, I'm using a somewhat controversial name in calling my function XCTWait. This is because the XCT prefix should be considered reserved for Apple's XCTest framework. However, the decision to name it this way stems from the desire to improve its discoverability. By naming it as such, when a developer types XCT In their code editor, XCTWait is now presented as one of the offered auto-complete entries** making finding and using much more likely.
However, some purists may frown on this approach, citing if Apple ever added something named similar, this code may suddenly break/stop working (although unlikely unless the signatures also matched.)
As such, use such namings at your own discretion. Alternately, simply rename it to something you prefer/that meets your own naming standards.
(** Provided it is in the same project or in a library/package they've imported somewhere above)

Concurrently run async tasks with unnamed async let

With Swift concurrency, is it possible to have something almost like an 'unnamed' async let?
Here is an example. You have the following actor:
actor MyActor {
private var foo: Int = 0
private var bar: Int = 0
func setFoo(to value: Int) async {
foo = value
}
func setBar(to value: Int) async {
bar = value
}
func printResult() {
print("foo =", foo)
print("bar =", bar)
}
}
Now I want to set foo and bar using the given methods. Simple usage would look like the following:
let myActor = MyActor()
await myActor.setFoo(to: 5)
await myActor.setBar(to: 10)
await myActor.printResult()
However this code is sequentially run. For all intents and purposes, assume setFoo(to:) and setBar(to:) may be a long running task. You can also assume the methods are mutually exclusive (don't share variables & won't affect each other).
To make this code current, async let can be used. However, this just starts the tasks until they are awaited later on. In my example you'll notice I don't need the return value from these methods. All I need is that before printResult() is called, the previous tasks have completed.
I could come up with the following:
let myActor = MyActor()
async let tempFoo: Void = myActor.setFoo(to: 5)
async let tempBar: Void = myActor.setBar(to: 10)
let _ = await [tempFoo, tempBar]
await myActor.printResult()
Explicitly creating these tasks and then awaiting an array of them seems incorrect. Is this really the best way?
This can be achieved with a task group using withTaskGroup(of:returning:body:). The method calls are individual tasks, and then we await waitForAll() which continues when all tasks have completed.
Code:
await withTaskGroup(of: Void.self) { group in
let myActor = MyActor()
group.addTask {
await myActor.setFoo(to: 5)
}
group.addTask {
await myActor.setBar(to: 10)
}
await group.waitForAll()
await myActor.printResult()
}
I made your actor a class to allow concurrent execution of the two methods.
import Foundation
final class Jeep {
private var foo: Int = 0
private var bar: Int = 0
func setFoo(to value: Int) {
print("begin foo")
foo = value
sleep(1)
print("end foo \(value)")
}
func setBar(to value: Int) {
print("begin bar")
bar = value
sleep(2)
print("end bar \(bar)")
}
func printResult() {
print("printResult foo:\(foo), bar:\(bar)")
}
}
let jeep = Jeep()
let blocks = [
{ jeep.setFoo(to: 1) },
{ jeep.setBar(to: 2) },
]
// ...WORK
RunLoop.current.run(mode: RunLoop.Mode.default, before: NSDate(timeIntervalSinceNow: 5) as Date)
Replace WORK with one of these:
// no concurrency, ordered execution
for block in blocks {
block()
}
jeep.printResult()
// concurrency, unordered execution, tasks created upfront programmatically
Task {
async let foo: Void = blocks[0]()
async let bar: Void = blocks[1]()
await [foo, bar]
jeep.printResult()
}
// concurrency, unordered execution, tasks created upfront, but started by the system (I think)
Task {
await withTaskGroup(of: Void.self) { group in
for block in blocks {
group.addTask { block() }
}
}
// when the initialization closure exits, all child tasks are awaited implicitly
jeep.printResult()
}
// concurrency, unordered execution, awaited in order
Task {
let tasks = blocks.map { block in
Task { block() }
}
for task in tasks {
await task.value
}
jeep.printResult()
}
// tasks created upfront, all tasks start concurrently, produce result as soon as they finish
let stream = AsyncStream<Void> { continuation in
Task {
let tasks = blocks.map { block in
Task { block() }
}
for task in tasks {
continuation.yield(await task.value)
}
continuation.finish()
}
}
Task {
// now waiting for all values, bad use of a stream, I know
for await value in stream {
print(value as Any)
}
jeep.printResult()
}

Data Race Issue with Swift Actor

I'm in the process of converting my code to using Swift concurrency and I'm running into an issue with Actor which I don't know how to fix it correctly.
Here is a simple actor:
actor MyActor {
private var count = 0
func increase() {
count += 1
}
}
In other places where I need to update the actor, I have to call its functions in concurrency context:
Task {
await myActor.increase()
}
That's good. But what I don't understand is if the actor return the increase function as a closure like this:
actor MyActor {
private var count = 0
func increase() -> () -> Void {
{
print("START")
self.count += 1
print("END")
}
}
}
In other places, I can get a reference to the returned closure and call it freely in non-concurrency context:
class ViewController: UIViewController {
let actor = MyActor()
var increase: (() -> Void)?
override func viewDidLoad() {
super.viewDidLoad()
Task {
increase = await actor.increase()
}
DispatchQueue.main.asyncAfter(deadline: .now() + 1) {
let increase = self.increase
DispatchQueue.concurrentPerform(iterations: 100) { _ in
increase?()
}
}
}
}
The above code print this to the output:
START
START
START
START
START
END
END
START
START
...
I'm not sure if I understand or use Actor correctly. Actor protects its state from data races, but in this case, it does not prevent that. Is it correct behavior? Is there a way to fix it?

How to convert DispatchQueue debounce to Swift Concurrency task?

I have an existing debouncer utility using DispatchQueue. It accepts a closure and executes it before the time threshold is met. It can be used like this:
let limiter = Debouncer(limit: 5)
var value = ""
func sendToServer() {
limiter.execute {
print("\(Date.now.timeIntervalSince1970): Fire! \(value)")
}
}
value.append("h")
sendToServer() // Waits until 5 seconds
value.append("e")
sendToServer() // Waits until 5 seconds
value.append("l")
sendToServer() // Waits until 5 seconds
value.append("l")
sendToServer() // Waits until 5 seconds
value.append("o")
sendToServer() // Waits until 5 seconds
print("\(Date.now.timeIntervalSince1970): Last operation called")
// 1635691696.482115: Last operation called
// 1635691701.859087: Fire! hello
Notice it is not calling Fire! multiple times, but just 5 seconds after the last time with the value from the last task. The Debouncer instance is configured to hold the last task in queue for 5 seconds no matter how many times it is called. The closure is passed into the execute(block:) method:
final class Debouncer {
private let limit: TimeInterval
private let queue: DispatchQueue
private var workItem: DispatchWorkItem?
private let syncQueue = DispatchQueue(label: "Debouncer", attributes: [])
init(limit: TimeInterval, queue: DispatchQueue = .main) {
self.limit = limit
self.queue = queue
}
#objc func execute(block: #escaping () -> Void) {
syncQueue.async { [weak self] in
if let workItem = self?.workItem {
workItem.cancel()
self?.workItem = nil
}
guard let queue = self?.queue, let limit = self?.limit else { return }
let workItem = DispatchWorkItem(block: block)
queue.asyncAfter(deadline: .now() + limit, execute: workItem)
self?.workItem = workItem
}
}
}
How can I convert this into a concurrent operation so it can be called like below:
let limit = Debouncer(limit: 5)
func sendToServer() {
await limiter.waitUntilFinished
print("\(Date.now.timeIntervalSince1970): Fire! \(value)")
}
sendToServer()
sendToServer()
sendToServer()
However, this wouldn't debounce the tasks but suspend them until the next one gets called. Instead it should cancel the previous task and hold the current task until the debounce time. Can this be done with Swift Concurrency or is there a better approach to do this?
Tasks have the ability to use isCancelled or checkCancellation, but for the sake of a debounce routine, where you want to wait for a period of time, you might just use the throwing rendition of Task.sleep(nanoseconds:), whose documentation says:
If the task is canceled before the time ends, this function throws CancellationError.
Thus, this effectively debounces for 2 seconds.
var task: Task<(), Never>?
func debounced(_ string: String) {
task?.cancel()
task = Task {
do {
try await Task.sleep(nanoseconds: 2_000_000_000)
logger.log("result \(string)")
} catch {
logger.log("canceled \(string)")
}
}
}
Note, Appleā€™s swift-async-algorithms has a debounce for asynchronous sequences.
Based on #Rob's great answer, here's a sample using an actor and Task:
actor Limiter {
enum Policy {
case throttle
case debounce
}
private let policy: Policy
private let duration: TimeInterval
private var task: Task<Void, Never>?
init(policy: Policy, duration: TimeInterval) {
self.policy = policy
self.duration = duration
}
nonisolated func callAsFunction(task: #escaping () async -> Void) {
Task {
switch policy {
case .throttle:
await throttle(task: task)
case .debounce:
await debounce(task: task)
}
}
}
private func throttle(task: #escaping () async -> Void) {
guard self.task?.isCancelled ?? true else { return }
Task {
await task()
}
self.task = Task {
try? await sleep()
self.task?.cancel()
self.task = nil
}
}
private func debounce(task: #escaping () async -> Void) {
self.task?.cancel()
self.task = Task {
do {
try await sleep()
guard !Task.isCancelled else { return }
await task()
} catch {
return
}
}
}
private func sleep() async throws {
try await Task.sleep(nanoseconds: UInt64(duration * 1_000_000_000))
}
}
The tests are inconsistent in passing so I think my assumption on the order of the tasks firing is incorrect, but the sample is a good start I think:
final class LimiterTests: XCTestCase {
func testThrottler() async throws {
// Given
let promise = expectation(description: "Ensure first task fired")
let throttler = Limiter(policy: .throttle, duration: 1)
var value = ""
var fulfillmentCount = 0
promise.expectedFulfillmentCount = 2
func sendToServer(_ input: String) {
throttler {
value += input
// Then
switch fulfillmentCount {
case 0:
XCTAssertEqual(value, "h")
case 1:
XCTAssertEqual(value, "hwor")
default:
XCTFail()
}
promise.fulfill()
fulfillmentCount += 1
}
}
// When
sendToServer("h")
sendToServer("e")
sendToServer("l")
sendToServer("l")
sendToServer("o")
await sleep(2)
sendToServer("wor")
sendToServer("ld")
wait(for: [promise], timeout: 10)
}
func testDebouncer() async throws {
// Given
let promise = expectation(description: "Ensure last task fired")
let limiter = Limiter(policy: .debounce, duration: 1)
var value = ""
var fulfillmentCount = 0
promise.expectedFulfillmentCount = 2
func sendToServer(_ input: String) {
limiter {
value += input
// Then
switch fulfillmentCount {
case 0:
XCTAssertEqual(value, "o")
case 1:
XCTAssertEqual(value, "old")
default:
XCTFail()
}
promise.fulfill()
fulfillmentCount += 1
}
}
// When
sendToServer("h")
sendToServer("e")
sendToServer("l")
sendToServer("l")
sendToServer("o")
await sleep(2)
sendToServer("wor")
sendToServer("ld")
wait(for: [promise], timeout: 10)
}
func testThrottler2() async throws {
// Given
let promise = expectation(description: "Ensure throttle before duration")
let throttler = Limiter(policy: .throttle, duration: 1)
var end = Date.now + 1
promise.expectedFulfillmentCount = 2
func test() {
// Then
XCTAssertLessThan(.now, end)
promise.fulfill()
}
// When
throttler(task: test)
throttler(task: test)
throttler(task: test)
throttler(task: test)
throttler(task: test)
await sleep(2)
end = .now + 1
throttler(task: test)
throttler(task: test)
throttler(task: test)
await sleep(2)
wait(for: [promise], timeout: 10)
}
func testDebouncer2() async throws {
// Given
let promise = expectation(description: "Ensure debounce after duration")
let debouncer = Limiter(policy: .debounce, duration: 1)
var end = Date.now + 1
promise.expectedFulfillmentCount = 2
func test() {
// Then
XCTAssertGreaterThan(.now, end)
promise.fulfill()
}
// When
debouncer(task: test)
debouncer(task: test)
debouncer(task: test)
debouncer(task: test)
debouncer(task: test)
await sleep(2)
end = .now + 1
debouncer(task: test)
debouncer(task: test)
debouncer(task: test)
await sleep(2)
wait(for: [promise], timeout: 10)
}
private func sleep(_ duration: TimeInterval) async {
await Task.sleep(UInt64(duration * 1_000_000_000))
}
}

Share execution among several tasks in GCD?

I have an async function that synchronizes the network and database from the last call, then returns the results. There are several callers from different threads that calls this function.
Instead of executing and serving the request per call, I'd like to queue up the tasks while the async function runs, then flush out the queue so the next set of tasks can be queued up.
Here's what I came up so far:
extension DataWorker {
// Handle simultanuous pull requests in a queue
private static let pullQueue = DispatchQueue(label: "DataWorker.remotePull")
private static var pullTasks = [((SomeType) -> Void)]()
private static var isPulling = false
func remotePull(completion: ((SomeType) -> Void)?) {
DataWorker.pullQueue.async {
if let completion = completion {
DataWorker.pullTasks.append(completion)
}
guard !DataWorker.isPulling else { return }
DataWorker.isPulling = true
self.store.remotePull { result in
print("Remote pull executed")
DataWorker.pullQueue.async {
let tasks = DataWorker.pullTasks
DataWorker.pullTasks.removeAll()
DataWorker.isPulling = false
DispatchQueue.main.async {
tasks.forEach { $0(result) }
}
}
}
}
}
}
Below is how I'm testing it, which I expect exactly 100 iterations but only a couple of remotePull executions:
DispatchQueue.concurrentPerform(iterations: 100) { iteration in
self.dataWorker.remotePull { _ in
print("Iteration: \(iteration)")
}
}
Is this approach even accurate, or a more elegant or efficient way of achieving this shared task approach?