Skip to content

Python Examples

Example scripts are in examples/python/.

RediSearch: Server-Side Filtering and Aggregation

This example demonstrates RediSearch integration including the query builder, server-side filtering, and aggregation:

#!/usr/bin/env python3
"""RediSearch example: Server-side filtering and aggregation.

This example demonstrates:
- Creating a RediSearch index
- Using search_hashes() for server-side filtering
- Using the query builder (col, raw)
- Using aggregate_hashes() for server-side aggregation

Prerequisites:
- Redis Stack running on localhost:6379
- pip install polars-redis redis
"""

import polars as pl
import polars_redis as pr
import redis as redis_py
from polars_redis import col, raw

URL = "redis://localhost:6379"


def setup_sample_data():
    """Create sample data and RediSearch index."""
    r = redis_py.Redis()

    # Clear existing data
    for key in r.scan_iter("employee:*"):
        r.delete(key)

    # Create sample employees
    employees = [
        {
            "name": "Alice",
            "age": "32",
            "department": "engineering",
            "salary": "120000",
            "status": "active",
        },
        {
            "name": "Bob",
            "age": "28",
            "department": "engineering",
            "salary": "95000",
            "status": "active",
        },
        {
            "name": "Carol",
            "age": "45",
            "department": "product",
            "salary": "140000",
            "status": "active",
        },
        {
            "name": "Dave",
            "age": "35",
            "department": "product",
            "salary": "110000",
            "status": "inactive",
        },
        {
            "name": "Eve",
            "age": "29",
            "department": "marketing",
            "salary": "85000",
            "status": "active",
        },
        {
            "name": "Frank",
            "age": "52",
            "department": "engineering",
            "salary": "150000",
            "status": "active",
        },
        {
            "name": "Grace",
            "age": "38",
            "department": "marketing",
            "salary": "95000",
            "status": "active",
        },
        {
            "name": "Henry",
            "age": "41",
            "department": "engineering",
            "salary": "130000",
            "status": "inactive",
        },
    ]

    for i, emp in enumerate(employees, 1):
        r.hset(f"employee:{i}", mapping=emp)

    # Drop existing index if it exists
    try:
        r.execute_command("FT.DROPINDEX", "employees_idx")
    except redis_py.ResponseError:
        pass  # Index doesn't exist

    # Create RediSearch index
    r.execute_command(
        "FT.CREATE",
        "employees_idx",
        "ON",
        "HASH",
        "PREFIX",
        "1",
        "employee:",
        "SCHEMA",
        "name",
        "TEXT",
        "SORTABLE",
        "age",
        "NUMERIC",
        "SORTABLE",
        "department",
        "TAG",
        "salary",
        "NUMERIC",
        "SORTABLE",
        "status",
        "TAG",
    )

    print("Created 8 employees and RediSearch index")
    return r


def example_basic_search():
    """Basic search with raw query string."""
    print("\n=== Basic Search (age > 30) ===")

    df = pr.search_hashes(
        URL,
        index="employees_idx",
        query="@age:[30 +inf]",  # RediSearch query syntax
        schema={
            "name": pl.Utf8,
            "age": pl.Int64,
            "department": pl.Utf8,
            "salary": pl.Float64,
        },
    ).collect()

    print(df)


def example_query_builder():
    """Using the Polars-like query builder."""
    print("\n=== Query Builder (age > 30 AND status == active) ===")

    # Build query with Polars-like syntax
    query = (col("age") > 30) & (col("status") == "active")

    df = pr.search_hashes(
        URL,
        index="employees_idx",
        query=query,
        schema={
            "name": pl.Utf8,
            "age": pl.Int64,
            "department": pl.Utf8,
            "status": pl.Utf8,
        },
    ).collect()

    print(df)


def example_or_conditions():
    """Combining conditions with OR."""
    print("\n=== OR Conditions (engineering OR product) ===")

    query = (col("department") == "engineering") | (col("department") == "product")

    df = pr.search_hashes(
        URL,
        index="employees_idx",
        query=query,
        schema={"name": pl.Utf8, "department": pl.Utf8, "salary": pl.Float64},
    ).collect()

    print(df)


def example_negation():
    """Using negation."""
    print("\n=== Negation (NOT inactive) ===")

    query = col("status") != "inactive"

    df = pr.search_hashes(
        URL,
        index="employees_idx",
        query=query,
        schema={"name": pl.Utf8, "status": pl.Utf8},
    ).collect()

    print(df)


def example_raw_query():
    """Using raw() for complex queries."""
    print("\n=== Raw Query (name prefix search) ===")

    # Full-text prefix search
    query = raw("@name:A*")  # Names starting with A

    df = pr.search_hashes(
        URL,
        index="employees_idx",
        query=query,
        schema={"name": pl.Utf8, "department": pl.Utf8},
    ).collect()

    print(df)


def example_sorted_search():
    """Search with sorting."""
    print("\n=== Sorted Search (by salary descending) ===")

    df = pr.search_hashes(
        URL,
        index="employees_idx",
        query="@status:{active}",
        schema={"name": pl.Utf8, "salary": pl.Float64},
        sort_by="salary",
        sort_ascending=False,
    ).collect()

    print(df)


def example_basic_aggregation():
    """Basic aggregation with GROUP BY."""
    print("\n=== Aggregation (count by department) ===")

    df = pr.aggregate_hashes(
        URL,
        index="employees_idx",
        query="*",
        group_by=["@department"],
        reduce=[("COUNT", [], "employee_count")],
    )

    print(df)


def example_multi_aggregation():
    """Multiple aggregation functions."""
    print("\n=== Multi Aggregation (salary stats by department) ===")

    df = pr.aggregate_hashes(
        URL,
        index="employees_idx",
        query="@status:{active}",  # Only active employees
        group_by=["@department"],
        reduce=[
            ("COUNT", [], "headcount"),
            ("AVG", ["@salary"], "avg_salary"),
            ("MIN", ["@salary"], "min_salary"),
            ("MAX", ["@salary"], "max_salary"),
            ("SUM", ["@salary"], "total_payroll"),
        ],
        sort_by=[("@avg_salary", False)],  # Sort by avg salary descending
    )

    print(df)


def example_computed_fields():
    """Using APPLY for computed fields."""
    print("\n=== Computed Fields (avg order value) ===")

    df = pr.aggregate_hashes(
        URL,
        index="employees_idx",
        query="*",
        group_by=["@department"],
        reduce=[
            ("SUM", ["@salary"], "total_salary"),
            ("COUNT", [], "count"),
        ],
        apply=[
            ("@total_salary / @count", "calculated_avg"),
        ],
    )

    print(df)


def example_global_aggregation():
    """Aggregation without grouping (global stats)."""
    print("\n=== Global Aggregation (company-wide stats) ===")

    df = pr.aggregate_hashes(
        URL,
        index="employees_idx",
        query="@status:{active}",
        reduce=[
            ("COUNT", [], "total_employees"),
            ("AVG", ["@salary"], "company_avg_salary"),
            ("AVG", ["@age"], "avg_age"),
        ],
    )

    print(df)


def main():
    """Run all examples."""
    print("RediSearch Example")
    print("=" * 50)

    setup_sample_data()

    # Search examples
    example_basic_search()
    example_query_builder()
    example_or_conditions()
    example_negation()
    example_raw_query()
    example_sorted_search()

    # Aggregation examples
    example_basic_aggregation()
    example_multi_aggregation()
    example_computed_fields()
    example_global_aggregation()

    print("\n" + "=" * 50)
    print("All examples completed!")


if __name__ == "__main__":
    main()

Basic Hashes

Scanning and writing Redis hashes:

"""Basic example: scanning and writing Redis hashes."""

import polars as pl
import polars_redis as redis

URL = "redis://localhost:6379"


def main():
    # Define schema
    schema = {
        "name": pl.Utf8,
        "age": pl.Int64,
        "score": pl.Float64,
        "active": pl.Boolean,
    }

    # Scan hashes matching pattern
    lf = redis.scan_hashes(URL, pattern="user:*", schema=schema)

    # Filter and select (projection pushdown fetches only needed fields)
    df = lf.filter(pl.col("age") > 25).select(["_key", "name", "age"]).collect()

    print("Users over 25:")
    print(df)

    # Write new data back to Redis
    new_users = pl.DataFrame(
        {
            "name": ["Charlie", "Diana"],
            "age": [28, 35],
            "score": [88.5, 92.0],
            "active": [True, False],
        }
    )

    # Auto-generate keys from row index
    count = redis.write_hashes(
        new_users,
        URL,
        key_column=None,
        key_prefix="user:new:",
        ttl=3600,  # 1 hour TTL
    )
    print(f"\nWrote {count} new hashes")


if __name__ == "__main__":
    main()

JSON Documents

Working with RedisJSON documents:

"""Example: working with RedisJSON documents."""

import polars as pl
import polars_redis as redis

URL = "redis://localhost:6379"


def main():
    # Scan JSON documents
    schema = {
        "title": pl.Utf8,
        "category": pl.Utf8,
        "price": pl.Float64,
        "in_stock": pl.Boolean,
    }

    lf = redis.scan_json(URL, pattern="product:*", schema=schema)

    # Aggregate by category
    summary = (
        lf.group_by("category")
        .agg(
            [
                pl.len().alias("count"),
                pl.col("price").mean().alias("avg_price"),
                pl.col("price").max().alias("max_price"),
            ]
        )
        .sort("avg_price", descending=True)
        .collect()
    )

    print("Products by category:")
    print(summary)

    # Write new products
    products = pl.DataFrame(
        {
            "title": ["Widget", "Gadget"],
            "category": ["tools", "electronics"],
            "price": [19.99, 49.99],
            "in_stock": [True, True],
        }
    )

    count = redis.write_json(
        products,
        URL,
        key_column=None,
        key_prefix="product:",
    )
    print(f"\nWrote {count} products")


if __name__ == "__main__":
    main()

Schema Inference

Automatic schema detection:

"""Example: automatic schema inference."""

import polars_redis as redis

URL = "redis://localhost:6379"


def main():
    # Infer schema from existing hashes
    schema = redis.infer_hash_schema(URL, pattern="user:*", sample_size=100)
    print("Inferred hash schema:")
    for field, dtype in schema.items():
        print(f"  {field}: {dtype}")

    # Use inferred schema to scan
    df = redis.read_hashes(URL, pattern="user:*", schema=schema)
    print(f"\nLoaded {len(df)} rows")
    print(df.head())

    # Infer JSON schema
    json_schema = redis.infer_json_schema(URL, pattern="product:*", sample_size=50)
    print("\nInferred JSON schema:")
    for field, dtype in json_schema.items():
        print(f"  {field}: {dtype}")


if __name__ == "__main__":
    main()

Strings and Counters

Redis strings and counter aggregation:

"""Example: working with Redis strings."""

import polars as pl
import polars_redis as redis

URL = "redis://localhost:6379"


def main():
    # Scan string values as integers (counters)
    lf = redis.scan_strings(
        URL,
        pattern="counter:*",
        value_type=pl.Int64,
    )

    # Sum all counters
    total = lf.select(pl.col("value").sum().alias("total")).collect()
    print(f"Total count: {total['total'][0]}")

    # Scan as strings for cache entries
    cache = redis.read_strings(URL, pattern="cache:*")
    print(f"\nCache entries: {len(cache)}")
    print(cache.head())

    # Write counters
    counters = pl.DataFrame({"value": ["100", "200", "300"]})

    count = redis.write_strings(
        counters,
        URL,
        key_column=None,
        key_prefix="counter:page:",
        ttl=86400,  # 1 day TTL
    )
    print(f"\nWrote {count} counters")


if __name__ == "__main__":
    main()

TTL and Metadata

TTL and row index columns:

"""Example: TTL and metadata columns."""

import polars as pl
import polars_redis as redis

URL = "redis://localhost:6379"


def main():
    schema = {"name": pl.Utf8, "email": pl.Utf8}

    # Include TTL and row index columns
    lf = redis.scan_hashes(
        URL,
        pattern="user:*",
        schema=schema,
        include_key=True,
        include_ttl=True,
        include_row_index=True,
    )

    df = lf.collect()
    print("Users with metadata:")
    print(df)

    # Find keys expiring soon (TTL < 1 hour)
    expiring = df.filter((pl.col("_ttl") > 0) & (pl.col("_ttl") < 3600))
    print(f"\nKeys expiring within 1 hour: {len(expiring)}")

    # Keys without expiry (TTL = -1)
    no_expiry = df.filter(pl.col("_ttl") == -1)
    print(f"Keys without expiry: {len(no_expiry)}")


if __name__ == "__main__":
    main()

Write Modes

Write modes: fail, replace, append:

"""Example: write modes (fail, replace, append)."""

import polars as pl
import polars_redis as redis

URL = "redis://localhost:6379"


def main():
    # Sample data
    users = pl.DataFrame(
        {
            "_key": ["user:1", "user:2"],
            "name": ["Alice", "Bob"],
            "age": [30, 25],
        }
    )

    # Replace mode (default): overwrites existing keys
    count = redis.write_hashes(users, URL, if_exists="replace")
    print(f"Replace mode: wrote {count} hashes")

    # Fail mode: skips keys that already exist
    count = redis.write_hashes(users, URL, if_exists="fail")
    print(f"Fail mode: wrote {count} hashes (0 because keys exist)")

    # Append mode: merges fields into existing hashes
    updates = pl.DataFrame(
        {
            "_key": ["user:1", "user:2"],
            "score": [95.5, 88.0],  # Add new field
        }
    )
    count = redis.write_hashes(updates, URL, if_exists="append")
    print(f"Append mode: updated {count} hashes with new field")

    # Verify
    df = redis.read_hashes(
        URL,
        pattern="user:*",
        schema={"name": pl.Utf8, "age": pl.Int64, "score": pl.Float64},
    )
    print("\nFinal data:")
    print(df)


if __name__ == "__main__":
    main()