Skip to content

DataStream.filter_sql

This will filter the DataStream to contain only rows that match a certain predicate specified in SQL syntax. You can write any SQL clause you would generally put in a WHERE statement containing arbitrary conjunctions and disjunctions. The columns in your statement must be in the schema of this DataStream!

Since a DataStream is implemented as a stream of batches, you might be tempted to think of a filtered DataStream as a stream of batches where each batch directly results from a filter being applied to a batch in the source DataStream. While this certainly may be the case, filters are aggressively optimized by Quokka and is most likely pushed all the way down to the input readers. As a result, you typically should not see a filter node in a Quokka execution plan shown by explain().

Parameters:

Name Type Description Default
predicate str

a SQL WHERE clause, look at the examples.

required
Return

A DataStream consisting of rows from the source DataStream that match the predicate.

Examples:

Read in a CSV file into a DataStream f.

>>> f = qc.read_csv("lineitem.csv")

Filter for all the rows where l_orderkey smaller than 10 and l_partkey greater than 5.

>>> f = f.filter_sql("l_orderkey < 10 and l_partkey > 5") 

Nested conditions are supported.

>>> f = f.filter_sql("l_orderkey < 10 and (l_partkey > 5 or l_partkey < 1)") 

Most SQL features such as IN and date are supported. Anything DuckDB supports should work.

>>> f = f.filter_sql("l_shipmode IN ('MAIL','SHIP') and l_receiptdate < date '1995-01-01'")

You can do arithmetic in the predicate just like in SQL.

>>> f = f.filter_sql("l_shipdate < date '1994-01-01' + interval '1' year and l_discount between 0.06 - 0.01 and 0.06 + 0.01")

This will fail! Assuming c_custkey is not in f.schema

>>> f = f.filter_sql("c_custkey > 10")
Source code in pyquokka/datastream.py
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
def filter_sql(self, predicate: str):

    """
    This will filter the DataStream to contain only rows that match a certain predicate specified in SQL syntax. 
    You can write any SQL clause you would generally put in a WHERE statement containing arbitrary conjunctions and 
    disjunctions. The columns in your statement must be in the schema of this DataStream! 

    Since a DataStream is implemented as a stream of batches, you might be tempted to think of a filtered DataStream
    as a stream of batches where each batch directly results from a filter being applied to a batch in the source DataStream. 
    While this certainly may be the case, filters are aggressively optimized by Quokka and is most likely pushed all the way down
    to the input readers. As a result, you typically should not see a filter node in a Quokka execution plan shown by `explain()`. 

    Args:
        predicate (str): a SQL WHERE clause, look at the examples.

    Return:
        A DataStream consisting of rows from the source DataStream that match the predicate.

    Examples:

        Read in a CSV file into a DataStream f.

        >>> f = qc.read_csv("lineitem.csv")

        Filter for all the rows where l_orderkey smaller than 10 and l_partkey greater than 5.

        >>> f = f.filter_sql("l_orderkey < 10 and l_partkey > 5") 

        Nested conditions are supported.

        >>> f = f.filter_sql("l_orderkey < 10 and (l_partkey > 5 or l_partkey < 1)") 

        Most SQL features such as IN and date are supported. Anything DuckDB supports should work.

        >>> f = f.filter_sql("l_shipmode IN ('MAIL','SHIP') and l_receiptdate < date '1995-01-01'")

        You can do arithmetic in the predicate just like in SQL. 

        >>> f = f.filter_sql("l_shipdate < date '1994-01-01' + interval '1' year and l_discount between 0.06 - 0.01 and 0.06 + 0.01")

        This will fail! Assuming c_custkey is not in f.schema

        >>> f = f.filter_sql("c_custkey > 10")
    """

    assert type(predicate) == str
    predicate = sqlglot.parse_one(predicate)
    # convert to CNF
    predicate = optimizer.normalize.normalize(predicate, dnf = False)

    columns = set(i.name for i in predicate.find_all(
        sqlglot.expressions.Column))
    for column in columns:
        assert column in self.schema, "Tried to filter on a column not in the schema {}".format(column)

    if self.materialized:
        batch_arrow = self._get_materialized_df().to_arrow()
        con = duckdb.connect().execute('PRAGMA threads=%d' % 8)
        df = polars.from_arrow(con.execute("select * from batch_arrow where " + predicate.sql(dialect = "duckdb")).arrow())
        return self.quokka_context.from_polars(df)

    if not optimizer.normalize.normalized(predicate):
        def f(df):
            batch_arrow = df.to_arrow()
            con = duckdb.connect().execute('PRAGMA threads=%d' % 8)
            return polars.from_arrow(con.execute("select * from batch_arrow where " + predicate.sql(dialect = "duckdb")).arrow())

        transformed = self.transform(f, new_schema = self.schema, required_columns=self.schema)
        return transformed
    else:
        return self.quokka_context.new_stream(sources={0: self}, partitioners={0: PassThroughPartitioner()}, node=FilterNode(self.schema, predicate),
                                          schema=self.schema, sorted = self.sorted)