Skip to content

DataStream.top_k

This is a topk function that effectively performs select * from stream order by columns limit k. The strategy is to take k rows from each batch coming in and do a final sort and limit k in a stateful executor.

Parameters:

Name Type Description Default
columns str or list

a column or a list of columns to sort on.

required
k int

the number of rows to return.

required
descending bool or list

a boolean or a list of booleans indicating whether to sort in descending order. If a list, the length must be the same as the length of columns.

None
Return

A DataStream object with the specified top k rows.

Examples:

>>> lineitem = qc.read_csv("lineitem.csv")

result will be a DataStream.

>>> result = lineitem.top_k("l_orderkey", 10)
>>> result = lineitem.top_k(["l_orderkey", "l_orderdate"], 10, descending = [True, False])
Source code in pyquokka/datastream.py
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
def top_k(self, columns, k, descending = None):
    """
    This is a topk function that effectively performs select * from stream order by columns limit k.
    The strategy is to take k rows from each batch coming in and do a final sort and limit k in a stateful executor.

    Args:
        columns (str or list): a column or a list of columns to sort on.
        k (int): the number of rows to return.
        descending (bool or list): a boolean or a list of booleans indicating whether to sort in descending order. If a list, the length must be the same as the length of `columns`.

    Return:
        A DataStream object with the specified top k rows.

    Examples:

        >>> lineitem = qc.read_csv("lineitem.csv")

        `result` will be a DataStream.

        >>> result = lineitem.top_k("l_orderkey", 10)
        >>> result = lineitem.top_k(["l_orderkey", "l_orderdate"], 10, descending = [True, False])
    """
    if type(columns) == str:
        columns = [columns]
    assert type(columns) == list and len(columns) > 0

    if descending is not None:
        if type(descending) == bool:
            descending = [descending]
        assert type(descending) == list and len(descending) == len(columns)
        assert all([type(i) == bool for i in descending])
    else:
        descending = [False] * len(columns)

    assert type(k) == int
    assert k > 0

    new_columns = []
    for i in range(len(columns)):
        if descending[i]:
            new_columns.append(columns[i] + " desc")
        else:
            new_columns.append(columns[i] + " asc")

    sql_statement = "select * from batch_arrow order by " + ",".join(new_columns) + " limit " + str(k)

    def f(df):
        batch_arrow = df.to_arrow()
        con = duckdb.connect().execute('PRAGMA threads=%d' % 8)
        return polars.from_arrow(con.execute(sql_statement).arrow())

    transformed = self.transform(f, new_schema = self.schema, required_columns=set(self.schema))

    topk_node = StatefulNode(
        schema=self.schema,
        schema_mapping={col: {0: col} for col in self.schema},
        required_columns={0: set(columns)},
        operator=ConcatThenSQLExecutor(sql_statement)
    )
    topk_node.set_placement_strategy(SingleChannelStrategy())
    return self.quokka_context.new_stream(
        sources={0: transformed},
        partitioners={0: BroadcastPartitioner()},
        node=topk_node,
        schema=self.schema,
    )