I have a polars dataframe
polars dataframe
import polars as pl
data = [
(1, 3),
(1, 3),
(1, 10),
(2, 2),
(2, 2),
(3, 1),
(3, 5)
]
df = pl.DataFrame(data, columns=["item_id", "num_days_after_first_review"])
I would like to have a column that indicates a counter for each item_id with respect to num_days_after_first_review;
so the result will be like
data = [
(1, 3, 1),
(1, 3, 2),
(1, 10, 3),
(2, 1, 1),
(2, 2, 2),
(3, 1, 1),
(3, 5, 2)
]
df = pl.DataFrame(data, columns=["item_id", "num_days_after_first_review", "num"])
One approach is to use .over() with pl.count() and pl.arange()
df.with_columns(
pl.arange(1, pl.count() + 1)
.over("item_id")
.alias("num"))
shape: (7, 3)
┌─────────┬─────────────────────────────┬─────┐
│ item_id | num_days_after_first_review | num │
│ --- | --- | --- │
│ i64 | i64 | i64 │
╞═════════╪═════════════════════════════╪═════╡
│ 1 | 3 | 1 │
│ 1 | 3 | 2 │
│ 1 | 10 | 3 │
│ 2 | 2 | 1 │
│ 2 | 2 | 2 │
│ 3 | 1 | 1 │
│ 3 | 5 | 2 │
└─────────┴─────────────────────────────┴─────┘
I will try to explain this as good as possible, because I am unfortunately quiet new to polars. I have a large time series dataset where each separate timeseries is identified with a group_id. Additionally, there is a time_idx column that identifies which of the possible time series step is present and have a corresponding target value if present. As a minimal example, consider the following:
min_df = pl.DataFrame(
{"grop_idx": [0, 1, 2, 3], "time_idx": [[0, 1, 2, 3], [2, 3], [0, 2, 3], [0,3]]}
)
┌──────────┬───────────────┐
│ grop_idx ┆ time_idx │
│ --- ┆ --- │
│ i64 ┆ list[i64] │
╞══════════╪═══════════════╡
│ 0 ┆ [0, 1, 2, 3] │
├╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 1 ┆ [2, 3] │
├╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 2 ┆ [0, 2, 3] │
├╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 3 ┆ [0, 3] │
└──────────┴───────────────┘
Here, the time range in the dataset is 4 steps long, but not for all individual series all time steps are present. So while group_idx=0 has all present steps, group_idx=0 only has step 0 and 3, meaning that for step 1 and 2 no recorded target value is present.
Now, I would like to obtain all possible sub sequences so that we start from each possible time step for a given sequence length and maximally go to the max_time_step (in this case 3). For example, for sequence_length=3, the expected output would be:
result_df = pl.DataFrame(
{
"group_idx": [0, 0, 1, 1, 2, 2, 3, 3],
"time_idx": [[0, 1, 2, 3], [0, 1, 2, 3], [2, 3], [2, 3], [0,2,3], [0,2,3], [0,3], [0,3]],
"sub_sequence": [[0,1,2], [1,2,3], [None, None, 2], [None, 2, 3], [0, None, 2], [None, 2, 3], [0, None, None], [None, None, 3]]
}
)
┌───────────┬───────────────┬─────────────────┐
│ group_idx ┆ time_idx ┆ sub_sequence │
│ --- ┆ --- ┆ --- │
│ i64 ┆ list[i64] ┆ list[i64] │
╞═══════════╪═══════════════╪═════════════════╡
│ 0 ┆ [0, 1, ... 3] ┆ [0, 1, 2] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 0 ┆ [0, 1, ... 3] ┆ [1, 2, 3] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 1 ┆ [2, 3] ┆ [null, null, 2] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 1 ┆ [2, 3] ┆ [null, 2, 3] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 2 ┆ [0, 2, 3] ┆ [0, null, 2] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 2 ┆ [0, 2, 3] ┆ [null, 2, 3] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 3 ┆ [0, 3] ┆ [0, null, null] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 3 ┆ [0, 3] ┆ [null, null, 3] │
└───────────┴───────────────┴─────────────────┘
All of this should be computed within polars, because the real dataset is much larger both in terms of the number of time series and time series length.
Edit:
Based on the suggestion by #ΩΠΟΚΕΚΡΥΜΜΕΝΟΣ I have tried the following on the actual dataset (~200 million rows after .explode()). I forgot to say that we can assume that that group_idxand time_idx are already sorted. However, this gets killed.
(
min_df.lazy()
.with_column(
pl.col("time_idx").alias("time_idx_nulls")
)
.groupby_rolling(
index_column='time_idx',
by='group_idx',
period=str(max_sequence_length) + 'i',
)
.agg(pl.col("time_idx_nulls"))
.filter(pl.col('time_idx_nulls').arr.lengths() == max_sequence_length)
)
Here's an algorithm that needs only the desired sub-sequence length as input. It uses groupby_rolling to create your sub-sequences.
period = 3
min_df = min_df.explode('time_idx')
(
min_df.get_column('group_idx').unique().to_frame()
.join(
min_df.get_column('time_idx').unique().to_frame(),
how='cross'
)
.join(
min_df.with_column(pl.col('time_idx').alias('time_idx_nulls')),
on=['group_idx', 'time_idx'],
how='left',
)
.groupby_rolling(
index_column='time_idx',
by='group_idx',
period=str(period) + 'i',
)
.agg(pl.col("time_idx_nulls"))
.filter(pl.col('time_idx_nulls').arr.lengths() == period)
.sort('group_idx')
)
shape: (8, 3)
┌───────────┬──────────┬─────────────────┐
│ group_idx ┆ time_idx ┆ time_idx_nulls │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ list[i64] │
╞═══════════╪══════════╪═════════════════╡
│ 0 ┆ 2 ┆ [0, 1, 2] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 0 ┆ 3 ┆ [1, 2, 3] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 1 ┆ 2 ┆ [null, null, 2] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 1 ┆ 3 ┆ [null, 2, 3] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 2 ┆ 2 ┆ [0, null, 2] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 2 ┆ 3 ┆ [null, 2, 3] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 3 ┆ 2 ┆ [0, null, null] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 3 ┆ 3 ┆ [null, null, 3] │
└───────────┴──────────┴─────────────────┘
And for example, with period = 2:
shape: (12, 3)
┌───────────┬──────────┬────────────────┐
│ group_idx ┆ time_idx ┆ time_idx_nulls │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ list[i64] │
╞═══════════╪══════════╪════════════════╡
│ 0 ┆ 1 ┆ [0, 1] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 0 ┆ 2 ┆ [1, 2] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 0 ┆ 3 ┆ [2, 3] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 1 ┆ 1 ┆ [null, null] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 1 ┆ 2 ┆ [null, 2] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 1 ┆ 3 ┆ [2, 3] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 2 ┆ 1 ┆ [0, null] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 2 ┆ 2 ┆ [null, 2] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 2 ┆ 3 ┆ [2, 3] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 3 ┆ 1 ┆ [0, null] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 3 ┆ 2 ┆ [null, null] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 3 ┆ 3 ┆ [null, 3] │
└───────────┴──────────┴────────────────┘
Edit: managing RAM requirements
One way that we can manage RAM requirements (for this, or any other algorithm on large datasets) is to find ways to divide-and-conquer.
If we look at our particular problem, each line in the input dataset leads to results that are independent of any other line. We can use this fact to apply our algorithm in batches.
But first, let's create some data that leads to a large problem:
min_time = 0
max_time = 1_000
nbr_groups = 400_000
min_df = (
pl.DataFrame({"time_idx": [list(range(min_time, max_time, 2))]})
.join(
pl.arange(0, nbr_groups, eager=True).alias("group_idx").to_frame(),
how="cross"
)
)
min_df.explode('time_idx')
shape: (200000000, 2)
┌──────────┬───────────┐
│ time_idx ┆ group_idx │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞══════════╪═══════════╡
│ 0 ┆ 0 │
├╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┤
│ 2 ┆ 0 │
├╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┤
│ 4 ┆ 0 │
├╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┤
│ 6 ┆ 0 │
├╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┤
│ ... ┆ ... │
├╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┤
│ 992 ┆ 399999 │
├╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┤
│ 994 ┆ 399999 │
├╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┤
│ 996 ┆ 399999 │
├╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┤
│ 998 ┆ 399999 │
└──────────┴───────────┘
The input dataset when exploded is 200 million records .. so roughly the size you describe. (Of course, using divide-and-conquer, we won't explode the full dataset.)
To divide-and-conquer this, we'll slice our input dataset into smaller datasets, run the algorithm on the smaller datasets, and then concat the results into one large dataset. (One nice feature of slice is that it's very cheap - it's simply a window into the original dataset, so it consumes very little additional RAM.)
Notice the slice_size variable. You'll need to experiment with this value on your particular computing platform. You want to set this as large as your RAM requirements allow. If set too low, your program will take too long. If set too high, your program will crash. (I've arbitrarily set this to 10,000 as a starting value.)
time_index_df = (
pl.arange(min_time, max_time, eager=True, dtype=pl.Int64)
.alias("time_idx")
.to_frame()
.lazy()
)
period = 3
slice_size = 10_000
result = pl.concat(
[
(
time_index_df
.join(
min_df
.lazy()
.slice(next_index, slice_size)
.select("group_idx"),
how="cross",
)
.join(
min_df
.lazy()
.slice(next_index, slice_size)
.explode('time_idx')
.with_column(pl.col("time_idx").alias("time_idx_nulls")),
on=["group_idx", "time_idx"],
how="left",
)
.groupby_rolling(
index_column='time_idx',
by='group_idx',
period=str(period) + 'i',
)
.agg(pl.col("time_idx_nulls"))
.filter(pl.col('time_idx_nulls').arr.lengths() == period)
.select(['group_idx', 'time_idx_nulls'])
.collect()
)
for next_index in range(0, min_df.height, slice_size)
]
)
result.sort('group_idx')
shape: (399200000, 2)
┌───────────┬───────────────────┐
│ group_idx ┆ time_idx_nulls │
│ --- ┆ --- │
│ i64 ┆ list[i64] │
╞═══════════╪═══════════════════╡
│ 0 ┆ [0, null, 2] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 0 ┆ [null, 2, null] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 0 ┆ [2, null, 4] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 0 ┆ [null, 4, null] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ ... ┆ ... │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 399999 ┆ [994, null, 996] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 399999 ┆ [null, 996, null] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 399999 ┆ [996, null, 998] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 399999 ┆ [null, 998, null] │
└───────────┴───────────────────┘
Some other things
You do actually need to use the joins. The joins are used to fill in the "holes" in your sequences with null values.
Also, notice that I've put each slice/batch into lazy mode, but not the entire algorithm. Depending on your computing platform, using lazy mode for the entire algorithm may again overwhelm your system, as Polars attempts to spread the work across multiple processors which could lead to more out-of-memory situations for you.
Also note the humongous size of my output dataset: almost 400 million records. I did this purposely as a reminder that your output dataset may be the ultimate problem. That is, any algorithm would fail if the result dataset is larger than your RAM can hold.
Here's another approach using duckdb
It seems to perform much better in both runtime and memory in my local benchmark.
You can .explode("subsequence") afterwards to get a row per subsequence - this seems to be quite memory intensive though.
Update: polars can perform single column .explode() for "free". https://github.com/pola-rs/polars/pull/5676
You can unnest([ ... ] subsequence) to do the explode in duckdb - it seems to be a bit slower currently.
Update: Using the explode_table() from https://issues.apache.org/jira/plugins/servlet/mobile#issue/ARROW-12099 seems to add very little overhead.
>>> import duckdb
...
... min_df = pl.DataFrame({
... "group_idx": [0, 1, 2, 3],
... "time_idx": [[0, 1, 2, 3], [2, 3], [0, 2, 3], [0,3]]
... })
... max_time_step = 3
... sequence_length = 2
... upper_bound = (
... max_time_step - (
... 1 if max_time_step % sequence_length == 0 else 0
... )
... )
... tbl = min_df.to_arrow()
... pl.from_arrow(
... duckdb.connect().execute(f"""
... select
... group_idx, [
... time_idx_nulls[n: n + {sequence_length - 1}]
... for n in range(1, {upper_bound + 1})
... ] subsequence
... from (
... from tbl select group_idx, list_transform(
... range(0, {max_time_step + 1}),
... n -> case when list_has(time_idx, n) then n end
... ) time_idx_nulls
... )
... """)
... .arrow()
... )
shape: (4, 2)
┌───────────┬─────────────────────────────────────┐
│ group_idx | subsequence │
│ --- | --- │
│ i64 | list[list[i64]] │
╞═══════════╪═════════════════════════════════════╡
│ 0 | [[0, 1], [1, 2], [2, 3]] │
├───────────┼─────────────────────────────────────┤
│ 1 | [[null, null], [null, 2], [2, 3]... │
├───────────┼─────────────────────────────────────┤
│ 2 | [[0, null], [null, 2], [2, 3]] │
├───────────┼─────────────────────────────────────┤
│ 3 | [[0, null], [null, null], [null,... │
└─//────────┴─//──────────────────────────────────┘
I suspect there should be a cleaner way to do this but you could create range/mask list columns:
>>> max_time_step = 3
>>> sequence_length = 3
>>> (
... min_df
... .with_columns([
... pl.arange(0, max_time_step + 1).list().alias("range"),
... pl.col("time_idx").arr.eval(
... pl.arange(0, max_time_step + 1).is_in(pl.element()),
... parallel=True
... ).alias("mask")
... ])
... )
shape: (4, 4)
┌───────────┬───────────────┬───────────────┬──────────────────────────┐
│ group_idx | time_idx | range | mask │
│ --- | --- | --- | --- │
│ i64 | list[i64] | list[i64] | list[bool] │
╞═══════════╪═══════════════╪═══════════════╪══════════════════════════╡
│ 0 | [0, 1, ... 3] | [0, 1, ... 3] | [true, true, ... true] │
├───────────┼───────────────┼───────────────┼──────────────────────────┤
│ 1 | [2, 3] | [0, 1, ... 3] | [false, false, ... true] │
├───────────┼───────────────┼───────────────┼──────────────────────────┤
│ 2 | [0, 2, 3] | [0, 1, ... 3] | [true, false, ... true] │
├───────────┼───────────────┼───────────────┼──────────────────────────┤
│ 3 | [0, 3] | [0, 1, ... 3] | [true, false, ... true] │
└─//────────┴─//────────────┴─//────────────┴─//───────────────────────┘
You can then .explode() those columns, replace true with the number and group them back together.
Update #1: Use #ΩΠΟΚΕΚΡΥΜΜΕΝΟΣ's .groupby_rolling() technique to
generate correct sub-sequences.
Update #2: Use regular .groupby() and .list().slice() to generate sub-sequences.
>>> min_df = pl.DataFrame({
... "group_idx": [0, 1, 2, 3],
... "time_idx": [[0, 1, 2, 3], [2, 3], [0, 2, 3], [0,3]]
... })
... max_time_step = 3
... sequence_length = 2
... (
... min_df
... .with_columns([
... pl.arange(0, max_time_step + 1).list().alias("range"),
... pl.col("time_idx").arr.eval(
... pl.arange(0, max_time_step + 1).is_in(pl.element()),
... parallel=True
... ).alias("mask")
... ])
... .explode(["range", "mask"])
... .with_column(
... pl.when(pl.col("mask"))
... .then(pl.col("range"))
... .alias("value"))
... .groupby("group_idx", maintain_order=True)
... .agg([
... pl.col("value")
... .list()
... .slice(length=sequence_length, offset=n)
... .suffix(f"{n}")
... for n in range(0, max_time_step - (1 if max_time_step % sequence_length == 0 else 0))
... ])
... .melt("group_idx", value_name="subsequence")
... .drop("variable")
... .sort("group_idx")
... )
shape: (12, 2)
┌───────────┬──────────────┐
│ group_idx | subsequence │
│ --- | --- │
│ i64 | list[i64] │
╞═══════════╪══════════════╡
│ 0 | [0, 1] │
├───────────┼──────────────┤
│ 0 | [1, 2] │
├───────────┼──────────────┤
│ 0 | [2, 3] │
├───────────┼──────────────┤
│ 1 | [null, null] │
├───────────┼──────────────┤
│ 1 | [null, 2] │
├───────────┼──────────────┤
│ ... | ... │
├───────────┼──────────────┤
│ 2 | [null, 2] │
├───────────┼──────────────┤
│ 2 | [2, 3] │
├───────────┼──────────────┤
│ 3 | [0, null] │
├───────────┼──────────────┤
│ 3 | [null, null] │
├───────────┼──────────────┤
│ 3 | [null, 3] │
└─//────────┴─//───────────┘
It feels like you should be able to use pl.element() inside .then() here to avoid the explode/groupby but it fails:
>>> (
... min_df
... .with_column(
... pl.col("time_idx").arr.eval(
... pl.when(pl.arange(0, max_time_step + 1).is_in(pl.element()))
... .then(pl.element()),
... parallel=True)
... .alias("subsequence")
... )
... )
---------------------------------------------------------------------------
ShapeError Traceback (most recent call last)
import polars as pl
df = pl.DataFrame({'a': [[1, 2, 3], [8, 9, 4]], 'b': [[2, 3, 4], [4, 5, 6]]})
So given the dataframe df
a b
[1, 2, 3] [2, 3, 4]
[8, 9, 4] [4, 5, 6]
I would like to get a column c, that is an intersection of a andb
a b c
[1, 2, 3] [2, 3, 4] [2, 3]
[8, 9, 4] [4, 5, 6] [4]
I know I can use the apply function with python set intersection, but I want to do it using polars expressions.
We can accomplish the intersection using the arr.eval expression. The arr.eval expression allows us to treat a list as a Series/column, so that we can use the same contexts and expressions that we use with columns and Series.
First, let's extend your example so that we can show what happens when the intersection is empty.
df = pl.DataFrame(
{
"a": [[1, 2, 3], [8, 9, 4], [0, 1, 2]],
"b": [[2, 3, 4], [4, 5, 6], [10, 11, 12]],
}
)
df
shape: (3, 2)
┌───────────┬──────────────┐
│ a ┆ b │
│ --- ┆ --- │
│ list[i64] ┆ list[i64] │
╞═══════════╪══════════════╡
│ [1, 2, 3] ┆ [2, 3, 4] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ [8, 9, 4] ┆ [4, 5, 6] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ [0, 1, 2] ┆ [10, 11, 12] │
└───────────┴──────────────┘
The Algorithm
There are two ways to accomplish this. The first is extendable to the intersection of more than two sets (see Other Notes below).
df.with_column(
pl.col("a")
.arr.concat('b')
.arr.eval(pl.element().filter(pl.count().over(pl.element()) == 2))
.arr.unique()
.alias('intersection')
)
or
df.with_column(
pl.col("a")
.arr.concat('b')
.arr.eval(pl.element().filter(pl.element().is_duplicated()))
.arr.unique()
.alias('intersection')
)
shape: (3, 3)
┌───────────┬──────────────┬──────────────┐
│ a ┆ b ┆ intersection │
│ --- ┆ --- ┆ --- │
│ list[i64] ┆ list[i64] ┆ list[i64] │
╞═══════════╪══════════════╪══════════════╡
│ [1, 2, 3] ┆ [2, 3, 4] ┆ [2, 3] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ [8, 9, 4] ┆ [4, 5, 6] ┆ [4] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ [0, 1, 2] ┆ [10, 11, 12] ┆ [] │
└───────────┴──────────────┴──────────────┘
How it works
We first concatenate the two lists into a single list. Any element that is in both lists will appear twice.
df.with_column(
pl.col("a")
.arr.concat('b')
.alias('ablist')
)
shape: (3, 3)
┌───────────┬──────────────┬────────────────┐
│ a ┆ b ┆ ablist │
│ --- ┆ --- ┆ --- │
│ list[i64] ┆ list[i64] ┆ list[i64] │
╞═══════════╪══════════════╪════════════════╡
│ [1, 2, 3] ┆ [2, 3, 4] ┆ [1, 2, ... 4] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ [8, 9, 4] ┆ [4, 5, 6] ┆ [8, 9, ... 6] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ [0, 1, 2] ┆ [10, 11, 12] ┆ [0, 1, ... 12] │
└───────────┴──────────────┴────────────────┘
Then we can use the arr.eval function which allows us to treat the concatenated list as if it is a Series/column. In this case, we'll use a filter context to find any element that appears more than once. (The polars.element expression in a list context is used like polars.col is used in a Series.)
df.with_column(
pl.col("a")
.arr.concat('b')
.arr.eval(pl.element().filter(pl.count().over(pl.element()) == 2))
.alias('filtered')
)
shape: (3, 3)
┌───────────┬──────────────┬───────────────┐
│ a ┆ b ┆ filtered │
│ --- ┆ --- ┆ --- │
│ list[i64] ┆ list[i64] ┆ list[i64] │
╞═══════════╪══════════════╪═══════════════╡
│ [1, 2, 3] ┆ [2, 3, 4] ┆ [2, 3, ... 3] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ [8, 9, 4] ┆ [4, 5, 6] ┆ [4, 4] │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ [0, 1, 2] ┆ [10, 11, 12] ┆ [] │
└───────────┴──────────────┴───────────────┘
Note: the above step can also be expressed using the is_duplicated expression. (In the Other Notes section, we'll see that using is_duplicated will not work when calculating the intersection of more than two sets.)
df.with_column(
pl.col("a")
.arr.concat('b')
.arr.eval(pl.element().filter(pl.element().is_duplicated()))
.alias('filtered')
)
All that remains is then to remove the duplicates from the results, using the arr.unique expression (which is the result shown in the beginning).
Other Notes
I'm assuming that your lists are really sets, in that elements appear only once in each list. If there are duplicates in the original lists, we can apply arr.unique to each list before the concatenation step.
Also, this process can be extended to find the intersection of more than two sets. Simply concatenate all the lists together, and then change the filter step from == 2 to == n (where n is the number of sets). (Note: using the is_duplicated expression above will not work with more than two sets.)
The arr.eval method does have a parallel keyword. You can try setting this to True and see if it yields better performance in your particular situation.
Other Set Operations
Symmetric difference: change the filter criterion to == 1 (and omit the arr.unique step.)
Union: use arr.concat followed by arr.unique.
Set difference: compute the intersection (as above), then concatenate the original list/set and filter for items that appear only once. Alternatively, for small list sizes, you can concatenate “a” to itself and then to “b” and then filter for elements that occur twice (but not three times).
Assuming I already have a predicate expression, how do I filter with that predicate, but apply it only within groups? For example, the predicate might be to keep all rows equal to the maximum or within a group. (There could be multiple rows kept in a group if there is a tie.)
With my dplyr experience, I thought that I could just .groupby and then .filter, but that does not work.
import polars as pl
df = pl.DataFrame(dict(x=[0, 0, 1, 1], y=[1, 2, 3, 3]))
expression = pl.col("y") == pl.col("y").max()
df.groupby("x").filter(expression)
# AttributeError: 'GroupBy' object has no attribute 'filter'
I then thought I could apply .over to the expression, but that does not work either.
import polars as pl
df = pl.DataFrame(dict(x=[0, 0, 1, 1], y=[1, 2, 3, 3]))
expression = pl.col("y") == pl.col("y").max()
df.filter(expression.over("x"))
# RuntimeError: Any(ComputeError("this binary expression is not an aggregation:
# [(col(\"y\")) == (col(\"y\").max())]
# pherhaps you should add an aggregation like, '.sum()', '.min()', '.mean()', etc.
# if you really want to collect this binary expression, use `.list()`"))
For this particular problem, I can invoke .over on the max, but I don't know how to apply this to an arbitrary predicate I don't have control over.
import polars as pl
df = pl.DataFrame(dict(x=[0, 0, 1, 1], y=[1, 2, 3, 3]))
expression = pl.col("y") == pl.col("y").max().over("x")
df.filter(expression)
# shape: (3, 2)
# ┌─────┬─────┐
# │ x ┆ y │
# │ --- ┆ --- │
# │ i64 ┆ i64 │
# ╞═════╪═════╡
# │ 0 ┆ 2 │
# ├╌╌╌╌╌┼╌╌╌╌╌┤
# │ 1 ┆ 3 │
# ├╌╌╌╌╌┼╌╌╌╌╌┤
# │ 1 ┆ 3 │
# └─────┴─────┘
If you had updated to polars>=0.13.0 your second try would have worked. :)
df = pl.DataFrame(dict(
x=[0, 0, 1, 1],
y=[1, 2, 3, 3])
)
df.filter((pl.col("y") == pl.max("y").over("x")))
shape: (3, 2)
┌─────┬─────┐
│ x ┆ y │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞═════╪═════╡
│ 0 ┆ 2 │
├╌╌╌╌╌┼╌╌╌╌╌┤
│ 1 ┆ 3 │
├╌╌╌╌╌┼╌╌╌╌╌┤
│ 1 ┆ 3 │
└─────┴─────┘